In [None]:
%pip install tslearn tsai

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import random
import time
import torch.optim as optim
import os
import matplotlib.pyplot as plt
import json

from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.nn.utils import weight_norm
from torch.cuda import device_count
from soft_dtw_cuda import SoftDTW
from tslearn.metrics import dtw, dtw_path
from tsai.imports import *
from tsai.utils import *
from tsai.models.layers import *
from tsai.models.utils import *
from datetime import datetime

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
randomlist = [35, 73, 86, 7, 98] 

In [None]:
# for reproducibility
seed = randomlist[0]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Data Prepreration

In [None]:
class dataManager:
  # the constructor loads the data
  def __init__(self, data_path, data_set, train_size, valid_size):
    self.dataset = data_set
    self.datasets = self.preprocess(data_path, train_size, valid_size)

  def preprocess(self, data_path, train_size, valid_size):
    # load the data
    data = self.load_data(data_path)
    # normalize the data
    if self.dataset == 'cost_data' or self.dataset == 'ecg' or self.dataset == 'synthetic':
        norm_type = 0
    else:
        norm_type = 2
    n_data = self.normalized(norm_type, data)
    # split the data into train/valid/test sets
    if self.dataset == 'electricity' or self.dataset == 'traffic':
        n, m = n_data.shape
    else:
        n, m, d = n_data.shape
    train_indices = int(n * train_size)
    valid_indices = int((train_size+valid_size) * n)
    test_indices = n - valid_indices
    train_data = np.array(n_data[0:train_indices])
    valid_data = np.array(n_data[train_indices:valid_indices])
    test_data = np.array(n_data[valid_indices:n]) 
    # convert the data into supervised problem of x,y
    data_sets = {}
    data_sets['train'], data_sets['valid'], data_sets['test'] = self.generate_data(self, self.dataset,train_data, valid_data, test_data)
    return data_sets

  def parse_ucr(self, filename):
    X = []
    for line in open(filename):
      line = line.strip()
      arr = line.split(",")
      f_line = list(map(float, arr[1:]))
      f_line = np.array(f_line).reshape(-1, 1)
      X.append(f_line)
    return X

  def load_data(self, data_path):
    if self.dataset == 'electricity':
      data_period = [2014]
      df = pd.read_csv(data_path, sep=';', decimal=',')
      df.rename(columns = {"Unnamed: 0": "time"}, inplace = True)
      df['time'] = pd.to_datetime(df['time'])
      df_included = df[df["time"].dt.year.isin(data_period)].reset_index()
      df_included.drop('index', axis=1, inplace=True)
      #hourly level aggregation
      df_included.index = pd.to_datetime(df_included['time'])
      df_included = df_included.resample('1h').mean()
      total_data = np.array([df_included.loc[i,:] for i in df_included.index])

    if self.dataset == 'traffic':
      f = open(data_path)
      total_data = np.loadtxt(f, delimiter=',')
  
    if self.dataset == 'ecg' or self.dataset == 'synthetic':
        data_folder = 'synthetic_control' if self.dataset == 'synthetic' else 'ECG5000'
        tr = os.path.join(data_path, "%s_TRAIN" % data_folder)
        te = os.path.join(data_path, "%s_TEST" % data_folder)
        X_tr = np.array(self.parse_ucr(tr))
        X_te = np.array(self.parse_ucr(te))
        total_data = np.concatenate((X_tr, X_te), axis=0)
    return total_data
  
  def normalized(self, normalize, data):
    dat = np.zeros(data.shape)
    n, m = data.shape
    if (normalize == 0):
        dat = data
    if (normalize == 1): # normalized by the maximum value of entire matrix.
        dat = data / np.max(data)  
    if (normalize == 2): # normlized by the maximum value of each row(sensor).
      for i in range(m):
        dat[:,i] = data[:,i] / np.max(np.abs(data[:,i]))
    return dat
  
  def split_series(self, data):
    P, horizon = 168, 24
    n_len, m = data.shape
    X = np.ndarray((n_len, P, m))
    Y = np.ndarray((n_len, m))

    for i in range(n_len):
      start = i + P
      end = (start+1) + horizon
      if end < n_len:
          X[i,:,:] = torch.from_numpy(data[i:start, :])
          Y[i,:] = torch.from_numpy(data[end, :])
      else:
          break
    return np.float32(X), np.float32(Y)
  
  def generate_data(self, dataset, train_data, valid_data, test_data):
    if dataset == 'ecg' or dataset == 'synthetic':
      proportion=0.6
      len_ts = train_data.shape[1]
      len_input = int(round(len_ts * proportion))
      len_output = len_ts - len_input
      x_train = np.float32(train_data[:, :len_input, :])
      y_train = np.float32(train_data[:, len_input:, 0])
      x_valid = np.float32(valid_data[:, :len_input, :])
      y_valid = np.float32(valid_data[:, len_input:, 0])
      x_test = np.float32(test_data[:, :len_input, :])
      y_test = np.float32(test_data[:, len_input:, 0])

    if dataset == 'electricity' or dataset == 'traffic':
      x_train, y_train = self.split_series(train_data)
      x_valid, y_valid = self.split_series(valid_data)
      x_test, y_test = self.split_series(test_data)
    
    return [x_train, y_train], [x_valid, y_valid], [x_test, y_test]

