### Import Packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
from sbi.inference import SNPE
from sbi import utils as utils
from sbi.analysis import run_sbc, sbc_rank_plot
from astropy.io import fits
from astropy.table import Table, Column
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
import pickle
import tarp
from scipy.stats import gaussian_kde

### Data Loading

In [None]:
# Globular Cluster Data
gc_filename = "/raid/users/heigerm/catalogues/sptab_spspectra_rvtab_gcs.fits"
gc_HDUlist = fits.open(gc_filename)

# Globular Cluster parameters from sp pipeline
gc_sp_tab = Table(gc_HDUlist['SPTAB'].data)  

# Globular Cluster Observed Spectra
gc_spectra = Table(gc_HDUlist['SPECTRA_SP'].data)

# print the globular clusters in gc_spectra
set(gc_spectra['gcname'])

### GC Spectra

In [None]:
# Create the spectra table for each globular cluster
gc_names = [
    'NGC_2419', 'NGC_5024_M_53', 'NGC_5053', 'NGC_5272_M_3', 'NGC_5466',
    'NGC_5634', 'NGC_5904_M_5', 'NGC_6205_M_13', 'NGC_6218_M_12', 'NGC_6229',
    'NGC_6341_M_92', 'NGC_7078_M_15', 'NGC_7089_M_2', 'Pal_14', 'Pal_5'
]

gc_tables = {name: gc_spectra[gc_spectra['gcname'] == name] for name in gc_names}

In [None]:
# Combine the spectra from the three arms for each globular cluster
def process_gc_spectra(gc_table):
    processed_spectra = Table()
    processed_spectra.add_column(Column(name='combined_flux', dtype='object'))
    processed_spectra.add_column(Column(name='combined_wavelength', dtype='object'))

    for row in gc_table:
        flux_B, wavelength_B = row['flx_B'], row['B_WAVELENGTH']
        flux_R, wavelength_R = row['flx_R'], row['R_WAVELENGTH']
        flux_Z, wavelength_Z = row['flx_Z'], row['Z_WAVELENGTH']

        combined_flux = np.concatenate([flux_B, flux_R, flux_Z])
        combined_wavelength = np.concatenate([wavelength_B, wavelength_R, wavelength_Z])

        sort_order = np.argsort(combined_wavelength)
        combined_flux = combined_flux[sort_order]
        combined_wavelength = combined_wavelength[sort_order]

        global_median = np.median(combined_flux)
        combined_flux -= global_median
        IQR = np.percentile(combined_flux, 75) - np.percentile(combined_flux, 25)
        combined_flux = combined_flux / IQR

        processed_spectra.add_row([combined_flux, combined_wavelength])
    
    return processed_spectra

gc_processed_spectra = {name: process_gc_spectra(table) for name, table in gc_tables.items()}

In [None]:
# The gc data for testing
X_gc = {name: np.array([np.array([float(val) for val in array]) for array in spectra['combined_flux']])
        for name, spectra in gc_processed_spectra.items()}

In [None]:
X_gc['NGC_2419'].shape # only 1 object in this gc

### Load Model

In [None]:
num_folds = 5

