In [1]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
from torch import tensor
import numpy as np
from torch.utils.data import Dataset
from sklearn.metrics import mean_squared_error
import random
import os
import matplotlib.pyplot as plt
#from numba import jit
import pickle
from scipy.interpolate import interp1d
from torch.utils.data import DataLoader, random_split
import torch
from torchsummary import summary
import seaborn as sns
import sys
import torch.nn.functional as F
import pywt
from sklearn.preprocessing import MinMaxScaler
from torch.cuda import FloatTensor

# Req for package
sys.path.append("../")
from SkinLearning.Utils.NN import train, test, DEVICE


torch.backends.cudnn.benchmark = True

In [7]:
# Folder name will correspond to index of sample
class SkinDataset(Dataset):
    def __init__(self, scaler, signalFolder="D:/SamplingResults2", sampleFile="../Data/newSamples.pkl", runs=range(65535), steps=128):
        # Load both disp1 and disp2 from each folder
        # Folders ordered according to index of sample
        # Use the corresponding sample as y -> append probe?
        self.input = []
        self.output = []
        
        with open(f"{sampleFile}", "rb") as f:
             samples = pickle.load(f)
        
        self.min = np.min(samples[runs])
        self.max = np.max(samples[runs])
        
        
        for run in tqdm(runs):
            inp = []
            fail = False
            
            files = os.listdir(f"{signalFolder}/{run}/")
            
            if files != ['Disp1.csv', 'Disp2.csv']:
                continue
            
            for file in files:
                a = pd.read_csv(f"{signalFolder}/{run}/{file}")
                a.rename(columns = {'0':'x', '0.1': 'y'}, inplace = True)
                
                # Skip if unconverged
                if a['x'].max() != 7.0:
                    fail = True
                    break

                # Interpolate curve for consistent x values
                xNew = np.linspace(0, 7, num=steps, endpoint=False)
                interped = interp1d(a['x'], a['y'], kind='cubic', fill_value="extrapolate")(xNew)
                    
                
                inp.append(interped.astype("float32"))
            
            if not fail:
                if len(inp) != 2:
                    raise Exception("sdf")

                self.input.append(inp)
                self.output.append(samples[int(run)])
        
        scaler.fit(self.output)
        self.output = scaler.fit_transform(self.output)
        self.output = tensor(self.output).type(FloatTensor)
        
        self.input = [waveletExtraction(sample) for sample in self.input]
        self.input = tensor(self.input).type(FloatTensor)
        
        
    def __len__(self):
        return len(self.output)
    
    def __getitem__(self, idx):
        sample = {"input": self.input[idx], "output": self.output[idx]}
        return sample
    
    

In [8]:
"""
    Creates the data set from filtered samples
    Returns the dataset and the scaler
"""
def getDataset(**kwargs):
    # Get filtered data
    if not 'runs' in kwargs.keys():
        with open("../Data/filtered.pkl", "rb") as f:
            runs = pickle.load(f)

        kwargs['runs'] = runs

    scaler = MinMaxScaler()
    dataset = SkinDataset(scaler=scaler, **kwargs)

    return dataset, scaler

In [9]:
"""
    Creates a train/test split from the given data
    Returns train and test data loaders
"""
def getSplit(dataset, p1=0.8):
    train_n = int(p1 * len(dataset))
    test_n = len(dataset) - train_n
    train_set, test_set = random_split(dataset, [train_n, test_n])

    return DataLoader(train_set, batch_size=32, shuffle=True), \
        DataLoader(test_set, batch_size=32, shuffle=True)

In [10]:
def waveletExtraction(x, wavelet='db1', level=4):
# perform wavelet packet decomposition on signal 1
    wp = pywt.WaveletPacket(x[0], wavelet, mode='symmetric', maxlevel=level)
    coeffs1 = []
    for node in wp.get_level(level, 'natural'):
        if node.path.endswith('a') or node.path.endswith('d'):
            coeffs1.append(node.data)
    coeffs1 = np.concatenate(coeffs1)
    
    # perform wavelet packet decomposition on signal 2
    wp = pywt.WaveletPacket(x[1], wavelet, mode='symmetric', maxlevel=level)
    coeffs2 = []
    for node in wp.get_level(level, 'natural'):
        if node.path.endswith('a') or node.path.endswith('d'):
            coeffs2.append(node.data)
    coeffs2 = np.concatenate(coeffs2)
    
    # concatenate the two coefficient arrays
    feature_vector = np.concatenate((coeffs1, coeffs2))
    
    return feature_vector

