### Import Packages

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
from sbi.inference import SNPE
from sbi import utils as utils
from sbi.analysis import run_sbc, sbc_rank_plot
from astropy.io import fits
from astropy.table import Table, Column
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
import pickle
import tarp

  from .autonotebook import tqdm as notebook_tqdm


### Data Loading

In [2]:
filename = "/raid/users/heigerm/catalogues/sp_x_apogee_x_spspectra_rvtab.fits"
# sp data
HDUlist = fits.open(filename)
# DESI
sp_tab = Table(HDUlist['SPTAB'].data)   
# APOGEE
apogee_tab = Table(HDUlist['APOGEEDR17'].data) 
# DESI SP Spectra
spectra = Table(HDUlist['SPECTRA_SP'].data)

### Data pre-processing

In [18]:
# Define parameters
targets = ['FE_H', 'MG_FE', 'C_FE', 'O_FE', 'CI_FE', 'AL_FE', 'SI_FE', 'K_FE', 'CA_FE', 'MN_FE', 'NI_FE', 'LOGG', 'TEFF']
feh_target, mgfe_target, cfe_target, ofe_target, cife_target, alfe_target, sife_target, kfe_target, cafe_target, mnfe_target, nife_target, log_g, temperature = (np.array(apogee_tab[col]) for col in targets)
targets_arr = [feh_target, kfe_target, cfe_target, cafe_target, nife_target, mnfe_target, ofe_target, cife_target, alfe_target]
# check for Al error = 0 case
alfe_target_err = np.array(apogee_tab['AL_FE_ERR'])
abnormal_rows = np.unique([index for target in targets_arr for index, value in enumerate(target) if value > 10] + 
                          [index for index, value in enumerate(alfe_target_err) if value == 0])


# Mask the abnormal rows across relevant datasets
mask = ~np.isin(np.arange(len(apogee_tab)), abnormal_rows)
apogee_tab_masked = apogee_tab[mask]
spectra_masked = spectra[mask]
sp_tab_masked = sp_tab[mask]

# Reconstruct target arrays with masked data
target_values_masked = {target: np.array(apogee_tab_masked[target]) for target in targets}

In [19]:
len(spectra_masked)

7336

### Spectra

In [20]:
# Combine spectra from the three arms and normalize
gb_combined_spectra = Table(names=['combined_flux', 'combined_wavelength'], dtype=['object', 'object'])

for row in spectra_masked:
    # Combine and sort flux and wavelength from all arms
    combined_flux = np.concatenate([row['flx_B'], row['flx_R'], row['flx_Z']])
    combined_wavelength = np.concatenate([row['B_WAVELENGTH'], row['R_WAVELENGTH'], row['Z_WAVELENGTH']])
    sort_order = np.argsort(combined_wavelength)
    combined_flux, combined_wavelength = combined_flux[sort_order], combined_wavelength[sort_order]

    # Normalize flux
    global_median = np.median(combined_flux)
    IQR = np.percentile(combined_flux, 75) - np.percentile(combined_flux, 25)
    normalized_flux = (combined_flux - global_median) / IQR

    gb_combined_spectra.add_row([normalized_flux, combined_wavelength])

### Input data

In [22]:
flux = np.array(gb_combined_spectra['combined_flux'])
# Input spectra
X = np.array([np.array(flux_val, dtype=float) for flux_val in flux])
# Parameters
theta = np.column_stack([target_values_masked[target] for target in targets])

In [23]:
print("Length of X:", len(X))
print("Dimensions of X:", len(X[0]))

Length of X: 7336
Dimensions of X: 13787


### Load Model

In [None]:
num_folds = 5

kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Initialize lists to store results
test_posterior_samples, sbc_ranks, sbc_dap_samples, all_x_test, all_y_test = [], [], [], [], []