In [None]:
data_path ='./datasets/traffic.txt'
data_sets = ['ecg', 'synthetic', 'electricity', 'traffic']
data_set = data_sets[0]
ds = dataManager(data_path, data_set)
# Load Data 
trainX, trainY = ds.datasets['train'][0], ds.datasets['train'][1]
validX, validY = ds.datasets['valid'][0], ds.datasets['valid'][1]
testX, testY = ds.datasets['test'][0], ds.datasets['test'][1]

In [None]:
def seed_worker(worker_id):
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(seed)

# Load Data
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)

print('trainX, trainY', trainX.shape, trainY.shape)
print('validX, validY', validX.shape, validY.shape)
print('testX, testY', testX.shape, testY.shape)

workers = 2 * device_count() 
batch_size = 128

datasets = {'train':  TensorDataset(trainX, trainY), 'val': TensorDataset(validX, validY), 'test': TensorDataset(testX, testY)}
dataloaders_dict = {}
for x in ['train', 'val', 'test']: 
  if x == 'train': 
    dataloaders_dict[x] = DataLoader(datasets[x], batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=workers, worker_init_fn=seed_worker, generator=g)
  else:
    dataloaders_dict[x] = DataLoader(datasets[x], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True, num_workers=workers, worker_init_fn=seed_worker, generator=g)


#Configure the dataset and models here.

In [None]:
nets = ['mlp', 'tcn']
loss_types = ['mse', 'soft_dtw', 'dtw_surrogate']

model_type = nets[0]
loss = loss_types[0]

In [None]:
if model_type == 'mlp':
  N_input = trainX.shape[1] * trainX.shape[2]
  N_output = trainY.shape[1]
if model_type == 'tcn':
  N_input =  trainX.shape[1]
  N_output = trainY.shape[1]
print('N_input, N_output', N_input, N_output)

#Network classes and helper functions

In [None]:
class MLP(nn.Module):
  '''
  This class represents the MLP model used for prediction.
  '''
  def __init__(self, input_features, output_features, hidden_units): 
    super(MLP, self).__init__()
    self.hidden = nn.Linear(input_features, hidden_units)
    self.fc = nn.Linear(hidden_units, output_features)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.hidden(x)
    x = self.relu(x)
    x = self.fc(x)
    return x

In [None]:
class TemporalBlock(Module):
  def __init__(self, ni, nf, ks, stride, dilation, padding, dropout=0.):
    self.conv1 =  weight_norm(nn.Conv1d(ni,nf,ks,stride=stride,padding=padding,dilation=dilation))
    self.relu1 = nn.ReLU()
    self.dropout1 = nn.Dropout(dropout)

    self.conv2 =  weight_norm(nn.Conv1d(nf,nf,ks,stride=stride,padding=padding,dilation=dilation))
    self.relu2 = nn.ReLU()
    self.dropout2 = nn.Dropout(dropout)

    self.conv3 =  weight_norm(nn.Conv1d(nf,nf,ks,stride=stride,padding=padding,dilation=dilation))
    self.relu3 = nn.ReLU()
    self.dropout3 = nn.Dropout(dropout)

    self.net = nn.Sequential(self.conv1, self.relu1, self.dropout1,
                              self.conv2, self.relu2, self.dropout2,
                              self.conv3, self.relu3, self.dropout3) #, self.chomp1, self.chomp2, self.chomp3
    self.relu = nn.ReLU()
    self.init_weights()

  def init_weights(self):
    # kaiming initialization
    nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_normal_(self.conv3.weight, mode='fan_in', nonlinearity='relu')

  def forward(self, x):
    out = self.net(x)
    return self.relu(out)

