In [1]:
import mne
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, confusion_matrix



In [2]:
mne.set_log_level('ERROR')

In [3]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


In [4]:
print(device)

cuda


In [5]:
#creating dataset
class CHBData(Dataset):
    def __init__(self, CHB_files, segment_length):
        self.segments = [] # input for model containing segment of length "l" and torch tensor containing the eeg values for that segment
        self.labels = [] # output of model containing  the labels of each segment, either interictal, preictal, or ictal
        for file_path, info in CHB_files.items():
            processed_data=self.preprocessing(file_path) # preprocess every file in dictionary 
            for start, end, label in info:
                segmented_eeg = self.segment_eeg(processed_data,start, end, label,segment_length) # segments that file
                for segment,label in segmented_eeg:
                    self.segments.append(segment) # adds to final list for model 
                    self.labels.append(label)

    def __len__(self):
        return len(self.segments)
    
    def __getitem__(self, i):
        segment, label = self.segments[i], self.labels[i]
        segment = segment.unsqueeze(0).unsqueeze(0)
        label_encoded = self.label_to_int(label)
        return segment,label_encoded

    def label_to_int(self,label):
        label_mapping = {'interictal': 0, 'preictal': 1, 'ictal': 2}
        return label_mapping[label]
    
    
    def segment_eeg(self,segment_tensor,start, end, label,segment_length):
        start=start*256 #have to mulyiply time by sampling rate BUDDDDYYYY 
        end=end*256
        segments=[] # list to store tuple of each segments pytorch tensor and each label inside the data 
        for i in range(start,end,segment_length):
            segment_end=min(i+segment_length,end)
            segment=segment_tensor[:,i:segment_end] # isolates segment of wtv lenght it is from each torch tensor in preprocessing our data 
            segments.append((segment,label))
            
        return segments

    
    def preprocessing(self,file_path):
        # loading data: 
        raw = mne.io.read_raw_edf(file_path)
        raw.load_data()
        # processing every raw object to remove 60 hz and its multiples:
        eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True)
        freqs = (60,120)
        raw_notch = raw.notch_filter(freqs=freqs, picks=eeg_picks)
        # applying a high pass filter of order 4 with a cutoff frequency of 30 Hz to the data to enhance gamma signal to noise ratio:
        raw_notch.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
        numpy_array=raw_notch.get_data()
        segment_tensor=torch.from_numpy(numpy_array)
        return segment_tensor

In [6]:


CHB_files = {
    'CHB-MIT/chb01_01.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_02.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_03.edf': [(0, 2396, 'interictal'), (2396, 2996, 'preictal'), (2996, 3036, 'ictal'), (3036, 3600, 'interictal')],
    'CHB-MIT/chb01_04.edf': [(0, 867, 'interictal'), (867, 1467, 'preictal'), (1467, 1494, 'ictal'), (1494, 3600, 'interictal')],
    'CHB-MIT/chb01_05.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_06.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_07.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_08.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_09.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_10.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_11.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_12.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_13.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_14.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_15.edf': [(0, 1132, 'interictal'), (1132, 1732, 'preictal'), (1732, 1772, 'ictal'), (1772, 3600, 'interictal')],
    'CHB-MIT/chb01_16.edf': [(0, 415, 'interictal'), (415, 1015, 'preictal'), (1015, 1066, 'ictal'), (1066, 3600, 'interictal')],
    'CHB-MIT/chb01_17.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_18.edf': [(0, 1120, 'interictal'), (1120, 1720, 'preictal'), (1720, 1810, 'ictal'), (1810, 3600, 'interictal')],
    'CHB-MIT/chb01_19.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_20.edf': [(0, 2663, 'interictal')],
    'CHB-MIT/chb01_21.edf': [(0, 327, 'preictal'), (327, 420, 'ictal'), (420, 3600, 'interictal')],
    'CHB-MIT/chb01_22.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_23.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_24.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_25.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_26.edf': [(0, 1262, 'interictal'), (1262, 1862, 'preictal'), (1862, 1963, 'ictal'), (1963, 2325, 'interictal')],
    'CHB-MIT/chb01_27.edf': [(0, 600, 'interictal')],
    'CHB-MIT/chb01_29.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_30.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_31.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_32.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_33.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_34.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_36.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_37.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_38.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_39.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_40.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_41.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_42.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_43.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/chb01_46.edf': [(0, 3600, 'interictal')]
    }



