In [1]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
from torch import tensor
import numpy as np
from torch.utils.data import Dataset
from sklearn.metrics import mean_squared_error
import random
import os
import matplotlib.pyplot as plt
#from numba import jit
import pickle
from scipy.interpolate import interp1d
from torch.utils.data import DataLoader, random_split
import torch
from torchsummary import summary
import seaborn as sns
import sys
import torch.nn.functional as F
import pywt
from sklearn.preprocessing import MinMaxScaler
from torch.cuda import FloatTensor

# Req for package
sys.path.append("../")
from SkinLearning.Utils.NN import train, test, DEVICE, getParameterLoss
from SkinLearning.Utils.Dataset import getDataset, getSplit
from SkinLearning.Utils.Plotting import plotParameterBars


torch.backends.cudnn.benchmark = True

In [11]:
# Remove one FC LAyer
class RNN_hidden(nn.Module):
    def __init__(self):
        super(RNN_hidden, self).__init__()
        self.conv1 = nn.Conv1d(2, 128, kernel_size=5, padding=1, bias=False)
        self.pool1 = nn.MaxPool1d(kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm1d(128)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, bias=False)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn3 = nn.BatchNorm1d(512)
        
        self.rnn = nn.RNN(15, 256, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128 , 64),
            nn.ReLU(),
            nn.Linear(64, 6),   
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        h0 = torch.zeros(1, batch_size, 256).to(x.device)
        x, h = self.rnn(x, h0)
        
        x = h[-1].reshape(batch_size, -1)
        x = self.fc(x)
        
        x = x.view(batch_size, 6)
        return x

In [12]:
# Remove one FC LAyer
class RNN_output(nn.Module):
    def __init__(self):
        super(RNN_output, self).__init__()
        self.conv1 = nn.Conv1d(2, 128, kernel_size=5, padding=1, bias=False)
        self.pool1 = nn.MaxPool1d(kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm1d(128)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, bias=False)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn3 = nn.BatchNorm1d(512)
        
        self.rnn = nn.RNN(15, 256, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(256, 6),
        )

    def forward(self, x):
        a = x.shape
        batch_size = x.shape[0]
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        
        h0 = torch.zeros(1, batch_size, 256).to(x.device)
        out, h = self.rnn(x, h0)
  
        x = out[:, -1, :]
        x = x.reshape(batch_size, -1)

        
        x = self.fc(x)
        
        x = x.view(batch_size, 6)
        return x

In [13]:
# Remove one FC LAyer
class RNN_orig(nn.Module):
    def __init__(self):
        super(RNN_orig, self).__init__()
        self.conv1 = nn.Conv1d(2, 128, kernel_size=5, padding=1, bias=False)
        self.pool1 = nn.MaxPool1d(kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm1d(128)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, bias=False)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn3 = nn.BatchNorm1d(512)
        
        self.rnn = nn.RNN(15, 256, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(131072, 1024),
            nn.ReLU(),
            nn.Linear(1024 , 512),
            nn.ReLU(),
            nn.Linear(512, 6),   
        )

    def forward(self, x):
        a = x.shape
        batch_size = x.shape[0]
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        
        h0 = torch.zeros(1, batch_size, 256).to(x.device)
        out, h = self.rnn(x, h0)
  
        x = out
        x = x.reshape(batch_size, -1)

        
        x = self.fc(x)
        
        x = x.view(batch_size, 6)
        return x

In [14]:
# Remove one FC LAyer
class RNN_both(nn.Module):
    def __init__(self):
        super(RNN_both, self).__init__()
        self.conv1 = nn.Conv1d(2, 128, kernel_size=5, padding=1, bias=False)
        self.pool1 = nn.MaxPool1d(kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm1d(128)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, bias=False)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn3 = nn.BatchNorm1d(512)
        
        self.rnn = nn.RNN(15, 256, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256 , 128),
            nn.ReLU(),
            nn.Linear(128, 6),   
        )

    def forward(self, x):
        a = x.shape
        batch_size = x.shape[0]
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        
        h0 = torch.zeros(1, batch_size, 256).to(x.device)
        out, h = self.rnn(x, h0)
  
        x = torch.concat([h[-1], out[:, -1, :]])
        x = x.reshape(batch_size, -1)

        
        x = self.fc(x)
        
        x = x.view(batch_size, 6)
        return x

In [15]:
# Uses full hidden state
class RNN_fh(nn.Module):
    def __init__(self):
        super(RNN_fh, self).__init__()
        self.conv1 = nn.Conv1d(2, 128, kernel_size=5, padding=1, bias=False)
        self.pool1 = nn.MaxPool1d(kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm1d(128)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, bias=False)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn3 = nn.BatchNorm1d(512)
        
        self.rnn = nn.RNN(15, 256, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128 , 64),
            nn.ReLU(),
            nn.Linear(64, 6),   
        )

    def forward(self, x):
        a = x.shape
        batch_size = x.shape[0]
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        
        h0 = torch.zeros(1, batch_size, 256).to(x.device)
        out, h = self.rnn(x, h0)
  
        x = h
        x = x.reshape(batch_size, -1)

        
        x = self.fc(x)
        
        x = x.view(batch_size, 6)
        return x

In [16]:
dataset, scaler = getDataset()

100%|█████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:09<00:00, 246.42it/s]


In [17]:
train_loader, test_loader = getSplit(dataset)

In [21]:
out = RNN_output()

In [22]:
out_train_loss, out_val_loss =  train(train_loader, out, val_loader=test_loader, LR=0.0001, epochs=550)

Using: cuda:0


