# Training and validating LSTM + MPL model with the WHXE loss function

## Imports

In [21]:
import time
import pickle
import platform
import os
import imageio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

from LSTM_model import get_LSTM_Classifier
from dataloader import LSSTSourceDataSet, load, get_augmented_data, get_static_features, ts_length
from loss import WHXE_Loss
from taxonomy import get_taxonomy_tree, get_prediction_probs, get_highest_prob_path, plot_colored_tree
from vizualizations import make_gif, plot_confusion_matrix, plot_roc_curves
from interpret_results import get_conditional_probabilites, get_all_confusion_matrices

from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from tensorflow import keras

In [7]:
print("Tensorflow version", tf.__version__)
print("Python version", platform.python_version())

Tensorflow version 2.15.0
Python version 3.10.12


## Load and balance the tensors:

This step takes a while because it has load from disc to memory...

In [8]:
X_ts = load("processed/train/x_ts.pkl")
X_static = load("processed/train/x_static.pkl")
Y = load("processed/train/y.pkl")
astrophysical_classes = load("processed/train/a_labels.pkl")

Small step to convert X_static from a dictionary to an array

In [9]:
static_list = ['MWEBV', 'MWEBV_ERR', 'HOSTGAL_PHOTOZ', 'HOSTGAL_PHOTOZ_ERR', 'HOSTGAL_SPECZ', 'HOSTGAL_SPECZ_ERR']
for i in range(len(X_static)):
    
    if i%1000 == 0:
        print(f"{(i/len(X_static) * 100):.3f} %", end="\r")
        
    X_static[i] = get_static_features(X_static[i]) #, feature_list=static_list)

99.943 %

Balance the data set in some way

In [10]:
max_class_count = 13000

X_ts_balanced = []
X_static_balanced = []
Y_balanced = []
lengths_balanced = []
astrophysical_classes_balanced = []

for c in np.unique(astrophysical_classes):

    idx = np.where(np.array(astrophysical_classes) == c)[0]
    
    if len(idx) > max_class_count:
        idx = idx[:max_class_count]
 
    X_ts_balanced += [X_ts[i] for i in idx]
    X_static_balanced += [X_static[i] for i in idx]
    Y_balanced += [Y[i] for i in idx]
    astrophysical_classes_balanced += [astrophysical_classes[i] for i in idx]

# Print summary of the data set used for training and validation
a, b = np.unique(astrophysical_classes_balanced, return_counts=True)
data_summary = pd.DataFrame(data = {'Class': a, 'Count': b})
data_summary

Unnamed: 0,Class,Count
0,AGN,13000
1,CART,8207
2,Cepheid,13000
3,Delta Scuti,13000
4,Dwarf Novae,8025
5,EB,13000
6,ILOT,7461
7,KN,4426
8,M-dwarf Flare,1859
9,PISN,13000


Free up some memory

In [11]:
del X_ts, X_static, Y, astrophysical_classes

Split into train and validation

In [12]:
val_fraction = 0.05
X_ts_train, X_ts_val, X_static_train, X_static_val, Y_train, Y_val, astrophysical_classes_train, astrophysical_classes_val = train_test_split(X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced, shuffle=True, random_state = 40, test_size = val_fraction)

Free up some more memory

In [13]:
del X_ts_balanced, X_static_balanced, Y_balanced, astrophysical_classes_balanced

## Define the Loss function, criterion, etc.

In [14]:
# Loss and optimizer
tree = get_taxonomy_tree()
loss_object = WHXE_Loss(tree, astrophysical_classes_train, alpha=0) 
criterion = loss_object.compute_loss

In [15]:
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

2024-06-11 04:17:44.408893: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31127 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB, pci bus id: 0000:3b:00.0, compute capability: 7.0
2024-06-11 04:17:44.410314: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31127 MB memory:  -> device: 1, name: Tesla V100-SXM2-32GB, pci bus id: 0000:89:00.0, compute capability: 7.0


In [16]:
# Inputs for model
ts_dim = 5
static_dim = len(X_static_train[0])
output_dim = 26
latent_size = 64

num_epochs = 50
batch_size = 1024

