In [4]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
from torch import tensor
import numpy as np
from torch.utils.data import Dataset
from sklearn.metrics import mean_squared_error
import random
import os
import matplotlib.pyplot as plt
#from numba import jit
import pickle
from scipy.interpolate import interp1d
from torch.utils.data import DataLoader, random_split
import torch
from torchsummary import summary
import seaborn as sns
import sys
import torch.nn.functional as F
import pywt
from sklearn.preprocessing import MinMaxScaler
from torch import FloatTensor

# Req for package
sys.path.append("../")
from SkinLearning.Utils.NN import train, test, DEVICE


torch.backends.cudnn.benchmark = True

In [3]:
DEVICE = 'cpu'

In [5]:
# Folder name will correspond to index of sample
class SkinDataset(Dataset):
    def __init__(self, scaler, method="coefficients", combined=False, signalFolder="D:/SamplingResults2", sampleFile="../Data/newSamples.pkl", runs=range(65535), steps=128):
        # Load both disp1 and disp2 from each folder
        # Folders ordered according to index of sample
        # Use the corresponding sample as y -> append probe?
        self.input = []
        self.output = []
        
        with open(f"{sampleFile}", "rb") as f:
             samples = pickle.load(f)
        
        self.min = np.min(samples[runs])
        self.max = np.max(samples[runs])
        
        
        for run in tqdm(runs):
            inp = []
            fail = False
            
            files = os.listdir(f"{signalFolder}/{run}/")
            
            if files != ['Disp1.csv', 'Disp2.csv']:
                continue
            
            for file in files:
                a = pd.read_csv(f"{signalFolder}/{run}/{file}")
                a.rename(columns = {'0':'x', '0.1': 'y'}, inplace = True)
                
                # Skip if unconverged
                if a['x'].max() != 7.0:
                    fail = True
                    break

                # Interpolate curve for consistent x values
                xNew = np.linspace(0, 7, num=steps, endpoint=False)
                interped = interp1d(a['x'], a['y'], kind='cubic', fill_value="extrapolate")(xNew)
                    
                
                inp.append(interped.astype("float32"))
            
            if not fail:
                if len(inp) != 2:
                    raise Exception("sdf")

                self.input.append(inp)
                self.output.append(samples[int(run)])
        
        scaler.fit(self.output)
        self.output = scaler.fit_transform(self.output)
        self.output = tensor(self.output).type(FloatTensor)
        
        self.input = [waveletExtraction(sample, method, combined=combined) for sample in self.input]
        self.input = tensor(np.array(self.input)).type(FloatTensor)
        
        
    def __len__(self):
        return len(self.output)
    
    def __getitem__(self, idx):
        sample = {"input": self.input[idx], "output": self.output[idx]}
        return sample
    
    

In [6]:
"""
    Creates the data set from filtered samples
    Returns the dataset and the scaler
"""
def getDataset(**kwargs):
    # Get filtered data
    if not 'runs' in kwargs.keys():
        with open("../Data/filtered.pkl", "rb") as f:
            runs = pickle.load(f)

        kwargs['runs'] = runs

    scaler = MinMaxScaler()
    dataset = SkinDataset(scaler=scaler, **kwargs)

    return dataset, scaler

In [7]:
"""
    Creates a train/test split from the given data
    Returns train and test data loaders
"""
def getSplit(dataset, p1=0.8):
    train_n = int(p1 * len(dataset))
    test_n = len(dataset) - train_n
    train_set, test_set = random_split(dataset, [train_n, test_n])

    return DataLoader(train_set, batch_size=32, shuffle=True), \
        DataLoader(test_set, batch_size=32, shuffle=True)

In [8]:
datasets['non-combined']['raw'][0][0]['input']

NameError: name 'datasets' is not defined

In [72]:
from scipy.stats import skew, kurtosis, entropy
from math import gcd
from functools import reduce

def waveletExtraction(x, method, combined=False, wavelet='db4', level=6, combine_method="concatenate"):
    #s = MinMaxScaler()
    #x = s.fit_transform(x)
    if combined:
        coeffs = pywt.WaveletPacket(x, wavelet=wavelet, maxlevel=level, mode='symmetric')
        coeffs = extractCoeff(coeffs, level)
        
        
        features = []
        if method == "raw":
                features.append(coeffs)
        else:
            # Get sub-freqency band details
            for i in range(len(coeffs)):  
                    if method == "stats":
                        features.append(stats(coeffs[i]))
                    elif method == "entropy":
                        features.append(entropy(coeffs[i]))
                    elif method == "min-max":
                        features.append(minMax(coeffs[i]))
                    elif method=="energy":
                        features.append(getEnergy(coeffs[i]))
                    
        return features
    
        return combined_coefficients 
    else:
        wp1 = pywt.WaveletPacket(x[0], wavelet=wavelet, maxlevel=level)
        wp2 = pywt.WaveletPacket(x[1], wavelet=wavelet, maxlevel=level)
        
        coeffs1 = extractCoeff(wp1, 0)
        coeffs2 = extractCoeff(wp2, 0)

        features1, features2 = [], []
        
        if method == "raw":
                features1, features2 = coeffs1, coeffs2
        else:
            # Get sub-freqency band details
            for i in range(len(coeffs1)):
                if method == "stats":
                    features1.append(stats(coeffs1[i]))
                    features2.append(stats(coeffs2[i]))
                elif method == "entropy":
                    features1.append(entropy(coeffs1[i]))
                    features2.append(entropy(coeffs2[i]))
                elif method == "min-max":
                    features1.append(minMax(coeffs1[i]))
                    features2.append(minMax(coeffs2[i]))
                elif method=="energy":
                    features1.append(getEnergy(coeffs1[i]))
                    features2.append(getEnergy(coeffs2[i]))
        
        if combine_method == "concatenate":
            return [np.array(features1).flatten(), np.array(features2).flatten()]
        else:
            features1 = np.array(features1)
            features2 = np.array(features2)
            
            signal1_coeffs_transposed = features1.T
            signal2_coeffs_transposed = features2.T

            # Flatten the transposed arrays into 1D arrays
            signal1_coeffs_flattened = signal1_coeffs_transposed.flatten()
            signal2_coeffs_flattened = signal2_coeffs_transposed.flatten()

            # Interleave the flattened arrays element-wise
            combined_coeffs = np.empty(signal1_coeffs_flattened.shape[0] + signal2_coeffs_flattened.shape[0], dtype=signal1_coeffs_flattened.dtype)
            combined_coeffs[0::2] = signal1_coeffs_flattened
            combined_coeffs[1::2] = signal2_coeffs_flattened


            return combined_coeffs #np.concatenate(flattened) -> works better?

def minMax(coeffs):
    return [
        min(coeffs),
        max(coeffs)
    ]
    
def stats(c):
    return [
                np.mean(c),
                np.std(c),
                skew(c),
                kurtosis(c)
            ]

# Find the common denominator of the number of intervals for each level
def reorganise(features):
    # Calculate the number of time intervals for each level
    num_intervals = [len(feature) for feature in features]

    common_denominator = reduce(lambda x, y: x * y // gcd(x, y), num_intervals)

    # Initialize an empty list to store the reorganized features
    reorganized_features = []

    # Iterate through the time intervals and add the corresponding features to the reorganized list
    for t in range(common_denominator):
        interval_features = []

        for level_features, num_intervals_level in zip(features, num_intervals):
            
            # Find the corresponding feature index for the current time interval and level
            index = (t * num_intervals_level) // common_denominator
            interval_features.append(level_features[int(index)])

        reorganized_features.append(interval_features)
    
    return np.array(reorganized_features)

def extractCoeff(wp, level=0):
    coeffs = []
    coeffs =  wp.get_level(6, 'freq')
    features = np.array([c.data for c in coeffs])
    return features

def getEnergy(coefficients):
    energy = np.sum(np.square(coefficients))
    return energy

In [73]:
datasets = {}

for comb in ['non-combined']:
    datasets[comb] = {}
    
    for method in ['min-max', 'stats', 'raw', 'entropy', 'energy']:# , 'stats', 'raw', 'entropy']:
        datasets[comb][method] = getDataset(method=method)#, combined = True if comb == "combined" else False, runs=[1,2,3])

100%|█████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:09<00:00, 241.92it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:09<00:00, 246.63it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:08<00:00, 260.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:09<00:00, 226.88it/s]


In [75]:
datasets['non-combined']['energy'] = getDataset(method='energy')

100%|█████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:10<00:00, 212.14it/s]


In [76]:
datasets['non-combined']['energy'][0][0]

