In [1]:
import mne
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from torch import Tensor
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, auc, roc_curve
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
import json
import time
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

mne.set_log_level('ERROR')
mne.cuda.init_cuda(verbose=True)

CUDA not enabled in config, skipping initialization


In [2]:
# Used to read our json file of each dictionary mapped to its file name
def read_dict_from_json_file(filepath):
    with open(filepath, 'r') as file:
        return json.load(file)

In [3]:
class CHBData(Dataset):
    def __init__(self, patient_number, alpha=0.1, ):
        self.patient_number = patient_number
        self.segment_dict = read_dict_from_json_file(f'CHB-MIT/Segment_dicts/segment_dict{self.patient_number}.json')
        # Mapping label names to integers
        self.label_to_int = {'interictal': 0, 'preictal': 1, 'ictal': 2}
        self.alpha = alpha  # Smoothing factor for EMA. Smaller alpha places greater importance on more recent elements
        # Initialize EMA and variance (for standard deviation calculation) for each channel
        #self.EMA = {}
        #self.variance = {}
        # Initial values could be adjusted based on dataset characteristics
        self.initialized_channels = False

    def __len__(self):
        return len(self.segment_dict)

    def __getitem__(self, index):
        segment_filepath = f'CHB-MIT/Segments/chb{self.patient_number}/{index}-chb{self.patient_number}.pt'
        segment = torch.load(segment_filepath)
        segment = segment.to(device)
        segment = self.pad_sequence(segment)

                # Ensure EMA and variance are initialized
        #if not self.initialized_channels:
        #    self.initialize_ema_variance(segment.shape[1])
        #    self.initialized_channels = True

        # Computes z-normalization based on Exponential moving average
        #self.ema_normalize(segment)

        # Retrieve the label and map it to an integer
        label_name = self.segment_dict[segment_filepath]
        label = self.label_to_int[label_name]
        
        segment = segment.float()  # Assuming you're using a device in the outer scope
        
        return segment, label

    def pad_sequence(self, segment):
        total_padding = 28 - segment.shape[1]
        pad_top = total_padding // 2
        pad_bottom = total_padding - pad_top
        padded_tensor = F.pad(segment, (0, 0, pad_top, pad_bottom), mode='constant', value=0)
        return padded_tensor

    def initialize_ema_variance(self, num_channels):
        for i in range(num_channels):
            self.EMA[i] = None
            self.variance[i] = None

    def ema_normalize(self, segment):
        # Update EMA and standard deviation for each channel and normalize
        for i in range(segment.shape[1]):  
            channel_data = segment[0, i, :] # select each channel
            if self.EMA[i] is None:  # First update
                self.EMA[i] = channel_data.mean()
                self.variance[i] = channel_data.var()
            else:
                #update EMA and variance based
                self.EMA[i] = self.alpha * channel_data.mean() + (1 - self.alpha) * self.EMA[i]
                self.variance[i] = self.alpha * ((channel_data - self.EMA[i])**2).mean() + (1 - self.alpha) * self.variance[i]

            std = torch.sqrt(self.variance[i])
            # Normalize this channel
            segment[0,i, :] = (channel_data - self.EMA[i]) / std

In [4]:
class SeizureSense(nn.Module):
    
    def __init__(self):
        super(SeizureSense,self).__init__()
        
        #Block 1
        #should be taking in an input of 28x256
        #first layer temporal filters
        self.conv1=nn.Conv2d(1,8,(1,128),stride=1,padding=0)
        self.batchnorm1 = nn.BatchNorm2d(8, False)
        
        #spatial layer(depthwise layer)
        self.conv2_23=nn.Conv2d(8,32,(28,1))
        self.batchnorm2 = nn.BatchNorm2d(32, False)
        self.avgpool1 = nn.AvgPool2d((1,2))
        #apply dropout here in forward
        
        
        #Block 2
        #sepereable convolutional 2d
        self.conv3=nn.Conv2d(32,32,(1,16),stride=1)
        self.batchnorm3 = nn.BatchNorm2d(32,False) #CHECK THIS!
        self.avgpool2 = nn.AvgPool2d((1,16))
        
     
        #Block 3
        
        self.fc1= nn.Linear(96, 30)
        self.fc2=nn.Linear(30,3)

        #apply dropout here in forward
        self.dropout = nn.Dropout(0.25)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.elu(x)
        x = self.batchnorm1(x)
        x = self.dropout(x)    
        x = self.conv2_23(x)
        x = F.elu(x)
        x = self.batchnorm2(x)
        x = self.dropout(x)
        x = self.avgpool1(x)
        
        x = F.elu(self.conv3(x))
        x = self.batchnorm3(x)
        x = self.dropout(x)
        x = self.avgpool2(x)
        x = x.view(x.size(0), -1)  
        x = self.fc1(x)
        x = F.elu(x)
        x = torch.sigmoid(self.fc2(x))
        return x

