In [None]:
import os, sys

sys.path.append("../..")

import torch
import numpy as np
from astropy.table import Table

from astroclip.env import format_with_env
from morphology_utils.models import train_eval_on_question
from morphology_utils.plotting import plot_radar

ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}")


# Load the data
galaxy_zoo = Table.read(
    f"{ASTROCLIP_ROOT}/datasets/galaxy_zoo/gz5_decals_crossmatched_embeddings.h5"
)

# Remove the galaxies with fewer than 3 votes
galaxy_zoo = galaxy_zoo[galaxy_zoo["smooth-or-featured_total-votes"] >= 3]

# Get the embeddings
X = {
    "AstroCLIP": torch.tensor(galaxy_zoo["astroclip_embeddings"]),
    "AstroDINO": torch.tensor(galaxy_zoo["astrodino_embeddings"]),
    "Stein": torch.tensor(galaxy_zoo["stein_embeddings"]),
}

# Get the names of the columns
names = names = [
    "smooth",
    "disk-edge-on",
    "spiral-arms",
    "bar",
    "bulge-size",
    "how-rounded",
    "edge-on-bulge",
    "spiral-winding",
    "spiral-arm-count",
    "merging",
]

# Get the labels
galaxy_zoo.remove_columns(
    ["astroclip_embeddings", "astrodino_embeddings", "stein_embeddings"]
)
classifications = galaxy_zoo

# Get the key list
keys = {
    name: {
        "target": [
            key
            for key in classifications.colnames
            if name in key and "debiased" in key and "mask" not in key
        ],
        "counts": [
            key
            for key in classifications.colnames
            if name in key and "total-votes" in key
        ][0],
    }
    for name in names
}

In [None]:
# Select first 80% for train and last 20% for test
train_indices = int(0.8 * len(classifications))

X_train, X_test = {}, {}
for key in X.keys():
    X_train[key] = X[key][:train_indices]
    X_test[key] = X[key][train_indices:]

classifications_train, classifications_test = (
    classifications[:train_indices],
    classifications[train_indices:],
)

In [None]:
# This is the total number of possible votes
total_counts_train = classifications_train[keys["smooth"]["counts"]].data

# Get accuracy and F1 score on each question
outputs = {key: {} for key in X.keys()}
for name in names:
    question, num_classes = name, len(keys[name]["target"])

    # Get the train samples above 50% answered
    counts_train = classifications_train[keys[name]["counts"]].data
    # train_mask = np.where(counts_train / total_counts_train > 0.5)[0]
    train_mask = [True] * len(counts_train)

    # Get the test samples above 34 answers
    counts_test = classifications_test[keys[name]["counts"]].data
    test_mask = np.where(counts_test > 34)[0]

    # Get train and test
    y_train = torch.tensor(
        classifications_train[keys[name]["target"]].to_pandas().values
    )[train_mask]
    y_test = torch.tensor(
        classifications_test[keys[name]["target"]].to_pandas().values
    )[test_mask]

    train_nan_mask = torch.isnan(y_train).any(axis=1)
    test_nan_mask = torch.isnan(y_test).any(axis=1)

    # Train and evaluate on each model
    print(f"Training on question: {question}...")
    for model in X.keys():
        X_train_local = X_train[model][train_mask][~train_nan_mask]
        X_test_local = X_test[model][test_mask][~test_nan_mask]
        outputs[model][name] = train_eval_on_question(
            X_train_local,
            X_test_local,
            y_train,
            y_test,
            X_train_local.shape[1],
            num_classes=num_classes,
            MLP_dim=256,
            epochs=25,
            dropout=0.2,
        )
        print(
            f"Model: {model}, Accuracy: {outputs[model][name]['Accuracy']:.4f}, F1: {outputs[model][name]['F1 Score']:.4f}"
        )
    print("Done!")

In [None]:
# Clean up labels
outputs["Unaligned Transformer"] = outputs.pop("AstroDINO")
outputs["Stein, et al."] = outputs.pop("Stein")

# Plot radar plots
plot_radar(outputs, metric="Accuracy", file_path=f"./outputs/radar_accuracy.png")
plot_radar(outputs, metric="F1 Score", file_path=f"./outputs/radar_f1_score.png")