Epoch 1/550: 100%|██████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 70.44batch/s]
Epoch 2/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 77.37batch/s, lastLoss=0.186, valLoss=0.265]
Epoch 3/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.97batch/s, lastLoss=0.125, valLoss=0.155]
Epoch 4/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 77.72batch/s, lastLoss=0.116, valLoss=0.116]
Epoch 5/550: 100%|████████████████████████████████████| 56/56 [00:00<00:00, 77.35batch/s, lastLoss=0.113, valLoss=0.11]
Epoch 6/550: 100%|████████████████████████████████████| 56/56 [00:00<00:00, 77.92batch/s, lastLoss=0.11, valLoss=0.103]
Epoch 7/550: 100%|████████████████████████████████████| 56/56 [00:00<00:00, 77.14batch/s, lastLoss=0.107, valLoss=0.11]
Epoch 8/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 78.30batch/s, lastLoss=0.106, valLoss=0.105]
Epoch 9/550: 100%|██████████████████████

Epoch 69/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.36batch/s, lastLoss=0.0794, valLoss=0.0766]
Epoch 70/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.15batch/s, lastLoss=0.0749, valLoss=0.0857]
Epoch 71/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.30batch/s, lastLoss=0.0775, valLoss=0.0772]
Epoch 72/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 78.92batch/s, lastLoss=0.0733, valLoss=0.0734]
Epoch 73/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.38batch/s, lastLoss=0.0797, valLoss=0.0785]
Epoch 74/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.65batch/s, lastLoss=0.0811, valLoss=0.0776]
Epoch 75/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 77.93batch/s, lastLoss=0.0772, valLoss=0.0794]
Epoch 76/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 79.70batch/s, lastLoss=0.075, valLoss=0.0737]
Epoch 77/550: 100%|█████████████████████

Epoch 137/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 81.69batch/s, lastLoss=0.0659, valLoss=0.0633]
Epoch 138/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 81.97batch/s, lastLoss=0.0676, valLoss=0.0665]
Epoch 139/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 82.78batch/s, lastLoss=0.067, valLoss=0.0647]
Epoch 140/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 82.60batch/s, lastLoss=0.0676, valLoss=0.0631]
Epoch 141/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 75.91batch/s, lastLoss=0.0657, valLoss=0.0581]
Epoch 142/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.24batch/s, lastLoss=0.0658, valLoss=0.0662]
Epoch 143/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.28batch/s, lastLoss=0.0678, valLoss=0.0755]
Epoch 144/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.67batch/s, lastLoss=0.0677, valLoss=0.0609]
Epoch 145/550: 100%|████████████████████

Epoch 205/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.57batch/s, lastLoss=0.0583, valLoss=0.0608]
Epoch 206/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.52batch/s, lastLoss=0.0613, valLoss=0.0614]
Epoch 207/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.09batch/s, lastLoss=0.0591, valLoss=0.0572]
Epoch 208/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.40batch/s, lastLoss=0.0589, valLoss=0.0575]
Epoch 209/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.72batch/s, lastLoss=0.0587, valLoss=0.0572]
Epoch 210/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.16batch/s, lastLoss=0.0577, valLoss=0.0573]
Epoch 211/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 78.79batch/s, lastLoss=0.058, valLoss=0.0582]
Epoch 212/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 78.86batch/s, lastLoss=0.0576, valLoss=0.061]
Epoch 213/550: 100%|████████████████████

Epoch 273/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.76batch/s, lastLoss=0.0535, valLoss=0.0578]
Epoch 274/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.23batch/s, lastLoss=0.0526, valLoss=0.0561]
Epoch 275/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.96batch/s, lastLoss=0.0528, valLoss=0.0505]
Epoch 276/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.67batch/s, lastLoss=0.0522, valLoss=0.0507]
Epoch 277/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.89batch/s, lastLoss=0.0542, valLoss=0.0556]
Epoch 278/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.64batch/s, lastLoss=0.0512, valLoss=0.0514]
Epoch 279/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.11batch/s, lastLoss=0.0541, valLoss=0.0543]
Epoch 280/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.54batch/s, lastLoss=0.0529, valLoss=0.0529]
Epoch 281/550: 100%|████████████████████

Epoch 341/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.52batch/s, lastLoss=0.0498, valLoss=0.0493]
Epoch 342/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 78.51batch/s, lastLoss=0.0489, valLoss=0.055]
Epoch 343/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.84batch/s, lastLoss=0.0498, valLoss=0.0535]
Epoch 344/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.06batch/s, lastLoss=0.0499, valLoss=0.0478]
Epoch 345/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.20batch/s, lastLoss=0.0489, valLoss=0.0493]
Epoch 346/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.67batch/s, lastLoss=0.0506, valLoss=0.0455]
Epoch 347/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.47batch/s, lastLoss=0.0481, valLoss=0.0461]
Epoch 348/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.12batch/s, lastLoss=0.0502, valLoss=0.0539]
Epoch 349/550: 100%|████████████████████

Epoch 409/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.91batch/s, lastLoss=0.0487, valLoss=0.0434]
Epoch 410/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.32batch/s, lastLoss=0.0439, valLoss=0.0462]
Epoch 411/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.04batch/s, lastLoss=0.044, valLoss=0.0443]
Epoch 412/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.11batch/s, lastLoss=0.0453, valLoss=0.0449]
Epoch 413/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.34batch/s, lastLoss=0.0451, valLoss=0.0507]
Epoch 414/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.03batch/s, lastLoss=0.0455, valLoss=0.0452]
Epoch 415/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.78batch/s, lastLoss=0.0449, valLoss=0.0474]
Epoch 416/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.50batch/s, lastLoss=0.0452, valLoss=0.0509]
Epoch 417/550: 100%|████████████████████

Epoch 477/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 78.69batch/s, lastLoss=0.041, valLoss=0.0427]
Epoch 478/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.44batch/s, lastLoss=0.0405, valLoss=0.045]
Epoch 479/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.38batch/s, lastLoss=0.0425, valLoss=0.0436]
Epoch 480/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 79.22batch/s, lastLoss=0.0416, valLoss=0.051]
Epoch 481/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 78.45batch/s, lastLoss=0.041, valLoss=0.0449]
Epoch 482/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.94batch/s, lastLoss=0.0419, valLoss=0.0485]
Epoch 483/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.30batch/s, lastLoss=0.0413, valLoss=0.0402]
Epoch 484/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.26batch/s, lastLoss=0.0405, valLoss=0.0438]
Epoch 485/550: 100%|████████████████████

