This is the notebook for predicting the [Fe/H] in the DESI spectra using masked autoregressive flow

The model is trained using labels from APOGEE

The structure of this notebook follows:

1. Import packages for this notebook, we use torch and the model is from the sbi package
2. Load the data from two sources: Spectra with signal to noise ratio < 50 and Spectra with signal to noise ratio > 50
3. Cross-match spectra between desi and apogee and normalize the spectra, Remove rows in apogee with abnormal Fe/H values
4. Train the MAF models and save them 
5. Test the model performances: 2D histograms (Accuracy check), SBC (Posterior well calibrated or not)
6. Application to Globular Clusters


## 1. Import Packages

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os
import torch
import torch.nn as nn
from sbi.inference import SNPE
from sbi import utils as utils
from sbi.analysis import run_sbc
from astropy.io import fits
from astropy.table import Table, Column
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
import pickle

  from .autonotebook import tqdm as notebook_tqdm


## 2. Data Loading

### 2.1 load spectra with lower SNR

In [None]:
# load file path
f1 = "/raid/users/heigerm/catalogues/sptab_spspectra_rvtab_lowsnr.fits"
HDUlist1 = fits.open(f1)

# DESI labels
sp_tab1 = Table(HDUlist1['SPTAB'].data) 

# APOGEE labels
apogee_tab1 = Table(HDUlist1[4].data)

# DESI SP Spectra
spectra1 = Table(HDUlist1['SPECTRA_SP'].data) 

### 2.2 load spectra with SNR > 50 

In [None]:
# load file path
f2 = "/raid/users/heigerm/catalogues/sp_x_apogee_x_spspectra_rvtab.fits" 
HDUlist2 = fits.open(f2)

# DESI labels
sp_tab2 = Table(HDUlist2['SPTAB'].data)  

# APOGEE labels
apogee_tab2 = Table(HDUlist2['APOGEEDR17'].data)

# DESI SP Spectra
spectra2 = Table(HDUlist2['SPECTRA_SP'].data)

## 3. Data pre-processing

### 3.1 Cross-match DESI with APOGEE (ra, dec)

In [None]:
# Create SkyCoord objects for both tables
apogee_coords = SkyCoord(ra=apogee_tab1['RA']*u.degree, dec=apogee_tab1['DEC']*u.degree)
spectra_coords = SkyCoord(ra=spectra1['TARGET_RA']*u.degree, dec=spectra1['TARGET_DEC']*u.degree) 

# Find the closest match for each entry in spectra1 within a tolerance
idx, d2d, _ = spectra_coords.match_to_catalog_sky(apogee_coords)

tolerance = 1 * u.arcsec
matches_within_tolerance = d2d < tolerance

apogee_tab1_matched = apogee_tab1[idx[matches_within_tolerance]]
spectra1_matched = spectra1[matches_within_tolerance]
sp_tab1_matched = sp_tab1[matches_within_tolerance]

# stack the apogee labels from the two tables
apogee_tab1_selected = apogee_tab1_matched['APOGEE_ID', 'RA', 'DEC', 'FE_H', 'FE_H_ERR']
apogee_tab2_selected = apogee_tab2['APOGEE_ID', 'RA', 'DEC', 'FE_H', 'FE_H_ERR']

# obtain the apogee table
apogee_tab_combined = vstack([apogee_tab1_selected, apogee_tab2_selected])

# Combine the spectra from the two tables
# Select the common columns from each table
common_cols = set(spectra1_matched.colnames).intersection(spectra2.colnames)
spectra1_common = spectra1_matched[list(common_cols)]
spectra2_common = spectra2[list(common_cols)]

# obtain the spectra table
spectra_combined = vstack([spectra1_common, spectra2_common])

# Stack DESI labels from the two tables
common_cols = set(sp_tab1_matched.colnames).intersection(sp_tab2.colnames)
sp1_common = sp_tab1_matched[list(common_cols)]
sp2_common = sp_tab2[list(common_cols)]

# obtain the DESI table
sp_combined = vstack([sp1_common, sp2_common])

### 3.2 Remove abnormal rows for [Fe/H]

In [None]:
# remove the rows if Fe/H is nan or larger than 10, or have zero values for Fe/H and its error
abnormal_rows = np.where((np.isnan(apogee_tab_combined['FE_H'])) | (apogee_tab_combined['FE_H'] > 10)
                        |(apogee_tab_combined['FE_H'] == 0) |
                        (apogee_tab_combined['FE_H_ERR'] == 0) |
                        (np.isnan(apogee_tab_combined['FE_H_ERR'])))[0]

