In [None]:
import argparse
import os, sys
import torch
import numpy as np
import h5py
from tqdm import tqdm
from helpers import process, to_pil_image, dr2_rgb
from PIL import Image as im
from astropy.table import Table, join, vstack
from astropy.coordinates import SkyCoord, match_coordinates_sky
from astropy import units as u
from torchvision.transforms import (
    Compose,
    ToTensor,
    Normalize,
    Resize,
    InterpolationMode,
    CenterCrop,
)

sys.path.insert(
    0,
    os.path.abspath(
        "/mnt/home/lparker/Documents/AstroFoundationModel/AstroDino/dinov2/"
    ),
)
sys.path.insert(
    0, os.path.abspath("/mnt/home/lparker/Documents/AstroFoundationModel/AstroDino/")
)
from dinov2.utils.config import setup
from dinov2.models import build_model_from_cfg
from dinov2.fsdp import FSDPCheckpointer
from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.eval.setup import setup_and_build_model
from dinov2.data.transforms import make_normalize_transform

# Image Files locations
files_north = [
    os.path.join(
        "/mnt/ceph/users/polymathic/external_data/astro/DECALS_Stein_et_al/north/",
        "images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1),
    )
    for i in range(10)
]
files_south = [
    os.path.join(
        "/mnt/ceph/users/polymathic/external_data/astro/DECALS_Stein_et_al/south/",
        "images_npix152_0%02d000000_0%02d000000.h5" % (i, i + 1),
    )
    for i in range(62)
]

# Classifications location
gz5_decals_path = "/mnt/home/lparker/ceph/gz_decals_volunteers_5.csv"
gz2_sdss_path = "/mnt/home/lparker/ceph/gz2_hart16.csv"

# Transformations for models
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
img_transforms = Compose(
    [
        to_pil_image,
        Resize(152, InterpolationMode.BICUBIC),
        ToTensor(),
        CenterCrop(144),
        Normalize(MEAN, STD),
    ]
)


class config:
    output_dir = "/mnt/home/lparker/ceph/dino_training"
    config_file = "../astrodino/configs/ssl_default_config.yaml"
    pretrained_weights = "/mnt/home/lparker/ceph/astrodino/vitl12_simplified_better_wd/training_199999/teacher_checkpoint.pth"
    opts = []


def get_paired_classifications(sky, gz_survey):
    if sky == "south":
        files = files_south
    elif sky == "north":
        files = files_north
    else:
        raise ValueError("Not supported sky type, choose south or north")

    if gz_survey == "gz2":
        classifications_path = gz2_sdss_path
    elif gz_survey == "gz5":
        classifications_path = gz5_decals_path
    else:
        raise ValueError("Not supported gz_survey type, choose gz2 or gz5")

    print(f"Sky type is {sky}, survey type is {gz_survey}", flush=True)

    morphologies = Table.read(classifications_path, format="ascii")

    ra_list = []
    dec_list = []
    index_list = []
    file_list = []

    print("Processing files", flush=True)
    for i, file in enumerate(tqdm(files)):
        with h5py.File(file, "r") as f:
            ra = f["ra"][:]
            dec = f["dec"][:]

            # Append data to lists
            ra_list.extend(ra)
            dec_list.extend(dec)
            file_list.extend([file] * len(ra))
            index_list.extend(range(0, len(ra)))

    positions = Table(
        [ra_list, dec_list, index_list, file_list], names=("ra", "dec", "index", "file")
    )

    table1 = positions
    table2 = morphologies

    coords1 = SkyCoord(ra=table1["ra"] * u.degree, dec=table1["dec"] * u.degree)
    coords2 = SkyCoord(ra=table2["ra"] * u.degree, dec=table2["dec"] * u.degree)

    print("Matching coordinates", flush=True)
    idx, d2d, d3d = coords1.match_to_catalog_sky(coords2).to(u.arcsec).value

    max_sep = 0.5 * u.arcsec
    sep_constraint = d2d < max_sep

    classifications = table2[idx[sep_constraint]]
    positions_matched = table1[sep_constraint]
    classifications["index"] = np.array(positions_matched["index"])
    classifications["file"] = np.array(positions_matched["file"])
    classifications["image"] = np.zeros((len(classifications), 3, 152, 152))

    print("Generating catalog with images", flush=True)
    for i, file in enumerate(files):
        print(f"Processing file {i+1}/{len(files)}", flush=True)
        images = []
        with h5py.File(file, "r") as f:
            for k, entry in enumerate(tqdm(classifications)):
                if entry["file"] != file:
                    continue
                index = entry["index"]
                classifications[k]["image"] = f["images"][index]
    return classifications

