In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from astropy.table import Table, join
from sklearn.metrics import r2_score
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler

from utils.models import MLP, zero_shot, few_shot

supervised_path = "/mnt/ceph/users/polymathic/astroclip/supervised/"

# Load the data
train_provabgs = Table.read(
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_train_embeddings.hdf5"
)
test_provabgs = Table.read(
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_test_embeddings.hdf5"
)

In [None]:
models = ["astrodino", "stein"]

# Get data
data = {}
for model in models:
    data[model] = {}
    X_train, X_test = (
        train_provabgs[model + "_embeddings"],
        test_provabgs[model + "_embeddings"],
    )
    embedding_scaler = StandardScaler().fit(X_train)
    data[model]["train"] = embedding_scaler.transform(X_train)
    data[model]["test"] = embedding_scaler.transform(X_test)

# Get redshifts
z_train = train_provabgs["Z_HP"]
z_test = test_provabgs["Z_HP"]

# Scale properties
scaler = {"mean": z_train.mean(), "std": z_train.std()}
z_train = (z_train - scaler["mean"]) / scaler["std"]

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from numpy import ndarray


class MLP(nn.Sequential):
    """MLP model"""

    def __init__(self, n_in, n_out, n_hidden=(16, 16, 16), act=None, dropout=0):
        if act is None:
            act = [
                nn.LeakyReLU(),
            ] * (len(n_hidden) + 1)
        assert len(act) == len(n_hidden) + 1

        layer = []
        n_ = [n_in, *n_hidden, n_out]
        for i in range(len(n_) - 2):
            layer.append(nn.Linear(n_[i], n_[i + 1]))
            layer.append(act[i])
            layer.append(nn.Dropout(p=dropout))
        layer.append(nn.Linear(n_[-2], n_[-1]))

        super(MLP, self).__init__(*layer)


def few_shot(
    model: nn.Module,
    X_train: ndarray,
    y_train: ndarray,
    X_test: ndarray,
    max_epochs: int = 10,
    lr: float = 1e-3,
) -> ndarray:
    """Train a few-shot model using a simple neural network"""
    train_dataset = TensorDataset(
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.float32),
    )
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    num_features = y_train.shape[1] if len(y_train.shape) > 1 else 1
    model = MLP(
        n_in=X_train.shape[1],
        n_out=num_features,
        n_hidden=[64, 64],
        act=[nn.ReLU()] * 3,
    )

    # Set up the model
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Train the model
    model.cuda()
    model.train()
    for epoch in range(max_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs.cuda()).squeeze()
            loss = criterion(outputs, labels.cuda())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

    # Make predictions
    model.eval()
    with torch.no_grad():
        preds = model(torch.tensor(X_test, dtype=torch.float32).cuda()).cpu().numpy()
    return preds

In [None]:
preds_knn, preds_mlp = {}, {}

for key in data.keys():
    raw_preds_knn = zero_shot(data[key]["train"], z_train, data[key]["test"])
    if key == "stein":
        model = MLP(128, 1, n_hidden=[128, 64, 32])
    else:
        model = MLP(1024, 1, n_hidden=[128, 64, 32])

    raw_preds_mlp = few_shot(
        model, data[key]["train"], z_train, data[key]["test"], max_epochs=20
    ).squeeze()
    preds_knn[key] = raw_preds_knn * scaler["std"] + scaler["mean"]
    preds_mlp[key] = raw_preds_mlp * scaler["std"] + scaler["mean"]

In [None]:
# Get predictions from supervised models
resnet_preds = torch.load(os.path.join(supervised_path, "image/test_pred.pt"))["Z_HP"]
photometry_preds = torch.load(os.path.join(supervised_path, "photometry/test_pred.pt"))[
    "Z_HP"
]

preds_knn["resnet18"] = resnet_preds
preds_knn["photometry"] = photometry_preds
preds_mlp["resnet18"] = resnet_preds
preds_mlp["photometry"] = photometry_preds

In [None]:
def plot_redshift_scatter(preds, z_test, save_loc="scatter.png"):
    fig, ax = plt.subplots(2, len(preds.keys()), figsize=(16, 10))

    for i, name in enumerate(preds.keys()):
        sns.scatterplot(ax=ax[0, i], x=z_test, y=preds[name], s=5, color=".15")
        sns.histplot(
            ax=ax[0, i], x=z_test, y=preds[name], bins=50, pthresh=0.1, cmap="mako"
        )
        sns.kdeplot(
            ax=ax[0, i], x=z_test, y=preds[name], levels=5, color="k", linewidths=1
        )

        ax[0, i].plot(0, 0.65, "--", linewidth=1.5, alpha=0.5, color="grey")
        ax[0, i].set_xlim(0, 0.6)
        ax[0, i].set_ylim(0, 0.6)
        ax[0, i].text(
            0.9,
            0.1,
            "$R^2$ score: %0.2f" % r2_score(z_test, preds[name]),
            horizontalalignment="right",
            verticalalignment="top",
            fontsize=22,
            transform=ax[0, i].transAxes,
        )
        ax[0, i].set_title(name, fontsize=25)

    ax[0, 0].set_ylabel("$Z_{pred}$", fontsize=25)

    for i, name in enumerate(preds.keys()):
        x = z_test
        y = (z_test - preds[name]) / (1 + z_test)

        bins = np.linspace(0, 0.62, 20)
        x_binned = np.digitize(x, bins)
        y_avg = [y[x_binned == i].mean() for i in range(1, len(bins))]
        y_std = [y[x_binned == i].std() for i in range(1, len(bins))]

        sns.scatterplot(ax=ax[1, i], x=x, y=y, s=2, alpha=0.3, color="black")
        sns.lineplot(ax=ax[1, i], x=bins[:-1], y=y_std, color="r", label="std")

        # horizontal line on y = 0
        ax[1, i].axhline(0, color="grey", linewidth=1.5, alpha=0.5, linestyle="--")

        # sns.scatterplot(ax=ax[1,i], x=bins[:-1], y=y_avg, s=15, color='.15')
        ax[1, i].set_xlim(0, 0.6)
        ax[1, i].set_ylim(-0.3, 0.3)
        ax[1, i].set_xlabel("$Z_{true}$", fontsize=25)
        ax[1, i].legend(fontsize=15, loc="upper right")

    ax[1, 0].set_ylabel("$(Z_{true}-Z_{pred})/(1+Z_{true})$", fontsize=25)

    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(save_loc, dpi=300)

In [None]:
plot_redshift_scatter(preds_knn, z_test, save_loc="scatter_knn.png")
plot_redshift_scatter(preds_mlp, z_test, save_loc="scatter_mlp.png")