# Create a mask to filter out the abnormal rows
mask = ~np.isin(np.arange(len(apogee_tab_combined)), abnormal_rows)

# Apply the mask to the datasets to filter out the abnormal rows
apogee_tab_masked = apogee_tab_combined[mask]
spectra_masked = spectra_combined[mask]
sp_masked = sp_combined[mask]

### 3.3 Combine the spectra from three arms and normalize (x-median/iqr)

In [None]:
gb_combined_spectra = Table(names=['combined_flux', 'combined_wavelength'], dtype=['object', 'object'])

for row in spectra_masked:
    # Combine and sort flux and wavelength from all arms
    combined_flux = np.concatenate([row['flx_B'], row['flx_R'], row['flx_Z']])
    combined_wavelength = np.concatenate([row['B_WAVELENGTH'], row['R_WAVELENGTH'], row['Z_WAVELENGTH']])
    sort_order = np.argsort(combined_wavelength)
    combined_flux, combined_wavelength = combined_flux[sort_order], combined_wavelength[sort_order]

    # Normalize flux
    global_median = np.median(combined_flux)
    IQR = np.percentile(combined_flux, 75) - np.percentile(combined_flux, 25)
    normalized_flux = (combined_flux - global_median) / IQR

    gb_combined_spectra.add_row([normalized_flux, combined_wavelength])
    
flux = np.array(gb_combined_spectra['combined_flux'])

### 3.4 Set up Input data

In [None]:
# Input spectra
X = np.array([np.array(flux_val, dtype=float) for flux_val in flux])

# Parameter [Fe/H]
theta = np.array(apogee_tab_masked["FE_H"])

## 4. Model Training

### 4.1 Cross Valiation Set-up

In [None]:
# 5 fold cross validation
num_folds = 5
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# define the results to be saved
results = {'exp': [], 'pred': [], 'res': [], 'var': []}

# objects for simulation-based calibration
test_posterior_samples, sbc_ranks, sbc_dap_samples, all_x_test, all_y_test = [], [], [], [], []

### 4.2 Train and Save Models (Masked Autoregressive Flow)

In [None]:
# Iterate through the folds
for fold, (train_index, test_index) in enumerate(kf.split(X, theta)):


    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = theta[train_index], theta[test_index]

    # Convert to PyTorch tensors
    X_train, X_test = map(torch.Tensor, (X_train, X_test))
    y_train, y_test = map(torch.Tensor, (y_train, y_test))
    y_train, y_test = y_train.unsqueeze(-1), y_test.unsqueeze(-1)
    # save the test data
    all_x_test.append(X_test)
    all_y_test.append(y_test)
        
    # Masked Autoregressive Flow
    inference = SNPE(density_estimator="maf")
    inference.append_simulations(y_train, X_train)
    
    # Train MAF 
    density_estimator = inference.train()
    
    # Obtain the posterior 
    posterior = inference.build_posterior(density_estimator)

    # Save the posterior
    model_pkl_file = f"MAF_fold{fold}.pkl" 
    
    with open(model_pkl_file, 'wb') as file:
        pickle.dump(posterior, file)

## 5. Testing

### 5.1 Obtain the predictions and Simulation-based Calibration ranks

In [None]:
# load the models
for fold, (train_index, test_index) in enumerate(kf.split(X, theta)):
    
    model_pkl_file = f"MAF_fold{fold}.pkl"
    with open(model_pkl_file, 'rb') as file:
        posterior = pickle.load(file)
        
    # Sample from the posterior for the test set
    n_samples = 250  
    for idx in range(len(X_test)):
        samples = posterior.sample((n_samples,), x=X_test[idx])
        target_samples = samples[:, 0]
        target_exp = y_test[idx]
        target_pred = torch.mean(target_samples)
        target_res = target_pred - target_exp
        target_var = torch.var(target_samples)

        results['exp'].append(target_exp)
        results['pred'].append(target_pred)
        results['res'].append(target_res)
        results['var'].append(target_var)
        
    # Simulation-Based Calibration (SBC)
    num_posterior_samples=1000
    ranks, dap_samples = run_sbc(y_test, X_test, posterior, num_posterior_samples=num_posterior_samples, reduce_fns='marginals')
    sbc_ranks.append(ranks)
    sbc_dap_samples.append(dap_samples)
    
    
