# The Plan

**1. Dataset Acquisition & Understanding**

- **Download the CHBM-MIT EEG dataset:** Go to [https://physionet.org/content/chbmit/](https://physionet.org/content/chbmit/) and follow the instructions to download the dataset.
- **Familiarize yourself with the data:**
  - Open the downloaded files in a text editor or use data exploration tools in Python like pandas.
  - Identify the format of the data (e.g., EDF, CSV).
  - Understand the channels, sampling rate, and meaning of each column in the data.
  - Explore the provided annotation files and seizure labels.
- **Choose environment:**
  - **Macbook Air M1:** Install Python (version 3.7 or above) and the necessary libraries

**2. Preprocessing & Feature Extraction**

- **Import libraries:**

```python
import mne
import numpy as np
import matplotlib.pyplot as plt
```

- **Load data:** Use mne functions like `mne.io.read_raw_edf` to load the EEG data.

```python
raw = mne.io.read_raw_edf('chb01_01.edf', preload=False)  # Replace with your filename
```

- **Cleaning and Filtering:**
  - Apply basic filtering (e.g., notch filter to remove power line noise) using `raw.filter()`.
  - Perform visual inspection (e.g., plotting the data) to identify and remove artifacts like muscle movement or equipment noise.
  - Learn about advanced cleaning techniques like Independent Component Analysis (ICA) for advanced noise reduction.
- **Resampling:**
  - If needed, resample the data to a consistent sampling rate using `raw.resample()`
- **Segmentation:**
  - Use event markers or annotations to segment the data into relevant epochs (e.g., ictal and interictal periods) using `mne.Epochs`.
  - Consider different epoch lengths based on your chosen features and seizure type.
- **Feature Extraction:**
  - Implement functions to calculate desired time-domain features (e.g., mean, variance, amplitude) and frequency-domain features (e.g., power spectral density using FFT) using libraries like NumPy or scikit-learn.
  - Explore libraries like NeuroKit2 for specific EEG feature extraction functionalities.
  - Consider advanced features like connectivity metrics (coherence, phase lag) using MNE-Python for later source localization.

**3. Testing Different Models**
 - **Import libraries:**

```python
import torch
```

# Dataset

## Description of CHBM-MIT EEG Dataset
### Dataset: Scalp EEG Recordings from Children with Intractable Seizures

This dataset contains electroencephalography (EEG) recordings from **22 subjects** with intractable seizures. The data is grouped into **23 cases**, with some subjects contributing multiple recordings.

**Key Points:**

* **Subjects:** 22 (5 males, 17 females; ages 1.5-22)
* **Cases:** 23 (chb01 to chb23)
* **Sampling Frequency:** 256 Hz
* **Recordings per Subject:** 9-42 (each lasting 1-4 hours)
* **Seizures:** 198 total (182 in original set)
* **File Types:**
    * `.edf`: Raw EEG data files (664 total)
    * `.seizure`: Annotations for seizure start and end times (for files containing seizures)

**Additional Notes:**

* Case `chb21` is from the same subject as `chb01`, but recorded 1.5 years later.
* Case `chb24` is not included in the `SUBJECT-INFO` file.

In [1]:
import os
from dotenv import load_dotenv
load_dotenv()

LOCAL_PATH = os.getenv("LOCAL_PATH")

In [2]:
from glob import glob
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

### Load paths

In [3]:
#? Load the data from the records.txt file
records_txt = LOCAL_PATH + 'RECORDS'
with open(records_txt, 'r') as file:
    records_path = [LOCAL_PATH + line.strip() for line in file]
print(len(records_path), records_path)

#? Load the data from the records-with-seizures.txt file
records_seizure_txt = LOCAL_PATH + 'RECORDS-WITH-SEIZURES'
with open(records_seizure_txt, 'r') as file:
    records_seizure_path = [LOCAL_PATH + line.strip() for line in file]
print(len(records_seizure_path), records_seizure_path)

#? Get the records that are not in the records-with-seizures.txt file
records_seizure_set = set(records_seizure_path) #* Convert records_seizure_path to a set for faster lookup
records_normal_path = [record for record in records_path if record not in records_seizure_set]
print(len(records_normal_path), records_normal_path)

#? Load the chbxx-summary.txt file
summary_files = glob(LOCAL_PATH + '*/chb*-summary.txt')
summary_files.sort()
print(len(summary_files), summary_files)

686 ['/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_01.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_02.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_03.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_04.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_05.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_06.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_07.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_08.edf', '/Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb01/chb01_09.edf', '/Users/aaryaa

### Visualizing the data

In [4]:
%%capture
%matplotlib qt
rawEEG = mne.io.read_raw_edf(records_path[0], preload=True)

In [5]:
%%capture
rawEEG.plot(block=False, duration=5, title="rawEEG")

In [6]:
rawEEG.info

0,1
Measurement date,"November 06, 2076 11:42:54 GMT"
Experimenter,Unknown
Participant,Surrogate

0,1
Digitized points,Not available
Good channels,23 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available

0,1
Sampling frequency,256.00 Hz
Highpass,0.00 Hz
Lowpass,128.00 Hz


In [7]:
%%capture
filteredEEG = rawEEG.copy().filter(.05, 45)
filteredEEG.plot(block=False, duration=5, title="Filtered rawEEG")

### Add Seizure label to the columns based on the summary file

#### Extracting seizure times from summary file

In [8]:
import re

In [9]:
def get_seizure_data(file):
    seizure_data = []
    with open(file, 'r') as f:
            block = f.read().split('\n\n')
            for text in block:
                num_seizures = re.findall(r'Number of Seizures in File: (\d+)', text)
                if num_seizures:
                    for i in range(1, int(num_seizures[0])+1):
                        name = re.findall(r'File Name: (.*\.edf)', text)
                        start_time = re.findall(r'Seizure {} Start Time: (\d+) seconds'.format(i), text)
                        end_time = re.findall(r'Seizure {} End Time: (\d+) seconds'.format(i), text)
                        if name and start_time and end_time:
                            seizure_data.append([name[0], int(start_time[0]), int(end_time[0])])
    return pd.DataFrame(seizure_data, columns=['name', 'seizure_start', 'seizure_end'])


In [10]:
seizures = get_seizure_data(summary_files[0])
seizures

Unnamed: 0,name,seizure_start,seizure_end
0,chb01_03.edf,2996,3036
1,chb01_04.edf,1467,1494
2,chb01_15.edf,1732,1772
3,chb01_16.edf,1015,1066
4,chb01_18.edf,1720,1810
5,chb01_21.edf,327,420
6,chb01_26.edf,1862,1963


#### Updating the DF to include the seizure data

In [11]:
#Updating the DF to include the seizure data
def add_seizure(table, patient):
    # Extract the patient file and patient identifier
    patient_file = patient.split('/')[-1]
    patient_id = patient.split('/')[-1].split('_')[0]
    
    # Construct the path to the patient's summary file
    patient_summary_file = LOCAL_PATH + patient_id + "/" + patient_id + '-summary.txt'
    
    # Get seizure data
    seizures = get_seizure_data(patient_summary_file)
    
    # Initialize the 'seizure' column in the table
    table['seizure'] = 0
    
    # Update the 'seizure' column based on the seizure data
    for name, start, end in seizures[['name', 'seizure_start', 'seizure_end']].values:
        if name == patient_file:
            print(name, start, end)
            table.loc[(table['time'] >= start) & (table['time'] <= end), 'seizure'] = 1
    
    return table

### edf -> DataFrame

In [12]:
#? Convert the edf files to a numpy array
def read_edf(file):
    data = mne.io.read_raw_edf(file, preload=True) #* Load the data
    data.set_eeg_reference() #* Set the reference to the average of the channels
    data.filter(l_freq=0.5, h_freq=45) #* Filter the data to remove noise and artifacts and to get the frequency band of interest (0.5-45Hz)
    data = data.to_data_frame() #* Convert the data to a pandas dataframe
    data = add_seizure(data, file) #* Add the seizure data to the dataframe
    return data

In [13]:
%%capture
test = read_edf(records_path[0]) #* test = (n_epochs, n_channels, n_times)

In [14]:
test #? time is on seconds

Unnamed: 0,time,FP1-F7,F7-T7,T7-P7,P7-O1,FP1-F3,F3-C3,C3-P3,P3-O1,FP2-F4,...,T8-P8-0,P8-O2,FZ-CZ,CZ-PZ,P7-T7,T7-FT9,FT9-FT10,FT10-T8,T8-P8-1,seizure
0,0.000000,-1.101143e-14,6.776264e-15,-3.388132e-15,-3.388132e-15,4.319868e-14,1.524659e-14,-8.470329e-16,-5.505714e-15,3.896352e-14,...,-2.032879e-14,-2.710505e-14,4.235165e-15,-2.371692e-14,-1.185846e-14,-1.524659e-14,6.522154e-14,3.388132e-15,-2.032879e-14,0
1,0.003906,7.055020e+01,4.215925e+01,4.246527e-01,-6.569185e+00,8.777450e+01,4.693861e+01,-3.806049e+01,9.762790e+00,6.623274e+01,...,-6.092636e+01,-8.490027e+01,4.342311e+01,-8.819665e+01,-6.004795e+01,8.573601e+00,1.559610e+02,-9.431075e+01,-6.092636e+01,0
2,0.007812,1.109660e+02,6.660589e+01,-2.229159e-01,-1.041840e+01,1.381268e+02,7.336417e+01,-5.972773e+01,1.490357e+01,1.040902e+02,...,-9.535299e+01,-1.317778e+02,6.826041e+01,-1.384215e+02,-9.354706e+01,1.391102e+01,2.428639e+02,-1.488214e+02,-9.535299e+01,0
3,0.011719,1.150999e+02,6.902914e+01,-1.244642e+00,-1.142124e+01,1.431842e+02,7.532187e+01,-6.174902e+01,1.448634e+01,1.076785e+02,...,-9.821620e+01,-1.319506e+02,7.050686e+01,-1.431228e+02,-9.609835e+01,1.583888e+01,2.466743e+02,-1.550397e+02,-9.821620e+01,0
4,0.015625,1.014857e+02,5.967384e+01,-9.441963e-01,-1.134372e+01,1.257844e+02,6.585548e+01,-5.438283e+01,1.167978e+01,9.459313e+01,...,-8.635133e+01,-1.095636e+02,6.153675e+01,-1.260197e+02,-8.515616e+01,1.632351e+01,2.116679e+02,-1.369789e+02,-8.635133e+01,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
921595,3599.980469,-1.966181e+00,-2.142653e+01,7.536038e+00,2.584352e+01,-5.263448e+00,-4.254891e+01,9.008678e+00,4.826194e+01,-2.606238e+00,...,-8.741520e+00,-1.865593e+01,-1.471764e+01,-8.999999e+00,-1.485327e+01,-4.186132e+01,5.323803e+01,1.779193e+01,-8.741520e+00,0
921596,3599.984375,-2.079331e+01,-1.185444e+01,1.661293e+01,1.858470e+01,-2.505590e+01,-2.795592e+01,1.807515e+01,3.698395e+01,-1.727932e+00,...,-8.342460e+00,-2.070563e+01,-5.068659e+00,-2.648837e+00,-1.707417e+01,-3.202083e+01,5.419064e+01,1.606031e+01,-8.342460e+00,0
921597,3599.988281,-3.058849e+01,-3.581231e+00,1.957185e+01,9.808446e+00,-3.530229e+01,-9.906149e+00,1.945038e+01,2.060774e+01,3.630578e-01,...,-6.370221e+00,-1.477411e+01,1.737872e+00,3.991179e+00,-1.536983e+01,-1.520150e+01,3.855946e+01,9.585113e+00,-6.370221e+00,0
921598,3599.992188,-2.244143e+01,1.475079e-01,1.324309e+01,3.444270e+00,-2.574197e+01,-5.632689e-02,1.238904e+01,7.631659e+00,1.035744e+00,...,-3.364490e+00,-6.727399e+00,3.028862e+00,4.935477e+00,-9.170710e+00,-3.482139e+00,1.819841e+01,3.554529e+00,-3.364490e+00,0


In [15]:
test.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 921600 entries, 0 to 921599
Data columns (total 25 columns):
 #   Column    Non-Null Count   Dtype  
---  ------    --------------   -----  
 0   time      921600 non-null  float64
 1   FP1-F7    921600 non-null  float64
 2   F7-T7     921600 non-null  float64
 3   T7-P7     921600 non-null  float64
 4   P7-O1     921600 non-null  float64
 5   FP1-F3    921600 non-null  float64
 6   F3-C3     921600 non-null  float64
 7   C3-P3     921600 non-null  float64
 8   P3-O1     921600 non-null  float64
 9   FP2-F4    921600 non-null  float64
 10  F4-C4     921600 non-null  float64
 11  C4-P4     921600 non-null  float64
 12  P4-O2     921600 non-null  float64
 13  FP2-F8    921600 non-null  float64
 14  F8-T8     921600 non-null  float64
 15  T8-P8-0   921600 non-null  float64
 16  P8-O2     921600 non-null  float64
 17  FZ-CZ     921600 non-null  float64
 18  CZ-PZ     921600 non-null  float64
 19  P7-T7     921600 non-null  float64
 20  T7-F

# Model

Testing and training sets<br>
* Total Normal = 545<br>
* Total Seizure = 142<br>
So we can use around 75% of this data for training ang remaining 25% for testing<br>
We can use 75% of the data for training (normal 436, seizure 113) and 25% for testing (normal 109, seizure 28)

## Neural Networks

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
from sklearn.model_selection import train_test_split

In [17]:
seed = 77

### Feed Forward Neural Network<br>

In [18]:
#? Create the Feed Forward Neural Network Model
class FFNN_Model(nn.Module):
    def __init__(self, in_channels, h1, h2):
        super(FFNN_Model, self).__init__()
        self.fc1 = nn.Linear(in_channels, h1)
        self.fc2 = nn.Linear(h1, h2)
        self.out = nn.Linear(h2, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        return x

##### Training on one patient file

In [39]:
%%capture
torch.manual_seed(seed) #* Set the seed for reproducibility

#? Create the model
data = read_edf(records_path[0])
model = FFNN_Model(data.shape[1]-2, 64, 32) 

In [40]:
data.shape

(921600, 25)

In [41]:
# Train test split
X = data.drop(columns=['time', 'seizure']).values
y = data['seizure'].values

#? Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

#? Convert the data to tensors
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

In [42]:
# Set the criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [43]:
# Train the model
epochs = 100
losses = []
for i in range(epochs):
    # Forward pass
    y_pred = model.forward(X_train)
    loss = criterion(y_pred, y_train)
    losses.append(loss.detach().numpy())
    if i % 10 == 0:
        print(f'Epoch: {i} Loss: {loss.item()}')
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (737280x23 and 25x64)

In [None]:
# Plot the training
plt.plot(range(epochs), losses)
plt.ylabel('Loss')
plt.xlabel('Epoch')

Text(0.5, 0, 'Epoch')

## Main

In [None]:
def main():
    FeedForward()
    # ConvNet()
    # RecNet()
    # LongShort()


In [None]:
if __name__:
    main()

NameError: name 'FeedForward' is not defined

Channels marked as bad:
none
Channels marked as bad:
none
