In [None]:
%%capture
!pip install mne
!pip install pytorch-lightning

In [None]:
%%capture
#!wget https://md-datasets-cache-zipfiles-prod.s3.eu-west-1.amazonaws.com/fshy54ypyh-1.zip -O data.zip
#!unzip data.zip


In [None]:
from glob import glob
import scipy.io
import torch.nn as nn
import torch
import numpy as np
import mne

In [None]:
input=torch.randn(3,22,15000)
input.shape

torch.Size([3, 22, 15000])

In [None]:
class Block(nn.Module):
    def __init__(self,inplace):
        super().__init__()
        self.conv1=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=2,stride=2,padding=0)
        self.conv2=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=4,stride=2,padding=1)
        self.conv3=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=8,stride=2,padding=3)
        self.relu=nn.ReLU()

    def forward(self,x):
        x1=self.relu(self.conv1(x))
        x2=self.relu(self.conv2(x))
        x3=self.relu(self.conv3(x))
        x=torch.cat([x1,x3,x3],dim=1)
        return x

In [None]:
class ChronoNet(nn.Module):
  def __init__(self,channel):
    super().__init__()
    self.block1=Block(channel)
    self.block2=Block(96)
    self.block3=Block(96)
    self.gru1=nn.GRU(input_size=96,hidden_size=32,batch_first=True)
    self.gru2=nn.GRU(input_size=32,hidden_size=32,batch_first=True)
    self.gru3=nn.GRU(input_size=64,hidden_size=32,batch_first=True)
    self.gru4=nn.GRU(input_size=96,hidden_size=32,batch_first=True)
    self.gru_linear=nn.Linear(64,1)
    self.flatten=nn.Flatten()
    self.fc1=nn.Linear(32,1)
    self.relu=nn.ReLU()
  def forward(self,x):
    x=self.block1(x)
    x=self.block2(x)
    x=self.block3(x)
    x=x.permute(0,2,1)
    gru_out1,_=self.gru1(x)
    gru_out2,_=self.gru2(gru_out1)
    gru_out=torch.cat([gru_out1,gru_out2],dim=2)
    gru_out3,_=self.gru3(gru_out)
    gru_out=torch.cat([gru_out1,gru_out2,gru_out3],dim=2)
    #print('gru_out',gru_out.shape)
    linear_out=self.relu(self.gru_linear(gru_out.permute(0,2,1)))
    gru_out4,_=self.gru4(linear_out.permute(0,2,1))
    x=self.flatten(gru_out4)
    x=self.fc1(x)
    return x

In [None]:
input=torch.randn(3,14,512)
input.shape
model=ChronoNet(14)
out=model(input)
out.shape

torch.Size([3, 1])

In [None]:
IDD_data_path='/content/Data/CleanData/CleanData_TDC/Rest'
TDC_data_path='/content/Data/Data/CleanData/Data/Data/CleanData/CleanData_IDD/Rest'
!rm '/content/Data/Data/CleanData/Data/Data/CleanData/CleanData_IDD/Rest/NDS001_Rest_CD(1).mat'

In [None]:
def convertmat2mne(data):
  ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
  ch_types = ['eeg'] * 14
  sampling_freq=128
  info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sampling_freq)
  info.set_montage('standard_1020')
  data=mne.io.RawArray(data, info)
  data.set_eeg_reference()
  data.filter(l_freq=1,h_freq=30)
  epochs=mne.make_fixed_length_epochs(data,duration=4,overlap=0)
  return epochs.get_data()

In [None]:
%%capture
idd_subject=[]
for idd in glob(IDD_data_path+'/*.mat'):
  data=scipy.io.loadmat(idd)['clean_data']
  data=convertmat2mne(data)
  idd_subject.append(data)

In [None]:
%%capture
tdc_subject=[]
for tdc in glob(TDC_data_path+'/*.mat'):
  data=scipy.io.loadmat(tdc)['clean_data']
  data=convertmat2mne(data)
  tdc_subject.append(data)
  

In [None]:
len(idd_subject),len(tdc_subject)

(7, 7)

In [None]:
control_epochs_labels=[len(i)*[0] for i in tdc_subject]
patients_epochs_labels=[len(i)*[1] for i in idd_subject]
print(len(control_epochs_labels),len(patients_epochs_labels))

7 7


In [None]:
data_list=tdc_subject+idd_subject
label_list=control_epochs_labels+patients_epochs_labels
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]
print(len(data_list),len(label_list),len(groups_list))


