<a href="https://colab.research.google.com/github/AmulyaMat/EEG-Signal-Seizure-Detection/blob/main/EEG_Model_Final_Amulya.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES
# TO THE CORRECT LOCATION (/kaggle/input) IN YOUR NOTEBOOK,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'hms-harmful-brain-activity-classification:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-competitions-data%2Fkaggle-v2%2F59093%2F7469972%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240401%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240401T023013Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D15071ecb525e95669e7b8e7567f9ed0b64f26474242798b0c825392121d6c87ac24e62cd1164157520fc0f8b103cd5e250e6aedbaa237e23fdf265d65ab11dd7c3e8adf7783dc5b0fefdaf6459c47f5aa44341b0fab0785d781edd47637c5e91019b3db5585d109bf5a30251ead29f4cbb71758b28214040c15a35e874ea557c64cb3f071da84b0152e5d4ddd809f22efef789a7f8da7e21c518853e028ca2f3ae4db8fa169ec0e4f79fef4b015419ed31cb4644168735d0f074ea3c967d0ec548d9f2fa4e68b3c5099105495da83425cf3321da05dffcd7f46803b682570f241458330ad5f96d537d17e01b74d1f533eb543b4620c95277bedf5cd93ad11fff'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


# 1. Importing libraries

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

# 2. Importing EEG data (train.csv) and extracting 10 second secgments

In [None]:
train = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/train.csv')
print('Train shape', train.shape )
train

In [None]:
ids = train.eeg_id.unique()

In [None]:
import random
n = 10
sample_ids = random.sample(list(ids), n)
sample_ids

In [None]:
# CREATING DATA LOADER FOR SPECIFIC EEG IDS

import pandas as pd

