## Import standard libraries

In [1]:
import torch
import torch.nn as nn  # we'll use this a lot going forward!
import torch.nn.functional as F

import numpy as np
import warnings

# Import matplotlib library and setup environment for plots
%matplotlib inline
%config InlineBackend.figure_format='retina'
from matplotlib import pyplot as plt, rc

# Import json library and create function to format dictionaries.
import json
format_json = lambda x: json.dumps(x, indent=4)

# Import pandas and set pandas DataFrame visualization parameters
from IPython.display import display
import pandas as pd
pd.options.display.max_columns = None
pd.options.display.max_rows = None

# Set rendering parameters to use TeX font if not working on Juno app.
from pathlib import Path
import os
if not '/private/var/' in os.getcwd():
    rc('font', **{'family': 'serif', 'serif': ['Computer Modern'], 'size': 11})
    rc('text', usetex=True)
    
# Get current working directory path for the tool parent folder and print it.
parent_folder = 'Tool'
cwd = str(Path(os.getcwd()[:os.getcwd().index(parent_folder)+len(parent_folder)]))
print('Parent working directory: %s' % cwd)


Parent working directory: /Users/jjrr/Documents/SCA-Project/Tool


## Import user defined libraries

In [2]:
# Import custom libraries from local folder.
import sys
sys.path.append("..")

from library.irplib import utils, eda, config, sdg
from library.irplib import rnn

## Data preparation

### Import training dataset

In [3]:
# Import transformed training dataset
df = eda.import_cdm_data(os.path.join(cwd,'data','esa-challenge','train_data_transformed.csv'))

# Count number of CDMs available per event
nb_cdms = df.groupby(['event_id']).count()['time_to_tca'].to_numpy(dtype=np.int16)

# Define window size and number of events to forecast
window_size = 5
events_to_forecast = 1
min_cdms = window_size + events_to_forecast

print(f'Events suitable for training (More than {min_cdms-1} CDMs): {np.sum(nb_cdms>=min_cdms)}'
      f' ({np.sum(nb_cdms>=min_cdms)/len(nb_cdms)*100:5.1f}%)')
print(f'Time sequences with event_id integrity per feature: {np.sum(nb_cdms[nb_cdms>=min_cdms]-min_cdms)}')

# Count number of CDMs per event
ts_events  = df[['event_id', 'time_to_tca']].groupby(['event_id']).count().rename(columns={'time_to_tca':'nb_cdms'})

# Get events that have a minimum number of CDMs equal to the window_size + events_to_forecast
events_filter = list(ts_events[ts_events['nb_cdms']>=min_cdms].index.values)

# Redefine DataFrame to contain only events suitable for TSF to save memory
df = df[df['event_id'].isin(events_filter)]

# Show first data points to explore data types
display(df.head(10))
df.info()

Events suitable for training (More than 5 CDMs): 9400 ( 71.5%)
Time sequences with event_id integrity per feature: 94699


