In [None]:
from astropy.table import Table, join
import torch
import tqdm
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
from sklearn.metrics import r2_score
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, random_split

import os, sys

sys.path.append("../")

from models import ResNet18, SpectrumEncoder, MLP

In [None]:
def setup_supervised_data(
    train_data: Table,
    test_data: Table,
    modality: str,
    properties: list = None,
    batch_size: int = 128,
    train_size: float = 0.8,
):
    """Helper function to set up supervised data for training and testing."""
    if properties is None:
        properties = ["Z_HP", "PROVABGS_LOG_MSTAR_BF", "Z_MW", "TAGE_MW", "AVG_SFR"]

    # Set up the training data
    if modality == "image":
        X_train, X_test = torch.tensor(
            train_data[modality], dtype=torch.float32
        ), torch.tensor(test_data[modality], dtype=torch.float32)

    elif modality == "spectrum":
        X_train, X_test = torch.tensor(
            train_data[modality], dtype=torch.float32
        ), torch.tensor(test_data[modality], dtype=torch.float32)
        X_train = X_train.squeeze().squeeze()
        X_test = X_test.squeeze().squeeze()

    elif modality == "photometry":
        X_train = torch.tensor(
            np.stack([train_data["MAG_G"], train_data["MAG_R"], train_data["MAG_Z"]]),
            dtype=torch.float32,
        ).permute(1, 0)
        X_test = torch.tensor(
            np.stack([test_data["MAG_G"], test_data["MAG_R"], test_data["MAG_Z"]]),
            dtype=torch.float32,
        ).permute(1, 0)

    # Scale the data
    X_mean, X_std = X_train.mean(), X_train.std()
    X_train = (X_train - X_mean) / X_std
    X_test = (X_test - X_mean) / X_std

    # Set up the property data
    property_data, scale = {}, {}
    for p in properties:
        data = torch.tensor(train_data[p].data, dtype=torch.float32)
        mean, std = data.mean(), data.std()
        property_data[p] = ((data - mean) / std).squeeze()
        scale[p] = {"mean": mean.numpy(), "std": std.numpy()}
    y_train = torch.stack([property_data[p] for p in properties], dim=1)

    # Split the data into training, validation, and test sets
    total_size = len(X_train)
    train_size = int(train_size * total_size)
    train_dataset, val_dataset = random_split(
        TensorDataset(X_train, y_train), [train_size, total_size - train_size]
    )

    # Set up the data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(X_test, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader, scale


def train_model(
    model: nn.Module,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    scalers: dict[str, StandardScaler],
    properties: list,
    device="cuda",
    num_epochs=50,
    learning_rate=1e-3,
):
    """Helper function to train a model."""
    model.to(device)

    # Define the loss function and optimizer
    criterion = torch.nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=5e-4)

    best_val_loss = float("inf")

    epochs = tqdm.trange(num_epochs, desc="Training Model: ", leave=True)

    # Training loop
    for epoch in epochs:
        train_loss = 0
        model.train()
        for X_batch, y_batch in train_loader:
            y_pred = model(X_batch.to(device)).squeeze()
            loss = criterion(y_pred, y_batch.to(device))
            train_loss += loss.item()

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_pred, val_true, val_loss = [], [], 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                y_pred = model(X_batch.to(device)).squeeze().detach().cpu()
                loss = criterion(y_pred, y_batch)
                val_loss += loss.item()
                val_pred.append(y_pred)
                val_true.append(y_batch)

        val_pred = torch.cat(val_pred).numpy()
        val_true = torch.cat(val_true).numpy()

        val_r2s = {}
        for i, prop in enumerate(scale.keys()):
            pred_i = (val_pred[:, i] * scalers[prop]["std"]) + scalers[prop]["mean"]
            true_i = (val_true[:, i] * scalers[prop]["std"]) + scalers[prop]["mean"]
            val_r2s[prop] = r2_score(true_i, pred_i)

        if val_loss / len(val_loader) < best_val_loss:
            best_model = model.state_dict()
            best_val_loss = val_loss / len(val_loader)

        # Early stopping
        if epoch > 10 and val_loss / len(val_loader) > 1.5 * best_val_loss:
            break

        epochs.set_description(
            "epoch: {}, train loss: {:.4f}, val loss: {:.4f}, z_hp: {:.4f}".format(
                epoch + 1,
                train_loss / len(train_loader),
                val_loss / len(val_loader),
                val_r2s["Z_HP"],
            )
        )
        epochs.update(1)

    return best_model

