# VAE Evaluation and Analysis

**Imports**

In [None]:
# Basic imports
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import math
import corner

from matplotlib import pyplot as plt

from torch.optim import Adam

from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
from torchvision.utils import save_image, make_grid

from astropy.io import fits
from astropy.table import Table

**Model import**

In [None]:
# Hyperparameters
dataset_path = '~/datasets'
cuda = True
DEVICE = torch.device("cuda" if cuda else "cpu")

# Data parameters
input_rows = 10000   # Number of spectra to input, tried 20k before. Try 60k now.
batch_size = 100    # Tested 100
validation_split = .2   # Fraction of dataset to reserve for test
random_seed = 42
np.random.seed(random_seed)
shuffle_toggle = False  # Redundant due to random reduced idx already implemented

# Model Dimensions
x_dim  = 7514
# Originally 400, 200
hidden_dim = 400
latent_dim = 28

# Learning rate
# Default 0.001
lr = 0.001
# Gradient clipping
clipping_value = 1

# VAE Beta
beta = 1
# VAE Separability Loss scaling factor
gamma = 1

# Num epochs
epochs = 2500

In [None]:
"""
    A simple implementation of Gaussian MLP Encoder and Decoder
"""

class Encoder(nn.Module):
  def __init__(self, input_dim, hidden_dim, latent_dim):
    super(Encoder, self).__init__()

    # FC Linear version
    self.FC_input = nn.Linear(input_dim+3, 3757)
    # Try batch normalization
    # nn.BatchNorm1d(hidden_dim)
    self.FC_input2 = nn.Linear(3757, 1878)
    self.FC_input3 = nn.Linear(1878, 939)
    self.FC_input4 = nn.Linear(939, 469)
    self.FC_input5 = nn.Linear(469, 234)
    self.FC_input6 = nn.Linear(234, 117)
    self.FC_input7 = nn.Linear(117, 58)
    # self.FC_input8 = nn.Linear(58, 29)

    # Increase number of layers!!

    # Mean and log variance
    self.FC_mean  = nn.Linear(58, latent_dim)
    self.FC_var   = nn.Linear(58, latent_dim)
    
    self.LeakyReLU = nn.LeakyReLU()
    self.gelu = torch.nn.GELU()
    
    self.training = True
      
  def forward(self, x):
    h_ = self.gelu(self.FC_input(x))
    h_ = self.gelu(self.FC_input2(h_))
    h_ = self.gelu(self.FC_input3(h_))
    h_ = self.gelu(self.FC_input4(h_))
    h_ = self.gelu(self.FC_input5(h_))
    h_ = self.gelu(self.FC_input6(h_))
    h_ = self.gelu(self.FC_input7(h_))
    # h_ = self.LeakyReLU(self.FC_input8(h_))

    mean = self.FC_mean(h_)
    log_var = self.FC_var(h_)                     # encoder produces mean and log of variance 
                                                  #             (i.e., parameters of simple tractable normal distribution "q"

    # Last 3 columns of x should be TEFF, LOGG, FE_H
    passed_parameters = x[:, -3:]
    
    return mean, log_var, passed_parameters

In [None]:
class Decoder(nn.Module):
  def __init__(self, latent_dim, hidden_dim, output_dim):
    super(Decoder, self).__init__()

    # FC Linear version
    # self.FC_hidden = nn.Linear(latent_dim, 29)
    self.FC_hidden2 = nn.Linear(latent_dim+3, 58)
    self.FC_hidden3 = nn.Linear(58, 117)
    self.FC_hidden4 = nn.Linear(117, 234)
    self.FC_hidden5 = nn.Linear(234, 469)
    self.FC_hidden6 = nn.Linear(469, 939)
    self.FC_hidden7 = nn.Linear(939, 1878)
    self.FC_hidden8 = nn.Linear(1878, 3757)
    self.FC_output = nn.Linear(3757, output_dim)
    
    self.LeakyReLU = nn.LeakyReLU()
    self.gelu = torch.nn.GELU()

  def forward(self, x):
    # h = self.LeakyReLU(self.FC_hidden(x))
    h = self.gelu(self.FC_hidden2(x))
    h = self.gelu(self.FC_hidden3(h))
    h = self.gelu(self.FC_hidden4(h))
    h = self.gelu(self.FC_hidden5(h))
    h = self.gelu(self.FC_hidden6(h))
    h = self.gelu(self.FC_hidden7(h))
    h = self.gelu(self.FC_hidden8(h))
  
    # originally torch.sigmoid, but output range incorrect
    # Replace with smooth function - look into this!!
    # Softplus, gelu 

    #activation = torch.nn.Softplus()
    activation = torch.nn.GELU()
    x_hat = activation(self.FC_output(h))
    #x_hat = torch.nn.GELU(self.FC_output(h))
    #print(x_hat)
    return x_hat

In [None]:
class Model(nn.Module):
  def __init__(self, Encoder, Decoder):
    super(Model, self).__init__()
    self.Encoder = Encoder
    self.Decoder = Decoder

  def reparameterization(self, mean, var):
    epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
    z = mean + var*epsilon                          # reparameterization trick
    return z
  
  # Modified to explicitly pass reduced_parameter_errors
  def forward(self, x, originalIndex, batch_size, reduced_parameter_errors):
    # Generate mean, log var
    mean, log_var, passed_parameters = self.Encoder(x)

    z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)

    passed_parameters_errors = torch.from_numpy(reduced_parameter_errors[originalIndex : originalIndex + batch_size].astype(np.float32)).to(DEVICE)
    passed_parameters_errors = torch.nn.functional.relu(passed_parameters_errors)
    u = self.reparameterization(passed_parameters, passed_parameters_errors)
    z = torch.hstack((z, u))
    x_hat = self.Decoder(z)
    
    # Modified in evaluation file to return latent vector z
    return x_hat, mean, log_var, z

In [None]:
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)

model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)

In [None]:
model.load_state_dict(torch.load('outputs/3-18-2022-kld-boost/vae_best_300_2022-03-21.pt')['state_dict'])

In [None]:
model.eval()

**Data Load**

In [None]:
print("\n ********************** Opening FITS files from drive **********************")

star_hdus = fits.open('allStar-r12-l33.fits')
astroNN_hdus = fits.open('apogee_astroNN-DR16-v1.fits')
star_spec = fits.open('contspec_dr16_final.fits')

star = star_hdus[1].data
star_astroNN = astroNN_hdus[1].data
star_spectra = star_spec[0].data

star_hdus.close()
astroNN_hdus.close()
star_spec.close()

print("Number of spectra: ", len(star))
print("Data points per spectra: ", len(star_spectra[1]))

# starInfoDebug()

In [None]:
# star_mask = fits.open('contspec_dr16_mask.fits')
# star_err = fits.open('contspec_dr16_err.fits')

# star_mask_data = star_mask[0].data
# star_err_data = star_err[0].data

# star_mask.close()
# star_err.close()

# print("Number of masks: ", len(star_mask_data))
# print("Data points per maskk: ", len(star_mask_data[1]))

# print("Number of errors: ", len(star_err_data))
# print("Data points per error: ", len(star_err_data[1]))

In [None]:
# Further Star Analysis

from astropy.table import Table
dat = Table.read('allStar-r12-l33.fits', format='fits')
names = [name for name in dat.colnames if len(dat[name].shape) <= 1]
df = dat[names].to_pandas()
df.head()

**Dataset Class**


In [None]:
# https://visualstudiomagazine.com/articles/2020/09/10/pytorch-dataloader.aspx

class spectraDataset(torch.utils.data.Dataset):

  # Num rows = max number of spectra to load
  def __init__(self, src, num_rows=None):
    if num_rows == None:
      spectra = src.astype(np.float32)
    else:
      spectra = src.astype(np.float32)[0:num_rows]

    # y_tmp = np.loadtxt(src_file, max_rows=num_rows,
    #   usecols=7, delimiter="\t", skiprows=0,
    #   dtype=np.long)

    self.x_data = torch.tensor(spectra, dtype=torch.float32).to(DEVICE)

    # self.y_data = T.tensor(y_tmp,
    #   dtype=T.long).to(DEVICE)

  def __len__(self):
    return len(self.x_data)  # required

  def __getitem__(self, idx):
    # if T.is_tensor(idx):
    #   idx = idx.tolist()
    # preds = self.x_data[idx, 0:7]
    # pol = self.y_data[idx]
    # sample = \
    #   { 'predictors' : preds, 'political' : pol }

    sample = self.x_data[idx]
    # Modified March 21 2022 to return batch index
    return sample, idx

**Data Filtering**

