In [1]:
#python -m ipykernel install --user --name astrodino --display-name "Python (astrodino)"
import sys
print(sys.executable)   # current python folder
print(sys.version)      # Python ver

import site
print(site.getsitepackages())

!which python

/u/yacheng/conda-envs/astrodino/bin/python
3.10.19 | packaged by conda-forge | (main, Oct 22 2025, 22:29:10) [GCC 14.3.0]
['/u/yacheng/conda-envs/astrodino/lib/python3.10/site-packages']
/mpcdf/soft/SLE_15/packages/x86_64/jupyterlab/3.6.3/bin/python


In [2]:
import dinov2

In [3]:
import os, sys

sys.path.append("../..")

import torch
import numpy as np
from astropy.table import Table

from astroclip.env import format_with_env
from morphology_utils.models import train_eval_on_question
from morphology_utils.plotting import plot_radar

ASTROCLIP_ROOT = format_with_env("{ASTROCLIP_ROOT}")

# Load the data
galaxy_zoo = Table.read(f"/ptmp/yacheng/outthere_ssl/images/galaxy_zoo/galaxy_zoo_embeddings.hdf5")

embedding_keys = [key for key in galaxy_zoo.keys() if 'embeddings' in key]
print("Embedding keys:", embedding_keys)

'''
# use unique names
X = {
    "AstroCLIP": torch.tensor(galaxy_zoo["astroclip_embeddings"]),
    "AstroDINO": torch.tensor(galaxy_zoo["astrodino_embeddings"]),
    "Stein": torch.tensor(galaxy_zoo["stein_embeddings"]),
}
'''

#use unique column embedding names from embedding_keys
X = {
    key: torch.tensor(np.array(galaxy_zoo[key]))
    for key in embedding_keys
}


# Get the names of the columns
names = names = [
    "smooth",
    "disk-edge-on",
    "spiral-arms",
    "bar",
    "bulge-size",
    "how-rounded",
    "edge-on-bulge",
    "spiral-winding",
    "spiral-arm-count",
    "merging",
]

# Get the labels
galaxy_zoo.remove_columns(
    embedding_keys
)
classifications = galaxy_zoo
    
# Get the key list
keys = {
    name: {
        "target": [
            key
            for key in classifications.colnames
            if name in key and "debiased" in key and "mask" not in key
        ],
        "counts": [
            key
            for key in classifications.colnames
            if name in key and "total-votes" in key
        ][0],
    }
    for name in names
}

/u/yacheng/conda-envs/astrodino/lib/python3.10/site-packages/lightning/fabric/__init__.py:41: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.


Embedding keys: ['astrodino_3z_7e5_vit_base_embeddings', 'astrodino_3z_7e5_vit_large_embeddings']


In [5]:
galaxy_zoo