In [11]:
dataset, scaler = getDataset()

100%|██████████████████████████████████████████████████████████████████████████████| 2241/2241 [00:35<00:00, 63.39it/s]
  self.input = tensor(self.input).type(FloatTensor)


In [12]:
train_loader, test_loader = getSplit(dataset)

In [193]:
class SiameseRNN(nn.Module):
    def __init__(self, input_size=256, hidden_size=512):
        super(SiameseRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=3, batch_first=True)
                
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 6)
        
        )
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        x = x.reshape(batch_size, 1, -1)
        
        #x1 = x[:, 0, :, :].reshape(batch_size, 1, -1)
        #x2 = x[:, 0, :, :].reshape(batch_size, 1, -1)
        
        h0 = torch.zeros(3, batch_size, self.hidden_size).to(DEVICE)
        _, h1 = self.rnn(x, h0)  # Add a batch dimension
        #_, h2 = self.rnn(x2, h0)  # Add a batch dimension
        
        out = h1[-1]
        out = out.reshape(batch_size, -1)
        out = self.fc(out)
        return out

In [194]:
class SiameseLSTM(nn.Module):
    def __init__(self, input_size=256, hidden_size=1024, num_layers=3):
        super(SiameseLSTM, self).__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        # Define the LSTM layer
        self.lstm = nn.LSTM(input_size=input_size, 
                            hidden_size=hidden_size, 
                            num_layers=num_layers, 
                            batch_first=True)
        
        self.fc = nn.Sequential(
            nn.Linear(2048 , 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 6),
        )

     
    def forward(self, x):
        batch_size = x.shape[0]
        #x1 = x[:, 0, :].unsqueeze(1)
        #x2 = x[:, 0, :].unsqueeze(1)
        x = x.unsqueeze(1)
        
        
        h0 = torch.zeros(self.num_layers*1, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers*1, x.size(0), self.hidden_size).to(x.device)
        
        # Forward pass through the LSTM layer
        o, h = self.lstm(x, (h0, c0))
        #out2 = self.lstm(x2, (h0, c0))[1][-1]

        
        #out = torch.cat([out1, out2], dim=0)
        out = h[-1]
        out = out.reshape(batch_size, -1)
        # Pass the last hidden state through the fully connected layer
        out = self.fc(out)
        
        return out

In [195]:
sRNN = SiameseRNN()

In [196]:
sRNN_train_loss, sRNN_val_loss =  train(train_loader, sRNN, val_loader=test_loader, LR=0.001, epochs=400)

Using: cuda:0


Epoch 1/400: 100%|█████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 127.27batch/s]
Epoch 2/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 132.70batch/s, lastLoss=0.204, valLoss=0.178]
Epoch 3/400: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 124.86batch/s, lastLoss=0.18, valLoss=0.182]
Epoch 4/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 139.65batch/s, lastLoss=0.168, valLoss=0.171]
Epoch 5/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 136.25batch/s, lastLoss=0.164, valLoss=0.155]
Epoch 6/400: 100%|███████████████████████████████████| 56/56 [00:00<00:00, 135.26batch/s, lastLoss=0.16, valLoss=0.155]
Epoch 7/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 139.13batch/s, lastLoss=0.154, valLoss=0.156]
Epoch 8/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 134.13batch/s, lastLoss=0.151, valLoss=0.146]
Epoch 9/400: 100%|██████████████████████

KeyboardInterrupt: 

In [None]:
sLSTM = SiameseLSTM()

In [108]:
sLSTM_train_loss, sLSTM_val_loss =  train(train_loader, sLSTM, val_loader=test_loader, LR=0.001, epochs=400)

Using: cuda:0


Epoch 1/400: 100%|█████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 136.42batch/s]
Epoch 2/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 136.75batch/s, lastLoss=0.211, valLoss=0.178]
Epoch 3/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 134.61batch/s, lastLoss=0.178, valLoss=0.173]
Epoch 4/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 132.86batch/s, lastLoss=0.167, valLoss=0.165]
Epoch 5/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 140.00batch/s, lastLoss=0.161, valLoss=0.164]
Epoch 6/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 136.92batch/s, lastLoss=0.159, valLoss=0.163]
Epoch 7/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 144.14batch/s, lastLoss=0.155, valLoss=0.148]
Epoch 8/400: 100%|██████████████████████████████████| 56/56 [00:00<00:00, 132.86batch/s, lastLoss=0.151, valLoss=0.146]
Epoch 9/400: 100%|██████████████████████

