# Import libraries 

In [None]:
# Jax dedicated libraries
from flax import nnx
import optax
import jax
import jax.numpy as jnp # From this point on, there should not be numpy anymore but only jax.numpy
import jax.scipy as jsp

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import corner
import numpy as np

# Module functions
import ximinf.nn_train as nntr
import ximinf.nn_test as nnte
import ximinf.generate_sim as gsim

# Other
import h5py
import gc

from absl import logging

logging.set_verbosity(logging.ERROR)

# Set seed

In [None]:
key = jax.random.PRNGKey(42)

# Set device type

In [None]:
# Try GPU backends in priority order
gpu = None
for backend in ("METAL", "cuda", "gpu"):
    try:
        devs = jax.devices(backend)
    except RuntimeError:
        continue
    if devs:
        gpu = devs[0]
        break

# Fallback
cpu = jax.devices("cpu")[0]

# Use GPU if found
if gpu=='cuda':
    nntr.print_gpu_memory()
    device = gpu 
elif gpu is not None:
    device = gpu
else:
    device = cpu
    
jax.default_device(device)
    
backend = jax.default_backend()
print(backend)

# Import training data

In [None]:
# Path to the new HDF5 file
# file_path = "./data/SIM/simulations_10000_2_1000_1000_brok_alpha_beta_gamma_sigma_XS_gaussian_priors_cosmo_XL_errors_z.h5"
file_path = "../data/SIM/simulations_2000_5_1000_1000_brok_alpha_beta_gamma_sigma_XS_gaussian_exponential_priors_errors_z.h5" #simulations_10000_2_1000_1000_brok_alpha_beta_gamma_sigma_XS_gaussian_priors_cosmo_XL_errors_z

with h5py.File(file_path, "r") as f:

    # Load parameters
    params = {k: jnp.array(f["params"][k][:], dtype=jnp.float32) for k in f["params"].keys()}

    # Load data
    data = {k: jnp.array(f["data"][k][:], dtype=jnp.float32) for k in f["data"].keys()}

    # Load priors
    priors = {}
    for name in f["priors"].keys():
        grp = f["priors"][name]
        priors[name] = {
            "range": jnp.array(grp["range"][:], dtype=jnp.float32),
            "type": grp.attrs["type"]
        }

# Number of simulations and SNe per simulation
N = next(iter(params.values())).shape[0]   # N simulations
M = next(iter(data.values())).shape[1]     # M SNe per simulation

print(f"The file contains {N} simulations of size {M}")
print("Loaded priors:", priors)

print('Removing cosmology ...')
jax.config.update("jax_enable_x64", True)
# Take care to not run the correction twice
mu_planck18, magobs = nntr.rm_cosmo(data['z'], data['magobs'], package='cosmologix') #, package='cosmologix'
print('... done')
jax.config.update("jax_enable_x64", False)

data['magobs'] = magobs
mask = magobs != 0 

# Only update 'mabs' if it exists in params
if 'mabs' in params:
    params['mabs'] = params['mabs'] + 19.3
    priors['mabs']['range'] += 19.3

gc.collect()

In [None]:
# Example usage
print("Parameter names:", list(params.keys()))
param_groups = ['mabs', ['alpha_low', 'alpha_high'], 'beta', ['gamma', 'sigma_int']] # Sort them from easiest to hardest to infer , 'Om0'
print(param_groups[0])

# ----------------------------
# Flatten param_groups into a unique ordered list
# ----------------------------
global_param_names = []
for group in param_groups:
    group_list = [group] if isinstance(group, str) else group
    for p in group_list:
        if p not in global_param_names:
            global_param_names.append(p)

In [None]:
print("Column names:", list(data.keys()))

# Compute residuals

## Apply mask

In [None]:
# Masked version (NaNs)
data_filt = {k: jnp.where(mask, v, jnp.nan) for k, v in data.items()}

## Display residuals

In [None]:
index = 102  # index for your data slice

# Define the color maps
cmap1 = LinearSegmentedColormap.from_list(
    'custom_red_beige_blue',
    ['#1F487E', 'beige', '#A31621']
)