## Train the classifier using WHXE loss and save the model

In [12]:
model = get_LSTM_Classifier(ts_dim, static_dim, output_dim, latent_size)
keras.utils.plot_model(model, to_file='lstm.pdf', show_shapes=True, show_layer_names=True)
plt.close()

In [14]:
@tf.function
def train_step(x_ts, x_static, y):
    with tf.GradientTape() as tape:
        logits = model((x_ts, x_static), training=True)
        loss_value = criterion(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    return loss_value

avg_train_losses = []
for epoch in range(num_epochs):
    
    print(f"\nStart of epoch {epoch}")
    start_time = time.time()

    print("Augmenting time series lengths...")
    
    # Create the augmented data set for training
    X_ts_train_aug, X_static_train_aug, Y_train_aug, astrophysical_classes_train_aug = get_augmented_data(X_ts_train, X_static_train, Y_train, astrophysical_classes_train)
    train_dataset =  tf.data.Dataset.from_tensor_slices((X_ts_train_aug, X_static_train_aug, Y_train_aug, astrophysical_classes_train_aug)).batch(batch_size)
    
    # Array to keep tracking of the training loss
    train_loss_values = []
    
    # Iterate over the batches of the dataset.
    for step, (x_ts_batch_train, x_static_batch_train, y_batch_train, a_class_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_ts_batch_train, x_static_batch_train, y_batch_train)
    
    # Log the avg train loss
    avg_train_loss = np.mean(loss_value)
    avg_train_losses.append(avg_train_loss)
    print(f"Avg training loss: {float(avg_train_loss):.4f}")
        
    print(f"Time taken: {time.time() - start_time:.2f}s")
    model.save(f"models/lstm_epoch_{epoch}.h5")


Start of epoch 0
Augmenting time series lengths...
19.801 %


KeyboardInterrupt



In [None]:
plt.plot(range(num_epochs), np.log(avg_train_losses))
plt.x_label("Train Epoch")
plt.y_label("Train log loss")

## Load the saved model and validate everthing looks okay

In [17]:
load_model_epoch = 43
saved_model = keras.models.load_model(f"models/lstm_epoch_{load_model_epoch}.h5", compile=False)

In [18]:
fractions = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

In [30]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    x1, x2, y_true, _ = get_augmented_data(X_ts_val, X_static_val, Y_val, astrophysical_classes_val, fraction=f)
    
    # Run inference on these
    y_pred = saved_model.predict([x1, x2])

    #################################
    
    # Get the predictions at the leaf
    for i in range(y_pred.shape[0]):
    
        pseudo_probs, weighted_tree = get_prediction_probs(y_pred[[i], :])
        y_pred[i, 1:3] =  pseudo_probs[0, 1:3]

    y_pred_label = np.argmax(y_pred[:, 1:3], axis=1)
    y_true_label = np.argmax(y_true[:, 1:3], axis=1)

    # Print the stats
    print(f'For {int(f*100)}% of the light curve, these are the statistics')
    report = classification_report(y_true_label, y_pred_label, target_names=list(loss_object.level_order_nodes)[1:3])
    print(report)

    # Make plots
    leaf_labels = list(loss_object.level_order_nodes)[1:3]
    plot_title = f"~{f * 100}% of each LC visible"
    cf_plot_file = f"gif/root_cf/{f}.png"
    roc_plot_file = f"gif/root_roc/{f}.png"
    
    plot_confusion_matrix(y_true_label, y_pred_label, leaf_labels, plot_title, cf_plot_file)
    plt.close()
    plot_roc_curves(y_true[:, 1:3], y_pred[:, 1:3], leaf_labels, plot_title, roc_plot_file)
    plt.close()

Running inference for 10% light curves...
For 10% of the light curve, these are the statistics
              precision    recall  f1-score   support

   Transient       0.91      1.00      0.95      7312
    Variable       1.00      0.78      0.87      3287

    accuracy                           0.93     10599
   macro avg       0.95      0.89      0.91     10599
weighted avg       0.94      0.93      0.93     10599

Running inference for 20% light curves...
For 20% of the light curve, these are the statistics
              precision    recall  f1-score   support

   Transient       0.95      1.00      0.97      7312
    Variable       1.00      0.88      0.93      3287

    accuracy                           0.96     10599
   macro avg       0.97      0.94      0.95     10599
weighted avg       0.96      0.96      0.96     10599

Running inference for 30% light curves...
For 30% of the light curve, these are the statistics
              precision    recall  f1-score   support

   Tra

In [31]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    x1, x2, y_true, _ = get_augmented_data(X_ts_val, X_static_val, Y_val, astrophysical_classes_val, fraction=f)
    
    # Run inference on these
    y_pred = saved_model.predict([x1, x2])

    # Get the predictions at the leaf
    for i in range(y_pred.shape[0]):
    
        pseudo_probs, weighted_tree = get_prediction_probs(y_pred[[i], :])
        leaf_prob , _ = get_highest_prob_path(weighted_tree)
        y_pred[i, -19:] =  leaf_prob

    y_pred_label = np.argmax(y_pred[:, -19:], axis=1)
    y_true_label = np.argmax(y_true[:, -19:], axis=1)

    # Print the stats
    print(f'For {int(f*100)}% of the light curve, these are the statistics')
    report = classification_report(y_true_label, y_pred_label, target_names=list(loss_object.level_order_nodes)[-19:])
    print(report)

    # Make plots
    leaf_labels = list(loss_object.level_order_nodes)[-19:]
    plot_title = f"~{f * 100}% of each LC visible"
    cf_plot_file = f"gif/leaf_cf/{f}.png"
    roc_plot_file = f"gif/leaf_roc/{f}.png"
    
    plot_confusion_matrix(y_true_label, y_pred_label, leaf_labels, plot_title, cf_plot_file)
    plt.close()
    plot_roc_curves(y_true[:, -19:], y_pred[:, -19:], leaf_labels, plot_title, roc_plot_file)
    plt.close()

Running inference for 10% light curves...
For 10% of the light curve, these are the statistics
               precision    recall  f1-score   support

          AGN       1.00      0.49      0.66       667
         SNIa       0.35      0.35      0.35       672
       SNIb/c       0.33      0.14      0.19       669
        SNIax       0.37      0.15      0.22       667
      SNI91bg       0.57      0.45      0.50       629
         SNII       0.24      0.49      0.32       676
           KN       0.24      0.75      0.36       190
  Dwarf Novae       0.35      0.99      0.52       409
        uLens       0.66      0.39      0.49       664
M-dwarf Flare       0.60      0.08      0.15       109
         SLSN       0.58      0.65      0.62       617
          TDE       0.35      0.46      0.40       617
         ILOT       0.43      0.43      0.43       354
         CART       0.24      0.27      0.26       415
         PISN       0.81      0.63      0.71       624
      Cepheid       0.77

## Making a cool animation:

In [32]:
cf_files = [f"gif/root_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/root_cf/root_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [33]:
roc_files = [f"gif/root_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/root_roc/root_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [34]:
cf_files = [f"gif/leaf_cf/{f}.png" for f in fractions]
make_gif(cf_files, 'gif/leaf_cf/leaf_cf.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


In [35]:
roc_files = [f"gif/leaf_roc/{f}.png" for f in fractions]
make_gif(roc_files, 'gif/leaf_roc/leaf_roc.gif')
plt.close()

MovieWriter ffmpeg unavailable; using Pillow instead.


For the love of everthing that is good in this world, please use a different notebook for testing and genearting statistics. Keep this notebook simple enought to be converted into a script. 

In [None]:
for f in fractions:

    print(f'Running inference for {int(f*100)}% light curves...')

    x1, x2, y_true, _ = get_augmented_data(X_ts_val, X_static_val, Y_val, astrophysical_classes_val, fraction=f)
    
    # Run inference on these
    y_pred = saved_model.predict([x1, x2])

    # Get the conditional probabilities
    _, pseudo_conditional_probabilities = get_conditional_probabilites(y_pred, tree)
    
    print(f'For {int(f*100)}% of the light curve, these are the statistics')
    
    get_all_confusion_matrices(y_true, y_pred, tree)
    plt.show()
    #################################