Epoch 545/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.36batch/s, lastLoss=0.0405, valLoss=0.0399]
Epoch 546/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.31batch/s, lastLoss=0.0386, valLoss=0.0433]
Epoch 547/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.39batch/s, lastLoss=0.0389, valLoss=0.0393]
Epoch 548/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.04batch/s, lastLoss=0.0393, valLoss=0.0382]
Epoch 549/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 78.55batch/s, lastLoss=0.0387, valLoss=0.0412]
Epoch 550/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 79.06batch/s, lastLoss=0.0401, valLoss=0.0419]


Average train loss: 0.058039054373859776
Average validation loss: 0.05829862871100054


In [25]:
def test(test_loader, net, scaler):
    net.to(DEVICE)
    net.eval()
    criterion = nn.L1Loss()

    losses = []
    p_losses = []
    mae = []
    orig = []

    with torch.no_grad():
        with tqdm(test_loader, unit=" batch") as it:
            for idx, data in enumerate(it):
                inp, out = data['input'].to(DEVICE), data['output'].to(DEVICE)

                predicted = net(inp)
                
                # Denormalise
                p = scaler.inverse_transform(predicted.cpu().numpy())
                o = scaler.inverse_transform(out.cpu().numpy())
                    
                # Get column wise and overall MAPE
                # Since each column is normalised should also be able to use MAE*100
                p_loss = np.mean(100*(np.abs(o-p)/o), axis=0)
                loss = np.mean(100*(np.abs(o-p)/o))

                curr_mae = criterion(predicted, out).item()
                mae.append(curr_mae)
        
            
                p_losses.append(p_loss)
                losses.append(loss)
                orig.append(curr_mae)

            
    average_mape = 100 - np.mean(losses)
    average_p_loss = 100 - np.mean(p_losses, axis=0)
    mae_mean = np.mean(mae)
    mape2 = 100 - np.mean(orig)
    
    return average_mape, average_p_loss, mae_mean, mape2

In [27]:
test(test_loader, out, scaler)

100%|█████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 147.08 batch/s]


(93.49792051315308,
 array([95.07151 , 88.322296, 99.78066 , 99.81861 , 83.50792 , 94.486534],
       dtype=float32),
 0.04341856713096301,
 99.95658143286904)

In [28]:
hidden = RNN_hidden()

In [29]:
hidden_train_loss, hidden_val_loss =  train(train_loader, hidden, val_loader=test_loader, LR=0.0001, epochs=550)

Using: cuda:0


Epoch 1/550: 100%|██████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 69.27batch/s]
Epoch 2/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 74.91batch/s, lastLoss=0.256, valLoss=0.198]
Epoch 3/550: 100%|████████████████████████████████████| 56/56 [00:00<00:00, 75.99batch/s, lastLoss=0.17, valLoss=0.173]
Epoch 4/550: 100%|████████████████████████████████████| 56/56 [00:00<00:00, 75.27batch/s, lastLoss=0.15, valLoss=0.135]
Epoch 5/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.26batch/s, lastLoss=0.128, valLoss=0.121]
Epoch 6/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.90batch/s, lastLoss=0.119, valLoss=0.112]
Epoch 7/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 74.87batch/s, lastLoss=0.112, valLoss=0.114]
Epoch 8/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 75.37batch/s, lastLoss=0.109, valLoss=0.106]
Epoch 9/550: 100%|██████████████████████

Epoch 69/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.94batch/s, lastLoss=0.0689, valLoss=0.0684]
Epoch 70/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 77.29batch/s, lastLoss=0.0675, valLoss=0.0725]
Epoch 71/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.87batch/s, lastLoss=0.0669, valLoss=0.0666]
Epoch 72/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 74.28batch/s, lastLoss=0.0708, valLoss=0.0619]
Epoch 73/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 73.15batch/s, lastLoss=0.0688, valLoss=0.0667]
Epoch 74/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 74.22batch/s, lastLoss=0.0677, valLoss=0.0661]
Epoch 75/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 73.54batch/s, lastLoss=0.0667, valLoss=0.0681]
Epoch 76/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 73.06batch/s, lastLoss=0.0664, valLoss=0.0642]
Epoch 77/550: 100%|█████████████████████

Epoch 137/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 74.27batch/s, lastLoss=0.0586, valLoss=0.053]
Epoch 138/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 75.34batch/s, lastLoss=0.0585, valLoss=0.0583]
Epoch 139/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.08batch/s, lastLoss=0.0592, valLoss=0.0581]
Epoch 140/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.70batch/s, lastLoss=0.0581, valLoss=0.0567]
Epoch 141/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.92batch/s, lastLoss=0.0563, valLoss=0.0607]
Epoch 142/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.97batch/s, lastLoss=0.0595, valLoss=0.0587]
Epoch 143/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.09batch/s, lastLoss=0.0578, valLoss=0.058]
Epoch 144/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.55batch/s, lastLoss=0.0581, valLoss=0.0561]
Epoch 145/550: 100%|████████████████████

Epoch 205/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 77.06batch/s, lastLoss=0.053, valLoss=0.0569]
Epoch 206/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.93batch/s, lastLoss=0.0509, valLoss=0.0478]
Epoch 207/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.76batch/s, lastLoss=0.0521, valLoss=0.0491]
Epoch 208/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.58batch/s, lastLoss=0.0512, valLoss=0.0576]
Epoch 209/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.78batch/s, lastLoss=0.0487, valLoss=0.0537]
Epoch 210/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.02batch/s, lastLoss=0.0518, valLoss=0.0471]
Epoch 211/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.23batch/s, lastLoss=0.0492, valLoss=0.0503]
Epoch 212/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.84batch/s, lastLoss=0.051, valLoss=0.0591]
Epoch 213/550: 100%|████████████████████