In [None]:
def fetch_data(df, star_spectra, teff_min, teff_max, logg_min, logg_max, snr_min, snr_max):
    # Surpress write on copy warning
    pd.options.mode.chained_assignment = None  # default='warn'

    # Isolate critical columns
    star_df = df[['APSTAR_ID', 'TEFF_SPEC', 'LOGG_SPEC', 'SNR', 'ASPCAPFLAGS', 'STARFLAGS', 'FE_H']]

    #We only include stars with no bad star flags set, SNR > 200, 4000 < teff < 5500, and logg < 3.5.
    star_df_best = star_df.loc[(star_df['TEFF_SPEC'] < teff_max) & (star_df['TEFF_SPEC'] > teff_min) & 
    (star_df['LOGG_SPEC'] < logg_max) & (star_df['LOGG_SPEC'] > logg_min) & 
    (star_df['SNR'] < snr_max) & (star_df['SNR'] > snr_min)]

    # Decode byte flags into strings
    star_df_best['ASPCAPFLAGS'] = star_df_best['ASPCAPFLAGS'].str.decode("utf-8")

    # Strip out stars with STAR_BAD flag
    star_df_best = star_df_best.loc[~(star_df_best['ASPCAPFLAGS'].str.contains("STAR_BAD"))]

    # Modified to add 3 parameters in 'u'
    # Extract columns with three target parameters
    star_df_best_parameters = star_df_best[["TEFF_SPEC", "LOGG_SPEC", "FE_H"]]
    star_df_best_parameters.reset_index()

    # Make a new array with the associated errors for TEFF, LOGG, FE_H
    # Have to transform 1d numpy array to 2d array for hstack. Transpose to get row -> Column
    parameter_errors = np.hstack((star['teff_err'][None, :].T,star['logg_err'][None, :].T, star['fe_h_err'][None, :].T))
    star_df_best_parameters_np = star_df_best_parameters.to_numpy()
    star_df_best_parameters_np_std_dev = star_df_best_parameters_np.std(axis=0)
    star_df_best_parameters_np -= star_df_best_parameters_np.mean(axis=0)
    star_df_best_parameters_np /= star_df_best_parameters_np.std(axis=0)

    #print(star_df_best.head(1))

    # Update the star_spectra dataframe with only 'good' indices
    star_spectra_filtered = star_spectra[star_df_best.index]
    star_spectra_filtered = np.hstack((star_spectra_filtered, star_df_best_parameters_np))

    ### Modified March 21 2022
    parameter_errors = parameter_errors[star_df_best.index]
    # Divide parameter errors by std dev of data
    parameter_errors /= star_df_best_parameters_np_std_dev
    
    # Update masks, errors as well  
    # star_err_data = star_err_data[star_df_best.index]
    # star_mask_data = star_mask_data[star_df_best.index]

    print("After applying data filters " + str(len(star_spectra_filtered)) + " spectra remaining")

    # # Reduce the dataset down to a manageable size, based on input_rows hyperparameter
    np.random.seed(random_seed)
    random_reduced_idx = list(np.random.choice(len(star_spectra_filtered), input_rows, replace=False))
    
    print("Reduced spectra count to " + str(input_rows))

    # # Grab only spectra with indices randomly selected from above
    #reduced_star_spectra = np.take(star_spectra, random_reduced_idx, 0)
    reduced_star_spectra = star_spectra_filtered[random_reduced_idx]

    ### Modified March 21 2022
    reduced_parameter_errors = parameter_errors[random_reduced_idx]

    #print(pd.DataFrame(reduced_star_spectra))

    # Normalize
    for starRow in reduced_star_spectra:
        starRow -= starRow.min()
        #print(starRow.max())
        if starRow.max() == 0:
            #print("Found zero max")
            pass
        else:
            starRow /= starRow.max()

    #print(pd.DataFrame(reduced_star_spectra))

    # Final normalized, reduced inputs
    train_dataset = spectraDataset(reduced_star_spectra)

    # Creating data indices for training and validation splits:
    dataset_size = len(reduced_star_spectra)
    indices = list(range(dataset_size))
    # split = int(np.floor(validation_split * dataset_size))

    # print("Splitting dataset at", split)

    # If shuffling is enabled, use random seed to shuffle data indices
    if shuffle_toggle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    # Generate data loaders
    kwargs = {'num_workers': 0}

    # Try without random sampling (simple split on index)
    # train_loader = DataLoader(train_dataset[split:], batch_size=batch_size, **kwargs)
    test_loader = DataLoader(train_dataset, batch_size=batch_size, **kwargs)

    # print('Batches in train:', len(train_loader))
    print('Batches in test:', len(test_loader))

    return test_loader, reduced_parameter_errors

High, Medium, Low SNR Stars

In [None]:
#We only include stars with no bad star flags set, SNR > 200, 4000 < teff < 5500, and logg < 3.5.

low_snr_data, low_snr_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 0, snr_max = 100)

medium_snr_data, medium_snr_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 100, snr_max = 200)

high_snr_data, high_snr_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

High, Medium, Low Effective Temperature Stars

In [None]:
# Not enough samples when logg is restricted -> Need to verify

low_teff_data, low_teff_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 0, 
teff_max = 4000, logg_min = 0, logg_max = 5, snr_min = 200, snr_max = math.inf)

medium_teff_data, medium_teff_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 5, snr_min = 200, snr_max = math.inf)

high_teff_data, high_teff_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 5500, 
teff_max = math.inf, logg_min = 0, logg_max = 5, snr_min = 200, snr_max = math.inf)

High, Medium, Low Surface Gravity Stars

In [None]:
low_logg_data, low_logg_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 2, snr_min = 200, snr_max = math.inf)

medium_logg_data, medium_logg_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 2, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

high_logg_data, high_logg_parameter_errors = fetch_data(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 3.5, logg_max = math.inf, snr_min = 200, snr_max = math.inf)

**Latent Space Plotting**

In [None]:
print(encoder.FC_mean.weight)
print(encoder.FC_var.weight)

In [None]:
print(encoder.FC_mean.weight[0])

In [None]:
# Feed model helper function
def feed_model(data_loader, parameter_errors):
  latent_vector_array = []

  with torch.no_grad():
    for batch_idx, (x, _) in enumerate(tqdm(data_loader)):
      x = x.view(batch_size, x_dim+3)
      x = x.to(DEVICE)

      #print("Looking at original batch index:", batch_idx)
      x_hat, mean, log_var, z = model(x, batch_idx, batch_size, parameter_errors)

      # Each batch has 100 latent vecs of size 28
      for batchItem in z:
        # Add to overall latent vector array
        latent_vector_array += [batchItem.cpu().detach().numpy()]

  return latent_vector_array

**Feed Model with Filtered Data**

In [None]:
high_snr_array = feed_model(high_snr_data, high_snr_parameter_errors)
medium_snr_array = feed_model(medium_snr_data, medium_snr_parameter_errors)
low_snr_array = feed_model(low_snr_data, low_snr_parameter_errors)

In [None]:
high_teff_array = feed_model(high_teff_data, high_teff_parameter_errors)
medium_teff_array = feed_model(medium_teff_data, medium_teff_parameter_errors)
low_teff_array = feed_model(low_teff_data, low_teff_parameter_errors)

In [None]:
high_logg_array = feed_model(high_logg_data, high_logg_parameter_errors)
medium_logg_array = feed_model(medium_logg_data, medium_logg_parameter_errors)
low_logg_array = feed_model(low_logg_data, low_logg_parameter_errors)

**SNR Plot**

In [None]:
# Test Plot for Selecting categories
from matplotlib.lines import Line2D

figure_snr = corner.corner(np.array(high_snr_array)[:, [1,2,3,4,5]], color="green", labels=['VAE 1','VAE 2', 'VAE 3','VAE 4','VAE 5','VAE 6'])
corner.corner(np.array(medium_snr_array)[:, [1,2,3,4,5]], color="yellow", fig=figure_snr)
corner.corner(np.array(low_snr_array)[:, [1,2,3,4,5]], color="red", fig=figure_snr)
#plt.title("High/Medium/Low SNR Latent Space Comparison")

figure_snr.suptitle("High/Medium/Low SNR Latent Space Comparison")

legend_elements = [Line2D([0], [0], marker='o', color='green', label='High', markerfacecolor='green', markersize=10),
                   Line2D([0], [0], marker='o', color='yellow', label='Medium', markerfacecolor='yellow', markersize=10),
                   Line2D([0], [0], marker='o', color='red', label='Low', markerfacecolor='red', markersize=10)]

plt.legend(handles=legend_elements, loc='upper left')
plt.show()

In [None]:
figure_snr = corner.corner(np.array(high_snr_array), color="green")
corner.corner(np.array(medium_snr_array), color="yellow", fig=figure_snr)
corner.corner(np.array(low_snr_array), color="red", fig=figure_snr)
plt.title("High/Medium/Low SNR Latent Space Comparison")
plt.show()

