In [None]:
import os
from astropy.table import Table, join
import torch
from sklearn.metrics import r2_score
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler

from models import MLP
from heads import zero_shot, few_shot

supervised_path = "/mnt/ceph/users/polymathic/astroclip/supervised/"

# Load the data
train_provabgs = Table.read(
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_train_embeddings.hdf5"
)
test_provabgs = Table.read(
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_test_embeddings.hdf5"
)

In [None]:
models = ["astrodino"]

# Get data
data = {}
for model in models:
    data[model] = {}
    X_train, X_test = (
        train_provabgs[model + "_embeddings"],
        test_provabgs[model + "_embeddings"],
    )
    embedding_scaler = StandardScaler().fit(X_train)
    data[model]["train"] = embedding_scaler.transform(X_train)
    data[model]["test"] = embedding_scaler.transform(X_test)

# Get redshifts
z_train = train_provabgs["Z_HP"]
z_test = test_provabgs["Z_HP"]

# Scale properties
scaler = {"mean": z_train.mean(), "std": z_train.std()}
z_train = (z_train - scaler["mean"]) / scaler["std"]

In [None]:
from models import MLP

model = MLP(1024, 1)

In [None]:
from heads import zero_shot, few_shot

preds_knn, preds_mlp = {}, {}

for key in data.keys():
    raw_preds_knn = zero_shot(data[key]["train"], z_train, data[key]["test"])
    raw_preds_mlp = few_shot(model, data[key]["train"], z_train, data[key]["test"])
    preds_knn[key] = raw_preds_knn * scaler["std"] + scaler["mean"]
    preds_mlp[key] = raw_preds_mlp * scaler["std"] + scaler["mean"]

In [None]:
# Get predictions from supervised models

supervised_path = "/mnt/ceph/users/polymathic/astroclip/supervised/"

supervised = {
    "resnet18": torch.load(os.path.join(supervised_path, "image/test_pred.pt"))["Z_HP"],
    "photometry": torch.load(os.path.join(supervised_path, "photometry/test_pred.pt"))[
        "Z_HP"
    ],
}

In [None]:
r2_score(z_test, supervised["resnet18"]), r2_score(z_test, supervised["photometry"])