Epoch 69/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 140.88batch/s, lastLoss=0.0841, valLoss=0.0918]
Epoch 70/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 135.75batch/s, lastLoss=0.0835, valLoss=0.0868]
Epoch 71/400: 100%|████████████████████████████████| 56/56 [00:00<00:00, 141.77batch/s, lastLoss=0.083, valLoss=0.0845]
Epoch 72/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 140.35batch/s, lastLoss=0.0835, valLoss=0.0861]
Epoch 73/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 142.67batch/s, lastLoss=0.0833, valLoss=0.0866]
Epoch 74/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 140.35batch/s, lastLoss=0.0849, valLoss=0.0844]
Epoch 75/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 136.92batch/s, lastLoss=0.0853, valLoss=0.0917]
Epoch 76/400: 100%|████████████████████████████████| 56/56 [00:00<00:00, 143.04batch/s, lastLoss=0.0855, valLoss=0.087]
Epoch 77/400: 100%|█████████████████████

Epoch 137/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 138.10batch/s, lastLoss=0.0812, valLoss=0.0823]
Epoch 138/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 148.93batch/s, lastLoss=0.0823, valLoss=0.0868]
Epoch 139/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 139.30batch/s, lastLoss=0.0837, valLoss=0.0875]
Epoch 140/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 144.51batch/s, lastLoss=0.0836, valLoss=0.0937]
Epoch 141/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 140.00batch/s, lastLoss=0.0821, valLoss=0.083]
Epoch 142/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 138.27batch/s, lastLoss=0.0804, valLoss=0.0874]
Epoch 143/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 137.08batch/s, lastLoss=0.0822, valLoss=0.088]
Epoch 144/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 143.40batch/s, lastLoss=0.0814, valLoss=0.0841]
Epoch 145/400: 100%|████████████████████

Epoch 205/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 171.25batch/s, lastLoss=0.0793, valLoss=0.0876]
Epoch 206/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 185.43batch/s, lastLoss=0.0779, valLoss=0.0856]
Epoch 207/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 217.89batch/s, lastLoss=0.0775, valLoss=0.0929]
Epoch 208/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 204.56batch/s, lastLoss=0.0792, valLoss=0.0877]
Epoch 209/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 209.03batch/s, lastLoss=0.0769, valLoss=0.0899]
Epoch 210/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 218.66batch/s, lastLoss=0.0806, valLoss=0.0946]
Epoch 211/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 204.74batch/s, lastLoss=0.0787, valLoss=0.0831]
Epoch 212/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 193.43batch/s, lastLoss=0.0786, valLoss=0.0923]
Epoch 213/400: 100%|████████████████████

Epoch 273/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 212.12batch/s, lastLoss=0.0764, valLoss=0.0941]
Epoch 274/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 208.56batch/s, lastLoss=0.0762, valLoss=0.0907]
Epoch 275/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 173.37batch/s, lastLoss=0.0766, valLoss=0.0835]
Epoch 276/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 178.34batch/s, lastLoss=0.0751, valLoss=0.0881]
Epoch 277/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 197.18batch/s, lastLoss=0.0764, valLoss=0.0862]
Epoch 278/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 197.53batch/s, lastLoss=0.0753, valLoss=0.0981]
Epoch 279/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 210.92batch/s, lastLoss=0.077, valLoss=0.0841]
Epoch 280/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 198.93batch/s, lastLoss=0.0762, valLoss=0.084]
Epoch 281/400: 100%|████████████████████

Epoch 341/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 217.39batch/s, lastLoss=0.0712, valLoss=0.0838]
Epoch 342/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 220.46batch/s, lastLoss=0.0685, valLoss=0.0776]
Epoch 343/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 216.79batch/s, lastLoss=0.0717, valLoss=0.0898]
Epoch 344/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 216.62batch/s, lastLoss=0.0771, valLoss=0.0782]
Epoch 345/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 219.17batch/s, lastLoss=0.0704, valLoss=0.0819]
Epoch 346/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 217.05batch/s, lastLoss=0.0715, valLoss=0.0818]
Epoch 347/400: 100%|██████████████████████████████| 56/56 [00:00<00:00, 220.61batch/s, lastLoss=0.0719, valLoss=0.0818]
Epoch 348/400: 100%|███████████████████████████████| 56/56 [00:00<00:00, 209.99batch/s, lastLoss=0.0702, valLoss=0.084]
Epoch 349/400: 100%|████████████████████

KeyboardInterrupt: 