In [None]:
import h5py
import numpy as np
import pyro
import pyro.distributions as dist
import pyro.distributions.transforms as T
import seaborn as sns
import torch
import tqdm
from astropy.table import Table, join
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset


def get_data(train_dataset, test_dataset, source: str = "images"):
    train_provabgs, test_provabgs = (Table.read(train_provabgs),)

    # Scale the galaxy property data
    prop_scalers = {}
    y_train, y_test = np.zeros((len(train_provabgs), 5)), np.zeros(
        (len(test_provabgs), 5)
    )
    for i, p in enumerate(properties):
        prop_train, prop_test = train_provabgs[p].reshape(-1, 1), test_provabgs[
            p
        ].reshape(-1, 1)
        if p == "Z_MW":
            prop_train, prop_test = np.log(prop_train), np.log(prop_test)
        if p == "SFR":
            prop_train, prop_test = np.log(prop_train) - train_provabgs[
                "LOG_MSTAR"
            ].reshape(-1, 1), np.log(prop_test) - test_provabgs["LOG_MSTAR"].reshape(
                -1, 1
            )
        prop_scaler = StandardScaler().fit(prop_train)
        prop_train, prop_test = prop_scaler.transform(
            prop_train
        ), prop_scaler.transform(prop_test)
        y_train[:, i], y_test[:, i] = prop_train.squeeze(), prop_test.squeeze()
        prop_scalers[p] = prop_scaler

    if source == "images":
        train_images, test_images = (
            train_provabgs["image_features"],
            test_provabgs["image_features"],
        )
        image_scaler = StandardScaler().fit(train_images)
        train_images, test_images = image_scaler.transform(
            train_images
        ), image_scaler.transform(test_images)

        data = {
            "X_train": torch.tensor(train_images),
            "X_test": torch.tensor(test_images),
            "y_train": torch.tensor(y_train, dtype=torch.float32),
            "y_test": torch.tensor(y_test, dtype=torch.float32),
        }

    elif source == "spectra":
        train_spectra, test_spectra = (
            train_provabgs["spectra_features"],
            test_provabgs["spectra_features"],
        )
        spectrum_scaler = StandardScaler().fit(train_spectra)
        train_spectra, test_spectra = spectrum_scaler.transform(
            train_spectra
        ), spectrum_scaler.transform(test_spectra)

        data = {
            "X_train": torch.tensor(train_spectra),
            "X_test": torch.tensor(test_spectra),
            "y_train": torch.tensor(y_train, dtype=torch.float32),
            "y_test": torch.tensor(y_test, dtype=torch.float32),
        }

    elif source == "photometry":
        data = {
            "X_train": torch.tensor(
                train_provabgs["MAG_G", "MAG_R", "MAG_Z"], dtype=torch.float32
            ),
            "X_test": torch.tensor(
                test_provabgs["MAG_G", "MAG_R", "MAG_Z"], dtype=torch.float32
            ),
            "y_train": torch.tensor(y_train, dtype=torch.float32),
            "y_test": torch.tensor(y_test, dtype=torch.float32),
        }

    else:
        raise ValueError("Invalid source. Must be one of: images, spectra, photometry")

    return data, prop_scalers