In [None]:
# Key Concepts
'''
1. Forward problem: θ → x (simulation)
2. Inverse problem: x → θ (what we're solving)
3. Posterior: P(θ|x) probability distribution over parameters
4. Prior: Initial assumptions about parameter ranges
'''
import pandas as pd
import numpy as np
import sys

# seaparate into train and test set.
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
#from torch.distributions import Uniform, ExpTransform, TransformedDistribution #, AffineTransform
import torch.nn as nn
from sklearn.preprocessing import Normalizer
import joblib
import os 
import ili
from ili.dataloaders import NumpyLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage#, PlotSinglePosterior

from sbi.utils.user_input_checks import process_prior

sys.path.append("/disk/xray15/aem2/camels/proj2")
from setup_params_1P import plot_uvlf, plot_colour
from setup_params_SB import *
from priors_SB import initialise_priors_SB28

from variables_config_28 import uvlf_limits, n_bins_lf, colour_limits, n_bins_colour
# parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
model = "IllustrisTNG"
spec_type = "attenuated"
sps = "BC03"
snap = ["044"]
bands = "all" # or just GALEX?

# lets try UVLF and colours this time.
colours = False  
luminosity_functions = True
name = f"{model}_{bands}_{sps}_{spec_type}_{n_bins_lf}_{n_bins_colour}"

# initialize CAMELS and load parameter info using camels.py
cam = camels(model=model, sim_set='SB28')

if colours and not luminosity_functions:
    model_out_dir = "/disk/xray15/aem2/data/28pams/IllustrisTNG/SB/models/colours_only/"
    plots_out_dir = "/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/sbi_plots/colours_only/"
    
elif luminosity_functions and not colours:
    model_out_dir = "/disk/xray15/aem2/data/28pams/IllustrisTNG/SB/models/lf_only/"
    plots_out_dir = "/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/sbi_plots/lfs_only/"

elif colours and luminosity_functions:
    model_out_dir = "/disk/xray15/aem2/data/28pams/IllustrisTNG/SB/models/colours_lfs/"
    plots_out_dir = "/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/sbi_plots/colours_lfs/"

# You might want to add an else for safety:
else:
    raise ValueError("At least one of colours or luminosity_functions must be True")

print("Saving model in ", model_out_dir)
print("Saving plots in ", plots_out_dir)


In [None]:
# parameter info file (df_info) is used for defining priors
# the actual parameter values come from the camels class which reads CosmoAstroSeed_IllustrisTNG_L25n256_SB28.txt

#  parameters defined here: /disk/xray15/aem2/data/28pams/IllustrisTNG/SB/CosmoAstroSeed_IllustrisTNG_L25n256_SB28.txt which is used for theta
df_pars = pd.read_csv('/disk/xray15/aem2/data/28pams/IllustrisTNG/SB/CosmoAstroSeed_IllustrisTNG_L25n256_SB28.txt', delim_whitespace=True)
print(df_pars)


# prior values come from this:
df_info = pd.read_csv("/disk/xray15/aem2/data/28pams/Info_IllustrisTNG_L25n256_28params.txt")
print(df_info)



In [None]:
theta = df_pars.iloc[:, 1:29].to_numpy()  # excluding 'name' column and 'seed' column

print(theta)
print(theta.shape)

In [None]:
print("Column names:")
print(df_pars.columns.tolist())

In [None]:
# plot the first one (omega0) to see shape of prior:
plt.hist(theta[:, 24])

In [None]:
if __name__ == "__main__":
    theta, x = get_theta_x_SB()
    print(theta.shape, x.shape)

In [None]:
if colours:
    fig = plot_colour(x)
    plt.savefig('/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/colours_test/colour_check.png')
    plt.show()

In [None]:
if luminosity_functions:
    fig = plot_uvlf(x)
    plt.savefig('/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/LFs_test/uvlf_check.png')
    plt.show()

In [None]:

# get the priors and data
prior = initialise_priors_SB28(
    df=df_info, 
    device=device,
    astro=True,
    dust=False  # no dust for testing. set to False to only get the 28 model parameters.
    # with dust = True, prior has 32 dimensions (28 parameters + 4 dust parameters) 
)