# [Fe/H] truth from APOGEE    
feh_exp = torch.stack(results['exp'])

# [Fe/H] predictions
feh_pred = torch.stack(results['pred']) 

# [Fe/H] residuals
feh_res = torch.stack(results['res']) 

### 5.2. 2D Histogram: DESI SP vs. APOGEE, MAF (our model predictions) vs. APOGEE

In [None]:
cmap = mcolors.LinearSegmentedColormap.from_list('greenblue', ['white', 'dodgerblue', 'dodgerblue', 'royalblue', 'royalblue', 'mediumblue', 'mediumblue', 'midnightblue'])
feh_exp = np.array(feh_exp)
feh_pred = np.array(feh_pred)
feh_res = np.array(feh_res)

In [None]:
# set up plot for DESI vs. APOGEE
xbin = np.linspace(-2, 0.5, 100) # range of Fe/H values
lower = [] # -1 sigma
median = []
upper = [] # +1 sigma
center = []
for n in range(len(xbin)-1):
    edges = [i for i in range(len(feh_exp)) if (feh_exp[i] >= xbin[n] and feh_exp[i] < xbin[n+1])]
    residual = [] 
    if edges:
        c = (xbin[n+1] + xbin[n]) / 2 
        center.append(c)
        for k in edges:
            residual.append(feh_res[k]) 
        p = np.percentile(residual, (16, 50, 84)) 
        lower.append(p[0])
        median.append(p[1])
        upper.append(p[2])


center = np.array(center, dtype=float)
lower = np.array(lower, dtype=float)
median = np.array(median, dtype=float)
upper = np.array(upper, dtype=float)      
        
l = np.interp(xbin, center, lower) 
m = np.interp(xbin, center, median)
u = np.interp(xbin, center, upper)

desi_feh = np.array(sp_masked['FEH'])
apogee_feh = np.array(apogee_tab_masked['FE_H'])
xbin1 = np.linspace(-2, 0.5, 100)
r1 = np.array(desi_feh) - np.array(apogee_feh) # residuals between DESI and APOGEE

# set up plot for MAF (our model) vs. APOGEE
lower1 = []
median1 = []
upper1 = []
center1 = []
for n in range(len(xbin1)-1):
    edges = [i for i in range(len(feh_exp)) if (feh_exp[i] >= xbin1[n] and feh_exp[i]< xbin1[n+1])]
    residual = []
    if edges:
        c = (xbin1[n+1] + xbin1[n])/2
        center1.append(c)
        for k in edges:
            residual.append(r1[k])
        p = np.percentile(residual, (16, 50, 84))
        lower1.append(p[0])
        median1.append(p[1])
        upper1.append(p[2])


center1 = np.array(center1, dtype=float)
lower1 = np.array(lower1, dtype=float)
median1 = np.array(median1, dtype=float)
upper1 = np.array(upper1, dtype=float)         
        
        
l1 = np.interp(xbin1, center1, lower1)
m1 = np.interp(xbin1, center1, median1)
u1 = np.interp(xbin1, center1, upper1)

# Plot the 2D histograms! 
fig, (ax1,ax2) = plt.subplots(2, 1, figsize=(14,12))
ax1 = plt.subplot(2, 1, 1)
h1, x, y, i = plt.hist2d(feh_exp, r1, bins = (np.linspace(-2, 0.5, 100), np.linspace(-0.75, 0.75, 50)), 
                         cmap = cmap)
plt.plot(xbin1, l1, c = 'black', linestyle = '--', label = '-1$\sigma$', lw = 2)
plt.plot(xbin1, m1, c = 'black', label = 'Median', lw = 2)
plt.plot(xbin1, u1, c = 'black', linestyle = 'dotted', label = '+1$\sigma$', lw = 2)

plt.ylabel('$\Delta$ Before', fontsize = 18)
plt.yticks(fontsize = 16)
plt.legend(fontsize = 16)
ax1.tick_params(axis = 'x', labelbottom=False)
im1 = ax1.imshow(h1, cmap = cmap)


