## Import standard libraries

In [1]:
import torch
import torch.nn as nn  # we'll use this a lot going forward!
import torch.nn.functional as F

import numpy as np
import warnings

# Import matplotlib library and setup environment for plots
%matplotlib inline
%config InlineBackend.figure_format='retina'
from matplotlib import pyplot as plt, rc

# Import json library and create function to format dictionaries.
import json
format_json = lambda x: json.dumps(x, indent=4)

# Import pandas and set pandas DataFrame visualization parameters
from IPython.display import display
import pandas as pd
pd.options.display.max_columns = None
pd.options.display.max_rows = None

# Set rendering parameters to use TeX font if not working on Juno app.
from pathlib import Path
import os
if not '/private/var/' in os.getcwd():
    rc('font', **{'family': 'serif', 'serif': ['Computer Modern'], 'size': 11})
    rc('text', usetex=True)
    
# Get current working directory path for the tool parent folder and print it.
parent_folder = 'Tool'
cwd = str(Path(os.getcwd()[:os.getcwd().index(parent_folder)+len(parent_folder)]))
print('Parent working directory: %s' % cwd)


Parent working directory: /Users/jjrr/Documents/SCA-Project/Tool


## Import user defined libraries

In [2]:
# Import custom libraries from local folder.
import sys
sys.path.append("..")

from library.irplib import utils, eda, config, sdg

## Data preparation

### Import training dataset

In [3]:
# Import transformed training dataset
df = eda.import_cdm_data(os.path.join(cwd,'data','esa-challenge','train_data_transformed.csv'))

# Count number of CDMs available per event
nb_cdms = df.groupby(['event_id']).count()['time_to_tca'].to_numpy(dtype=np.int32)

# Get information summary on TS suitable for training
min_cdms = 5
print(f'Events suitable for training: {np.sum(nb_cdms>min_cdms)}'
      f' ({np.sum(nb_cdms>min_cdms)/len(nb_cdms)*100:.2f}%)')
print(f'TSs with event_id integrity for training: {np.sum(nb_cdms[nb_cdms>min_cdms]-min_cdms)}')

# Show first data points to explore data types
display(df[df['event_id'].isin([0,1,2])])

Events suitable for training: 9400 (71.46%)
TSs with event_id integrity for training: 104099