# process the data
x_all = np.array([np.hstack(_x) for _x in x])
x_all = torch.tensor(x_all, dtype=torch.float32, device=device)

print("Theta shape:", theta.shape)
print("X shape:", x_all.shape)

In [None]:

# Move data to GPU as early as possible
x_all = x_all.to(device)
x_all


In [None]:
theta = torch.tensor(theta, dtype=torch.float32, device=device)
theta

In [None]:
# Handle NaN values and normalize while on GPU
x_all_cpu = x_all.cpu().numpy()  # Only move to CPU when necessary for sklearn
x_all_cpu


In [None]:
print("Data shape before processing:", x_all_cpu.shape)
print("Number of values:",(x_all_cpu).sum())
print("Number of NaN values:", np.isnan(x_all_cpu).sum())
print("Number of infinite values:", np.isinf(x_all_cpu).sum())

# how many nan values are there? if they are all nan something has gone horribly wrong.
# this looks better - 18th Nov

In [None]:

# get rid of NaN/inf values, replace with small random noise
nan_mask = np.isnan(x_all_cpu) | np.isinf(x_all_cpu)
nan_mask



In [None]:

if nan_mask.any():
    x_all_cpu[nan_mask] = np.random.rand(np.sum(nan_mask)) * 1e-10


In [None]:
x_all_cpu

In [None]:
print("Data shape before processing:", x_all_cpu.shape)
print("Number of NaN values:", np.isnan(x_all_cpu).sum())
print("Number of infinite values:", np.isinf(x_all_cpu).sum())


In [None]:

# Normalize
'''
With normalization:
- All values brought to similar scale
- Neural network can learn more effectively
- No single bin dominates the learning
'''

norm = Normalizer()

# Option: Add small constant before normalizing
epsilon = 1e-10  # Small constant
x_all_shifted = x_all_cpu + epsilon
x_all_normalized = norm.fit_transform(x_all_shifted)
x_all = torch.tensor(x_all_normalized, dtype=torch.float32, device=device)

'''
x_all_normalized = norm.fit_transform(x_all_cpu)
x_all = torch.tensor(x_all_normalized, dtype=torch.float32, device=device)
'''
x_all

In [None]:
# Add some diagnostics for the normalized data
def analyze_normalization(x_all):
    """Analyze the normalized data distribution"""
    x_numpy = x_all.cpu().numpy()
    
    print("Normalization Statistics:")
    print(f"Mean: {np.mean(x_numpy):.6f}")
    print(f"Std: {np.std(x_numpy):.6f}")
    print(f"Min: {np.min(x_numpy):.6f}")
    print(f"Max: {np.max(x_numpy):.6f}")
    print(f"Zero elements: {np.sum(x_numpy == 0)} out of {x_numpy.size}")
    
    # Plot distribution
    plt.figure(figsize=(10, 5))
    plt.hist(x_numpy.flatten(), bins=50, density=True)
    plt.title('Distribution of Normalized Values')
    plt.xlabel('Value')
    plt.ylabel('Density')
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(plots_out_dir, 'normalization_distribution.png'))
    plt.close()

analyze_normalization(x_all)

# zero elements might refer to:
# empty magnitude bins (no galaxies in that magnitude range)
# below detection limit regions
# is actually normal for UVLFs

In [None]:

# Save normalizer
joblib.dump(norm, f'/disk/xray15/aem2/data/28pams/IllustrisTNG/SB/models/{name}_scaler.save')

# Print final check
print("Any NaN in normalized data:", torch.isnan(x_all).any().item())
print("Any inf in normalized data:", torch.isinf(x_all).any().item())


In [None]:

# make test mask
test_mask = create_test_mask() # 10% testing
test_mask


In [None]:
train_mask = ~test_mask # 90% for training
train_mask


In [None]:
# NPE: train a neural network to learn the mapping between the observed data and the posterior distribution of the parameters
# Use simulation-based approaches (like the CAMELS simulations you mentioned) to generate many realizations of the observed data and the corresponding model parameters.
# Train a neural network to take the observed data as input and output the parameters of the posterior distribution (e.g. mean, variance) for those parameters.
# Once the neural network is trained, you can apply it to the actual observed data to obtain estimates of the posterior distributions of the model parameters.