**Effective Temp Plot**

In [None]:
figure_teff = corner.corner(np.array(high_teff_array), color="green", labels=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31])
corner.corner(np.array(medium_teff_array), color="yellow", fig=figure_teff)
corner.corner(np.array(low_teff_array), color="red", fig=figure_teff)
plt.title("High/Medium/Low Effective Temperature Latent Space Comparison")
plt.show()

**Surface Gravity Plot**

In [None]:
figure_logg = corner.corner(np.array(high_logg_array), color="green", labels=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31])
corner.corner(np.array(medium_logg_array), color="yellow", fig=figure_logg)
corner.corner(np.array(low_logg_array), color="red", fig=figure_logg)
plt.title("High/Medium/Low Logg Latent Space Comparison")
plt.show()

*Sample testing stars with similar parameters - vary just one parameter up/down e.g. logg or teff
Plot in latent space and try and generate tracks, see if any trends*

**Track Visualization**

In [None]:
batch_size = 50
input_rows = 50

In [None]:
# Modified data fetch function for plotting tracks

def fetch_data_single(df, star_spectra, teff_min, teff_max, logg_min, logg_max, snr_min, snr_max):
    # Surpress write on copy warning
    pd.options.mode.chained_assignment = None  # default='warn'

    # Isolate critical columns
    star_df = df[['APSTAR_ID', 'TEFF_SPEC', 'LOGG_SPEC', 'SNR', 'ASPCAPFLAGS', 'STARFLAGS', 'FE_H']]

    #We only include stars with no bad star flags set, SNR > 200, 4000 < teff < 5500, and logg < 3.5.
    star_df_best = star_df.loc[(star_df['TEFF_SPEC'] < teff_max) & (star_df['TEFF_SPEC'] > teff_min) & 
    (star_df['LOGG_SPEC'] < logg_max) & (star_df['LOGG_SPEC'] > logg_min) & 
    (star_df['SNR'] < snr_max) & (star_df['SNR'] > snr_min)]

    # Decode byte flags into strings
    star_df_best['ASPCAPFLAGS'] = star_df_best['ASPCAPFLAGS'].str.decode("utf-8")

    # Strip out stars with STAR_BAD flag
    star_df_best = star_df_best.loc[~(star_df_best['ASPCAPFLAGS'].str.contains("STAR_BAD"))]

    # Modified to add 3 parameters in 'u'
    # Extract columns with three target parameters
    star_df_best_parameters = star_df_best[["TEFF_SPEC", "LOGG_SPEC", "FE_H"]]
    star_df_best_parameters.reset_index()

    # Make a new array with the associated errors for TEFF, LOGG, FE_H
    # Have to transform 1d numpy array to 2d array for hstack. Transpose to get row -> Column
    parameter_errors = np.hstack((star['teff_err'][None, :].T,star['logg_err'][None, :].T, star['fe_h_err'][None, :].T))
    star_df_best_parameters_np = star_df_best_parameters.to_numpy()
    star_df_best_parameters_np_std_dev = star_df_best_parameters_np.std(axis=0)
    star_df_best_parameters_np -= star_df_best_parameters_np.mean(axis=0)
    star_df_best_parameters_np /= star_df_best_parameters_np.std(axis=0)

    #print(star_df_best.head(1))

    # Update the star_spectra dataframe with only 'good' indices
    star_spectra_filtered = star_spectra[star_df_best.index]
    star_spectra_filtered = np.hstack((star_spectra_filtered, star_df_best_parameters_np))

    ### Modified March 21 2022
    parameter_errors = parameter_errors[star_df_best.index]
    # Divide parameter errors by std dev of data
    parameter_errors /= star_df_best_parameters_np_std_dev

    # Update masks, errors as well
    # star_err_data = star_err_data[star_df_best.index]
    # star_mask_data = star_mask_data[star_df_best.index]

    print("After applying data filters " + str(len(star_spectra_filtered)) + " spectra remaining")
    
    # Bring star_df_best to have same indices as star_spectra_filtered (reset to start from 0)
    star_df_best_reset = star_df_best.reset_index(drop=True)
    #print(star_df_best_reset)
    #print(star_spectra_filtered)

    # # Reduce the dataset down to a manageable size, based on input_rows hyperparameter
    np.random.seed(random_seed)
    random_reduced_idx = list(np.random.choice(len(star_spectra_filtered), input_rows, replace=False))
    
    print("Reduced spectra count to " + str(input_rows))

    # # Grab only spectra with indices randomly selected from above
    #reduced_star_spectra = np.take(star_spectra, random_reduced_idx, 0)

    # Modified to just take first entry
    #reduced_star_spectra = np.array([star_spectra_filtered[0]])
    # Original:
    reduced_star_spectra = np.array([star_spectra_filtered[random_reduced_idx]])

    ### Modified March 21 2022
    reduced_parameter_errors = parameter_errors[random_reduced_idx]

    #print("Reduced spectra:",reduced_star_spectra)

    # For average star calculation
    reduced_star_df_best = star_df_best_reset.iloc[random_reduced_idx]
    #print("Reduced parameters:",reduced_star_df_best)
    # Calculate means of paramaters
    print("Mean TEFF:",reduced_star_df_best["TEFF_SPEC"].mean())
    print("Mean LOGG:",reduced_star_df_best["LOGG_SPEC"].mean())
    print("Mean SNR:",reduced_star_df_best["SNR"].mean())

    #print(pd.DataFrame(reduced_star_spectra))

    # Normalize
    for starRow in reduced_star_spectra:
        starRow -= starRow.min()
        #print(starRow.max())
        if starRow.max() == 0:
            #print("Found zero max")
            pass
        else:
            starRow /= starRow.max()

    #print(pd.DataFrame(reduced_star_spectra))

    # Final normalized, reduced inputs
    train_dataset = spectraDataset(reduced_star_spectra)

    # Creating data indices for training and validation splits:
    dataset_size = len(reduced_star_spectra)
    indices = list(range(dataset_size))
    # split = int(np.floor(validation_split * dataset_size))

    # print("Splitting dataset at", split)

    # If shuffling is enabled, use random seed to shuffle data indices
    if shuffle_toggle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    # Generate data loaders
    kwargs = {'num_workers': 0}

    # Try without random sampling (simple split on index)
    # train_loader = DataLoader(train_dataset[split:], batch_size=batch_size, **kwargs)
    test_loader = DataLoader(train_dataset, batch_size = batch_size, **kwargs)

    # print('Batches in train:', len(train_loader))
    print('Batches in test:', len(test_loader))

    return test_loader, reduced_parameter_errors

**Extract SNR/Teff/LogG Track Samples**

In [None]:
snr_sample_1, snr_sample_1_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3, logg_max = 3.1, snr_min = 50, snr_max = 100)

snr_sample_2, snr_sample_2_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3, logg_max = 3.1, snr_min = 100, snr_max = 150)

snr_sample_3, snr_sample_3_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3, logg_max = 3.1, snr_min = 150, snr_max = 200)

snr_sample_4, snr_sample_4_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3, logg_max = 3.1, snr_min = 250, snr_max = 300)

snr_sample_5, snr_sample_5_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3, logg_max = 3.1, snr_min = 300, snr_max = math.inf)

In [None]:
# teff_sample_1 = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4200, 
# teff_max = 4400, logg_min = 3.45, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

# teff_sample_2 = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4400, 
# teff_max = 4600, logg_min = 3.45, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

# teff_sample_3 = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4600, 
# teff_max = 4800, logg_min = 3.45, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

# teff_sample_4 = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
# teff_max = 5000, logg_min = 3.45, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

# teff_sample_5 = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 5000, 
# teff_max = 5200, logg_min = 3.45, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

# teff_sample_6 = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 5200, 
# teff_max = 5400, logg_min = 3.45, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

teff_sample_1, teff_sample_1_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4700, 
teff_max = 4800, logg_min = 3, logg_max = 3.1, snr_min = 200, snr_max = math.inf)

teff_sample_2, teff_sample_2_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3, logg_max = 3.1, snr_min = 200, snr_max = math.inf)

teff_sample_3, teff_sample_3_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4900, 
teff_max = 5000, logg_min = 3, logg_max = 3.1, snr_min = 200, snr_max = math.inf)

teff_sample_4, teff_sample_4_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 5000, 
teff_max = 5100, logg_min = 3, logg_max = 3.1, snr_min = 200, snr_max = math.inf)

teff_sample_5, teff_sample_5_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 5100, 
teff_max = 5200, logg_min = 3, logg_max = 3.1, snr_min = 200, snr_max = math.inf)

