In [None]:
# Key Concepts
'''
1. Forward problem: θ → x (simulation)
2. Inverse problem: x → θ (what we're solving)
3. Posterior: P(θ|x) probability distribution over parameters
4. Prior: Initial assumptions about parameter ranges
'''
import pandas as pd
import numpy as np
import sys

# seaparate into train and test set.
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
#from torch.distributions import Uniform, ExpTransform, TransformedDistribution #, AffineTransform
import torch.nn as nn
from sklearn.preprocessing import Normalizer
import joblib
import os 
import ili
from ili.dataloaders import NumpyLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage, PlotSinglePosterior

from sbi.utils.user_input_checks import process_prior

sys.path.append("/disk/xray15/aem2/camels/proj1/")
from camels.proj1.setup_params_LH import *
from priors import initialise_priors
from variables_config import n_bins_lf, n_bins_colour #, colour_limits, uvlf_limits

# parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
model = "IllustrisTNG"
spec_type = "attenuated"
sps = "BC03"
# 084 = 0.1, 060=1.05, 052=1.48, 044=2
snap = ['086'] # ['044', '052', '060',
bands = "all" # or just GALEX?

# lets try UVLF and colours this time.
colours = True  
luminosity_functions = True
name = f"{model}_{bands}_{sps}_{spec_type}_{n_bins_lf}_{n_bins_colour}"

# initialize CAMELS and load parameter info using camels.py
cam = camels(model=model, sim_set='LH')

if colours and not luminosity_functions:
    model_out_dir = "/disk/xray15/aem2/data/6pams/LH/IllustrisTNG/models/colours_only/"
    plots_out_dir = "/disk/xray15/aem2/plots/6pams/IllustrisTNG/LH/test/sbi_plots/colours_only/"
    
elif luminosity_functions and not colours:
    model_out_dir = "/disk/xray15/aem2/data/6pams/LH/IllustrisTNG/models/lf_only/"
    plots_out_dir = "/disk/xray15/aem2/plots/6pams/LH/IllustrisTNG/test/sbi_plots/lf_only"

elif colours and luminosity_functions:
    model_out_dir = "/disk/xray15/aem2/data/6pams/LH/IllustrisTNG/models/colours_lfs/"
    plots_out_dir = "/disk/xray15/aem2/plots/6pams/LH/IllustrisTNG/test/sbi_plots/colours_lfs/"

# You might want to add an else for safety:
else:
    raise ValueError("At least one of colours or luminosity_functions must be True")

print("Saving model in ", model_out_dir)
print("Saving plots in ", plots_out_dir)


In [None]:
# Check available snapshots first
available_snaps = get_available_snapshots()
print(f"Available snapshots: {available_snaps}")

### Chris version of sbi:


In [None]:
import sys

sys.path.insert(0, "..")

import torch

torch.set_default_dtype(torch.float32)

import numpy as np
import ili
from ili.dataloaders import NumpyLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage
from sklearn.preprocessing import Normalizer
import joblib

from priors import initialise_priors
from camels.proj1.setup_params_LH import get_theta_x
from camels import camels


# IllustrisTNG_all_BC03_attenuated_12_12_086

model = "IllustrisTNG"  # "Swift-EAGLE" # "Astrid" # "IllustrisTNG" # "Simba"
spec_type = "intrinsic"
sps = "BC03"
snap = ["086"]  # , "060", "044"] #  "060", "044"]  # , "086", "060", "044"]
n_bins_lf = 12 
n_bins_colour = 12
cam = camels(model)

bands = "all"
colours = True
luminosity_functions = True

name = f"{model}_{bands}_{sps}_{spec_type}_{n_bins_lf}_{n_bins_colour}"

if isinstance(snap, list):
    for snp in snap:
        name += f"_{snp}"
else:
    name += f"_{snap}"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