# Training arguments
train_args = {
    "training_batch_size": 64, # changed from 4 to 10 as dealing with more sims, want it to be faster for initial testing.
    "learning_rate": 5e-5,
    "stop_after_epochs": 5, # loss function. waits to see if things improve.
    "validation_fraction": 0.1,  # creates another split within the training data for validation
}

# Configure network 
hidden_features = 100
num_transforms = 8

net = ili.utils.load_nde_sbi(
    engine="NPE",                       # Neural Posterior Estimation
    model="nsf",                        # Neural Spline Flow
    hidden_features=hidden_features,    # Network width
    num_transforms=num_transforms,      # Network depth
    # Remove device parameter as it's not allowed
)



# Data loader
loader = NumpyLoader(

    # x = x_all[train_mask]
    # theta=theta[train_mask]
    # clone - makes new memory allocation for this version of x/theta
    # detach - doesnt affect computations on theta that were done previously (not to mess with test/train versions)
    x=x_all[train_mask].clone().detach(),
    theta=theta[train_mask].clone().detach()
)


# Runner setup with device specified here
runner = InferenceRunner.load(
    backend="sbi",
    engine="NPE",
    prior=prior,
    nets=[net], # nets
    device=device,  # Device specified in runner, not network
    train_args=train_args,
    proposal=None,
    out_dir=model_out_dir,
    name=name
)

# Run training - 'learn the likelihood'
# this is training the neural network which will act like our likelihood!
posterior_ensemble, summaries = runner(loader=loader)

# process of training:
'''
- the neural network learns P(θ|x): probability of parameters given observations
- uses training data to learn mapping from x → θ
- then we validate on held-out portion of training data
'''


In [None]:

# Add training analysis
def analyze_training_progress(summaries):
    """Analyze and print training progress statistics"""
    train_losses = summaries[0]['training_log_probs']
    val_losses = summaries[0]['validation_log_probs']
    
    print("\nTraining Progress Analysis:")
    print("-" * 50)
    for epoch in range(len(train_losses)):
        train_loss = train_losses[epoch]
        val_loss = val_losses[epoch]
        gap = train_loss - val_loss
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, "
              f"Val Loss = {val_loss:.4f}, Gap = {gap:.4f}")

# Analyze the training
analyze_training_progress(summaries)