{'input': tensor([[2.0634e-03, 1.9313e-04, 6.1960e-06, 1.3240e-05, 1.9456e-05, 6.7089e-06,
          1.2562e-05, 4.5536e-06, 6.8081e-06, 1.8880e-06, 2.1032e-06, 2.6671e-06,
          2.3956e-06, 1.5530e-06, 7.3691e-06, 1.3685e-05, 7.8473e-06, 2.9935e-06,
          3.7773e-07, 7.1285e-07, 8.2999e-07, 3.8769e-07, 1.7844e-07, 2.2033e-06,
          3.6354e-08, 2.0814e-07, 1.1370e-07, 1.0103e-06, 2.7930e-06, 1.3206e-06,
          5.9242e-07, 1.4897e-05, 5.1210e-09, 3.6606e-09, 5.7816e-10, 7.7265e-10,
          7.4486e-09, 3.2818e-09, 1.2351e-09, 4.3832e-09, 2.5644e-09, 9.9259e-10,
          1.8569e-09, 1.3191e-09, 2.6533e-09, 2.0654e-09, 1.2057e-09, 1.6858e-09,
          1.3479e-08, 4.2453e-09, 1.4717e-09, 9.9297e-09, 8.1621e-10, 3.5505e-10,
          1.3394e-09, 5.1120e-09, 4.2717e-08, 5.0222e-09, 8.4436e-09, 1.3365e-08,
          7.2233e-09, 8.0118e-10, 2.6555e-09, 2.6087e-07],
         [1.5781e-02, 2.4109e-03, 3.9224e-04, 3.6970e-04, 6.8843e-05, 3.5664e-06,
          2.4005e-05, 1.7359e-

In [77]:
loaders = {}
lengths = {}

for comb in datasets.keys():
    loaders[comb] = {}
    lengths[comb] = {}
    
    for method in datasets[comb].keys():
        loaders[comb][method] = {}
        loaders[comb][method]['train'], loaders[comb][method]['test'] = getSplit(datasets[comb][method][0])
        lengths[comb][method] = len(datasets[comb][method][0][0]['input'][0])

In [78]:
loaders, lengths

({'non-combined': {'min-max': {'train': <torch.utils.data.dataloader.DataLoader at 0x16d4b876860>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x16e078f6020>},
   'stats': {'train': <torch.utils.data.dataloader.DataLoader at 0x16e078fd600>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x16e078fdd80>},
   'raw': {'train': <torch.utils.data.dataloader.DataLoader at 0x16e078fca60>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x16e078fc2b0>},
   'entropy': {'train': <torch.utils.data.dataloader.DataLoader at 0x16e078fda20>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x16e078fcd60>},
   'energy': {'train': <torch.utils.data.dataloader.DataLoader at 0x16e078fe1d0>,
    'test': <torch.utils.data.dataloader.DataLoader at 0x16e078ff430>}}},
 {'non-combined': {'min-max': 128,
   'stats': 256,
   'raw': 512,
   'entropy': 64,
   'energy': 64}})

In [79]:
lengths2 = {}

for comb in datasets.keys():
    lengths2[comb] = {}
    
    for method in datasets[comb].keys():
        lengths2[comb][method] = len(datasets[comb][method][0][0]['input'][0])

In [80]:
lengths2

{'non-combined': {'min-max': 128,
  'stats': 256,
  'raw': 512,
  'entropy': 64,
  'energy': 64}}

In [105]:
class RNNMulti(nn.Module):
    def __init__(self, input_size=256, hidden_size=256, num_layers=1, method="concatenate"):
        super(RNNMulti, self).__init__()
        self.method = method
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        if method == 'concatenate':
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        elif method == 'multi_channel':
            self.rnn = nn.RNN(2, hidden_size, num_layers, batch_first=True)
        elif method == 'independent':
            self.rnn = nn.RNN(input_size//2, hidden_size, num_layers, batch_first=True)
        else:
            raise ValueError("Invalid method. Choose from 'concatenate', 'multi_channel', or 'independent'.")

        if method == 'independent':
            self.fc = nn.Linear(hidden_size*2, 6)
        else:
            self.fc = nn.Linear(hidden_size, 6)

    def forward(self, x):
        batch_size = x.size(0)
        signal_size = self.input_size//2
        signal1 = x[..., :signal_size]
        signal2 = x[..., signal_size:]
        
        x = x.reshape(batch_size, -1, self.input_size)
        signal1 = signal1.reshape(batch_size, -1, signal_size)
        signal2 = signal2.reshape(batch_size, -1, signal_size)
        

        if self.method == 'concatenate':
            _, hidden = self.rnn(x)
            hidden = hidden[-1]
        elif self.method == 'multi_channel':
            output, hidden = self.rnn(x.view(batch_size, -1, 2))
            hidden = hidden[-1]
        elif self.method == 'independent':
            output1, hidden1 = self.rnn(signal1)
            output2, hidden2 = self.rnn(signal2)
            hidden = torch.concat((hidden1[-1], hidden2[-1]), axis=0)

        output = self.fc(hidden)

        return output

In [100]:
class LSTMMulti(nn.Module):
    def __init__(self, input_size=256, hidden_size=256, num_layers=1, method="concatenate"):
        super(LSTMMulti, self).__init__()
        self.method = method
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        if method == 'concatenate':
            self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        elif method == 'multi_channel':
            self.lstm = nn.LSTM(2, hidden_size, num_layers, batch_first=True)
        elif method == 'independent':
            self.lstm = nn.LSTM(input_size//2, hidden_size, num_layers, batch_first=True)
        else:
            raise ValueError("Invalid method. Choose from 'concatenate', 'multi_channel', or 'independent'.")
        
        if method == 'independent':
            self.fc = nn.Linear(hidden_size*2, 6)
        else:
            self.fc = nn.Linear(hidden_size, 6)

    def forward(self, x):
        batch_size = x.size(0)
        signal_size = self.input_size//2
        signal1 = x[..., :signal_size]
        signal2 = x[..., signal_size:]
        
        x = x.reshape(batch_size, -1, self.input_size)
        signal1 = signal1.reshape(batch_size, -1, signal_size)
        signal2 = signal2.reshape(batch_size, -1, signal_size)

        h0 = torch.zeros((self.num_layers, batch_size, self.hidden_size)).to(x.device)  # Initialize hidden state to zeros
        c0 = torch.zeros((self.num_layers, batch_size, self.hidden_size)).to(x.device)  # Initialize cell state to zeros

        if self.method == 'concatenate':
            _, (hidden, cell) = self.lstm(x)
            hidden = hidden[-1]
        elif self.method == 'multi_channel':
            output, (hidden, cell) = self.lstm(x.view(batch_size, -1, 2))
            hidden = hidden[-1]
        elif self.method == 'independent':
            output1, (hidden1, cell1) = self.lstm(signal1)
            output2, (hidden2, cell2) = self.lstm(signal2)
            
            hidden = torch.concat((hidden1[-1], hidden2[-1]), axis=1)
            
        output = self.fc(hidden)

        return output

In [101]:
# Remove one FC LAyer
class RNNCurrent(nn.Module):
    def __init__(self):
        super(RNNCurrent, self).__init__()
        self.conv1 = nn.Conv1d(1, 128, kernel_size=5, padding=1, bias=False)
        self.pool1 = nn.MaxPool1d(kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm1d(128)
        
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1, bias=False)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1, bias=False)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.bn3 = nn.BatchNorm1d(512)
        
        self.rnn = nn.RNN(1, 256, batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128 , 64),
            nn.ReLU(),
            nn.Linear(64, 6),   
        )

    def forward(self, x):
        a = x.shape
        batch_size = x.shape[0]
        
        x = x.reshape(batch_size, 1, -1)
        x = self.pool1(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool2(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool3(torch.relu(self.bn3(self.conv3(x))))
        
        
        out, h = self.rnn(x)
  
        x = out[:, -1, :]
        x = x.reshape(batch_size, -1)

        
        x = self.fc(x)
        
        x = x.view(batch_size, 6)
        return x

In [122]:
models = {}

# Get an LSTM and RNN of each type, paired with each type of dataset
for model_type in ['LSTM', 'RNN']:
    models[model_type] = {}
    
    for comb in datasets.keys():
        models[model_type][comb] = {}
        
        for ext_method in datasets[comb].keys():
            models[model_type][comb][ext_method] = {}
            
            for model_method in ['concatenate', 'multi_channel', 'independent']:
                models[model_type][comb][ext_method][model_method] = \
                LSTMMulti(method=model_method, input_size=lengths[comb][ext_method], hidden_size=2*lengths[comb][ext_method])

In [119]:
lengths

{'non-combined': {'min-max': 128,
  'stats': 256,
  'raw': 512,
  'entropy': 64,
  'energy': 64}}

In [130]:
losses = {}

for model_type in models.keys():
    losses[model_type] = {}
    if model_type == 'LSTM':
        continue
    for comb in models[model_type].keys():
        losses[model_type][comb] = {}
        
        for ext_method in models[model_type][comb].keys():
            losses[model_type][comb][ext_method] = {}
            
            for model_method in models[model_type][comb][ext_method].keys():
                print(f"Running {model_type} based on {model_method}, with dataset using {comb} {ext_method}")
                train_loss, val_loss = train(
                    loaders[comb][ext_method]['train'],
                    models[model_type][comb][ext_method][model_method],
                    val_loader=loaders[comb][ext_method]['test'],
                    LR=0.001,
                    epochs=1500, early_stopping=True)
                     

Running RNN based on concatenate, with dataset using non-combined min-max
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 232.36batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 227.71batch/s, counter=0, lastLoss=0.186, valLoss=0.189]
Epoch 3/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 225.62batch/s, counter=0, lastLoss=0.183, valLoss=0.181]
Epoch 4/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 221.61batch/s, counter=0, lastLoss=0.179, valLoss=0.167]
Epoch 5/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 238.17batch/s, counter=1, lastLoss=0.171, valLoss=0.172]
Epoch 6/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 223.76batch/s, counter=2, lastLoss=0.167, valLoss=0.169]
Epoch 7/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 227.18batch/s, counter=3, lastLoss=0.165, valLoss=0.17]
Epoch 8/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 231.88batch/s, counter=0, lastLoss=0.165, valLoss=0.16]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 224.77batch/s, counter=3, lastLoss=0.115, valLoss=0.113]
Epoch 70/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 233.81batch/s, counter=4, lastLoss=0.117, valLoss=0.114]
Epoch 71/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 240.09batch/s, counter=5, lastLoss=0.115, valLoss=0.113]
Epoch 72/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 233.82batch/s, counter=0, lastLoss=0.115, valLoss=0.109]
Epoch 73/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 235.15batch/s, counter=1, lastLoss=0.114, valLoss=0.117]
Epoch 74/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 229.50batch/s, counter=2, lastLoss=0.113, valLoss=0.111]
Epoch 75/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 240.81batch/s, counter=3, lastLoss=0.114, valLoss=0.113]
Epoch 76/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 233.08batch/s, counter=4, lastLoss=0.113, valLoss=0.115]
Epoch 77/1500: 100%|████████████████████

Epoch 137/1500: 100%|███████████████████| 56/56 [00:00<00:00, 243.03batch/s, counter=5, lastLoss=0.101, valLoss=0.0964]
Epoch 138/1500: 100%|███████████████████| 56/56 [00:00<00:00, 241.89batch/s, counter=6, lastLoss=0.101, valLoss=0.0978]
Epoch 139/1500: 100%|███████████████████| 56/56 [00:00<00:00, 236.80batch/s, counter=7, lastLoss=0.0996, valLoss=0.106]
Epoch 140/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 251.69batch/s, counter=8, lastLoss=0.1, valLoss=0.0986]
Epoch 141/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 241.89batch/s, counter=9, lastLoss=0.1, valLoss=0.0964]
Epoch 142/1500: 100%|████████████████████| 56/56 [00:00<00:00, 247.79batch/s, counter=10, lastLoss=0.1, valLoss=0.0985]
Epoch 143/1500: 100%|███████████████████| 56/56 [00:00<00:00, 248.96batch/s, counter=0, lastLoss=0.099, valLoss=0.0946]
Epoch 144/1500: 100%|██████████████████| 56/56 [00:00<00:00, 249.43batch/s, counter=1, lastLoss=0.0994, valLoss=0.0956]
Epoch 145/1500: 100%|███████████████████

Epoch 205/1500: 100%|██████████████████| 56/56 [00:00<00:00, 244.01batch/s, counter=3, lastLoss=0.0941, valLoss=0.0928]
Epoch 206/1500: 100%|███████████████████| 56/56 [00:00<00:00, 239.30batch/s, counter=4, lastLoss=0.094, valLoss=0.0911]
Epoch 207/1500: 100%|███████████████████| 56/56 [00:00<00:00, 224.45batch/s, counter=5, lastLoss=0.0939, valLoss=0.088]
Epoch 208/1500: 100%|██████████████████| 56/56 [00:00<00:00, 249.77batch/s, counter=6, lastLoss=0.0951, valLoss=0.0888]
Epoch 209/1500: 100%|██████████████████| 56/56 [00:00<00:00, 240.79batch/s, counter=7, lastLoss=0.0944, valLoss=0.0885]
Epoch 210/1500: 100%|██████████████████| 56/56 [00:00<00:00, 244.52batch/s, counter=8, lastLoss=0.0927, valLoss=0.0888]
Epoch 211/1500: 100%|██████████████████| 56/56 [00:00<00:00, 247.79batch/s, counter=9, lastLoss=0.0935, valLoss=0.0853]
Epoch 212/1500: 100%|██████████████████| 56/56 [00:00<00:00, 248.34batch/s, counter=10, lastLoss=0.093, valLoss=0.0877]
Epoch 213/1500: 100%|█████████████████| 

