<a href="https://colab.research.google.com/github/Rishi625/Seizure_detection_ML/blob/main/Seizure_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import numpy as np
import os
from pyedflib import highlevel
from scipy import signal as sgn
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report

In [None]:
# Set the path to the EEG database
PATH = "chb-mit-scalp-eeg-database-1.0.0/chb-mit-scalp-eeg-database-1.0.0"

# Get all folders in the database
all_chb_folders = os.listdir(PATH)

def get_signal(file_path):
    # Read the EDF file
    signals, signal_headers, headers = highlevel.read_edf(file_path)

    # Resample each signal to 100,000 points
    signals = np.array([sgn.resample(sig, 100000) for sig in signals])

    # Take the maximum value across all channels for each time point
    sample = np.max(signals, axis=0)

    # Normalize the sample
    return sample / sample.max()

# Create column names for the DataFrame
signal_cols = [i for i in range(100000)]

def save_npys(folder_name):
    # Get all files in the folder
    all_files = os.listdir(f"{PATH}/{folder_name}")

    # Filter for EDF files and seizure files
    edf_files = [f for f in all_files if f.endswith('edf')]
    seizure_files = ['.'.join(f.split('.')[:-1]) for f in all_files if f.endswith('seizures')]

    # Initialize DataFrame and labels list
    df = pd.DataFrame(columns=signal_cols)
    labels = []

    for f in tqdm(edf_files, leave=False):
        sampled_signal = get_signal(os.path.join(PATH, folder_name, f))
        df.loc[df.shape[0]] = sampled_signal
        label = 1 if f in seizure_files else 0
        labels.append(label)

    # Add labels to the DataFrame
    df['label'] = labels

    # Save the DataFrame as a numpy array
    np.save(f"./converted_artifacts_1lac/{folder_name}.npy", df.to_numpy())

# Process all folders in the database
for folder in tqdm(all_chb_folders):
    save_npys(folder)

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_path):

        #Initialize the CustomDataset.

        super(CustomDataset, self).__init__()
        self.data_path = data_path
        self.all_files = os.listdir(data_path)

        # Initialize empty arrays for data and labels
        self.data = np.zeros((0, 10000))
        self.labels = np.zeros((0))

        # Load and process all files in the data directory
        for file_path in tqdm(self.all_files, leave=False):
            file_path = os.path.join(data_path, file_path)
            signal = np.load(file_path)

            self.data = np.vstack((self.data, signal[:, :-1]))
            self.labels = np.append(self.labels, signal[:, -1])

    def __len__(self):

       # Get the total number of samples in the dataset.
        return self.data.shape[0]

    def __getitem__(self, idx):

        # Get a single sample from the dataset.
        return {
            'signal': self.data[idx].astype(np.float32),
            'label': self.labels[idx].astype(np.int64)
        }

### LSTM Model

In [None]:
# Model
class LSTMModel(torch.nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.lstm1 = torch.nn.LSTM(input_size=10000, hidden_size=100, num_layers=2, batch_first=True)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(0.2)
        self.fc1 = torch.nn.Linear(100, 32)
        self.out = torch.nn.Linear(32, 2)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        h_t = torch.zeros(2, x.size(0), 100, dtype=torch.float32).to(x.device)
        c_t = torch.zeros(2, x.size(0), 100, dtype=torch.float32).to(x.device)
        x, _ = self.lstm1(x, (h_t, c_t))
        x = x[:, -1, :]
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.out(x)
        x = self.softmax(x)
        return x

In [None]:
dataset = CustomDataset("/content/drive/MyDrive/converted_artifacts")

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [500, 186])

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

  0%|          | 0/24 [00:00<?, ?it/s]

### LSTM Model Training


In [None]:
model = LSTMModel()

model.train()

epochs = 10

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for _ in tqdm(range(epochs), leave=True):
    running_loss = 0
    for batch in tqdm(train_dataloader, leave=True):
        x = batch['signal']
        y = batch['label']
        x = torch.tensor(x.reshape(-1, 1, 10000))
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()/len(train_dataloader)
    print(f"Loss: {running_loss}")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

  x = torch.tensor(x.reshape(-1, 1, 10000))


Loss: 0.5612158608436585


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.5189746122360229


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.5177609615325927


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.5174465768337251


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.5171509516239167


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.5168721299171447


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.5028171620368957


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.48405555796623245


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.4464812395572661


  0%|          | 0/125 [00:00<?, ?it/s]

Loss: 0.44747619056701643


### Model Evaluation

In [None]:
model.eval()

preds = []
labels = []
for batch in test_dataloader:
    x = batch['signal']
    y = batch['label']
    x = torch.tensor(x.reshape(-1, 1, 10000))
    y_hat = model(x)
    pred = y_hat.argmax(dim=1)
    labels.extend(y.tolist())
    preds.extend(pred.tolist())

correct = [1 if p==l else 0 for p, l in zip(preds, labels)]


print(sum(correct)/ len(preds))
print(classification_report(labels, preds))

  x = torch.tensor(x.reshape(-1, 1, 10000))


0.7688172043010753
              precision    recall  f1-score   support

           0       0.79      0.96      0.87       147
           1       0.25      0.05      0.09        39

    accuracy                           0.77       186
   macro avg       0.52      0.51      0.48       186
weighted avg       0.68      0.77      0.70       186

