In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import sys
from sklearn.model_selection import train_test_split
import GPUtil

import matplotlib.colors as mcolors
import torch
import torch.nn as nn
from sklearn.preprocessing import Normalizer
import joblib
import ili  # Import ili for the SBI functionality
from ili.dataloaders import NumpyLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage, PlotSinglePosterior

from sbi.utils.user_input_checks import process_prior

sys.path.append("/disk/xray15/aem2/camels/proj2")
from setup_params_1P import plot_uvlf, plot_colour
from setup_params_SB import *
from priors_SB import initialise_priors_SB28

from variables_config_28 import uvlf_limits, n_bins_lf, colour_limits, n_bins_colour
# parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
model = "IllustrisTNG"
spec_type = "attenuated"
sps = "BC03"
snap = ["044"]
bands = "all" # or just GALEX?

name = f"{model}_{bands}_{sps}_{spec_type}_{n_bins_lf}_{n_bins_colour}"

cam = camels(model=model, sim_set='SB28')

colours = False
luminosity_functions = True

if colours and not luminosity_functions: # colours only
    dir_to_inspect = 'grid_search_results_colors_20241201_121508'
    plots_out_dir = '/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/sbi_best/colours_only' 

elif luminosity_functions and not colours: # lfs
    dir_to_inspect = 'grid_search_results_20241201_105510'
    plots_out_dir = '/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/sbi_best/lfs_only' 

elif colours and luminosity_functions: # both
    dir_to_inspect = '/disk/xray15/aem2/grid_search_results_colors_lfs_20241201_121621'
    plots_out_dir = '/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/sbi_best/colours_lfs' 

else:
    raise ValueError("At least one of colours or luminosity_functions must be True")

print("Saving plots in ", plots_out_dir)



Get x and theta

In [None]:
# parameter info file (df_info) is used for defining priors
# the actual parameter values come from the camels class which reads CosmoAstroSeed_IllustrisTNG_L25n256_SB28.txt

#  parameters defined here: /disk/xray15/aem2/data/28pams/IllustrisTNG/SB/CosmoAstroSeed_IllustrisTNG_L25n256_SB28.txt which is used for theta
df_pars = pd.read_csv('/disk/xray15/aem2/data/28pams/IllustrisTNG/SB/CosmoAstroSeed_IllustrisTNG_L25n256_SB28.txt', delim_whitespace=True)
print(df_pars)


# prior values come from this:
df_info = pd.read_csv("/disk/xray15/aem2/data/28pams/Info_IllustrisTNG_L25n256_28params.txt")
print(df_info)

theta = df_pars.iloc[:, 1:29].to_numpy()  # excluding 'name' column and 'seed' column

print(theta)
print(theta.shape)
print("Column names:")
print(df_pars.columns.tolist())

if __name__ == "__main__":
    theta, x = get_theta_x_SB(
        luminosity_functions=luminosity_functions,
        colours=colours  # This will now override the default
    )
    print(theta.shape, x.shape)
    
if colours:
    fig = plot_colour(x)
    #plt.savefig('/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/colours_test/colours/colour_check.png')
    plt.show()

if luminosity_functions:
    fig = plot_uvlf(x)
    #plt.savefig('/disk/xray15/aem2/plots/28pams/IllustrisTNG/SB/test/LFs_test/uvlf_check.png')
    plt.show()


# get the priors and data
prior = initialise_priors_SB28(
    df=df_info, 
    device=device,
    astro=True,
    dust=False  # no dust for testing. set to False to only get the 28 model parameters.
    # with dust = True, prior has 32 dimensions (28 parameters + 4 dust parameters) 
)

# process the data
x_all = np.array([np.hstack(_x) for _x in x])
x_all = torch.tensor(x_all, dtype=torch.float32, device=device)

print("Theta shape:", theta.shape)
print("X shape:", x_all.shape)


# Move data to GPU as early as possible
x_all = x_all.to(device)
print('x_all:', x_all)

theta = torch.tensor(theta, dtype=torch.float32, device=device)
print('theta:', theta)

# Handle NaN values and normalize while on GPU
x_all_cpu = x_all.cpu().numpy()  # Only move to CPU when necessary for sklearn
print('x_all_cpu:', x_all_cpu)

print("Data shape before processing:", x_all_cpu.shape)
print("Number of values:",(x_all_cpu).sum())
print("Number of NaN values:", np.isnan(x_all_cpu).sum())
print("Number of infinite values:", np.isinf(x_all_cpu).sum())


# get rid of NaN/inf values, replace with small random noise
nan_mask = np.isnan(x_all_cpu) | np.isinf(x_all_cpu)
print('nan_mask:', nan_mask)


if nan_mask.any():
    x_all_cpu[nan_mask] = np.random.rand(np.sum(nan_mask)) * 1e-10

print("Data shape before processing:", x_all_cpu.shape)
print("Number of NaN values:", np.isnan(x_all_cpu).sum())
print("Number of infinite values:", np.isinf(x_all_cpu).sum())

print('x_all_cpu:', x_all_cpu)


# Normalize
norm = Normalizer()

# Option: Add small constant before normalizing
epsilon = 1e-10  # Small constant
x_all_shifted = x_all_cpu + epsilon
x_all_normalized = norm.fit_transform(x_all_shifted)
x_all = torch.tensor(x_all_normalized, dtype=torch.float32, device=device)

print('x_all:', x_all)

# make test mask
test_mask = create_test_mask() # 10% testing
train_mask = ~test_mask # 90% for training


Get best nn from parameter search

In [None]:
best_results = {}
best_val_loss = -1e10

