In [None]:
# Key Concepts
'''
1. Forward problem: θ → x (simulation)
2. Inverse problem: x → θ (what we're solving)
3. Posterior: P(θ|x) probability distribution over parameters
4. Prior: Initial assumptions about parameter ranges
'''
import pandas as pd
import numpy as np
import sys

# seaparate into train and test set.
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
#from torch.distributions import Uniform, ExpTransform, TransformedDistribution #, AffineTransform
import torch.nn as nn
from sklearn.preprocessing import Normalizer
import joblib
import os 
import ili
from ili.dataloaders import NumpyLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage#, PlotSinglePosterior

from sbi.utils.user_input_checks import process_prior

sys.path.append("/disk/xray15/aem2/camels/proj1/")
from setup_params import plot_uvlf, plot_colour
from setup_params import *
from priors import initialise_priors
from variables_config import n_bins_lf, n_bins_colour #, colour_limits, uvlf_limits

# parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
model = "IllustrisTNG"
spec_type = "attenuated"
sps = "BC03"
snap = ['044', '052', '060', '086']
bands = "all" # or just GALEX?

# lets try UVLF and colours this time.
colours = True  
luminosity_functions = True
name = f"{model}_{bands}_{sps}_{spec_type}_{n_bins_lf}_{n_bins_colour}"

# initialize CAMELS and load parameter info using camels.py
cam = camels(model=model, sim_set='LH')

if colours and not luminosity_functions:
    model_out_dir = "/disk/xray15/aem2/data/6pams/IllustrisTNG/LH/models/colours_only/"
    plots_out_dir = "/disk/xray15/aem2/plots/6pams/IllustrisTNG/LH/test/sbi_plots/colours_only/"
    
elif luminosity_functions and not colours:
    model_out_dir = "/disk/xray15/aem2/data/6pams/IllustrisTNG/LH/models/lf_only/"
    plots_out_dir = "/disk/xray15/aem2/data/6pams/IllustrisTNG/LH/test/sbi_plots/lf_only"

elif colours and luminosity_functions:
    model_out_dir = "/disk/xray15/aem2/data/6pams/IllustrisTNG/LH/models/colours_lfs/"
    plots_out_dir = "/disk/xray15/aem2/plots/6pams/IllustrisTNG/LH/test/sbi_plots/colours_lfs/"

# You might want to add an else for safety:
else:
    raise ValueError("At least one of colours or luminosity_functions must be True")

print("Saving model in ", model_out_dir)
print("Saving plots in ", plots_out_dir)


In [None]:
# Check available snapshots first
available_snaps = get_available_snapshots()
print(f"Available snapshots: {available_snaps}")

### Chris version of sbi:


In [None]:

prior = initialise_priors(device=device, astro=True, dust=False)
theta, x = get_theta_x(
    # photo_dir=f"/mnt/ceph/users/clovell/CAMELS_photometry/{model}/",
    photo_dir=f"/disk/xray15/aem2/data/6pams/",
    spec_type=spec_type,
    model=model,
    snap=snap,
    sps=sps,
    n_bins_lf=n_bins_lf,
    n_bins_colour=n_bins_colour,
    colours=colours,
    luminosity_functions=luminosity_functions,
    device=device,
)


In [None]:
if luminosity_functions:
    fig = plot_uvlf(x)
    plt.savefig('/disk/xray15/aem2/plots/6pams/IllustrisTNG/LH/test/LFs_test/uvlf_check.png')
    plt.show()

In [None]:
if colours:
    fig = plot_colour(x)
    plt.savefig('/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/colours_test/colour_check.png')
    plt.show()

In [None]:

x_all = np.array([np.hstack(_x) for _x in x])

# # Make sure no constant variables, to avoid nan loss with lampe NDE models
# x_all[x_all == 0.0] = np.array(
#     np.random.rand(np.sum((x_all == 0.0))) * 1e-10
# )

norm = Normalizer()
x_all = torch.tensor(
    norm.fit_transform(X=x_all),
    # x_all,
    dtype=torch.float32,
    device=device, 
)

joblib.dump(norm, f'/disk/xray15/aem2/data/6pams/IllustrisTNG/LH/models/{name}_scaler.save')


# test_mask = np.random.rand(1000) > 0.9
# np.savetxt('../data/test_mask.txt', test_mask, fmt='%i')
test_mask = np.loadtxt("/disk/xray15/aem2/data/6pams/IllustrisTNG/LH/test_mask.txt", dtype=bool)


In [None]:
# Network architecture improvements
hidden_features = 60  # Increase from 30 for more capacity
num_transforms = 4    # Increase from 4 for more expressive transforms
# num_bins = 10 # spline bins, this is default in sbi package anyway.

# Create larger ensemble
nets = [
    ili.utils.load_nde_sbi(
        engine="NPE",
        model="nsf", 
        hidden_features=hidden_features, 
        num_transforms=num_transforms,
    ) for _ in range(2)  
]

# Optimize training parameters
train_args = {
    "training_batch_size": 4,  # Increase from 4 for better gradient estimates
    "learning_rate": 5e-4,      # Slightly lower for more stable training
    "stop_after_epochs": 50   # More time to converge
}

# Keep the existing loader setup
loader = NumpyLoader(
    x=x_all[~test_mask],
    theta=torch.tensor(theta[~test_mask, :], device=device)
)

runner = InferenceRunner.load(
    backend="sbi",
    engine="NPE",
    prior=prior,
    nets=nets,
    device=device,
    train_args=train_args,
    proposal=None,
    out_dir="models/",
    name=name,
)


posterior_ensemble, summaries = runner(loader=loader)


In [None]:
def plot_training_diagnostics(summaries):
    """Plot training diagnostics without empty subplots"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))  # Changed to 1 row, 2 columns
    
    # Loss curves
    train_losses = summaries[0]['training_log_probs']
    val_losses = summaries[0]['validation_log_probs']
    epochs = range(len(train_losses))
    
    ax1.plot(epochs, train_losses, '-', label='Training', color='blue')
    ax1.plot(epochs, val_losses, '--', label='Validation', color='red')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Log probability')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Overfitting gap
    gap = np.array(train_losses) - np.array(val_losses)
    ax2.plot(epochs, gap, '-', color='purple')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss difference')
    ax2.set_title('Overfitting Gap')
    ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Use the function
fig = plot_training_diagnostics(summaries)
plt.savefig(os.path.join(plots_out_dir, f'training_analysis_{name}.png'))
plt.show()
plt.close()

In [None]:
"""
Coverage plots for each model
"""
metric = PosteriorCoverage(
    num_samples=int(4e3),
    sample_method='direct',
    # sample_method="slice_np_vectorized",
    # sample_params={'num_chains': 1},
    # sample_method="vi",
    # sample_params={"dist": "maf", "n_particles": 32, "learning_rate": 1e-2},
    labels=cam.labels,
    plot_list=["coverage", "histogram", "predictions", "tarp"],
    # out_dir="../plots/",
)

fig = metric(
    posterior=posterior_ensemble,
    x=x_all[test_mask].cpu(),
    theta=theta[test_mask, :].cpu(),
    signature=f"coverage_{name}_",
)

fig[3].axes[0].set_xlim(0,1)
fig[3].axes[0].set_ylim(0,1)

# Save figures with descriptive names
plot_names = ["coverage", "histogram", "predictions", "TARP"]
for f, name_suffix in zip(fig, plot_names):
    save_path = os.path.join(plots_out_dir, f'coverage_{name}_plot_{name_suffix}.png')
    f.savefig(save_path, bbox_inches='tight', dpi=200)
    