iauname,ra,dec,redshift,elpetro_absmag_r,elpetro_absmag_r.mask,sersic_nmgy_r,petro_th50,petro_th90,petro_theta,upload_group,active_learning_on,in_gzd_ab,png_loc,smooth-or-featured_total-votes,smooth-or-featured_smooth,smooth-or-featured_smooth_fraction,smooth-or-featured_smooth_debiased,smooth-or-featured_smooth_debiased.mask,smooth-or-featured_featured-or-disk,smooth-or-featured_featured-or-disk_fraction,smooth-or-featured_featured-or-disk_debiased,smooth-or-featured_featured-or-disk_debiased.mask,smooth-or-featured_artifact,smooth-or-featured_artifact_fraction,smooth-or-featured_artifact_debiased,smooth-or-featured_artifact_debiased.mask,disk-edge-on_total-votes,disk-edge-on_yes,disk-edge-on_yes_fraction,disk-edge-on_yes_fraction.mask,disk-edge-on_yes_debiased,disk-edge-on_yes_debiased.mask,disk-edge-on_no,disk-edge-on_no_fraction,disk-edge-on_no_fraction.mask,disk-edge-on_no_debiased,disk-edge-on_no_debiased.mask,has-spiral-arms_total-votes,has-spiral-arms_yes,has-spiral-arms_yes_fraction,has-spiral-arms_yes_fraction.mask,has-spiral-arms_yes_debiased,has-spiral-arms_yes_debiased.mask,has-spiral-arms_no,has-spiral-arms_no_fraction,has-spiral-arms_no_fraction.mask,has-spiral-arms_no_debiased,has-spiral-arms_no_debiased.mask,bar_total-votes,bar_strong,bar_strong_fraction,bar_strong_fraction.mask,bar_strong_debiased,bar_strong_debiased.mask,bar_weak,bar_weak_fraction,bar_weak_fraction.mask,bar_weak_debiased,bar_weak_debiased.mask,bar_no,bar_no_fraction,bar_no_fraction.mask,bar_no_debiased,bar_no_debiased.mask,bulge-size_total-votes,bulge-size_dominant,bulge-size_dominant_fraction,bulge-size_dominant_fraction.mask,bulge-size_dominant_debiased,bulge-size_dominant_debiased.mask,bulge-size_large,bulge-size_large_fraction,bulge-size_large_fraction.mask,bulge-size_large_debiased,bulge-size_large_debiased.mask,bulge-size_moderate,bulge-size_moderate_fraction,bulge-size_moderate_fraction.mask,bulge-size_moderate_debiased,bulge-size_moderate_debiased.mask,bulge-size_small,bulge-size_small_fraction,bulge-size_small_fraction.mask,bulge-size_small_debiased,bulge-size_small_debiased.mask,bulge-size_none,bulge-size_none_fraction,bulge-size_none_fraction.mask,bulge-size_none_debiased,bulge-size_none_debiased.mask,how-rounded_total-votes,how-rounded_round,how-rounded_round_fraction,how-rounded_round_fraction.mask,how-rounded_round_debiased,how-rounded_round_debiased.mask,how-rounded_in-between,how-rounded_in-between_fraction,how-rounded_in-between_fraction.mask,how-rounded_in-between_debiased,how-rounded_in-between_debiased.mask,how-rounded_cigar-shaped,how-rounded_cigar-shaped_fraction,how-rounded_cigar-shaped_fraction.mask,how-rounded_cigar-shaped_debiased,how-rounded_cigar-shaped_debiased.mask,edge-on-bulge_total-votes,edge-on-bulge_boxy,edge-on-bulge_boxy_fraction,edge-on-bulge_boxy_fraction.mask,edge-on-bulge_boxy_debiased,edge-on-bulge_boxy_debiased.mask,edge-on-bulge_none,edge-on-bulge_none_fraction,edge-on-bulge_none_fraction.mask,edge-on-bulge_none_debiased,edge-on-bulge_none_debiased.mask,edge-on-bulge_rounded,edge-on-bulge_rounded_fraction,edge-on-bulge_rounded_fraction.mask,edge-on-bulge_rounded_debiased,edge-on-bulge_rounded_debiased.mask,spiral-winding_total-votes,spiral-winding_tight,spiral-winding_tight_fraction,spiral-winding_tight_fraction.mask,spiral-winding_tight_debiased,spiral-winding_tight_debiased.mask,spiral-winding_medium,spiral-winding_medium_fraction,spiral-winding_medium_fraction.mask,spiral-winding_medium_debiased,spiral-winding_medium_debiased.mask,spiral-winding_loose,spiral-winding_loose_fraction,spiral-winding_loose_fraction.mask,spiral-winding_loose_debiased,spiral-winding_loose_debiased.mask,spiral-arm-count_total-votes,spiral-arm-count_1,spiral-arm-count_1_fraction,spiral-arm-count_1_fraction.mask,spiral-arm-count_1_debiased,spiral-arm-count_1_debiased.mask,spiral-arm-count_2,spiral-arm-count_2_fraction,spiral-arm-count_2_fraction.mask,spiral-arm-count_2_debiased,spiral-arm-count_2_debiased.mask,spiral-arm-count_3,spiral-arm-count_3_fraction,spiral-arm-count_3_fraction.mask,spiral-arm-count_3_debiased,spiral-arm-count_3_debiased.mask,spiral-arm-count_4,spiral-arm-count_4_fraction,spiral-arm-count_4_fraction.mask,spiral-arm-count_4_debiased,spiral-arm-count_4_debiased.mask,spiral-arm-count_more-than-4,spiral-arm-count_more-than-4_fraction,spiral-arm-count_more-than-4_fraction.mask,spiral-arm-count_more-than-4_debiased,spiral-arm-count_more-than-4_debiased.mask,spiral-arm-count_cant-tell,spiral-arm-count_cant-tell_fraction,spiral-arm-count_cant-tell_fraction.mask,spiral-arm-count_cant-tell_debiased,spiral-arm-count_cant-tell_debiased.mask,merging_total-votes,merging_none,merging_none_fraction,merging_none_fraction.mask,merging_none_debiased,merging_none_debiased.mask,merging_minor-disturbance,merging_minor-disturbance_fraction,merging_minor-disturbance_fraction.mask,merging_minor-disturbance_debiased,merging_minor-disturbance_debiased.mask,merging_major-disturbance,merging_major-disturbance_fraction,merging_major-disturbance_fraction.mask,merging_major-disturbance_debiased,merging_major-disturbance_debiased.mask,merging_merger,merging_merger_fraction,merging_merger_fraction.mask,merging_merger_debiased,merging_merger_debiased.mask,wrong_size_statistic,wrong_size_warning,file_id,index
bytes19,float64,float64,float64,float64,bool,float64,float64,float64,float64,bytes15,bytes5,bytes5,bytes32,int64,int64,float64,float64,bool,int64,float64,float64,bool,int64,float64,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,int64,float64,bool,float64,bool,float64,bytes5,int32,int32
J112953.88-000427.4,172.47452574804296,-0.07428063790376414,0.12414213,-21.253342,False,167.26288,1.8350658,5.414066,4.2024717,pre_active,False,True,dr5/J112/J112953.88-000427.4.png,84,57,0.6785714285714286,0.10256410256410256,False,23,0.27380952380952384,0.9166666666666666,False,4,0.047619047619047616,0.02574176551028673,False,23,7,0.30434782608695654,False,0.04878048780487805,False,16,0.6956521739130435,False,0.8055015139329738,False,16,1,0.0625,False,0.8205128205128205,False,15,0.9375,False,0.10817139052102372,False,16,2,0.125,False,0.14583333333333334,False,1,0.0625,False,0.18421052631578946,False,13,0.8125,False,0.7317073170731707,False,16,1,0.0625,False,0.038461538461538464,False,6,0.375,False,0.21568627450980396,False,9,0.5625,False,0.5055038295783774,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,57,0,0.0,False,0.0,False,20,0.3508771929824561,False,0.35714285714285715,False,37,0.6491228070175439,False,0.8666666666666667,False,7,0,0.0,False,0.0,False,2,0.2857142857142857,False,0.2352941176470588,False,5,0.7142857142857143,False,0.5573770491803278,False,1,1,1.0,False,1.0,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,1,0,0.0,False,0.0,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,1,1.0,False,1.0,False,80,67,0.8375,False,0.7727272727272727,False,8,0.1,False,0.12770785184557293,False,5,0.0625,False,0.07692307692307693,False,0,0.0,False,0.0,False,133.59268406353215,False,14,438459
J104325.29+190335.0,160.85653322901592,19.06044073355132,0.049088165,-21.77541,False,2941.7292,4.405412,14.320828,10.757237,targeted,False,False,dr5/J104/J104325.29+190335.0.png,37,33,0.8918918918918919,0.8571428571428571,False,2,0.05405405405405406,0.038461538461538464,False,2,0.05405405405405406,0.022165799543054487,False,2,0,0.0,False,0.0,False,2,1.0,False,1.0,False,2,0,0.0,False,0.0,False,2,1.0,False,1.0,False,2,0,0.0,False,0.0,False,0,0.0,False,0.0,False,2,1.0,False,1.0,False,2,1,0.5,False,0.2,False,0,0.0,False,0.0,False,1,0.5,False,0.4697643948915614,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,33,4,0.12121212121212122,False,0.3103448275862069,False,27,0.8181818181818182,False,0.8888888888888888,False,2,0.06060606060606061,False,0.03125,False,0,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0.0,True,0.0,False,0,0.0,True,0.0,False,35,0,0.0,False,0.0,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,2,0.05714285714285714,False,0.06075684859652457,False,132.44312905608277,False,14,12246
J104629.54+115415.1,161.62313506655573,11.904197333613393,0.09290579,-19.947397,False,125.076324,4.71987,13.139816,9.726173,active_baseline,False,False,dr5/J104/J104629.54+115415.1.png,5,1,0.2,0.0,True,4,0.8,0.0,True,0,0.0,0.0,True,4,0,0.0,False,0.0,True,4,1.0,False,0.0,True,4,4,1.0,False,0.0,True,0,0.0,False,0.0,True,4,0,0.0,False,0.0,True,2,0.5,False,0.0,True,2,0.5,False,0.0,True,4,0,0.0,False,0.0,True,0,0.0,False,0.0,True,2,0.5,False,0.0,True,2,0.5,False,0.0,True,0,0.0,False,0.0,True,1,0,0.0,False,0.0,True,1,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,4,2,0.5,False,0.0,True,1,0.25,False,0.0,True,1,0.25,False,0.0,True,4,0,0.0,False,0.0,True,4,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,5,4,0.8,False,0.0,True,1,0.2,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,156.33630171865005,False,15,72305
J082950.68+125621.8,127.46119002717477,12.939386194422855,0.06661941,-19.800629,False,173.59203,2.929065,7.2249613,6.5283313,active_baseline,False,False,dr5/J082/J082950.68+125621.8.png,8,2,0.25,0.0,True,6,0.75,0.0,True,0,0.0,0.0,True,6,6,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,2,0,0.0,False,0.0,True,0,0.0,False,0.0,True,2,1.0,False,0.0,True,6,0,0.0,False,0.0,True,3,0.5,False,0.0,True,3,0.5,False,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,8,7,0.875,False,0.0,True,1,0.125,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,155.38637876717456,False,14,540350
J122056.00-015022.0,185.23334172546194,-1.839339610760786,0.070041895,-19.013256,False,88.82688,2.5218973,6.4968877,5.054409,active_baseline,False,False,dr5/J122/J122056.00-015022.0.png,5,2,0.4,0.0,True,3,0.6,0.0,True,0,0.0,0.0,True,3,0,0.0,False,0.0,True,3,1.0,False,0.0,True,3,2,0.6666666666666666,False,0.0,True,1,0.3333333333333333,False,0.0,True,3,0,0.0,False,0.0,True,0,0.0,False,0.0,True,3,1.0,False,0.0,True,3,0,0.0,False,0.0,True,0,0.0,False,0.0,True,2,0.6666666666666666,False,0.0,True,1,0.3333333333333333,False,0.0,True,0,0.0,False,0.0,True,2,0,0.0,False,0.0,True,2,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,2,0,0.0,False,0.0,True,1,0.5,False,0.0,True,1,0.5,False,0.0,True,2,0,0.0,False,0.0,True,2,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,5,5,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,143.73549146425276,False,15,343495
J100927.56+071112.4,152.36482996579448,7.186803120438202,0.10094784,-21.110653,False,283.3018,3.7220948,9.933681,9.42301,pre_active,False,False,dr5/J100/J100927.56+071112.4.png,34,14,0.4117647058823529,0.05,False,18,0.5294117647058824,0.9387755102040816,False,2,0.058823529411764705,0.019948665610765925,False,18,1,0.05555555555555555,False,0.027777777777777776,False,17,0.9444444444444444,False,0.9789584054354445,False,17,7,0.4117647058823529,False,0.9074074074074074,False,10,0.5882352941176471,False,0.31074329617904756,False,17,12,0.7058823529411765,False,0.8333333333333334,False,4,0.23529411764705882,False,0.1724137931034483,False,1,0.058823529411764705,False,0.030303030303030307,False,17,0,0.0,False,0.0,False,2,0.11764705882352941,False,0.04651162790697674,False,12,0.7058823529411765,False,0.656247160265868,False,3,0.17647058823529413,False,0.42372881355932196,False,0,0.0,False,0.0,False,14,0,0.0,False,0.0,False,13,0.9285714285714286,False,0.9,False,1,0.07142857142857142,False,0.1111111111111111,False,1,0,0.0,False,0.0,False,0,0.0,False,0.0,False,1,1.0,False,1.0,False,7,2,0.2857142857142857,False,0.5021147934944031,False,5,0.7142857142857143,False,0.6368387948547768,False,0,0.0,False,0.0,False,7,0,0.0,False,0.0,False,4,0.5714285714285714,False,0.2258064516129032,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,3,0.42857142857142855,False,0.4782608695652174,False,32,9,0.28125,False,0.3793103448275862,False,11,0.34375,False,0.36342403741914225,False,8,0.25,False,0.3783783783783784,False,4,0.125,False,0.12973556664448804,False,151.7612214530218,False,14,347249
J151949.21+280418.7,229.95505514677487,28.071839897306823,0.045993112,-19.293116,False,275.67108,1.5636468,5.200698,3.7963612,active_baseline,False,False,dr5/J151/J151949.21+280418.7.png,11,8,0.7272727272727273,0.0,True,2,0.18181818181818182,0.0,True,1,0.09090909090909091,0.0,True,2,0,0.0,False,0.0,True,2,1.0,False,0.0,True,2,0,0.0,False,0.0,True,2,1.0,False,0.0,True,2,0,0.0,False,0.0,True,0,0.0,False,0.0,True,2,1.0,False,0.0,True,2,0,0.0,False,0.0,True,1,0.5,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,1,0.5,False,0.0,True,8,3,0.375,False,0.0,True,5,0.625,False,0.0,True,0,0.0,False,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,10,4,0.4,False,0.0,True,2,0.2,False,0.0,True,0,0.0,False,0.0,True,4,0.4,False,0.0,True,135.63854519999168,False,14,319047
J143254.45+034938.1,218.22690475069928,3.8272346054241573,0.14696836,-20.958143,False,92.36639,1.9810072,4.9577327,4.7013,active_priority,True,False,dr5/J143/J143254.45+034938.1.png,52,21,0.40384615384615385,0.0,False,29,0.5576923076923077,1.0,False,2,0.038461538461538464,0.0,False,29,1,0.034482758620689655,False,0.019230769230769232,False,28,0.9655172413793104,False,0.9924054032168147,False,28,23,0.8214285714285714,False,1.0,False,5,0.17857142857142858,False,0.0050447146205390885,False,28,5,0.17857142857142858,False,0.28125,False,8,0.2857142857142857,False,0.35135135135135137,False,15,0.5357142857142857,False,0.3636363636363637,False,28,1,0.03571428571428571,False,0.012195121951219513,False,6,0.21428571428571427,False,0.02631578947368421,False,14,0.5,False,0.4663063354929633,False,5,0.17857142857142858,False,0.6,False,2,0.07142857142857142,False,0.01927631821910397,False,21,5,0.23809523809523808,False,0.32142857142857145,False,15,0.7142857142857143,False,0.6451612903225806,False,1,0.047619047619047616,False,0.04838709677419355,False,1,0,0.0,False,0.0,False,0,0.0,False,0.0,False,1,1.0,False,1.0,False,23,8,0.34782608695652173,False,0.6882152304624496,False,10,0.43478260869565216,False,0.337679486757324,False,5,0.21739130434782608,False,0.04931118703919616,False,23,1,0.043478260869565216,False,0.015384615384615384,False,20,0.8695652173913043,False,0.5263157894736842,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,0,0.0,False,0.0,False,2,0.08695652173913043,False,0.2058823529411765,False,50,36,0.72,False,0.6515151515151515,False,5,0.1,False,0.12562879805862665,False,7,0.14,False,0.3448275862068966,False,2,0.04,False,0.04626767086152703,False,145.17405458324777,False,15,186723
J121421.61+271034.5,183.59008332238614,27.176257381071284,0.13253377,-20.639168,False,78.11458,1.9619702,5.7860413,4.1267076,active_baseline,False,False,dr5/J121/J121421.61+271034.5.png,5,4,0.8,0.0,True,1,0.2,0.0,True,0,0.0,0.0,True,1,0,0.0,False,0.0,True,1,1.0,False,0.0,True,1,0,0.0,False,0.0,True,1,1.0,False,0.0,True,1,0,0.0,False,0.0,True,0,0.0,False,0.0,True,1,1.0,False,0.0,True,1,0,0.0,False,0.0,True,0,0.0,False,0.0,True,1,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,4,4,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,0,0.0,True,0.0,True,5,5,1.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,0,0.0,False,0.0,True,150.7240868635341,False,15,252857
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [6]:
# Select first 80% for train and last 20% for test
train_indices = int(0.9 * len(classifications))