# Iterate through the folds
for fold, (train_index, test_index) in enumerate(kf.split(X, theta)):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = theta[train_index], theta[test_index]

    # Standardize the data
    scaler = StandardScaler().fit(X_train)
    X_train, X_test = scaler.transform(X_train), scaler.transform(X_test)

    # Convert to PyTorch tensors
    X_train, X_test, y_train, y_test = map(torch.Tensor, (X_train, X_test, y_train, y_test))

    # Store test sets for later analysis
    all_x_test.append(X_test)
    all_y_test.append(y_test)

    # load the posterior model saved
    model_pkl_file = f"SBI_fold_{fold}.pkl" 
    
    with open(model_pkl_file, 'rb') as file:
        posterior = pickle.load(file)
    
    # Simulation-Based Calibration (SBC)
    num_posterior_samples=1000
    ranks, dap_samples = run_sbc(y_test, X_test, posterior, num_posterior_samples=num_posterior_samples, reduce_fns='marginals')
    sbc_ranks.append(ranks)
    sbc_dap_samples.append(dap_samples)
    

### Error Checks - Simulation Based Calibration

In [None]:
sbc_x_test = torch.cat(all_x_test, dim = 0)
sbc_x_test = sbc_x_test.numpy()

sbc_y_test = torch.cat(all_y_test, dim = 0)
sbc_y_test = sbc_y_test.numpy()

sbc_ranks_test = torch.cat(sbc_ranks, dim = 0)
sbc_ranks_test = sbc_ranks_test.numpy()

sbc_dap_samples_test = torch.cat(sbc_dap_samples, dim = 0)
sbc_dap_samples_test = sbc_dap_samples_test.numpy()

sbc_ranks_test_tensor = torch.tensor(sbc_ranks_test)
sbc_dap_samples_tensor = torch.tensor(sbc_dap_samples_test)
sbc_y_test_tensor = torch.tensor(sbc_y_test)

# KS test
#check_stats = check_sbc(sbc_ranks_test_tensor, sbc_y_test, sbc_dap_samples_tensor, num_posterior_samples=num_posterior_samples)
#print(f"kolmogorov-smirnov p-values \ncheck_stats['ks_pvals'] = {check_stats['ks_pvals'].numpy()}")

# SBC Rank Plot
f, ax = sbc_rank_plot(
    ranks=ranks,
    num_posterior_samples=num_posterior_samples,
    parameter_labels = ['Fe/H', 'Mg/Fe', 'O/Fe', 'C/Fe', 'CI/Fe', 'Si/Fe', 'K/Fe', 'Ca/Fe', 'Al/Fe', 'Mn/Fe', 'Ni/Fe', 
                        'Log_g', 'Teff'],
    plot_type="hist",
    num_bins=None)

# SBC CDF Plot
f, ax = sbc_rank_plot(ranks, 1_000, parameter_labels = ['Fe/H', 'Mg/Fe', 'O/Fe', 'C/Fe', 'CI/Fe', 'Si/Fe', 'K/Fe', 'Ca/Fe', 'Al/Fe', 'Mn/Fe', 'Ni/Fe',
                                                       'Log_g', 'Teff'], 
                      plot_type="cdf")

### Tarp

In [None]:
n_dims = 13
n_samples = 250

post_samples = torch.cat(test_posterior_samples, dim = 0)
post_samples = post_samples.numpy()
posterior_samples = post_samples.reshape((n_samples, len(sbc_x_test), n_dims))

coverage_values, ecp = tarp.get_drp_coverage(posterior_samples, sbc_y_test, references='random', metric='euclidean')

plt.plot(coverage_values, ecp, marker='o')
plt.xlabel("Credibility Level")
plt.ylabel("Expected Coverage Probability")
plt.title("Credibility Level vs. Expected Coverage Probability")
plt.grid(True)
plt.show()

significance_level = 0.05

# Step 2: Compare each ECP value with the significance level
is_significant = ecp <= significance_level

# Step 3: Calculate the p-value
p_value = is_significant.mean()

print("p-value:", p_value)