Unnamed: 0,event_id,time_to_tca,mission_id,risk,max_risk_estimate,max_risk_scaling,miss_distance,relative_speed,relative_position_r,relative_position_t,relative_position_n,relative_velocity_r,relative_velocity_t,relative_velocity_n,t_time_lastob_start,t_time_lastob_end,t_recommended_od_span,t_actual_od_span,t_obs_available,t_obs_used,t_residuals_accepted,t_weighted_rms,t_rcs_estimate,t_cd_area_over_mass,t_cr_area_over_mass,t_sedr,t_j2k_sma,t_j2k_ecc,t_j2k_inc,t_ct_r,t_cn_r,t_cn_t,t_crdot_r,t_crdot_t,t_crdot_n,t_ctdot_r,t_ctdot_t,t_ctdot_n,t_ctdot_rdot,t_cndot_r,t_cndot_t,t_cndot_n,t_cndot_rdot,t_cndot_tdot,c_object_type,c_time_lastob_start,c_time_lastob_end,c_recommended_od_span,c_actual_od_span,c_obs_available,c_obs_used,c_residuals_accepted,c_weighted_rms,c_rcs_estimate,c_cd_area_over_mass,c_cr_area_over_mass,c_sedr,c_j2k_sma,c_j2k_ecc,c_j2k_inc,c_ct_r,c_cn_r,c_cn_t,c_crdot_r,c_crdot_t,c_crdot_n,c_ctdot_r,c_ctdot_t,c_ctdot_n,c_ctdot_rdot,c_cndot_r,c_cndot_t,c_cndot_n,c_cndot_rdot,c_cndot_tdot,t_span,c_span,t_h_apo,t_h_per,c_h_apo,c_h_per,geocentric_latitude,azimuth,elevation,mahalanobis_distance,t_position_covariance_det,c_position_covariance_det,t_sigma_r,c_sigma_r,t_sigma_t,c_sigma_t,t_sigma_n,c_sigma_n,t_sigma_rdot,c_sigma_rdot,t_sigma_tdot,c_sigma_tdot,t_sigma_ndot,c_sigma_ndot,F10,F3M,SSN,AP
9,2,6.983474,2,-10.816161,-6.601713,13.293159,22902.0,14348.0,-1157.6,-6306.2,21986.3,15.8,-13792.0,-3957.1,1.0,0.0,3.92,3.92,444,442,99.4,1.094,3.4505,3.042086,0.92498,-10.894082,7158.39453,0.00086,98.523094,-0.099768,0.357995,-0.122174,0.085472,-0.999674,0.121504,-0.999114,0.057809,-0.353866,-0.043471,-0.025138,0.087954,-0.430583,-0.088821,0.021409,UNKNOWN,180.0,2.0,13.87,13.87,15,15,100.0,1.838,,1.579769,2.227246,-7.228422,7168.396928,0.001367,69.717278,-0.068526,0.63697,-0.038214,0.064305,-0.999989,0.036762,-0.996314,0.153806,-0.634961,-0.149627,0.715984,-0.159057,0.953945,0.156803,-0.723349,12.0,2.0,786.417082,774.097978,800.056782,780.463075,63.955771,-16.008858,-0.063092,115.208802,15.229084,42.445608,2.201549,5.549886,4.994608,10.549895,0.49631,5.385613,-1.875151,3.681239,-4.670266,-1.309462,-5.550399,-1.080559,73,77,27,4
10,2,6.691611,2,-10.850473,-6.603452,13.374242,22966.0,14348.0,-1161.1,-6330.2,22046.3,15.8,-13792.0,-3957.1,1.0,0.0,3.86,3.86,444,442,99.4,1.099,3.4505,2.880922,1.065057,-10.960651,7158.394561,0.000861,98.523097,-0.005874,0.360471,-0.036075,-0.002789,-0.999876,0.03587,-0.997255,-0.068114,-0.357012,0.076754,-0.027154,0.084268,-0.442266,-0.085037,0.020991,UNKNOWN,180.0,2.0,13.87,13.87,15,15,100.0,1.838,,1.579769,2.227246,-7.228422,7168.397641,0.001367,69.717278,-0.06775,0.636974,-0.038143,0.063521,-0.999989,0.036689,-0.996313,0.153053,-0.634998,-0.148865,0.715914,-0.158753,0.953971,0.156495,-0.723302,12.0,2.0,786.42051,774.094612,800.05708,780.464203,63.956674,-16.008858,-0.063092,101.429474,16.265328,42.441549,2.196657,5.549796,5.490139,10.547926,0.516145,5.385589,-1.378157,3.679266,-4.66989,-1.309606,-5.536811,-1.080597,73,77,27,4
11,2,6.269979,2,-30.0,-6.217958,426.808532,18785.0,14347.0,-698.8,-5176.4,18044.8,14.4,-13791.4,-3957.2,1.0,0.0,3.85,3.85,447,445,99.4,1.113,3.4505,2.746222,0.965072,-11.022029,7158.407962,0.000862,98.5231,-0.222621,0.425875,-0.149746,0.206756,-0.999517,0.147289,-0.999479,0.191052,-0.423717,-0.175085,0.082662,0.017007,-0.405439,-0.018617,-0.08382,UNKNOWN,1.0,0.0,14.63,14.63,15,15,100.0,1.641,,1.649021,1.879016,-7.247312,7168.395887,0.001297,69.718437,0.025977,0.563595,0.065183,-0.045196,-0.999602,-0.075887,-0.999774,-0.006036,-0.564147,0.025308,0.703561,-0.027022,0.916588,0.007301,-0.706289,12.0,2.0,786.439755,774.102169,799.554662,780.963112,63.903391,-16.009902,-0.057504,177.272242,15.145344,31.967553,2.295355,3.879431,4.803485,7.832651,0.601252,4.465008,-2.064897,0.961933,-4.579886,-2.98728,-5.530733,-1.644409,71,77,23,8
12,2,6.042352,2,-30.0,-6.271078,181.496778,18842.0,14347.0,-700.0,-5192.1,18099.4,14.4,-13791.4,-3957.2,1.0,0.0,3.83,3.83,451,449,99.4,1.122,3.4479,2.662064,0.996299,-11.058486,7158.407846,0.000862,98.523108,-0.23012,0.236754,-0.04598,0.222933,-0.999848,0.047067,-0.998058,0.169099,-0.236669,-0.161825,0.104967,0.00233,-0.433469,-0.003761,-0.106536,UNKNOWN,1.0,0.0,14.63,14.63,15,15,100.0,1.641,,1.649021,1.879016,-7.247312,7168.396232,0.001297,69.718436,-0.199922,0.552272,0.010836,0.192984,-0.999944,-0.014799,-0.999656,0.192663,-0.554558,-0.185664,0.694842,-0.051859,0.916218,0.044474,-0.699265,12.0,2.0,786.439933,774.101759,799.555083,780.963381,63.904169,-16.009902,-0.057504,134.49467,15.94439,33.940263,2.213602,3.905626,5.30696,8.812877,0.507719,4.465087,-1.562397,1.943879,-4.666855,-2.964674,-5.541783,-1.643388,71,77,23,8
13,2,5.711716,2,-30.0,-6.277448,187.52536,19015.0,14347.0,-709.9,-5242.1,18264.8,14.5,-13791.4,-3957.2,1.0,0.0,3.72,3.72,466,464,99.5,1.143,3.4479,2.808804,1.381067,-11.012908,7158.407228,0.000863,98.523107,0.173348,0.31812,0.169947,-0.187696,-0.999456,-0.170031,-0.999407,-0.207145,-0.321772,0.221392,0.210865,0.006796,-0.381647,-0.010702,-0.209821,UNKNOWN,1.0,0.0,14.63,14.63,15,15,100.0,1.641,,1.649021,1.879016,-7.247312,7168.396846,0.001297,69.718436,-0.194648,0.552833,0.011237,0.187592,-0.999942,-0.015262,-0.999649,0.187206,-0.555115,-0.180088,0.695361,-0.051522,0.916218,0.044023,-0.69978,12.0,2.0,786.444957,774.095498,799.555077,780.964615,63.906532,-16.009902,-0.057903,194.741883,14.47507,33.909269,2.109108,3.904538,4.663119,8.797461,0.541527,4.465003,-2.204913,1.928403,-4.752686,-2.965744,-5.515744,-1.643389,71,77,23,8
14,2,5.377642,2,-30.0,-6.278272,190.090568,19137.0,14347.0,-710.3,-5273.6,18382.8,14.5,-13791.4,-3957.2,1.0,0.0,3.69,3.69,478,475,99.5,1.119,3.4479,2.816552,1.414954,-11.019932,7158.406819,0.000862,98.523103,-0.115261,0.255582,0.007605,0.107766,-0.999832,-0.00841,-0.998007,0.052449,-0.257121,-0.044921,0.288501,-0.023265,-0.359223,0.022849,-0.288807,UNKNOWN,2.0,1.0,14.63,14.63,15,15,100.0,1.641,,1.649021,1.879016,-7.247312,7168.397203,0.001297,69.718436,-0.193936,0.552834,0.011532,0.186793,-0.999941,-0.015606,-0.99964,0.1861,-0.555175,-0.178894,0.695308,-0.051279,0.916219,0.043688,-0.699793,12.0,2.0,786.443328,774.09631,799.555148,780.965258,63.908209,-16.009902,-0.057903,149.531226,15.588817,33.884905,2.122911,3.904485,5.175424,8.785229,0.537283,4.464941,-1.692445,1.916182,-4.750897,-2.9659,-5.584831,-1.643389,70,77,11,5
15,2,5.028915,2,-30.0,-6.283246,188.270059,18918.0,14347.0,-714.8,-5213.9,18172.0,14.5,-13791.4,-3957.2,1.0,0.0,3.63,3.63,478,475,99.5,1.129,3.4479,3.02206,1.570693,-10.953841,7158.407343,0.000862,98.523115,-0.188946,0.174655,-0.026595,0.182648,-0.999807,0.027637,-0.998078,0.12783,-0.174536,-0.121472,0.286378,-0.055995,-0.380081,0.054933,-0.286015,UNKNOWN,2.0,1.0,14.63,14.63,15,15,100.0,1.641,,1.649021,1.879016,-7.247312,7168.395801,0.001297,69.718436,-0.198442,0.552433,0.010928,0.191476,-0.999943,-0.014906,-0.999639,0.190277,-0.55482,-0.183248,0.695002,-0.05178,0.916216,0.044368,-0.699527,12.0,2.0,786.444058,774.096628,799.553665,780.963937,63.905222,-16.009902,-0.057903,157.237923,15.091503,33.932893,2.043324,3.905318,5.0665,8.809227,0.469616,4.465049,-1.802808,1.940231,-4.834537,-2.96516,-5.611228,-1.643385,70,77,11,5
16,2,4.724355,2,-30.0,-6.283579,192.100563,19152.0,14347.0,-717.5,-5277.4,18397.2,14.5,-13791.4,-3957.2,1.0,0.0,3.72,3.72,469,466,99.4,1.159,3.4479,3.398029,1.775686,-10.828875,7158.406303,0.000863,98.523117,0.138205,0.220052,0.043062,-0.146262,-0.999742,-0.042066,-0.998311,-0.195389,-0.22023,0.203368,0.274975,0.056101,-0.395093,-0.058309,-0.275805,UNKNOWN,2.0,1.0,14.63,14.63,15,15,100.0,1.641,,1.649021,1.879016,-7.247312,7168.39682,0.001297,69.718436,-0.192786,0.553006,0.011435,0.185672,-0.999941,-0.01549,-0.99964,0.184898,-0.555334,-0.177721,0.695521,-0.051359,0.916215,0.043803,-0.699987,12.0,2.0,786.443569,774.095036,799.554287,780.965353,63.908418,-16.009902,-0.057903,169.699821,15.011373,33.89398,2.041773,3.904195,4.97637,8.789843,0.522089,4.464934,-1.891187,1.920781,-4.817995,-2.966167,-5.596369,-1.643383,70,77,11,5
17,2,4.334354,2,-30.0,-6.524619,1984.965171,18635.0,14347.0,-704.3,-5134.8,17900.0,14.4,-13791.4,-3957.2,1.0,0.0,3.7,3.7,478,476,99.3,1.103,3.4479,3.008709,1.58409,-10.950861,7158.408216,0.000863,98.523098,-0.010367,0.249339,-0.049027,-0.00141,-0.999242,0.049649,-0.999415,-0.023803,-0.247384,0.035573,0.188556,0.055165,-0.396573,-0.054989,-0.190535,UNKNOWN,1.0,0.0,16.59,16.59,18,18,100.0,1.689,,1.64975,1.871932,-7.274009,7168.393928,0.001295,69.71848,-0.058769,0.545509,-0.108184,0.008024,-0.997869,0.082388,-0.999982,0.06095,-0.54785,-0.010167,0.728109,-0.10533,0.93518,0.060764,-0.730231,12.0,2.0,786.448043,774.09439,799.540589,780.973267,63.901459,-16.009902,-0.057504,205.770885,13.384397,30.401822,1.893778,3.925567,4.334709,6.972808,0.497011,4.485034,-2.53217,0.10443,-4.974756,-2.940259,-5.632836,-1.638699,69,77,0,5
18,2,4.055292,2,-30.0,-6.270511,404.887321,18789.0,14347.0,-700.7,-5177.3,18048.3,14.4,-13791.4,-3957.2,1.0,0.0,3.63,3.63,489,488,99.1,1.104,3.4479,2.691865,1.486174,-11.060869,7158.407999,0.000863,98.523106,-0.155322,0.217544,-0.016202,0.14893,-0.99955,0.018769,-0.998696,0.104859,-0.217809,-0.098427,0.164164,0.027754,-0.407321,-0.028021,-0.16671,UNKNOWN,1.0,0.0,16.59,16.59,18,18,100.0,1.689,,1.64975,1.871932,-7.274009,7168.395399,0.001295,69.71848,-0.159724,0.5405,-0.046468,0.141367,-0.999715,0.037042,-0.999366,0.128234,-0.545087,-0.10978,0.722872,-0.057421,0.935103,0.041131,-0.727637,12.0,2.0,786.446057,774.095941,799.542098,780.9747,63.903569,-16.009902,-0.057504,193.536714,13.45916,32.427031,1.805864,3.936785,4.542696,7.978575,0.417638,4.484984,-2.326744,1.109949,-5.068911,-2.93368,-5.664379,-1.638476,69,77,0,5