Epoch 273/1500: 100%|██████████████████| 56/56 [00:00<00:00, 244.54batch/s, counter=0, lastLoss=0.0879, valLoss=0.0788]
Epoch 274/1500: 100%|██████████████████| 56/56 [00:00<00:00, 245.57batch/s, counter=1, lastLoss=0.0861, valLoss=0.0856]
Epoch 275/1500: 100%|██████████████████| 56/56 [00:00<00:00, 242.31batch/s, counter=2, lastLoss=0.0867, valLoss=0.0837]
Epoch 276/1500: 100%|██████████████████| 56/56 [00:00<00:00, 248.33batch/s, counter=3, lastLoss=0.0859, valLoss=0.0959]
Epoch 277/1500: 100%|██████████████████| 56/56 [00:00<00:00, 242.51batch/s, counter=4, lastLoss=0.0861, valLoss=0.0931]
Epoch 278/1500: 100%|██████████████████| 56/56 [00:00<00:00, 244.50batch/s, counter=5, lastLoss=0.0865, valLoss=0.0813]
Epoch 279/1500: 100%|██████████████████| 56/56 [00:00<00:00, 247.84batch/s, counter=6, lastLoss=0.0893, valLoss=0.0901]
Epoch 280/1500: 100%|███████████████████| 56/56 [00:00<00:00, 247.79batch/s, counter=7, lastLoss=0.0882, valLoss=0.086]
Epoch 281/1500: 100%|██████████████████|

Epoch 341/1500: 100%|███████████████████| 56/56 [00:00<00:00, 247.23batch/s, counter=17, lastLoss=0.0804, valLoss=0.08]
Epoch 342/1500: 100%|██████████████████| 56/56 [00:00<00:00, 246.15batch/s, counter=0, lastLoss=0.0826, valLoss=0.0724]
Epoch 343/1500: 100%|██████████████████| 56/56 [00:00<00:00, 246.14batch/s, counter=1, lastLoss=0.0823, valLoss=0.0792]
Epoch 344/1500: 100%|██████████████████| 56/56 [00:00<00:00, 247.34batch/s, counter=2, lastLoss=0.0801, valLoss=0.0749]
Epoch 345/1500: 100%|██████████████████| 56/56 [00:00<00:00, 244.56batch/s, counter=3, lastLoss=0.0789, valLoss=0.0772]
Epoch 346/1500: 100%|████████████████████| 56/56 [00:00<00:00, 245.08batch/s, counter=4, lastLoss=0.08, valLoss=0.0734]
Epoch 347/1500: 100%|██████████████████| 56/56 [00:00<00:00, 247.46batch/s, counter=5, lastLoss=0.0801, valLoss=0.0728]
Epoch 348/1500: 100%|████████████████████| 56/56 [00:00<00:00, 248.75batch/s, counter=6, lastLoss=0.08, valLoss=0.0764]
Epoch 349/1500: 100%|██████████████████|

Epoch 409/1500: 100%|██████████████████| 56/56 [00:00<00:00, 241.47batch/s, counter=2, lastLoss=0.0772, valLoss=0.0818]
Epoch 410/1500: 100%|███████████████████| 56/56 [00:00<00:00, 242.87batch/s, counter=3, lastLoss=0.0763, valLoss=0.074]
Epoch 411/1500: 100%|████████████████████| 56/56 [00:00<00:00, 248.60batch/s, counter=4, lastLoss=0.0771, valLoss=0.07]
Epoch 412/1500: 100%|██████████████████| 56/56 [00:00<00:00, 246.01batch/s, counter=5, lastLoss=0.0759, valLoss=0.0676]
Epoch 413/1500: 100%|██████████████████| 56/56 [00:00<00:00, 245.91batch/s, counter=6, lastLoss=0.0756, valLoss=0.0714]
Epoch 414/1500: 100%|██████████████████| 56/56 [00:00<00:00, 246.13batch/s, counter=7, lastLoss=0.0755, valLoss=0.0706]
Epoch 415/1500: 100%|██████████████████| 56/56 [00:00<00:00, 249.99batch/s, counter=8, lastLoss=0.0764, valLoss=0.0815]
Epoch 416/1500: 100%|██████████████████| 56/56 [00:00<00:00, 248.28batch/s, counter=9, lastLoss=0.0763, valLoss=0.0804]
Epoch 417/1500: 100%|█████████████████| 

Epoch 477/1500: 100%|█████████████████| 56/56 [00:00<00:00, 248.45batch/s, counter=19, lastLoss=0.0716, valLoss=0.0672]
Epoch 478/1500: 100%|██████████████████| 56/56 [00:00<00:00, 242.12batch/s, counter=20, lastLoss=0.073, valLoss=0.0748]
Epoch 479/1500: 100%|█████████████████| 56/56 [00:00<00:00, 250.00batch/s, counter=21, lastLoss=0.0735, valLoss=0.0694]
Epoch 480/1500: 100%|███████████████████| 56/56 [00:00<00:00, 246.97batch/s, counter=22, lastLoss=0.0715, valLoss=0.07]
Epoch 481/1500: 100%|█████████████████| 56/56 [00:00<00:00, 249.99batch/s, counter=23, lastLoss=0.0737, valLoss=0.0694]
Epoch 482/1500: 100%|██████████████████| 56/56 [00:00<00:00, 240.32batch/s, counter=0, lastLoss=0.0718, valLoss=0.0654]
Epoch 483/1500: 100%|██████████████████| 56/56 [00:00<00:00, 251.11batch/s, counter=1, lastLoss=0.0722, valLoss=0.0661]
Epoch 484/1500: 100%|██████████████████| 56/56 [00:00<00:00, 251.11batch/s, counter=2, lastLoss=0.0724, valLoss=0.0662]
Epoch 485/1500: 100%|██████████████████|

Epoch 545/1500: 100%|█████████████████| 56/56 [00:00<00:00, 248.41batch/s, counter=33, lastLoss=0.0698, valLoss=0.0709]
Epoch 546/1500: 100%|█████████████████| 56/56 [00:00<00:00, 238.66batch/s, counter=34, lastLoss=0.0702, valLoss=0.0646]
Epoch 547/1500: 100%|█████████████████| 56/56 [00:00<00:00, 250.62batch/s, counter=35, lastLoss=0.0699, valLoss=0.0683]
Epoch 548/1500: 100%|██████████████████| 56/56 [00:00<00:00, 247.26batch/s, counter=36, lastLoss=0.0694, valLoss=0.067]
Epoch 549/1500: 100%|█████████████████| 56/56 [00:00<00:00, 248.68batch/s, counter=37, lastLoss=0.0699, valLoss=0.0718]
Epoch 550/1500: 100%|███████████████████| 56/56 [00:00<00:00, 236.30batch/s, counter=0, lastLoss=0.0712, valLoss=0.064]
Epoch 551/1500: 100%|██████████████████| 56/56 [00:00<00:00, 234.82batch/s, counter=1, lastLoss=0.0711, valLoss=0.0676]
Epoch 552/1500: 100%|███████████████████| 56/56 [00:00<00:00, 245.08batch/s, counter=2, lastLoss=0.0705, valLoss=0.066]
Epoch 553/1500: 100%|██████████████████|

Epoch 613/1500: 100%|██████████████████| 56/56 [00:00<00:00, 245.96batch/s, counter=0, lastLoss=0.0679, valLoss=0.0628]
Epoch 614/1500: 100%|██████████████████| 56/56 [00:00<00:00, 247.05batch/s, counter=1, lastLoss=0.0703, valLoss=0.0661]
Epoch 615/1500: 100%|██████████████████| 56/56 [00:00<00:00, 246.06batch/s, counter=0, lastLoss=0.0691, valLoss=0.0627]
Epoch 616/1500: 100%|██████████████████| 56/56 [00:00<00:00, 250.12batch/s, counter=1, lastLoss=0.0686, valLoss=0.0733]
Epoch 617/1500: 100%|████████████████████| 56/56 [00:00<00:00, 240.85batch/s, counter=2, lastLoss=0.0713, valLoss=0.07]
Epoch 618/1500: 100%|██████████████████| 56/56 [00:00<00:00, 249.81batch/s, counter=0, lastLoss=0.0695, valLoss=0.0603]
Epoch 619/1500: 100%|██████████████████| 56/56 [00:00<00:00, 243.20batch/s, counter=1, lastLoss=0.0689, valLoss=0.0667]
Epoch 620/1500: 100%|███████████████████| 56/56 [00:00<00:00, 249.32batch/s, counter=2, lastLoss=0.068, valLoss=0.0688]
Epoch 621/1500: 100%|██████████████████|

Early stopping after 667 epochs
Average train loss: 0.039126534281963755
Average validation loss: 0.03812617767887811
Running RNN based on multi_channel, with dataset using non-combined min-max
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 120.89batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 122.73batch/s, counter=0, lastLoss=0.188, valLoss=0.181]
Epoch 3/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 125.83batch/s, counter=1, lastLoss=0.187, valLoss=0.187]
Epoch 4/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 124.38batch/s, counter=0, lastLoss=0.186, valLoss=0.18]
Epoch 5/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 126.51batch/s, counter=1, lastLoss=0.187, valLoss=0.19]
Epoch 6/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 125.93batch/s, counter=2, lastLoss=0.187, valLoss=0.183]
Epoch 7/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 125.94batch/s, counter=3, lastLoss=0.186, valLoss=0.181]
Epoch 8/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 125.95batch/s, counter=4, lastLoss=0.186, valLoss=0.184]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|████████████████████| 56/56 [00:00<00:00, 126.20batch/s, counter=16, lastLoss=0.186, valLoss=0.187]
Epoch 70/1500: 100%|████████████████████| 56/56 [00:00<00:00, 124.45batch/s, counter=17, lastLoss=0.185, valLoss=0.179]
Epoch 71/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 124.38batch/s, counter=18, lastLoss=0.185, valLoss=0.18]
Epoch 72/1500: 100%|████████████████████| 56/56 [00:00<00:00, 126.17batch/s, counter=19, lastLoss=0.185, valLoss=0.183]
Epoch 73/1500: 100%|████████████████████| 56/56 [00:00<00:00, 126.02batch/s, counter=20, lastLoss=0.185, valLoss=0.185]
Epoch 74/1500: 100%|████████████████████| 56/56 [00:00<00:00, 126.17batch/s, counter=21, lastLoss=0.185, valLoss=0.181]
Epoch 75/1500: 100%|████████████████████| 56/56 [00:00<00:00, 129.30batch/s, counter=22, lastLoss=0.185, valLoss=0.187]
Epoch 76/1500: 100%|████████████████████| 56/56 [00:00<00:00, 127.02batch/s, counter=23, lastLoss=0.186, valLoss=0.187]
Epoch 77/1500: 100%|████████████████████

Early stopping after 102 epochs
Average train loss: 0.012632632607328042
Average validation loss: 0.012475353315803739
Running RNN based on independent, with dataset using non-combined min-max
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 211.85batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 212.58batch/s, counter=0, lastLoss=0.186, valLoss=0.184]
Epoch 3/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 207.09batch/s, counter=0, lastLoss=0.184, valLoss=0.179]
Epoch 4/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 210.13batch/s, counter=0, lastLoss=0.179, valLoss=0.167]
Epoch 5/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 208.49batch/s, counter=0, lastLoss=0.17, valLoss=0.162]
Epoch 6/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 209.48batch/s, counter=0, lastLoss=0.167, valLoss=0.16]
Epoch 7/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 211.86batch/s, counter=1, lastLoss=0.165, valLoss=0.171]
Epoch 8/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 211.71batch/s, counter=0, lastLoss=0.163, valLoss=0.158]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 209.08batch/s, counter=0, lastLoss=0.113, valLoss=0.109]
Epoch 70/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 204.74batch/s, counter=1, lastLoss=0.113, valLoss=0.111]
Epoch 71/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 208.40batch/s, counter=2, lastLoss=0.113, valLoss=0.111]
Epoch 72/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 207.77batch/s, counter=3, lastLoss=0.113, valLoss=0.11]
Epoch 73/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 210.92batch/s, counter=4, lastLoss=0.113, valLoss=0.109]
Epoch 74/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 204.01batch/s, counter=5, lastLoss=0.112, valLoss=0.115]
Epoch 75/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 208.56batch/s, counter=6, lastLoss=0.111, valLoss=0.111]
Epoch 76/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 208.81batch/s, counter=0, lastLoss=0.112, valLoss=0.106]
Epoch 77/1500: 100%|████████████████████