In [None]:
logg_sample_1, logg_sample_1_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 5000, logg_min = 3, logg_max = 3.1, snr_min = 200, snr_max = math.inf)

logg_sample_2, logg_sample_2_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 5000, logg_min = 3.1, logg_max = 3.2, snr_min = 200, snr_max = math.inf)

logg_sample_3, logg_sample_3_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 5000, logg_min = 3.2, logg_max = 3.3, snr_min = 200, snr_max = math.inf)

logg_sample_4, logg_sample_4_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 5000, logg_min = 3.3, logg_max = 3.4, snr_min = 200, snr_max = math.inf)

logg_sample_5, logg_sample_5_error = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 5000, logg_min = 3.4, logg_max = 3.5, snr_min = 200, snr_max = math.inf)

In [None]:
for idx, val in enumerate(teff_sample_1):
    print(idx, val)

In [None]:
# Feed model helper function
def feed_model_single(data_loader, parameter_errors):
  latent_vector_array = []

  with torch.no_grad():
    for batch_idx, (x, _) in enumerate(tqdm(data_loader)):
      #print(x)
      x = x.view(batch_size, x_dim+3)
      x = x.to(DEVICE)    
      x_hat, mean, log_var, z = model(x, batch_idx, batch_size, parameter_errors)

      # Debug latent space print
      #print(z)

      latent_vector_array += [z.cpu().detach().numpy()]

  return latent_vector_array

In [None]:
snr_sample_1_array = feed_model_single(snr_sample_1, snr_sample_1_error)
snr_sample_2_array = feed_model_single(snr_sample_2, snr_sample_2_error)
snr_sample_3_array = feed_model_single(snr_sample_3, snr_sample_3_error)
snr_sample_4_array = feed_model_single(snr_sample_4, snr_sample_4_error)
snr_sample_5_array = feed_model_single(snr_sample_5, snr_sample_5_error)

In [None]:
teff_sample_1_array = feed_model_single(teff_sample_1, teff_sample_1_error)
teff_sample_2_array = feed_model_single(teff_sample_2, teff_sample_2_error)
teff_sample_3_array = feed_model_single(teff_sample_3, teff_sample_3_error)
teff_sample_4_array = feed_model_single(teff_sample_4, teff_sample_4_error)
teff_sample_5_array = feed_model_single(teff_sample_5, teff_sample_5_error)
# teff_sample_6_array = feed_model_single(teff_sample_6)

In [None]:
logg_sample_1_array = feed_model_single(logg_sample_1, logg_sample_1_error)
logg_sample_2_array = feed_model_single(logg_sample_2, logg_sample_2_error)
logg_sample_3_array = feed_model_single(logg_sample_3, logg_sample_3_error)
logg_sample_4_array = feed_model_single(logg_sample_4, logg_sample_4_error)
logg_sample_5_array = feed_model_single(logg_sample_5, logg_sample_5_error)

**Average latent space values**

In [None]:
# 10 rows, each row with 28 latent space vals
print("Number of samples:",len(teff_sample_1_array[0]))
print("Number of vals in single sample:",len(teff_sample_1_array[0][0]))

In [None]:
# Averaging helper function
def latentValAverager(input):
    # multiple_lists = [[2,5,1,9], [4,9,5,10]]
    arrays = [np.array(x) for x in input]
    return [np.mean(k) for k in zip(*arrays)]

In [None]:
# Latent space values averaged across 10 samples from each class - raw list of 28 outputs
snr_sample_1_array_mean = latentValAverager(snr_sample_1_array[0]) 
snr_sample_2_array_mean = latentValAverager(snr_sample_2_array[0]) 
snr_sample_3_array_mean = latentValAverager(snr_sample_3_array[0]) 
snr_sample_4_array_mean = latentValAverager(snr_sample_4_array[0]) 
snr_sample_5_array_mean = latentValAverager(snr_sample_5_array[0]) 

In [None]:
# Latent space values averaged across 10 samples from each class - raw list of 28 outputs
teff_sample_1_array_mean = latentValAverager(teff_sample_1_array[0]) 
teff_sample_2_array_mean = latentValAverager(teff_sample_2_array[0]) 
teff_sample_3_array_mean = latentValAverager(teff_sample_3_array[0]) 
teff_sample_4_array_mean = latentValAverager(teff_sample_4_array[0]) 
teff_sample_5_array_mean = latentValAverager(teff_sample_5_array[0]) 

In [None]:
# Latent space values averaged across 10 samples from each class - raw list of 28 outputs
logg_sample_1_array_mean = latentValAverager(logg_sample_1_array[0]) 
logg_sample_2_array_mean = latentValAverager(logg_sample_2_array[0]) 
logg_sample_3_array_mean = latentValAverager(logg_sample_3_array[0]) 
logg_sample_4_array_mean = latentValAverager(logg_sample_4_array[0]) 
logg_sample_5_array_mean = latentValAverager(logg_sample_5_array[0]) 

**SNR Tracking Plot**