X_train, X_test = {}, {}
for key in X.keys():
    X_train[key] = X[key][:train_indices]
    X_test[key] = X[key][train_indices:]

classifications_train, classifications_test = (
    classifications[:train_indices],
    classifications[train_indices:],
)

In [7]:
# This is the total number of possible votes
total_counts_train = classifications_train[keys["smooth"]["counts"]].data

# Get accuracy and F1 score on each question
outputs = {key: {} for key in X.keys()}
for name in names:
    question, num_classes = name, len(keys[name]["target"])

    # Get the train samples above 50% answered
    counts_train = classifications_train[keys[name]["counts"]].data
    #train_mask = np.where(counts_train / total_counts_train > 0.5)[0]
    train_mask = [True] * len(counts_train)

    # Get the test samples above 34 answers
    counts_test = classifications_test[keys[name]["counts"]].data
    test_mask = np.where(counts_test > 34)[0]
    #test_mask = [True] * len(counts_test)

    # Get train and test
    y_train = torch.tensor(
        classifications_train[keys[name]["target"]].to_pandas().values
    )[train_mask]
    y_test = torch.tensor(
        classifications_test[keys[name]["target"]].to_pandas().values
    )[test_mask]

    train_nan_mask = torch.isnan(y_train).any(axis=1)
    test_nan_mask = torch.isnan(y_test).any(axis=1)

    # Train and evaluate on each model
    print(f"Training on question: {question}...")
    for model in X.keys():
        X_train_local = X_train[model][train_mask][~train_nan_mask]
        X_test_local = X_test[model][test_mask][~test_nan_mask]
        outputs[model][name] = train_eval_on_question(
            X_train_local,
            X_test_local,
            y_train,
            y_test,
            X_train_local.shape[1],
            num_classes=num_classes,
            MLP_dim=256,
            epochs=25,
            dropout=0.2,
        )
        print(
            f"Model: {model}, Accuracy: {outputs[model][name]['Accuracy']:.4f}, F1: {outputs[model][name]['F1 Score']:.4f}"
        )
    print("Done!")

