In [209]:
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Subset
from torch import nn, optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchsummary import summary
import seaborn as sns
from copy import deepcopy
from tqdm import tqdm

# Req for package
import sys
sys.path.append("../")
from SkinLearning.NN.Helpers import train, test as t, DEVICE, getParameterLoss, setSeed
from SkinLearning.NN.Models import MultiTemporal
from SkinLearning.Utils.Dataset import getDataset, getSplit
from SkinLearning.Utils.Plotting import plotParameterBars


torch.backends.cudnn.benchmark = True

In [12]:
setSeed()

In [132]:
class SignalAutoencoder(nn.Module):
    def __init__(self, encoding_dim=8):
        super(SignalAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, encoding_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU()
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x1 = x[:, 0, :].reshape(batch_size, 1, -1)
        x2 = x[:, 1, :].reshape(batch_size, 1, -1)
        
        encoded1 = self.encoder(x1)
        encoded2 = self.encoder(x2)
        
        decoded1 = self.decoder(encoded1)
        decoded2 = self.decoder(encoded2)

        return torch.concat([encoded1, encoded2], dim=1), torch.concat([decoded1, decoded2], dim=1)

In [160]:
# In theory wont work
class SiameseRNN(nn.Module):
    def __init__(self, input_size=16, hidden_size=512):
        super(SiameseRNN, self).__init__()
        self.hidden_size = hidden_size
        # num_layers > 5 reduces performance
        self.ae1 = SignalAutoencoder()
        self.ae2 = SignalAutoencoder()
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True)
                
        self.fc = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 6)
        
        )
    
    def forward(self, x):
        batch_size = x.shape[0]
        x = encoder(x)[0]
        
        x1 = x[:, 0, :].reshape(batch_size, 1, -1)
        x2 = x[:, 1, :].reshape(batch_size, 1, -1)
        
        o1, h1 = self.rnn(x1)  # Add a batch dimension
        o2, h2 = self.rnn(x2)  # Add a batch dimension
        
        #out = torch.cat([h1[-1], h2[-1]], dim=1)
        out=h1[-1]
        out = out.reshape(batch_size, -1)
        out = self.fc(out)
        return out

In [4]:
dataset, scaler = getDataset()

100%|█████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:09<00:00, 229.49it/s]


In [5]:
train_loader, test_loader = getSplit(dataset)

In [163]:
net = SiameseRNN()

In [164]:
train(train_loader, net, val_loader=test_loader, LR=0.001, epochs=1500, early_stopping=True)

Using: cuda


100%|██████████████████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 123.45batch/s]
100%|███████████████████████████| 56/56 [00:00<00:00, 135.73batch/s, counter=0, epoch=1, lastLoss=0.247, valLoss=0.183]
100%|███████████████████████████| 56/56 [00:00<00:00, 131.67batch/s, counter=0, epoch=2, lastLoss=0.187, valLoss=0.181]
100%|███████████████████████████| 56/56 [00:00<00:00, 135.09batch/s, counter=1, epoch=3, lastLoss=0.188, valLoss=0.182]
100%|███████████████████████████| 56/56 [00:00<00:00, 134.13batch/s, counter=2, epoch=4, lastLoss=0.188, valLoss=0.183]
100%|███████████████████████████| 56/56 [00:00<00:00, 133.36batch/s, counter=3, epoch=5, lastLoss=0.188, valLoss=0.185]
100%|███████████████████████████| 56/56 [00:00<00:00, 136.60batch/s, counter=0, epoch=6, lastLoss=0.189, valLoss=0.179]
100%|███████████████████████████| 56/56 [00:00<00:00, 133.01batch/s, counter=1, epoch=7, lastLoss=0.187, valLoss=0.183]
100%|███████████████████████████| 56/56 

KeyboardInterrupt: 