In [None]:
ndim = 31
# Extract the axes
axes = np.array(figure_snr.axes).reshape((ndim, ndim))

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]

        ax.plot(snr_sample_1_array_mean[xi], snr_sample_1_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(snr_sample_2_array_mean[xi], snr_sample_2_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(snr_sample_3_array_mean[xi], snr_sample_3_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(snr_sample_4_array_mean[xi], snr_sample_4_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(snr_sample_5_array_mean[xi], snr_sample_5_array_mean[yi], marker="o", markersize=2, color="darkviolet")

        # Add text label
        ax.annotate("1",xy=(snr_sample_1_array_mean[xi], snr_sample_1_array_mean[yi]), fontsize=8)
        ax.annotate("2",xy=(snr_sample_2_array_mean[xi], snr_sample_2_array_mean[yi]), fontsize=8)
        ax.annotate("3",xy=(snr_sample_3_array_mean[xi], snr_sample_3_array_mean[yi]), fontsize=8)
        ax.annotate("4",xy=(snr_sample_4_array_mean[xi], snr_sample_4_array_mean[yi]), fontsize=8)
        ax.annotate("5",xy=(snr_sample_5_array_mean[xi], snr_sample_5_array_mean[yi]), fontsize=8)

figure_snr

In [None]:
# Test Plot for Selecting categories
from matplotlib.lines import Line2D

figure_snr_crop = corner.corner(np.array(high_snr_array)[:, [1,2,3,4,5]], color="green", labels=['VAE 1','VAE 2', 'VAE 3','VAE 4','VAE 5','VAE 6'])
corner.corner(np.array(medium_snr_array)[:, [1,2,3,4,5]], color="yellow", fig=figure_snr_crop)
corner.corner(np.array(low_snr_array)[:, [1,2,3,4,5]], color="red", fig=figure_snr_crop)
#plt.title("High/Medium/Low SNR Latent Space Comparison")

figure_snr_crop.suptitle("High/Medium/Low SNR Latent Space Comparison")

legend_elements = [Line2D([0], [0], marker='o', color='green', label='High', markerfacecolor='green', markersize=10),
                   Line2D([0], [0], marker='o', color='yellow', label='Medium', markerfacecolor='yellow', markersize=10),
                   Line2D([0], [0], marker='o', color='red', label='Low', markerfacecolor='red', markersize=10),
                   Line2D([0], [0], marker='o', color='blue', label='Track', markerfacecolor='blue', markersize=10)
                   ]

plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.25, 2))

ndim = 5
# Extract the axes
axes = np.array(figure_snr_crop.axes).reshape((ndim, ndim))

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
        #print("ypos:",yi+1,"xpos:",xi+1)

        # ax.text(0.9, 0.95,(xi+1,yi+1),
        # horizontalalignment='center',
        # verticalalignment='center',
        # transform = ax.transAxes)
        # ax.annotate((xi+1,yi+1), xy=(0,0))

        ax.plot(snr_sample_1_array_mean[xi+1], snr_sample_1_array_mean[yi+1], marker="o", markersize=4, color="blue")
        ax.plot(snr_sample_2_array_mean[xi+1], snr_sample_2_array_mean[yi+1], marker="o", markersize=4, color="blue")
        ax.plot(snr_sample_3_array_mean[xi+1], snr_sample_3_array_mean[yi+1], marker="o", markersize=4, color="blue")
        ax.plot(snr_sample_4_array_mean[xi+1], snr_sample_4_array_mean[yi+1], marker="o", markersize=4, color="blue")
        ax.plot(snr_sample_5_array_mean[xi+1], snr_sample_5_array_mean[yi+1], marker="o", markersize=4, color="blue")

        # Add text label
        # ax.annotate("1",xy=(snr_sample_1_array_mean[xi+1], snr_sample_1_array_mean[yi+1]), fontsize=8)
        # ax.annotate("2",xy=(snr_sample_2_array_mean[xi+1], snr_sample_2_array_mean[yi+1]), fontsize=8)
        # ax.annotate("3",xy=(snr_sample_3_array_mean[xi+1], snr_sample_3_array_mean[yi+1]), fontsize=8)
        # ax.annotate("4",xy=(snr_sample_4_array_mean[xi+1], snr_sample_4_array_mean[yi+1]), fontsize=8)
        # ax.annotate("5",xy=(snr_sample_5_array_mean[xi+1], snr_sample_5_array_mean[yi+1]), fontsize=8)

plt.show()

**Effective Temp Tracking Plot**

In [None]:
ndim = 31

# ndim, nsamples = ndim, 50
# np.random.seed(42)
# samples = np.random.randn(ndim * nsamples).reshape([nsamples, ndim])
# figure = corner.corner(samples)

# Extract the axes
axes = np.array(figure_teff.axes).reshape((ndim, ndim))

# Loop over the diagonal - histograms
# for i in range(ndim):

#     ax = axes[i, i]
#     ax.axvline(1, color="g")
#     ax.axvline(1, color="r")

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
        # ax.annotate(str(yi)+","+str(xi), (teff_sample_1_array[0][0][xi], teff_sample_1_array[0][0][yi]))

        ax.plot(teff_sample_1_array_mean[xi], teff_sample_1_array_mean[yi], marker="o", markersize=4, color="purple")
        ax.plot(teff_sample_2_array_mean[xi], teff_sample_2_array_mean[yi], marker="o", markersize=4, color="purple")
        ax.plot(teff_sample_3_array_mean[xi], teff_sample_3_array_mean[yi], marker="o", markersize=4, color="purple")
        ax.plot(teff_sample_4_array_mean[xi], teff_sample_4_array_mean[yi], marker="o", markersize=4, color="purple")
        ax.plot(teff_sample_5_array_mean[xi], teff_sample_5_array_mean[yi], marker="o", markersize=4, color="purple")

        # Add text label
        ax.annotate("1",xy=(teff_sample_1_array_mean[xi], teff_sample_1_array_mean[yi]), fontsize=8)
        ax.annotate("2",xy=(teff_sample_2_array_mean[xi], teff_sample_2_array_mean[yi]), fontsize=8)
        ax.annotate("3",xy=(teff_sample_3_array_mean[xi], teff_sample_3_array_mean[yi]), fontsize=8)
        ax.annotate("4",xy=(teff_sample_4_array_mean[xi], teff_sample_4_array_mean[yi]), fontsize=8)
        ax.annotate("5",xy=(teff_sample_5_array_mean[xi], teff_sample_5_array_mean[yi]), fontsize=8)

        # Old plotting code for single sample
        # ax.plot(teff_sample_6_array[0][0][xi], teff_sample_6_array[0][0][yi], marker=".", markersize=10, color="lavender")

        # ax.plot(1, 1, "sg")
        # ax.plot(2, 2, "sr")

figure_teff
    

In [None]:
# Test Plot for Selecting categories
from matplotlib.lines import Line2D

figure_teff_crop = corner.corner(np.array(high_teff_array)[:, [8,9,10,11,12]], color="green", labels=['VAE 9','VAE 10','VAE 11', 'VAE 12','VAE 13'])
corner.corner(np.array(medium_teff_array)[:, [8,9,10,11,12]], color="yellow", fig=figure_teff_crop)
corner.corner(np.array(low_teff_array)[:, [8,9,10,11,12]], color="red", fig=figure_teff_crop)
#plt.title("High/Medium/Low SNR Latent Space Comparison")

figure_teff_crop.suptitle("High/Medium/Low Effective Temperature Latent Space Comparison")

legend_elements = [Line2D([0], [0], marker='o', color='green', label='High', markerfacecolor='green', markersize=10),
                   Line2D([0], [0], marker='o', color='yellow', label='Medium', markerfacecolor='yellow', markersize=10),
                   Line2D([0], [0], marker='o', color='red', label='Low', markerfacecolor='red', markersize=10),
                   Line2D([0], [0], marker='o', color='blue', label='Track', markerfacecolor='dodgerblue', markersize=10)
                   ]

plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.25, 2))

ndim = 5
# Extract the axes
axes = np.array(figure_teff_crop.axes).reshape((ndim, ndim))

# Offset = first index in crop
offset = 8

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
        #print("ypos:",yi+1,"xpos:",xi+1)

        # ax.text(0.9, 0.95,(xi,yi),
        # horizontalalignment='center',
        # verticalalignment='center',
        # transform = ax.transAxes)
        # ax.annotate((xi+1,yi+1), xy=(0,0))

        ax.plot(teff_sample_1_array_mean[xi+offset], teff_sample_1_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=3, color="blue")
        ax.plot(teff_sample_2_array_mean[xi+offset], teff_sample_2_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=3, color="blue")
        ax.plot(teff_sample_3_array_mean[xi+offset], teff_sample_3_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=3, color="blue")
        ax.plot(teff_sample_4_array_mean[xi+offset], teff_sample_4_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=3, color="blue")
        ax.plot(teff_sample_5_array_mean[xi+offset], teff_sample_5_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=3, color="blue")

        # Add text label
        ax.annotate("1",xy=(teff_sample_1_array_mean[xi+offset], teff_sample_1_array_mean[yi+offset]), fontsize=12)
        ax.annotate("2",xy=(teff_sample_2_array_mean[xi+offset], teff_sample_2_array_mean[yi+offset]), fontsize=12)
        ax.annotate("3",xy=(teff_sample_3_array_mean[xi+offset], teff_sample_3_array_mean[yi+offset]), fontsize=12)
        ax.annotate("4",xy=(teff_sample_4_array_mean[xi+offset], teff_sample_4_array_mean[yi+offset]), fontsize=12)
        ax.annotate("5",xy=(teff_sample_5_array_mean[xi+offset], teff_sample_5_array_mean[yi+offset]), fontsize=12)

plt.show()

**Surface Gravity Tracking Plot**

In [None]:
ndim = 31

# ndim, nsamples = ndim, 50
# np.random.seed(42)
# samples = np.random.randn(ndim * nsamples).reshape([nsamples, ndim])
# figure = corner.corner(samples)

# Extract the axes
axes = np.array(figure_logg.axes).reshape((ndim, ndim))

# Loop over the diagonal - histograms
# for i in range(ndim):