<class 'pandas.core.frame.DataFrame'>
Int64Index: 151099 entries, 9 to 162633
Columns: 103 entries, event_id to AP
dtypes: category(7), float64(88), int16(8)
memory usage: 106.7 MB


## Time-Series Forecasting problem

### Converting data from Pandas DataFrame to Pytorch Tensors

In [5]:
from tqdm import tqdm
from tqdm import trange

# Get input variable features from config file.
in_var_features = list(config.get_features(**{'input':True, 'variable':True}).keys())
in_features     = list(config.get_features(**{'input':True}).keys())

# Get time-series sets for every continuous variable feature 
# (constant features by definition do not need to be forecasted)
tensor_filename = f'training_tsf_ws{window_size}-f{events_to_forecast}.pt'

# Check if file containing tensors is available in the data folder and load it
filepath         = os.path.join(cwd,'data','tensors', tensor_filename)
features_tensors = torch.load(filepath) if os.path.exists(filepath) else {}

in_var_features = ['t_cr_area_over_mass', 't_sedr', 't_ct_r', 't_cn_r',
                      't_cn_t', 't_crdot_r', 't_crdot_t', 't_crdot_n']

# Get all input features from which the tensors have not been extracted yet.
remaining_features = [f for f in in_var_features if not f in list(features_tensors.keys())]