In [184]:
class CAE(nn.Module):
    def __init__(self, hidden_size=256, single_fc=True, out="f_hidden"):
        super(CAE, self).__init__()
        
        self.enc = nn.Sequential(
            nn.Conv1d(2, 128, kernel_size=5, padding=1, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=5, stride=2),

            nn.Conv1d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),

            nn.Conv1d(256, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
        )
        
        self.dec = nn.Sequential(
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(),
            nn.ConvTranspose1d(512, 256, kernel_size=5, padding=1, bias=False),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(),
            nn.ConvTranspose1d(256, 128, kernel_size=3, padding=1, bias=False),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 2, kernel_size=3, padding=1, bias=False),
            
        )
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        enc1 = self.enc(x)
        dec1 = self.dec(enc1)
            
        return enc1, dec1

In [185]:
encoder = CAE()

In [186]:
train(train_loader, encoder, val_loader=test_loader, LR=0.0001, epochs=1500, early_stopping=True)

Using: cuda


100%|███████████████████████████████████████████████████████████████████████████████| 56/56 [00:01<00:00, 30.04batch/s]
100%|█████████████████████████| 56/56 [00:00<00:00, 166.20batch/s, counter=0, epoch=1, lastLoss=0.0844, valLoss=0.0114]
100%|█████████████████████████| 56/56 [00:00<00:00, 166.67batch/s, counter=0, epoch=2, lastLoss=0.005, valLoss=0.00348]
100%|███████████████████████| 56/56 [00:00<00:00, 166.92batch/s, counter=0, epoch=3, lastLoss=0.00315, valLoss=0.00216]
100%|███████████████████████| 56/56 [00:00<00:00, 166.03batch/s, counter=0, epoch=4, lastLoss=0.00238, valLoss=0.00146]
100%|███████████████████████| 56/56 [00:00<00:00, 160.58batch/s, counter=0, epoch=5, lastLoss=0.00186, valLoss=0.00136]
100%|████████████████████████| 56/56 [00:00<00:00, 163.99batch/s, counter=0, epoch=6, lastLoss=0.00154, valLoss=0.0013]
100%|███████████████████████| 56/56 [00:00<00:00, 159.85batch/s, counter=1, epoch=7, lastLoss=0.00131, valLoss=0.00138]
100%|███████████████████████| 56/56 [00:

100%|█████████████████████| 56/56 [00:00<00:00, 168.17batch/s, counter=1, epoch=68, lastLoss=0.000102, valLoss=7.97e-5]
100%|██████████████████████| 56/56 [00:00<00:00, 167.92batch/s, counter=2, epoch=69, lastLoss=0.0001, valLoss=0.000128]
100%|██████████████████████| 56/56 [00:00<00:00, 164.21batch/s, counter=3, epoch=70, lastLoss=9.52e-5, valLoss=7.25e-5]
100%|████████████████████| 56/56 [00:00<00:00, 170.47batch/s, counter=4, epoch=71, lastLoss=0.000114, valLoss=0.000192]
100%|██████████████████████| 56/56 [00:00<00:00, 170.59batch/s, counter=5, epoch=72, lastLoss=0.00011, valLoss=9.07e-5]
100%|█████████████████████| 56/56 [00:00<00:00, 166.43batch/s, counter=0, epoch=73, lastLoss=0.000167, valLoss=6.78e-5]
100%|██████████████████████| 56/56 [00:00<00:00, 169.03batch/s, counter=1, epoch=74, lastLoss=0.000114, valLoss=9.6e-5]
100%|█████████████████████| 56/56 [00:00<00:00, 151.73batch/s, counter=2, epoch=75, lastLoss=0.000111, valLoss=7.68e-5]
100%|██████████████████████| 56/56 [00:0

100%|██████████████████████| 56/56 [00:00<00:00, 171.78batch/s, counter=4, epoch=136, lastLoss=2.1e-5, valLoss=2.53e-5]
100%|█████████████████████| 56/56 [00:00<00:00, 169.28batch/s, counter=5, epoch=137, lastLoss=2.03e-5, valLoss=2.88e-5]
100%|███████████████████████| 56/56 [00:00<00:00, 173.12batch/s, counter=6, epoch=138, lastLoss=2.3e-5, valLoss=2.7e-5]
100%|██████████████████████| 56/56 [00:00<00:00, 168.49batch/s, counter=7, epoch=139, lastLoss=2.7e-5, valLoss=1.86e-5]
100%|█████████████████████| 56/56 [00:00<00:00, 171.44batch/s, counter=8, epoch=140, lastLoss=2.25e-5, valLoss=1.82e-5]
100%|█████████████████████| 56/56 [00:00<00:00, 172.94batch/s, counter=0, epoch=141, lastLoss=2.02e-5, valLoss=1.71e-5]
100%|█████████████████████| 56/56 [00:00<00:00, 168.55batch/s, counter=0, epoch=142, lastLoss=1.91e-5, valLoss=1.54e-5]
100%|██████████████████████| 56/56 [00:00<00:00, 171.78batch/s, counter=1, epoch=143, lastLoss=2.12e-5, valLoss=2.3e-5]
100%|█████████████████████| 56/56 [00:00

100%|█████████████████████| 56/56 [00:00<00:00, 166.18batch/s, counter=6, epoch=204, lastLoss=3.64e-6, valLoss=5.03e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 165.44batch/s, counter=7, epoch=205, lastLoss=4.45e-6, valLoss=4.27e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 166.11batch/s, counter=0, epoch=206, lastLoss=2.91e-6, valLoss=1.89e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 162.78batch/s, counter=1, epoch=207, lastLoss=2.62e-6, valLoss=2.86e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 164.76batch/s, counter=2, epoch=208, lastLoss=3.82e-6, valLoss=3.99e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 171.97batch/s, counter=3, epoch=209, lastLoss=3.33e-6, valLoss=3.85e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 167.17batch/s, counter=4, epoch=210, lastLoss=4.56e-6, valLoss=4.13e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 164.01batch/s, counter=0, epoch=211, lastLoss=2.97e-6, valLoss=1.67e-6]
100%|█████████████████████| 56/56 [00:00

100%|████████████████████| 56/56 [00:00<00:00, 161.99batch/s, counter=28, epoch=272, lastLoss=2.99e-6, valLoss=3.27e-6]
100%|█████████████████████| 56/56 [00:00<00:00, 168.93batch/s, counter=29, epoch=273, lastLoss=2.68e-6, valLoss=2.2e-6]
100%|████████████████████| 56/56 [00:00<00:00, 169.66batch/s, counter=30, epoch=274, lastLoss=2.53e-6, valLoss=2.54e-6]
100%|████████████████████| 56/56 [00:00<00:00, 170.01batch/s, counter=31, epoch=275, lastLoss=2.27e-6, valLoss=1.76e-6]
100%|████████████████████| 56/56 [00:00<00:00, 173.45batch/s, counter=32, epoch=276, lastLoss=2.24e-6, valLoss=2.23e-6]
100%|████████████████████| 56/56 [00:00<00:00, 167.29batch/s, counter=33, epoch=277, lastLoss=2.03e-6, valLoss=1.49e-6]
100%|████████████████████| 56/56 [00:00<00:00, 172.03batch/s, counter=34, epoch=278, lastLoss=2.29e-6, valLoss=2.84e-6]
100%|████████████████████| 56/56 [00:00<00:00, 169.76batch/s, counter=35, epoch=279, lastLoss=1.81e-6, valLoss=1.22e-6]
100%|████████████████████| 56/56 [00:00<

100%|██████████████████████| 56/56 [00:00<00:00, 156.44batch/s, counter=4, epoch=340, lastLoss=8.8e-7, valLoss=1.06e-6]
100%|██████████████████████| 56/56 [00:00<00:00, 161.45batch/s, counter=5, epoch=341, lastLoss=9.2e-7, valLoss=6.16e-7]
100%|██████████████████████| 56/56 [00:00<00:00, 157.81batch/s, counter=6, epoch=342, lastLoss=8.66e-7, valLoss=6.3e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 158.84batch/s, counter=0, epoch=343, lastLoss=7.88e-7, valLoss=2.79e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 161.94batch/s, counter=1, epoch=344, lastLoss=6.34e-7, valLoss=7.93e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 156.73batch/s, counter=2, epoch=345, lastLoss=9.71e-7, valLoss=4.45e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 153.23batch/s, counter=3, epoch=346, lastLoss=7.02e-7, valLoss=5.84e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 168.67batch/s, counter=4, epoch=347, lastLoss=5.41e-7, valLoss=3.59e-7]
100%|█████████████████████| 56/56 [00:00

100%|████████████████████| 56/56 [00:00<00:00, 169.36batch/s, counter=10, epoch=408, lastLoss=1.29e-6, valLoss=1.28e-6]
100%|████████████████████| 56/56 [00:00<00:00, 165.44batch/s, counter=11, epoch=409, lastLoss=1.02e-6, valLoss=1.05e-6]
100%|████████████████████| 56/56 [00:00<00:00, 172.84batch/s, counter=12, epoch=410, lastLoss=1.06e-6, valLoss=4.71e-7]
100%|████████████████████| 56/56 [00:00<00:00, 169.70batch/s, counter=13, epoch=411, lastLoss=1.06e-6, valLoss=5.85e-7]
100%|████████████████████| 56/56 [00:00<00:00, 172.15batch/s, counter=14, epoch=412, lastLoss=6.87e-7, valLoss=3.53e-7]
100%|████████████████████| 56/56 [00:00<00:00, 172.05batch/s, counter=15, epoch=413, lastLoss=6.04e-7, valLoss=7.39e-7]
100%|████████████████████| 56/56 [00:00<00:00, 171.34batch/s, counter=16, epoch=414, lastLoss=6.63e-7, valLoss=2.29e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 173.11batch/s, counter=17, epoch=415, lastLoss=6.59e-7, valLoss=3.7e-7]
100%|████████████████████| 56/56 [00:00<

100%|██████████████████████| 56/56 [00:00<00:00, 171.43batch/s, counter=8, epoch=476, lastLoss=2.33e-7, valLoss=2.8e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 171.14batch/s, counter=9, epoch=477, lastLoss=2.69e-7, valLoss=1.29e-7]
100%|██████████████████████| 56/56 [00:00<00:00, 169.70batch/s, counter=0, epoch=478, lastLoss=1.39e-7, valLoss=7.7e-8]
100%|█████████████████████| 56/56 [00:00<00:00, 170.73batch/s, counter=1, epoch=479, lastLoss=1.69e-7, valLoss=8.99e-8]
100%|███████████████████████| 56/56 [00:00<00:00, 169.02batch/s, counter=2, epoch=480, lastLoss=1.8e-7, valLoss=1.3e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 171.52batch/s, counter=3, epoch=481, lastLoss=2.56e-7, valLoss=3.09e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 169.44batch/s, counter=4, epoch=482, lastLoss=2.01e-7, valLoss=1.18e-7]
100%|█████████████████████| 56/56 [00:00<00:00, 167.74batch/s, counter=5, epoch=483, lastLoss=1.55e-7, valLoss=7.78e-8]
100%|█████████████████████| 56/56 [00:00

100%|█████████████████████| 56/56 [00:00<00:00, 172.77batch/s, counter=3, epoch=544, lastLoss=8.99e-8, valLoss=5.64e-8]
100%|█████████████████████| 56/56 [00:00<00:00, 171.26batch/s, counter=4, epoch=545, lastLoss=8.74e-8, valLoss=4.89e-8]
100%|██████████████████████| 56/56 [00:00<00:00, 170.47batch/s, counter=5, epoch=546, lastLoss=1.08e-7, valLoss=7.2e-8]
100%|█████████████████████| 56/56 [00:00<00:00, 163.03batch/s, counter=6, epoch=547, lastLoss=1.25e-7, valLoss=7.68e-8]
100%|█████████████████████| 56/56 [00:00<00:00, 171.52batch/s, counter=7, epoch=548, lastLoss=1.27e-7, valLoss=6.01e-8]
100%|█████████████████████| 56/56 [00:00<00:00, 169.06batch/s, counter=0, epoch=549, lastLoss=7.12e-8, valLoss=4.04e-8]
100%|█████████████████████| 56/56 [00:00<00:00, 172.07batch/s, counter=1, epoch=550, lastLoss=7.42e-8, valLoss=6.83e-8]
100%|█████████████████████| 56/56 [00:00<00:00, 171.94batch/s, counter=2, epoch=551, lastLoss=9.08e-8, valLoss=6.52e-8]
100%|█████████████████████| 56/56 [00:00

100%|████████████████████| 56/56 [00:00<00:00, 173.37batch/s, counter=46, epoch=612, lastLoss=1.17e-7, valLoss=5.88e-8]
100%|████████████████████| 56/56 [00:00<00:00, 171.57batch/s, counter=47, epoch=613, lastLoss=4.83e-8, valLoss=3.77e-8]
100%|████████████████████| 56/56 [00:00<00:00, 171.42batch/s, counter=48, epoch=614, lastLoss=7.91e-8, valLoss=4.12e-8]
100%|████████████████████| 56/56 [00:00<00:00, 167.20batch/s, counter=49, epoch=615, lastLoss=1.03e-7, valLoss=7.24e-8]

Early stopping after 616 epochs
Average train loss: 8.447184908910218e-05
Average validation loss: 3.1040995137877146e-05





([0.0844115265977702,
  0.00499564955576456,
  0.0031544263724104632,
  0.0023758517303836663,
  0.0018608077703642526,
  0.001541190001132366,
  0.001305987811065279,
  0.0012197478364604259,
  0.0010039584272557736,
  0.0009443823281409485,
  0.000949007245903236,
  0.0008195393280142785,
  0.0007020178943223852,
  0.0007280986527413395,
  0.0005692042785605216,
  0.0005910676747069894,
  0.0005443428022512567,
  0.0004654561110198431,
  0.0004995587847328611,
  0.0005306084739069254,
  0.00049250910160481,
  0.00040520072434446774,
  0.00042119490951465977,
  0.0003352799937731886,
  0.0003669122925202828,
  0.0003319588234132555,
  0.0003146752600774302,
  0.0003187810465793258,
  0.00027911947654502,
  0.00028258593530543815,
  0.00035180378550389184,
  0.0004397859177385856,
  0.00039490563566297557,
  0.00035268515057396144,
  0.0003010509229040638,
  0.0002664322504902624,
  0.0002647368933789299,
  0.00023573117713177844,
  0.00019811816933465058,
  0.00021477501858108944,
  0

In [182]:
def train(
    train_loader,
    net,
    LR=0.1,
    epochs=2000,
    val_loader=None,
    early_stopping=False,
    patience=50,
    optimizer=optim.Adam,
    plot=False,
    cluster=False
):
    net.to(DEVICE)
    optimizer = optimizer(net.parameters(), lr=LR)
    criterion = nn.MSELoss()
    val_losses = []        
    losses = []
    best_val_loss = 1e10
    counter = 0
    
    global loss
    loss = 0
    
    if plot:
        _, ax = plt.subplots(1, 1)

    print(f"Using: {DEVICE}")

    def processBatch(ittr):
        global loss
        loss = 0
        
        for _, data in enumerate(ittr):
            inp, out = data['input'].to(DEVICE), data['output'].to(DEVICE)

            optimizer.zero_grad()
            predicted = torch.Tensor(net(inp)[1])

            cost = criterion(inp, predicted)
            loss += cost.item()
            cost.backward()
            optimizer.step()
    
    for epoch in range(epochs):
        net.train()
        
        if plot or cluster:
            processBatch(train_loader)
        else:
            with tqdm(train_loader, unit="batch") as it:
                if epoch > 0:
                    it.set_postfix(lastLoss=losses[-1], valLoss=0 if len(val_losses) \
                         == 0 else val_losses[-1], counter=counter, epoch=epoch+1/epochs)
                processBatch(it)
        
        loss /= len(train_loader)
        losses.append(loss)

        if val_loader:
            val_loss = 0
            net.eval()
            for idx, data in enumerate(val_loader):
                inp, out = data['input'].to(DEVICE), data['output'].to(DEVICE)

                predicted = torch.Tensor(net(inp)[1])
                cost = criterion(inp, predicted)
                val_loss += cost.item()
            val_loss /= len(val_loader)  
            val_losses.append(val_loss)
            
            if plot:
                update_plot(losses, val_losses, ax, "Training", epoch)
            if cluster:
                print(f"Epoch {epoch+1}/{epochs}:")
                print(f"    Training loss: {losses[-1]}")
                print(f"    Validation loss: {val_losses[-1]}")
                print(f"    Stagnation counter: {counter}\n")

            if early_stopping:
                if val_losses[-1] < best_val_loss:
                    best_val_loss = val_loss
                    counter = 0
                else:
                    counter += 1
                if counter >= patience:
                    print(f"Early stopping after {epoch + 1} epochs")
                    break
         
    print(f"Average train loss: {np.sum(losses)/epochs}")
    print(f"Average validation loss: {np.sum(val_losses)/epochs}")
    
    return losses, val_losses


In [196]:
# Remove one FC LAyer
class MultiLSTM(nn.Module):
    def __init__(self, hidden_size=256, single_fc=True, out="f_hidden"):
        super(MultiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.out = out
        
        self.cnn = encoder
        
        self.lstm = nn.LSTM(15, hidden_size, batch_first=True, num_layers=1, bidirectional=True)
        
        input_tensor = torch.zeros(32, 512, 15)
        output, hidden = self.lstm(input_tensor)
        
        fc_in = hidden_size
        if out == 'output':
            fc_in = output.shape[1] * output.shape[2]
        elif out == 'f-output':
            fc_in = output.shape[2]
        elif fc_in == 'hidden' or out == 'cell':
            out = hidden_size * output.shape[2]
        elif out == 'f-hiden' or out == 'f-cell':
            fc_in = output.shape[2]
        elif out == 'h+o' or out == 'h+c' :
            fc_in = output.shape[1]
        
        if single_fc:
            self.fc = nn.Linear(fc_in*2, 6)
        else:
            self.fc = nn.Sequential(
                nn.Linear(fc_in*2, 128),
                nn.ReLU(),
                nn.Linear(128 , 64),
                nn.ReLU(),
                nn.Linear(64, 6),   
            )

    def forward(self, x):
        batch_size = x.shape[0]
        x = encoder(x)[0]
        
        #h0 = torch.zeros(1, batch_size, 256).to(x.device)
        #c0 = torch.zeros((1, batch_size, self.hidden_size)).to(x.device)
        o, (h, c) = self.lstm(x)
        
        if self.out == "f_hidden":
            x = h[-1].reshape(batch_size, -1)
        elif self.out == "hidden":
            x = h.reshape(batch_size, -1)
        elif self.out == "f_output":
            x = o[:, -1, :].reshape(batch_size, -1)
        elif self.out == "output":
            x = o.reshape(batch_size, -1)
        elif self.out == "f_cell":
            x == c[-1].reshape(batch_size, -1)
        elif self.out == "cell":
            x == c.reshape(batch_size, -1)
        elif self.out == "h+c":
            x = torch.concat([h[-1], c[-1]], dim=1).view(batch_size, -1)
        elif self.out == "h+o":
            x = torch.concat([h[-1], o[:, -1, :]], dim=1).view(o.size(0), -1)
            
        x = self.fc(x)
        return x

In [206]:
test = MultiLSTM(out="f_output", single_fc=False)

In [207]:
traint_loss, valt_loss = train(train_loader, test, val_loader=test_loader, LR=0.0001, epochs=1500, early_stopping=True)

Using: cuda


100%|███████████████████████████████████████████████████████████████████████████████| 56/56 [00:02<00:00, 25.91batch/s]
100%|████████████████████████████| 56/56 [00:02<00:00, 26.22batch/s, counter=0, epoch=1, lastLoss=0.263, valLoss=0.179]
100%|████████████████████████████| 56/56 [00:02<00:00, 26.21batch/s, counter=0, epoch=2, lastLoss=0.163, valLoss=0.145]
100%|████████████████████████████| 56/56 [00:02<00:00, 26.26batch/s, counter=0, epoch=3, lastLoss=0.136, valLoss=0.124]
100%|████████████████████████████| 56/56 [00:02<00:00, 26.41batch/s, counter=0, epoch=4, lastLoss=0.114, valLoss=0.108]
100%|████████████████████████████| 56/56 [00:02<00:00, 26.17batch/s, counter=0, epoch=5, lastLoss=0.107, valLoss=0.102]
100%|████████████████████████████| 56/56 [00:02<00:00, 26.12batch/s, counter=1, epoch=6, lastLoss=0.101, valLoss=0.104]
100%|███████████████████████████| 56/56 [00:02<00:00, 26.18batch/s, counter=0, epoch=7, lastLoss=0.0992, valLoss=0.091]
100%|████████████████████████████| 56/56

100%|█████████████████████████| 56/56 [00:02<00:00, 26.11batch/s, counter=6, epoch=68, lastLoss=0.0523, valLoss=0.0585]
100%|█████████████████████████| 56/56 [00:02<00:00, 26.33batch/s, counter=0, epoch=69, lastLoss=0.0518, valLoss=0.0496]
100%|█████████████████████████| 56/56 [00:02<00:00, 26.45batch/s, counter=1, epoch=70, lastLoss=0.0518, valLoss=0.0595]
100%|█████████████████████████| 56/56 [00:02<00:00, 26.20batch/s, counter=2, epoch=71, lastLoss=0.0527, valLoss=0.0577]
100%|█████████████████████████| 56/56 [00:02<00:00, 26.40batch/s, counter=3, epoch=72, lastLoss=0.0514, valLoss=0.0508]
100%|█████████████████████████| 56/56 [00:02<00:00, 26.37batch/s, counter=4, epoch=73, lastLoss=0.0525, valLoss=0.0554]
100%|█████████████████████████| 56/56 [00:02<00:00, 26.19batch/s, counter=5, epoch=74, lastLoss=0.0541, valLoss=0.0561]
100%|█████████████████████████| 56/56 [00:02<00:00, 26.32batch/s, counter=6, epoch=75, lastLoss=0.0517, valLoss=0.0576]
100%|██████████████████████████| 56/56 [

100%|███████████████████████| 56/56 [00:02<00:00, 25.93batch/s, counter=15, epoch=136, lastLoss=0.0459, valLoss=0.0536]
100%|███████████████████████| 56/56 [00:02<00:00, 26.07batch/s, counter=16, epoch=137, lastLoss=0.0453, valLoss=0.0468]
100%|███████████████████████| 56/56 [00:02<00:00, 26.08batch/s, counter=17, epoch=138, lastLoss=0.0469, valLoss=0.0554]
100%|███████████████████████| 56/56 [00:02<00:00, 26.04batch/s, counter=18, epoch=139, lastLoss=0.0467, valLoss=0.0488]
100%|████████████████████████| 56/56 [00:02<00:00, 25.42batch/s, counter=19, epoch=140, lastLoss=0.045, valLoss=0.0468]
100%|████████████████████████| 56/56 [00:02<00:00, 25.89batch/s, counter=20, epoch=141, lastLoss=0.0455, valLoss=0.049]
100%|███████████████████████| 56/56 [00:02<00:00, 25.86batch/s, counter=21, epoch=142, lastLoss=0.0444, valLoss=0.0499]
100%|███████████████████████| 56/56 [00:02<00:00, 25.85batch/s, counter=22, epoch=143, lastLoss=0.0449, valLoss=0.0515]
100%|████████████████████████| 56/56 [00

Early stopping after 171 epochs
Average train loss: 0.00643280007087049
Average validation loss: 0.006802819130900834


In [210]:
t(test_loader, test, scaler)

(93.37990760803223,
 array([94.31275 , 88.29201 , 99.73189 , 99.81301 , 85.045235, 93.08455 ],
       dtype=float32),
 0.0470153254767259)