Unnamed: 0,event_id,time_to_tca,mission_id,risk,max_risk_estimate,max_risk_scaling,miss_distance,relative_speed,relative_position_r,relative_position_t,relative_position_n,relative_velocity_r,relative_velocity_t,relative_velocity_n,t_time_lastob_start,t_time_lastob_end,t_recommended_od_span,t_actual_od_span,t_obs_available,t_obs_used,t_residuals_accepted,t_weighted_rms,t_rcs_estimate,t_cd_area_over_mass,t_cr_area_over_mass,t_sedr,t_j2k_sma,t_j2k_ecc,t_j2k_inc,t_ct_r,t_cn_r,t_cn_t,t_crdot_r,t_crdot_t,t_crdot_n,t_ctdot_r,t_ctdot_t,t_ctdot_n,t_ctdot_rdot,t_cndot_r,t_cndot_t,t_cndot_n,t_cndot_rdot,t_cndot_tdot,c_object_type,c_time_lastob_start,c_time_lastob_end,c_recommended_od_span,c_actual_od_span,c_obs_available,c_obs_used,c_residuals_accepted,c_weighted_rms,c_rcs_estimate,c_cd_area_over_mass,c_cr_area_over_mass,c_sedr,c_j2k_sma,c_j2k_ecc,c_j2k_inc,c_ct_r,c_cn_r,c_cn_t,c_crdot_r,c_crdot_t,c_crdot_n,c_ctdot_r,c_ctdot_t,c_ctdot_n,c_ctdot_rdot,c_cndot_r,c_cndot_t,c_cndot_n,c_cndot_rdot,c_cndot_tdot,t_span,c_span,t_h_apo,t_h_per,c_h_apo,c_h_per,geocentric_latitude,azimuth,elevation,mahalanobis_distance,t_position_covariance_det,c_position_covariance_det,t_sigma_r,c_sigma_r,t_sigma_t,c_sigma_t,t_sigma_n,c_sigma_n,t_sigma_rdot,c_sigma_rdot,t_sigma_tdot,c_sigma_tdot,t_sigma_ndot,c_sigma_ndot,F10,F3M,SSN,AP
0,0,1.566798,5,-10.204955,-7.834756,8.602101,14923.0,13792.0,453.8,5976.6,-13666.8,-7.2,-12637.0,-5525.9,1.0,0.0,3.78,3.78,459,458,98.9,1.265,0.402,1.648115,0.70659,-9.879172,6996.918867,0.003997,97.806412,-0.397969,0.292258,0.040799,0.394221,-0.999674,-0.038498,-0.981098,0.214612,-0.316493,-0.210247,0.170737,-0.001551,0.531593,0.002117,-0.179278,UNKNOWN,180.0,2.0,15.85,15.85,15,15,100.0,2.36,,0.676499,0.499637,-6.567277,7006.60732,0.003144,74.045735,-0.824859,0.473976,-0.002576,0.825216,-0.999998,0.003565,-0.732954,0.220006,-0.814249,-0.220621,0.249855,0.19662,0.722186,-0.196908,-0.668487,1.0,2.0,646.745439,590.818294,650.497251,606.443389,-73.574095,-23.618769,0.02991,129.430951,13.510814,38.329744,1.400673,5.586208,4.924475,10.90351,0.57741,3.84187,-1.914944,4.065123,-5.498756,-1.801545,-5.813629,-0.950721,89,83,42,11
1,0,1.207494,5,-10.355758,-7.848937,8.956374,14544.0,13792.0,474.3,5821.2,-13319.8,-7.0,-12637.0,-5525.9,1.0,0.0,3.79,3.79,456,455,98.5,1.27,0.402,1.607704,0.900255,-9.721584,6996.920255,0.003996,97.80642,-0.073137,0.297366,0.060541,0.069652,-0.998192,-0.052511,-0.99424,-0.029644,-0.302333,0.03403,0.179696,0.001552,0.561142,-0.005165,-0.181036,UNKNOWN,180.0,2.0,15.85,15.85,15,15,100.0,2.36,,0.676499,0.499637,-6.567277,7006.621053,0.003144,74.045736,-0.818207,0.482754,-0.003578,0.818573,-0.999998,0.004574,-0.728759,0.202595,-0.81749,-0.203216,0.258964,0.195718,0.721903,-0.196008,-0.674979,1.0,2.0,646.743506,590.823004,650.513314,606.454793,-73.57069,-23.618769,0.029079,271.540424,11.645172,38.318093,1.260385,5.569076,4.026603,10.898258,0.588319,3.841512,-2.81889,4.059864,-5.585288,-1.805559,-5.831622,-0.9509,89,83,42,11
2,0,0.952193,5,-10.345631,-7.847406,8.932195,14475.0,13792.0,474.6,5796.2,-13256.1,-7.0,-12637.0,-5525.9,1.0,0.0,3.79,3.8,456,455,98.5,1.257,0.402,1.592208,0.695163,-9.712935,6996.920553,0.003996,97.806418,-0.10923,0.305189,0.043711,0.107079,-0.996235,-0.034287,-0.996674,0.033933,-0.308501,-0.030161,0.12376,0.01963,0.579274,-0.023726,-0.125737,UNKNOWN,180.0,2.0,15.85,15.85,15,15,100.0,2.36,,0.676499,0.499637,-6.567277,7006.623524,0.003144,74.045737,-0.817408,0.483828,-0.003742,0.817774,-0.999998,0.004738,-0.729083,0.201698,-0.817662,-0.20232,0.260092,0.195558,0.721854,-0.195849,-0.675347,1.0,2.0,646.745607,590.821499,650.515082,606.457965,-73.570088,-23.618769,0.029079,347.899292,10.757052,38.31592,1.212547,5.567071,3.624286,10.897246,0.599902,3.841445,-3.237592,4.058849,-5.633474,-1.805745,-5.820248,-0.950932,89,83,42,11
3,0,0.579669,5,-10.337809,-7.84588,8.913444,14579.0,13792.0,472.7,5838.9,-13350.7,-7.0,-12637.0,-5525.9,1.0,0.0,3.86,3.86,443,442,98.4,1.254,0.402,1.608062,0.539818,-9.64591,6996.920276,0.003997,97.806423,0.021588,0.423647,0.157544,-0.032072,-0.990913,-0.149998,-0.998856,-0.059398,-0.427166,0.072214,0.228051,0.039163,0.547564,-0.049976,-0.229497,UNKNOWN,180.0,2.0,15.85,15.85,15,15,100.0,2.36,,0.676499,0.499637,-6.567277,7006.622932,0.003144,74.045736,-0.817557,0.483678,-0.00376,0.817923,-0.999998,0.004755,-0.728092,0.200534,-0.817883,-0.201155,0.259681,0.195808,0.721874,-0.196097,-0.675656,1.0,2.0,646.747747,590.818806,650.515635,606.456229,-73.571021,-23.618769,0.029079,435.376626,9.429859,38.31896,1.123559,5.567427,3.101572,10.89873,0.602548,3.841549,-3.813734,4.060338,-5.714465,-1.806008,-5.83675,-0.950883,89,83,40,14
4,0,0.257806,5,-10.39126,-7.852942,9.036838,14510.0,13792.0,478.7,5811.1,-13288.0,-7.0,-12637.0,-5525.9,1.0,0.0,3.86,3.86,440,439,98.8,1.34,0.402,1.657651,0.722942,-9.602338,6996.920446,0.003996,97.806426,0.417865,0.406002,0.246911,-0.465256,-0.983144,-0.243006,-0.999749,-0.430895,-0.405723,0.47999,0.223469,0.118674,0.543475,-0.143542,-0.22406,UNKNOWN,180.0,2.0,15.85,15.85,15,15,100.0,2.36,,0.676499,0.499637,-6.567277,7006.626646,0.003144,74.045736,-0.81598,0.485794,-0.004081,0.816347,-0.999998,0.005077,-0.727257,0.196662,-0.818575,-0.197284,0.261753,0.195657,0.721794,-0.195947,-0.677041,1.0,2.0,646.745868,590.821024,650.519613,606.459678,-73.570409,-23.618769,0.029079,469.178802,8.965347,38.317072,1.221472,5.563476,2.77949,10.897918,0.672074,3.841487,-4.194687,4.059524,-5.607668,-1.806842,-5.738377,-0.950913,89,83,40,14
5,1,6.530455,5,-7.561299,-7.254301,2.746782,2392.0,3434.0,74.3,2317.1,-589.4,25.9,-847.8,-3328.2,1.0,0.0,4.06,4.06,432,431,98.9,1.447,0.403,1.595188,0.090725,-10.276501,7001.527412,0.00103,97.76702,-0.134713,0.126515,-0.023597,0.108845,-0.998856,0.018619,-0.999598,0.107629,-0.126921,-0.081454,-0.096789,-0.082136,0.810101,0.085601,0.099044,DEBRIS,1.0,0.0,3.28,3.28,33,33,100.0,1.394,0.0045,0.307134,0.0,-4.498246,6880.65211,0.017491,82.431758,-0.858928,0.106509,0.325019,0.858976,-1.0,-0.324987,-0.91019,0.993879,0.240336,-0.99389,-0.682288,0.812345,0.224426,-0.812406,0.803415,1.0,2.0,630.601047,616.179777,622.863863,382.166357,-54.220645,-75.708851,-0.432069,4.183398,17.927486,36.311456,1.869822,4.182252,5.555486,11.2636,1.555683,3.937777,-1.278861,4.438849,-4.964692,-1.137492,-5.472981,-3.049457,71,88,0,2
6,1,5.561646,5,-9.315693,-7.468904,7.223137,3587.0,3434.0,99.0,3475.4,-885.1,24.7,-847.8,-3328.2,1.0,0.0,3.91,3.91,449,448,99.0,1.447,0.403,1.744551,0.552328,-10.302343,7001.52578,0.001029,97.767016,-0.125119,0.221214,0.011789,0.098413,-0.998694,-0.020421,-0.999507,0.095203,-0.222805,-0.068121,0.009246,-0.017973,0.815906,0.015338,-0.009128,DEBRIS,1.0,0.0,3.28,3.28,36,36,100.0,1.428,0.0045,0.305194,0.0,-4.566296,6880.652333,0.017494,82.431471,-0.929685,-0.031789,0.295882,0.929701,-1.0,-0.295874,-0.965251,0.993607,0.219643,-0.993613,-0.582234,0.723177,0.30164,-0.723223,0.690395,1.0,2.0,630.589995,616.187566,622.886965,382.143701,-54.224559,-75.708851,-0.412051,7.267719,17.206642,35.403617,1.782985,4.343231,5.412093,10.981388,1.442052,3.746817,-1.423345,4.156707,-5.051207,-1.329401,-5.592821,-3.218367,70,87,13,14
7,1,5.226504,5,-7.422508,-7.051001,2.956639,7882.0,3434.0,-50.0,-7638.3,1945.7,36.8,-847.7,-3328.2,1.0,0.0,3.99,3.99,434,433,98.9,1.383,0.403,1.701995,0.548782,-10.321625,7001.538661,0.001028,97.767016,-0.259827,0.194564,-0.01736,0.236605,-0.99876,0.010202,-0.99931,0.225534,-0.196229,-0.201741,-0.017222,-0.019988,0.813066,0.018852,0.017768,DEBRIS,1.0,0.0,3.44,3.44,36,36,100.0,1.543,0.0045,0.286323,0.0,-4.638379,6880.573187,0.017487,82.431405,-0.949988,-0.159227,0.029555,0.949971,-1.0,-0.029337,-0.98201,0.991788,0.081376,-0.991781,-0.38717,0.287235,0.855166,-0.287098,0.330801,1.0,2.0,630.601095,616.202226,622.753693,382.11868,-54.181705,-75.710468,-0.613897,11.197928,16.925462,38.229383,1.706974,4.580773,5.416534,10.863491,1.394081,4.93181,-1.42629,4.038918,-5.131717,-1.350506,-5.630159,-2.485949,70,87,13,14
8,1,3.570013,5,-9.248105,-7.327533,7.425994,26899.0,3434.0,-82.0,-26067.0,6638.2,56.8,-847.8,-3328.2,1.0,0.0,4.04,4.04,430,429,99.0,1.402,0.403,1.809278,0.336598,-10.382249,7001.561205,0.001028,97.767002,0.025433,0.318842,0.027919,-0.050648,-0.997384,-0.040625,-0.99956,-0.053437,-0.319966,0.079098,0.016432,-0.04985,0.760496,0.046685,-0.015341,DEBRIS,1.0,0.0,3.54,3.55,33,33,100.0,1.65,0.0045,0.294611,0.0,-4.741029,6880.588349,0.017491,82.431524,-0.896007,-0.341763,0.075631,0.895993,-1.0,-0.07517,-0.98256,0.962694,0.243049,-0.962685,-0.389578,0.249596,0.811324,-0.249363,0.344542,1.0,2.0,630.622566,616.225845,622.796659,382.106039,-54.109782,-75.708851,-0.94748,62.058087,15.444216,36.71646,1.696593,4.604334,4.749726,10.233657,1.329935,4.576352,-2.081025,3.409257,-5.135739,-1.743221,-5.681594,-2.841973,71,87,21,5
9,2,6.983474,2,-10.816161,-6.601713,13.293159,22902.0,14348.0,-1157.6,-6306.2,21986.3,15.8,-13792.0,-3957.1,1.0,0.0,3.92,3.92,444,442,99.4,1.094,3.4505,3.042086,0.92498,-10.894082,7158.39453,0.00086,98.523094,-0.099768,0.357995,-0.122174,0.085472,-0.999674,0.121504,-0.999114,0.057809,-0.353866,-0.043471,-0.025138,0.087954,-0.430583,-0.088821,0.021409,UNKNOWN,180.0,2.0,13.87,13.87,15,15,100.0,1.838,,1.579769,2.227246,-7.228422,7168.396928,0.001367,69.717278,-0.068526,0.63697,-0.038214,0.064305,-0.999989,0.036762,-0.996314,0.153806,-0.634961,-0.149627,0.715984,-0.159057,0.953945,0.156803,-0.723349,12.0,2.0,786.417082,774.097978,800.056782,780.463075,63.955771,-16.008858,-0.063092,115.208802,15.229084,42.445608,2.201549,5.549886,4.994608,10.549895,0.49631,5.385613,-1.875151,3.681239,-4.670266,-1.309462,-5.550399,-1.080559,73,77,27,4