Epoch 137/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 202.56batch/s, counter=4, lastLoss=0.1, valLoss=0.0961]
Epoch 138/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 208.92batch/s, counter=5, lastLoss=0.1, valLoss=0.0956]
Epoch 139/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 210.43batch/s, counter=6, lastLoss=0.1, valLoss=0.094]
Epoch 140/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 205.20batch/s, counter=7, lastLoss=0.1, valLoss=0.096]
Epoch 141/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 208.56batch/s, counter=8, lastLoss=0.1, valLoss=0.0992]
Epoch 142/1500: 100%|███████████████████| 56/56 [00:00<00:00, 208.02batch/s, counter=9, lastLoss=0.0998, valLoss=0.101]
Epoch 143/1500: 100%|█████████████████| 56/56 [00:00<00:00, 204.74batch/s, counter=10, lastLoss=0.0997, valLoss=0.0951]
Epoch 144/1500: 100%|████████████████████| 56/56 [00:00<00:00, 208.43batch/s, counter=11, lastLoss=0.1, valLoss=0.0949]
Epoch 145/1500: 100%|███████████████████

Epoch 205/1500: 100%|█████████████████| 56/56 [00:00<00:00, 206.41batch/s, counter=10, lastLoss=0.0938, valLoss=0.0914]
Epoch 206/1500: 100%|██████████████████| 56/56 [00:00<00:00, 209.30batch/s, counter=0, lastLoss=0.0942, valLoss=0.0875]
Epoch 207/1500: 100%|██████████████████| 56/56 [00:00<00:00, 209.54batch/s, counter=1, lastLoss=0.0944, valLoss=0.0898]
Epoch 208/1500: 100%|███████████████████| 56/56 [00:00<00:00, 205.12batch/s, counter=2, lastLoss=0.094, valLoss=0.0882]
Epoch 209/1500: 100%|██████████████████| 56/56 [00:00<00:00, 206.36batch/s, counter=3, lastLoss=0.0936, valLoss=0.0903]
Epoch 210/1500: 100%|██████████████████| 56/56 [00:00<00:00, 209.58batch/s, counter=4, lastLoss=0.0941, valLoss=0.0875]
Epoch 211/1500: 100%|██████████████████| 56/56 [00:00<00:00, 205.58batch/s, counter=5, lastLoss=0.0934, valLoss=0.0886]
Epoch 212/1500: 100%|██████████████████| 56/56 [00:00<00:00, 206.26batch/s, counter=6, lastLoss=0.0935, valLoss=0.0888]
Epoch 213/1500: 100%|██████████████████|

Epoch 273/1500: 100%|██████████████████| 56/56 [00:00<00:00, 197.00batch/s, counter=6, lastLoss=0.0882, valLoss=0.0835]
Epoch 274/1500: 100%|██████████████████| 56/56 [00:00<00:00, 207.96batch/s, counter=7, lastLoss=0.0872, valLoss=0.0852]
Epoch 275/1500: 100%|██████████████████| 56/56 [00:00<00:00, 202.30batch/s, counter=8, lastLoss=0.0878, valLoss=0.0881]
Epoch 276/1500: 100%|██████████████████| 56/56 [00:00<00:00, 207.93batch/s, counter=9, lastLoss=0.0873, valLoss=0.0881]
Epoch 277/1500: 100%|██████████████████| 56/56 [00:00<00:00, 207.80batch/s, counter=10, lastLoss=0.0873, valLoss=0.083]
Epoch 278/1500: 100%|█████████████████| 56/56 [00:00<00:00, 203.64batch/s, counter=11, lastLoss=0.0869, valLoss=0.0817]
Epoch 279/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.41batch/s, counter=12, lastLoss=0.0871, valLoss=0.0938]
Epoch 280/1500: 100%|█████████████████| 56/56 [00:00<00:00, 206.25batch/s, counter=13, lastLoss=0.0876, valLoss=0.0855]
Epoch 281/1500: 100%|█████████████████| 

Epoch 341/1500: 100%|█████████████████| 56/56 [00:00<00:00, 208.12batch/s, counter=11, lastLoss=0.0832, valLoss=0.0816]
Epoch 342/1500: 100%|██████████████████| 56/56 [00:00<00:00, 203.35batch/s, counter=0, lastLoss=0.0824, valLoss=0.0739]
Epoch 343/1500: 100%|██████████████████| 56/56 [00:00<00:00, 207.41batch/s, counter=1, lastLoss=0.0812, valLoss=0.0794]
Epoch 344/1500: 100%|██████████████████| 56/56 [00:00<00:00, 196.89batch/s, counter=2, lastLoss=0.0821, valLoss=0.0762]
Epoch 345/1500: 100%|██████████████████| 56/56 [00:00<00:00, 210.53batch/s, counter=3, lastLoss=0.0813, valLoss=0.0809]
Epoch 346/1500: 100%|██████████████████| 56/56 [00:00<00:00, 204.00batch/s, counter=4, lastLoss=0.0834, valLoss=0.0774]
Epoch 347/1500: 100%|██████████████████| 56/56 [00:00<00:00, 210.17batch/s, counter=5, lastLoss=0.0822, valLoss=0.0761]
Epoch 348/1500: 100%|███████████████████| 56/56 [00:00<00:00, 206.00batch/s, counter=6, lastLoss=0.0807, valLoss=0.079]
Epoch 349/1500: 100%|██████████████████|

Epoch 409/1500: 100%|██████████████████| 56/56 [00:00<00:00, 207.04batch/s, counter=12, lastLoss=0.0777, valLoss=0.077]
Epoch 410/1500: 100%|██████████████████| 56/56 [00:00<00:00, 208.95batch/s, counter=13, lastLoss=0.079, valLoss=0.0767]
Epoch 411/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.87batch/s, counter=14, lastLoss=0.0792, valLoss=0.0774]
Epoch 412/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.30batch/s, counter=15, lastLoss=0.0787, valLoss=0.0749]
Epoch 413/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.13batch/s, counter=16, lastLoss=0.0777, valLoss=0.0752]
Epoch 414/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.27batch/s, counter=17, lastLoss=0.0777, valLoss=0.0786]
Epoch 415/1500: 100%|█████████████████| 56/56 [00:00<00:00, 200.74batch/s, counter=18, lastLoss=0.0777, valLoss=0.0762]
Epoch 416/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.88batch/s, counter=19, lastLoss=0.0783, valLoss=0.0738]
Epoch 417/1500: 100%|█████████████████| 

Epoch 477/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.89batch/s, counter=16, lastLoss=0.0755, valLoss=0.0724]
Epoch 478/1500: 100%|███████████████████| 56/56 [00:00<00:00, 206.13batch/s, counter=17, lastLoss=0.0756, valLoss=0.07]
Epoch 479/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.62batch/s, counter=18, lastLoss=0.0772, valLoss=0.0759]
Epoch 480/1500: 100%|█████████████████| 56/56 [00:00<00:00, 208.57batch/s, counter=19, lastLoss=0.0767, valLoss=0.0741]
Epoch 481/1500: 100%|█████████████████| 56/56 [00:00<00:00, 204.33batch/s, counter=20, lastLoss=0.0764, valLoss=0.0806]
Epoch 482/1500: 100%|█████████████████| 56/56 [00:00<00:00, 211.47batch/s, counter=21, lastLoss=0.0745, valLoss=0.0701]
Epoch 483/1500: 100%|█████████████████| 56/56 [00:00<00:00, 206.18batch/s, counter=22, lastLoss=0.0749, valLoss=0.0718]
Epoch 484/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.44batch/s, counter=23, lastLoss=0.0739, valLoss=0.0788]
Epoch 485/1500: 100%|█████████████████| 

Early stopping after 510 epochs
Average train loss: 0.032129140537099116
Average validation loss: 0.03132247010481855
Running RNN based on concatenate, with dataset using non-combined stats
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 226.27batch/s]
Epoch 2/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 232.68batch/s, counter=0, lastLoss=0.115, valLoss=0.11]
Epoch 3/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 239.23batch/s, counter=0, lastLoss=0.103, valLoss=0.104]
Epoch 4/1500: 100%|████████████████████| 56/56 [00:00<00:00, 241.64batch/s, counter=0, lastLoss=0.0983, valLoss=0.0992]
Epoch 5/1500: 100%|████████████████████| 56/56 [00:00<00:00, 238.05batch/s, counter=0, lastLoss=0.0953, valLoss=0.0879]
Epoch 6/1500: 100%|████████████████████| 56/56 [00:00<00:00, 231.83batch/s, counter=1, lastLoss=0.0933, valLoss=0.0915]
Epoch 7/1500: 100%|████████████████████| 56/56 [00:00<00:00, 241.04batch/s, counter=2, lastLoss=0.0879, valLoss=0.0889]
Epoch 8/1500: 100%|████████████████████| 56/56 [00:00<00:00, 243.46batch/s, counter=3, lastLoss=0.0867, valLoss=0.0906]
Epoch 9/1500: 100%|████████████████████|

Epoch 69/1500: 100%|███████████████████| 56/56 [00:00<00:00, 239.47batch/s, counter=0, lastLoss=0.0356, valLoss=0.0647]
Epoch 70/1500: 100%|████████████████████| 56/56 [00:00<00:00, 239.21batch/s, counter=1, lastLoss=0.0354, valLoss=0.071]
Epoch 71/1500: 100%|███████████████████| 56/56 [00:00<00:00, 242.13batch/s, counter=2, lastLoss=0.0347, valLoss=0.0715]
Epoch 72/1500: 100%|███████████████████| 56/56 [00:00<00:00, 236.55batch/s, counter=3, lastLoss=0.0359, valLoss=0.0669]
Epoch 73/1500: 100%|███████████████████| 56/56 [00:00<00:00, 226.75batch/s, counter=4, lastLoss=0.0345, valLoss=0.0682]
Epoch 74/1500: 100%|███████████████████| 56/56 [00:00<00:00, 238.30batch/s, counter=5, lastLoss=0.0345, valLoss=0.0671]
Epoch 75/1500: 100%|███████████████████| 56/56 [00:00<00:00, 236.28batch/s, counter=6, lastLoss=0.0342, valLoss=0.0705]
Epoch 76/1500: 100%|███████████████████| 56/56 [00:00<00:00, 236.76batch/s, counter=7, lastLoss=0.0354, valLoss=0.0679]
Epoch 77/1500: 100%|███████████████████|