In [None]:
classifications = get_paired_classifications("north", "gz5")

# Load Pretrained Models

In [None]:
import sys

sys.path.insert(
    0,
    "/mnt/home/lparker/Documents/AstroFoundationModel/AstroCLIP_legacy/notebooks/tutorial/",
)

from leopoldo import AstroCLIP, OutputExtractor, seq_decoder, config, forward_im
from tutorial_helpers import load_model_from_ckpt, forward
from torchvision.transforms import Compose, Normalize
import copy


class config:
    output_dir = "/mnt/home/lparker/ceph/dino_training"
    config_file = "/mnt/home/lparker/Documents/AstroFoundationModel/AstroDino_legacy/astrodino/configs/ssl_default_config.yaml"
    pretrained_weights = "/mnt/home/lparker/ceph/astrodino/vitl12_simplified_better_wd/training_199999/teacher_checkpoint.pth"
    opts = []


# Specify transforms
MEAN = (0.485, 0.456, 0.406)  # Imagenet default mean
STD = (0.229, 0.224, 0.225)  # Imagenet default std
img_transforms = Compose([Normalize(MEAN, STD)])

# set this
embed_dim = 512

# Define DINO model
img_model, dtype = setup_and_build_model(config())

DINO = copy.deepcopy(img_model)

# Extract encoder_q from Moco_v2 model
img_model.forward = forward_im.__get__(img_model)
img_model = OutputExtractor(img_model, embed_dim=embed_dim, freeze_backbone=True)
num_params = np.sum(np.fromiter((p.numel() for p in img_model.parameters()), int))
print(f"Number of parameters in image model: {num_params:,}")

# The model is saved in the Seqformer branch of Fi-LLM
model_path = "/mnt/home/sgolkar/ceph/saves/fillm/run-seqformer-2708117"
out = load_model_from_ckpt(model_path)
config = out["config"]
spec_model = out["model"]
spec_model.forward = forward.__get__(spec_model, type(img_model))
num_params = np.sum(np.fromiter((p.numel() for p in spec_model.parameters()), int))
print(f"Number of parameters in spectrum model: {num_params:,}")

# Define image and spectrum encoders
image_encoder = img_model
spectrum_encoder = seq_decoder(
    model=spec_model, embed_dim=embed_dim, freeze_backbone=True
)

# Set up AstroCLIP
astroclip = AstroCLIP(image_encoder, spectrum_encoder, 1)

In [None]:
file = "/mnt/home/lparker/Documents/AstroFoundationModel/AstroCLIP_legacy/notebooks/tutorial/astroclip-clip-explore/03x73csv/checkpoints/epoch=14-step=2310.ckpt"

ckpt = torch.load(file)
astroclip.load_state_dict(ckpt["state_dict"])

# Transforms

In [None]:
import lightning as L
import torch.nn as nn
from torchvision import transforms
from PIL import Image as im


def sdss_rgb(imgs, bands, scales=None, m=0.02):
    rgbscales = {
        "u": (2, 1.5),  # 1.0,
        "g": (2, 2.5),
        "r": (1, 1.5),
        "i": (0, 1.0),
        "z": (0, 0.4),  # 0.3
    }
    if scales is not None:
        rgbscales.update(scales)

    I = 0
    for img, band in zip(imgs, bands):
        plane, scale = rgbscales[band]
        img = torch.maximum(torch.tensor(0), img * scale + m)
        I = I + img
    I /= len(bands)
    Q = 20
    fI = torch.arcsinh(Q * I) / torch.sqrt(torch.tensor(Q))
    I += (I == 0.0) * 1e-6
    H, W = I.shape
    rgb = torch.zeros((H, W, 3)).to(torch.float32)
    for img, band in zip(imgs, bands):
        plane, scale = rgbscales[band]
        rgb[:, :, plane] = (img * scale + m) * fI / I
    rgb = torch.clip(rgb, 0, 1)
    return rgb