## Time-Series Forecasting problem

In [8]:
def event_ts_sets(full_seq:np.ndarray, window_size:int, events_to_forecast:int=1) -> list:  
    """Get all possible Time-Series subsets (sequence->target) from a complete Time-Series set 
    associated to an event.

    Args:
        full_sequence (np.ndarray): Array containing the full sequence 
        of data for a given event.
        window_size (int): Window size of events.
        events_to_forecast (int): Number of events to forecast. Defaults to 1.

    Returns:
        list: List of tuples with sequences and labels.
    """

    # Get number of TS sets to extract from the full sequence. 
    n = len(full_seq) - (window_size + events_to_forecast)+1

    # Initialize Time-Series sets list containing tuples with sequence-target
    # for a given event.
    ts_sets = []

    # Create the list of Time-Series sets using a loop.
    for i in range(n):

        # Get sequence and target value for element i
        seq_i       = full_seq[i:i+window_size]
        target_i    = full_seq[i+window_size:i+window_size+events_to_forecast]

        # Add tuple to the output list
        ts_sets.append((seq_i, target_i))

    return ts_sets

### Converting data from Pandas DataFrame to Pytorch Tensors

In [11]:
from tqdm import tqdm

window_size = 5
events_to_forecast = 1

# Count number of CDMs (full Time-Series sequence) per event
ts_events  = df[['event_id', 'time_to_tca']].groupby(['event_id']).count().rename(columns={'time_to_tca':'nb_events'})