Epoch 137/1500: 100%|██████████████████| 56/56 [00:00<00:00, 238.80batch/s, counter=5, lastLoss=0.0232, valLoss=0.0646]
Epoch 138/1500: 100%|███████████████████| 56/56 [00:00<00:00, 238.14batch/s, counter=0, lastLoss=0.0231, valLoss=0.062]
Epoch 139/1500: 100%|███████████████████| 56/56 [00:00<00:00, 241.90batch/s, counter=1, lastLoss=0.024, valLoss=0.0647]
Epoch 140/1500: 100%|███████████████████| 56/56 [00:00<00:00, 238.30batch/s, counter=0, lastLoss=0.0233, valLoss=0.062]
Epoch 141/1500: 100%|██████████████████| 56/56 [00:00<00:00, 240.33batch/s, counter=1, lastLoss=0.0231, valLoss=0.0648]
Epoch 142/1500: 100%|██████████████████| 56/56 [00:00<00:00, 239.89batch/s, counter=2, lastLoss=0.0227, valLoss=0.0631]
Epoch 143/1500: 100%|██████████████████| 56/56 [00:00<00:00, 235.78batch/s, counter=3, lastLoss=0.0223, valLoss=0.0661]
Epoch 144/1500: 100%|███████████████████| 56/56 [00:00<00:00, 239.48batch/s, counter=4, lastLoss=0.023, valLoss=0.0658]
Epoch 145/1500: 100%|██████████████████|

Epoch 205/1500: 100%|█████████████████| 56/56 [00:00<00:00, 237.22batch/s, counter=49, lastLoss=0.0186, valLoss=0.0796]


Early stopping after 205 epochs
Average train loss: 0.004868390732699827
Average validation loss: 0.009553001387789846
Running RNN based on multi_channel, with dataset using non-combined stats
Using: cuda:0


Epoch 1/1500: 100%|█████████████████████████████████████████████████████████████████| 56/56 [00:01<00:00, 47.08batch/s]
Epoch 2/1500: 100%|███████████████████████| 56/56 [00:01<00:00, 47.55batch/s, counter=0, lastLoss=0.176, valLoss=0.175]
Epoch 3/1500: 100%|███████████████████████| 56/56 [00:01<00:00, 47.79batch/s, counter=0, lastLoss=0.172, valLoss=0.163]
Epoch 4/1500: 100%|███████████████████████| 56/56 [00:01<00:00, 47.82batch/s, counter=0, lastLoss=0.165, valLoss=0.158]
Epoch 5/1500: 100%|███████████████████████| 56/56 [00:01<00:00, 47.71batch/s, counter=0, lastLoss=0.162, valLoss=0.155]
Epoch 6/1500: 100%|███████████████████████| 56/56 [00:01<00:00, 47.92batch/s, counter=1, lastLoss=0.159, valLoss=0.158]
Epoch 7/1500: 100%|███████████████████████| 56/56 [00:01<00:00, 47.73batch/s, counter=0, lastLoss=0.155, valLoss=0.147]
Epoch 8/1500: 100%|███████████████████████| 56/56 [00:01<00:00, 47.95batch/s, counter=1, lastLoss=0.154, valLoss=0.158]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|█████████████████████| 56/56 [00:01<00:00, 47.23batch/s, counter=21, lastLoss=0.104, valLoss=0.121]
Epoch 70/1500: 100%|██████████████████████| 56/56 [00:01<00:00, 47.30batch/s, counter=22, lastLoss=0.104, valLoss=0.12]
Epoch 71/1500: 100%|█████████████████████| 56/56 [00:01<00:00, 47.46batch/s, counter=23, lastLoss=0.104, valLoss=0.124]
Epoch 72/1500: 100%|██████████████████████| 56/56 [00:01<00:00, 47.30batch/s, counter=0, lastLoss=0.104, valLoss=0.116]
Epoch 73/1500: 100%|██████████████████████| 56/56 [00:01<00:00, 47.37batch/s, counter=1, lastLoss=0.104, valLoss=0.128]
Epoch 74/1500: 100%|██████████████████████| 56/56 [00:01<00:00, 47.20batch/s, counter=2, lastLoss=0.104, valLoss=0.122]
Epoch 75/1500: 100%|██████████████████████| 56/56 [00:01<00:00, 47.45batch/s, counter=0, lastLoss=0.102, valLoss=0.115]
Epoch 76/1500: 100%|██████████████████████| 56/56 [00:01<00:00, 47.33batch/s, counter=0, lastLoss=0.102, valLoss=0.112]
Epoch 77/1500: 100%|████████████████████

Early stopping after 125 epochs
Average train loss: 0.00935333549900956
Average validation loss: 0.010575502739846705
Running RNN based on independent, with dataset using non-combined stats
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 197.48batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 205.89batch/s, counter=0, lastLoss=0.119, valLoss=0.106]
Epoch 3/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 200.72batch/s, counter=0, lastLoss=0.106, valLoss=0.0989]
Epoch 4/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 199.07batch/s, counter=1, lastLoss=0.101, valLoss=0.0993]
Epoch 5/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 201.07batch/s, counter=2, lastLoss=0.0974, valLoss=0.101]
Epoch 6/1500: 100%|████████████████████| 56/56 [00:00<00:00, 200.38batch/s, counter=0, lastLoss=0.0941, valLoss=0.0979]
Epoch 7/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 202.90batch/s, counter=1, lastLoss=0.0914, valLoss=0.101]
Epoch 8/1500: 100%|████████████████████| 56/56 [00:00<00:00, 189.72batch/s, counter=0, lastLoss=0.0926, valLoss=0.0932]
Epoch 9/1500: 100%|████████████████████|

Epoch 69/1500: 100%|███████████████████| 56/56 [00:00<00:00, 199.24batch/s, counter=7, lastLoss=0.0419, valLoss=0.0732]
Epoch 70/1500: 100%|███████████████████| 56/56 [00:00<00:00, 197.55batch/s, counter=8, lastLoss=0.0432, valLoss=0.0737]
Epoch 71/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 192.12batch/s, counter=9, lastLoss=0.04, valLoss=0.072]
Epoch 72/1500: 100%|███████████████████| 56/56 [00:00<00:00, 199.78batch/s, counter=0, lastLoss=0.0408, valLoss=0.0706]
Epoch 73/1500: 100%|███████████████████| 56/56 [00:00<00:00, 197.53batch/s, counter=1, lastLoss=0.0399, valLoss=0.0736]
Epoch 74/1500: 100%|███████████████████| 56/56 [00:00<00:00, 191.13batch/s, counter=2, lastLoss=0.0391, valLoss=0.0733]
Epoch 75/1500: 100%|███████████████████| 56/56 [00:00<00:00, 197.88batch/s, counter=3, lastLoss=0.0397, valLoss=0.0708]
Epoch 76/1500: 100%|███████████████████| 56/56 [00:00<00:00, 199.58batch/s, counter=4, lastLoss=0.0388, valLoss=0.0733]
Epoch 77/1500: 100%|███████████████████|

Epoch 137/1500: 100%|█████████████████| 56/56 [00:00<00:00, 197.07batch/s, counter=12, lastLoss=0.0283, valLoss=0.0749]
Epoch 138/1500: 100%|█████████████████| 56/56 [00:00<00:00, 193.78batch/s, counter=13, lastLoss=0.0275, valLoss=0.0687]
Epoch 139/1500: 100%|██████████████████| 56/56 [00:00<00:00, 198.53batch/s, counter=14, lastLoss=0.0294, valLoss=0.072]
Epoch 140/1500: 100%|█████████████████| 56/56 [00:00<00:00, 195.80batch/s, counter=15, lastLoss=0.0269, valLoss=0.0707]
Epoch 141/1500: 100%|█████████████████| 56/56 [00:00<00:00, 196.15batch/s, counter=16, lastLoss=0.0268, valLoss=0.0706]
Epoch 142/1500: 100%|█████████████████| 56/56 [00:00<00:00, 189.82batch/s, counter=17, lastLoss=0.0258, valLoss=0.0695]
Epoch 143/1500: 100%|██████████████████| 56/56 [00:00<00:00, 198.67batch/s, counter=18, lastLoss=0.026, valLoss=0.0677]
Epoch 144/1500: 100%|█████████████████| 56/56 [00:00<00:00, 199.99batch/s, counter=19, lastLoss=0.0264, valLoss=0.0709]
Epoch 145/1500: 100%|█████████████████| 

Epoch 205/1500: 100%|███████████████████| 56/56 [00:00<00:00, 189.51batch/s, counter=1, lastLoss=0.0218, valLoss=0.068]
Epoch 206/1500: 100%|██████████████████| 56/56 [00:00<00:00, 201.49batch/s, counter=2, lastLoss=0.0214, valLoss=0.0703]
Epoch 207/1500: 100%|██████████████████| 56/56 [00:00<00:00, 197.47batch/s, counter=3, lastLoss=0.0225, valLoss=0.0701]
Epoch 208/1500: 100%|███████████████████| 56/56 [00:00<00:00, 200.39batch/s, counter=4, lastLoss=0.0211, valLoss=0.076]
Epoch 209/1500: 100%|██████████████████| 56/56 [00:00<00:00, 189.19batch/s, counter=5, lastLoss=0.0207, valLoss=0.0693]
Epoch 210/1500: 100%|██████████████████| 56/56 [00:00<00:00, 196.49batch/s, counter=6, lastLoss=0.0212, valLoss=0.0702]
Epoch 211/1500: 100%|██████████████████| 56/56 [00:00<00:00, 199.60batch/s, counter=7, lastLoss=0.0213, valLoss=0.0735]
Epoch 212/1500: 100%|███████████████████| 56/56 [00:00<00:00, 190.98batch/s, counter=8, lastLoss=0.021, valLoss=0.0684]
Epoch 213/1500: 100%|██████████████████|

Early stopping after 253 epochs
Average train loss: 0.006115040725008363
Average validation loss: 0.012400354069347183
Running RNN based on concatenate, with dataset using non-combined raw
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 201.81batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 200.36batch/s, counter=0, lastLoss=0.179, valLoss=0.164]
Epoch 3/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 206.64batch/s, counter=0, lastLoss=0.163, valLoss=0.154]
Epoch 4/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 203.28batch/s, counter=0, lastLoss=0.156, valLoss=0.152]
Epoch 5/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 205.89batch/s, counter=0, lastLoss=0.146, valLoss=0.144]
Epoch 6/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 203.64batch/s, counter=0, lastLoss=0.141, valLoss=0.143]
Epoch 7/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 204.76batch/s, counter=0, lastLoss=0.139, valLoss=0.134]
Epoch 8/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 207.03batch/s, counter=0, lastLoss=0.135, valLoss=0.128]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|███████████████████| 56/56 [00:00<00:00, 207.36batch/s, counter=1, lastLoss=0.0823, valLoss=0.0824]
Epoch 70/1500: 100%|███████████████████| 56/56 [00:00<00:00, 204.76batch/s, counter=2, lastLoss=0.0823, valLoss=0.0867]
Epoch 71/1500: 100%|███████████████████| 56/56 [00:00<00:00, 206.26batch/s, counter=0, lastLoss=0.0821, valLoss=0.0781]
Epoch 72/1500: 100%|███████████████████| 56/56 [00:00<00:00, 205.51batch/s, counter=1, lastLoss=0.0837, valLoss=0.0898]
Epoch 73/1500: 100%|███████████████████| 56/56 [00:00<00:00, 207.07batch/s, counter=2, lastLoss=0.0827, valLoss=0.0812]
Epoch 74/1500: 100%|███████████████████| 56/56 [00:00<00:00, 203.43batch/s, counter=3, lastLoss=0.0805, valLoss=0.0824]
Epoch 75/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 206.87batch/s, counter=0, lastLoss=0.08, valLoss=0.0761]
Epoch 76/1500: 100%|███████████████████| 56/56 [00:00<00:00, 210.13batch/s, counter=1, lastLoss=0.0833, valLoss=0.0799]
Epoch 77/1500: 100%|███████████████████|