Training on question: smooth...
Model: astrodino_3z_7e5_vit_base_embeddings, Accuracy: 0.7889, F1: 0.6961
Model: astrodino_3z_7e5_vit_large_embeddings, Accuracy: 0.7889, F1: 0.6961
Done!
Training on question: disk-edge-on...
Model: astrodino_3z_7e5_vit_base_embeddings, Accuracy: 0.9393, F1: 0.9321
Model: astrodino_3z_7e5_vit_large_embeddings, Accuracy: 0.8850, F1: 0.8310
Done!
Training on question: spiral-arms...
Model: astrodino_3z_7e5_vit_base_embeddings, Accuracy: 0.8619, F1: 0.8987
Model: astrodino_3z_7e5_vit_large_embeddings, Accuracy: 0.8536, F1: 0.8940
Done!
Training on question: bar...
Model: astrodino_3z_7e5_vit_base_embeddings, Accuracy: 0.5230, F1: 0.3562
Model: astrodino_3z_7e5_vit_large_embeddings, Accuracy: 0.5230, F1: 0.3562
Done!
Training on question: bulge-size...
Model: astrodino_3z_7e5_vit_base_embeddings, Accuracy: 0.7280, F1: 0.7184
Model: astrodino_3z_7e5_vit_large_embeddings, Accuracy: 0.7741, F1: 0.7681
Done!
Training on question: how-rounded...
Model: astrodino