#     ax = axes[i, i]
#     ax.axvline(1, color="g")
#     ax.axvline(1, color="r")

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
        # ax.annotate(str(yi)+","+str(xi), (teff_sample_1_array[0][0][xi], teff_sample_1_array[0][0][yi]))

        ax.plot(logg_sample_1_array_mean[xi], logg_sample_1_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(logg_sample_2_array_mean[xi], logg_sample_2_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(logg_sample_3_array_mean[xi], logg_sample_3_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(logg_sample_4_array_mean[xi], logg_sample_4_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(logg_sample_5_array_mean[xi], logg_sample_5_array_mean[yi], marker="o", markersize=2, color="darkviolet")

        # Add text label
        ax.annotate("1",xy=(logg_sample_1_array_mean[xi], logg_sample_1_array_mean[yi]), fontsize=8)
        ax.annotate("2",xy=(logg_sample_2_array_mean[xi], logg_sample_2_array_mean[yi]), fontsize=8)
        ax.annotate("3",xy=(logg_sample_3_array_mean[xi], logg_sample_3_array_mean[yi]), fontsize=8)
        ax.annotate("4",xy=(logg_sample_4_array_mean[xi], logg_sample_4_array_mean[yi]), fontsize=8)
        ax.annotate("5",xy=(logg_sample_5_array_mean[xi], logg_sample_5_array_mean[yi]), fontsize=8)

        # Old plotting code for single sample
        # ax.plot(teff_sample_6_array[0][0][xi], teff_sample_6_array[0][0][yi], marker=".", markersize=10, color="lavender")

        # ax.plot(1, 1, "sg")
        # ax.plot(2, 2, "sr")

figure_logg
    

In [None]:
# Test Plot for Selecting categories
from matplotlib.lines import Line2D

figure_logg_crop = corner.corner(np.array(high_logg_array)[:, [16,17,18,19,20]], color="green", labels=['VAE 17','VAE 18','VAE 19', 'VAE 20','VAE 21'])
corner.corner(np.array(medium_logg_array)[:, [16,17,18,19,20]], color="yellow", fig=figure_logg_crop)
corner.corner(np.array(low_logg_array)[:, [16,17,18,19,20]], color="red", fig=figure_logg_crop)
#plt.title("High/Medium/Low SNR Latent Space Comparison")

figure_logg_crop.suptitle("High/Medium/Low Surface Gravity Latent Space Comparison")

legend_elements = [Line2D([0], [0], marker='o', color='green', label='High', markerfacecolor='green', markersize=10),
                   Line2D([0], [0], marker='o', color='yellow', label='Medium', markerfacecolor='yellow', markersize=10),
                   Line2D([0], [0], marker='o', color='red', label='Low', markerfacecolor='red', markersize=10),
                   Line2D([0], [0], marker='o', color='blue', label='Track', markerfacecolor='dodgerblue', markersize=10)
                   ]

plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.25, 2))

ndim = 5
# Extract the axes
axes = np.array(figure_logg_crop.axes).reshape((ndim, ndim))

# Offset = first index in crop
offset = 16

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
        #print("ypos:",yi+1,"xpos:",xi+1)

        # ax.text(0.15, 0.15,(xi,yi),
        # horizontalalignment='center',
        # verticalalignment='center',
        # transform = ax.transAxes)
        # ax.annotate((xi+1,yi+1), xy=(0,0))

        ax.plot(logg_sample_1_array_mean[xi+offset], logg_sample_1_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(logg_sample_2_array_mean[xi+offset], logg_sample_2_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(logg_sample_3_array_mean[xi+offset], logg_sample_3_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(logg_sample_4_array_mean[xi+offset], logg_sample_4_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue", markersize=4, color="blue")
        ax.plot(logg_sample_5_array_mean[xi+offset], logg_sample_5_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")

        # Add text label
        ax.annotate("1",xy=(logg_sample_1_array_mean[xi+offset], logg_sample_1_array_mean[yi+offset]), fontsize=12)
        ax.annotate("2",xy=(logg_sample_2_array_mean[xi+offset], logg_sample_2_array_mean[yi+offset]), fontsize=12)
        ax.annotate("3",xy=(logg_sample_3_array_mean[xi+offset], logg_sample_3_array_mean[yi+offset]), fontsize=12)
        ax.annotate("4",xy=(logg_sample_4_array_mean[xi+offset], logg_sample_4_array_mean[yi+offset]), fontsize=12)
        ax.annotate("5",xy=(logg_sample_5_array_mean[xi+offset], logg_sample_5_array_mean[yi+offset]), fontsize=12)

plt.show()

**FE_H,  O_FE Corner Plots - Data**

In [None]:
# Modified data fetch function for metallicity/abundance

def fetch_data_abundance(df, star_spectra, teff_min, teff_max, logg_min, logg_max, snr_min, snr_max, 
metallicity_min=None, metallicity_max=None, abundance_min=None, abundance_max=None, abundance_flag=None):
    # Surpress write on copy warning
    pd.options.mode.chained_assignment = None  # default='warn'

    # Isolate critical columns

    # No specific chemical abundance to target
    # Baseline: We only include stars with no bad star flags set, SNR > 200, 4000 < teff < 5500, and logg < 3.5.

    if abundance_flag == None:
        star_df = df[['APSTAR_ID', 'TEFF_SPEC', 'LOGG_SPEC', 'SNR', 'ASPCAPFLAGS', 'STARFLAGS', 'FE_H']]

        star_df_best = star_df.loc[(star_df['TEFF_SPEC'] < teff_max) & (star_df['TEFF_SPEC'] > teff_min) & 
        (star_df['LOGG_SPEC'] < logg_max) & (star_df['LOGG_SPEC'] > logg_min) & 
        (star_df['SNR'] < snr_max) & (star_df['SNR'] > snr_min) &
        (star_df['FE_H'] < metallicity_max) & (star_df['FE_H'] > metallicity_min)]
        
    else:
        star_df = df[['APSTAR_ID', 'TEFF_SPEC', 'LOGG_SPEC', 'SNR', 'ASPCAPFLAGS', 'STARFLAGS', 'FE_H', abundance_flag]]

        star_df_best = star_df.loc[(star_df['TEFF_SPEC'] < teff_max) & (star_df['TEFF_SPEC'] > teff_min) & 
        (star_df['LOGG_SPEC'] < logg_max) & (star_df['LOGG_SPEC'] > logg_min) & 
        (star_df['SNR'] < snr_max) & (star_df['SNR'] > snr_min) &
        (star_df['FE_H'] < metallicity_max) & (star_df['FE_H'] > metallicity_min) &
        (star_df[abundance_flag] < abundance_max) & (star_df[abundance_flag] > abundance_min)] 

    # Decode byte flags into strings
    star_df_best['ASPCAPFLAGS'] = star_df_best['ASPCAPFLAGS'].str.decode("utf-8")

    # Strip out stars with STAR_BAD flag
    star_df_best = star_df_best.loc[~(star_df_best['ASPCAPFLAGS'].str.contains("STAR_BAD"))]

    # Modified to add 3 parameters in 'u'
    # Extract columns with three target parameters
    star_df_best_parameters = star_df_best[["TEFF_SPEC", "LOGG_SPEC", "FE_H"]]
    star_df_best_parameters.reset_index()

    # Make a new array with the associated errors for TEFF, LOGG, FE_H
    # Have to transform 1d numpy array to 2d array for hstack. Transpose to get row -> Column
    parameter_errors = np.hstack((star['teff_err'][None, :].T,star['logg_err'][None, :].T, star['fe_h_err'][None, :].T))
    star_df_best_parameters_np = star_df_best_parameters.to_numpy()
    star_df_best_parameters_np_std_dev = star_df_best_parameters_np.std(axis=0)
    star_df_best_parameters_np -= star_df_best_parameters_np.mean(axis=0)
    star_df_best_parameters_np /= star_df_best_parameters_np.std(axis=0)

    #print(star_df_best.head(1))

    # Update the star_spectra dataframe with only 'good' indices
    star_spectra_filtered = star_spectra[star_df_best.index]
    star_spectra_filtered = np.hstack((star_spectra_filtered, star_df_best_parameters_np))

    ### Modified March 21 2022
    parameter_errors = parameter_errors[star_df_best.index]
    # Divide parameter errors by std dev of data
    parameter_errors /= star_df_best_parameters_np_std_dev
    # Update masks, errors as well
    # star_err_data = star_err_data[star_df_best.index]
    # star_mask_data = star_mask_data[star_df_best.index]

    print("After applying data filters " + str(len(star_spectra_filtered)) + " spectra remaining")

     # Bring star_df_best to have same indices as star_spectra_filtered (reset to start from 0)
    star_df_best_reset = star_df_best.reset_index(drop=True)
    #print(star_df_best_reset)
    #print(star_spectra_filtered)

    # # Reduce the dataset down to a manageable size, based on input_rows hyperparameter
    np.random.seed(random_seed)
    random_reduced_idx = list(np.random.choice(len(star_spectra_filtered), input_rows, replace=False))
    
    print("Reduced spectra count to " + str(input_rows))

    # # Grab only spectra with indices randomly selected from above
    #reduced_star_spectra = np.take(star_spectra, random_reduced_idx, 0)

    # Modified to just take first entry
    #reduced_star_spectra = np.array([star_spectra_filtered[0]])
    # Original:
    reduced_star_spectra = star_spectra_filtered[random_reduced_idx]

    ### Modified March 21 2022
    reduced_parameter_errors = parameter_errors[random_reduced_idx]

    #print("Reduced spectra:",reduced_star_spectra)

    # For average star calculation
    reduced_star_df_best = star_df_best_reset.iloc[random_reduced_idx]
    #print("Reduced parameters:",reduced_star_df_best)
    # Calculate means of paramaters
    print("Mean TEFF:",reduced_star_df_best["TEFF_SPEC"].mean())
    print("Mean LOGG:",reduced_star_df_best["LOGG_SPEC"].mean())
    print("Mean SNR:",reduced_star_df_best["SNR"].mean())
    print("Mean Fe/H:",reduced_star_df_best["FE_H"].mean())

    if abundance_flag != None:
        print("Mean Abundance:",reduced_star_df_best[abundance_flag].mean())

    # Normalize
    for starRow in reduced_star_spectra:
        starRow -= starRow.min()
        #print(starRow.max())
        if starRow.max() == 0:
            #print("Found zero max")
            pass
        else:
            starRow /= starRow.max()

    #print(pd.DataFrame(reduced_star_spectra))

    # Final normalized, reduced inputs
    train_dataset = spectraDataset(reduced_star_spectra)

    # Creating data indices for training and validation splits:
    dataset_size = len(reduced_star_spectra)
    indices = list(range(dataset_size))
    # split = int(np.floor(validation_split * dataset_size))

    # print("Splitting dataset at", split)

    # If shuffling is enabled, use random seed to shuffle data indices
    if shuffle_toggle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    # Generate data loaders
    kwargs = {'num_workers': 0}

    # Try without random sampling (simple split on index)
    # train_loader = DataLoader(train_dataset[split:], batch_size=batch_size, **kwargs)
    test_loader = DataLoader(train_dataset, batch_size=batch_size, **kwargs)

    # print('Batches in train:', len(train_loader))
    print('Batches in test:', len(test_loader))

    return test_loader, reduced_parameter_errors

In [None]:
# Metallicity

input_rows = 10000
batch_size = 100

low_fe_h_data, low_fe_h_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 200, snr_max = math.inf, metallicity_min = -0.5, metallicity_max = -0.25)

medium_fe_h_data, medium_fe_h_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 200, snr_max = math.inf, metallicity_min = -0.25, metallicity_max = 0)

high_fe_h_data, high_fe_h_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 200, snr_max = math.inf, metallicity_min = 0, metallicity_max = 0.25)

In [None]:
# Oxygen Abundance

input_rows = 10000
batch_size = 100

low_o_fe_data, low_o_fe_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 200, snr_max = math.inf, metallicity_min = -0.25,
metallicity_max = 0.25, abundance_min = -math.inf, abundance_max = 0, abundance_flag='O_FE')

medium_o_fe_data, medium_o_fe_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 200, snr_max = math.inf, metallicity_min = -0.25, 
metallicity_max = 0.25, abundance_min = 0, abundance_max = 0.05, abundance_flag='O_FE')

high_o_fe_data, high_o_fe_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 0, logg_max = 3.5, snr_min = 200, snr_max = math.inf, metallicity_min = -0.25,
metallicity_max = 0.25, abundance_min = 0.05, abundance_max = 0.2, abundance_flag='O_FE')

In [None]:
high_fe_h_array = feed_model(high_fe_h_data, high_fe_h_errors)
medium_fe_h_array = feed_model(medium_fe_h_data, medium_fe_h_errors)
low_fe_h_array = feed_model(low_fe_h_data, low_fe_h_errors)

In [None]:
high_o_fe_array = feed_model(high_o_fe_data, high_o_fe_errors)
medium_o_fe_array = feed_model(medium_o_fe_data, medium_o_fe_errors)
low_o_fe_array = feed_model(low_o_fe_data, low_o_fe_errors)

**Fe/H, O/Fe Track Sample Data**

In [None]:
# For tracks, average out plotted points across 50 samples
input_rows = 50
batch_size = 50

fe_h_sample_1, fe_h_sample_1_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -0.5, metallicity_max = -0.4)