Epoch 273/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.20batch/s, lastLoss=0.0461, valLoss=0.0523]
Epoch 274/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.80batch/s, lastLoss=0.0452, valLoss=0.0419]
Epoch 275/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.87batch/s, lastLoss=0.0444, valLoss=0.0464]
Epoch 276/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.88batch/s, lastLoss=0.0459, valLoss=0.0547]
Epoch 277/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.27batch/s, lastLoss=0.0477, valLoss=0.05]
Epoch 278/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.86batch/s, lastLoss=0.0447, valLoss=0.0429]
Epoch 279/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.33batch/s, lastLoss=0.0446, valLoss=0.0489]
Epoch 280/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.42batch/s, lastLoss=0.0458, valLoss=0.05]
Epoch 281/550: 100%|████████████████████

Epoch 341/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.04batch/s, lastLoss=0.0425, valLoss=0.0423]
Epoch 342/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.10batch/s, lastLoss=0.0417, valLoss=0.0443]
Epoch 343/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.54batch/s, lastLoss=0.0417, valLoss=0.0429]
Epoch 344/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 77.06batch/s, lastLoss=0.042, valLoss=0.0416]
Epoch 345/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.17batch/s, lastLoss=0.0406, valLoss=0.0442]
Epoch 346/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.76batch/s, lastLoss=0.0423, valLoss=0.0475]
Epoch 347/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.37batch/s, lastLoss=0.0455, valLoss=0.0436]
Epoch 348/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.06batch/s, lastLoss=0.0417, valLoss=0.0423]
Epoch 349/550: 100%|████████████████████

Epoch 409/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.60batch/s, lastLoss=0.0379, valLoss=0.0424]
Epoch 410/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.04batch/s, lastLoss=0.0414, valLoss=0.0455]
Epoch 411/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.91batch/s, lastLoss=0.0387, valLoss=0.0422]
Epoch 412/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.73batch/s, lastLoss=0.0395, valLoss=0.0483]
Epoch 413/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.04batch/s, lastLoss=0.0386, valLoss=0.0402]
Epoch 414/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.16batch/s, lastLoss=0.0383, valLoss=0.0387]
Epoch 415/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.03batch/s, lastLoss=0.0388, valLoss=0.0437]
Epoch 416/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.06batch/s, lastLoss=0.0364, valLoss=0.0408]
Epoch 417/550: 100%|████████████████████

Epoch 477/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.85batch/s, lastLoss=0.036, valLoss=0.0403]
Epoch 478/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.15batch/s, lastLoss=0.0348, valLoss=0.0374]
Epoch 479/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.87batch/s, lastLoss=0.0337, valLoss=0.041]
Epoch 480/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.99batch/s, lastLoss=0.0364, valLoss=0.0381]
Epoch 481/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.62batch/s, lastLoss=0.034, valLoss=0.0361]
Epoch 482/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.40batch/s, lastLoss=0.0351, valLoss=0.038]
Epoch 483/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.40batch/s, lastLoss=0.0343, valLoss=0.0357]
Epoch 484/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.77batch/s, lastLoss=0.0355, valLoss=0.0377]
Epoch 485/550: 100%|████████████████████

Epoch 545/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.54batch/s, lastLoss=0.0331, valLoss=0.0374]
Epoch 546/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.40batch/s, lastLoss=0.0356, valLoss=0.0395]
Epoch 547/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.49batch/s, lastLoss=0.0364, valLoss=0.0377]
Epoch 548/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.01batch/s, lastLoss=0.0342, valLoss=0.0392]
Epoch 549/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.54batch/s, lastLoss=0.034, valLoss=0.0356]
Epoch 550/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 75.38batch/s, lastLoss=0.0339, valLoss=0.044]


Average train loss: 0.05106937584231378
Average validation loss: 0.05266756271373367


In [30]:
orig = RNN_orig()

In [31]:
orig_train_loss, orig_val_loss =  train(train_loader, orig, val_loader=test_loader, LR=0.0001, epochs=550)

Using: cuda:0


Epoch 1/550: 100%|██████████████████████████████████████████████████████████████████| 56/56 [00:02<00:00, 20.01batch/s]
Epoch 2/550: 100%|███████████████████████████████████| 56/56 [00:02<00:00, 20.25batch/s, lastLoss=0.157, valLoss=0.266]
Epoch 3/550: 100%|███████████████████████████████████| 56/56 [00:02<00:00, 20.27batch/s, lastLoss=0.116, valLoss=0.166]
Epoch 4/550: 100%|███████████████████████████████████| 56/56 [00:02<00:00, 20.31batch/s, lastLoss=0.108, valLoss=0.108]
Epoch 5/550: 100%|██████████████████████████████████| 56/56 [00:02<00:00, 20.30batch/s, lastLoss=0.104, valLoss=0.0958]
Epoch 6/550: 100%|████████████████████████████████████| 56/56 [00:02<00:00, 20.32batch/s, lastLoss=0.1, valLoss=0.0911]
Epoch 7/550: 100%|█████████████████████████████████| 56/56 [00:02<00:00, 20.34batch/s, lastLoss=0.0946, valLoss=0.0959]
Epoch 8/550: 100%|█████████████████████████████████| 56/56 [00:02<00:00, 20.32batch/s, lastLoss=0.0909, valLoss=0.0914]
Epoch 9/550: 100%|██████████████████████

