In [1]:
import os
import glob
import pandas as pd
import sys

sys.path.append('../scripts')
from make_polar_hexbin_images import create_polar_hexbin_from_ecsv
from classify_pytorch_resnet import load_resnet_model, predict_image

# Paths to models
class_model_path = '../models/2class_binary_resnet_gaia_hexbin.pth'
overspot_model_path = '../models/2class_overspot_resnet_gaia_hexbin.pth'

# Load models once
class_model = load_resnet_model(class_model_path, num_classes=2)
overspot_model = load_resnet_model(overspot_model_path, num_classes=2)

# Directory containing images
image_dir = '../data/gaia_real_data/images_hex'

overcontact_results = []

for image_path in glob.glob(os.path.join(image_dir, '*.png')):
    # Object name is the part before the first '_'
    object_name = os.path.basename(image_path).split('_')[0]
    # Predict system class
    pred_class, class_probs = predict_image(class_model, image_path)
    class_label = 'detached' if pred_class == 0 else 'overcontact'
    class_prob = class_probs[pred_class]
    # Only process overcontact systems
    if class_label == 'overcontact':
        row = {
            'object': object_name,
            'class': class_label,
            'class_probability': class_prob,
            'spots': None,
            'spot_probability': None
        }
        # Predict spots using the overcontact spot model
        spot_pred, spot_probs = predict_image(overspot_model, image_path)
        spot_label = 'n' if spot_pred == 0 else 's'
        spot_prob = spot_probs[spot_pred]
        row['spots'] = spot_label
        row['spot_probability'] = spot_prob
        overcontact_results.append(row)

# Create DataFrame
df_overcontact = pd.DataFrame(overcontact_results)


In [4]:
df_overcontact.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 82 entries, 0 to 81
Data columns (total 5 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   object             82 non-null     object 
 1   class              82 non-null     object 
 2   class_probability  82 non-null     float32
 3   spots              82 non-null     object 
 4   spot_probability   82 non-null     float32
dtypes: float32(2), object(3)
memory usage: 2.7+ KB