ax2 = plt.subplot(2, 1, 2)
h2, x, y, i = plt.hist2d(feh_exp, feh_res, bins = (np.linspace(-2, 0.5, 100), np.linspace(-0.75, 0.75, 50)), cmap = cmap)
plt.plot(xbin, l, c = 'black', linestyle = '--', label = '-1$\sigma$', lw = 2)
plt.plot(xbin, m, c = 'black', label = 'Median', lw = 2)
plt.plot(xbin, u, c = 'black', linestyle = 'dotted', label = '+1$\sigma$', lw = 2)

plt.xlabel('APOGEE [Fe/H]', fontsize = 18)
plt.ylabel('$\Delta$ After', fontsize = 18)
plt.xticks(fontsize = 16)
plt.yticks(fontsize = 16)
plt.legend(fontsize = 18)
im2 = ax2.imshow(h2, cmap = cmap)

cbar = fig.colorbar(im1, ax=(ax1, ax2))
cbar.ax.tick_params(labelsize = 16) 
cbar.set_label('# of stars', fontsize = 18)

### 5.3. Simulation-based Calibration Checks

In [None]:
sbc_x_test = torch.cat(all_x_test, dim = 0)
sbc_x_test = sbc_x_test.numpy()

sbc_y_test = torch.cat(all_y_test, dim = 0)
sbc_y_test = sbc_y_test.numpy()

sbc_ranks_test = torch.cat(sbc_ranks, dim = 0)
sbc_ranks_test = sbc_ranks_test.numpy()

sbc_dap_samples_test = torch.cat(sbc_dap_samples, dim = 0)
sbc_dap_samples_test = sbc_dap_samples_test.numpy()

sbc_ranks_test_tensor = torch.tensor(sbc_ranks_test)
sbc_dap_samples_tensor = torch.tensor(sbc_dap_samples_test)
sbc_y_test_tensor = torch.tensor(sbc_y_test)

# SBC Rank Plot
f, ax = sbc_rank_plot(
    ranks=ranks,
    num_posterior_samples=num_posterior_samples,
    parameter_labels = ['Fe/H'],
    plot_type="hist",
    num_bins=None)

# SBC CDF Plot
f, ax = sbc_rank_plot(ranks, 1_000, parameter_labels = ['Fe/H'], 
                      plot_type="cdf")

## 6. Application to Globular Clusters 

### 6.1. Load Globular Cluster Data

In [None]:
# File path
# Globular Clusters with 4 truths of [Fe/H]
gc_filename1 = "/home/jupyter-tingli/DESI/gc_iron_240123.fits"
gc_HDUlist1 = fits.open(gc_filename1)

# Globular Cluster Spectra
gc_spectra1 = Table(gc_HDUlist1[1].data)

# print the globular clusters in gc_spectra
set(gc_spectra1['gcname'])

### 6.2. Extract true [Fe/H] from the literature, each GC has 4 truths

In [None]:
unique_gcnames = set(gc_spectra1['gcname'])

# Dictionary to hold the metallicity values for each gcname
gc_metallicities = {}

for gcname in unique_gcnames:
    # Filter the table for the current gcname
    gc_data = gc_spectra1[gc_spectra1['gcname'] == gcname]
    
    # Assuming there's at least one entry for each gcname and the values are consistent
    # Extract the metallicity values for the first occurrence
    metallicity_H10 = gc_data['metallicity_H10'][0]
    metallicity_K19 = gc_data['metallicity_K19'][0]
    metallicity_B19 = gc_data['metallicity_B19'][0]
    metallicity_V20 = gc_data['metallicity_V20'][0]
    
    # Store the extracted values in the dictionary
    gc_metallicities[gcname] = {
        'metallicity_H10': metallicity_H10,
        'metallicity_K19': metallicity_K19,
        'metallicity_B19': metallicity_B19,
        'metallicity_V20': metallicity_V20
    }

### 6.3. Create the spectra table for each globular cluster

In [None]:
gc_names = [
    'NGC_2419', 'NGC_5024_M_53', 'NGC_5053', 'NGC_5272_M_3', 'NGC_5466',
    'NGC_5634', 'NGC_5904_M_5', 'NGC_6205_M_13', 'NGC_6218_M_12', 'NGC_6229',
    'NGC_6341_M_92', 'NGC_7078_M_15', 'NGC_7089_M_2', 'Pal_14', 'Pal_5'
]

gc_tables = {name: gc_spectra[gc_spectra['gcname'] == name] for name in gc_names}