for root, dirs, files in os.walk(dir_to_inspect, topdown=False):
   for name in files:
      if name == 'results.pkl':
        with open(os.path.join(root, name), 'rb') as f:
            x = pickle.load(f)
          
        # find 
        val_loss = max(x['summaries'][0]['validation_log_probs']) # pick value closest to 0
        if val_loss > best_val_loss:
          best_val_loss = val_loss
          best_results = x.copy()
          best_nn = files
          print(files)

In [None]:
def create_run_name(params):
    """Create a descriptive run name from parameters"""
    return f"hf{params['hidden_features']}_nt{params['num_transforms']}_bs{params['training_batch_size']}_lr{params['learning_rate']:.0e}_nets{params['num_nets']}"

config_str = create_run_name(best_results['parameters'])
config_str

In [None]:
plt.plot(best_results['summaries'][0]['validation_log_probs'])
plt.plot(best_results['summaries'][0]['training_log_probs'])


In [None]:
posterior_ensemble = best_results['posterior_ensemble']

Get summaries

In [None]:

# Get test data
x_test = x_all[test_mask]
theta_test = theta[test_mask]

# Number of samples to draw from posterior
n_samples = 1000

# Storage for predictions
all_samples = []
all_means = []
all_stds = []

# Generate posterior samples for each test point
for i in range(len(x_test)):
    # Get samples from the posterior
    samples = posterior_ensemble.sample(
        (n_samples,), 
        x=x_test[i].reshape(1, -1)
    ).cpu().numpy()
    
    # Calculate mean and std of samples
    mean = samples.mean(axis=0)
    std = samples.std(axis=0)
    
    all_samples.append(samples)
    all_means.append(mean)
    all_stds.append(std)

all_samples = np.array(all_samples)
all_means = np.array(all_means)
all_stds = np.array(all_stds)

Make plots

In [None]:
x_train=x_all[train_mask].clone().detach(),
theta_train=theta[train_mask].clone().detach()


In [None]:
print(plots_out_dir)

In [None]:
param_names = df_pars.columns[1:29].tolist()  # Excluding 'name' column

fig, axes = plt.subplots(7, 4, figsize=(16, 28)) 
axes = axes.flatten()

fontsize = 10  

plt.rcParams['figure.constrained_layout.use'] = True  

# Plot each parameter
for i in range(28):
    ax = axes[i]
    
    # True vs predicted with error bars
    ax.errorbar(
        theta_test[:, i].cpu().numpy(),
        all_means[:, i],
        yerr=all_stds[:, i],
        fmt='.',
        color='k',
        ecolor='blue',
        capsize=0,
        elinewidth=0.8,  
        alpha=0.3,       
        markersize=5    
    )
    
    # Add true line
    lims = [
        min(ax.get_xlim()[0], ax.get_ylim()[0]),
        max(ax.get_xlim()[1], ax.get_ylim()[1])
    ]
    ax.plot(lims, lims, '--', color='black', alpha=0.5, linewidth=1)
    
    # get metrics
    rmse = np.sqrt(np.mean((theta_test[:, i].cpu().numpy() - all_means[:, i])**2))
    r2 = np.corrcoef(theta_test[:, i].cpu().numpy(), all_means[:, i])[0, 1]**2
    chi2 = np.mean(((theta_test[:, i].cpu().numpy() - all_means[:, i])**2) / (all_stds[:, i]**2))
    
    # add metrics box in top left corner
    stats_text = f'RMSE = {rmse:.2f}\n' + \
                 f'R² = {r2:.2f}\n' + \
                 f'χ² = {chi2:.2f}'
    ax.text(0.05, 0.95, stats_text,
            transform=ax.transAxes,
            bbox=dict(facecolor='white', alpha=0.8),
            verticalalignment='top',
            fontsize=fontsize-1)  # Slightly smaller font for stats
    
    # title: parameter name
    ax.set_title(param_names[i], fontsize=fontsize, pad=5)  # Reduced padding
    
    # axis labels
    ax.set_xlabel('True', fontsize=fontsize-1)
    ax.set_ylabel('Inferred', fontsize=fontsize-1)
    
    # tick labels
    ax.tick_params(axis='both', which='major', labelsize=fontsize-2)
    
    # internal padding
    ax.margins(x=0.05, y=0.05)

# subplot spacing
plt.subplots_adjust(
    left=0.01,    # Less space on left
    right=0.7,   # Less space on right
    bottom=0.05,  # Less space at bottom
    top=0.7,     # Less space at top
    wspace=0.2,   # Less space between plots horizontally
    hspace=0.2    # Less space between plots vertically
)


# Save figure with detailed filename
save_path = f'{plots_out_dir}/posterior_predictions_{config_str}.png'
plt.savefig(save_path, dpi=300, bbox_inches='tight', pad_inches=0.1)
print(save_path)
plt.show()

In [None]:

# coverage plots
metric = PosteriorCoverage(
    num_samples=int(4e3),
    sample_method='direct',
    labels=cam.labels,
    plot_list=["tarp"], # "coverage", "histogram", "predictions", 
    out_dir=plots_out_dir,
)

# Generate plots
figs = metric(
    posterior=posterior_ensemble,
    x=x_all[test_mask].cpu(),
    theta=theta[test_mask, :].cpu(),
    signature=f"coverage_{name}_{config_str}_"  # Add config to filename
)

config_text = config_str

# Process each figure
for i, fig in enumerate(figs):
    plt.figure(fig.number)  # Activate the figure
    plt.figtext(0.02, 0.98, config_text,
                fontsize=8,
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
                verticalalignment='top')
    
    # Save each figure with type indicator
    plot_types = ["tarp"] #"coverage", "histogram", "predictions",
    plt.savefig(os.path.join(plots_out_dir, 
                f'metric_{plot_types[i]}_{name}_{config_str}_bestnn_test.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()