cmap2 = LinearSegmentedColormap.from_list(
    'custom_green_beige_purple',
    ['#687444', 'beige', '#5E4983']
)

cmap3 = LinearSegmentedColormap.from_list(
    'custom_blue_beige_orange',
    ['#1F487E', 'beige', '#C07835']
)

# Create figure and horizontal subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 4), constrained_layout=True)

# First subplot: z vs magobs
sc1 = axes[0].scatter(
    data_filt['z'][index, :],
    data_filt['magobs'][index, :],
    c=data_filt['z'][index, :],
    cmap=cmap1,
    edgecolor='k'
)
axes[0].set_title('Magnitude vs Redshift', fontsize=14)
axes[0].set_xlabel('Redshift (z)', fontsize=12)
axes[0].set_ylabel('Observed Magnitude', fontsize=12)
cbar1 = plt.colorbar(sc1, ax=axes[0])
cbar1.set_label('Redshift z', fontsize=12)

# Second subplot: c vs magobs
sc2 = axes[1].scatter(
    data_filt['c'][index, :],
    data_filt['magobs'][index, :],
    c=data_filt['c'][index, :],
    cmap=cmap2,
    edgecolor='k'
)
axes[1].set_title('Magnitude vs Color', fontsize=14)
axes[1].set_xlabel('Color (c)', fontsize=12)
axes[1].set_ylabel('Observed Magnitude', fontsize=12)
cbar2 = plt.colorbar(sc2, ax=axes[1])
cbar2.set_label('Color value', fontsize=12)

# Third subplot: x1 vs magobs
sc3 = axes[2].scatter(
    data_filt['x1'][index, :],
    data_filt['magobs'][index, :],
    c=data_filt['x1'][index, :],
    cmap=cmap3,
    edgecolor='k'
)
axes[2].set_title('Magnitude vs Stretch', fontsize=14)
axes[2].set_xlabel('Stretch (x1)', fontsize=12)
axes[2].set_ylabel('Observed Magnitude', fontsize=12)
cbar3 = plt.colorbar(sc3, ax=axes[2])
cbar3.set_label('Stretch x1', fontsize=12)

# Construct the title string dynamically
title_str = ", ".join(f"{name} = {params[name][index]:.2f}" for name in global_param_names)

fig.suptitle(
    title_str,
    fontsize=16
)

plt.show()

gc.collect() 


In [None]:
# Padded version (zeros)
data_padded = {k: jnp.where(mask, v, 0) for k, v in data.items()}

# select_columns = ['magobs', 'magobs_err', 'c', 'c_err', 'x1','x1_err', 'mass', 'mass_err', 'localcolor', 'localcolor_err', 'prompt']
select_columns = ['magobs', 'magobs_err', 'c', 'c_err', 'x1','x1_err', 'localcolor', 'localcolor_err'] # , 'z'
select_params  = list(params.keys())  # or subset you want

data_stats = {}
for col in select_columns:
    # flatten over all objects in all examples
    all_values = jnp.concatenate([data[col].ravel() for data in [data]])  # replace with training set list if needed
    mu = jnp.mean(all_values)
    sigma = jnp.std(all_values) + 1e-8
    data_stats[col] = {'mu': mu, 'sigma': sigma}

data_padded_normed = {}
for col in select_columns:
    x = data_padded[col]
    mu = data_stats[col]['mu']
    sigma = data_stats[col]['sigma']
    x_norm = (x - mu) / sigma
    x_norm = jnp.where(mask, x_norm, 0.0)
    data_padded_normed[col] = x_norm

data_padded = data_padded_normed
gc.collect()

In [None]:
param_stats = {}
for dic_key in select_params:
    all_values = jnp.concatenate([params[dic_key].ravel() for params in [params]])  # replace with training set list
    mu = jnp.mean(all_values)
    sigma = jnp.std(all_values) + 1e-8
    param_stats[dic_key] = {'mu': mu, 'sigma': sigma}

normalized_priors = {}

