# The Plan

**1. Dataset Acquisition & Understanding**

- **Download the CHBM-MIT EEG dataset:** Go to [https://physionet.org/content/chbmit/](https://physionet.org/content/chbmit/) and follow the instructions to download the dataset.
- **Familiarize yourself with the data:**
  - Open the downloaded files in a text editor or use data exploration tools in Python like pandas.
  - Identify the format of the data (e.g., EDF, CSV).
  - Understand the channels, sampling rate, and meaning of each column in the data.
  - Explore the provided annotation files and seizure labels.
- **Choose environment:**
  - **Macbook Air M1:** Install Python (version 3.7 or above) and the necessary libraries

**2. Preprocessing & Feature Extraction**

- **Import libraries:**

```python
import mne
import numpy as np
import matplotlib.pyplot as plt
```

- **Load data:** Use mne functions like `mne.io.read_raw_edf` to load the EEG data.

```python
raw = mne.io.read_raw_edf('chb01_01.edf', preload=False)  # Replace with your filename
```

- **Cleaning and Filtering:**
  - Apply basic filtering (e.g., notch filter to remove power line noise) using `raw.filter()`.
  - Perform visual inspection (e.g., plotting the data) to identify and remove artifacts like muscle movement or equipment noise.
  - Learn about advanced cleaning techniques like Independent Component Analysis (ICA) for advanced noise reduction.
- **Resampling:**
  - If needed, resample the data to a consistent sampling rate using `raw.resample()`
- **Segmentation:**
  - Use event markers or annotations to segment the data into relevant epochs (e.g., ictal and interictal periods) using `mne.Epochs`.
  - Consider different epoch lengths based on your chosen features and seizure type.
- **Feature Extraction:**
  - Implement functions to calculate desired time-domain features (e.g., mean, variance, amplitude) and frequency-domain features (e.g., power spectral density using FFT) using libraries like NumPy or scikit-learn.
  - Explore libraries like NeuroKit2 for specific EEG feature extraction functionalities.
  - Consider advanced features like connectivity metrics (coherence, phase lag) using MNE-Python for later source localization.

**3. Testing Different Models**
 - **Import libraries:**

```python
import torch
```

# Dataset

## Description of CHBM-MIT EEG Dataset
### Dataset: Scalp EEG Recordings from Children with Intractable Seizures

This dataset contains electroencephalography (EEG) recordings from **22 subjects** with intractable seizures. The data is grouped into **23 cases**, with some subjects contributing multiple recordings.

**Key Points:**

* **Subjects:** 22 (5 males, 17 females; ages 1.5-22)
* **Cases:** 23 (chb01 to chb23)
* **Sampling Frequency:** 256 Hz
* **Recordings per Subject:** 9-42 (each lasting 1-4 hours)
* **Seizures:** 198 total (182 in original set)
* **File Types:**
    * `.edf`: Raw EEG data files (664 total)
    * `.seizure`: Annotations for seizure start and end times (for files containing seizures)

**Additional Notes:**

* Case `chb21` is from the same subject as `chb01`, but recorded 1.5 years later.
* Case `chb24` is not included in the `SUBJECT-INFO` file.

In [None]:
import os
from dotenv import load_dotenv
load_dotenv()

LOCAL_PATH = os.getenv("LOCAL_PATH")

In [None]:
from glob import glob
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

### Load paths

In [None]:
#? Load the data from the records.txt file
records_txt = LOCAL_PATH + 'RECORDS'
with open(records_txt, 'r') as file:
    records_path = [LOCAL_PATH + line.strip() for line in file]
print(len(records_path), records_path)

#? Load the data from the records-with-seizures.txt file
records_seizure_txt = LOCAL_PATH + 'RECORDS-WITH-SEIZURES'
with open(records_seizure_txt, 'r') as file:
    records_seizure_path = [LOCAL_PATH + line.strip() for line in file]
records_seizure_path.pop() #* Remove the last element which is an empty string
print(len(records_seizure_path), records_seizure_path)

#? Get the records that are not in the records-with-seizures.txt file
records_seizure_set = set(records_seizure_path) #* Convert records_seizure_path to a set for faster lookup
records_normal_path = [record for record in records_path if record not in records_seizure_set]
print(len(records_normal_path), records_normal_path)

#? Load the chbxx-summary.txt file
summary_files = glob(LOCAL_PATH + '*/chb*-summary.txt')
summary_files.sort()
print(len(summary_files), summary_files)

### Visualizing the data

In [None]:
%%capture
%matplotlib qt
rawEEG = mne.io.read_raw_edf(records_path[0], preload=True)

In [None]:
%%capture
rawEEG.plot(block=False, duration=5, title="rawEEG")

In [None]:
rawEEG.info

In [None]:
%%capture
filteredEEG = rawEEG.copy().filter(.05, 45)
filteredEEG.plot(block=False, duration=5, title="Filtered rawEEG")

### Add Seizure label to the columns based on the summary file

#### Extracting seizure times from summary file

In [None]:
import re

In [None]:
def get_seizure_data(file):
    seizure_data = []
    with open(file, 'r') as f:
            block = f.read().split('\n\n')
            for text in block:
                num_seizures = re.findall(r'Number of Seizures in File: (\d+)', text)
                if num_seizures:
                    for i in range(1, int(num_seizures[0])+1):
                        name = re.findall(r'File Name: (.*\.edf)', text)
                        start_time = re.findall(r'Seizure {} Start Time: (\d+) seconds'.format(i), text)
                        end_time = re.findall(r'Seizure {} End Time: (\d+) seconds'.format(i), text)
                        if name and start_time and end_time:
                            seizure_data.append([name[0], int(start_time[0]), int(end_time[0])])
    return pd.DataFrame(seizure_data, columns=['name', 'seizure_start', 'seizure_end'])


In [None]:
seizures = get_seizure_data(summary_files[0])
seizures

#### Updating the DF to include the seizure data

In [None]:
#Updating the DF to include the seizure data
def add_seizure(table, patient):
    # Extract the patient file and patient identifier
    patient_file = patient.split('/')[-1]
    patient_id = patient.split('/')[-1].split('_')[0]

    #Special case for chb17
    # /Users/aaryaashokk/Documents/Coding/Projects/DataSets/chb-mit-scalp-eeg-database-1.0.0/chb17a/chb17a-summary.txt
    if patient_id == 'chb17a'or patient_id == 'chb17b' or patient_id == 'chb17c':
        patient_id = 'chb17'
    
    # Construct the path to the patient's summary file
    patient_summary_file = LOCAL_PATH + patient_id + "/" + patient_id + '-summary.txt'
    
    # Get seizure data
    seizures = get_seizure_data(patient_summary_file)
    
    # Initialize the 'seizure' column in the table
    table['seizure'] = 0
    
    # Update the 'seizure' column based on the seizure data
    for name, start, end in seizures[['name', 'seizure_start', 'seizure_end']].values:
        if name == patient_file:
            print(name, start, end)
            table.loc[(table['time'] >= start) & (table['time'] <= end), 'seizure'] = 1
    
    return table

### edf -> DataFrame

In [None]:
#? Convert the edf files to a numpy array
def read_edf(file):
    data = mne.io.read_raw_edf(file, preload=True) #* Load the data
    data.set_eeg_reference() #* Set the reference to the average of the channels
    data.filter(l_freq=0.5, h_freq=45) #* Filter the data to remove noise and artifacts and to get the frequency band of interest (0.5-45Hz)
    data = data.to_data_frame() #* Convert the data to a pandas dataframe
    common_cols = ['time','F7-T7','T7-P7','P7-O1','FP1-F3','F3-C3','C3-P3','P3-O1','FP2-F4','F4-C4','C4-P4','P4-O2','FP2-F8','F8-T8','T8-P8-0','P8-O2','FZ-CZ','CZ-PZ','P7-T7','T7-FT9','FT9-FT10','FT10-T8','T8-P8-1']
    data = data.filter(items=common_cols) #* Filter the data to include only the common channels
    data = add_seizure(data, file) #* Add the seizure data to the dataframe
    return data

In [None]:
%%capture
test = read_edf(records_path[0]) #* test = (n_epochs, n_channels, n_times)

In [None]:
test #? time is on seconds

In [None]:
test.info()

# Model

Testing and training sets<br>
* Total Normal = 545<br>
* Total Seizure = 142<br>
So we can use around 84% of this data for training ang remaining 16% for testing<br>

## Neural Networks

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pandas as pd
from sklearn.model_selection import train_test_split

In [None]:
seed = 77

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device: {device}")

### Feed Forward Neural Network<br>

In [None]:
#? Create the Feed Forward Neural Network Model
class FFNN_Model(nn.Module):
    def __init__(self, in_channels, h1, h2):
        super(FFNN_Model, self).__init__()
        self.fc1 = nn.Linear(in_channels, h1)
        self.fc2 = nn.Linear(h1, h2)
        self.fc3 = nn.Linear(h2, 16)
        self.fc4 = nn.Linear(16, 8)
        self.out = nn.Linear(8, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.out(x)
        return x

##### Training on one patient file

In [None]:
%%capture
torch.manual_seed(seed) #* Set the seed for reproducibility

#? Create the model
data = pd.concat([read_edf(records_path[2]), read_edf(records_path[3])])
model = FFNN_Model(data.shape[1]-2, 64, 32).to(device)

In [None]:
data.shape

In [None]:
# Train test split
X = data.drop(columns=['time', 'seizure']).values
y = data['seizure'].values

#? Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

#? Convert the data to tensors
X_train = torch.FloatTensor(X_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_train = torch.LongTensor(y_train).to(device)
y_test = torch.LongTensor(y_test).to(device)

In [None]:
# Set the criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
# Train the model
epochs = 500
losses = []
for i in range(epochs):
    # Forward pass
    y_pred = model.forward(X_train)
    loss = criterion(y_pred, y_train)

    losses.append(loss)
    if i % 10 == 0:
        print(f'Epoch: {i} Loss: {loss.item()}')
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# Plot the training
plt.plot(range(epochs), losses)
plt.ylabel('Loss')
plt.xlabel('Epoch')

In [None]:
# Evaluate Model
with torch.no_grad():
    y_eval = model.forward(X_test)
    loss = criterion(y_eval, y_test)
loss

In [None]:
correct = 0
with torch.no_grad():
    for i, data in enumerate(X_test):
        y_val = model.forward(data)
        print(f"{i+1}: {str(y_val)} \t {y_test[i]}")
        if y_test[i] == 1 and y_val.argmax().item() == y_test[i]:
            correct += 1
print(f"Total correct: {correct}")
print(f"Accuracy: {correct/len(X_test)}")

#### CLEARLY FCNN IS CANT BE USED FOR THIS DATA

### Recurrent Neural Network

In [None]:
# Define the RNN model
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

In [None]:
def split_training_testing(data):
    X = data.drop(columns=['time', 'seizure']).values
    y = data['seizure'].values

    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.16, random_state=seed, shuffle=False)

    # Reshape the data for RNN input
    X_train = torch.FloatTensor(X_train).reshape(-1, 256, X_train.shape[1]).to(device)
    X_test = torch.FloatTensor(X_test).reshape(-1, 256, X_test.shape[1]).to(device)
    y_train = torch.FloatTensor([y_train[i:i + 256].sum() for i in range(0, len(y_train), 256)]).to(device)
    y_test = torch.FloatTensor([y_test[i:i + 256].sum() for i in range(0, len(y_test), 256)]).to(device)
    
    return X_train, X_test, y_train, y_test

In [None]:
# Train the model
def train_model(model, criterion, optimizer, X_train, y_train, epochs=30):
    model.train()
    model.to(device)
    losses = []

    for epoch in range(epochs):
        optimizer.zero_grad()
        y_pred = model(X_train).squeeze()
        loss = criterion(y_pred, y_train)
        losses.append(loss.item())

        if epoch % 10 == 0:
            print(f'Epoch: {epoch} Loss: {loss.item()}')

        loss.backward()
        optimizer.step()

    # Plot the training loss
    plt.plot(range(epochs), losses)
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.show()

    return model

In [None]:
# Evaluation
def evaluate_model(model, X_test, y_test):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        y_val = model(X_test).squeeze()
        predictions = torch.sigmoid(y_val) >= 0.5
        for i in range(len(y_test)):
            if y_test[i] == 1:
                total += 1
                if predictions[i] == y_test[i]:
                    correct += 1

    print(f"Correct:{correct}/Total:{total}")

    try:
        return correct/total
    except ZeroDivisionError:
        return 0


In [None]:
def RecNet():
    input_size = 22  # Exclude 'time' and 'seizure' columns
    hidden_size = 64
    num_layers = 2
    num_classes = 1
    skipped_files = []

    model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    for file in records_seizure_path:
        data = read_edf(file)
        if data.shape[1]-2 != input_size:
            skipped_files.append(file)
            continue
        X_train, X_test, y_train, y_test = split_training_testing(data)

        model = train_model(model, criterion, optimizer, X_train, y_train, epochs=50)

    accuracy = evaluate_model(model, X_test, y_test)
    print(f"Skipped files: {skipped_files}")
    print(f"Accuracy: {accuracy}")

## Main

In [None]:
def main():
    print("main")
    # FeedForward()
    RecNet()
    # LongShort()


In [None]:
if __name__:
    main()