14 14 14


In [None]:
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler
gkf=GroupKFold()
from sklearn.base import TransformerMixin,BaseEstimator
from sklearn.preprocessing import StandardScaler
#https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix
class StandardScaler3D(BaseEstimator,TransformerMixin):
    #batch, sequence, channels
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self,X,y=None):
        self.scaler.fit(X.reshape(-1, X.shape[2]))
        return self

    def transform(self,X):
        return self.scaler.transform(X.reshape( -1,X.shape[2])).reshape(X.shape)

In [None]:
import numpy as np
data_array=np.concatenate(data_list)
label_array=np.concatenate(label_list)
group_array=np.concatenate(groups_list)
data_array=np.moveaxis(data_array,1,2)

print(data_array.shape,label_array.shape,group_array.shape)

(420, 512, 14) (420,) (420,)


In [None]:
accuracy=[]
for train_index, val_index in gkf.split(data_array, label_array, groups=group_array):
    train_features,train_labels=data_array[train_index],label_array[train_index]
    val_features,val_labels=data_array[val_index],label_array[val_index]
    scaler=StandardScaler3D()
    train_features=scaler.fit_transform(train_features)
    val_features=scaler.transform(val_features)
    train_features=np.moveaxis(train_features,1,2)
    val_features=np.moveaxis(val_features,1,2)

    break

In [None]:
train_features = torch.Tensor(train_features)
val_features = torch.Tensor(val_features)
train_labels = torch.Tensor(train_labels)
val_labels = torch.Tensor(val_labels)

In [None]:
len(val_features),len(val_labels)

(90, 90)

In [None]:
train_features.shape

torch.Size([330, 14, 512])

In [None]:
from pytorch_lightning import LightningModule,Trainer
import torchmetrics
from torch.utils.data import TensorDataset,DataLoader

In [None]:
class ChronoModel(LightningModule):
  def __init__(self):
    super(ChronoModel,self).__init__()
    self.model=ChronoNet(14)
    self.lr=1e-3
    self.bs=12
    self.worker=2
    self.acc=torchmetrics.Accuracy()
    self.creterion=nn.BCEWithLogitsLoss()

  def forward(self,x):
    x=self.model(x)
    return x

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(),lr=self.lr)

  def train_dataloader(self):
    dataset=TensorDataset(train_features,train_labels)
    dataloader=DataLoader(dataset,batch_size=self.bs,num_workers=self.worker,shuffle=True)
    return dataloader

  def training_step(self,batch,batch_idx):
    signal,label=batch
    out=self(signal.float())
    loss=self.creterion(out.flatten(),label.float().flatten())
    acc=self.acc(out.flatten(),label.long().flatten())
    return {'loss':loss,'acc':acc}

  def trained_epoch_end(self,outputs):
    acc=torch.stack([x['acc'] for x in outputs]).mean().detach().cpu().numpy().round(2)
    loss=torch.stack([x['loss'] for x in outputs]).mean().detach().cpu().numpy().round(2)
    print('train acc loss',acc,loss)

  def val_dataloader(self):
    dataset=TensorDataset(val_features,val_labels)
    dataloader=DataLoader(dataset,batch_size=self.bs,num_workers=self.worker,shuffle=True)
    return dataloader

  def validation_step(self,batch,batch_idx):
    signal,label=batch
    out=self(signal.float())
    loss=self.creterion(out.flatten(),label.float().flatten())
    acc=self.acc(out.flatten(),label.long().flatten())
    return {'loss':loss,'acc':acc}

  def validation_epoch_end(self,outputs):
    acc=torch.stack([x['acc'] for x in outputs]).mean().detach().cpu().numpy().round(2)
    loss=torch.stack([x['loss'] for x in outputs]).mean().detach().cpu().numpy().round(2)
    print('val acc loss',acc,loss)
    











In [None]:
model=ChronoModel()


In [None]:
trainer=Trainer(max_epochs=1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model)


  | Name      | Type              | Params
------------------------------------------------
0 | model     | ChronoNet         | 133 K 
1 | acc       | Accuracy          | 0     
2 | creterion | BCEWithLogitsLoss | 0     
------------------------------------------------
133 K     Trainable params
0         Non-trainable params
133 K     Total params
0.534     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


val acc loss 0.25 0.7


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

val acc loss 0.32 0.66