def dr2_rgb(rimgs, bands, **ignored):
    return sdss_rgb(
        rimgs, bands, scales=dict(g=(2, 6.0), r=(1, 3.4), z=(0, 2.2)), m=0.03
    )


class toRGB(transforms.ToTensor):
    def __init__(self, bands, scales=None, m=0.02):
        self.bands = bands
        self.scales = scales
        self.m = m

    def __call__(self, rimgs):
        if len(rimgs.shape) == 3:
            return dr2_rgb(rimgs.T, self.bands).T
        if len(rimgs.shape) == 4:
            img_outs = []
            for img in rimgs:
                img_outs.append(dr2_rgb(img.T, self.bands).T[None, :, :, :])
            return torch.concatenate(img_outs)


MEAN, STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

img_transforms = Compose(
    [
        Resize(152, InterpolationMode.BICUBIC),
        ToTensor(),
        CenterCrop(144),
        Normalize(MEAN, STD),
    ]
)

to_rgb = toRGB(bands=["g", "r", "z"])

# Rest

In [None]:
# import lightning as L
# from pl_bolts.models.self_supervised import Moco_v2

# class OutputExtractor(L.LightningModule):
#     """
#     Pass data through network to extract model outputs
#     """
#     def __init__(self, backbone: torch.nn.Module):
#         super(OutputExtractor, self).__init__()
#         self.backbone = backbone
#         self.backbone.eval()

#     def forward(self, batch):
#         x = batch
#         z_emb = self.backbone(x)
#         return z_emb

#     def predict(self, batch, batch_idx: int, dataloader_idx: int=None):
#         return self(batch)

# # Extract encoder_q from Moco_v2 model
# moco_model = Moco_v2.load_from_checkpoint(checkpoint_path='/mnt/ceph/users/flanusse/resnet50.ckpt')
# backbone = moco_model.encoder_q
# moco_model = OutputExtractor(backbone).to('cuda')

In [None]:
astroclip.cuda()
DINO.cuda()
# moco_model.cuda()
CLIP_embeddings = []
DINO_embeddings = []
stein_embeddings = []

batch, batch_size = [], 512
total_batches = len(classifications) // batch_size

for k, entry in enumerate(tqdm(classifications)):
    image = torch.tensor(entry["image"])
    batch.append(image)

    if len(batch) == batch_size:
        batch = torch.stack(batch)

        # with torch.no_grad():
        # stein_embeddings.append(moco_model(batch.to(torch.float32).cuda()).detach().cpu().numpy())

        batch = (
            np.array(to_rgb(batch.permute(0, 2, 3, 1)) * 255)
            .astype("uint8")
            .transpose(0, 2, 3, 1)
        )
        batch = torch.stack(
            [img_transforms(im.fromarray(batch[i])) for i in range(batch.shape[0])]
        )

        with torch.no_grad():
            CLIP_embeddings.append(
                astroclip(batch.cuda(), image=True).detach().cpu().numpy()
            )
            DINO_embeddings.append(DINO.forward(batch.cuda()).detach().cpu().numpy())
        batch = []

    # do last batch
    if k == len(classifications) - 1:
        batch = torch.stack(batch)

        # with torch.no_grad():
        #     stein_embeddings.append(moco_model(batch.to(torch.float32).cuda()).detach().cpu().numpy())

        batch = (
            np.array(to_rgb(batch.permute(0, 2, 3, 1)) * 255)
            .astype("uint8")
            .transpose(0, 2, 3, 1)
        )
        batch = torch.stack(
            [img_transforms(im.fromarray(batch[i])) for i in range(batch.shape[0])]
        )

        with torch.no_grad():
            CLIP_embeddings.append(
                astroclip(batch.cuda(), image=True).detach().cpu().numpy()
            )
            DINO_embeddings.append(DINO.forward(batch.cuda()).detach().cpu().numpy())

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler


class MLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim, dropout_rate):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Dropout(dropout_rate),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Dropout(dropout_rate),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Dropout(dropout_rate),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes),
        )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.layers(x)
        x = self.softmax(x)
        return x.squeeze()


def train_eval_MLP(
    X_train,
    X_test,
    y_train,
    y_test,
    embed_dim,
    num_classes,
    MLP_dim=128,
    lr=1e-3,
    epochs=25,
    dropout=0.2,
):
    # Split the dataset into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.1, random_state=42
    )

    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    # Create a DataLoader
    samples_weight = y_train.max(dim=1).values  # Taking max fraction as the weight

    train_loader = DataLoader(
        train_dataset,
        batch_size=256,
        sampler=WeightedRandomSampler(samples_weight, len(samples_weight)),
    )
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    mlp = MLP(embed_dim, num_classes, MLP_dim, dropout)
    criterion = nn.BCEWithLogitsLoss()  # Suitable for multi-label classification
    optimizer = optim.Adam(mlp.parameters(), lr=lr)

    # Training loop
    best_val_loss = float("inf")
    best_metrics = None

    for epoch in range(epochs):  # Define your number of epochs
        mlp.train()
        train_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = mlp(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        # Validation loop
        mlp.eval()
        val_loss = 0
        with torch.no_grad():
            for data, target in val_loader:
                output = mlp(data)
                loss = criterion(output, target)
                val_loss += loss.item()

        val_loss /= len(val_loader)

        # Report every 25 epochs
        # if epoch % 25 == 0:
        #    print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

        # Save best model based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = mlp.state_dict()

    mlp.load_state_dict(best_model)
    y_pred = mlp(X_test).detach()

    y_pred = (y_pred == torch.max(y_pred, dim=1, keepdim=True).values).int()
    y_true = (y_test == torch.max(y_test, dim=1, keepdim=True).values).int()

    accuracy = accuracy_score(y_true.numpy(), y_pred.numpy())
    f1_score = precision_recall_fscore_support(
        y_true.numpy(), y_pred.numpy(), average="weighted", zero_division=0
    )[2]
    return {"Accuracy": accuracy, "F1 Score": f1_score}

In [None]:
X_CLIP = torch.tensor(np.concatenate(CLIP_embeddings))
X_DINO = torch.tensor(np.concatenate(DINO_embeddings))
X_Stein = torch.tensor(np.concatenate(stein_embeddings))

X = {}

X["CLIP"] = X_CLIP
X["DINO"] = X_DINO
X["Stein"] = X_Stein

In [None]:
dir = "/mnt/home/lparker/ceph/gz5_south/"

X = {}

X["CLIP"] = torch.load(dir + "X_CLIP.pt")
X["DINO"] = torch.load(dir + "X_DINO.pt")
X["Stein"] = torch.load(dir + "X_Stein.pt")

classifications = Table.read(dir + "classifications.csv")

In [None]:
# shuffle all X and classifications in the same way
np.random.seed(42)

shuffled_indices = np.random.permutation(len(classifications))

for key in X.keys():
    X[key] = X[key][shuffled_indices]

classifications = classifications[shuffled_indices]

In [None]:
# select first 80% for train and last 20% for test
train_indices = int(0.8 * len(classifications))

X_train = {}
X_test = {}

for key in X.keys():
    X_train[key] = X[key][:train_indices]
    X_test[key] = X[key][train_indices:]

classifications_train, classifications_test = (
    classifications[:train_indices],
    classifications[train_indices:],
)

In [None]:
keys = {}
names = [
    "smooth",
    "disk-edge-on",
    "spiral-arms",
    "bar",
    "bulge-size",
    "how-rounded",
    "edge-on-bulge",
    "spiral-winding",
    "spiral-arm-count",
    "merging",
]
for name in names:
    local_dict = {}
    local_dict["debiased"] = [
        key for key in classifications.colnames if name in key and "debiased" in key
    ]
    local_dict["counts"] = [
        key for key in classifications.colnames if name in key and "total-votes" in key
    ]
    keys[name] = local_dict

In [None]:
total_counts_train = classifications_train[keys["smooth"]["counts"]].to_pandas().values

outputs = {"CLIP": {}, "DINO": {}, "Stein": {}}

for name in names:
    question, num_classes = name, len(keys[name]["debiased"])

    counts_train = classifications_train[keys[name]["counts"]].to_pandas().values
    pct_answered = np.array(counts_train / total_counts_train)
    above50 = np.where(pct_answered > 0.5)[0]

    y_train = torch.tensor(
        classifications_train[keys[name]["debiased"]].to_pandas().values
    )[above50]
    train_mask = torch.isnan(y_train).any(axis=1)
    y_train = y_train[~train_mask]

    counts_test = np.array(
        classifications_test[keys[name]["counts"]].to_pandas().values
    )
    above35 = np.where(counts_test > 35)[0]

    y_test = torch.tensor(
        classifications_test[keys[name]["debiased"]].to_pandas().values
    )[above35]
    test_mask = torch.isnan(y_test).any(axis=1)
    y_test = y_test[~test_mask]

    categories = keys[name]["debiased"]
    print(f"Question: {question}, Classes: {categories}")
    print(f"Number of classes: {num_classes}, Number of samples: {len(y_test)}")

    for model in X.keys():
        X_train_local = X_train[model][above50][~train_mask]
        X_test_local = X_test[model][above35][~test_mask]

        outputs[model][name] = train_eval_MLP(
            X_train_local,
            X_test_local,
            y_train,
            y_test,
            X_train_local.shape[1],
            num_classes=num_classes,
            MLP_dim=128,
            epochs=25,
            dropout=0.2,
        )
        print(
            f'{model} - Accuracy: {outputs[model][name]["Accuracy"]:.4f}, F1 Score: {outputs[model][name]["F1 Score"]:.4f}'
        )

    print("")

In [None]:
from math import pi
import matplotlib.pyplot as plt


def plot_radar(
    outputs, metric, file_path, title="Galaxy Property Estimation", fontsize=25
):
    questions = {}
    for key in outputs.keys():
        questions[key] = [
            outputs[key][question][metric] for question in outputs[key].keys()
        ]

    # Create radar chart
    angles = np.linspace(0, 2 * pi, len(questions[key]), endpoint=False).tolist()
    angles += angles[:1]  # complete the loop

    fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))

    colors = ["red", "red", "black", "blue"]
    styles = ["solid", "dashed", "solid", "solid"]

    # Plot each array on the radar chart
    for key in questions.keys():
        # if key == 'ZooBot': continue
        stats = [questions[key][i] for i in range(len(questions[key]))]
        stats += stats[:1]
        ax.plot(
            angles,
            stats,
            label=key,
            linewidth=2,
            linestyle=styles.pop(0),
            color=colors.pop(0),
        )

    labels = outputs[key].keys()

    # capitalize labels
    labels = [label.capitalize() for label in labels]

    # Add labels with specific fontsize
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)

    # Change r label to fontsize
    ax.tick_params(axis="y", labelsize=fontsize)

    ax.set_xticks(angles[:-1], labels, fontsize=fontsize, color="black")

    # make not overlap with plot
    # ax.set_xticklabels(labels, fontsize=fontsize)

    # make theta labels not overlap with plot
    ax.set_ylim(0, 1.0)

    # Add legend and title with specific fontsize
    legend = plt.legend(loc="upper right", bbox_to_anchor=(1.1, 1.1))
    plt.setp(
        legend.get_texts(), fontsize=fontsize
    )  # Explicitly set fontsize for legend

    plt.savefig(file_path)
    plt.close()