print(f'Features already available in tensor file: {len(list(features_tensors.keys()))}')

# Iterate over all remaining features to get the time series subsets
t = trange(len(remaining_features), desc='Extracting sequences of time-series ...', leave=True)

for f in t:

    # Initialize list of tensors for feature f
    feature = remaining_features[f]
    features_tensors[feature] = []

    for e, event_id in enumerate(events_filter):

        # Update progress bar
        tqdm_desc = f'Extracting sequences of time-series from {feature}' 
                    f' {"."*(30-len(feature))} (Progress: {(e+1)/len(events_filter)*100:5.1f}%)'
        t.set_description(tqdm_desc)
        t.refresh()

        # Get full sequence from dataset and convert it to a tensor.
        feature_dtype = str(df[df['event_id']==event_id][feature].dtype).lower()
        full_seq = df[df['event_id']==event_id][feature].to_numpy(dtype=feature_dtype)
        full_seq = torch.nan_to_num(torch.FloatTensor(full_seq))

        # Add Time-Series subsets from full sequence tensor and add it to the list for the feature f
        features_tensors[feature] = features_tensors[feature] + rnn.event_ts_sets(full_seq, window_size)

    if (f+1)%10==0 or (f+1==len(remaining_features)):
        # Save tensors containing all Time-Series subsets for training organised by feature.
        t.set_description(f'Saving tensors with sequences of time-series into external file {"."*(len(tqdm_desc)-64)}')
        t.refresh()
        torch.save(features_tensors, filepath)