fe_h_sample_2, fe_h_sample_2_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -0.4, metallicity_max = -0.3)

fe_h_sample_3, fe_h_sample_3_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -0.3, metallicity_max = -0.2)

fe_h_sample_4, fe_h_sample_4_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -0.2, metallicity_max = -0.1)

fe_h_sample_5, fe_h_sample_5_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -0.1, metallicity_max = 0)

In [None]:
# For tracks, average out plotted points across 50 samples
input_rows = 50
batch_size = 50

# Metallicity originally between -0.3, 0.1

o_fe_sample_1, o_fe_sample_1_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
abundance_min = -math.inf, abundance_max = -0.5, abundance_flag='O_FE')

o_fe_sample_2, o_fe_sample_2_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
abundance_min = -0.5, abundance_max = 0, abundance_flag='O_FE')

o_fe_sample_3, o_fe_sample_3_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
abundance_min = 0, abundance_max = 0.25, abundance_flag='O_FE')

# Scales are not evenly increasing at this point - didn't have enough samples

o_fe_sample_4, o_fe_sample_4_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
abundance_min = 0.25, abundance_max = 0.35, abundance_flag='O_FE')

o_fe_sample_5, o_fe_sample_5_errors = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4000, 
teff_max = 5500, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
abundance_min = 0.35, abundance_max = math.inf, abundance_flag='O_FE')



# o_fe_sample_1 = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
# teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
# abundance_min = 0, abundance_max = 0.025, abundance_flag='O_FE')

# o_fe_sample_2 = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
# teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
# abundance_min = 0.025, abundance_max = 0.05, abundance_flag='O_FE')

# o_fe_sample_3 = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
# teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
# abundance_min = 0.05, abundance_max = 0.075, abundance_flag='O_FE')

# # Scales are not evenly increasing at this point - didn't have enough samples

# o_fe_sample_4 = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
# teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
# abundance_min = 0.075, abundance_max = 0.12, abundance_flag='O_FE')

# o_fe_sample_5 = fetch_data_abundance(df = df, star_spectra = star_spectra, teff_min = 4800, 
# teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf, metallicity_min = -math.inf, metallicity_max = math.inf,
# abundance_min = 0.12, abundance_max = math.inf, abundance_flag='O_FE')

In [None]:
fe_h_sample_1_array = feed_model_single(fe_h_sample_1, fe_h_sample_1_errors)
fe_h_sample_2_array = feed_model_single(fe_h_sample_2, fe_h_sample_2_errors)
fe_h_sample_3_array = feed_model_single(fe_h_sample_3, fe_h_sample_3_errors)
fe_h_sample_4_array = feed_model_single(fe_h_sample_4, fe_h_sample_4_errors)
fe_h_sample_5_array = feed_model_single(fe_h_sample_5, fe_h_sample_5_errors)

In [None]:
o_fe_sample_1_array = feed_model_single(o_fe_sample_1, o_fe_sample_1_errors)
o_fe_sample_2_array = feed_model_single(o_fe_sample_2, o_fe_sample_2_errors)
o_fe_sample_3_array = feed_model_single(o_fe_sample_3, o_fe_sample_3_errors)
o_fe_sample_4_array = feed_model_single(o_fe_sample_4, o_fe_sample_4_errors)
o_fe_sample_5_array = feed_model_single(o_fe_sample_5, o_fe_sample_5_errors)

In [None]:
# Latent space values averaged across 10 samples from each class - raw list of 28 outputs
fe_h_sample_1_array_mean = latentValAverager(fe_h_sample_1_array[0]) 
fe_h_sample_2_array_mean = latentValAverager(fe_h_sample_2_array[0]) 
fe_h_sample_3_array_mean = latentValAverager(fe_h_sample_3_array[0]) 
fe_h_sample_4_array_mean = latentValAverager(fe_h_sample_4_array[0]) 
fe_h_sample_5_array_mean = latentValAverager(fe_h_sample_5_array[0]) 

In [None]:
# Latent space values averaged across 10 samples from each class - raw list of 28 outputs
o_fe_sample_1_array_mean = latentValAverager(o_fe_sample_1_array[0]) 
o_fe_sample_2_array_mean = latentValAverager(o_fe_sample_2_array[0]) 
o_fe_sample_3_array_mean = latentValAverager(o_fe_sample_3_array[0]) 
o_fe_sample_4_array_mean = latentValAverager(o_fe_sample_4_array[0]) 
o_fe_sample_5_array_mean = latentValAverager(o_fe_sample_5_array[0]) 

**Metallicity Track Plot**

In [None]:
figure_fe_h = corner.corner(np.array(high_fe_h_array), color="green", labels=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31])
corner.corner(np.array(medium_fe_h_array), color="yellow", fig=figure_fe_h)
corner.corner(np.array(low_fe_h_array), color="red", fig=figure_fe_h)
plt.title("High/Medium/Low Metallicity (Fe/H) Latent Space Comparison")

ndim = 31