for name, prior in priors.items():
    mu = param_stats[name]['mu']
    sigma = param_stats[name]['sigma']

    norm_range = (prior['range'] - mu) / sigma
        
    normalized_priors[name] = {
        'range': norm_range,
        'type': prior['type']  # type stays unchanged
    }

gc.collect()

In [None]:
# data   = nntr.normalize(data, data_stats)
params = nntr.normalize(params, param_stats)

gc.collect()

# Prepare train and test sets

## Set bounds

## Label data

In [None]:
# -----------------------------------------
# Autoregressive dataset construction
# -----------------------------------------

param_names = list(params.keys())
K = len(param_names)  # total number of parameters

# Generate false parameters directly from priors
# priors['sigma_int']['type'] = 'positive-gaussian'
# priors['sigma_int']['range'] = jnp.array([0.1, 0.4])

false_params = gsim.scan_params(priors, N)

false_params = jnp.stack(
    [false_params[name] for name in param_names],
    axis=1,   # (N, K)
)

# Extract mus and sigmas in the same order
mus = jnp.array([param_stats[name]['mu'] for name in param_names])
sigmas = jnp.array([param_stats[name]['sigma'] for name in param_names])

# Normalize
false_params = (false_params - mus) / sigmas

# True params in array form
true_params = jnp.stack([params[name] for name in param_names], axis=1)  # (N,K)

# ----- Observational data concatenation (unchanged) -----
data_names = list(data_padded.keys())
n_cols = len(data_names)

data_arrays = [data_padded[name] for name in data_names]
data_stacked = jnp.stack(data_arrays, axis=-1)  # (N, M, n_cols)
data_concat  = data_stacked.reshape(N, M * n_cols)

gc.collect()

## Prepare train/test sets

In [None]:
# ----------------------------------------------------
# Global train / test split (shared across all groups)
# ----------------------------------------------------

key, split_key = jax.random.split(key)

indices = jnp.arange(N)
perm = jax.random.permutation(split_key, indices)

n_test = int(0.3 * N)
test_idx  = perm[:n_test]
train_idx = perm[n_test:]

# Slice once
data_train  = data_concat[train_idx]
data_test   = data_concat[test_idx]

param_true_train  = true_params[train_idx]
param_true_test   = true_params[test_idx]

false_params_train = false_params[train_idx]
false_params_test  = false_params[test_idx]

mask_train = mask[train_idx]
mask_test  = mask[test_idx]

gc.collect()

## Concatenate data

In [None]:
# ----------------------------------------------------
# Build parameter slices per group (correct semantics)
# ----------------------------------------------------

all_group_param_slices = []

for g, group in enumerate(param_groups):

    group_list = [group] if isinstance(group, str) else group
    group_idx = jnp.array([param_names.index(name) for name in group_list])

    prev_groups = [
        p for i in range(g)
        for p in (param_groups[i] if isinstance(param_groups[i], list) else [param_groups[i]])
    ]
    prev_idx = (
        jnp.array([param_names.index(name) for name in prev_groups], dtype=int)
        if prev_groups else jnp.array([], dtype=int)
    )

    visible_idx = (
        jnp.concatenate([prev_idx, group_idx], axis=0)
        if prev_idx.size > 0 else group_idx
    )

    # --------------------
    # Labels
    # --------------------
    key, label_key1, label_key2 = jax.random.split(key,3)
    labels_train = jax.random.uniform(label_key1, (train_idx.shape[0],)) > 0.5
    labels_test  = jax.random.uniform(label_key2, (test_idx.shape[0],))  > 0.5

    # --------------------
    # Build params (ONLY flip current group)
    # --------------------
    params_train = jnp.array(param_true_train)
    params_train = params_train.at[:, group_idx].set(
        jnp.where(
            labels_train[:, None],
            param_true_train[:, group_idx],
            false_params_train[:, group_idx],
        )
    )

    params_test = jnp.array(param_true_test)
    params_test = params_test.at[:, group_idx].set(
        jnp.where(
            labels_test[:, None],
            param_true_test[:, group_idx],
            false_params_test[:, group_idx],
        )
    )

    chosen_train = params_train[:, visible_idx]
    chosen_test  = params_test[:,  visible_idx]

    all_group_param_slices.append({
        "chosen_train": chosen_train,
        "chosen_test":  chosen_test,
        "labels_train": labels_train,
        "labels_test":  labels_test,
    })