Epoch 137/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.88batch/s, counter=11, lastLoss=0.0684, valLoss=0.0741]
Epoch 138/1500: 100%|██████████████████| 56/56 [00:00<00:00, 206.26batch/s, counter=0, lastLoss=0.0679, valLoss=0.0659]
Epoch 139/1500: 100%|███████████████████| 56/56 [00:00<00:00, 207.14batch/s, counter=1, lastLoss=0.069, valLoss=0.0671]
Epoch 140/1500: 100%|██████████████████| 56/56 [00:00<00:00, 193.00batch/s, counter=0, lastLoss=0.0673, valLoss=0.0653]
Epoch 141/1500: 100%|██████████████████| 56/56 [00:00<00:00, 208.96batch/s, counter=1, lastLoss=0.0675, valLoss=0.0657]
Epoch 142/1500: 100%|████████████████████| 56/56 [00:00<00:00, 207.41batch/s, counter=2, lastLoss=0.0664, valLoss=0.07]
Epoch 143/1500: 100%|██████████████████| 56/56 [00:00<00:00, 204.76batch/s, counter=3, lastLoss=0.0668, valLoss=0.0717]
Epoch 144/1500: 100%|██████████████████| 56/56 [00:00<00:00, 208.19batch/s, counter=4, lastLoss=0.0681, valLoss=0.0659]
Epoch 145/1500: 100%|██████████████████|

Epoch 205/1500: 100%|███████████████████| 56/56 [00:00<00:00, 206.70batch/s, counter=10, lastLoss=0.06, valLoss=0.0599]
Epoch 206/1500: 100%|█████████████████| 56/56 [00:00<00:00, 206.48batch/s, counter=11, lastLoss=0.0613, valLoss=0.0691]
Epoch 207/1500: 100%|██████████████████| 56/56 [00:00<00:00, 199.64batch/s, counter=0, lastLoss=0.0622, valLoss=0.0588]
Epoch 208/1500: 100%|██████████████████| 56/56 [00:00<00:00, 208.95batch/s, counter=1, lastLoss=0.0616, valLoss=0.0646]
Epoch 209/1500: 100%|██████████████████| 56/56 [00:00<00:00, 206.23batch/s, counter=2, lastLoss=0.0607, valLoss=0.0692]
Epoch 210/1500: 100%|██████████████████| 56/56 [00:00<00:00, 203.14batch/s, counter=3, lastLoss=0.0643, valLoss=0.0594]
Epoch 211/1500: 100%|██████████████████| 56/56 [00:00<00:00, 208.96batch/s, counter=4, lastLoss=0.0613, valLoss=0.0684]
Epoch 212/1500: 100%|██████████████████| 56/56 [00:00<00:00, 206.63batch/s, counter=0, lastLoss=0.0609, valLoss=0.0581]
Epoch 213/1500: 100%|██████████████████|

Epoch 273/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.67batch/s, counter=11, lastLoss=0.0579, valLoss=0.0572]
Epoch 274/1500: 100%|█████████████████| 56/56 [00:00<00:00, 204.38batch/s, counter=12, lastLoss=0.0565, valLoss=0.0609]
Epoch 275/1500: 100%|█████████████████| 56/56 [00:00<00:00, 202.17batch/s, counter=13, lastLoss=0.0599, valLoss=0.0636]
Epoch 276/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.80batch/s, counter=14, lastLoss=0.0569, valLoss=0.0553]
Epoch 277/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.13batch/s, counter=15, lastLoss=0.0571, valLoss=0.0686]
Epoch 278/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.53batch/s, counter=16, lastLoss=0.0584, valLoss=0.0634]
Epoch 279/1500: 100%|█████████████████| 56/56 [00:00<00:00, 204.76batch/s, counter=17, lastLoss=0.0567, valLoss=0.0606]
Epoch 280/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.14batch/s, counter=18, lastLoss=0.0565, valLoss=0.0615]
Epoch 281/1500: 100%|██████████████████|

Epoch 341/1500: 100%|███████████████████| 56/56 [00:00<00:00, 204.01batch/s, counter=7, lastLoss=0.054, valLoss=0.0601]
Epoch 342/1500: 100%|███████████████████| 56/56 [00:00<00:00, 208.02batch/s, counter=8, lastLoss=0.0551, valLoss=0.056]
Epoch 343/1500: 100%|██████████████████| 56/56 [00:00<00:00, 203.87batch/s, counter=9, lastLoss=0.0554, valLoss=0.0571]
Epoch 344/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.82batch/s, counter=10, lastLoss=0.0529, valLoss=0.0592]
Epoch 345/1500: 100%|█████████████████| 56/56 [00:00<00:00, 196.15batch/s, counter=11, lastLoss=0.0553, valLoss=0.0611]
Epoch 346/1500: 100%|█████████████████| 56/56 [00:00<00:00, 206.69batch/s, counter=12, lastLoss=0.0546, valLoss=0.0592]
Epoch 347/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.88batch/s, counter=13, lastLoss=0.0541, valLoss=0.0542]
Epoch 348/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.06batch/s, counter=14, lastLoss=0.0543, valLoss=0.0539]
Epoch 349/1500: 100%|█████████████████| 

Epoch 409/1500: 100%|█████████████████| 56/56 [00:00<00:00, 208.57batch/s, counter=22, lastLoss=0.0536, valLoss=0.0509]
Epoch 410/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.43batch/s, counter=23, lastLoss=0.0535, valLoss=0.0597]
Epoch 411/1500: 100%|█████████████████| 56/56 [00:00<00:00, 205.88batch/s, counter=24, lastLoss=0.0535, valLoss=0.0517]
Epoch 412/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.80batch/s, counter=25, lastLoss=0.0535, valLoss=0.0582]
Epoch 413/1500: 100%|███████████████████| 56/56 [00:00<00:00, 205.89batch/s, counter=26, lastLoss=0.0525, valLoss=0.06]
Epoch 414/1500: 100%|█████████████████| 56/56 [00:00<00:00, 206.40batch/s, counter=27, lastLoss=0.0524, valLoss=0.0649]
Epoch 415/1500: 100%|█████████████████| 56/56 [00:00<00:00, 204.91batch/s, counter=28, lastLoss=0.0576, valLoss=0.0622]
Epoch 416/1500: 100%|█████████████████| 56/56 [00:00<00:00, 207.02batch/s, counter=29, lastLoss=0.0536, valLoss=0.0503]
Epoch 417/1500: 100%|██████████████████|

Early stopping after 436 epochs
Average train loss: 0.019859525808754063
Average validation loss: 0.020240845171569124
Running RNN based on multi_channel, with dataset using non-combined raw
Using: cuda:0


Epoch 1/1500: 100%|█████████████████████████████████████████████████████████████████| 56/56 [00:05<00:00, 10.56batch/s]
Epoch 2/1500: 100%|███████████████████████| 56/56 [00:05<00:00, 10.55batch/s, counter=0, lastLoss=0.188, valLoss=0.185]
Epoch 3/1500: 100%|███████████████████████| 56/56 [00:05<00:00, 10.56batch/s, counter=1, lastLoss=0.187, valLoss=0.188]
Epoch 4/1500: 100%|█████████████████████████| 56/56 [00:05<00:00, 10.58batch/s, counter=2, lastLoss=0.186, valLoss=0.2]
Epoch 5/1500: 100%|███████████████████████| 56/56 [00:05<00:00, 10.56batch/s, counter=0, lastLoss=0.187, valLoss=0.182]
Epoch 6/1500: 100%|███████████████████████| 56/56 [00:05<00:00, 10.54batch/s, counter=1, lastLoss=0.186, valLoss=0.191]
Epoch 7/1500: 100%|███████████████████████| 56/56 [00:05<00:00, 10.54batch/s, counter=2, lastLoss=0.186, valLoss=0.183]
Epoch 8/1500: 100%|███████████████████████| 56/56 [00:05<00:00, 10.54batch/s, counter=3, lastLoss=0.186, valLoss=0.185]
Epoch 9/1500: 100%|█████████████████████

Early stopping after 66 epochs
Average train loss: 0.008159620154294228
Average validation loss: 0.008178285130858422
Running RNN based on independent, with dataset using non-combined raw
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 162.80batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 163.52batch/s, counter=0, lastLoss=0.178, valLoss=0.173]
Epoch 3/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 163.28batch/s, counter=0, lastLoss=0.164, valLoss=0.16]
Epoch 4/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 156.67batch/s, counter=0, lastLoss=0.154, valLoss=0.149]
Epoch 5/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 162.46batch/s, counter=0, lastLoss=0.147, valLoss=0.147]
Epoch 6/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 161.40batch/s, counter=0, lastLoss=0.142, valLoss=0.137]
Epoch 7/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 161.87batch/s, counter=0, lastLoss=0.138, valLoss=0.133]
Epoch 8/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 160.48batch/s, counter=1, lastLoss=0.136, valLoss=0.133]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|███████████████████| 56/56 [00:00<00:00, 164.72batch/s, counter=0, lastLoss=0.0865, valLoss=0.0816]
Epoch 70/1500: 100%|███████████████████| 56/56 [00:00<00:00, 159.56batch/s, counter=1, lastLoss=0.0855, valLoss=0.0837]
Epoch 71/1500: 100%|████████████████████| 56/56 [00:00<00:00, 163.05batch/s, counter=2, lastLoss=0.086, valLoss=0.0829]
Epoch 72/1500: 100%|███████████████████| 56/56 [00:00<00:00, 163.76batch/s, counter=3, lastLoss=0.0863, valLoss=0.0908]
Epoch 73/1500: 100%|███████████████████| 56/56 [00:00<00:00, 164.43batch/s, counter=4, lastLoss=0.0854, valLoss=0.0878]
Epoch 74/1500: 100%|███████████████████| 56/56 [00:00<00:00, 161.63batch/s, counter=5, lastLoss=0.0864, valLoss=0.0868]
Epoch 75/1500: 100%|███████████████████| 56/56 [00:00<00:00, 163.52batch/s, counter=6, lastLoss=0.0847, valLoss=0.0916]
Epoch 76/1500: 100%|███████████████████| 56/56 [00:00<00:00, 163.28batch/s, counter=7, lastLoss=0.0852, valLoss=0.0877]
Epoch 77/1500: 100%|███████████████████|

Epoch 137/1500: 100%|███████████████████| 56/56 [00:00<00:00, 161.63batch/s, counter=2, lastLoss=0.074, valLoss=0.0722]
Epoch 138/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.80batch/s, counter=3, lastLoss=0.0732, valLoss=0.0751]
Epoch 139/1500: 100%|██████████████████| 56/56 [00:00<00:00, 161.16batch/s, counter=4, lastLoss=0.0727, valLoss=0.0744]
Epoch 140/1500: 100%|██████████████████| 56/56 [00:00<00:00, 156.47batch/s, counter=5, lastLoss=0.0722, valLoss=0.0737]
Epoch 141/1500: 100%|██████████████████| 56/56 [00:00<00:00, 162.48batch/s, counter=6, lastLoss=0.0734, valLoss=0.0825]
Epoch 142/1500: 100%|██████████████████| 56/56 [00:00<00:00, 162.09batch/s, counter=7, lastLoss=0.0714, valLoss=0.0742]
Epoch 143/1500: 100%|██████████████████| 56/56 [00:00<00:00, 162.11batch/s, counter=8, lastLoss=0.0726, valLoss=0.0789]
Epoch 144/1500: 100%|██████████████████| 56/56 [00:00<00:00, 161.86batch/s, counter=0, lastLoss=0.0719, valLoss=0.0695]
Epoch 145/1500: 100%|███████████████████

