In [1]:
from src.datasets import patientDataset, eegDataset
from src.resnet import ResNet1d
from src.lstm import Lstm
from tqdm import tqdm
import numpy as np
import mne
from helper_code import *

In [2]:
from torch.utils.data import DataLoader
import torch

no_of_segs = 72
pattrainset = patientDataset("../train/", no_of_segs)
# trainset = eegDataset(pattrainset, [i for i in range(len(pattrainset))])
trainloader = DataLoader(dataset=pattrainset, batch_size=1, shuffle=True)

pattestset = patientDataset("../split_5/", no_of_segs)
# testset = eegDataset(pattestset, [i for i in range(len(pattestset))])
testloader = DataLoader(dataset=pattestset, batch_size=1, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def validate(model, val_loader, criterion):
    val_loss = []
    model.eval()
    for i, (inputs, labels) in enumerate(val_loader):
        inputs = torch.Tensor(inputs).view(len(inputs), 1, -1)
        inputs = inputs.type(torch.FloatTensor).to(device)
        labels = torch.Tensor(labels)
        labels = labels.type(torch.FloatTensor).to(device)
        
        outputs = model(inputs).flatten()
        loss    = criterion(outputs, labels)
        val_loss.append(loss.item())
    
    return np.mean(val_loss)

In [10]:
def train_one_epoch(model, optimizer, criterion ,train_loader, epoch):
    model.train()
    model.zero_grad()
    train_loss = []
    for i, (inputs, labels) in enumerate(train_loader): 
            
        inputs = torch.Tensor(inputs).view(len(inputs), 1, -1)
        inputs = inputs.type(torch.FloatTensor).to(device)
        labels = torch.Tensor(labels)
        labels = labels.type(torch.FloatTensor).to(device)
        
        outputs = model(inputs).flatten()
        print(outputs.shape, labels.shape)
        loss    = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.cpu().detach().numpy())
        print("Epoch {}: {}/{} Loss: {}".format(epoch, i, len(train_loader), loss.item()), end='\r')
    return np.mean(train_loss)

In [11]:
# train_config = { 'num_epochs':20, 'learning_rate':1e-4 }
# arch_config  = {
#                 'n_input_channels':18,
#                 'signal_length':600,
#                 'net_filter_size':[ 18, 36],
#                 'net_signal_length':[ 200, 40],
#                 'kernel_size': 3,
#                 'n_classes':2,
#                 'dropout_rate':0.3
#                }

# model = ResNet1d(input_dim=(arch_config['n_input_channels'], arch_config['signal_length']), 
#                  blocks_dim=list(zip(arch_config['net_filter_size'], arch_config['net_signal_length'])),
#                  kernel_size=arch_config['kernel_size'],
#                  dropout_rate=arch_config['dropout_rate'])
# model = model.to(device)

# criterion = torch.nn.CrossEntropyLoss(reduction='sum')
# optimizer = torch.optim.AdamW(model.parameters(), lr=train_config['learning_rate'], weight_decay=1e-7)

In [12]:
import torch.nn as nn
class Lstm(nn.Module):
    """lstm+mlp"""
    def __init__(self, inp_dim, hidden_dim, target_size=2):
        super(Lstm, self).__init__()
        self.lstm = nn.LSTM(inp_dim, hidden_dim, dropout=0.4, num_layers=2)
        self.mlp  = nn.Linear(hidden_dim, 2)
        self.sigm = nn.Softmax(-1)
    def forward(self, x):
        _, hidden = self.lstm( x )
        x = self.mlp(hidden[1])
        x = self.sigm(x)
        return x

In [13]:
train_config = { 'num_epochs':20, 'learning_rate':5e-4 }
arch_config  = {
                'inp_size': 1530,
                'hidden_size': 500,
               }

model = Lstm(arch_config["inp_size"], arch_config["hidden_size"])
model = model.to(device)
model = torch.compile(model)

criterion = torch.nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.AdamW(model.parameters(), lr=train_config['learning_rate'], weight_decay=1e-7)

In [14]:
train_losses = []
val_losses = []

In [15]:
for epoch in range(train_config['num_epochs']):
    train_losses.append(train_one_epoch(model, optimizer, criterion, pattrainset, epoch))
    val_losses.append(validate(model, pattestset, criterion))
    print(train_losses[-1], val_losses[-1])

torch.Size([4]) torch.Size([2])


RuntimeError: size mismatch (got input: [4], target: [2])

In [None]:
import matplotlib.pyplot as plt
plt.plot(train_losses)
plt.plot(val_losses)

In [None]:
device