# Build a neural network

In [None]:
# Define the size of the different network layers
Nsize_e = 32 #32
Nsize_p = 64 #64 
Nsize_r = 128 #256 

n_cols = len(data_names)
print('# of columns :', n_cols)
n_params = len(param_names)
print('# of params :', n_params)


In [None]:
# ----------------------------
# Prepare a list of models, one per group
# ----------------------------

models_per_group = []
group_configs = []
rng = nnx.Rngs(0)

for g, group in enumerate(param_groups):
    # Determine number of parameters visible for this group
    prev_groups = [
        p
        for i in range(g)
        for p in (param_groups[i] if isinstance(param_groups[i], list) else [param_groups[i]])
    ]
    group_list = [group] if isinstance(group, str) else group
    visible_param_names = prev_groups + group_list
    n_params_visible = len(visible_param_names)

    print(
        f"Group {g}: visible parameters = {visible_param_names}, "
        f"total = {n_params_visible}"
    )

    # Create DeepSetClassifier for this group
    model_g = nntr.DeepSetClassifier(
        dropout_rate=0.1,
        Nsize_p=Nsize_p,
        Nsize_r=Nsize_r,
        n_cols=n_cols,
        n_params=n_params_visible,
        rngs=rng,
    )

    models_per_group.append(model_g)

    # ---- CONFIG CAPTURE ----
    group_configs.append({
        "group_id": g,
        "n_params_visible": n_params_visible,
        "visible_param_names": visible_param_names,
    })



model_config = {
    "shared": {
        "Nsize_p": Nsize_p,
        "Nsize_r": Nsize_r,
        "Nsize_e": Nsize_e,
        # "n_cols": n_cols,
        "columns": select_columns,
        "param_groups": param_groups,
        "global_param_names": global_param_names,
        "priors": priors,
        "param_stats": param_stats,
        "data_stats": data_stats
    },
    "groups": group_configs,
}

# Optional: visualize one model
nnx.display(models_per_group[0])


# Train NN

In [None]:
# Early stopping parameters
n_batch = 100
patience = 20 #40
epochs = 1000

metrics_histories = []

# init_values = jnp.linspace(3e-4, 3e-4, len(all_group_param_slices))
# patiences   = jnp.linspace(40, 20,  len(all_group_param_slices))
# epochss = [1,1,1,1000]
# init_values = [1e-3, 1e-3, 1e-3, 1e-3]

for g, group_data in enumerate(all_group_param_slices):

    print(f"\n=== Training model for group {g}: {param_groups[g]} ===")

    chosen_train = group_data["chosen_train"]
    chosen_test  = group_data["chosen_test"]
    labels_train = group_data["labels_train"]
    labels_test  = group_data["labels_test"]

    # ------------------------------------------------
    # Concatenate inputs (no parameter logic here)
    # ------------------------------------------------
    train_data = jnp.concatenate(
        [data_train, mask_train, chosen_train], axis=-1
    )
    test_data = jnp.concatenate(
        [data_test, mask_test, chosen_test], axis=-1
    )

    train_labels = labels_train.astype(jnp.int32)[:, None]
    test_labels  = labels_test.astype(jnp.int32)[:, None]

    train_data   = jax.device_put(train_data, cpu)
    train_labels = jax.device_put(train_labels, cpu)
    test_data    = jax.device_put(test_data, cpu)
    test_labels  = jax.device_put(test_labels, cpu)

    # ------------------------------------------------
    # Optimiser
    # ------------------------------------------------
    learning_rate_schedule = optax.exponential_decay(
        init_value=5e-4,
        transition_steps=1000,
        decay_rate=0.9,
    )

    optimizer = nnx.Optimizer(
        models_per_group[g],
        optax.adamw(learning_rate_schedule, 0.9),
    )

    batch_size = train_data.shape[0] // n_batch

    model_g, metrics_history, key = nntr.train_loop(
        model=models_per_group[g],
        optimizer=optimizer,
        train_data=train_data,
        train_labels=train_labels,
        test_data=test_data,
        test_labels=test_labels,
        key=key,
        epochs=epochs, #epochs #epochss[g]
        batch_size=batch_size,
        patience=patience,
        metrics_history={
            'train_loss': [],
            'train_accuracy': [],
            'test_loss': [],
            'test_accuracy': []
        },
        M=M,
        N=int(N * 0.7),
        cpu=cpu,
        gpu=gpu,
        group_id=g,
        group_params=param_groups[g],
        plot_flag=True,
    )
    
    models_per_group[g] = model_g
    metrics_histories.append(metrics_history)


