In [53]:
import torch.nn as nn
import torch

In [54]:
class Block(nn.Module):
    def __init__(self,inplace):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=8, stride=2, padding=3)
        self.relu = nn.ReLU()

    def forward(self,x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x = torch.cat([x1,x2,x3], dim = 1)
        return x

In [55]:
class ChronoNet(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.block1 = Block(channel)
        self.block2 = Block(96)
        self.block3 = Block(96)

        self.gru1 = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
        self.gru2 = nn.GRU(input_size=32, hidden_size=32,batch_first=True)
        self.gru3 = nn.GRU(input_size=64, hidden_size=32,batch_first=True)
        self.gru4 = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
        
        self.gru_linear = nn.Linear(64,1)
        self.flattern = nn.Flatten()
        self.fcl = nn.Linear(32,1)
        self.relu = nn.ReLU()
    
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = x.permute(0,2,1)

        gru_out1,_ = self.gru1(x)
        gru_out2,_= self.gru2(gru_out1)
        gru_out = torch.cat([gru_out1, gru_out2], dim=2)

        gru_out3,_ = self.gru3(gru_out)
        gru_out = torch.cat([gru_out1,gru_out2, gru_out3], dim=2)
        linear_out = self.relu(self.gru_linear(gru_out.permute(0,2,1)))

        gru_out4,_ = self.gru4(linear_out.permute(0,2,1))
        
        x = self.flattern(gru_out4)
        x = self.fcl(x)

        return x

In [56]:
input = torch.randn(3,14,512)
input.shape
model = ChronoNet(14)
out = model(input)
out.shape

torch.Size([3, 1])

In [57]:
import os
from pathlib import Path
from glob import glob
import scipy.io
import mne

IDD = Path('EEG dataset/Data/CleanData/CLeanData_IDD/Rest')
TDC = Path('EEG dataset/Data/CleanData/CLeanData_TDC/Rest')

In [58]:
def convertMatToMNE(data):
    ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
    ch_types = ['eeg'] * 14
    sampling_freq = 128
    info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sampling_freq)
    info.set_montage('standard_1020')
    
    data = mne.io.RawArray(data,info)
    # data.set_egg_reference()
    data.filter(l_freq=1, h_freq=30)
    epochs = mne.make_fixed_length_epochs(data, duration=4, overlap=0)
    return epochs.get_data()


In [59]:
%%capture

pattern = os.path.join(IDD, '*.mat')
idd_files = glob(pattern)
idd_subject = []

for idd in idd_files:
    data = scipy.io.loadmat(idd)['clean_data']
    data = convertMatToMNE(data)
    idd_subject.append(data)

In [60]:
%%capture

pattern = os.path.join(TDC, '*.mat')
tdc_files = glob(pattern)
tdc_subject = []

for tdc in tdc_files:
    data = scipy.io.loadmat(tdc)['clean_data']
    data = convertMatToMNE(data)
    tdc_subject.append(data)


In [61]:
healthy_epoch_labels = [len(i) *[0] for i in tdc_subject]
patient_epoch_labels = [len(i) *[1] for i in idd_subject]    
len(healthy_epoch_labels), len(patient_epoch_labels)

(7, 7)

In [62]:
data_list = tdc_subject + idd_subject
labels_list = healthy_epoch_labels + patient_epoch_labels 
groups_list = [[i]*len(j) for i,j in enumerate(data_list)]
len(data_list), len(labels_list), len(groups_list)  

(14, 14, 14)

In [63]:
from sklearn.model_selection import GroupKFold, LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler
from sklearn.base import TransformerMixin, BaseEstimator
import numpy as np

gkf = GroupKFold()


In [64]:
class StandardScaler3D(BaseEstimator, TransformerMixin):
    def __init__(self):
        self.scalar = StandardScaler()
    
    def fit(self,X,y=None):
        self.scalar.fit(X.reshape(-1, X.shape[2]))
        return self
    
    def transform(self,X):
        return self.scalar.transform(X.reshape(-1,X.shape[2])).reshape(X.shape)


In [65]:
import numpy as np

data_list = np.concatenate(data_list)
labels_list = np.concatenate(labels_list)
groups_list = np.concatenate(groups_list)
data_list = np.moveaxis(data_list,1,2)

print(data_list.shape, labels_list.shape, groups_list.shape)

(420, 512, 14) (420,) (420,)


In [66]:
accuracy = []

for train_index, val_index in gkf.split(data_list, labels_list, groups=groups_list):
    train_features, train_labels = data_list[train_index], labels_list[train_index]
    val_features, val_labels = data_list[val_index], labels_list[val_index]

    scaler = StandardScaler3D()

    train_features = scaler.fit_transform(train_features)
    train_features = scaler.fit_transform(val_features)
    train_features = np.moveaxis(train_features,1,2)
    val_features = np.moveaxis(val_features,1,2)

    break

In [67]:
train_features = torch.Tensor(train_features)
val_features = torch.Tensor(val_features)
train_labels = torch.Tensor(train_labels)
val_labels = torch.Tensor(val_labels)

len(train_features), len(val_features)

(90, 90)

In [68]:
from pytorch_lightning import LightningModule, Trainer
import torchmetrics
from torch.utils.data import TensorDataset, DataLoader

In [69]:
class ChronoModel(LightningModule):
    def __init__ (self):
        super(ChronoModel, self).__init__()
        self.model=ChronoNet(14)
        self.lr = 1e-3   #learning rate
        self.bs = 12     #batch size
        self.worker = 2   # no of worker
        self.acc = torchmetrics.Accuracy(task='binary')
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self,x):
        x = self.model(x)
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters, lr=self.lr)

    def train_dataloader(self):
        dataset = TensorDataset(train_features, train_labels)
        dataloader = DataLoader(dataset, batch_size = self.bs, num_workers =self.worker, shuffle = True)
        return dataloader
    
    def training_step(self, batch):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        acc = self.acc(out.flatten(), label.long().flatten())
        
        return { 'loss': loss, 'acc':acc }
    
    def trained_epoch_end(self, outputs):
        acc = torch.stack([x['acc'] for x in outputs]).mean().detach().cpu().numpy().round(2)
        loss = torch.stack([x['loss'] for x in outputs]).mean().detach().cpu().numpy().round(2)

        print('train acc loss ', acc, loss)

    def val_dataloader(self):
        dataset = TensorDataset(val_features, val_labels)
        dataloader = dataloader(dataset, batch_size = self.bs, num_workers =self.worker, shuffle = True)
        return dataloader
    
    def validation_step(self, batch,batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.creterion(out.flatten(), label.float().flatten())
        acc = self.acc(out.flatten(), label.long().flatten())
        
        return { 'loss': loss, 'acc':acc }
    
    def validated_epoch_end(self, outputs):
        acc = torch.stack([x['acc'] for x in outputs]).mean().detach().cpu().numpy().round(2)
        loss = torch.stack([x['loss'] for x in outputs]).mean().detach().cpu().numpy().round(2)

        print('val acc loss ', acc)

In [70]:
model = ChronoModel()

In [71]:
trainer = Trainer(max_epochs=1)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [72]:
trainer.fit(model)

  if not hasattr(np, "object"):


TypeError: 'method' object is not iterable