def TemporalConvNet(c_in, layers, ks=2, dropout=0.):
  temp_layers = []
  for i in range(len(layers)):
    dilation_size = 2 ** i
    ni = c_in if i == 0 else layers[i-1]
    nf = layers[i]
    temp_layers += [TemporalBlock(ni, nf, ks, stride=1, dilation=dilation_size, padding=(ks-1) * dilation_size, dropout=dropout)]
  return nn.Sequential(*temp_layers)


class TCN(Module):
  '''
  This class represents the TCN model used for prediction.
  '''
  def __init__(self, c_in, c_out, layers=[128, 64, 32], ks=7, conv_dropout=0., fc_dropout=0.):
    self.tcn = TemporalConvNet(c_in, layers, ks=ks, dropout=conv_dropout)
    self.gap = GAP1d()
    self.dropout = nn.Dropout(fc_dropout) if fc_dropout else None
    self.linear = nn.Linear(layers[-1],c_out)
    self.init_weights()

  def init_weights(self):
    nn.init.kaiming_normal_(self.linear.weight, mode='fan_in', nonlinearity='relu')

  def forward(self, x):
    x = self.tcn(x)
    x = self.gap(x)
    if self.dropout is not None: x = self.dropout(x)
    return self.linear(x)

In [None]:
class SurrogateNetwork(nn.Module):
  '''
  This class represents the surrogate DTW model used for loss function.
  '''
  def __init__(self, dataset_name, batch_size): 
    super().__init__()
    self.input_dim = 1 
    self.output_dim = 1
    self.horizon = {'ecg': 56, 'synthetic': 24, 'traffic_data': 862, 'electricity': 370 }[dataset_name]
    self.bs = batch_size
    self.conv_layer = nn.Sequential(
      nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3),
      nn.BatchNorm1d(64),
      nn.ReLU(),
      nn.Conv1d(64, 32, 3, dilation=2),
      nn.BatchNorm1d(32),
      nn.ReLU(),
      nn.Conv1d(32, 16, 3),
      nn.BatchNorm1d(16),
      nn.ReLU(),
      nn.Conv1d(16, 8, 3),
      nn.ReLU()
    )
    self.fc1 = nn.Linear({56: 368, 24:112, 66:448, 862:6816, 370:2880}[self.horizon], 1) 
    
    # If L1 norm without built-in function
    self.fcOut = nn.Linear(batch_size, 1)
    
    # If L1 norm with built-in function
    #self.l1 = nn.L1Loss()
  
  def forward(self, x1, x2):
    x1 = torch.transpose(x1, 1, 2)
    x1 = self.conv_layer(x1)
    x1 = x1.reshape(x1.size(0), -1)
    x1 = self.fc1(x1) 

    x2 = torch.transpose(x2, 1, 2)
    x2 = self.conv_layer(x2)
    x2 = x2.reshape(x2.size(0), -1)
    x2 = self.fc1(x2)
    
    # L1 norm without built-in function
    x = torch.abs(x1 - x2)
    x = x.permute(1, 0)
    x = self.fcOut(x)

    # L1 norm with built-in function
    #x = self.l1(x1, x2)
    return x

In [None]:
class Record():
  def __init__(self):
    self.metric_criterion = nn.MSELoss()
  
  # calculate the metrics here
  def update(self, output, y):
    loss_mse = torch.tensor(0)
    # MSE    
    loss_mse = self.metric_criterion(output, y)   
    # DTW
    loss_dtw = 0 
    for k in range(batch_size):         
      target_k_cpu = y[k,:,0:1].view(-1).detach().cpu().numpy()
      output_k_cpu = output[k,:,0:1].view(-1).detach().cpu().numpy()
      sim = dtw(target_k_cpu, output_k_cpu)   
      loss_dtw += sim
    loss_dtw = loss_dtw /batch_size
    return loss_mse.item(), loss_dtw