Epoch 69/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.30batch/s, lastLoss=0.0568, valLoss=0.0552]
Epoch 70/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.25batch/s, lastLoss=0.0536, valLoss=0.0547]
Epoch 71/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.27batch/s, lastLoss=0.0552, valLoss=0.0561]
Epoch 72/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.26batch/s, lastLoss=0.0566, valLoss=0.0586]
Epoch 73/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.17batch/s, lastLoss=0.0554, valLoss=0.0568]
Epoch 74/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.27batch/s, lastLoss=0.0542, valLoss=0.0629]
Epoch 75/550: 100%|█████████████████████████████████| 56/56 [00:02<00:00, 20.26batch/s, lastLoss=0.058, valLoss=0.0603]
Epoch 76/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.20batch/s, lastLoss=0.0585, valLoss=0.0715]
Epoch 77/550: 100%|█████████████████████

Epoch 137/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0406, valLoss=0.041]
Epoch 138/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.21batch/s, lastLoss=0.0413, valLoss=0.0487]
Epoch 139/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.13batch/s, lastLoss=0.0423, valLoss=0.0481]
Epoch 140/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.14batch/s, lastLoss=0.0402, valLoss=0.0457]
Epoch 141/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.13batch/s, lastLoss=0.0376, valLoss=0.0451]
Epoch 142/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.18batch/s, lastLoss=0.0387, valLoss=0.0473]
Epoch 143/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.13batch/s, lastLoss=0.0378, valLoss=0.0442]
Epoch 144/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.16batch/s, lastLoss=0.0401, valLoss=0.0434]
Epoch 145/550: 100%|████████████████████

Epoch 205/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.16batch/s, lastLoss=0.0365, valLoss=0.0493]
Epoch 206/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.18batch/s, lastLoss=0.0342, valLoss=0.0447]
Epoch 207/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.23batch/s, lastLoss=0.033, valLoss=0.0456]
Epoch 208/550: 100%|█████████████████████████████████| 56/56 [00:02<00:00, 20.19batch/s, lastLoss=0.0329, valLoss=0.04]
Epoch 209/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0348, valLoss=0.0428]
Epoch 210/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.18batch/s, lastLoss=0.0359, valLoss=0.0431]
Epoch 211/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.17batch/s, lastLoss=0.0317, valLoss=0.0396]
Epoch 212/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.13batch/s, lastLoss=0.0332, valLoss=0.0484]
Epoch 213/550: 100%|████████████████████

Epoch 273/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.22batch/s, lastLoss=0.0281, valLoss=0.0406]
Epoch 274/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.17batch/s, lastLoss=0.0276, valLoss=0.0404]
Epoch 275/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.17batch/s, lastLoss=0.0269, valLoss=0.0431]
Epoch 276/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.18batch/s, lastLoss=0.0264, valLoss=0.0434]
Epoch 277/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0268, valLoss=0.0414]
Epoch 278/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0282, valLoss=0.0464]
Epoch 279/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0264, valLoss=0.0378]
Epoch 280/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.18batch/s, lastLoss=0.0286, valLoss=0.0508]
Epoch 281/550: 100%|████████████████████

Epoch 341/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.14batch/s, lastLoss=0.0238, valLoss=0.0368]
Epoch 342/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.19batch/s, lastLoss=0.024, valLoss=0.0383]
Epoch 343/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.023, valLoss=0.0363]
Epoch 344/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.20batch/s, lastLoss=0.0252, valLoss=0.0388]
Epoch 345/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0236, valLoss=0.0395]
Epoch 346/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.17batch/s, lastLoss=0.0228, valLoss=0.0398]
Epoch 347/550: 100%|█████████████████████████████████| 56/56 [00:02<00:00, 20.25batch/s, lastLoss=0.0248, valLoss=0.04]
Epoch 348/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.17batch/s, lastLoss=0.0232, valLoss=0.0385]
Epoch 349/550: 100%|████████████████████

Epoch 409/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.16batch/s, lastLoss=0.0208, valLoss=0.0373]
Epoch 410/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.17batch/s, lastLoss=0.0199, valLoss=0.0374]
Epoch 411/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.16batch/s, lastLoss=0.0208, valLoss=0.0405]
Epoch 412/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.16batch/s, lastLoss=0.0209, valLoss=0.0389]
Epoch 413/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.16batch/s, lastLoss=0.0202, valLoss=0.041]
Epoch 414/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0194, valLoss=0.0409]
Epoch 415/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.16batch/s, lastLoss=0.0217, valLoss=0.0386]
Epoch 416/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.15batch/s, lastLoss=0.0211, valLoss=0.0357]
Epoch 417/550: 100%|████████████████████

Epoch 477/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.32batch/s, lastLoss=0.0194, valLoss=0.0354]
Epoch 478/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.28batch/s, lastLoss=0.018, valLoss=0.0357]
Epoch 479/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.29batch/s, lastLoss=0.0201, valLoss=0.0388]
Epoch 480/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.29batch/s, lastLoss=0.0195, valLoss=0.0391]
Epoch 481/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.28batch/s, lastLoss=0.0193, valLoss=0.0353]
Epoch 482/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.30batch/s, lastLoss=0.0187, valLoss=0.038]
Epoch 483/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.28batch/s, lastLoss=0.0213, valLoss=0.0397]
Epoch 484/550: 100%|█████████████████████████████████| 56/56 [00:02<00:00, 20.28batch/s, lastLoss=0.02, valLoss=0.0368]
Epoch 485/550: 100%|████████████████████

Epoch 545/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.33batch/s, lastLoss=0.0161, valLoss=0.0411]
Epoch 546/550: 100%|████████████████████████████████| 56/56 [00:02<00:00, 20.31batch/s, lastLoss=0.0162, valLoss=0.035]
Epoch 547/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.29batch/s, lastLoss=0.0173, valLoss=0.0352]
Epoch 548/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.29batch/s, lastLoss=0.0177, valLoss=0.0378]
Epoch 549/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.27batch/s, lastLoss=0.0167, valLoss=0.0355]
Epoch 550/550: 100%|███████████████████████████████| 56/56 [00:02<00:00, 20.30batch/s, lastLoss=0.0162, valLoss=0.0362]