Features already available in tensor file: 8


Extracting sequences of time-series ...: 0it [00:00, ?it/s]


#### Embedding categorical input features

An embedding is a vector representation of a categorical variable. The representation of this vector is computed through the use of NN models/techniques that take into account potential relation between categories in order to create the vector representation for each category.

In practice, an embedding matrix is a lookup table for a vector. Each row of an embedding matrix is a vector for a unique category.

The main advantadge of using embeddings instead of One Hot/Dummy Encoding techniques (one column per unique value of categorical feature with 0s and 1s) is that it can preserve the natural order and common relationships between the categorical features. For example, we could represent the days of the week with 4 floating-point numbers each, and two consecutive days would look more similar than two weekdays that are days apart from each other.


The rule of thumb for determining the embedding size (number of elemens per array) is to divide the number of unique entries in each column by 2, but not to exceed 50.

In [6]:
class EventPropagation(nn.Module):
    def __init__(self, input_size=1, hidden_size=100, out_size=1):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Add an LSTM layer:
        self.lstm = nn.LSTM(input_size,hidden_size)
        
        # Add a fully-connected layer:
        self.linear = nn.Linear(hidden_size,out_size)
        
        # Initialize h0 and c0:
        self.hidden = (torch.zeros(1,1,hidden_size),
                       torch.zeros(1,1,hidden_size))
    
    def forward(self,seq):
        # Get output from LSTM layer and the h0 c0 values updated (passed through LSTM)
        lstm_out, self.hidden = self.lstm(seq.view(len(seq), 1, -1), self.hidden)

        # Predict next values with the Linear layer
        pred = self.linear(lstm_out.view(len(seq),-1))

        # Return only last value
        return pred[-1]