# Exclude those events that do not have a minimum number of CDMs equal to the window_size + events_to_forecast
ts_events = ts_events.drop(ts_events[ts_events['nb_events']<window_size+events_to_forecast].index)
events_filter = list(ts_events.index.values)

# Get time-series sets for every continuous variable feature 
# (constant features by definition do not need to be forecasted)
tensor_filename = f'training_tsf_ws{window_size}-f{events_to_forecast}.pt'

# Check if tensors pt file is available in the data folder
filepath = os.path.join(cwd,'data','tensors', tensor_filename)

# Import tensors
features_ts = torch.load(filepath) if os.path.exists(filepath) else {}
print(f'Features already available in tensor file: {len(list(features_ts.keys()))}\n')

# Get input variable features from config file.
in_var_features = list(config.get_features(**{'input':True, 'variable':True}).keys())
in_features = list(config.get_features(**{'input':True}).keys())

for f, feature in enumerate([f for f in in_var_features if not f in list(features_ts.keys())]):

    # Initialize list of tensors for feature f
    features_ts[feature] = []

    for event_id in tqdm(events_filter, desc=f'Extracting time-series subsets for {feature:<25s}'
                                             f' (Overall progress: {(f+1)/(len(in_features)-len(list(features_ts.keys())))*100:4.1f}%)'):

        # Get full sequence from dataset and convert it to a tensor.
        feature_dtype = str(df[df['event_id']==2][feature].dtype).lower()
        full_seq = torch.FloatTensor(df[df['event_id']==event_id][feature].dropna().to_numpy(dtype=feature_dtype))

        # Add Time-Series subsets from full sequence tensor and add it to the list for the feature f
        features_ts[feature] = features_ts[feature] + event_ts_sets(full_seq, window_size)