def create_data_loader(eeg_ids, eeg_data_dir, train_data, segment_length=10):
    """
    Create a data loader function to extract 10-second EEG segments for specified EEG IDs.

    Args:
    - eeg_ids (list): List of EEG IDs for which segments need to be extracted.
    - eeg_data_dir (str): Directory path where EEG data files are stored.
    - train_data (DataFrame): DataFrame containing training data with EEG labels and offsets.
    - segment_length (int): Length of EEG segments in seconds.

    Returns:
    - data_loader (generator): Generator function to yield EEG segments along with target labels.
    """
    def data_loader():
        for eeg_id in eeg_ids:
            # Load EEG data for the current EEG ID
            eeg_data_path = f"{eeg_data_dir}/{eeg_id}.parquet"
            example = pd.read_parquet(eeg_data_path)

            # Filter training data for the current EEG ID
            train_eegid = train_data[train_data['eeg_id'] == eeg_id]
            offset_values_list = train_eegid['eeg_label_offset_seconds'].tolist()
            print("Number of offset subsamples for EEG ID", eeg_id, ":", len(offset_values_list))

            # Extract 10-second EEG segments along with target labels
            for offset in offset_values_list:
                start_index = int(offset) * 200
                end_index = start_index + (segment_length * 200)

                # Extract 10-second segment centered around the offset
                middle_index = (start_index + end_index) // 2
                segment_start = middle_index - (segment_length // 2 * 200)
                segment_end = middle_index + (segment_length // 2 * 200)

                # Extract EEG segment
                eeg_segment = example.iloc[segment_start:segment_end].reset_index(drop=True)

                # Get target labels
                target_labels = train_eegid.iloc[0][['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']]

                # Yield EEG segment along with target labels
                yield eeg_segment, target_labels

    return data_loader

# Example usage:
eeg_ids = sample_ids
eeg_data_dir = "/kaggle/input/hms-harmful-brain-activity-classification/train_eegs"
train_data = train  # Assuming 'train' is your DataFrame containing training data
segment_length = 10  # 10-second EEG segments
loader = create_data_loader(eeg_ids, eeg_data_dir, train_data, segment_length)

In [None]:
# Getting EEG data and target data from data loader

# Collect all EEG segments and target labels into separate lists
all_segments = []
all_targets = []

# Iterate through the data loader to extract EEG segments and target labels
for eeg_segment, target_labels in loader():
    # Append each EEG segment to the list
    all_segments.append(eeg_segment)
    # Append corresponding target labels to the list
    all_targets.append(target_labels)

# Concatenate all segments into a single DataFrame
full_eeg_segments = pd.concat(all_segments, ignore_index=True)
full_eeg_segments

In [None]:
full_eeg_targets = pd.DataFrame(all_targets)
full_eeg_targets

# 3. Visualizing EEG signals

In [None]:
def plot_eeg(df, title):
    fig, axs = plt.subplots(20, 1, figsize=(30, 20), sharex=True)

    for i, ax in enumerate(axs):
        ax.plot(df.iloc[:,i], color="black")
        ax.set_ylabel(df.columns[i], rotation=0)
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xticks([])
        ax.spines[["top", "bottom", "left", "right"]].set_visible(False)

    fig.suptitle(title, fontsize=50, verticalalignment='top')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust the layout to not overlap with the figure title
    plt.show()

In [None]:
eeg_example = pd.read_parquet('/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/941668605.parquet')
print(f'eeg_id=941668605 has {eeg_example.shape[0]} samples.')
print()
display(eeg_example)

In [None]:
plot_eeg(eeg_example, 'EEG signal, id = 941668605, all subsamples')

In [None]:
offset = 0
start_index = int(offset) * 200
end_index = start_index + (segment_length * 200)

# Extract 10-second segment centered around the offset
middle_index = (start_index + end_index) // 2
segment_start = middle_index - (segment_length // 2 * 200)
segment_end = middle_index + (segment_length // 2 * 200)

# Extract EEG segment
eeg_segment_0 = eeg_example.iloc[segment_start:segment_end].reset_index(drop=True)

# Get target labels
#target_labels = train_eegid.iloc[0][['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']]

eeg_segment_0

In [None]:
plot_eeg(eeg_segment_0, 'EEG signal, id = 941668605, subsamples = 0')

In [None]:
plt.plot(range(2000), eeg_segment_0.iloc[:,0], color="blue")
plt.xlabel("# of samples")
plt.ylabel("voltage mV")
plt.title('EEG Signal, Fp1')

In [None]:
full_eeg_segments

# 3. Feature Engineering

## 3a. Denoising signals with wavelet transform

In [None]:
# denoising function using wavelet transform
import pywt

def maddest(d, axis=None):
    return np.mean(np.absolute(d - np.mean(d, axis)), axis)

def denoise(x, wavelet='haar', level=1):
    ret = {key:[] for key in x.columns}

    for pos in x.columns:
        coeff = pywt.wavedec(x[pos], wavelet, mode="per")
        sigma = (1/0.6745) * maddest(coeff[-level])

        uthresh = sigma * np.sqrt(2*np.log(len(x)))
        coeff[1:] = (pywt.threshold(i, value=uthresh, mode='hard') for i in coeff[1:])

        ret[pos]=pywt.waverec(coeff, wavelet, mode='per')

    return pd.DataFrame(ret)

In [None]:
eeg_segment_0_denoised = denoise(eeg_segment_0, wavelet="db8")
plot_eeg(eeg_segment_0_denoised, 'Denoised EEG signals for EEG id =  941668605, offset = 0')

In [None]:
plt.plot(range(2000), eeg_segment_0_denoised.iloc[:,0], color="blue")
plt.xlabel("# of samples")
plt.ylabel("voltage mV")
plt.title('EEG Signal, Fp1')

In [None]:
# Denoise entire data
full_eeg_segments_denoised = denoise(full_eeg_segments, wavelet="db8")
full_eeg_segments_denoised

## 3b. Discrete wavelet transform features

In [None]:
from pywt import wavedec

def wavelet_decompose_channels(data, level, output=False):
  # take every x number of points using numpy's slicing (start:stop:step)
    data = data[0::2]

    data.columns.name='channel'

    # transpose the data
    data_t = data.transpose()

    # get the wavelet coefficients at each level in a list
    coeffs_list = wavedec(data_t.values, wavelet='db4', level=level)
    #print(len(coeffs_list))

    # make a list of the component names (later column rows)
    nums = list(range(1,level+1))
    names=[]
    for num in nums:
        names.append('D' + str(num))
    names.append('A' + str(nums[-1]))

  # reverse the names so it counts down
    names = names[::-1]
    #print(names)

    i = 0
    wavelets = pd.DataFrame()
    for i in range(1, len(coeffs_list)):
    #for i, array in enumerate(coeffs_list):
        #print(i)
        array = coeffs_list[i]
        # turn into a dataframe and transpose
        level_df = pd.DataFrame(array)
        level_df.index = data.columns
        level_df['level'] = names[i]
        level_df= level_df.set_index('level', append=True)
        level_df=level_df.T
        # add the next levels df to another column
        wavelets = pd.concat([wavelets,level_df], axis=1, sort=True)

    # sort values along the channels
    wavelets = wavelets.sort_values(['channel', 'level'], axis=1)

  # remove the AN levels
  #regex = re.compile('D')
  #bad_items = [x for x in list(wavelets.columns.levels[1]) if not regex.match(x)]
  #decom_wavelets = wavelets.drop(bad_items, axis=1, level = 'level')

  #decom_wavelets.index.name='sample'

  #if output:
  #  display(decom_wavelets.head())

    wavelets_cleaned = wavelets.dropna()

    return wavelets_cleaned

dwt_wavelets = wavelet_decompose_channels(full_eeg_segments_denoised, level=5, output=True)

In [None]:
dwt_wavelets

## c. Finding statistics for DWT features

### Log Sum

In [None]:
def minus_small(data):
  # find the smallest value for each data column (channel)...
  min_val = data.min()
  # ...and subtract it from all the data in the column and add one
  data = data.subtract(min_val).add(1)

  return data

def log_sum(data, output=False):
    absolute_sums = data.sum()
    # ...and subtract it from all the data in the column and add one
    absolute_sums_minus = minus_small(absolute_sums)
    # find the log of each elecment (datapoint)
    absolute_sums_log = absolute_sums_minus.apply(np.log)
    absolute_sums_log.index += '_LSWT'

    if output:
        display(absolute_sums_log)

    return absolute_sums_log


def reformat(data, feature_name):
  data.index = [feature_name+level for level in data.index]
  data.index.name = 'feature'
  data = pd.DataFrame(data.unstack()).T

  return data

def log_sum_channels(data, output=False):
  #absolute_sums = data.sum()
  # make the columns channels
  #absolute_sums = absolute_sums.unstack('channel')
  # for each channel apply the minus small function

    logsum = pd.DataFrame(index=data.index)

    # Iterate over each channel and calculate the mean across 'D1' to 'D5'
    for channel in data.columns.get_level_values(0).unique():
        print(channel)
    # Calculate the mean for the current channel

        logsum[channel] = data[channel].apply(minus_small)
        logsum[channel] = data[channel].apply(np.log)


    logsum.columns = [f"{col}_LSTW" for col in logsum.columns]


    return logsum

In [None]:
log_sum_channels(dwt_wavelets, output=True)

In [None]:
def ave(data, output=False):
    # Initialize an empty DataFrame to store the means for each channel
    means = pd.DataFrame(index=example_wavelets.index)

    # Iterate over each channel and calculate the mean across 'D1' to 'D5'
    for channel in data.columns.get_level_values(0).unique():
    # Calculate the mean for the current channel
        means[channel] = data[channel].mean(axis=1)

    means.columns = [f"{col}_DT_mean" for col in means.columns]

    return means


# Use the function with your data
example_wavelet_mean = ave(example_wavelets, output=True)
example_wavelet_mean

### Mean

In [None]:
def ave(data, output=False):
    # Initialize an empty DataFrame to store the means for each channel
    means = pd.DataFrame(index=example_wavelets.index)

    # Iterate over each channel and calculate the mean across 'D1' to 'D5'
    for channel in data.columns.get_level_values(0).unique():
    # Calculate the mean for the current channel
        means[channel] = data[channel].mean(axis=1)

    means.columns = [f"{col}_DT_mean" for col in means.columns]

    return means


# Use the function with your data
example_wavelet_mean = ave(example_wavelets, output=True)
example_wavelet_mean

### Mean Average Power

In [None]:
def abs_ave(data, output=False):
    # Initialize an empty DataFrame to store the means for each channel
    means_abs = pd.DataFrame(index=data.index)

    # Iterate over each channel and calculate the mean across 'D1' to 'D5'
    for channel in data.columns.get_level_values(0).unique():
    # Calculate the mean for the current channel
        means_abs[channel] = data[channel].abs().mean(axis = 1)

    means_abs.columns = [f"{col}_DT_mean_abs" for col in means_abs.columns]

    return means_abs

example_wavelet_meanabs = abs_ave(example_wavelets, output=True)
example_wavelet_meanabs

### STD

In [None]:
def std_val(data, output=False):
    # Initialize an empty DataFrame to store the means for each channel
    means_abs = pd.DataFrame(index=data.index)

    # Iterate over each channel and calculate the mean across 'D1' to 'D5'
    for channel in data.columns.get_level_values(0).unique():
    # Calculate the mean for the current channel
        means_abs[channel] = data[channel].std(axis = 1)

    means_abs.columns = [f"{col}_DT_mean_abs" for col in means_abs.columns]

    return means_abs

example_wavelet_std = std_val(example_wavelets, output=True)
example_wavelet_std

In [None]:
def ratio_channels(epoch_data):

  decimation_levels = list(epoch_data.index)

  ratio_data=pd.Series()
  for level_no in range(0, len(decimation_levels)):
    # for the first decimation
    if level_no == 0:
      ratio_data[decimation_levels[level_no]] = \
      epoch_data.loc[decimation_levels[level_no]]/epoch_data.loc[decimation_levels[level_no+1]]
    #for the last decimation
    elif level_no == len(decimation_levels)-1:
      ratio_data[decimation_levels[level_no]] = \
      epoch_data.loc[decimation_levels[level_no]]/epoch_data.loc[decimation_levels[level_no-1]]
    else:
      before = epoch_data.loc[decimation_levels[level_no-1]]
      after = epoch_data.loc[decimation_levels[level_no+1]]
      mean_data = (before+after)/2

      ratio_data[decimation_levels[level_no]] = \
      epoch_data.loc[decimation_levels[level_no]]/mean_data

  #name the index column
  ratio_data.index.name = 'features'

  return ratio_data

# get the ratio
example_ratio_data = example_wavelets.mean().unstack('channel').apply(ratio_channels)
example_ratio_data = reformat(example_ratio_data, 'Ratio_Mean_')
display(example_ratio_data.head())

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv1D, BatchNormalization, Activation, MaxPooling1D, Flatten, Dense, LSTM, Bidirectional
from tensorflow.keras.models import Model

def build_cnn_rnn_model(input_shape=(310, 19), num_classes=2):
    # Input layer
    inputs = Input(shape=input_shape)

    # CNN Module (DenseNet-BC architecture)
    x = inputs
    growth_rate = 32
    for i in range(7):
        x = dense_block(x, growth_rate, layers=4)
        if i != 6:
            x = transition_block(x)

    # RNN Module (Bidirectional LSTM)
    x = Bidirectional(LSTM(units=8, return_sequences=True))(x)

    # Fully Connected Layer
    x = Flatten()(x)
    x = Dense(16, activation='relu')(x)

    # Output Layer
    outputs = Dense(num_classes, activation='softmax')(x)

    # Create model
    model = Model(inputs=inputs, outputs=outputs)

    return model

def dense_block(x, growth_rate, layers=4):
    for _ in range(layers):
        x = conv_block(x, growth_rate)
    return x

def conv_block(x, growth_rate):
    x1 = BatchNormalization()(x)
    x1 = Activation('relu')(x1)
    x1 = Conv1D(filters=4*growth_rate, kernel_size=1, padding='same')(x1)

    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv1D(filters=growth_rate, kernel_size=3, padding='same')(x1)

    x = tf.keras.layers.concatenate([x, x1], axis=-1)
    return x

def transition_block(x):
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv1D(filters=int(x.shape[-1])//2, kernel_size=1, padding='same')(x)
    x = MaxPooling1D(pool_size=2)(x)
    return x

# Build the model
model = build_cnn_rnn_model()
model.summary()


We define functions to create the CNN and RNN modules separately.

The build_cnn_rnn_model function integrates both modules into a single model.

The **CNN** module consists of **7 dense blocks, each containing 4 convolutional layers** with a **growth rate of 32.**

Between dense blocks, transition blocks are used to reduce the number of feature maps and downsample the input.

After the CNN module, **a bidirectional LSTM layer** is added to capture long-term dependencies.

Finally, a fully connected layer and an output layer are added for classification.

We then build the model and print its summary to examine the architecture.