In [None]:
# Define Multivariate LSTM network class
class EventPropagation(nn.Module):
    def __init__(self,input_size, hidden_size, output_size, seq_length, num_layers=1):
        super(EventPropagation, self).__init__()
        self.input_size = input_size    # Number of input features
        self.hidden_size = hidden_size  # Number of hidden neurons
        self.output_size = output_size  # Number of outputs
        self.num_layers = num_layers    # Number of recurrent (stacked) layers
        self.seq_length = seq_length
    
        self.lstm = nn.LSTM(input_size = self.input_size, 
                            hidden_size = self.hidden_size,
                            num_layers = self.num_layers,
                            batch_first = True)
        # according to pytorch docs LSTM output is 
        # (batch_size,seq_len, num_directions * hidden_size)
        # when considering batch_first = True
        self.linear = nn.Linear(self.hidden_size*self.seq_length, 
                                self.output_size)
        
    
    def init_hidden(self, n_sequences):
        # Initialize states. Even with batch_first = True this remains same as docs
        h_state = torch.zeros(self.num_layers, n_sequences, self.hidden_size) # Hidden state
        c_state = torch.zeros(self.num_layers, n_sequences, self.hidden_size) # Cell state
        self.hidden = (h_state, c_state)
    
    
    def forward(self, inputs):        
        n_sequences, seq_length, n_features = inputs.size()
        
        lstm_out, self.hidden = self.lstm(inputs, self.hidden)
        # lstm_out(with batch_first = True) is 
        # (batch_size,seq_len,num_directions * hidden_size)
        # for following linear layer we want to keep batch_size dimension and merge rest       
        # .contiguous() -> solves tensor compatibility error
        inputs = lstm_out.contiguous().view(n_sequences,-1)
        outputs = self.linear(x)
        
        return outputs

In [None]:
# Get list of features and number of sequences to process from tensor file
features    = list(features_tensors.keys())
n_sequences = len(features_tensors[features[0]])
seq_length  = 5 # Window size for the TSF

# Initialize inputs and outputs arrays to contain sequences to process
inputs = np.full(shape = (n_sequences,len(features)), fill_value = None)
outputs = np.full(shape = (n_sequences,len(features)), fill_value = None)