In [6]:
model = SeizureSense().to(device)

Training loop

In [7]:
weights = torch.tensor([1,5,5], dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5
num_classes = 3
num_patients = 16
losses = []


for patient_number in range(1,5):
    #ensures correct formatting for dataset
    if patient_number < 10:
        full_dataset = CHBData(f"0{patient_number}")
    else:
        full_dataset = CHBData(patient_number)

    # Split the dataset into train and test sets
    train_indices, val_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)



    # dataloaders for training, validation, and test
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)

    # Instantiate dataloaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


    # Instantiate dataloaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
    start_time = time.time()  # Start timing
    for epoch in range(num_epochs):        
        model.train()
        for inputs, labels in train_loader:
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            losses.append(loss.item())
            loss.backward()
            optimizer.step()

        # Validation Phase
        model.eval()
        val_labels = []
        val_probabilities = []  # Store probabilities for AUROC calculation
        val_predictions = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                labels = labels.to(device)
                logits = model(inputs)
                outputs = F.softmax(logits, dim=1)

                # finds the class with the highest probability and returns it as a one hot encoded vector (i.e [1,0,0])
                _, predictions = torch.max(outputs, 1) 

                predictions = predictions.cpu().numpy()
                val_labels.extend(labels.cpu().numpy())  # Convert to numpy array
                val_predictions.extend(predictions)
                #val_probabilities.extend(outputs.cpu().numpy())  # Convert to numpy array

        # Calculate metrics
        accuracy = accuracy_score(val_labels, val_predictions)
        recall = recall_score(val_labels, val_predictions, average=None, zero_division=0)  
        # Calculate AUROC
        # For multi-class AUROC, adjust according to your class strategy
        #val_labels_binarized = label_binarize(val_labels, classes=np.unique(val_labels))
        #auroc = roc_auc_score(val_labels_binarized, np.array(val_probabilities), average='weighted',multi_class='ovr')  
        # Timing end and calculate duration
        end_time = time.time()
        epoch_duration = end_time - start_time
        print(recall)
        print(f'Epoch {epoch+1}/{num_epochs} - Loss: {np.mean(losses):.4f}, Accuracy: {accuracy:.4f}, Recall: {recall[1]:.4f}, Time: {epoch_duration:.2f}s')
        #AUROC: {auroc:.4f},
        # Reset losses for next epoch
        losses = []

KeyboardInterrupt: 

In [None]:
print(accuracy)

1.0


In [31]:
torch.save(model.state_dict(), 'SeizureSenseStateDict.pth')

In [None]:
optimal_idx = np.argmin(np.sqrt(np.square(1-tpr) + np.square(fpr)))
optimal_threshold = thresholds[optimal_idx]
print(f"Optimal threshold: {optimal_threshold}")

In [166]:

optimal_threshold=.3 #actual=.087
def decode_predictions(predictions):
    agregated_outputs=predictions.mean(dim=1)
    probabilities=F.softmax(agregated_outputs,dim=-1)
    print(probabilities)
    if probabilities[0][1]>optimal_threshold:
        return 'preictal'
    else:
        max_values, max_indices = torch.max(probabilities, dim=1)
        if max_indices==0:
            return 'interictal'
        else:
            return 'ictal'

In [176]:
input=torch.rand((1,1,28,256)).float()

model.eval()
with torch.no_grad():  # Disable gradient computation for inference
    predictions = model(input.to(device))
    print(predictions.shape)
    #print(predictions)
    
    print(decode_predictions(predictions))
    #predictions=decode_predictions(predictions)
    #print('Predictions:',predictions)


torch.Size([1, 77, 3])
tensor([[1., 0., 0.]], device='cuda:0')
interictal