# Test NN

## Test Accuracy

In [None]:
# Set models to evaluation mode
for model_g in models_per_group:
    model_g.eval()  # disable dropout, etc.

batch_size = 128
metrics_per_group = []

# Loop over groups
for g, model_g in enumerate(models_per_group):

    print(f"\n=== Evaluating model for group {g}: {param_groups[g]} ===")

    chosen_test  = all_group_param_slices[g]["chosen_test"]
    labels_test  = all_group_param_slices[g]["labels_test"]

    num_samples = labels_test.shape[0]

    all_logits = []
    all_labels = []

    for i in range(0, num_samples, batch_size):

        xb = jnp.concatenate(
            [
                data_test[i:i + batch_size],
                mask_test[i:i + batch_size],
                chosen_test[i:i + batch_size],
            ],
            axis=-1,
        )

        yb = labels_test[i:i + batch_size, None].astype(jnp.int32)

        # Model predictions
        logits = nntr.pred_step(model_g, xb)
        all_logits.append(logits)
        all_labels.append(yb)

    # Merge batches
    all_logits = jnp.concatenate(all_logits, axis=0)
    all_labels = jnp.concatenate(all_labels, axis=0)

    all_preds = (jsp.special.expit(all_logits) > 0.5).astype(jnp.int32)

    # Confusion matrix components
    TP = jnp.sum((all_preds == 1) & (all_labels == 1))
    TN = jnp.sum((all_preds == 0) & (all_labels == 0))
    FP = jnp.sum((all_preds == 1) & (all_labels == 0))
    FN = jnp.sum((all_preds == 0) & (all_labels == 1))

    accuracy    = (TP + TN) / (TP + TN + FP + FN)
    precision   = TP / (TP + FP + 1e-8)
    sensitivity = TP / (TP + FN + 1e-8)
    specificity = TN / (TN + FP + 1e-8)

    print(
        f"Group {g} ({param_groups[g]}): "
        f"Accuracy={accuracy:.3f}, "
        f"Precision={precision:.3f}, "
        f"Sensitivity={sensitivity:.3f}, "
        f"Specificity={specificity:.3f}"
    )

    metrics_per_group.append({
        "accuracy": accuracy,
        "precision": precision,
        "sensitivity": sensitivity,
        "specificity": specificity,
    })


# Test posterior

In [None]:
for model_g in models_per_group:
    model_g.eval()  # disable dropout, etc.

# ----------------------------
# Parameter info
# ----------------------------
param_names = list(params.keys())
N_SIM_PARAMS = len(param_names)

# ----------------------------
# LAST GROUP: precomputed params + labels
# ----------------------------
chosen_test = all_group_param_slices[-1]["chosen_test"]
labels_test = all_group_param_slices[-1]["labels_test"]

# ----------------------------
# Mask for "true" samples (label == 1)
# ----------------------------
mask_true = labels_test == 1
N_sims = int(jnp.minimum(100, jnp.sum(mask_true)))

# Get indices of true samples
true_idx = jnp.nonzero(mask_true, size=N_sims, fill_value=0)[0]

NDIM = len(global_param_names)