#### 6.3.1. Preprocess the GC spectra (normalization)

In [None]:
def process_gc_spectra(gc_table):
    processed_spectra = Table()
    processed_spectra.add_column(Column(name='combined_flux', dtype='object'))
    processed_spectra.add_column(Column(name='combined_wavelength', dtype='object'))

    for row in gc_table:
        flux_B, wavelength_B = row['flx_B'], row['B_WAVELENGTH']
        flux_R, wavelength_R = row['flx_R'], row['R_WAVELENGTH']
        flux_Z, wavelength_Z = row['flx_Z'], row['Z_WAVELENGTH']

        combined_flux = np.concatenate([flux_B, flux_R, flux_Z])
        combined_wavelength = np.concatenate([wavelength_B, wavelength_R, wavelength_Z])

        sort_order = np.argsort(combined_wavelength)
        combined_flux = combined_flux[sort_order]
        combined_wavelength = combined_wavelength[sort_order]

        global_median = np.median(combined_flux)
        combined_flux -= global_median
        IQR = np.percentile(combined_flux, 75) - np.percentile(combined_flux, 25)
        combined_flux = combined_flux / IQR

        processed_spectra.add_row([combined_flux, combined_wavelength])
    
    return processed_spectra

gc_processed_spectra = {name: process_gc_spectra(table) for name, table in gc_tables.items()}

### 6.4. Obtain the [Fe/H] predictions for the GCs

In [None]:
num_folds = 5

# store the predictions and variance
feh_pred_gc = {gc_name: [] for gc_name in X_gc.keys()}
feh_var_gc = {gc_name: [] for gc_name in X_gc.keys()}

for fold in range(num_folds):
    
    # load the models
    model_pkl_file = f"MAF_fold{fold}.pkl"
    with open(model_pkl_file, 'rb') as file:
        posterior = pickle.load(file)
    
    # For each globular cluster, make predictions using the loaded model
    for gc_name, gc_data in X_gc.items():
        # Initialize lists to store predictions and variances for the current GC
        feh_pred_gc[gc_name] = []
        feh_var_gc[gc_name] = []
    
        # Iterate through each observation in the GC data
        for observation in gc_data:
            # Convert the single observation to a PyTorch tensor and add an extra dimension to match input shape
            gc_data_tensor = torch.Tensor(observation).unsqueeze(0)
        
            # Make predictions for the single observation
            samples = posterior.sample((250,), x=gc_data_tensor)  # Adjust the number of samples as needed
            fe_h = samples[:, 0]  # Extract [Fe/H] predictions
        
            # Calculate mean and variance of [Fe/H] predictions for the single observation
            feh_pred_gc[gc_name].append(torch.mean(fe_h).item())
            feh_var_gc[gc_name].append(torch.var(fe_h).item())

# Calculate the overall mean and variance of [Fe/H] predictions for each GC across all folds
feh_pred_gc_mean = {gc_name: np.mean(values) for gc_name, values in feh_pred_gc.items()}
feh_var_gc_mean = {gc_name: np.mean(values) for gc_name, values in feh_var_gc.items()}

### 6.5. Histograms: Fe/H predictions (with truths) for each Globular Clusters

In [None]:
for gc_name, predictions in feh_pred_gc.items():
    # convert the predictions to array
    flat_predictions = np.array(predictions)
    
    if flat_predictions.size > 1:
        plt.figure(figsize=(8, 6))
        plt.hist(flat_predictions, bins=30, density=True, alpha=0.6, color = 'purple', label=f'Histogram for {gc_name}')
        # load the 4 truths for each GC
        metallicity_measures = ['metallicity_H10', 'metallicity_K19', 'metallicity_B19', 'metallicity_V20']
        for measure in metallicity_measures:
            if gc_name in gc_metallicities and measure in gc_metallicities[gc_name]:
                true_value = gc_metallicities[gc_name][measure]
                if not np.isnan(true_value):
                    plt.axvline(x=true_value, linestyle='--', label=f'True {measure} for {gc_name}')
        
        plt.title(f'Histogram of [Fe/H] Predictions for {gc_name}')
        plt.xlabel('[Fe/H] Predictions')
        plt.ylabel('Frequency')
        plt.legend()
        plt.tight_layout()
        plt.show()
        
    else:
        print(f"Skipping {gc_name} due to insufficient data for histogram.") 