## Train a PAE on labeled galaxy spectra

In [1]:
import numpy as np
import pandas as pd
import torch
from pathlib import Path
from spectra_pae.spectra_pae import *

Note: To fully understand what this code is doing under the hood, you need to read the paper, look at the Spectra_PAE class and the PytorchPAE package with the Autoencoder class. All of these are well documented.

### set parameters

In [2]:
SEED               = 287505

## dataset name (if you want to add a new dataset, it must be added to the the PytorchPAE package in custom_datasets.py) 
dataset_name       = 'SDSS_DR16'
# dataset directory
data_dir           = '/global/cscratch1/sd/vboehm/Datasets/sdss/by_model'
# directory for saving trained models 
model_dir          = '/global/cscratch1/sd/vboehm/SDSSOutlier/fc'
# dimensionality of the input data 
input_dim          = (1000,1)

### initiate the class

In [3]:
# all functionalities are described in the class documentation 
SPAE = Spectra_PAE(data_dir, model_dir, dataset_name='SDSS_DR16', input_dim=input_dim)

### Train the full model. This goes through all training steps described in the publication

In [4]:
SPAE.train_complete_model(nepochs=100, use_prior=True)

training AE stage 1...
AE stage 1 training completed.
evaluating AE stage 1 train...
evaluating AE stage 1 valid...
evaluating AE stage 1 test...
training AE stage 2...
AE stage 2 training completed.
evaluating AE stage 2 train...
evaluating AE stage 2 valid...
evaluating AE stage 2 test...
loading trained NF1...
computing prior probabilities...
classifying...
classifying...
classifying...
loading trained NF2...


### get log probability of all spectra in the combined validation set under the most likely label

In [7]:
## the initial dataset is divided into training, validation and test set. However, I combine validation and test set into one dataset in the publication.
## the data I'm sharing is actually: training set = training set, validation set = validation + test set, test set = test set. 
## You can recover the original split between validation and test set from this. 
logps = SPAE.evaluate_NF2(SPAE.NF1_data['valid'],SPAE.new_labels['valid'])

### evaluate the rank (in terms of percentile) of a single spectrum with respect to a reference sample.


In [8]:
# note that I'm not using the training set as a reference sample here. This is to avoid biases from potential overfitting. 
# In the training I do not penalize overfitting as long as the validation loss keeps improving. Early stopping is based on the validation loss not improving, not the training loss!
rank = SPAE.evaluate_logp_percentile(SPAE.NF1_data['valid'],SPAE.labels['valid'], SPAE.NF1_data['valid'][0:1],SPAE.labels['valid'][0:1])