## Import standard libraries

In [None]:
# Import custom libraries from local folder.
import sys
sys.path.append("..")

# Import nn module from torch to replicate kessler tool
import torch.nn as nn

# Import utils library containing miscellaneous functions/classes
from scalib import utils

# Import library to import Kelvins challlenge data
from scalib.eda import kelvins_challenge_events

# Import SCALIB modules for NN development
import scalib.cdm as cdm            # Conjunction Data Messages
import scalib.event as event        # Conjunction Events
import scalib.nn as snn             # NN models
import scalib.cells as rnn_cells    # RNN cell architectures

# Import kessler RNN model
from scalib.kessler import LSTMPredictor

# Set overall seed for reproducibility
utils.seed(1)

# Import matplotlib library and setup environment for plots
%matplotlib inline
%config InlineBackend.figure_format='retina'
from matplotlib import rc

# Set rendering parameters to use TeX font if not working on Juno app.
if not '/private/var/' in utils.cwd:
    rc('font', **{'family': 'serif', 'serif': ['Computer Modern'], 'size': 11})
    rc('text', usetex=True)

print(utils.cwd)

## Data preparation

In [None]:
#As an example, we first show the case in which the data comes from the Kelvins competition.
#For this, we built a specific converter that takes care of the conversion from Kelvins format
#to standard CDM format (the data can be downloaded at https://kelvins.esa.int/collision-avoidance-challenge/data/):
filepath = utils.os.path.join(utils.cwd,'data/esa-challenge/train_data.csv')

# Get ConjunctionEventsDataset object 
events = kelvins_challenge_events(filepath,
            drop_features = ['c_rcs_estimate', 't_rcs_estimate'], 
            num_events = 1000)

# Get features to train the model.
nn_features = events.common_features(only_numeric=True)

# Define input and output size of the RNN model.
input_size = len(nn_features)
output_size = len(nn_features)

# Split data into a test set (5% of the total number of events)
len_test_set=int(0.05*len(events))

# Get Events to test model: used to compute the error the model would have in 
# run-mode.
events_test=events[-len_test_set:]
print('\nTest data:', events_test)

# Get events used for training and validation:
# - Training set: Used to train the model and backpropagate the loss.
# - Validation set: Used to compute the loss so that hyperparameters can be 
#   adjusted.
events_train_and_val=events[:-len_test_set]
print('Training and validation data:', events_train_and_val)

## Recurrent Neural Network model configuration

### RNN layer and cell architecture definition

#### Kessler configuration

In [None]:
# Initialize LSTM architecture using custom cell
layers = nn.ModuleDict({'lstm': nn.LSTM(input_size = input_size, 
                                         hidden_size = 256, 
                                         num_layers = 2,
                                         dropout = 0.2)})


#### LSTM layer with *vanilla* cell architecture

In [None]:
# Initialize LSTM architecture using custom cell
layers = nn.ModuleDict({'lstm': snn.LSTM(input_size = input_size, 
                                         hidden_size = 256,
                                         cell = rnn_cells.LSTM_Vanilla, 
                                         num_layers = 2,
                                         dropout = 0.2)})

#### LSTM layer with *SLIMx* cell architecture (*x* = 1, 2, or 3)

In [None]:
# Initialize LSTM architecture using custom cell
cell_args = dict(slim_version = 1)
layers = nn.ModuleDict({'lstm': snn.LSTM(input_size = input_size, 
                                         hidden_size = 256,
                                         cell = rnn_cells.LSTM_SLIMX, 
                                         num_layers = 2,
                                         dropout = 0.2,
                                         **cell_args)})

#### LSTM-Attention-LSTM layer

In [None]:
# Initialize encoder
encoder = snn.LSTM(input_size = input_size, 
                    hidden_size = 256,
                    cell = rnn_cells.LSTM_Vanilla,
                    dropout = 0.2)

# Initialize decoder
decoder = snn.LSTM(input_size = encoder.hidden_size, 
                    hidden_size = encoder.hidden_size,
                    cell = rnn_cells.LSTM_Vanilla,
                    dropout = 0.2)

# Get the SelfAttention layer (SelfAttention)
attention = snn.SelfAttentionLayer(encoder_output_size = encoder.hidden_size, 
                                   decoder_input_size = decoder.input_size)

layers = nn.ModuleDict({'lstm_encoder': encoder,
                        'attention': attention,
                        'lstm_decoder': decoder})

### Model instanciation

In [None]:
# Add remaining parameters for the model instanciation.
layers.update({'dropout': nn.Dropout(p = 0.2),
               'relu': nn.ReLU(),
               'linear': nn.Linear(256, output_size)})

# Initialize model.
model = snn.ConjunctionEventForecaster(layers = layers, features = nn_features)

# Print model.
print(f'\n{model}\n')

### Model training

In [None]:
# Start training
model.learn(events_train_and_val, 
            epochs = 10, # Number of epochs (one epoch is one full pass through the training dataset)
            lr = 1e-3, # Learning rate, can decrease it if training diverges
            batch_size = 16, # Minibatch size, can be decreased if there are issues with memory use
            device = 'cpu', # Can be 'cuda' if there is a GPU available
            valid_proportion = 0.15, # Proportion of the data to use as a validation set internally
            num_workers = 4, # Number of multithreaded dataloader workers, 4 is good for performance, but if there are any issues or errors, please try num_workers=1 as this solves issues with PyTorch most of the time
            event_samples_for_stats = 100) # Number of events to use to compute NN normalization factors, have this number as big as possible (and at least a few thousands)

##### Training vs validation loss

In [None]:
# Plot MSE loss throughout iterations.
model.plot_loss(filepath = None, log_scale = False)

##### Conjunction event forecasting

In [None]:
# Take a single event from test dataset and remove the last CDM.
event_idx = 3
event = events_test[event_idx]
event_beginning = event[0:len(event)-1]

# Print information about the event to forecast.
print(f'Forecasting next CDM from previous {len(event)} CDM(s)...')

# Predict the evolution of the conjunction event until TCA or the number of CDMs
# is max_length.
event_evolution = model.predict_event(event = event_beginning, 
                                      num_samples = 10, 
                                      max_length = 14)

# List of features to predict.
features = ['RELATIVE_SPEED', 'MISS_DISTANCE', 'OBJECT1_CT_T']

# Plot prediction in red
axs = event_evolution.plot_features(features = features, return_axs = True, 
                                    linewidth = 0.1, color = 'red', alpha=0.33, 
                                    label = 'Prediction')
#and the ground truth value in blue:
event.plot_features(feature=features, axs=axs, label='Actual', legend = True)