Average train loss: 0.034347546356609214
Average validation loss: 0.046026647793693524


In [32]:
fh = RNN_fh()

In [33]:
fh_train_loss, fh_val_loss =  train(train_loader, fh, val_loader=test_loader, LR=0.0001, epochs=550)

Using: cuda:0


Epoch 1/550: 100%|██████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 75.20batch/s]
Epoch 2/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.34batch/s, lastLoss=0.248, valLoss=0.245]
Epoch 3/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.26batch/s, lastLoss=0.161, valLoss=0.169]
Epoch 4/550: 100%|████████████████████████████████████| 56/56 [00:00<00:00, 76.04batch/s, lastLoss=0.14, valLoss=0.124]
Epoch 5/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.25batch/s, lastLoss=0.121, valLoss=0.121]
Epoch 6/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.21batch/s, lastLoss=0.113, valLoss=0.107]
Epoch 7/550: 100%|████████████████████████████████████| 56/56 [00:00<00:00, 76.28batch/s, lastLoss=0.107, valLoss=0.11]
Epoch 8/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 76.01batch/s, lastLoss=0.103, valLoss=0.0983]
Epoch 9/550: 100%|██████████████████████

Epoch 69/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.17batch/s, lastLoss=0.0695, valLoss=0.0691]
Epoch 70/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.12batch/s, lastLoss=0.0723, valLoss=0.0678]
Epoch 71/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.23batch/s, lastLoss=0.0698, valLoss=0.0651]
Epoch 72/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.20batch/s, lastLoss=0.0659, valLoss=0.0635]
Epoch 73/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.57batch/s, lastLoss=0.0698, valLoss=0.0643]
Epoch 74/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 75.29batch/s, lastLoss=0.0698, valLoss=0.0612]
Epoch 75/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 75.74batch/s, lastLoss=0.0686, valLoss=0.0682]
Epoch 76/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.27batch/s, lastLoss=0.0708, valLoss=0.0672]
Epoch 77/550: 100%|█████████████████████

Epoch 137/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.47batch/s, lastLoss=0.0542, valLoss=0.0531]
Epoch 138/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.54batch/s, lastLoss=0.0558, valLoss=0.051]
Epoch 139/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.65batch/s, lastLoss=0.0528, valLoss=0.0571]
Epoch 140/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.94batch/s, lastLoss=0.0563, valLoss=0.0531]
Epoch 141/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.91batch/s, lastLoss=0.0582, valLoss=0.0562]
Epoch 142/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.39batch/s, lastLoss=0.0588, valLoss=0.062]
Epoch 143/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.60batch/s, lastLoss=0.0574, valLoss=0.0671]
Epoch 144/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.71batch/s, lastLoss=0.0557, valLoss=0.0515]
Epoch 145/550: 100%|████████████████████

Epoch 205/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.57batch/s, lastLoss=0.0496, valLoss=0.0514]
Epoch 206/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.07batch/s, lastLoss=0.0499, valLoss=0.0525]
Epoch 207/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 75.83batch/s, lastLoss=0.0511, valLoss=0.0519]
Epoch 208/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.41batch/s, lastLoss=0.0535, valLoss=0.0542]
Epoch 209/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.93batch/s, lastLoss=0.0498, valLoss=0.05]
Epoch 210/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.75batch/s, lastLoss=0.0493, valLoss=0.0538]
Epoch 211/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.10batch/s, lastLoss=0.0498, valLoss=0.0499]
Epoch 212/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.77batch/s, lastLoss=0.0514, valLoss=0.0538]
Epoch 213/550: 100%|████████████████████

Epoch 273/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.63batch/s, lastLoss=0.0468, valLoss=0.0473]
Epoch 274/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.71batch/s, lastLoss=0.0476, valLoss=0.0468]
Epoch 275/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.06batch/s, lastLoss=0.0461, valLoss=0.0432]
Epoch 276/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.01batch/s, lastLoss=0.0445, valLoss=0.0472]
Epoch 277/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.02batch/s, lastLoss=0.0456, valLoss=0.0478]
Epoch 278/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.65batch/s, lastLoss=0.0454, valLoss=0.0443]
Epoch 279/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.17batch/s, lastLoss=0.0448, valLoss=0.0439]
Epoch 280/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.01batch/s, lastLoss=0.0455, valLoss=0.0462]
Epoch 281/550: 100%|████████████████████

Epoch 341/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.74batch/s, lastLoss=0.0406, valLoss=0.042]
Epoch 342/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.87batch/s, lastLoss=0.041, valLoss=0.0427]
Epoch 343/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 75.92batch/s, lastLoss=0.0413, valLoss=0.0459]
Epoch 344/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.62batch/s, lastLoss=0.0417, valLoss=0.0412]
Epoch 345/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.86batch/s, lastLoss=0.0429, valLoss=0.049]
Epoch 346/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.12batch/s, lastLoss=0.0425, valLoss=0.04]
Epoch 347/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.42batch/s, lastLoss=0.041, valLoss=0.0422]
Epoch 348/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.91batch/s, lastLoss=0.0445, valLoss=0.0499]
Epoch 349/550: 100%|████████████████████

Epoch 409/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.43batch/s, lastLoss=0.0392, valLoss=0.0409]
Epoch 410/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.54batch/s, lastLoss=0.0399, valLoss=0.0418]
Epoch 411/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.76batch/s, lastLoss=0.0388, valLoss=0.0395]
Epoch 412/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.58batch/s, lastLoss=0.0415, valLoss=0.0461]
Epoch 413/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.12batch/s, lastLoss=0.0382, valLoss=0.0417]
Epoch 414/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.72batch/s, lastLoss=0.0374, valLoss=0.0487]
Epoch 415/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 75.82batch/s, lastLoss=0.0373, valLoss=0.0416]
Epoch 416/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.52batch/s, lastLoss=0.0387, valLoss=0.0396]
Epoch 417/550: 100%|████████████████████