# Save tensors containing all Time-Series subsets for training organised by feature.
print('Saving list of Time-Series subsets tensors...', end='\r')
torch.save(features_ts, filepath)
print('Saving list of Time-Series subsets tensors... Saved.\n')

Features already available in tensor file: 0



Extracting time-series subsets for time_to_tca               (Overall progress:  1.0%): 100%|██████████| 9400/9400 [00:07<00:00, 1189.52it/s]
Extracting time-series subsets for miss_distance             (Overall progress:  2.1%): 100%|██████████| 9400/9400 [00:08<00:00, 1162.12it/s]
Extracting time-series subsets for relative_position_r       (Overall progress:  3.2%): 100%|██████████| 9400/9400 [00:08<00:00, 1114.28it/s]
Extracting time-series subsets for relative_position_t       (Overall progress:  4.3%): 100%|██████████| 9400/9400 [00:08<00:00, 1119.65it/s]
Extracting time-series subsets for relative_position_n       (Overall progress:  5.4%): 100%|██████████| 9400/9400 [00:08<00:00, 1104.57it/s]
Extracting time-series subsets for relative_velocity_r       (Overall progress:  6.5%): 100%|██████████| 9400/9400 [00:08<00:00, 1153.61it/s]
Extracting time-series subsets for t_recommended_od_span     (Overall progress:  7.7%): 100%|██████████| 9400/9400 [00:08<00:00, 1082.32it/s]
Extrac