feh_pred_gc = {gc_name: [] for gc_name in X_gc.keys()}
feh_var_gc = {gc_name: [] for gc_name in X_gc.keys()}
for fold in range(num_folds):
    # Load the model for the current fold
    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(13787, 4000),
                nn.GELU(),
                nn.Linear(4000, 2000),
                nn.GELU(),
                nn.Linear(2000, 1000),
                nn.GELU(),
                nn.Linear(1000, 500),
                nn.GELU(),
                nn.Linear(500, 100),
                nn.GELU(),
                nn.Linear(100, 50))

        def forward(self, x):
            return self.model(x)
    # Unif Resampling data
    model_pkl_file = f"SBI_Unif_fold_{fold}.pkl"
    
    # Original Training data
    # model_pkl_file = f"SBI_fold_{fold}.pkl"
    with open(model_pkl_file, 'rb') as file:
        posterior = pickle.load(file)
    
    # For each globular cluster, make predictions using the loaded model
    for gc_name, gc_data in X_gc.items():
        # Initialize lists to store predictions and variances for the current GC
        feh_pred_gc[gc_name] = []
        feh_var_gc[gc_name] = []
    
        # Iterate through each observation in the GC data
        for observation in gc_data:
            # Convert the single observation to a PyTorch tensor and add an extra dimension to match input shape
            gc_data_tensor = torch.Tensor(observation).unsqueeze(0)
        
            # Make predictions for the single observation
            samples = posterior.sample((250,), x=gc_data_tensor)  # Adjust the number of samples as needed
            fe_h = samples[:, 0]  # Extract [Fe/H] predictions
        
            # Calculate mean and variance of [Fe/H] predictions for the single observation
            feh_pred_gc[gc_name].append(torch.mean(fe_h).item())
            feh_var_gc[gc_name].append(torch.var(fe_h).item())

# Calculate the overall mean and variance of [Fe/H] predictions for each GC across all folds
feh_pred_gc_mean = {gc_name: np.mean(values) for gc_name, values in feh_pred_gc.items()}
feh_var_gc_mean = {gc_name: np.mean(values) for gc_name, values in feh_var_gc.items()}

In [None]:
feh_pred_gc_mean

In [None]:
feh_var_gc_mean

In [None]:
import pandas as pd

# Convert the mean and variance dictionaries to pandas Series for easy DataFrame construction
pred_mean_series = pd.Series(feh_pred_gc_mean, name='Pred Mean')
pred_var_series = pd.Series(feh_var_gc_mean, name='Pred Variance')

# Convert the true values dictionary to a pandas Series
true_values_series = pd.Series(true_values, name='True [Fe/H]')

# Construct the summary DataFrame
summary_df = pd.concat([pred_mean_series, pred_var_series, true_values_series], axis=1)

# Display the summary table
print(summary_df)

### Globular Cluster Metallicity Predictions

In [None]:
# True [Fe/H] values for each GC
true_values = {
    'NGC_2419': -2.1, 'NGC_5024_M_53': -1.86, 'NGC_5053': -2.3,
    'NGC_5272_M_3': -1.5, 'NGC_5466': -2.2, 'NGC_5634': -1.98,
    'NGC_5904_M_5': -1.29, 'NGC_6205_M_13': -1.33, 'NGC_6218_M_12': -1.14,
    'NGC_6229': -1.13, 'NGC_6341_M_92': -2.31, 'NGC_7078_M_15': -2.37,
    'NGC_7089_M_2': -1.65, 'Pal_14': -1.62, 'Pal_5': -1.41
}

for gc_name, predictions in feh_pred_gc.items():
    # Directly convert predictions to a NumPy array
    flat_predictions = np.array(predictions)
    
    # Ensure there are multiple predictions before calculating the KDE
    if flat_predictions.size > 1:
        # Calculate the KDE
        kde = gaussian_kde(flat_predictions)
        
        # Define a range over which to evaluate the KDE
        x_range = np.linspace(flat_predictions.min(), flat_predictions.max(), 500)
        
        # Evaluate the KDE over the defined range
        kde_values = kde(x_range)
        
        # Create a new figure for the current GC
        plt.figure(figsize=(8, 6))
        plt.plot(x_range, kde_values, label=f'KDE {gc_name}')
        
        # Plot the true [Fe/H] value for the current GC
        if gc_name in true_values:
            plt.axvline(x=true_values[gc_name], color='r', linestyle='--', label=f'True {gc_name}')
        
        plt.title(f'KDE of [Fe/H] Predictions for {gc_name}')
        plt.xlabel('[Fe/H] Predictions')
        plt.ylabel('Density')
        plt.legend()
        plt.tight_layout()
        plt.show()
    else:
        print(f"Skipping {gc_name} due to insufficient data for KDE.")