Epoch 205/1500: 100%|███████████████████| 56/56 [00:00<00:00, 162.45batch/s, counter=4, lastLoss=0.067, valLoss=0.0701]
Epoch 206/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.70batch/s, counter=5, lastLoss=0.0659, valLoss=0.0675]
Epoch 207/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.04batch/s, counter=6, lastLoss=0.0669, valLoss=0.0685]
Epoch 208/1500: 100%|██████████████████| 56/56 [00:00<00:00, 155.40batch/s, counter=7, lastLoss=0.0676, valLoss=0.0714]
Epoch 209/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.28batch/s, counter=8, lastLoss=0.0675, valLoss=0.0708]
Epoch 210/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.28batch/s, counter=9, lastLoss=0.0672, valLoss=0.0658]
Epoch 211/1500: 100%|█████████████████| 56/56 [00:00<00:00, 161.63batch/s, counter=10, lastLoss=0.0665, valLoss=0.0699]
Epoch 212/1500: 100%|██████████████████| 56/56 [00:00<00:00, 161.87batch/s, counter=11, lastLoss=0.067, valLoss=0.0751]
Epoch 213/1500: 100%|█████████████████| 

Epoch 273/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.58batch/s, counter=6, lastLoss=0.0637, valLoss=0.0669]
Epoch 274/1500: 100%|███████████████████| 56/56 [00:00<00:00, 160.90batch/s, counter=7, lastLoss=0.0627, valLoss=0.066]
Epoch 275/1500: 100%|██████████████████| 56/56 [00:00<00:00, 161.16batch/s, counter=0, lastLoss=0.0646, valLoss=0.0615]
Epoch 276/1500: 100%|██████████████████| 56/56 [00:00<00:00, 161.59batch/s, counter=1, lastLoss=0.0638, valLoss=0.0673]
Epoch 277/1500: 100%|███████████████████| 56/56 [00:00<00:00, 160.70batch/s, counter=2, lastLoss=0.0645, valLoss=0.066]
Epoch 278/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.28batch/s, counter=3, lastLoss=0.0647, valLoss=0.0694]
Epoch 279/1500: 100%|██████████████████| 56/56 [00:00<00:00, 162.56batch/s, counter=0, lastLoss=0.0651, valLoss=0.0612]
Epoch 280/1500: 100%|██████████████████| 56/56 [00:00<00:00, 162.88batch/s, counter=1, lastLoss=0.0635, valLoss=0.0642]
Epoch 281/1500: 100%|███████████████████

Epoch 341/1500: 100%|██████████████████| 56/56 [00:00<00:00, 161.79batch/s, counter=0, lastLoss=0.0616, valLoss=0.0577]
Epoch 342/1500: 100%|██████████████████| 56/56 [00:00<00:00, 160.49batch/s, counter=1, lastLoss=0.0601, valLoss=0.0612]
Epoch 343/1500: 100%|██████████████████| 56/56 [00:00<00:00, 161.23batch/s, counter=2, lastLoss=0.0609, valLoss=0.0618]
Epoch 344/1500: 100%|██████████████████| 56/56 [00:00<00:00, 162.81batch/s, counter=3, lastLoss=0.0603, valLoss=0.0604]
Epoch 345/1500: 100%|██████████████████| 56/56 [00:00<00:00, 163.75batch/s, counter=4, lastLoss=0.0612, valLoss=0.0641]
Epoch 346/1500: 100%|██████████████████| 56/56 [00:00<00:00, 160.02batch/s, counter=5, lastLoss=0.0615, valLoss=0.0624]
Epoch 347/1500: 100%|██████████████████| 56/56 [00:00<00:00, 162.81batch/s, counter=6, lastLoss=0.0608, valLoss=0.0614]
Epoch 348/1500: 100%|██████████████████| 56/56 [00:00<00:00, 160.93batch/s, counter=7, lastLoss=0.0616, valLoss=0.0606]
Epoch 349/1500: 100%|██████████████████|

Early stopping after 390 epochs
Average train loss: 0.019488499510811552
Average validation loss: 0.01972105626927482
Running RNN based on concatenate, with dataset using non-combined entropy
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 249.47batch/s]
Epoch 2/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 251.19batch/s, counter=1, lastLoss=nan, valLoss=nan]
Epoch 3/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 255.72batch/s, counter=2, lastLoss=nan, valLoss=nan]
Epoch 4/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 247.18batch/s, counter=3, lastLoss=nan, valLoss=nan]
Epoch 5/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 250.59batch/s, counter=4, lastLoss=nan, valLoss=nan]
Epoch 6/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 253.69batch/s, counter=5, lastLoss=nan, valLoss=nan]
Epoch 7/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 252.27batch/s, counter=6, lastLoss=nan, valLoss=nan]
Epoch 8/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 246.93batch/s, counter=7, lastLoss=nan, valLoss=nan]
Epoch 9/1500: 100%|█████████████████████

Early stopping after 50 epochs
Average train loss: nan
Average validation loss: nan
Running RNN based on multi_channel, with dataset using non-combined entropy
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 212.29batch/s]
Epoch 2/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 208.12batch/s, counter=1, lastLoss=nan, valLoss=nan]
Epoch 3/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 210.08batch/s, counter=2, lastLoss=nan, valLoss=nan]
Epoch 4/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 209.10batch/s, counter=3, lastLoss=nan, valLoss=nan]
Epoch 5/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 207.97batch/s, counter=4, lastLoss=nan, valLoss=nan]
Epoch 6/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 210.48batch/s, counter=5, lastLoss=nan, valLoss=nan]
Epoch 7/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 211.61batch/s, counter=6, lastLoss=nan, valLoss=nan]
Epoch 8/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 211.12batch/s, counter=7, lastLoss=nan, valLoss=nan]
Epoch 9/1500: 100%|█████████████████████

Early stopping after 50 epochs
Average train loss: nan
Average validation loss: nan
Running RNN based on independent, with dataset using non-combined entropy
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 201.10batch/s]
Epoch 2/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 209.14batch/s, counter=1, lastLoss=nan, valLoss=nan]
Epoch 3/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 204.30batch/s, counter=2, lastLoss=nan, valLoss=nan]
Epoch 4/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 209.03batch/s, counter=3, lastLoss=nan, valLoss=nan]
Epoch 5/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 208.53batch/s, counter=4, lastLoss=nan, valLoss=nan]
Epoch 6/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 209.37batch/s, counter=5, lastLoss=nan, valLoss=nan]
Epoch 7/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 210.15batch/s, counter=6, lastLoss=nan, valLoss=nan]
Epoch 8/1500: 100%|██████████████████████████| 56/56 [00:00<00:00, 211.95batch/s, counter=7, lastLoss=nan, valLoss=nan]
Epoch 9/1500: 100%|█████████████████████

Early stopping after 50 epochs
Average train loss: nan
Average validation loss: nan
Running RNN based on concatenate, with dataset using non-combined energy
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 239.85batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 251.39batch/s, counter=0, lastLoss=0.186, valLoss=0.178]
Epoch 3/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 248.92batch/s, counter=1, lastLoss=0.186, valLoss=0.183]
Epoch 4/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 246.91batch/s, counter=2, lastLoss=0.186, valLoss=0.186]
Epoch 5/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 248.54batch/s, counter=3, lastLoss=0.186, valLoss=0.183]
Epoch 6/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 249.43batch/s, counter=4, lastLoss=0.186, valLoss=0.187]
Epoch 7/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 251.15batch/s, counter=5, lastLoss=0.186, valLoss=0.183]
Epoch 8/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 245.10batch/s, counter=6, lastLoss=0.186, valLoss=0.183]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 251.14batch/s, counter=0, lastLoss=0.161, valLoss=0.157]
Epoch 70/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 248.36batch/s, counter=1, lastLoss=0.16, valLoss=0.163]
Epoch 71/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 232.10batch/s, counter=2, lastLoss=0.161, valLoss=0.16]
Epoch 72/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 250.03batch/s, counter=0, lastLoss=0.16, valLoss=0.157]
Epoch 73/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 252.27batch/s, counter=1, lastLoss=0.16, valLoss=0.165]
Epoch 74/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 247.72batch/s, counter=0, lastLoss=0.159, valLoss=0.157]
Epoch 75/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 243.95batch/s, counter=1, lastLoss=0.16, valLoss=0.161]
Epoch 76/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 246.73batch/s, counter=2, lastLoss=0.161, valLoss=0.16]
Epoch 77/1500: 100%|████████████████████

Epoch 137/1500: 100%|███████████████████| 56/56 [00:00<00:00, 250.13batch/s, counter=31, lastLoss=0.159, valLoss=0.157]
Epoch 138/1500: 100%|███████████████████| 56/56 [00:00<00:00, 237.85batch/s, counter=32, lastLoss=0.158, valLoss=0.166]
Epoch 139/1500: 100%|███████████████████| 56/56 [00:00<00:00, 246.72batch/s, counter=33, lastLoss=0.158, valLoss=0.161]
Epoch 140/1500: 100%|███████████████████| 56/56 [00:00<00:00, 246.73batch/s, counter=34, lastLoss=0.158, valLoss=0.161]
Epoch 141/1500: 100%|███████████████████| 56/56 [00:00<00:00, 250.08batch/s, counter=35, lastLoss=0.158, valLoss=0.159]
Epoch 142/1500: 100%|███████████████████| 56/56 [00:00<00:00, 252.30batch/s, counter=36, lastLoss=0.158, valLoss=0.163]
Epoch 143/1500: 100%|███████████████████| 56/56 [00:00<00:00, 250.02batch/s, counter=37, lastLoss=0.158, valLoss=0.157]
Epoch 144/1500: 100%|███████████████████| 56/56 [00:00<00:00, 247.28batch/s, counter=38, lastLoss=0.158, valLoss=0.155]
Epoch 145/1500: 100%|███████████████████

Early stopping after 155 epochs
Average train loss: 0.017257807063737085
Average validation loss: 0.01711208337826861
Running RNN based on multi_channel, with dataset using non-combined energy
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 204.12batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 213.32batch/s, counter=0, lastLoss=0.188, valLoss=0.183]
Epoch 3/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 213.14batch/s, counter=1, lastLoss=0.187, valLoss=0.184]
Epoch 4/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 200.37batch/s, counter=2, lastLoss=0.187, valLoss=0.185]
Epoch 5/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 198.86batch/s, counter=0, lastLoss=0.186, valLoss=0.181]
Epoch 6/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 212.95batch/s, counter=0, lastLoss=0.187, valLoss=0.178]
Epoch 7/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 210.27batch/s, counter=1, lastLoss=0.187, valLoss=0.183]
Epoch 8/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 208.83batch/s, counter=2, lastLoss=0.186, valLoss=0.181]
Epoch 9/1500: 100%|█████████████████████

Early stopping after 60 epochs
Average train loss: 0.007440499497134062
Average validation loss: 0.007371233327521216
Running RNN based on independent, with dataset using non-combined energy
Using: cuda:0