In [None]:
# Helper functions
def plot_losses(plot_values_1, plot_values_2, plt_name, label_1, label_2, plt_title):
  plt.plot(plot_values_1, label=label_1)
  plt.plot(plot_values_2, label=label_2)
  plt.legend()
  plt.title(plt_title)
  plt.savefig(f"plots/{plt_name}.png")
  # plt.show()
  plt.close()

def MetaLoss(my_outputs, my_labels):
  criteria = nn.L1Loss()
  m_loss = criteria(my_outputs, my_labels)
  return m_loss

def set_parameter_requires_grad_true(model):
  for param in model.parameters():
    param.requires_grad = True
  return model

def set_parameter_requires_grad_false(model):
  for param in model.parameters():
    param.requires_grad = False
  return model

def format_variables(var_1, var_2, batch_size): 
  '''
  var_1 == y
  var_2 == y_hat
  '''
  y = var_1.view([batch_size, -1, 1])
  y_hat = var_2.view([batch_size, -1, 1])
  return y, y_hat

# Training routines

In [None]:
import copy
def train_model(model, dataloaders, criterion, optimizer, batch_size, loss_type='dtw_surrogate', num_epochs=25):
  since = time.time()
  loss_history = {'train loss':list(), 'test loss':list()} 
  best_loss = float("inf")
  best_model_wts = copy.deepcopy(model.state_dict())
  record = Record()
  title = f"{data_set}_prediction_model_{loss_type}_{model_type}"
  model = model.to(device)

  for epoch in range(num_epochs):
    print('\n Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
    # Training
    phase = 'train'
    model.train()
    running_train_loss = []
    for inputs, targets in dataloaders[phase]:
      if model_type == 'mlp':
        inputs = inputs.reshape((inputs.shape[0], N_input)).to(device)
      if model_type == 'tcn':
        inputs = inputs.to(device) 
      targets = targets.to(device)
      optimizer.zero_grad()
      outputs = model(inputs)
      targets, outputs = format_variables(targets, outputs, batch_size)
      if loss_type == 'soft_dtw':
        loss_sdtw = criterion(outputs, targets)
        loss = loss_sdtw.mean()
      else:
        loss = criterion(outputs, targets)
      loss.backward()
      optimizer.step()
      running_train_loss.append(loss.item())
    train_epoch_loss = np.mean(running_train_loss)
    loss_history['train loss'].append(train_epoch_loss)
    print('{} Loss: {:.5f}'.format(phase, train_epoch_loss))

    # Evaluating
    phase = 'val' 
    model.eval()
    running_test_loss = []
    with torch.no_grad():
      for inputs, targets in dataloaders[phase]:
        if model_type == 'mlp':
          inputs = inputs.reshape((inputs.shape[0], N_input)).to(device)
        if model_type == 'tcn':
          inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        targets, outputs = format_variables(targets, outputs, batch_size)
        if loss_type == 'soft_dtw':
          loss_sdtw = criterion(outputs, targets)
          loss = loss_sdtw.mean()
        else:
          loss = criterion(outputs, targets)
        running_test_loss.append(loss.item())
    test_epoch_loss = np.mean(running_test_loss)
    if test_epoch_loss < best_loss:
      best_loss = test_epoch_loss
      best_model_wts = copy.deepcopy(model.state_dict()) 
      torch.save(model.state_dict(), f"stored_model/{title}.pt")
      es = 0
    else:
      es += 1
      print("Counter {} of 20".format(es))
      if es > 20:
        print("Early stopping with best_validation_loss: ", best_loss)
        break
    loss_history['test loss'].append(test_epoch_loss)
    print('{} Loss: {:.5f}'.format(phase, test_epoch_loss))
    print('='*60)
  time_elapsed = time.time() - since
  print('Training completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  print('Best validation Loss: {:5f}'.format(best_loss))

  # load best model weights
  model.load_state_dict(best_model_wts)
  if loss_type == 'dtw_surrogate':
    return model, loss_history, criterion
  else:
    return model, loss_history

In [None]:
def inference(model, dataloaders, criterion, batch_size, loss_type='dtw_surrogate'):
  record = Record()
  losses_mse, losses_dtw  = [], []
  model = model.to(device)
  model.eval() 
  running_loss = [] 
  phase = 'test'
  with torch.no_grad():
    for inputs, targets in dataloaders[phase]:
      if model_type == 'mlp':
          inputs = inputs.reshape((inputs.shape[0], N_input)).to(device)
      if model_type == 'tcn':
        inputs = inputs.to(device) 
      targets = targets.to(device)
      outputs = model(inputs)
      targets, outputs = format_variables(targets, outputs, batch_size)
      if loss_type == 'soft_dtw':
        loss_sdtw = criterion(outputs, targets)
        loss = loss_sdtw.mean()
      else:
        loss = criterion(outputs, targets)
      # statistics
      running_loss.append(loss.item())
      # compute metrics
      metric_mse, metric_dtw = record.update(outputs, targets)
      losses_mse.append(metric_mse)
      losses_dtw.append(metric_dtw)
  epoch_loss = np.mean(running_loss) 
  print('For Seed:', seed)
  print('{} Loss: {:.5f}'.format(phase, epoch_loss))
  test_mse_metric = np.mean(losses_mse)
  test_dtw_metric = np.mean(losses_dtw)
  print('Metric MSE: {:.5f}, Metric DTW: {:.5f}'.format(test_mse_metric, test_dtw_metric))
  print('='*60)
  return epoch_loss, test_mse_metric, test_dtw_metric

In [None]:
import copy
def train_model_surrogate(model, model_pred, dataloaders, optimizer, batch_size, num_epochs=25):
  since = time.time()
  loss_history = {'train loss':list(), 'test loss':list()} 
  best_loss = float("inf")
  best_model_wts = copy.deepcopy(model.state_dict())
  title =  f"{data_set}_surrogate_model"
  model = model.to(device)
  model_pred = model_pred.to(device)
  model_pred.eval()

  for epoch in range(num_epochs):
    print('\n Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
    # Training
    phase = 'train'
    model.train()
    running_train_loss = []
    for inputs, targets in dataloaders[phase]:
      if model_type == 'mlp':
          inputs = inputs.reshape((inputs.shape[0], N_input)).to(device)
      if model_type == 'tcn':
        inputs = inputs.to(device) 
      targets = targets.to(device)
      optimizer.zero_grad()
      outputs = model_pred(inputs)
      targets, outputs = format_variables(targets, outputs, batch_size)
      loss_dtw_true = 0
      for k in range(batch_size):
        target_k_cpu = targets[k,:,0:1].view(-1).detach().cpu().numpy()
        output_k_cpu = outputs[k,:,0:1].view(-1).detach().cpu().numpy()
        sim = dtw(target_k_cpu, output_k_cpu)
        loss_dtw_true += sim
      loss_dtw_true = torch.tensor((loss_dtw_true/batch_size), dtype=torch.float32)
      loss_dtw_hat = model(targets, outputs)
      loss = MetaLoss(loss_dtw_hat.cpu(), loss_dtw_true) 
      loss.backward()
      optimizer.step()
      train_loss = loss.item()
      running_train_loss.append(train_loss)
    train_epoch_loss = np.mean(running_train_loss)
    loss_history['train loss'].append(train_epoch_loss)
    print('{} Loss: {:.5f}'.format(phase, train_epoch_loss))

    # Evaluating
    phase = 'val'
    model.eval()
    running_test_loss = []
    with torch.no_grad():
      for inputs, targets in dataloaders[phase]:
        if model_type == 'mlp':
          inputs = inputs.reshape((inputs.shape[0], N_input)).to(device)
        if model_type == 'tcn':
          inputs = inputs.to(device) 
        targets = targets.to(device)
        outputs = model_pred(inputs)
        targets, outputs = format_variables(targets, outputs, batch_size)
        loss_dtw_true = 0
        for k in range(batch_size):
          target_k_cpu = targets[k,:,0:1].view(-1).detach().cpu().numpy()
          output_k_cpu = outputs[k,:,0:1].view(-1).detach().cpu().numpy()
          sim = dtw(target_k_cpu, output_k_cpu)
          loss_dtw_true += sim
        loss_dtw_true = torch.tensor((loss_dtw_true/batch_size), dtype=torch.float32)
        loss_dtw_hat = model(targets, outputs)
        loss = MetaLoss(loss_dtw_hat.cpu(), loss_dtw_true) 
        test_loss = loss.item()
        running_test_loss.append(test_loss)
    test_epoch_loss = np.mean(running_test_loss)
    print('{} Loss: {:.5f}'.format(phase, test_epoch_loss))
    if test_epoch_loss < best_loss:
      best_loss = test_epoch_loss
      best_model_wts = copy.deepcopy(model.state_dict()) 
      torch.save(model.state_dict(), f"stored_model/{title}.pt")
      es = 0
    else:
      es += 1
      print("Counter {} of 20".format(es))
      if es > 20:
        print("Early stopping with best_validation_loss: ", best_loss)
        break
    loss_history['test loss'].append(test_epoch_loss)
    print('='*60)

  time_elapsed = time.time() - since
  print('Training completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  print('Best validation Loss: {:5f}'.format(best_loss))

  # load best model weights
  model.load_state_dict(best_model_wts)
  return model, loss_history, model_pred

In [None]:
def optimizer_run_surrogate():
  k = 10
  # step-1: pre-train FM with MSE
  if model_type == 'mlp':
    model_pred = MLP(N_input, N_output, hidden_units=128)
  if model_type == 'tcn':
    model_pred = TCN(N_input, N_output, ks=4)
  title = f"{data_set}_prediction_model_mse_{model_type}"
  model_pred.load_state_dict(torch.load(f"stored_model/{title}.pt"))

  # step-2: pre-train SM with FM
  model_loss = SurrogateNetwork(data_set, batch_size=batch_size)
  model_pred = set_parameter_requires_grad_false(model_pred)
  optimizer = optim.Adam(model_loss.parameters())
  model_loss, loss_list, model_pred =  train_model_surrogate(model_loss, model_pred, dataloaders_dict, optimizer, batch_size, num_epochs=50)
  plot_losses(loss_list['train loss'], loss_list['test loss'], f"{data_set}_pre_train_surrogate_losses_{model_type}", 'train_loss', 'validation_loss', 'Loss plot')
 
  # step-3: train FM with SM
  for epoch in range(k):
    print('Iteration:', epoch)
    print('='*60)
    # train FM for alpha-steps
    print('Setting requires_grad to True for all modules in FM')
    model_pred = set_parameter_requires_grad_true(model_pred) 
    optimizer = torch.optim.Adam(model_pred.parameters())
    print('Setting requires_grad False for LM')
    criterion = set_parameter_requires_grad_false(model_loss) 
    model_pred, loss_list, model_loss = train_model(model_pred, dataloaders_dict, criterion, optimizer, batch_size, num_epochs=50, loss_type='dtw_surrogate')

    # train LM for beta-steps
    print('Setting requires_grad False for FM')
    model_pred = set_parameter_requires_grad_false(model_pred)
    print('Setting requires_grad True for LM')
    model_loss = set_parameter_requires_grad_true(model_loss)
    optimizer = optim.Adam(model_loss.parameters()) 
    model_loss, loss_list, model_pred =  train_model_surrogate(model_loss, model_pred, dataloaders_dict, optimizer, batch_size, num_epochs=50)
    print('#'*60)

  # step-4 inference
  print('Setting requires_grad to True for all layers in FM')
  model_pred = set_parameter_requires_grad_true(model_pred)
  criterion = set_parameter_requires_grad_false(model_loss)
  test_loss, test_mse_metric, test_dtw_metric = inference(model_pred, dataloaders_dict, criterion, batch_size, loss_type='dtw_surrogate')

  seedperf = {}
  seedperf['Seed'] = data_set + '_' + str(seed)
  # store the results
  seedperf['output_timestamp'] = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
  seedperf['loss_type'], seedperf['test_loss'], seedperf['mse_metric'], seedperf['dtw_metric'] = 'dtw_surrogate', test_loss, test_mse_metric, test_dtw_metric
  json.dump(seedperf, open("stored_model/results_tcn.json", "a"), indent=4)
  print('Results stored!')


In [None]:
def optimizer_run(loss_type):
  seedperf = {}
  seedperf['Seed'] = data_set + '_' + str(seed)
  # initialize the prediction model
  if model_type == 'mlp':
    model_pred = MLP(N_input, N_output, hidden_units=128)
  if model_type == 'tcn':
    model_pred = TCN(N_input, N_output, ks=4)
  if loss_type == 'mse':
    criterion = nn.MSELoss()
  if loss_type == 'soft_dtw':
    criterion = SoftDTW(use_cuda=True, gamma=0.01)
  optimizer = optim.Adam(model_pred.parameters())
  model_pred, loss_list = train_model(model_pred, dataloaders_dict, criterion, optimizer, batch_size, loss_type, num_epochs=25) 
  #plot_values_1, plot_values_2, plt_name, label_1, label_2, plt_title
  plot_losses(loss_list['train loss'], loss_list['test loss'], f"{data_set}_train_prediction_losses_for_{loss_type}_{seed}_{model_type}", 'train_loss', 'validation_loss', f"{data_set}_{loss_type}_{seed}")
  test_loss, test_mse_metric, test_dtw_metric = inference(model_pred, dataloaders_dict, criterion, batch_size, loss_type)
  
  # store the results
  seedperf['output_timestamp'] = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
  seedperf['loss_type'], seedperf['test_loss'], seedperf['mse_metric'], seedperf['dtw_metric'] = loss_type, test_loss, test_mse_metric, test_dtw_metric
  json.dump(seedperf, open("stored_model/results_tcn.json", "a"), indent=4)
  print('Results stored!')

#Start the execution from here.

In [None]:
if loss == 'dtw_surrogate':
  optimizer_run_surrogate()
else:
  optimizer_run(loss_type=loss)

#Plot the results here.

In [None]:
def visualize_results(nets, dataset):
  loss_list = ['mse', 'soft_dtw', 'dtw_surrogate']
  gen_test = iter(dataloaders_dict['test'])
  test_inputs, test_targets = next(gen_test)

  test_inputs  = torch.tensor(test_inputs, dtype=torch.float32).to(device)
  test_targets = torch.tensor(test_targets, dtype=torch.float32).to(device)
  print(test_inputs.shape, test_targets.shape)

  for ind in range(1,5):
    plt.figure()
    plt.rcParams['figure.figsize'] = (17.0,5.0)  
    
    k = 1
    i=0
    for net in nets:
      loss_type = loss_list[i]
      net = net.to(device)
      if model_type == 'mlp':
        inputs = test_inputs.reshape((test_inputs.shape[0], N_input)).to(device)
      if model_type == 'tcn':
        inputs = test_inputs.to(device)
      pred = net(inputs).to(device)

      print(test_inputs.shape, test_targets.shape, pred.shape)
      test_targets, pred = format_variables(test_targets, pred, batch_size)
      input = test_inputs.detach().cpu().numpy()[ind,:,:]
      target = test_targets.detach().cpu().numpy()[ind,:,:]
      preds = pred.detach().cpu().numpy()[ind,:,:]

      plt.subplot(1,3,k)
      plt.plot(range(0,N_input) ,input,label='input',linewidth=3)
      plt.plot(range(N_input-1,N_input+N_output), np.concatenate([ input[N_input-1:N_input], target ]) ,label='target',linewidth=3)   
      plt.plot(range(N_input-1,N_input+N_output),  np.concatenate([ input[N_input-1:N_input], preds ])  ,label='prediction',linewidth=3)       
      #plt.xticks(range(0,40,2))
      plt.legend()
      plt.title(loss_type)
      k = k+1
      i = i+1
    fig1 = plt.gcf()
    plt.show()
    plt.draw()
    fig1.savefig(f'plots/{model_type}_{dataset}_{ind}.png')

In [None]:
# Visualize results
#mse
title = f"{data_set}_prediction_model_mse_{model_type}"
net_mse =  TCN(N_input, N_output, ks=4)
net_mse.load_state_dict(torch.load(f"stored_model/{title}.pt"))
#soft_dtw
title = f"{data_set}_prediction_model_soft_dtw_{model_type}"
net_soft_dtw =  TCN(N_input, N_output, ks=4)
net_soft_dtw.load_state_dict(torch.load(f"stored_model/{title}.pt"))
#dtw_surrogate
title = f"{data_set}_prediction_model_dtw_surrogate_{model_type}"
net_s_dtw =  TCN(N_input, N_output, ks=4)
net_s_dtw.load_state_dict(torch.load(f"stored_model/{title}.pt"))

nets = [net_mse, net_soft_dtw, net_s_dtw]
visualize_results(nets, data_set)