# Model Visualisations

In this notebook, we take a quick look at how to load real and simulated data, and train a given machine learning model.

In [None]:
import spt
import spt.config as cfg

import random
import numpy as np

from spt.visualisation import plot_corner
from spt.utils import get_median_mode

In [None]:
try: # One-time setup
    assert(_SETUP)
except NameError:
    import os
    import torch as t
    os.chdir(os.path.split(spt.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        !nvidia-smi -L
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

ip = cfg.InferenceParams()
# ip.dataset_loc = './data/dsets/dev'
ip.use_existing_checkpoints = False
ip.retrain_model = False

mp = cfg.SANParams()
model = ip.model(mp)

from spt.load_photometry import load_simulated_data, get_norm_theta
fp = cfg.ForwardModelParams()

train_loader, test_loader = load_simulated_data(
    path=ip.dataset_loc,
    split_ratio=ip.split_ratio,
    batch_size=model.params.batch_size,
    test_batch_size=1,
    phot_transforms=[lambda x: t.from_numpy(np.log(x))],
    theta_transforms=[get_norm_theta(fp)],
)

model.offline_train(train_loader, ip)

# test dataset
tds = test_loader.dataset
def new_sample() -> tuple[int, int, tuple[t.Tensor, t.Tensor]]:
    ds_idx = random.randint(0, len(tds)-1) # test loader index
    xys = tds[ds_idx]
    idx = tds.indices[ds_idx]
    return ds_idx, idx, xys
def sample_at(ds_idx: int) -> tuple[int, t.Tensor, t.Tensor]:
    xys = tds[ds_idx]
    idx = tds.indices[ds_idx]
    return idx, xys

# OLD:
# sim_xs, sim_ys = spt.utils.new_sample(test_loader, 1)

# NOTE: this should be done automatically at the end of offline_train...
model.eval()

In [None]:
ds_idx, sim_idx, (sim_xs, sim_ys) = new_sample()

# for reproducibility
# ds_idx = 947898
# sim_idx, (sim_xs, sim_ys) = sample_at(ds_idx)

with t.inference_mode():
    samples = model.sample(sim_xs, n_samples=10000).cpu()

plot_corner(samples=samples.squeeze().numpy(), true_params=sim_ys,
            title=f'Sequential Autoregressive Network',
            description=str(model))

# Real Observations

Now we attempt to run the trained model on a real observation from a catalogue.

In [None]:
real_obs = spt.load_observation()
real_p = spt.Prospector(real_obs)

In [None]:
pd_obs = spt.load_observation()
real_obs = fp.build_obs_fn(fp.filters, pd_obs)
obs_idx = real_obs['_index']
required_cols = [f.maggie_col for f in fp.filters]
real_xs = t.tensor(pd_obs[required_cols].values.astype(np.float64)).log().to(device, dtype)

with t.inference_mode():
    real_samples = model.sample(real_xs, n_samples=10000).cpu()

plot_corner(samples=real_samples.squeeze().numpy(), true_params=None,
            title=f'Sequential Autoregressive Network',
            description=str(model))

real_median, real_mode = get_median_mode(real_samples)

fmp = spt.config.ForwardModelParams()
dt = spt.load_photometry.get_denorm_theta(fmp)
real_denorm_mode = dt(real_mode[None,:]).squeeze()
real_denorm_median = dt(real_median[None,:]).squeeze()

real_p.obs = real_obs
try:
    real_p.visualise_model(real_denorm_mode, show=True, save=False, title=f'SAN Parameter Predictions ({pd_obs["survey"]}:{int(pd_obs["idx"])})')
except Exception:
    print('Could not plot results')

### 