Epoch 1/1500: 100%|████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 211.81batch/s]
Epoch 2/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 204.02batch/s, counter=0, lastLoss=0.186, valLoss=0.184]
Epoch 3/1500: 100%|███████████████████████| 56/56 [00:00<00:00, 209.86batch/s, counter=1, lastLoss=0.186, valLoss=0.19]
Epoch 4/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 197.93batch/s, counter=2, lastLoss=0.186, valLoss=0.185]
Epoch 5/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 193.79batch/s, counter=3, lastLoss=0.186, valLoss=0.186]
Epoch 6/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 193.45batch/s, counter=4, lastLoss=0.186, valLoss=0.193]
Epoch 7/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 196.86batch/s, counter=5, lastLoss=0.186, valLoss=0.187]
Epoch 8/1500: 100%|██████████████████████| 56/56 [00:00<00:00, 206.68batch/s, counter=0, lastLoss=0.186, valLoss=0.181]
Epoch 9/1500: 100%|█████████████████████

Epoch 69/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 211.35batch/s, counter=13, lastLoss=0.159, valLoss=0.16]
Epoch 70/1500: 100%|████████████████████| 56/56 [00:00<00:00, 203.59batch/s, counter=14, lastLoss=0.159, valLoss=0.158]
Epoch 71/1500: 100%|████████████████████| 56/56 [00:00<00:00, 212.14batch/s, counter=15, lastLoss=0.159, valLoss=0.159]
Epoch 72/1500: 100%|████████████████████| 56/56 [00:00<00:00, 210.15batch/s, counter=16, lastLoss=0.159, valLoss=0.161]
Epoch 73/1500: 100%|█████████████████████| 56/56 [00:00<00:00, 204.40batch/s, counter=17, lastLoss=0.159, valLoss=0.16]
Epoch 74/1500: 100%|████████████████████| 56/56 [00:00<00:00, 210.93batch/s, counter=18, lastLoss=0.159, valLoss=0.157]
Epoch 75/1500: 100%|████████████████████| 56/56 [00:00<00:00, 207.04batch/s, counter=19, lastLoss=0.159, valLoss=0.155]
Epoch 76/1500: 100%|████████████████████| 56/56 [00:00<00:00, 207.41batch/s, counter=20, lastLoss=0.159, valLoss=0.155]
Epoch 77/1500: 100%|████████████████████

Early stopping after 105 epochs
Average train loss: 0.011660678441325824
Average validation loss: 0.01154915235340595


In [129]:
losses = {}
modelss = []

for comb in loaders.keys():
            for a in models[model_type][comb].keys():
                for ext_method in loaders[comb].keys():
                    print(f"Running {model_type} based on {model_method}, with dataset using {comb} {ext_method}")
                    modelss.append(SiameseRNN(input_size=lengths[comb][ext_method]))
                    train_loss, val_loss = train(
                        loaders[comb][ext_method]['train'],
                        modelss[-1],
                        val_loader=loaders[comb][ext_method]['test'],
                        LR=0.001,
                        epochs=300, early_stopping=True)

            print(test(loaders[comb][ext_method]['test'], m[-1], datasets[comb][ext_method][1]))


for i, method in enumerate(loaders['non-combined'].keys()):
    print(f"Running based on {method}")
    print(test(loaders['non-combined'][method]['test'], modelss[i], datasets['non-combined'][method][1]), "\n")

Running RNN based on independent, with dataset using non-combined min-max
Using: cuda:0


Epoch 1/300: 100%|██████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 91.36batch/s]
Epoch 2/300: 100%|████████████████████████| 56/56 [00:00<00:00, 89.03batch/s, counter=0, lastLoss=0.231, valLoss=0.189]
Epoch 3/300: 100%|███████████████████████| 56/56 [00:00<00:00, 104.09batch/s, counter=0, lastLoss=0.186, valLoss=0.187]
Epoch 4/300: 100%|███████████████████████| 56/56 [00:00<00:00, 103.90batch/s, counter=0, lastLoss=0.187, valLoss=0.187]
Epoch 5/300: 100%|███████████████████████| 56/56 [00:00<00:00, 103.23batch/s, counter=0, lastLoss=0.184, valLoss=0.178]
Epoch 6/300: 100%|███████████████████████| 56/56 [00:00<00:00, 102.38batch/s, counter=0, lastLoss=0.178, valLoss=0.173]
Epoch 7/300: 100%|████████████████████████| 56/56 [00:00<00:00, 96.14batch/s, counter=0, lastLoss=0.171, valLoss=0.164]
Epoch 8/300: 100%|████████████████████████| 56/56 [00:00<00:00, 97.56batch/s, counter=0, lastLoss=0.162, valLoss=0.157]
Epoch 9/300: 100%|██████████████████████

KeyboardInterrupt: 

In [126]:
for i, method in enumerate(loaders['non-combined'].keys()):
    print(f"Running based on {method}")
    print(test(loaders['non-combined'][method]['test'], modelss[i], datasets['non-combined'][method][1]), "\n")

Running based on min-max


100%|█████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 315.79 batch/s]

(90.46535587310791, array([94.535805, 76.38266 , 99.71    , 99.56899 , 83.740036, 88.854645],
      dtype=float32), 0.07161155641078949) 

Running based on stats





IndexError: list index out of range

In [None]:
losses = {}

for model_type in models.keys():
    losses[model_type] = {}
    if model_type == 'LSTM':
        continue
    for comb in models[model_type].keys():
        losses[model_type][comb] = {}
        
        for ext_method in models[model_type][comb].keys():
            losses[model_type][comb][ext_method] = {}
            
            for model_method in models[model_type][comb][ext_method].keys():
                print(f"Running {model_type} based on {model_method}, with dataset using {comb} {ext_method}")
                m = SiameseRNN(input_size=lengths[comb][ext_method])
                train_loss, val_loss = train(
                    loaders[comb][ext_method]['train'],
                    m,
                    val_loader=loaders[comb][ext_method]['test'],
                    LR=0.001,
                    epochs=300, early_stopping=True)
                
                print(test(loaders[comb][ext_method]['test'], m, datasets[comb][ext_method][1]))
                
                

In [90]:
datasets['non-combined'].keys()

dict_keys(['min-max', 'stats', 'raw', 'entropy', 'energy'])

In [88]:
test(loaders[comb][ext_method]['test'], m, datasets[comb][ext_method][1])

100%|█████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 168.58 batch/s]


(91.24422645568848,
 array([94.51197 , 79.451096, 99.72503 , 99.62737 , 83.16258 , 90.987305],
       dtype=float32),
 0.06617546553413073)

In [67]:
losses = {}

for model_type in models.keys():
    losses[model_type] = {}
    if model_type == 'LSTM':
        continue
    for comb in models[model_type].keys():
        losses[model_type][comb] = {}
        
        for ext_method in models[model_type][comb].keys():
            losses[model_type][comb][ext_method] = {}
            
            for model_method in models[model_type][comb][ext_method].keys():
                print(f"Running {model_type} based on {model_method}, with dataset using {comb} {ext_method}")
                train_loss, val_loss = train(
                    loaders[comb][ext_method]['train'],
                    SiameseRNN(input_size=lengths[comb][ext_method]),
                    val_loader=loaders[comb][ext_method]['test'],
                    LR=0.001,
                    epochs=300, early_stopping=True)
                
                losses[model_type]['train'] = train_loss
                losses[model_type]['val'] = val_loss
                
                

Running RNN based on concatenate, with dataset using non-combined min-max
Using: cuda:0


Epoch 1/300: 100%|██████████████████████████████████████████████████████████████████| 56/56 [00:01<00:00, 54.34batch/s]
Epoch 2/300: 100%|████████████████████████| 56/56 [00:01<00:00, 55.83batch/s, counter=0, lastLoss=0.214, valLoss=0.191]
Epoch 3/300: 100%|████████████████████████| 56/56 [00:01<00:00, 55.92batch/s, counter=0, lastLoss=0.187, valLoss=0.183]
Epoch 4/300: 100%|████████████████████████| 56/56 [00:00<00:00, 76.71batch/s, counter=0, lastLoss=0.185, valLoss=0.183]
Epoch 5/300: 100%|████████████████████████| 56/56 [00:01<00:00, 55.56batch/s, counter=0, lastLoss=0.181, valLoss=0.173]
Epoch 6/300: 100%|████████████████████████| 56/56 [00:01<00:00, 55.34batch/s, counter=0, lastLoss=0.172, valLoss=0.163]
Epoch 7/300: 100%|█████████████████████████| 56/56 [00:01<00:00, 55.69batch/s, counter=0, lastLoss=0.161, valLoss=0.16]
Epoch 8/300: 100%|████████████████████████| 56/56 [00:00<00:00, 56.08batch/s, counter=0, lastLoss=0.153, valLoss=0.142]
Epoch 9/300: 100%|██████████████████████

KeyboardInterrupt: 

In [68]:
# In theory wont work
class SiameseRNN(nn.Module):
    def __init__(self, input_size=16, hidden_size=256):
        super(SiameseRNN, self).__init__()
        self.hidden_size = hidden_size
        # num_layers > 5 reduces performance
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=3, batch_first=True)
                
        self.fc = nn.Sequential(
            nn.Linear(hidden_size*2, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 6)
        
        )
    
    def forward(self, x):
        batch_size = x.shape[0]
        #x = x.reshape(batch_size, 1, -1)
        
        x1 = x[:, 0, :].reshape(batch_size, 1, -1)
        x2 = x[:, 1, :].reshape(batch_size, 1, -1)
        

        h0 = torch.zeros(5, batch_size, self.hidden_size).cpu()
        
        o1, h1 = self.rnn(x1)  # Add a batch dimension
        o2, h2 = self.rnn(x2)  # Add a batch dimension
        
        out = torch.cat([o1[:, -1, :], o2[:, -1, :]], dim=1)
        #out=h1[-1]
        out = out.reshape(batch_size, -1)
        out = self.fc(out)
        return out

In [29]:
def train(
    train_loader,
    net,
    LR=0.1,
    epochs=2000,
    val_loader=None,
    early_stopping=False,
    patience=50,
    optimizer=optim.Adam
):
    net.to(DEVICE)
    optimizer = optimizer(net.parameters(), lr=LR)
    criterion = nn.L1Loss()
    val_losses = []        
    losses = []
    last_loss = 0
    best_val_loss = 1e10
    counter = 0

    print(f"Using: {DEVICE}")
    
    for epoch in range(epochs):
        loss = 0
        net.train()
        with tqdm(train_loader, unit="batch") as it:
            if epoch > 0:
                it.set_postfix(lastLoss=last_loss, valLoss=0 if len(val_losses) \
                     == 0 else val_losses[-1], counter=counter)
            for idx, data in enumerate(it):
                it.set_description(f"Epoch {epoch+1}/{epochs}")
                inp, out = data['input'].to(DEVICE), data['output'].to(DEVICE)
                
                optimizer.zero_grad()
                predicted = net(inp)

                cost = criterion(out, predicted)
                loss += cost.item()
                cost.backward()
                optimizer.step()
        
        loss /= len(it)
        losses.append(loss)
        last_loss = loss
        
        if val_loader:
            val_loss = 0
            net.eval()
            for idx, data in enumerate(val_loader):
                inp, out = data['input'].to(DEVICE), data['output'].to(DEVICE)

                predicted = net(inp)
                cost = criterion(out, predicted)
                val_loss += cost.item()
            val_loss /= len(val_loader)  
            val_losses.append(val_loss)

            if early_stopping:
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    counter = 0
                else:
                    counter += 1
                if counter >= patience:
                    print(f"Early stopping after {epoch + 1} epochs")
                    break
         
    print(f"Average train loss: {np.sum(losses)/epochs}")
    print(f"Average validation loss: {np.sum(val_losses)/epochs}")
    
    return losses, val_losses