In [None]:
import os
import numpy as np
import torch
from astropy.table import Table
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

from utils.models import few_shot, zero_shot
from utils.plotting import plot_redshift_scatter

PROVABGS_ROOT = "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/"
SUPERVISED_ROOT = "/mnt/ceph/users/polymathic/astroclip/supervised/"

# Define models in embeddings
image_models = ["astroclip_image", "astrodino", "stein"]
spectrum_models = ["astroclip_spectrum", "specformer"]

# Set up the paths
train_path = os.path.join(PROVABGS_ROOT, "provabgs_paired_train_embeddings.hdf5")
test_path = os.path.join(PROVABGS_ROOT, "provabgs_paired_test_embeddings.hdf5")

# Get embeddings and PROVABGS table
train_provabgs = Table.read(train_path)
test_provabgs = Table.read(test_path)

# Get properties and scale
properties = ["Z_MW", "LOG_MSTAR", "TAGE_MW", "sSFR"]
y_train = np.stack([train_provabgs[prop].data.squeeze() for prop in properties]).T
y_test = np.stack([test_provabgs[prop].data.squeeze() for prop in properties]).T
scaler = {"mean": y_train.mean(axis=0), "std": y_train.std(axis=0)}
y_train = (y_train - scaler["mean"]) / scaler["std"]

print(
    "Size of training set:",
    len(train_provabgs),
    "\nSize of test set:",
    len(test_provabgs),
)

# Galaxy Property Prediction from Image Embeddings

In [None]:
# Get data
data = {}
for model in image_models:
    data[model] = {}
    X_train, X_test = (
        train_provabgs[model + "_embeddings"],
        test_provabgs[model + "_embeddings"],
    )
    embedding_scaler = StandardScaler().fit(X_train)
    data[model]["train"] = embedding_scaler.transform(X_train)
    data[model]["test"] = embedding_scaler.transform(X_test)

In [None]:
# Perfrom knn and mlp
preds_knn, preds_mlp = {}, {}
for key in data.keys():
    print(f"Evaluating {key} model...")
    raw_preds_knn = zero_shot(data[key]["train"], y_train, data[key]["test"])
    raw_preds_mlp = few_shot(
        model, data[key]["train"], y_train, data[key]["test"]
    ).squeeze()
    preds_knn[key] = raw_preds_knn * scaler["std"] + scaler["mean"]
    preds_mlp[key] = raw_preds_mlp * scaler["std"] + scaler["mean"]

In [None]:
# Get predictions from supervised models
resnet_preds = torch.load(
    os.path.join(SUPERVISED_ROOT, "image/global_properties/test_pred.pt")
)
photometry_preds = torch.load(
    os.path.join(SUPERVISED_ROOT, "photometry/global_properties/test_pred.pt")
)

# Add predictions to dictionary
preds_knn["resnet18"] = np.stack(
    [resnet_preds[prop].squeeze() for prop in properties]
).T
preds_knn["photometry"] = np.stack(
    [photometry_preds[prop].squeeze() for prop in properties]
).T
preds_mlp["resnet18"] = np.stack(
    [resnet_preds[prop].squeeze() for prop in properties]
).T
preds_mlp["photometry"] = np.stack(
    [photometry_preds[prop].squeeze() for prop in properties]
).T

In [None]:
# Make a table of r^2 scores
from sklearn.metrics import r2_score

r2_scores = {key: {} for key in preds_knn.keys()}

for key in preds_knn.keys():
    for i, prop in enumerate(properties):
        r2_scores[key][prop] = r2_score(y_test[:, i], preds_knn[key][:, i])