In [1]:
import mne
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader,Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import json
from torch import tensor
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize

mne.set_log_level('ERROR')

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
CHB_files = {
    'CHB-MIT/CHB_Database/chb01/chb01_01.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_02.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_03.edf': [(0, 2396, 'interictal'), (2396, 2996, 'preictal'), (2996, 3036, 'ictal'), (3036, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_04.edf': [(0, 867, 'interictal'), (867, 1467, 'preictal'), (1467, 1494, 'ictal'), (1494, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_05.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_06.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_07.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_08.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_09.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_10.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_11.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_12.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_13.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_14.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_15.edf': [(0, 1132, 'interictal'), (1132, 1732, 'preictal'), (1732, 1772, 'ictal'), (1772, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_16.edf': [(0, 415, 'interictal'), (415, 1015, 'preictal'), (1015, 1066, 'ictal'), (1066, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_17.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_18.edf': [(0, 1120, 'interictal'), (1120, 1720, 'preictal'), (1720, 1810, 'ictal'), (1810, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_19.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_20.edf': [(0, 2663, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_21.edf': [(0, 327, 'preictal'), (327, 420, 'ictal'), (420, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_22.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_23.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_24.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_25.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_26.edf': [(0, 1262, 'interictal'), (1262, 1862, 'preictal'), (1862, 1963, 'ictal'), (1963, 2325, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_27.edf': [(0, 600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_29.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_30.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_31.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_32.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_33.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_34.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_36.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_37.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_38.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_39.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_40.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_41.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_42.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_43.edf': [(0, 3600, 'interictal')],
    'CHB-MIT/CHB_Database/chb01/chb01_46.edf': [(0, 3600, 'interictal')]
    }

In [3]:
def write_dict_to_json_file(segment_dict, filepath):
    with open(filepath, 'w') as file:
        json.dump(segment_dict, file)
        
def read_dict_from_json_file(filepath):
    with open(filepath, 'r') as file:
        return json.load(file)

In [19]:
def segment_data(filepath, info, segment_length, segment_dict, dict_length):
    raw = mne.io.read_raw_edf(filepath, preload=True)  # Load the EDF file
    eeg_picks = mne.pick_types(raw.info, meg=False, eeg=True)  # Select EEG channels
    freqs = (60, 120)  # Define notch filter frequencies to remove power line noise
    raw.notch_filter(freqs=freqs, picks=eeg_picks)  # Apply notch filter
    raw.set_meas_date(None)  # Reset measurement date

    for segment_info in info:
        start, end, label = segment_info
        current_position = start
        while current_position < end:
            tmin = current_position
            tmax = min(current_position + segment_length, end)  # Ensure tmax does not exceed the label's end
            segment_filename = f"{dict_length}-chb01.fif"
            segment_path = f"CHB-MIT/Segments/chb01/{segment_filename}"
            raw_segment = raw.copy().crop(tmin=tmin, tmax=tmax, include_tmax=False)
            segment_dict[segment_path] = label
            raw_segment.save(segment_path, overwrite=True)
            dict_length += 1  # Increment for a unique filename
            current_position += segment_length  # Move to the next segment start

    return dict_length

In [20]:
#chatgpt attempt nexct cell
segment_dict = {}
dict_length = 0  # Initialize dict_length to 0
segment_length = 1  # Define the length of each segment to be 1 second

# Adapted example usage to process each file in CHB_files with its associated segment info
for file_path, segment_info in CHB_files.items():
    # Now, segment_data processes all segments for a given file in one call
    # The function itself handles iteration over the segments defined in segment_info
    dict_length = segment_data(file_path, segment_info, segment_length, segment_dict, dict_length)

# After processing all files, save the segment dictionary to a JSON file
write_dict_to_json_file(segment_dict, 'CHB-MIT/segment_dict1.json')

In [39]:
class CHBData(Dataset):
    def __init__(self, segment_dict):
        self.segment_dict = segment_dict
        # Mapping label names to integers
        self.label_to_int = {'interictal': 0, 'preictal': 1, 'ictal': 2}

    def __len__(self):
        return len(self.segment_dict)

    def __getitem__(self, index,):
        
        segment = mne.io.read_raw_fif(f'CHB-MIT/Segments/chb01/{index}-chb01.fif', preload=True)
        segment.filter(l_freq=30, h_freq=None, fir_design='firwin', filter_length='auto', phase='zero', fir_window='hamming')
        numpy_array = segment.get_data()
        segment_tensor = torch.from_numpy(numpy_array).unsqueeze(0).unsqueeze(0)

        # Retrieve the label and map it to an integer
        label_name = self.segment_dict[f'CHB-MIT/Segments/chb01/{index}-chb01.fif']
        label = self.label_to_int[label_name]

        return segment_tensor, label


In [36]:
segment_dict = read_dict_from_json_file('CHB-MIT/segment_dict1.json')

In [41]:
full_dataset = CHBData(segment_dict)
#Nut big brain shit right here


torch.Size([1, 1, 23, 256])


In [42]:


# Split the dataset into train and test sets
train_indices, test_indices = train_test_split(range(len(full_dataset)), test_size=0.3, random_state=42)

#  split the test set into validation and actual test sets
val_indices, test_indices = train_test_split(test_indices, test_size=0.5, random_state=42)

# dataloaders for training, validation, and test
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

# Create dataloader type shit
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)


In [2]:
class SeizureSense(nn.Module):
    
    def __init__(self):
        super(SeizureSense,self).__init__()
        
        #Block 1
        #should be taking in an input of 23x512
        #first layer temporal shittt(pointwise is the technical term)
        self.conv1=nn.Conv2d(1,8,(1,128),stride=1,padding=0)
        self.batchnorm1 = nn.BatchNorm2d(8, True)
        
        #spatial layer(depthwise layer)
        self.conv2=nn.Conv2d(8,32,(23,1))
        self.batchnorm2 = nn.BatchNorm2d(32, True)
        self.avgpool1 = nn.AvgPool2d((1,2))
        #apply dropout here in forward
        
        
        #Block 2
        #sepereable convolutional 2d
        self.conv3=nn.Conv2d(32,32,(1,16),stride=1)
        self.batchnorm3 = nn.BatchNorm2d(32,True) #CHECK THIS!
        self.avgpool2 = nn.AvgPool2d((1,16))
        #apply dropout here in forward
        
        
        
        #lstm layers
        
        #Block 3
        #fc1 layer assumiing no bilstm layer rn can change later 
        #matthew check my math here please its 2am 
        
        self.fc1= nn.Linear(96, 30)
        self.fc2=nn.Linear(30,3)
        
        
        
    def forward(self,x):
        x=self.conv1(x)
        x =F.elu(x)
        x=self.batchnorm1(x)
        x = F.dropout(x, 0.25)
        
        x = self.conv2(x)
        x = F.elu(x)
        x=self.batchnorm2(x)
        x = F.dropout(x, 0.25)
        x=self.avgpool1(x)
        
        x=F.elu(self.conv3(x))
        x=self.batchnorm3(x)
        x = F.dropout(x, 0.25)
        x=self.avgpool2(x)
        
        #bilstm stufff
        
        #Fully connected time
        x=x.view(-1,96)
        x=F.elu(self.fc1(x))
        #fOR TRAINING NO NEED FOR SIGMOUD ON LAST THING BC CROSS ENTRTOPY LOSS SO ADD IT IN INFERENCE DONT FORGET
        x=self.fc2(x)
        return x

In [21]:
#maps back the labels to words
def decode_predictions(predictions):
    label_names = ['interictal', 'preictal', 'ictal']
    _, predicted_labels = torch.max(predictions, 1)
    return [label_names[label] for label in predicted_labels]

predicted_label_names = decode_predictions(output)
print(predicted_label_names)

['ictal']


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, recall_score, roc_auc_score
import matplotlib.pyplot as plt

# Assuming model, train_loader, and val_loader are already defined
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
num_classes = 3
losses = []

for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
    
    # Validation Phase
    model.eval()  
    val_labels = []
    val_probabilities = []  # Store probabilities for AUROC calculation
    with torch.no_grad():
        for inputs, labels in val_loader:
            logits = model(inputs)
            outputs = F.softmax(logits, dim=1)
            _, predicted = torch.max(outputs, 1)
            val_labels.extend(labels.cpu().numpy())  # Convert to numpy array
            val_probabilities.extend(outputs.cpu().numpy())  # Convert to numpy array
    
    # Calculate metrics
    accuracy = accuracy_score(val_labels, predicted.cpu().numpy())
    recall = recall_score(val_labels, predicted.cpu().numpy(), average='macro')
    
    
    # Calculate AUROC
    # For multi-class AUROC, we need to binarize the labels
    val_labels_binarized = label_binarize(val_labels, classes=range(num_classes))
    auroc = roc_auc_score(val_labels_binarized, val_probabilities, multi_class='ovr', average='macro')
    plot=plt.plot(losses)
    
    print(f'Epoch {epoch+1}/{num_epochs} - Loss: {np.mean(losses):.4f}, Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, AUROC: {auroc:.4f}')

    # Reset losses for next epoch
    losses = []