# ----------------------------
# Construct test inputs and full theta
# ----------------------------
# visible parameters for true samples
theta_star = chosen_test[true_idx]

# inputs excluding theta (data + mask)
xy_test = jnp.concatenate(
    [data_test[true_idx, :], mask_test[true_idx, :]],
    axis=-1
)

alpha_grid = jnp.linspace(0, 1, 50)

In [None]:
index = 10

# Convert theta_star to dict for unnormalisation
theta_star_dict = {name: theta_star[index, i] for i, name in enumerate(global_param_names)}

# Unnormalize
theta_star_unnormed_dict = nntr.unnormalize(theta_star_dict, param_stats)

# Convert back to array in the same order as global_param_names
theta_star_unnormed = jnp.array([theta_star_unnormed_dict[name] for name in global_param_names])

print(param_groups)
print(theta_star_unnormed)

In [None]:
group_names_list = []
for g in param_groups:
    if isinstance(g, str):
        group_names_list.append([g])  # wrap single parameter in a list
    else:
        group_names_list.append(g)  

In [None]:
priors

In [None]:
# Select a single test sample (or batch) as input
test_data = xy_test[index, :]  # single sample, shape (n_features,)

# Initial position at the middle of priors
theta_init = theta_star[index,:] #(BOUNDS[:, 0] + BOUNDS[:, 1]) / 2.0

visible_indices, group_indices = nnte.preprocess_groups(param_groups, global_param_names)

priors_inference = priors

normalized_priors_inference = {}

for name, prior in priors_inference.items():
    mu = param_stats[name]['mu']
    sigma = param_stats[name]['sigma']

    norm_range = (prior['range'] - mu) / sigma
    
    normalized_priors_inference[name] = {
        'range': norm_range,
        'type': prior['type']  # type stays unchanged
    }


In [None]:
 def log_post(theta):
    # Use the new grouped log-prob function
    return nnte.log_prob_fn_groups(
        theta,
        models_per_group,  # list of models per group
        test_data,
        normalized_priors_inference,
        visible_indices,
        group_indices,
        group_names_list
    )

In [None]:
# Run MCMC
print("Launch MCMC ...")
key, post = nnte.sample_posterior(
    log_post,
    n_warmup=100,
    n_samples=100,
    init_position=theta_init,
    rng_key=key
)
print("...finished")

In [None]:
n_samples, n_params = post.shape

# Convert post from array to dict with column names
post_dict = {name: post[:, i] for i, name in enumerate(global_param_names)}

# Unnormalize each parameter
post_unnormed_dict = nntr.unnormalize(post_dict, param_stats)

# Convert back to array for plotting
post_unnormed = jnp.stack([post_unnormed_dict[name] for name in global_param_names], axis=1)

# Now post_unnormed[:, i] contains the real-scale parameters
x = post_unnormed[:, 0]

ranges_from_priors = [
    (float(priors[name]['range'][0])-(float(priors[name]['range'][1])-float(priors[name]['range'][0]))/2, float(priors[name]['range'][1])+(float(priors[name]['range'][1])-float(priors[name]['range'][0]))/2)
    for name in global_param_names
]

In [None]:
fig = corner.corner(
    np.array(post_unnormed),
    labels=global_param_names,
    range=ranges_from_priors,
    quantiles=[0.16, 0.5, 0.84],      # 1D marginal: 1σ
    levels=[0.393469, 0.864665],     # 2D contours: 1σ, 2σ
    show_titles=True,
    title_fmt=".4f",
    bins=30,
    smooth=1.0,
    color="#1F487E",
    truths=theta_star_unnormed[:],
    truth_color="#A31621", 
    truth_alpha=0.8
)

plt.savefig("./corner.png",dpi=150)
plt.show()

# Save NN to disk

In [None]:
# Save the trained model for future use
save_path = '../data/NNs/nn_model_priors_M1000_cosmo_err_z_small_sample' #nn_model_priors_M1000_cosmo_err_z_smaller_dropout
nntr.save_autoregressive_nn(models_per_group,save_path, model_config)
print('NNs saved to ' + save_path)