Saving list of Time-Series subsets tensors... Saved.



#### Embedding categorical input features

An embedding is a vector representation of a categorical variable. The representation of this vector is computed through the use of NN models/techniques that take into account potential relation between categories in order to create the vector representation for each category.

In practice, an embedding matrix is a lookup table for a vector. Each row of an embedding matrix is a vector for a unique category.

The main advantadge of using embeddings instead of One Hot/Dummy Encoding techniques (one column per unique value of categorical feature with 0s and 1s) is that it can preserve the natural order and common relationships between the categorical features. For example, we could represent the days of the week with 4 floating-point numbers each, and two consecutive days would look more similar than two weekdays that are days apart from each other.


The rule of thumb for determining the embedding size (number of elemens per array) is to divide the number of unique entries in each column by 2, but not to exceed 50.

In [16]:
# Get input categorical features from config file.
cat_features = list(config.get_features(**{'input':True, 'continuous':False}).keys())

# This will set embedding sizes for the categorical columns:
# an embedding size is the length of the array into which every category
# is converted
cat_szs = [len(df[f].cat.categories) for f in cat_features]
emb_szs = [(size, min(50, (size+1)//2)) for size in cat_szs]

for f, feature in enumerate(cat_features):
    print(f'Feature {feature:20s} Embedding size: {emb_szs[f]} '
          f'(Unique vectors: {emb_szs[f][0]} | Length: {emb_szs[f][1]})')

Feature t_time_lastob_start  Embedding size: (3, 2) (Unique vectors: 3 | Length: 2)
Feature t_time_lastob_end    Embedding size: (3, 2) (Unique vectors: 3 | Length: 2)
Feature c_object_type        Embedding size: (5, 3) (Unique vectors: 5 | Length: 3)
Feature c_time_lastob_start  Embedding size: (3, 2) (Unique vectors: 3 | Length: 2)
Feature c_time_lastob_end    Embedding size: (3, 2) (Unique vectors: 3 | Length: 2)


In [None]:
from torch.utils.data import TensorDataset, DataLoader

# Get input continuous features from config file.
cont_features = list(config.get_features(**{'input':True, 'continuous':True}).keys())

# Covert continuous features to a tensor
conts = np.stack([df[f].values for f in cont_features], 1)
conts = torch.tensor(conts, dtype=torch.float)

# Convert categorical variables to a tensor
cats = torch.tensor(cat_features, dtype=torch.int64) 

# Convert target features to a tensor
y = torch.tensor(df[out_features].values, dtype=torch.float).reshape(-1,1)


# # Get input data and target data
# data    = df[in_features].values
# target  = df[out_features].values

# # Create tensor using TensorDataset from the input and target features.
# cdm_td = TensorDataset(torch.FloatTensor(data), torch.FloatTensor(target))
# print(cdm_td[0])

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, out_size=1):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Add an LSTM layer:
        self.lstm = nn.LSTM(input_size,hidden_size)
        
        # Add a fully-connected layer:
        self.linear = nn.Linear(hidden_size,out_size)
        
        # Initialize h0 and c0:
        self.hidden = (torch.zeros(1,1,hidden_size),
                       torch.zeros(1,1,hidden_size))
    
    def forward(self,seq):
        lstm_out, self.hidden = self.lstm(
            seq.view(len(seq), 1, -1), self.hidden)
        pred = self.linear(lstm_out.view(len(seq),-1))
        return pred[-1]   # we only care about the last prediction

In [None]:
class TabularModel(nn.Module):

    def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):
        
        # Inherit attributes from nn.Module class
        super().__init__()
        
        #############################################################################
        # Instanciate functions to use on the forward operation:
        
        # self.embeds: Creates a list of pre-configured Embedding operations (it is 
        # configured by passing the number of categories ni and the length of the 
        # embedding nf)
        self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
        
        # self.emb_drop: Cancels a proportion p of the embeddings.
        self.emb_drop = nn.Dropout(p)
        
        # self.bn_cont = Normalizes continuous features. This function is configured
        # by passing the number of continuous features to normalize.
        self.bn_cont = nn.BatchNorm1d(n_cont)
        
        #############################################################################
        # Count total number of embeddings (Total number of vector components for
        # every feature)
        n_emb = sum((nf for ni,nf in emb_szs))
        
        # Compute total number of inputs to pass to the initial layer (data point = 
        # Nb. of embeddings + Nb. of continuous variables)
        n_in = n_emb + n_cont
        
        # Run through every layer to set up the operations to perform per layer.
        # (i.e. layers=[100, 50, 200])
        layerlist = []
        for l, n_neurons in enumerate(layers):
            # On layer l, which contains n_neurons, perform the following operations:
            # 1. Apply linear neural network regression (z = Sum(wi*xi+bi))
            layerlist.append(nn.Linear(n_in,n_neurons))
            
            # 2. Apply ReLU activation function (al(z))
            layerlist.append(nn.ReLU(inplace=True))
            
            # 3. Normalize data using the n_neurons
            layerlist.append(nn.BatchNorm1d(n_neurons))
            
            # 4. Cancel out a random proportion p of the neurons to avoid overfitting
            layerlist.append(nn.Dropout(p))
            
            # 5. Set new number of input features n_in for the next layer l+1.
            n_in = n_neurons
        
        # Set the last layer of the list which corresponds to the final output
        layerlist.append(nn.Linear(layers[-1],out_sz))
        
        # Instantiate layers as a Neural Network sequential task
        self.layers = nn.Sequential(*layerlist)
    
    def forward(self, x_cat, x_cont):
        # Initialize embeddings list
        embeddings = []
        
        # Apply embedding function e from self.embeds to the category i
        # in x_cat array
        for i,e in enumerate(self.embeds):
            embeddings.append(e(x_cat[:,i]))
        
        # Concatenate embedding sections into 1
        x = torch.cat(embeddings, 1)
        
        # Apply dropout function to to the embeddings torch
        x = self.emb_drop(x)
        
        # Normalize continuous variables
        x_cont = self.bn_cont(x_cont)
        
        # Concatenate embeddings with continuous variables into one torch
        x = torch.cat([x, x_cont], 1)
        
        # Process all data points with the layers functions (sequential of operations)
        x = self.layers(x)
        
        return x


## Save the trained model to a file
Right now <strong><tt>model</tt></strong> has been trained and validated, and seems to correctly classify an iris 97% of the time. Let's save this to disk.<br>
The tools we'll use are <a href='https://pytorch.org/docs/stable/torch.html#torch.save'><strong><tt>torch.save()</tt></strong></a> and <a href='https://pytorch.org/docs/stable/torch.html#torch.load'><strong><tt>torch.load()</tt></strong></a><br>

There are two basic ways to save a model.<br>

The first saves/loads the `state_dict` (learned parameters) of the model, but not the model class. The syntax follows:<br>
<tt><strong>Save:</strong>&nbsp;torch.save(model.state_dict(), PATH)<br><br>
<strong>Load:</strong>&nbsp;model = TheModelClass(\*args, \*\*kwargs)<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;model.load_state_dict(torch.load(PATH))<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;model.eval()</tt>

The second saves the entire model including its class and parameters as a pickle file. Care must be taken if you want to load this into another notebook to make sure all the target data is brought in properly.<br>
<tt><strong>Save:</strong>&nbsp;torch.save(model, PATH)<br><br>
<strong>Load:</strong>&nbsp;model = torch.load(PATH))<br>
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;model.eval()</tt>

In either method, you must call <tt>model.eval()</tt> to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

For more information visit https://pytorch.org/tutorials/beginner/saving_loading_models.html