### Import Packages

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
from sbi.inference import SNPE
from sklearn.model_selection import KFold
from sbi import utils as utils
from astropy.io import fits
from astropy.table import Table, Column
from sklearn.preprocessing import StandardScaler
import pickle

  from .autonotebook import tqdm as notebook_tqdm


### Data Loading

In [2]:
filename = "/raid/users/heigerm/catalogues/sp_x_apogee_x_spspectra_rvtab.fits"
# sp data
HDUlist = fits.open(filename)
# DESI
sp_tab = Table(HDUlist['SPTAB'].data)   
# APOGEE
apogee_tab = Table(HDUlist['APOGEEDR17'].data) 
# DESI SP Spectra
spectra = Table(HDUlist['SPECTRA_SP'].data)

### Data pre-processing

In [18]:
# Define parameters
targets = ['FE_H', 'MG_FE', 'C_FE', 'O_FE', 'CI_FE', 'AL_FE', 'SI_FE', 'K_FE', 'CA_FE', 'MN_FE', 'NI_FE', 'LOGG', 'TEFF']
feh_target, mgfe_target, cfe_target, ofe_target, cife_target, alfe_target, sife_target, kfe_target, cafe_target, mnfe_target, nife_target, log_g, temperature = (np.array(apogee_tab[col]) for col in targets)
targets_arr = [feh_target, kfe_target, cfe_target, cafe_target, nife_target, mnfe_target, ofe_target, cife_target, alfe_target]
# check for Al error = 0 case
alfe_target_err = np.array(apogee_tab['AL_FE_ERR'])
abnormal_rows = np.unique([index for target in targets_arr for index, value in enumerate(target) if value > 10] + 
                          [index for index, value in enumerate(alfe_target_err) if value == 0])


# Mask the abnormal rows across relevant datasets
mask = ~np.isin(np.arange(len(apogee_tab)), abnormal_rows)
apogee_tab_masked = apogee_tab[mask]
spectra_masked = spectra[mask]
sp_tab_masked = sp_tab[mask]

# Reconstruct target arrays with masked data
target_values_masked = {target: np.array(apogee_tab_masked[target]) for target in targets}

In [19]:
len(spectra_masked)

7336

### Spectra

In [20]:
# Combine spectra from the three arms and normalize
gb_combined_spectra = Table(names=['combined_flux', 'combined_wavelength'], dtype=['object', 'object'])

for row in spectra_masked:
    # Combine and sort flux and wavelength from all arms
    combined_flux = np.concatenate([row['flx_B'], row['flx_R'], row['flx_Z']])
    combined_wavelength = np.concatenate([row['B_WAVELENGTH'], row['R_WAVELENGTH'], row['Z_WAVELENGTH']])
    sort_order = np.argsort(combined_wavelength)
    combined_flux, combined_wavelength = combined_flux[sort_order], combined_wavelength[sort_order]

    # Normalize flux
    global_median = np.median(combined_flux)
    IQR = np.percentile(combined_flux, 75) - np.percentile(combined_flux, 25)
    normalized_flux = (combined_flux - global_median) / IQR

    gb_combined_spectra.add_row([normalized_flux, combined_wavelength])

### Input data

In [22]:
flux = np.array(gb_combined_spectra['combined_flux'])
# Input spectra
X = np.array([np.array(flux_val, dtype=float) for flux_val in flux])
# Parameters
theta = np.column_stack([target_values_masked[target] for target in targets])

In [23]:
print("Length of X:", len(X))
print("Dimensions of X:", len(X[0]))

Length of X: 7336
Dimensions of X: 13787


### Load Model

In [None]:
num_folds = 5

kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Initialize lists to store results
results = {target: {'var': [], 'pred': [], 'res': [], 'exp': []} for target in targets}

# Iterate through the folds
for fold, (train_index, test_index) in enumerate(kf.split(X, theta)):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = theta[train_index], theta[test_index]

    # Standardize the data
    scaler = StandardScaler().fit(X_train)
    X_train, X_test = scaler.transform(X_train), scaler.transform(X_test)

    # Convert to PyTorch tensors
    X_train, X_test, y_train, y_test = map(torch.Tensor, (X_train, X_test, y_train, y_test))


    # load the posterior model saved
    model_pkl_file = f"SBI_fold_{fold}.pkl" 
    
    with open(model_pkl_file, 'rb') as file:
        posterior = pickle.load(file)
        
    # store the metallicity and abundance results 
    # original code:
    n_samples = 250

    for idx in range(len(X_test)):
        samples = posterior.sample((n_samples,), x=X_test[idx])
        
        # Iterate over each target and store the results
        for i, target in enumerate(targets):
            target_samples = samples[:, i]
            target_exp = y_test[idx][i]
            target_pred = torch.mean(target_samples)
            target_res = target_pred - target_exp
            target_var = torch.var(target_samples)
            
            results[target]['exp'].append(target_exp)
            results[target]['pred'].append(target_pred)
            results[target]['res'].append(target_res)
            results[target]['var'].append(target_var)