In [None]:
train_dataset = (
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_train.hdf5"
)
test_dataset = (
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_test.hdf5"
)
properties = None

In [None]:
if properties is None:
    properties = ["Z_HP", "PROVABGS_LOGMSTAR_BF", "Z_MW", "TAGE_MW", "AVG_SFR"]

# Load the data
train_provabgs = Table.read(train_dataset)
test_provabgs = Table.read(test_dataset)

In [None]:
# Get the data loaders & scalers
train_loader, val_loader, test_loader, scale = setup_supervised_data(
    train_provabgs, test_provabgs, "photometry", properties=properties
)

In [None]:
# Initialize the model
# model = SpectrumEncoder(n_latent=5)
model = MLP()

# Train the model
best_model = train_model(
    model, train_loader, val_loader, scale, properties, num_epochs=1, learning_rate=5e-4
)
model.load_state_dict(best_model)

In [None]:
device = "cuda"

In [None]:
test_pred = []
with torch.no_grad():
    for X_batch in test_loader:
        y_pred = model(X_batch.to(device)).squeeze().detach().cpu()
        test_pred.append(y_pred)
test_pred = torch.cat(test_pred).numpy()

In [None]:
for i, p in enumerate(scale.keys()):
    test_pred[:, i] = (test_pred[:, i] * scale[p]["std"]) + scale[p]["mean"]
    print(r2_score(test_provabgs[p], test_pred[:, i]))

In [None]:
def main(
    train_dataset: str, 
    test_datset: str, 
    save_dir: str,
    modality: str,
    num_epochs: int = 50, 
    learning_rate: float = 5e-4, 
    properties: list = None):

    if properties is None:
        properties = ['Z_HP', 'PROVABGS_LOGMSTAR_BF', 'Z_MW', 'TAGE_MW', 'AVG_SFR']

    # Load the data
    train_provabgs = Table.read(train_dataset)
    test_provabgs = Table.read(test_dataset)

    # Get the data loaders & scalers
    train_loader, val_loader, test_loader, scale = setup_supervised_data(train_provabgs, test_provabgs, modality, properties=properties)

    # Initialize the model
    if modality == 'image':
        model = ResNet18(num_classes=len(properties))
    elif modality == 'spectrum':
        model = 

    # Train the model
    best_model = train_model(model, train_loader, val_loader, scale, properties, num_epochs=num_epochs, learning_rate=learning_rate)
    model.load_state_dict(best_model)

    # Get the predictions
    test_pred = []
    with torch.no_grad():
        for X_batch in test_loader:
            y_pred = model(X_batch.to(device)).squeeze().detach().cpu()
            test_pred.append(y_pred)
    test_pred = torch.cat(test_pred).numpy()

    pred_dict = {}
    for i, p in enumerate(scale.keys()):
        pred_dict[p] = (test_pred[:, i] * scale[p]['std']) + scale[p]['mean']
        print(f'{p} R^2: {r2_score(test_provabgs[p], pred_dict[p])}')

    # Save the model and the predictions
    torch.save(model, os.path.join(save_dir, 'resnet.pt'))
    torch.save(pred_dict, os.path.join(save_dir, 'test_pred.pt'))

In [None]:
main(
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_train.hdf5",
    "/mnt/ceph/users/polymathic/astroclip/datasets/provabgs/provabgs_paired_test.hdf5",
    "/mnt/ceph/users/polymathic/astroclip/supervised/",
    num_epochs=1,
)