In [1]:
!git clone https://github.com/etas/SynCAN.git &> /dev/null
!unzip ./SynCAN/\*.zip -d ./SynCAN/. &> /dev/null
!rm ./SynCAN/*.zip &> /dev/null

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
import argparse
from pathlib import Path
import json
import numpy as np
import pandas as pd
from tqdm import tqdm,trange
import time
import pickle
torch.manual_seed(1234)

<torch._C.Generator at 0x7f2cd216d030>

In [3]:
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [4]:
# Here we define our model as a class
class INDRA_Enc(nn.Module):

    def __init__(self, input_dim, enc_input_dim,
                enc_output_dim, num_layers, batch_size):
        super(INDRA_Enc, self).__init__()
        self.input_dim = input_dim #Number of singals per MSG ID
        self.enc_input_dim = enc_input_dim
        self.enc_output_dim = enc_output_dim
        self.num_layers = num_layers
        self.batch_size = batch_size

        self.linear_enc = nn.Sequential(
                        nn.Linear(self.input_dim,self.enc_input_dim),
                        nn.Tanh())#,
                        # nn.Dropout(p=0.2))
        self.gru_enc = nn.GRU(self.enc_input_dim,self.enc_output_dim,
                                num_layers=self.num_layers)
        self.tanh_actv = nn.Tanh()
        self.drop_out = nn.Dropout(p=0.2)

    def init_hidden(self, hidden_dim):
        return torch.zeros(self.num_layers, self.batch_size, hidden_dim)

    def forward(self, input, hidden, drp_out):
        input = self.linear_enc(input)  #Making input ready for GRU encoder
        if drp_out:
            input = self.drop_out(input)
        enc_out , enc_hidden = self.gru_enc(input.view(len(input),
                                            self.batch_size,-1), hidden)
        enc_out = self.tanh_actv(enc_out)
        enc_hidden = self.tanh_actv(enc_hidden)

        if drp_out:
            enc_out = self.drop_out(enc_out)
            enc_hidden = self.drop_out(enc_hidden)

        return enc_out, enc_hidden

In [5]:
class INDRA_Dec(nn.Module):

    def __init__(self, dec_input_dim, dec_output_dim,
                linear_output_dim, num_layers, batch_size):
        super(INDRA_Dec, self).__init__()
        self.dec_input_dim = dec_input_dim  #Encoder output dimensions
        self.dec_output_dim = dec_output_dim
        self.linear_output_dim = linear_output_dim #Number of signals per MSG ID
        self.num_layers = num_layers
        self.batch_size = batch_size

        self.gru_dec = nn.GRU(self.dec_input_dim,self.dec_output_dim,
                                num_layers=self.num_layers)
        self.tanh_actv = nn.Tanh()
        self.linear_dec = nn.Sequential(
                        nn.Linear(self.dec_output_dim,self.linear_output_dim),
                        nn.Tanh())#,
                        # nn.Dropout(p=0.2))
        self.drop_out = nn.Dropout(p=0.2)

    def init_hidden(self, hidden_dim):
        return torch.zeros(self.num_layers, self.batch_size, hidden_dim)

    def forward(self, input, hidden, drp_out):
        dec_out , dec_hidden = self.gru_dec(input.view(len(input),
                                            self.batch_size,-1), hidden)
        dec_out = self.tanh_actv(dec_out)
        dec_hidden = self.tanh_actv(dec_hidden)
        if drp_out:
            dec_out = self.drop_out(dec_out)
            dec_hidden = self.drop_out(dec_hidden)

        dec_out = self.linear_dec(dec_out)

        return dec_out

In [6]:
class INDRA_ED_2L(nn.Module):
    def __init__(self, input_dim, linear_out_dim, context_dim, num_layers, batch_size):
        super(INDRA_ED_2L, self).__init__()
        self.encoder = INDRA_Enc(input_dim, linear_out_dim, context_dim, num_layers, batch_size)
        self.decoder = INDRA_Dec(context_dim, linear_out_dim, input_dim, num_layers, batch_size)
        self.linear_out_dim = linear_out_dim

    def forward(self, input, hidden1, hidden2, drp_out):
        enc_out, enc_hidden = self.encoder(input, hidden1, drp_out)
        dec_out = self.decoder(enc_out, hidden2, drp_out)

        return dec_out

In [7]:
# def csv2df(dir_path):
#     all_files = [file for file in sorted(dir_path.glob('train_*.zip'))]

#     ## TODO: Maybe create a new dataset using the below steps and avoid
#     ## repeating these steps every time the program is ran

#     all_cols = ['Label','ID', 'Time', 'Signal1','Signal2','Signal3','Signal4']
#     df_list = []
#     for f in tqdm(all_files,bar_format="{l_bar}{bar}|", ncols=50, desc = 'Loading dataset'):
#         if f == [all_files[0]]:
#             tmp_df = pd.read_csv(f,sep=',',header=0)
#         else:
#             tmp_df = pd.read_csv(f,sep=',',header=None,names = all_cols)
#         df_list.append(tmp_df)

#     df = pd.concat(df_list, sort=False, ignore_index=False, axis=0)

#     return df

In [8]:
# def prepare_dataset(dir_path, msg_id):

#     ## Import data to pandas df from csv
#     data_frame = csv2df(dir_path)

#     ## Preprocess dataset to make it ready for the model
#     #Filtering out other msg IDs from the data set
#     data_frame = data_frame[:][data_frame.ID==msg_id]

#     #dropping all columns with all NaN
#     data_frame = data_frame.dropna(axis=1,how='all')

#     #checking if DF corresponds to only one MSG ID
#     assert len(data_frame.ID.unique()) == 1

#     #checking if there are any NaNs
#     assert data_frame.isnull().values.any() == False

#     #Removing columns other than the signal data (3rd Column to the end)
#     data_frame = data_frame.iloc[:,3:len(data_frame.columns)]

#     return data_frame

In [9]:
def csv2df(dir_path):
    data_frames = []
    csv_path = dir_path + "/train_1.csv"
    df_temp = pd.read_csv(csv_path)
    data_frames.append(df_temp)
    for i in range(2, 5):
        csv_path = dir_path + "/train_" + str(i) + ".csv"
        df_temp = pd.read_csv(csv_path, header=None, names=["Label",  "Time", "ID", "Signal1",  "Signal2",  "Signal3",  "Signal4"])
        data_frames.append(df_temp)
    df = pd.concat(data_frames)
    return df

def prepare_dataset(msg_id):
    df = csv2df("/content/SynCAN")
    df = df[:][df.ID==msg_id]
    df = df.dropna(axis=1,how='all')
    assert len(df.ID.unique()) == 1
    assert df.isnull().values.any() == False
    df = df.iloc[:,3:len(df.columns)]
    return df

In [10]:
class SynCAN_Dataset(Dataset):

    def __init__(self,data,ss_len):
        self.data = data
        self.num_samples = len(self.data)
        self.ss_len = ss_len    #ss_len is the rolling window size

    def __len__(self):
        return len(self.data) - self.ss_len

    def __getitem__(self,idx):
        return torch.tensor(self.data[idx:idx+self.ss_len].values.astype(float)).float()
        # _y = torch.tensor(self.data[idx+self.ss_len:idx+self.ss_len+1].values).float()
        # _y = torch.tensor(self.data[idx+1:idx+self.ss_len+1].values).float()

        # return [_x, _y]

In [11]:
def update_best_stats(stats_dict,ep,Tr_loss,Vl_loss):

    Tr_loss_array = np.asarray(Tr_loss, dtype=np.float32)
    Vl_loss_array = np.asarray(Vl_loss, dtype=np.float32)

    stats_dict['epoch'] = ep
    stats_dict['running_train_loss'] = np.sum(Tr_loss_array)
    stats_dict['min_train_loss'] = np.min(Tr_loss_array)
    stats_dict['max_train_loss'] = np.max(Tr_loss_array)
    stats_dict['avg_train_loss'] = np.mean(Tr_loss_array)
    stats_dict['99_999p_train_loss'] = np.percentile(Tr_loss_array,99.999)
    stats_dict['99_99p_train_loss'] = np.percentile(Tr_loss_array,99.99)
    stats_dict['99_9p_train_loss'] = np.percentile(Tr_loss_array,99.9)
    stats_dict['99p_train_loss'] = np.percentile(Tr_loss_array,99)
    stats_dict['90p_train_loss'] = np.percentile(Tr_loss_array,90)
    stats_dict['median_train_loss'] = np.percentile(Tr_loss_array,50)

    stats_dict['running_val_loss'] = np.sum(Vl_loss_array)
    stats_dict['min_val_loss'] = np.min(Vl_loss_array)
    stats_dict['max_val_loss'] = np.max(Vl_loss_array)
    stats_dict['avg_val_loss'] = np.mean(Vl_loss_array)
    stats_dict['99_999p_val_loss'] = np.percentile(Vl_loss_array,99.999)
    stats_dict['99_99p_val_loss'] = np.percentile(Vl_loss_array,99.99)
    stats_dict['99_9p_val_loss'] = np.percentile(Vl_loss_array,99.9)
    stats_dict['99p_val_loss'] = np.percentile(Vl_loss_array,99)
    stats_dict['90p_val_loss'] = np.percentile(Vl_loss_array,90)
    stats_dict['92p_val_loss'] = np.percentile(Vl_loss_array,92)
    stats_dict['94p_val_loss'] = np.percentile(Vl_loss_array,94)
    stats_dict['95p_val_loss'] = np.percentile(Vl_loss_array,95)
    stats_dict['96p_val_loss'] = np.percentile(Vl_loss_array,96)
    stats_dict['98p_val_loss'] = np.percentile(Vl_loss_array,98)
    stats_dict['median_val_loss'] = np.percentile(Vl_loss_array,50)

In [12]:
def setup_folders(*args):
    for folder in args:
        if not folder.exists():
            folder.mkdir(mode=0o775, parents=False, exist_ok=False)

In [15]:
def main(args):
    start_time = time.time()

    if args["config"] not in ['ED_2L' , 'ED_2L_AL', 'DR_ED_2L', 'DR_ED_2L_AL']:
        print("Invalid settings configuration")
        assert False
    else:
        if "_AL" in args["config"]:
            lr_factor = 10
        else:
            lr_factor = 1

        settings_config = "ED_2L"

    drp_out = 1 if "DR_" in args["config"] else 0

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    ### Checking for existence of required folders
    output_dir = args["output_model"]
    if not output_dir.exists():
        output_dir.mkdir(mode=0o775, parents=False, exist_ok=False)

    plot_dir = args["plot_dir"]
    if not plot_dir.exists():
        plot_dir.mkdir(mode=0o775, parents=False, exist_ok=False)

    stats_dir = args["stats_dir"]
    if not stats_dir.exists():
        stats_dir.mkdir(mode=0o775, parents=False, exist_ok=False)

    ### loading settings from json
    # with open(str(args["settings"]),'r') as settings_file:
    #     settings = json.load(settings_file)

    ### hyper parameters
    settings = args["settings"]
    # TODO: Prune the parameters
    num_layers_ = settings[settings_config]['num_layers']
    num_epochs = settings[settings_config]['num_epochs']
    batch_size = settings[settings_config]['batch_size']
    sub_seq_len = settings[settings_config]['sub_seq_len']
    train_split = settings[settings_config]['train_split']
    val_split = settings[settings_config]['val_split']
    learning_rate = lr_factor * settings[settings_config]['learning_rate']
    context_dim = settings[settings_config]['context_dim']
    lin_out_dim = settings[settings_config]['lin_out_dim']
    L_func = settings[settings_config]['L_func']
    patience_ = settings[settings_config]['patience']
    optimizer = settings[settings_config]['optimizer']
    test_str = str(args["config"])

    print('Settings:')
    print('--------')
    print('Msg ID: id{0}'.format(args["msg_id"]))
    print('num_epochs {0}'.format(num_epochs))
    print('batch_size {0}'.format(batch_size))
    print('sub_seq_len {0}'.format(sub_seq_len))
    print('learning_rate {0}'.format(learning_rate))
    print('patience_ {0}'.format(patience_))
    print('dropout {0}'.format(drp_out))
    print('configuration {0}'.format(args["config"]))

    ### Pre-processing data_set
    msg_id = 'id' + str(args["msg_id"])
    directory = Path(args["dataset"])
    data_set = prepare_dataset(msg_id)
    num_features = len(data_set.columns)

    ### Pruning data_set to get integer number of sub_sequences
    num_samples_org = len(data_set.index)
    num_ss = num_samples_org//sub_seq_len
    data_set = data_set.iloc[:(num_ss*sub_seq_len)]

    ### Computing train and validation data sizes (in #samples)
    num_samples_adj = len(data_set.index)   # num adjusted samples
    train_size = int(num_samples_adj*train_split)  # data samples used for training
    val_size = num_samples_adj - train_size # data samples used for validation
    # train_size = int(num_samples_adj*0.09)
    # val_size = int(num_samples_adj*0.01)

    ### Creating pyTorch datasets
    train_data = SynCAN_Dataset(data_set[:train_size],sub_seq_len)
    val_data = SynCAN_Dataset(data_set[train_size:train_size+val_size],sub_seq_len)

    ### Creating a DataLoaders
    train_loader = DataLoader(dataset=train_data,batch_size=
                                    batch_size,shuffle=False, drop_last=
                                    True if (train_size % batch_size != 0) else False)

    val_loader = DataLoader(dataset=val_data,batch_size=
                                    batch_size,shuffle=False,drop_last=
                                    True if (val_size % batch_size != 0) else False)

    # available_ram = psutil.virtual_memory().available / (1024 * 1024 * 1024)

    ## TODO: Make the variable names consistant
    model = INDRA_ED_2L(input_dim=num_features, linear_out_dim = lin_out_dim, context_dim= context_dim,
                    num_layers= num_layers_, batch_size= batch_size).to(device)

    if L_func == "BCE":
        print("BCE loss")
        loss_fn = torch.nn.BCELoss().to(device)    ## TODO: Change the loss to the custom loss defined in the paper
    elif L_func == "MSE":
        print("MSE loss")
        loss_fn = torch.nn.MSELoss().to(device)
    elif L_func == "RMSE":
        print("RMSE loss")
        loss_fn = torch.nn.MSELoss().to(device) #sqrt is computed in the training loop
    else:
        print("Invalid loss function")
        assert False

    if optimizer == "ADAM":
        optimiser = torch.optim.Adam(model.parameters(),lr=learning_rate)
    elif optimizer == "SGD":
        optimiser = torch.optim.SGD(model.parameters(),lr=learning_rate)
    else:
        print("Invalid optimiser")
        assert False

    if "_AL" in args["config"]:
        scheduler = ReduceLROnPlateau(optimiser, mode='min', patience=5)

    running_loss_list = []
    val_loss_list = []
    avg_train_loss = []
    avg_val_loss = []
    es_epoch = 0
    best_val_running_loss = None

    stats_columns = ['epoch',
                    'running_train_loss','min_train_loss','max_train_loss','avg_train_loss',
                    'running_val_loss','min_val_loss','max_val_loss','avg_val_loss',
                    'best_epoch']

    best_stats_keys = ['epoch',
                 'running_train_loss','min_train_loss','max_train_loss','avg_train_loss',
                 '99_999p_train_loss', '99_99p_train_loss','99_9p_train_loss','99p_train_loss','90p_train_loss', 'median_train_loss',
                 'running_val_loss','min_val_loss','max_val_loss','avg_val_loss',
                 '99_999p_val_loss', '99_99p_val_loss','99_9p_val_loss','99p_val_loss','90p_val_loss', 'median_val_loss']

    best_stats = dict.fromkeys(best_stats_keys)

    stats_list = [dict.fromkeys(stats_columns) for i in range(num_epochs)]

    fname_prefix = str(msg_id + '_' + L_func + '_' + optimizer + '_' + test_str + '_L' + str(num_layers_))

    es = EarlyStopping(patience = patience_)

    ## plotting settigs
    font = {
        'weight' : 'bold'}#,
        # 'size'   : 18}
    matplotlib.rc('font', **font)

    # ########################### #
    #   Train and validate model  #
    # ########################### #

    for epoch in range(num_epochs):

        print("Epoch: ", epoch+1 , "/" , num_epochs)
        # Initialise hidden state
        # model.hidden1 = model.init_hidden(512).to(device)
        # model.hidden2 = model.init_hidden(num_features).to(device)
        # enc_model.hidden = tuple(i.to(device) for i in model.hidden)
        # dec_model.hidden = tuple(i.to(device) for i in model.hidden)
        hidden1 = torch.zeros(num_layers_,batch_size,context_dim).to(device)
        hidden2 = torch.zeros(num_layers_,batch_size,lin_out_dim).to(device)
        running_loss = 0
        train_itr_loss = []
        val_itr_loss = []
        val_running_loss = 0

        # stats_dict = dict.fromkeys(stats_columns)

        ### Training ###
        model.train()

        for itr, input_batch in enumerate(tqdm(train_loader,
            bar_format="{l_bar}{bar}|", ncols=50, desc='Training'),0):
            # bar_format="{l_bar}{bar}|{n_fmt}/{total_fmt}", ncols=50, desc='Train'),0):
            # input_batch = input_batch.reshape(sub_seq_len,batch_size,num_features).to(device)
            input_batch = input_batch.permute(1,0,2).to(device)
            # pred_batch = pred_batch.reshape(batch_size,num_features).to(device)
            # pred_batch = pred_batch.reshape(sub_seq_len*batch_size,num_features).to(device)

            # Clear stored gradient
            model.zero_grad()

            # forward pass
            reconstruction = model(input_batch, hidden1, hidden2, drp_out)

            # Compute loss
            loss = loss_fn(reconstruction, input_batch)
            if L_func == "RMSE":
                loss= torch.sqrt(loss)

            # Zero out gradient, else they will accumulate between epochs
            optimiser.zero_grad()

            # Backward pass
            loss.backward(retain_graph=False)

            # Update parameters
            optimiser.step()

            # accumulating loss
            running_loss += loss.item()
            train_itr_loss.append(loss.item())

        running_loss_list.append(running_loss)
        avg_train_loss.append(sum(train_itr_loss)/len(train_itr_loss))

        ### evaluate performance on validation set
        model.eval()
        val_hidden1 = torch.zeros(num_layers_,batch_size,context_dim).to(device)
        val_hidden2 = torch.zeros(num_layers_,batch_size,lin_out_dim).to(device)

        with torch.no_grad():
            for val_itr, val_input_batch in enumerate(tqdm(val_loader,
                            bar_format="{l_bar}{bar}|", ncols=50, desc='Validaiton'),0):
                            # bar_format="{l_bar}{bar}|{n_fmt}/{total_fmt}", ncols=50, desc='Val'),0):
                # val_input_batch = val_input_batch.reshape(sub_seq_len,batch_size,num_features).to(device)
                val_input_batch = val_input_batch.permute(1,0,2).to(device)
                # val_pred_batch = val_pred_batch.reshape(batch_size,num_features).to(device)
                # val_pred_batch = val_pred_batch.reshape(sub_seq_len*batch_size,num_features).to(device)

                val_recons = model(val_input_batch, val_hidden1, val_hidden2, drp_out)

                val_loss = loss_fn(val_recons,val_input_batch)
                if L_func == "RMSE":
                    val_loss= torch.sqrt(val_loss)

                val_running_loss += val_loss.item()
                val_itr_loss.append(val_loss.item())

        val_loss_list.append(val_running_loss)
        avg_val_loss.append(sum(val_itr_loss)/len(val_itr_loss))
        if "_AL" in args["config"]:
            scheduler.step(val_running_loss)

        print("")
        print("Train loss: ", running_loss)
        print("Validation loss: ", val_running_loss)
        print("Avg train loss: ", (sum(train_itr_loss)/len(train_itr_loss)))
        print("Avg validation loss: ", (sum(val_itr_loss)/len(val_itr_loss)))
        print("")

        stats_list[epoch]['epoch'] = epoch
        stats_list[epoch]['running_train_loss'] = running_loss
        stats_list[epoch]['min_train_loss'] = min(train_itr_loss)
        stats_list[epoch]['min_train_loss'] = max(train_itr_loss)
        stats_list[epoch]['avg_train_loss'] = sum(train_itr_loss)/len(train_itr_loss)
        stats_list[epoch]['running_val_loss'] = val_running_loss
        stats_list[epoch]['min_val_loss'] = min(val_itr_loss)
        stats_list[epoch]['min_val_loss'] = max(val_itr_loss)
        stats_list[epoch]['avg_val_loss'] = sum(val_itr_loss)/len(val_itr_loss)
        stats_list[epoch]['best_epoch'] = 0

        # stats_dict['epoch'] = epoch
        # stats_dict['running_train_loss'] = running_loss
        # stats_dict['min_train_loss'] = min(train_itr_loss)
        # stats_dict['min_train_loss'] = max(train_itr_loss)
        # stats_dict['avg_train_loss'] = sum(train_itr_loss)/len(train_itr_loss)
        # stats_dict['running_val_loss'] = val_running_loss
        # stats_dict['min_val_loss'] = min(val_itr_loss)
        # stats_dict['min_val_loss'] = max(val_itr_loss)
        # stats_dict['avg_val_loss'] = sum(val_itr_loss)/len(val_itr_loss)
        # stats_dict['best_epoch'] = 0
        # stats_df.loc[len(stats_df)] = stats_dict

        ### Checkpointing and EarlyStopping
        if best_val_running_loss is not None:
            if val_running_loss < best_val_running_loss:
                best_val_running_loss = val_running_loss
                model_name = str(Path(output_dir,fname_prefix + '.pt'))
                torch.save(model.state_dict(), model_name)
                stats_list[epoch]['best_epoch'] = 1
                update_best_stats(best_stats,epoch, train_itr_loss, val_itr_loss)
                assert best_stats['epoch'] == epoch #ensuring correct update
        else:
            best_val_running_loss = val_running_loss
            model_name = str(Path(output_dir,fname_prefix + '.pt'))
            torch.save(model.state_dict(), model_name)
            stats_list[epoch]['best_epoch'] = 1
            update_best_stats(best_stats,epoch, train_itr_loss, val_itr_loss)
            assert best_stats['epoch'] == epoch #ensuring correct update

        if es.step(torch.tensor([val_running_loss])):
            print("Early Stopping the training!")
            es_epoch = epoch - patience_
            break

    if es_epoch == 0:   #Never hit early stopping
        es_epoch = epoch

    ## save the list to a df
    stats_df = pd.DataFrame(stats_list)
    ## save df to disk
    df_fname = str(Path(stats_dir,fname_prefix + '_train' + '.pt'))
    stats_df.to_pickle(df_fname)

    #Save the best_stats to a pickle file
    best_stats_fname = str(Path(stats_dir,fname_prefix + '_train_best' + '.pt'))
    with open(best_stats_fname, 'wb') as handle:
        pickle.dump(best_stats, handle, protocol=pickle.HIGHEST_PROTOCOL)

    ## Saving the model ###
    model_name = str(Path(output_dir, fname_prefix + '_Last' + '.pt'))
    torch.save(model.state_dict(), model_name)

    plt.figure()
    plt.title("Running Loss", fontweight='bold', fontsize='12')
    plt.plot(running_loss_list,label='Training')
    plt.plot(val_loss_list, label = 'Validation')
    plt.axvline(es_epoch, label = 'EarlyStopping',color='red',linestyle='--')
    plt.xlabel('epochs', fontweight='bold', fontsize='12')
    plt.ylabel('Loss', fontweight='bold', fontsize='12')
    plt.legend()
    plt_fname = str(Path(plot_dir, fname_prefix +'_running_loss' + '.png'))
    plt.savefig(plt_fname, dpi=500)

    plt.figure()
    plt.title("Average Loss", fontweight='bold', fontsize='12')
    plt.plot(avg_train_loss,label='Training')
    plt.plot(avg_val_loss, label = 'Validation')
    plt.axvline(es_epoch, label = 'EarlyStopping',color='red',linestyle='--')
    plt.xlabel('epochs', fontweight='bold', fontsize='12')
    plt.ylabel('Loss', fontweight='bold', fontsize='12')
    plt.legend()
    plt_fname = str(Path(plot_dir, fname_prefix +'_avg_loss' + '.png'))
    plt.savefig(plt_fname, dpi=500)

    print("\n\n--- Total time: %s seconds ---" % (time.time() - start_time))

In [16]:
### Simulate args
settings = {
    "LinDec" : {
        "num_layers" : 1,
        "num_epochs" : 200,
        "batch_size" : 128,
        "sub_seq_len": 20,
        "train_split": 0.90,
        "val_split": 0.10,
        "save_epochs": 3,
        "patience": 8,
        "learning_rate" : 0.0001,
        "context_dim" : 512,
        "L_func" : "MSE",
        "optimizer" : "ADAM",
        "max_plt_count" : 8,
        "IS_config" : "avg"

    },
    "ED_2L" : {
        "num_layers" : 1,
        "num_epochs" : 200,
        "batch_size" : 128,
        "sub_seq_len": 20,
        "train_split": 0.90,
        "val_split": 0.10,
        "save_epochs": 3,
        "patience": 8,
        "learning_rate" : 0.0001,
        "lin_out_dim" : 128,
        "context_dim" : 64,
        "L_func" : "MSE",
        "optimizer" : "ADAM",
        "max_plt_count" : 8,
        "IS_config" : "avg"

    },
    "1L_ED" : {
        "num_layers" : 1,
        "num_epochs" : 200,
        "batch_size" : 128,
        "sub_seq_len": 20,
        "train_split": 0.90,
        "val_split": 0.10,
        "save_epochs": 3,
        "patience": 8,
        "learning_rate" : 0.0001,
        "lin_out_dim" : 128,
        "context_dim" : 64,
        "L_func" : "MSE",
        "optimizer" : "ADAM",
        "max_plt_count" : 8,
        "IS_config" : "avg"

    },
    "attacks" : {
        "no_attack" : 0,
        "plateau" : 1,
        "continuous" : 2,
        "playback" : 3,
        "suppress" : 4,
        "flooding" : 5
    },
}
args = {}
args["settings"] = settings
args["dataset"] = Path(Path.home(), 'SynCAN')
args["output_model"] = Path(Path.cwd(),'output_model')
args["plot_dir"] = Path(Path.cwd(),'plots')
args["stats_dir"] = Path(Path.cwd(),'stats')
args["msg_id"] = "2"
args["config"] = "ED_2L"
args["resume"] = False
# main(args)