In [None]:
outputs["ZooBot"] = {
    "smooth": {"Accuracy": 0.94, "F1 Score": 0.94},
    "disk-edge-on": {"Accuracy": 0.99, "F1 Score": 0.99},
    "spiral-arms": {"Accuracy": 0.93, "F1 Score": 0.94},
    "bar": {"Accuracy": 0.82, "F1 Score": 0.81},
    "bulge-size": {"Accuracy": 0.84, "F1 Score": 0.84},
    "how-rounded": {"Accuracy": 0.93, "F1 Score": 0.93},
    "edge-on-bulge": {"Accuracy": 0.91, "F1 Score": 0.90},
    "spiral-winding": {"Accuracy": 0.78, "F1 Score": 0.79},
    "spiral-arm-count": {"Accuracy": 0.77, "F1 Score": 0.76},
    "merging": {"Accuracy": 0.88, "F1 Score": 0.85},
}

In [None]:
plot_radar(
    outputs,
    "Accuracy",
    title="Galaxy Property Estimation",
    file_path="accuracy.png",
    fontsize=16,
)

In [None]:
plot_radar(
    outputs,
    "F1 Score",
    title="Galaxy Property Estimation",
    file_path="f1_score.png",
    fontsize=16,
)

In [None]:
import pandas as pd