In [8]:
outputs.keys()

dict_keys(['astrodino_3z_7e5_vit_base_embeddings', 'astrodino_3z_7e5_vit_large_embeddings'])

In [9]:
# Clean up labels: remove substrings from keys
REMOVE_SUBSTRINGS = ["_3z_band_stable", "_embeddings"]

def _strip_name(name: str) -> str:
    out = name
    for s in REMOVE_SUBSTRINGS:
        out = out.replace(s, "")
    # collapse consecutive underscores
    while "__" in out:
        out = out.replace("__", "_")
    return out.strip("_")

# Rename all keys in outputs dict
outputs = { _strip_name(k): v for k, v in outputs.items() }
print("Renamed output keys:", list(outputs.keys()))

# (Optional) map to prettier display labels if desired
# label_map = {
#     "astrodino": "AstroDINO Base",
#     "astrodino_multi_epochs": "AstroDINO Multi Epochs",
#     "astrodino_multi_epochs_vit_large": "AstroDINO Large",
# }
# outputs = { label_map.get(k, k): v for k, v in outputs.items() }

# Plot radar plots
plot_radar(outputs, metric="Accuracy", file_path=f"./outputs/radar_accuracy.png")
plot_radar(outputs, metric="F1 Score", file_path=f"./outputs/radar_f1_score.png")

Renamed output keys: ['astrodino_3z_7e5_vit_base', 'astrodino_3z_7e5_vit_large']


In [78]:
outputs.keys()

dict_keys(['astrodino', 'astrodino_multi_epochs', 'astrodino_multi_epochs_vit_large'])