Epoch 477/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.53batch/s, lastLoss=0.0361, valLoss=0.0439]
Epoch 478/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.18batch/s, lastLoss=0.0358, valLoss=0.0428]
Epoch 479/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 75.92batch/s, lastLoss=0.036, valLoss=0.042]
Epoch 480/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.63batch/s, lastLoss=0.0351, valLoss=0.0409]
Epoch 481/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.74batch/s, lastLoss=0.037, valLoss=0.0389]
Epoch 482/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.17batch/s, lastLoss=0.0366, valLoss=0.0403]
Epoch 483/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 77.19batch/s, lastLoss=0.035, valLoss=0.0397]
Epoch 484/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.65batch/s, lastLoss=0.0375, valLoss=0.0408]
Epoch 485/550: 100%|████████████████████

Epoch 545/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.69batch/s, lastLoss=0.0347, valLoss=0.0448]
Epoch 546/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.88batch/s, lastLoss=0.0341, valLoss=0.0404]
Epoch 547/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.67batch/s, lastLoss=0.0333, valLoss=0.0403]
Epoch 548/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 76.54batch/s, lastLoss=0.0332, valLoss=0.0371]
Epoch 549/550: 100%|███████████████████████████████| 56/56 [00:00<00:00, 77.13batch/s, lastLoss=0.0351, valLoss=0.0403]
Epoch 550/550: 100%|████████████████████████████████| 56/56 [00:00<00:00, 76.88batch/s, lastLoss=0.0349, valLoss=0.038]


Average train loss: 0.05092195909754555
Average validation loss: 0.05235694397427142


In [34]:
both = RNN_both()

In [35]:
both_train_loss, both_val_loss =  train(train_loader, both, val_loader=test_loader, LR=0.0001, epochs=550)

Using: cuda:0


Epoch 1/550: 100%|██████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 73.70batch/s]
Epoch 2/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 74.52batch/s, lastLoss=0.237, valLoss=0.219]
Epoch 3/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 77.18batch/s, lastLoss=0.186, valLoss=0.189]
Epoch 4/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.29batch/s, lastLoss=0.186, valLoss=0.184]
Epoch 5/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 77.22batch/s, lastLoss=0.186, valLoss=0.192]
Epoch 6/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.83batch/s, lastLoss=0.186, valLoss=0.183]
Epoch 7/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.97batch/s, lastLoss=0.186, valLoss=0.186]
Epoch 8/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.19batch/s, lastLoss=0.186, valLoss=0.188]
Epoch 9/550: 100%|██████████████████████

Epoch 69/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 77.67batch/s, lastLoss=0.185, valLoss=0.183]
Epoch 70/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 76.72batch/s, lastLoss=0.185, valLoss=0.185]
Epoch 71/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 77.25batch/s, lastLoss=0.185, valLoss=0.189]
Epoch 72/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 76.30batch/s, lastLoss=0.185, valLoss=0.184]
Epoch 73/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 77.22batch/s, lastLoss=0.185, valLoss=0.186]
Epoch 74/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 77.22batch/s, lastLoss=0.185, valLoss=0.187]
Epoch 75/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 75.51batch/s, lastLoss=0.185, valLoss=0.181]
Epoch 76/550: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 76.36batch/s, lastLoss=0.185, valLoss=0.19]
Epoch 77/550: 100%|█████████████████████

Epoch 137/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.69batch/s, lastLoss=0.185, valLoss=0.182]
Epoch 138/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 75.38batch/s, lastLoss=0.185, valLoss=0.177]
Epoch 139/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.84batch/s, lastLoss=0.185, valLoss=0.183]
Epoch 140/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.71batch/s, lastLoss=0.185, valLoss=0.184]
Epoch 141/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.15batch/s, lastLoss=0.185, valLoss=0.182]
Epoch 142/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.79batch/s, lastLoss=0.185, valLoss=0.184]
Epoch 143/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.73batch/s, lastLoss=0.185, valLoss=0.183]
Epoch 144/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.29batch/s, lastLoss=0.185, valLoss=0.182]
Epoch 145/550: 100%|████████████████████

Epoch 205/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.16batch/s, lastLoss=0.185, valLoss=0.184]
Epoch 206/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.73batch/s, lastLoss=0.185, valLoss=0.186]
Epoch 207/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.87batch/s, lastLoss=0.185, valLoss=0.194]
Epoch 208/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.84batch/s, lastLoss=0.185, valLoss=0.178]
Epoch 209/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 77.19batch/s, lastLoss=0.185, valLoss=0.19]
Epoch 210/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.95batch/s, lastLoss=0.184, valLoss=0.188]
Epoch 211/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 75.90batch/s, lastLoss=0.185, valLoss=0.187]
Epoch 212/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 75.22batch/s, lastLoss=0.184, valLoss=0.183]
Epoch 213/550: 100%|████████████████████

Epoch 273/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.01batch/s, lastLoss=0.184, valLoss=0.187]
Epoch 274/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.75batch/s, lastLoss=0.185, valLoss=0.188]
Epoch 275/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.77batch/s, lastLoss=0.185, valLoss=0.181]
Epoch 276/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.13batch/s, lastLoss=0.185, valLoss=0.188]
Epoch 277/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.70batch/s, lastLoss=0.185, valLoss=0.184]
Epoch 278/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.06batch/s, lastLoss=0.185, valLoss=0.186]
Epoch 279/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.33batch/s, lastLoss=0.185, valLoss=0.187]
Epoch 280/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 76.40batch/s, lastLoss=0.184, valLoss=0.18]
Epoch 281/550: 100%|████████████████████

