In [None]:
import os
import numpy as np
import torch
from astropy.table import Table
from sklearn.preprocessing import StandardScaler

from property_utils.models import few_shot, zero_shot
from property_utils.plotting import plot_scatter

PROVABGS_ROOT = "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/"
SUPERVISED_ROOT = "/mnt/ceph/users/polymathic/astroclip/supervised/"

# Define models in embeddings
image_models = ["astroclip_image", "astrodino", "stein"]
spectrum_models = ["astroclip_spectrum", "specformer"]

# Set up the paths
train_path = os.path.join(PROVABGS_ROOT, "provabgs_paired_train_embeddings.hdf5")
test_path = os.path.join(PROVABGS_ROOT, "provabgs_paired_test_embeddings.hdf5")

# Get embeddings and PROVABGS table
train_provabgs = Table.read(train_path)
test_provabgs = Table.read(test_path)

# Get redshifts
z_train = train_provabgs["Z_HP"]
z_test = test_provabgs["Z_HP"]

# Scale properties
scaler = {"mean": z_train.mean(), "std": z_train.std()}
z_train = (z_train - scaler["mean"]) / scaler["std"]

print(
    "Size of training set:",
    len(train_provabgs),
    "\nSize of test set:",
    len(test_provabgs),
)

# Redshift estimation from image embeddings

In [None]:
# Get data
data = {}
for model in image_models:
    data[model] = {}
    X_train, X_test = (
        train_provabgs[model + "_embeddings"],
        test_provabgs[model + "_embeddings"],
    )
    embedding_scaler = StandardScaler().fit(X_train)
    data[model]["train"] = embedding_scaler.transform(X_train)
    data[model]["test"] = embedding_scaler.transform(X_test)

In [None]:
# Perfrom knn and mlp
preds_knn, preds_mlp = {}, {}
for key in data.keys():
    print(f"Evaluating {key} model...")
    raw_preds_knn = zero_shot(data[key]["train"], z_train, data[key]["test"])
    raw_preds_mlp = few_shot(
        model, data[key]["train"], z_train, data[key]["test"], hidden_dims=[32]
    ).squeeze()
    preds_knn[key] = raw_preds_knn * scaler["std"] + scaler["mean"]
    preds_mlp[key] = raw_preds_mlp * scaler["std"] + scaler["mean"]

In [None]:
# Plot scatter plots
save_path = "./outputs/redshift/image"
if not os.path.exists(save_path):
    os.makedirs(save_path)
plot_scatter(preds_knn, z_test, save_loc=f"{save_path}/redshift_scatter_knn.png")
plot_scatter(preds_mlp, z_test, save_loc=f"{save_path}/redshift_scatter_mlp.png")

In [None]:
# Get predictions from  supervised models
preds_supervised = {
    "resnet18": torch.load(
        os.path.join(SUPERVISED_ROOT, "image/ResNet18/redshift/test_pred.pt")
    )["Z_HP"],
    "photometry": torch.load(
        os.path.join(SUPERVISED_ROOT, "photometry/MLP/redshift/test_pred.pt")
    )["Z_HP"],
}

save_path = "./outputs/redshift/image"
if not os.path.exists(save_path):
    os.makedirs(save_path)
plot_scatter(
    preds_supervised, z_test, save_loc=f"{save_path}/redshift_scatter_supervised.png"
)

# Redshift Estimation from Spectra Embeddings

In [None]:
# Get data
data = {}
for model in spectrum_models:
    data[model] = {}
    X_train, X_test = (
        train_provabgs[model + "_embeddings"],
        test_provabgs[model + "_embeddings"],
    )
    embedding_scaler = StandardScaler().fit(X_train)
    data[model]["train"] = embedding_scaler.transform(X_train)
    data[model]["test"] = embedding_scaler.transform(X_test)

In [None]:
# Perfrom knn and mlp
preds_knn, preds_mlp = {}, {}
for key in data.keys():
    print(f"Evaluating {key} model...")
    raw_preds_knn = zero_shot(data[key]["train"], z_train, data[key]["test"])
    raw_preds_mlp = few_shot(
        model, data[key]["train"], z_train, data[key]["test"]
    ).squeeze()
    preds_knn[key] = raw_preds_knn * scaler["std"] + scaler["mean"]
    preds_mlp[key] = raw_preds_mlp * scaler["std"] + scaler["mean"]

In [None]:
# Get predictions from supervised models
spectrum_preds = torch.load(
    os.path.join(SUPERVISED_ROOT, "spectrum/Conv+Att/redshift/test_pred.pt")
)["Z_HP"]

# Add predictions to dictionary
preds_knn["conv+att"] = spectrum_preds
preds_mlp["conv+att"] = spectrum_preds

In [None]:
# Plot scatter plots
save_path = "./outputs/redshift/spectrum"
if not os.path.exists(save_path):
    os.makedirs(save_path)
plot_scatter(preds_knn, z_test, save_loc=f"{save_path}/redshift_scatter_knn.png")
plot_scatter(preds_mlp, z_test, save_loc=f"{save_path}/redshift_scatter_mlp.png")