In [None]:
def plot_training_diagnostics(summaries):
    """Plot training diagnostics without empty subplots"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))  # Changed to 1 row, 2 columns
    
    # Loss curves
    train_losses = summaries[0]['training_log_probs']
    val_losses = summaries[0]['validation_log_probs']
    epochs = range(len(train_losses))
    
    ax1.plot(epochs, train_losses, '-', label='Training', color='blue')
    ax1.plot(epochs, val_losses, '--', label='Validation', color='red')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Log probability')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Overfitting gap
    gap = np.array(train_losses) - np.array(val_losses)
    ax2.plot(epochs, gap, '-', color='purple')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss difference')
    ax2.set_title('Overfitting Gap')
    ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Use the function
fig = plot_training_diagnostics(summaries)
plt.savefig(os.path.join(plots_out_dir, f'training_analysis_{name}.png'))
plt.show()
plt.close()

In [None]:
# plot train/validation loss
fig, ax = plt.subplots(1, 1, figsize=(6,4))
c = list(mcolors.TABLEAU_COLORS)
for i, m in enumerate(summaries):
    ax.plot(m['training_log_probs'], ls='-', label=f"{i}_train", c=c[i])
    ax.plot(m['validation_log_probs'], ls='--', label=f"{i}_val", c=c[i])
ax.set_xlim(0)
ax.set_xlabel('Epoch')
ax.set_ylabel('Log probability')
ax.legend()

In [None]:
# will this work or do we have to use it explicitly?
x_train=x_all[train_mask].clone().detach(),
theta_train=theta[train_mask].clone().detach()


In [None]:
len(x_train[0])


In [None]:
# Now, SBIRunner returns a custom class instance to be able to pass signature strings
# 1. prints our info on model configuration and architecture
print(posterior_ensemble.signatures)


# 2. choose a random input for training
seed_in = 49
np.random.seed(seed_in) # set seed for reproducability
ind = np.random.randint(len(x_train[0])) # choose observation (random index from training data)

# 3. generate posterior samples
seed_samp = 32
torch.manual_seed(seed_samp)# set seed for reproducability
# then, for the chosen training sample (as chosen above in 2.)
# generate 1000 samples from the posterior distribution using accept/reject sampling
samples = posterior_ensemble.sample(
    (1000,), 
    torch.Tensor(x_train[0][ind]).to(device))

# 4. calculate the probability densities for each sample
# i.e for each generated sample, calculate how likely it is using learned posterior distribution
log_prob = posterior_ensemble.log_prob(
    samples, # the generated samples from 3.
    torch.Tensor(x_train[0][ind]).to(device) # the chosen observation from 2.
    )

# convert to numpy so can read easier.
samples = samples.cpu().numpy()
log_prob = log_prob.cpu().numpy()

# Get parameter names from DataFrame columns, excluding 'name' and 'seed'
param_names = df_pars.columns[1:5].tolist() # first 5 to test

def plot_posterior_samples(samples, log_prob, param_names):
    """
    Plot the posterior distributions for each parameter
    """
    n_params = len(param_names)
    fig, axes = plt.subplots(n_params, 1, figsize=(10, 4*n_params))
    
    for i, (ax, name) in enumerate(zip(axes, param_names)):
        # Plot histogram of samples
        ax.hist(samples[:, i], bins=50, density=True, alpha=0.6)
        ax.set_xlabel(name)
        ax.set_ylabel('Density')
        
        # Add mean and std
        mean = samples[:, i].mean()
        std = samples[:, i].std()
        ax.axvline(mean, color='r', linestyle='--')
        ax.text(0.02, 0.95, f'Mean: {mean:.3f}\nStd: {std:.3f}', 
                transform=ax.transAxes, verticalalignment='top')
    
    plt.tight_layout()
    return fig

# Print parameter names to verify
print("Parameter names:", param_names)

# Create the plot
fig = plot_posterior_samples(samples, log_prob, param_names)

# adding log flag to better interpret results:
def plot_posterior_samples(samples, log_prob, param_names, df_info):
    """
    Plot the posterior distributions accounting for LogFlag from df_info
    """
    n_params = len(param_names)
    fig, axes = plt.subplots(n_params, 1, figsize=(10, 4*n_params))
    
    for i, (ax, name) in enumerate(zip(axes, param_names)):
        data = samples[:, i]
        param_info = df_info[df_info['ParamName'] == name].iloc[0]
        is_log = param_info['LogFlag'] == 1
        
        if is_log:
            ax.hist(data, bins=50, density=True, alpha=0.6)
            ax.set_xscale('log')
            log_data = np.log10(data)
            mean = np.mean(log_data)
            std = np.std(log_data)
            stats_text = f'Log10 Mean: {mean:.3f}\nLog10 Std: {std:.3f}'
            
            ax.axvline(param_info['MinVal'], color='g', linestyle=':', alpha=0.5, label='Min')
            ax.axvline(param_info['MaxVal'], color='g', linestyle=':', alpha=0.5, label='Max')
            ax.axvline(param_info['FiducialVal'], color='r', linestyle='--', alpha=0.5, label='Fiducial')
            
        else:
            ax.hist(data, bins=50, density=True, alpha=0.6)
            mean = np.mean(data)
            std = np.std(data)
            stats_text = f'Mean: {mean:.3f}\nStd: {std:.3f}'
            
            ax.axvline(param_info['MinVal'], color='g', linestyle=':', alpha=0.5, label='Min')
            ax.axvline(param_info['MaxVal'], color='g', linestyle=':', alpha=0.5, label='Max')
            ax.axvline(param_info['FiducialVal'], color='r', linestyle='--', alpha=0.5, label='Fiducial')
        
        # Add statistics text in top left
        ax.text(0.02, 0.95, stats_text, transform=ax.transAxes, 
                verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
        
        ax.set_xlabel(f"{name}\n{param_info['Description']}")
        ax.set_ylabel('Density')
        
        # Place legend in top right
        ax.legend(loc='upper right', bbox_to_anchor=(0.98, 0.98))
        
    plt.tight_layout()
    return fig

# Create the plot using df_info
fig = plot_posterior_samples(samples, log_prob, param_names, df_info)

In [None]:
def plot_posterior_samples_grid(samples, log_prob, param_names, df_info):
    """
    Plot the posterior distributions in a grid layout
    """
    n_params = len(param_names)
    n_cols = 4  # 4 columns for 28 parameters
    n_rows = (n_params + n_cols - 1) // n_cols  # Ceiling division
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 5*n_rows))
    axes = axes.flatten()
    
    for i, (ax, name) in enumerate(zip(axes, param_names)):
        data = samples[:, i]
        param_info = df_info[df_info['ParamName'] == name].iloc[0]
        is_log = param_info['LogFlag'] == 1
        
        if is_log:
            ax.hist(data, bins=50, density=True, alpha=0.6)
            ax.set_xscale('log')
            log_data = np.log10(data)
            mean = np.mean(log_data)
            std = np.std(log_data)
            stats_text = f'Log10 Mean: {mean:.3f}\nLog10 Std: {std:.3f}'
        else:
            ax.hist(data, bins=50, density=True, alpha=0.6)
            mean = np.mean(data)
            std = np.std(data)
            stats_text = f'Mean: {mean:.3f}\nStd: {std:.3f}'
        
        # Add parameter limits
        ax.axvline(param_info['MinVal'], color='g', linestyle=':', alpha=0.5, label='Min')
        ax.axvline(param_info['MaxVal'], color='g', linestyle=':', alpha=0.5, label='Max')
        ax.axvline(param_info['FiducialVal'], color='r', linestyle='--', alpha=0.5, label='Fiducial')
        
        # Add statistics text in top left
        ax.text(0.02, 0.95, stats_text, transform=ax.transAxes, 
                verticalalignment='top', fontsize=8, bbox=dict(facecolor='white', alpha=0.8))
        
        # Make title from parameter name and description
        ax.set_title(f"{name}\n{param_info['Description']}", fontsize=8, pad=5)
        ax.tick_params(labelsize=8)
        
        # Only show legend for first plot
        if i == 0:
            ax.legend(loc='upper right', bbox_to_anchor=(0.98, 0.98), fontsize=8)
    
    # Remove any empty subplots
    for j in range(i+1, len(axes)):
        fig.delaxes(axes[j])
    
    plt.tight_layout()
    return fig

# Get all parameter names
param_names = df_info['ParamName'].tolist()

# Create the grid plot
fig = plot_posterior_samples_grid(samples, log_prob, param_names, df_info)
plt.savefig(os.path.join(plots_out_dir, f'parameter_posteriors_grid_{name}.png'), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
from matplotlib.gridspec import GridSpec
def plot_posterior_samples_grid(samples, log_prob, param_names, df_info, model_name, train_args):
    """
    Plot the posterior distributions in a grid layout with model info
    """
    n_params = len(param_names)
    n_cols = 4
    n_rows = (n_params + n_cols - 1) // n_cols
    
    # Create figure
    fig = plt.figure(figsize=(20, 5*n_rows))
    
    # Create GridSpec with extra space at top for info
    gs = GridSpec(n_rows, n_cols, figure=fig)
    
    # Add model info text to figure (not in grid)
    model_info = (
        f"Model Config:\n"
        f"Name: {model_name}\n"
        f"Hidden Features: {hidden_features}\n"
        f"Num Transforms: {num_transforms}\n"
        f"\nTraining Args:\n"
        f"Batch Size: {train_args['training_batch_size']}\n"
        f"Learning Rate: {train_args['learning_rate']}\n"
        f"Stop After Epochs: {train_args['stop_after_epochs']}"
    )
    
    fig.text(0.02, 0.98, model_info, 
             fontsize=8,
             bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
             verticalalignment='top')
    
    # Plot parameters
    for i, name in enumerate(param_names):
        row = i // n_cols
        col = i % n_cols
        
        ax = fig.add_subplot(gs[row, col])
        data = samples[:, i]
        param_info = df_info[df_info['ParamName'] == name].iloc[0]
        is_log = param_info['LogFlag'] == 1
        
        if is_log:
            ax.hist(data, bins=50, density=True, alpha=0.6)
            ax.set_xscale('log')
            log_data = np.log10(data)
            mean = np.mean(log_data)
            std = np.std(log_data)
            stats_text = f'Log10 Mean: {mean:.3f}\nLog10 Std: {std:.3f}'
        else:
            ax.hist(data, bins=50, density=True, alpha=0.6)
            mean = np.mean(data)
            std = np.std(data)
            stats_text = f'Mean: {mean:.3f}\nStd: {std:.3f}'
        
        # Add parameter limits
        ax.axvline(param_info['MinVal'], color='g', linestyle=':', alpha=0.5, label='Min')
        ax.axvline(param_info['MaxVal'], color='g', linestyle=':', alpha=0.5, label='Max')
        ax.axvline(param_info['FiducialVal'], color='r', linestyle='--', alpha=0.5, label='Fiducial')
        
        ax.text(0.02, 0.95, stats_text, transform=ax.transAxes, 
                verticalalignment='top', fontsize=8, 
                bbox=dict(facecolor='white', alpha=0.8))
        
        ax.set_title(f"{name}\n{param_info['Description']}", fontsize=8, pad=5)
        ax.tick_params(labelsize=8)
        
        if i == 0:
            ax.legend(loc='upper right', fontsize=8)
    
    plt.tight_layout()
    # Adjust layout to make room for info text
    plt.subplots_adjust(top=0.95)
    return fig

# Get all parameter names from df_info
param_names = df_info['ParamName'].tolist()

# Now try plotting again with the correct parameter names
fig = plot_posterior_samples_grid(
    samples, 
    log_prob, 
    param_names,  # Now contains all 28 parameter names correctly
    df_info,
    model_name=name,
    train_args=train_args
)

# Save with model config in filename
save_name = (f'parameter_posteriors_grid_{name}_'
            f'h{hidden_features}_t{num_transforms}_'
            f'b{train_args["training_batch_size"]}_'
            f'e{train_args["stop_after_epochs"]}.png')

os.makedirs(plots_out_dir, exist_ok=True)
plt.savefig(os.path.join(plots_out_dir, save_name), 
            dpi=300, 
            bbox_inches='tight')

In [None]:
"""
Coverage plots for each model
"""
metric = PosteriorCoverage(
    num_samples=int(4e3),
    sample_method='direct',
    # sample_method="slice_np_vectorized",
    # sample_params={'num_chains': 1},
    # sample_method="vi",
    # sample_params={"dist": "maf", "n_particles": 32, "learning_rate": 1e-2},
    labels=cam.labels,
    plot_list=["coverage", "histogram", "predictions", "tarp"],
    out_dir=plots_out_dir,
)
metric

In [None]:

# 6. Evaluation Metrics
'''
- Coverage: How often true parameters fall within predicted ranges:
-- perfect diagonal line means perfect coverage
-- points above diaganol means over-confident predictions
-- points below diagonal means under-confident
-- large deviations from the diaganol means poorly calibrated model

- Histogram:
-- uniform (straight) distributions of p-values means well-calibrated model
-- u-shaped dist: under-confident model
-- bell shaped: over confident model

- Predictions: Compare model predictions with observations
-- True values from your test set (these are your "observations")
-- Model's predicted distributions for each parameter
-- Predicted distributions centered on true values
-- Error bar estimates that accurately capture the true values
-- Scatter along diagonal: Good predictions
-- Systematic offset: Bias in predictions
-- Wide spread: High uncertainty
-- Clustering in certain regions: Model performs better for some parameter ranges
# The predictions are comparing:
# The true parameter values used in your simulations (like Omega0, sigma8, etc.) vs
# What your neural posterior estimation (NPE) model predicts these values should be
# based on the observables (your luminosity functions and/or colors)

- TARP: Total Absolute Relative Probability
-- lower values indicate better calibration
-- comparing across parameters can help identify which are harder to predict
-- high TARP values suggest need for model improvement for those parameters

'''
# use test data here.
fig = metric(
    posterior=posterior_ensemble,
    x=x_all[test_mask].cpu(),
    # theta=theta[test_mask].cpu(),
    theta=theta[test_mask, :].cpu(),
    signature=f"coverage_{name}_",
)