axes = np.array(figure_fe_h.axes).reshape((ndim, ndim))

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
  
        ax.plot(fe_h_sample_1_array_mean[xi], fe_h_sample_1_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(fe_h_sample_2_array_mean[xi], fe_h_sample_2_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(fe_h_sample_3_array_mean[xi], fe_h_sample_3_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(fe_h_sample_4_array_mean[xi], fe_h_sample_4_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(fe_h_sample_5_array_mean[xi], fe_h_sample_5_array_mean[yi], marker="o", markersize=2, color="darkviolet")

        # Add text label
        ax.annotate("1",xy=(fe_h_sample_1_array_mean[xi], fe_h_sample_1_array_mean[yi]), fontsize=8)
        ax.annotate("2",xy=(fe_h_sample_2_array_mean[xi], fe_h_sample_2_array_mean[yi]), fontsize=8)
        ax.annotate("3",xy=(fe_h_sample_3_array_mean[xi], fe_h_sample_3_array_mean[yi]), fontsize=8)
        ax.annotate("4",xy=(fe_h_sample_4_array_mean[xi], fe_h_sample_4_array_mean[yi]), fontsize=8)
        ax.annotate("5",xy=(fe_h_sample_5_array_mean[xi], fe_h_sample_5_array_mean[yi]), fontsize=8)

figure_fe_h

In [None]:
# Test Plot for Selecting categories
from matplotlib.lines import Line2D

figure_fe_h_crop = corner.corner(np.array(high_fe_h_array)[:, [16,17,18,19,20]], color="green", labels=['VAE 17','VAE 18','VAE 19', 'VAE 20','VAE 21'])
corner.corner(np.array(medium_fe_h_array)[:, [16,17,18,19,20]], color="yellow", fig=figure_fe_h_crop)
corner.corner(np.array(low_fe_h_array)[:, [16,17,18,19,20]], color="red", fig=figure_fe_h_crop)
#plt.title("High/Medium/Low SNR Latent Space Comparison")

figure_fe_h_crop.suptitle("High/Medium/Low [Fe/H] Latent Space Comparison")

legend_elements = [Line2D([0], [0], marker='o', color='green', label='High', markerfacecolor='green', markersize=10),
                   Line2D([0], [0], marker='o', color='yellow', label='Medium', markerfacecolor='yellow', markersize=10),
                   Line2D([0], [0], marker='o', color='red', label='Low', markerfacecolor='red', markersize=10),
                   Line2D([0], [0], marker='o', color='blue', label='Track', markerfacecolor='dodgerblue', markersize=10)
                   ]

plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.25, 2))

ndim = 5
# Extract the axes
axes = np.array(figure_fe_h_crop.axes).reshape((ndim, ndim))

# Offset = first index in crop
offset = 16

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
        #print("ypos:",yi+1,"xpos:",xi+1)

        # ax.text(0.15, 0.15,(xi,yi),
        # horizontalalignment='center',
        # verticalalignment='center',
        # transform = ax.transAxes)
        # ax.annotate((xi+1,yi+1), xy=(0,0))

        ax.plot(fe_h_sample_1_array_mean[xi+offset], fe_h_sample_1_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(fe_h_sample_2_array_mean[xi+offset], fe_h_sample_2_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(fe_h_sample_3_array_mean[xi+offset], fe_h_sample_3_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(fe_h_sample_4_array_mean[xi+offset], fe_h_sample_4_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(fe_h_sample_5_array_mean[xi+offset], fe_h_sample_5_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")

        # Add text label
        ax.annotate("1",xy=(fe_h_sample_1_array_mean[xi+offset], fe_h_sample_1_array_mean[yi+offset]), fontsize=12)
        ax.annotate("2",xy=(fe_h_sample_2_array_mean[xi+offset], fe_h_sample_2_array_mean[yi+offset]), fontsize=12)
        ax.annotate("3",xy=(fe_h_sample_3_array_mean[xi+offset], fe_h_sample_3_array_mean[yi+offset]), fontsize=12)
        ax.annotate("4",xy=(fe_h_sample_4_array_mean[xi+offset], fe_h_sample_4_array_mean[yi+offset]), fontsize=12)
        ax.annotate("5",xy=(fe_h_sample_5_array_mean[xi+offset], fe_h_sample_5_array_mean[yi+offset]), fontsize=12)

plt.show()

**Oxygen Abundance Track Plot**

In [None]:
figure_o_fe = corner.corner(np.array(high_o_fe_array), color="green", labels=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31])
corner.corner(np.array(medium_o_fe_array), color="yellow", fig=figure_o_fe)
corner.corner(np.array(low_o_fe_array), color="red", fig=figure_o_fe)
plt.title("High/Medium/Low Oxygen Abundance (O/Fe) Latent Space Comparison")

ndim = 31

axes = np.array(figure_o_fe.axes).reshape((ndim, ndim))

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
  
        ax.plot(o_fe_sample_1_array_mean[xi], o_fe_sample_1_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(o_fe_sample_2_array_mean[xi], o_fe_sample_2_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(o_fe_sample_3_array_mean[xi], o_fe_sample_3_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(o_fe_sample_4_array_mean[xi], o_fe_sample_4_array_mean[yi], marker="o", markersize=2, color="darkviolet")
        ax.plot(o_fe_sample_5_array_mean[xi], o_fe_sample_5_array_mean[yi], marker="o", markersize=2, color="darkviolet")

        # Add text label
        ax.annotate("1",xy=(o_fe_sample_1_array_mean[xi], o_fe_sample_1_array_mean[yi]), fontsize=8)
        ax.annotate("2",xy=(o_fe_sample_2_array_mean[xi], o_fe_sample_2_array_mean[yi]), fontsize=8)
        ax.annotate("3",xy=(o_fe_sample_3_array_mean[xi], o_fe_sample_3_array_mean[yi]), fontsize=8)
        ax.annotate("4",xy=(o_fe_sample_4_array_mean[xi], o_fe_sample_4_array_mean[yi]), fontsize=8)
        ax.annotate("5",xy=(o_fe_sample_5_array_mean[xi], o_fe_sample_5_array_mean[yi]), fontsize=8)

figure_o_fe

In [None]:
# Test Plot for Selecting categories
from matplotlib.lines import Line2D

figure_o_fe_crop = corner.corner(np.array(high_o_fe_array)[:, [13,14,15,16,17,18]], color="green", labels=['VAE 14','VAE 15','VAE 16', 'VAE 17','VAE 18', 'VAE 19'])
corner.corner(np.array(medium_o_fe_array)[:, [13,14,15,16,17,18]], color="yellow", fig=figure_o_fe_crop)
corner.corner(np.array(low_o_fe_array)[:, [13,14,15,16,17,18]], color="red", fig=figure_o_fe_crop)
#plt.title("High/Medium/Low SNR Latent Space Comparison")

figure_o_fe_crop.suptitle("High/Medium/Low [O/Fe] Latent Space Comparison")

legend_elements = [Line2D([0], [0], marker='o', color='green', label='High', markerfacecolor='green', markersize=10),
                   Line2D([0], [0], marker='o', color='yellow', label='Medium', markerfacecolor='yellow', markersize=10),
                   Line2D([0], [0], marker='o', color='red', label='Low', markerfacecolor='red', markersize=10),
                   Line2D([0], [0], marker='o', color='blue', label='Track', markerfacecolor='dodgerblue', markersize=10)
                   ]

plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.25, 2))

ndim = 6
# Extract the axes
axes = np.array(figure_o_fe_crop.axes).reshape((ndim, ndim))

# Offset = first index in crop
offset = 13

# Loop over scatter plots
for yi in range(ndim):
    for xi in range(yi):
        ax = axes[yi, xi]
        #print("ypos:",yi+1,"xpos:",xi+1)

        # ax.text(0.15, 0.15,(xi,yi),
        # horizontalalignment='center',
        # verticalalignment='center',
        # transform = ax.transAxes)
        # ax.annotate((xi+1,yi+1), xy=(0,0))

        #ax.plot(o_fe_sample_1_array_mean[xi+offset], o_fe_sample_1_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(o_fe_sample_2_array_mean[xi+offset], o_fe_sample_2_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(o_fe_sample_3_array_mean[xi+offset], o_fe_sample_3_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(o_fe_sample_4_array_mean[xi+offset], o_fe_sample_4_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")
        ax.plot(o_fe_sample_5_array_mean[xi+offset], o_fe_sample_5_array_mean[yi+offset], marker="o", markerfacecolor="dodgerblue",markersize=4, color="blue")

        # Add text label
        #ax.annotate("1",xy=(o_fe_sample_1_array_mean[xi+offset], o_fe_sample_1_array_mean[yi+offset]), fontsize=12)
        ax.annotate("2",xy=(o_fe_sample_2_array_mean[xi+offset], o_fe_sample_2_array_mean[yi+offset]), fontsize=12)
        ax.annotate("3",xy=(o_fe_sample_3_array_mean[xi+offset], o_fe_sample_3_array_mean[yi+offset]), fontsize=12)
        ax.annotate("4",xy=(o_fe_sample_4_array_mean[xi+offset], o_fe_sample_4_array_mean[yi+offset]), fontsize=12)
        ax.annotate("5",xy=(o_fe_sample_5_array_mean[xi+offset], o_fe_sample_5_array_mean[yi+offset]), fontsize=12)

plt.show()

**Network Visualization**

In [None]:
!pip install graphviz
!pip install torchviz

In [None]:
from torchviz import make_dot

batch_size = 1
input_rows = 1

dummy_sample_1 = fetch_data_single(df = df, star_spectra = star_spectra, teff_min = 4800, 
teff_max = 4900, logg_min = 3.2, logg_max = 3.3, snr_min = 150, snr_max = math.inf)


for batch_idx, (x) in enumerate(tqdm(dummy_sample_1)):
    x = x.view(batch_size, x_dim)
    x = x.to(DEVICE)    
    x_hat, mean, log_var, z = model(x)


make_dot(x_hat, params=dict(model.named_parameters()))