prior = initialise_priors(device=device, astro=True, dust=False)
theta, x = get_theta_x(
    # photo_dir=f"/mnt/ceph/users/clovell/CAMELS_photometry/{model}/",
    photo_dir=f"/disk/xray15/aem2/data/6pams/",
    spec_type=spec_type,
    model=model,
    snap=snap,
    sps=sps,
    n_bins_lf=n_bins_lf,
    n_bins_colour=n_bins_colour,
    colours=colours,
    luminosity_functions=luminosity_functions,
    device=device,
)

x_all = np.array([np.hstack(_x) for _x in x])

# # Make sure no constant variables, to avoid nan loss with lampe NDE models
# x_all[x_all == 0.0] = np.array(
#     np.random.rand(np.sum((x_all == 0.0))) * 1e-10
# )

norm = Normalizer()
x_all = torch.tensor(
    norm.fit_transform(X=x_all),
    # x_all,
    dtype=torch.float32,
    device=device, 
)

joblib.dump(norm, os.path.join(model_out_dir, f'{name}_scaler.save'))


# test_mask = np.random.rand(1000) > 0.9
# np.savetxt('../data/test_mask.txt', test_mask, fmt='%i')
test_mask = np.loadtxt("/disk/xray15/aem2/data/6pams/test_mask.txt", dtype=bool)

hidden_features = 30
num_transforms = 4
nets = [
    # ili.utils.load_nde_sbi(
    #     engine="NLE", model="maf", hidden_features=50, num_transforms=5
    # ),
    ili.utils.load_nde_sbi(
        engine="NPE",
        model="nsf", hidden_features=hidden_features, num_transforms=num_transforms
    ),
    ili.utils.load_nde_sbi(
        engine="NPE",
        model="nsf", hidden_features=hidden_features, num_transforms=num_transforms
    ),
    ili.utils.load_nde_sbi(
        engine="NPE",
        model="nsf", hidden_features=hidden_features, num_transforms=num_transforms
    ),
    # ili.utils.load_nde_sbi(
    #     engine="NPE",
    #     model="nsf", hidden_features=hidden_features, num_transforms=num_transforms
    # ),
    # ili.utils.load_nde_lampe(model="nsf", device=device, hidden_features=20, num_transforms=2), 
    # ili.utils.load_nde_lampe(model="nsf", device=device, hidden_features=20, num_transforms=2), 
]

train_args = {"training_batch_size": 4, "learning_rate": 5e-4, 'stop_after_epochs': 20}

loader = NumpyLoader(
    x=x_all[~test_mask],
    # theta=torch.tensor(theta[~test_mask], device=device)
    theta=torch.tensor(theta[~test_mask, :], device=device)
)

runner = InferenceRunner.load(
    backend="sbi",  #'sbi', # 'lampe',
    engine="NPE",
    prior=prior,
    nets=nets,
    device=device,
    train_args=train_args,
    proposal=None,
    # embedding_net=None,
    out_dir=model_out_dir,
    name=name,
)

posterior_ensemble, summaries = runner(loader=loader)


"""
Coverage plots for each model
"""
metric = PosteriorCoverage(
    num_samples=int(4e3),
    sample_method='direct',
    # sample_method="slice_np_vectorized",
    # sample_params={'num_chains': 1},
    # sample_method="vi",
    # sample_params={"dist": "maf", "n_particles": 32, "learning_rate": 1e-2},
    labels=cam.labels,
    plot_list=["coverage", "histogram", "predictions", "tarp"],
    # out_dir="../plots/",
)

fig = metric(
    posterior=posterior_ensemble,
    x=x_all[test_mask].cpu(),
    theta=theta[test_mask, :].cpu(),
    signature=f"coverage_{name}_",
)

fig[3].axes[0].set_xlim(0,1)
fig[3].axes[0].set_ylim(0,1)
fig[3].savefig(f'../plots/coverage_{name}_plot_TARP.png', bbox_inches='tight', dpi=200)