In [None]:
# After collecting all the results in the results dictionary
for target in results:
    results[target]['exp'] = torch.stack(results[target]['exp'])
    results[target]['pred'] = torch.stack(results[target]['pred'])
    results[target]['res'] = torch.stack(results[target]['res'])
    results[target]['var'] = torch.stack(results[target]['var'])

### Summary Table of Residuals

In [None]:
# Create a DataFrame to store the summary statistics
summary_df = pd.DataFrame(columns=["Target", "Mean", "Median", "IQR", "Variance"])

# Calculate the statistics for each target and add them to the DataFrame
for target in results:
    residuals = results[target]['res'].numpy()
    mean_res = np.mean(residuals)
    median_res = np.median(residuals)
    iqr_res = np.percentile(residuals, 75) - np.percentile(residuals, 25)
    var_res = np.var(residuals)
    
    # Append a new row to the DataFrame
    summary_df = summary_df.append({
        "Target": target,
        "Mean": mean_res,
        "Median": median_res,
        "IQR": iqr_res,
        "Variance": var_res
    }, ignore_index=True)

# Set the 'Target' column as the index
summary_df.set_index('Target', inplace=True)

# Print the summary DataFrame
print(summary_df)

### Metallicity Prediction

In [None]:
plt.figure(figsize=(8, 6))  

plt.scatter(results['FE_H']['exp'].numpy(), results['FE_H']['pred'].numpy(), alpha=0.5, color='b', edgecolors='k', marker='o')
plt.xlim(-2, 0.5)
plt.ylim(-2, 0.5)
plt.xlabel("[Fe/H] Expected Values", fontsize=14)
plt.ylabel("Predicted Values", fontsize=14)
plt.title("Predicted vs. Expected [Fe/H]", fontsize=16)

plt.grid(True, linestyle='--', alpha=0.7)

plt.show()

In [None]:
# Define the bins and colormap
xbin = np.linspace(-2, 0.5, 100)
ybin = np.linspace(-0.75, 0.75, 50)
cmap = plt.cm.viridis

# Extract the expected and residual values for FE_H from the results dictionary
feh_exp = results['FE_H']['exp'].numpy()
feh_res = results['FE_H']['res'].numpy()

# Calculate the statistics for the residuals across bins
def calculate_stats(bin_edges, values):
    centers = []
    lower, median, upper = [], [], []
    for n in range(len(bin_edges) - 1):
        in_bin = (values >= bin_edges[n]) & (values < bin_edges[n + 1])
        if in_bin.any():
            center = (bin_edges[n + 1] + bin_edges[n]) / 2
            centers.append(center)
            percentiles = np.percentile(values[in_bin], [16, 50, 84])
            lower.append(percentiles[0])
            median.append(percentiles[1])
            upper.append(percentiles[2])
    return centers, lower, median, upper

# Calculate stats for the expected FE_H values
centers, lower, median, upper = calculate_stats(xbin, feh_exp)
l = np.interp(xbin, centers, lower)
m = np.interp(xbin, centers, median)
u = np.interp(xbin, centers, upper)

# Get the DESI and APOGEE [Fe/H] values from the masked tables
desi_feh = np.array(sp_tab_masked['Fe_H_sp'])
apogee_feh = np.array(apogee_tab_masked['FE_H'])
r1 = desi_feh - apogee_feh

# Calculate stats for the APOGEE FE_H values
centers1, lower1, median1, upper1 = calculate_stats(xbin, apogee_feh)
l1 = np.interp(xbin, centers1, lower1)
m1 = np.interp(xbin, centers1, median1)
u1 = np.interp(xbin, centers1, upper1)

# Plot the 2D histograms
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12))

# Plot for DESI - APOGEE [Fe/H]
ax1.hist2d(apogee_feh, r1, bins=(xbin, ybin), cmap=cmap)
ax1.plot(xbin, l1, 'k--', label='-1$\sigma$', lw=2)
ax1.plot(xbin, m1, 'k', label='Median', lw=2)
ax1.plot(xbin, u1, 'k:', label='+1$\sigma$', lw=2)
ax1.set_ylabel('$\Delta$ Before', fontsize=18)
ax1.legend(fontsize=16)
ax1.tick_params(axis='x', labelbottom=False)

# Plot for Predicted - Expected [Fe/H]
ax2.hist2d(feh_exp, feh_res, bins=(xbin, ybin), cmap=cmap)
ax2.plot(xbin, l, 'k--', label='-1$\sigma$', lw=2)
ax2.plot(xbin, m, 'k', label='Median', lw=2)
ax2.plot(xbin, u, 'k:', label='+1$\sigma$', lw=2)
ax2.set_xlabel('APOGEE [Fe/H]', fontsize=18)
ax2.set_ylabel('$\Delta$ After', fontsize=18)
ax2.legend(fontsize=18)

plt.subplots_adjust(hspace=0.05)

# Add a colorbar
cbar = fig.colorbar(ax1.images[0], ax=[ax1, ax2], orientation='vertical')
cbar.ax.tick_params(labelsize=16) 
cbar.set_label('# of stars', fontsize=18)

plt.show()