In [None]:
from astropy.table import Table, join
import torch
from sklearn.metrics import r2_score
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler

# Load the data
train_provabgs = Table.read(
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_train_embeddings.hdf5"
)
test_provabgs = Table.read(
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_test_embeddings.hdf5"
)

In [None]:
from helpers import scale_properties, unscale_properties

models = ["astrodino"]

data = {}
for model in models:
    data[model] = {}
    X_train, X_test = (
        train_provabgs[model + "_embeddings"],
        test_provabgs[model + "_embeddings"],
    )
    property_train, property_test = {"Z_HP": train_provabgs["Z_HP"]}, {
        "Z_HP": test_provabgs["Z_HP"]
    }
    y_train, scaler = scale_properties(property_train)

In [None]:
models = ["astrodino"]

# Get data
data = {}
for model in models:
    data[model] = {}
    X_train, X_test = (
        train_provabgs[model + "_embeddings"],
        test_provabgs[model + "_embeddings"],
    )
    embedding_scaler = StandardScaler().fit(X_train)
    data[model]["train"] = embedding_scaler.transform(X_train)
    data[model]["test"] = embedding_scaler.transform(X_test)

# Get redshifts
train_redshift = torch.tensor(y_train["Z_HP"])
test_redshift = torch.tensor(test_provabgs["Z_HP"])

In [None]:
from helpers import zero_shot, few_shot
from models import MLP

model = MLP(1024, 1)

In [None]:
preds_knn, preds_mlp = {}, {}

for key in data.keys():
    raw_preds_knn = zero_shot(data[key]["train"], train_redshift, data[key]["test"])
    raw_preds_mlp = few_shot(
        model, data[key]["train"], train_redshift, data[key]["test"]
    )
    preds_knn[key] = scaler["Z_HP"].inverse_transform(raw_preds_knn)
    preds_mlp[key] = scaler["Z_HP"].inverse_transform(raw_preds_mlp)

In [None]:
r2_score(test_redshift, preds_knn["astrodino"]), r2_score(
    test_redshift, preds_mlp["astrodino"]
)