# Initialize trange object for sequences to print progress bar.
sequences = trange(n_sequences, desc='Getting training and target tensors ...', leave=True)
for s in sequences:

    # Initialize list for sequence s
    inputs_s    = []
    outputs_s   = []
    # Get sequence s from all features
    for f, feature in enumerate(features):
        
        # Update progress bar
        tqdm_desc = f'Processing sequence {s:<6} from all features' 
                    f' (Progress: {(f+1)/len(features)*100:5.1f}%)'
        t.set_description(tqdm_desc)
        t.refresh()
        
        # Get sequence s (input and output) from feature f
        # - inputs_sf  = [f1_t1, f1_t2, ..., f1_tn]
        # - outputs_sf = [f1_tn+1]
        inputs_sf, outputs_sf = features_tensors[feature][s]
        
        # Get sequence s (input and output) from feature f
        # - inputs_s  = [[f1_t1, f1_t2, ..., f1_tn], [f2_t1, f2_t2, ..., f2_tn], ...]
        # - outputs_s = [[f1_tn+1], [f2_tn+1], ...]
        inputs_s.append(inputs_sf)
        outputs_s.append(outputs_sf)
        
    # Get sequence s (input and output) from feature f
    # - inputs  = [[f1_t1, f2_t1, ..., fn_t1], [f1_t2, f2_t2, ..., fn_t2], ...]
    # - outputs = [[f1_tn+1, f2_tn+1, ...], ...]
    inputs[s]  = np.transpose(np.asanyarray(inputs_s, dtype='float32'))
    outputs[s] = np.transpose(np.asarray(outputs_s, dtype='float32').flatten())

# Convert numpy array to tensors
inputs  = torch.tensor(inputs, dtype=torch.float32)
outputs = torch.tensor(outputs, dtype=torch.float32)

In [None]:
# Instanciate model with required inputs.
torch.manual_seed(42)
model = EventPropagation(input_size = n_features, 
                         hidden_size = 100,
                         output_size = n_features,
                         seq_length = seq_length)

# Define criterion and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)

# Print model
model

In [None]:
import time
start_time = time.time()

# model.train()

epochs = 500
batch_size = 16

# Initialize array of inf values for losses
losses = np.ones(epochs)*np.inf

# Iterate over all remaining features to get the time series subsets
t = trange(epochs, desc='Extracting sequences of time-series ...', leave=True)

for e in t:
    
    # Train model by passing n batches depending on the batch_size
    for b in range(0, n_sequences, batch_size):
        
        # Get inputs and outputs for batch b
        inputs_b  = inputs[b:b+batch_size, :, :]
        outputs_b = outputs[b:b+batch_size]  
        
        # Reset Gradient from the optimizer (hidden and cell states)
        optimizer.zero_grad()
        
        # Initialize hidden state and compute outputs
        model.init_hidden(inputs_b.size(0))
        forecast = model(inputs_b) 
        
        # Compute loss using the outputs for the batch b and store values in array
        loss = criterion(forecast[-1].view(-1), outputs_b)  
        losses[e] = loss.detach().numpy()
        
        # Back propagate loss and adjust parameters of the optimizer
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        t.set_description(f'Training Conjunction Event Propagation model | MSE loss = {loss.item():10.8f} ')
        t.refresh()
        
print(f'\nDuration: {time.time() - start_time:.0f} seconds')


## Save the trained model to a file
Right now <strong><tt>model</tt></strong> has been trained and validated, and seems to correctly classify an iris 97% of the time. Let's save this to disk.<br>
The tools we'll use are <a href='https://pytorch.org/docs/stable/torch.html#torch.save'><strong><tt>torch.save()</tt></strong></a> and <a href='https://pytorch.org/docs/stable/torch.html#torch.load'><strong><tt>torch.load()</tt></strong></a><br>

There are two basic ways to save a model.<br>

The first saves/loads the `state_dict` (learned parameters) of the model, but not the model class. The syntax follows:<br>
<tt><strong>Save:</strong>&nbsp;torch.save(model.state_dict(), PATH)<br><br>
<strong>Load:</strong>&nbsp;model = TheModelClass(\*args, \*\*kwargs)<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;model.load_state_dict(torch.load(PATH))<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;model.eval()</tt>

The second saves the entire model including its class and parameters as a pickle file. Care must be taken if you want to load this into another notebook to make sure all the target data is brought in properly.<br>
<tt><strong>Save:</strong>&nbsp;torch.save(model, PATH)<br><br>
<strong>Load:</strong>&nbsp;model = torch.load(PATH))<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;model.eval()</tt>

In either method, you must call <tt>model.eval()</tt> to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

For more information visit https://pytorch.org/tutorials/beginner/saving_loading_models.html