In [7]:


# Create the full dataset
full_dataset = CHBData(CHB_files, segment_length=512)

# Split the dataset into train and test sets





: 

In [None]:
train_indices, test_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)

#  split the test set into validation and actual test sets
val_indices, test_indices = train_test_split(test_indices, test_size=0.5, random_state=42)

# dataloaders for training, validation, and test
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

# Create dataloader type shit
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)

In [6]:
class SeizureSense(nn.Module):
    
    def __init__(self):
        super(SeizureSense,self).__init__()
        
        #Block 1
        #should be taking in an input of 23x512
        #first layer temporal shittt(pointwise is the technical term)
        self.conv1=nn.Conv2d(1,8,(1,128),stride=1,padding=0)
        self.batchnorm1 = nn.BatchNorm2d(8, False)
        
        #spatial layer(depthwise layer)
        self.conv2=nn.Conv2d(8,32,(23,1))
        self.batchnorm2 = nn.BatchNorm2d(32, False)
        self.avgpool1 = nn.AvgPool2d((1,2))
        #apply dropout here in forward
        
        
        #Block 2
        #sepereable convolutional 2d
        self.conv3=nn.Conv2d(32,32,(1,16),stride=1)
        self.batchnorm3 = nn.BatchNorm2d(32,False) #CHECK THIS!
        self.avgpool2 = nn.AvgPool2d((1,16))
        #apply dropout here in forward
        
        
        
        #lstm layers
        
        #Block 3
        #fc1 layer assumiing no bilstm layer rn can change later 
        #matthew check my math here please its 2am 
        
        self.fc1= nn.Linear(352, 30)
        self.fc2=nn.Linear(30,3)
        
        
        
    def forward(self,x):
        x=self.conv1(x)
        x =F.elu(x)
        x=self.batchnorm1(x)
        x = F.dropout(x, 0.25)
        
        x = self.conv2(x)
        x = F.elu(x)
        x=self.batchnorm2(x)
        x = F.dropout(x, 0.25)
        x=self.avgpool1(x)
        
        x=F.elu(self.conv3(x))
        x=self.batchnorm3(x)
        x = F.dropout(x, 0.25)
        x=self.avgpool2(x)
        
        #bilstm stufff
        
        #Fully connected time
        x=x.view(-1,352)
        x=F.elu(self.fc1(x))
        x=F.sigmoid(self.fc2(x))
        return x

In [12]:
model=SeizureSense()
model.to(device)
output=model(full_dataset.__getitem__(1)[0].float())

print(output)

tensor([[0.4635, 0.5390, 0.5229]], grad_fn=<SigmoidBackward0>)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [13]:
#maps back the labels to words
def decode_predictions(predictions):
    label_names = ['interictal', 'preictal', 'ictal']
    _, predicted_labels = torch.max(predictions, 1)
    return [label_names[label] for label in predicted_labels]

predicted_label_names = decode_predictions(output)
print(predicted_label_names)

['preictal']


In [None]:
#training loop time
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters())
num_epochs = 10
for epoch in num_epochs:
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    
    # Validation Phase
    model.eval()  
    val_labels = []
    val_predictions = []
    with torch.no_grad():  # Disable gradient calculation still dont rlly know why this maters but wtv
        for inputs, labels in val_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            val_labels.extend(labels.tolist())
            val_predictions.extend(predicted.tolist())
    
    accuracy = accuracy_score(val_labels, val_predictions)
    recall = recall_score(val_labels, val_predictions, average='macro')
    specificity = recall_score(val_labels, val_predictions, average='macro', pos_label=0)
    