Epoch 341/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.22batch/s, lastLoss=0.184, valLoss=0.192]
Epoch 342/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.77batch/s, lastLoss=0.184, valLoss=0.183]
Epoch 343/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 76.92batch/s, lastLoss=0.184, valLoss=0.18]
Epoch 344/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.24batch/s, lastLoss=0.184, valLoss=0.178]
Epoch 345/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.80batch/s, lastLoss=0.185, valLoss=0.184]
Epoch 346/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.79batch/s, lastLoss=0.184, valLoss=0.186]
Epoch 347/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.24batch/s, lastLoss=0.184, valLoss=0.188]
Epoch 348/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.13batch/s, lastLoss=0.184, valLoss=0.183]
Epoch 349/550: 100%|████████████████████

Epoch 409/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.58batch/s, lastLoss=0.184, valLoss=0.185]
Epoch 410/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.70batch/s, lastLoss=0.184, valLoss=0.186]
Epoch 411/550: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 77.19batch/s, lastLoss=0.184, valLoss=0.19]
Epoch 412/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.69batch/s, lastLoss=0.184, valLoss=0.186]
Epoch 413/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.86batch/s, lastLoss=0.185, valLoss=0.184]
Epoch 414/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.07batch/s, lastLoss=0.184, valLoss=0.184]
Epoch 415/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.25batch/s, lastLoss=0.184, valLoss=0.186]
Epoch 416/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.60batch/s, lastLoss=0.184, valLoss=0.189]
Epoch 417/550: 100%|████████████████████

Epoch 477/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.56batch/s, lastLoss=0.184, valLoss=0.181]
Epoch 478/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.86batch/s, lastLoss=0.184, valLoss=0.191]
Epoch 479/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.96batch/s, lastLoss=0.184, valLoss=0.186]
Epoch 480/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.35batch/s, lastLoss=0.183, valLoss=0.196]
Epoch 481/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.19batch/s, lastLoss=0.184, valLoss=0.184]
Epoch 482/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.29batch/s, lastLoss=0.184, valLoss=0.184]
Epoch 483/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.52batch/s, lastLoss=0.184, valLoss=0.184]
Epoch 484/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.17batch/s, lastLoss=0.184, valLoss=0.184]
Epoch 485/550: 100%|████████████████████

Epoch 545/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.32batch/s, lastLoss=0.184, valLoss=0.187]
Epoch 546/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.17batch/s, lastLoss=0.184, valLoss=0.186]
Epoch 547/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 77.11batch/s, lastLoss=0.184, valLoss=0.185]
Epoch 548/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.35batch/s, lastLoss=0.184, valLoss=0.181]
Epoch 549/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 75.70batch/s, lastLoss=0.184, valLoss=0.187]
Epoch 550/550: 100%|█████████████████████████████████| 56/56 [00:00<00:00, 76.55batch/s, lastLoss=0.184, valLoss=0.187]


Average train loss: 0.18457281245523458
Average validation loss: 0.18523676806462536


In [None]:
df = getParameterLoss([out, hidden, orig, fh, both],["Output[-1]", "Hidden[-1]", "Original (output)", "Hidden", "Hidden[-1] + Output[-1]"])

In [36]:
models = [out, hidden, orig, fh, both]
train = [out_train_loss, hidden_train_loss, orig_train_loss, fh_train_loss, both_train_loss]
test = [out_val_loss, hidden_val_loss, orig_val_loss, fh_val_loss, both_val_loss]

data = [models, train, test]

In [39]:
import pickle

with open("../Results/RNNOut.pkl", "wb") as f:
    pickle.dump(data, f)

In [None]:
plotParameterBars(df)

In [None]:
sns.set_theme()

"""
    Plots train and test curves of given models
"""
def printCurves(names, train_loss, val_loss, epochs=400, name="train_test2"):
    if len(names)%2 == 0:
        rs = len(names)//2
        cs = len(names)//2
    else:
        rs, cs = len(names), 1
        
    fig, ax = plt.subplots(
        rs, cs, figsize=(35 if len(names) != 1 else 10, len(names)*6), sharex=True, sharey=True, constrained_layout=True)
    x = range(0, epochs)
    fig.supxlabel("Epoch", fontsize=45)
    fig.supylabel("MAE", fontsize=45)

    #ax[0][1].set_xlabel("Training Loss")
    ##ax[1][1].set_xlabel("Validation Loss")
    
    tv_loss = list(zip(train_loss, val_loss))
    if cs == 1 and rs == 1:
        ax.set_title(names[0], fontsize=40)
        ax.xaxis.set_tick_params(labelsize=35)
        ax.yaxis.set_tick_params(labelsize=35)
        for j in range(2):
            ax.plot(x, tv_loss[0][j], c = "b" if j == 0 else "y", label = "Train loss" if j == 0 else "Validation loss")
    elif cs == 1:
        for i in range(len(names)):
            ax[i].set_title(names[i], fontsize=40)
            ax[i].xaxis.set_tick_params(labelsize=35)
            ax[i].yaxis.set_tick_params(labelsize=35)
            for j in range(2):
                ax[i].plot(x, tv_loss[i][j], c = "b" if j == 0 else "y", label = "Train loss" if j == 0 else "Validation loss")
    else:
        pos = 0
        for i in range(rs):
            for j in range(cs):
                ax[i][j].set_title(names[pos], fontsize=40)
                ax[i][j].xaxis.set_tick_params(labelsize=35)
                ax[i][j].yaxis.set_tick_params(labelsize=35)
                for k in range(2):
                    ax[i][j].plot(x, tv_loss[pos][k], c = "b" if k == 0 else "y", label = "Train loss" if k == 0 else "Validation loss")
    
                pos += 1
        
    #plt.xticks(fontsize=14, rotation=90)
    plt.legend(loc='best', prop={'size':30})
    plt.show()
    fig.savefig(f"../Results/{name}.svg", bbox_inches='tight')