acc, f1 = {}, {}
for key in outputs.keys():
    acc[key] = [outputs[key][question]["Accuracy"] for question in outputs[key].keys()]
    f1[key] = [outputs[key][question]["F1 Score"] for question in outputs[key].keys()]

acc = pd.DataFrame(acc, index=outputs[key].keys())
f1 = pd.DataFrame(f1, index=outputs[key].keys())

In [None]:
acc

In [None]:
# save acc and f1
acc.to_csv("accuracy.csv")
f1.to_csv("f1_score.csv")

In [None]:
acc.mean(axis=0)

In [None]:
f1

In [None]:
# average over rows
f1.mean(axis=0)

# GZ2

In [None]:
# Get Results for Smooth Question

question, num_classes = "Smooth", 3

smooth = torch.tensor(
    classifications[
        "t01_smooth_or_features_a01_smooth_fraction",
        "t01_smooth_or_features_a02_features_or_disk_fraction",
        "t01_smooth_or_features_a03_star_or_artifact_fraction",
    ]
)

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP, smooth, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO, smooth, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein, smooth, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for Edge On Question

question, num_classes = "Edge On", 2

counts = (
    classifications["t02_edgeon_a04_yes_count"]
    + classifications["t02_edgeon_a05_no_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications["t02_edgeon_a04_yes_fraction", "t02_edgeon_a05_no_fraction"]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for Bar Question

question, num_classes = "Bar", 2

counts = (
    classifications["t03_bar_a06_bar_count"]
    + classifications["t03_bar_a07_no_bar_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications["t03_bar_a06_bar_fraction", "t03_bar_a07_no_bar_fraction"]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for Spiral Question

question, num_classes = "Spiral Count", 2

counts = (
    classifications["t04_spiral_a08_spiral_count"]
    + classifications["t04_spiral_a09_no_spiral_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications[
    "t04_spiral_a08_spiral_fraction", "t04_spiral_a09_no_spiral_fraction"
]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for Bulge Prominence Question

question, num_classes = "Bulge Prominence", 4

counts = (
    classifications["t05_bulge_prominence_a10_no_bulge_count"]
    + classifications["t05_bulge_prominence_a11_just_noticeable_count"]
    + classifications["t05_bulge_prominence_a12_obvious_count"]
    + classifications["t05_bulge_prominence_a13_dominant_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications[
    "t05_bulge_prominence_a10_no_bulge_fraction",
    "t05_bulge_prominence_a11_just_noticeable_fraction",
    "t05_bulge_prominence_a12_obvious_fraction",
    "t05_bulge_prominence_a13_dominant_fraction",
]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for How Rounded Question

question, num_classes = "How Rounded", 3

counts = (
    classifications["t07_rounded_a16_completely_round_count"]
    + classifications["t07_rounded_a17_in_between_count"]
    + classifications["t07_rounded_a18_cigar_shaped_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications[
    "t07_rounded_a16_completely_round_fraction",
    "t07_rounded_a17_in_between_fraction",
    "t07_rounded_a18_cigar_shaped_fraction",
]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for Bulge Shape Question

question, num_classes = "Bulge Shape", 3

counts = (
    classifications["t09_bulge_shape_a25_rounded_count"]
    + classifications["t09_bulge_shape_a26_boxy_count"]
    + classifications["t09_bulge_shape_a27_no_bulge_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications[
    "t07_rounded_a16_completely_round_fraction",
    "t07_rounded_a17_in_between_fraction",
    "t07_rounded_a18_cigar_shaped_fraction",
]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for Spiral Arm Type

question, num_classes = "Spiral Arm Type", 3

counts = (
    classifications["t10_arms_winding_a28_tight_count"]
    + classifications["t10_arms_winding_a29_medium_count"]
    + classifications["t10_arms_winding_a30_loose_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications[
    "t10_arms_winding_a28_tight_fraction",
    "t10_arms_winding_a29_medium_fraction",
    "t10_arms_winding_a30_loose_fraction",
]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
# Get Results for Spiral Arm Count

question, num_classes = "Spiral Arm Count", 6

counts = (
    classifications["t11_arms_number_a31_1_count"]
    + classifications["t11_arms_number_a32_2_count"]
    + classifications["t11_arms_number_a33_3_count"]
    + classifications["t11_arms_number_a34_4_count"]
    + classifications["t11_arms_number_a36_more_than_4_count"]
    + classifications["t11_arms_number_a37_cant_tell_count"]
)
total = classifications["total_classifications"]
pct_answered = np.array(counts / total)
above50 = np.where(pct_answered > 0.5)[0]

y = classifications[
    "t11_arms_number_a31_1_fraction",
    "t11_arms_number_a32_2_fraction",
    "t11_arms_number_a33_3_fraction",
    "t11_arms_number_a34_4_fraction",
    "t11_arms_number_a36_more_than_4_fraction",
    "t11_arms_number_a37_cant_tell_fraction",
]
y = torch.tensor(y[above50])

print("CLIP")
outputs["CLIP"][question] = train_eval_MLP(
    X_CLIP[above50], y, embed_dim=512, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("DINO")
outputs["DINO"][question] = train_eval_MLP(
    X_DINO[above50], y, embed_dim=1024, num_classes=num_classes, MLP_dim=32, epochs=100
)

print("Stein")
outputs["Stein"][question] = train_eval_MLP(
    X_Stein[above50], y, embed_dim=128, num_classes=num_classes, MLP_dim=32, epochs=100
)

In [None]:
def plot_radar(
    outputs, metric, file_path, title="Galaxy Property Estimation", fontsize=22
):
    questions = {}
    for key in outputs.keys():
        questions[key] = [
            outputs[key][question][metric] for question in outputs[key].keys()
        ]

    # Create radar chart
    angles = np.linspace(0, 2 * pi, len(questions[key]), endpoint=False).tolist()
    angles += angles[:1]  # complete the loop

    fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))

    # Plot each array on the radar chart
    for key in questions.keys():
        stats = questions[key]
        stats += stats[:1]
        ax.plot(angles, stats, label=key)

    labels = outputs[key].keys()

    # Add labels with specific fontsize
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)

    # Change r label to fontsize
    ax.tick_params(axis="y", labelsize=fontsize)

    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(
        labels, fontsize=fontsize
    )  # Explicitly set fontsize for xtick labels

    # make theta labels not overlap with plot
    ax.set_ylim(0, 1)

    # Add legend and title with specific fontsize
    legend = plt.legend(loc="upper right", bbox_to_anchor=(1.1, 1.1))
    plt.setp(
        legend.get_texts(), fontsize=fontsize
    )  # Explicitly set fontsize for legend

    plt.savefig(file_path)
    plt.close()

In [None]:
plot_radar(
    outputs,
    "Accuracy",
    title="Galaxy Property Estimation",
    file_path="accuracy.png",
    fontsize=16,
)

In [None]:
plot_radar(
    outputs,
    "F1 Score",
    title="Galaxy Property Estimation",
    file_path="f1_score.png",
    fontsize=16,
)