# Brain Solver Python Processing Notebook

This notebook utilizes the custom `brain_solver` package for analyzing brain activity data. Our data sources include official datasets from Kaggle competitions and additional datasets for enhanced model training and evaluation.

This is the Training notebook.

## Data Sources

### Official:

- **HMS - Harmful Brain Activity Classification**
  - **Source:** [Kaggle Competition](https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification)
  - **Description:** This competition focuses on classifying harmful brain activity. It includes a comprehensive dataset for training and testing models.

- **Brain-Spectrograms**
  - **Source:** [Kaggle Dataset](https://www.kaggle.com/datasets/cdeotte/brain-spectrograms)
  - **Description:** The `specs.npy` file contains all the spectrograms from the HMS competition, offering a detailed view of brain activity through visual representations.

### Additional:

- **Brain-EEG-Spectrograms**
  - **Source:** [Kaggle Dataset](https://www.kaggle.com/datasets/cdeotte/brain-eeg-spectrograms)
  - **Description:** The `EEG_Spectrograms` folder includes one NumPy file per EEG ID, with each array shaped as (128x256x4), representing (frequency, time, montage chain). This dataset provides a more nuanced understanding of brain activity through EEG spectrograms.

- **hms_efficientnetb0_pt_ckpts**
  - **Source:** [Kaggle Dataset](https://www.kaggle.com/datasets/crackle/hms-efficientnetb0-pt-ckpts)
  - **Description:** This dataset offers pre-trained checkpoints for EfficientNetB0 models, tailored for the HMS competition. It's intended for use in fine-tuning models on the specific task of harmful brain activity classification.


In [1]:
import os, sys
import gc
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from brain_solver import Helpers as hp, Trainer as tr, BrainModel as br, EEGDataset
from brain_solver import Wav2Vec2 as w2v
from brain_solver import Filters, FilterType
from transformers.utils import logging
from tqdm import tqdm

# Suppress warnings if desired
import warnings

warnings.filterwarnings("ignore")
logging.set_verbosity(logging.CRITICAL)

# Setup for CUDA device selection
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [2]:
from brain_solver import Config
full_path = "/home/osloup/NoodleNappers/brain/data/" # Luppo
# full_path = "C:/Users/tygof/Documents/Semester 8/MLiP/NoodleNappers/brain/data/" # Tygo
# full_path = "C:/Users/dahbl/Documents/TrueDocs/Uni/Year 4/Semester 2/Machine Learning in Practice/brain/brain/data/" # Dick
config = Config(full_path,  full_path + "out/", USE_EEG_SPECTROGRAMS=True, USE_KAGGLE_SPECTROGRAMS=True, should_read_brain_spectograms=False, should_read_eeg_spectrogram_files=False, USE_PRETRAINED_MODEL=False)

# Kaggle Pull
# full_path = "/kaggle/input/"
# config = Config(full_path, "/kaggle/working/", USE_EEG_SPECTROGRAMS=True, USE_KAGGLE_SPECTROGRAMS=True, should_read_brain_spectograms=False, should_read_eeg_spectrogram_files=False, USE_PRETRAINED_MODEL=False)

import sys
sys.path.append(full_path + 'kaggle-kl-div')
# from kaggle_kl_div import score

In [3]:
# Create Output folder if does not exist
if not os.path.exists(config.output_path):
    os.makedirs(config.output_path)

# Initialize random environment
pl.seed_everything(config.seed, workers=True)

print(config.data_train_csv)

Seed set to 2024


/home/osloup/NoodleNappers/brain/data/hms-harmful-brain-activity-classification/train.csv


In [4]:
train_df: pd.DataFrame = hp.load_csv(config.data_train_csv)

if train_df is None:
    print("Failed to load the CSV file.")
    exit()
else:
    EEG_IDS = train_df.eeg_id.unique()
    TARGETS = train_df.columns[-6:]
    TARS = {"Seizure": 0, "LPD": 1, "GPD": 2, "LRDA": 3, "GRDA": 4, "Other": 5}
    TARS_INV = {x: y for y, x in TARS.items()}
    print("Train shape:", train_df.shape)

Train shape: (106800, 15)


In [5]:
train_data_preprocessed = hp.preprocess_eeg_data(train_df, TARGETS)

Train non-overlap eeg_id shape: (17089, 12)


In [6]:
train_data_preprocessed.head()

Unnamed: 0,eeg_id,spec_id,min_offset,max_offset,patient_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,target
0,568657,789577333,0.0,16.0,20654,0.0,0.0,0.25,0.0,0.166667,0.583333,Other
1,582999,1552638400,0.0,38.0,20230,0.0,0.857143,0.0,0.071429,0.0,0.071429,LPD
2,642382,14960202,1008.0,1032.0,5955,0.0,0.0,0.0,0.0,0.0,1.0,Other
3,751790,618728447,908.0,908.0,38549,0.0,0.0,1.0,0.0,0.0,0.0,GPD
4,778705,52296320,0.0,0.0,40955,0.0,0.0,0.0,0.0,0.0,1.0,Other


In [7]:
# Initialize the Filters class
# ft = Filters(order=5)

# train_eegs/ EEG data from one or more overlapping samples. Use the metadata in train.csv to select specific annotated subsets. The column names are the names of the individual electrode locations for EEG leads, with one exception. The EKG column is for an electrocardiogram lead that records data from the heart. All of the EEG data (for both train and test) was collected at a frequency of 200 samples per second.

# Define filter parameters
# cutoff_low = 0.1  # Low cutoff frequency (Hz)
# cutoff_high = 50.0  # High cutoff frequency (Hz)
# fs = 200  # Sampling rate (Hz)

# filtered_brain_spectrograms = {key: ft.apply_filter_to_spectrogram(spectrogram, [cutoff_low, cutoff_high], fs, FilterType.BANDPASS) for key, spectrogram in spectrograms.items()}
# filtered_eeg_spectrograms = {key: ft.apply_filter_to_spectrogram(spectrogram, [cutoff_low, cutoff_high], fs, FilterType.BANDPASS) for key, spectrogram in data_eeg_spectrograms.items()}

# combined_brain_spectrograms = {key: {'raw': spectrograms[key], 'filtered': filtered_brain_spectrograms[key]} for key in spectrograms}
# combined_eeg_spectrograms = {key: {'raw': data_eeg_spectrograms[key], 'filtered': filtered_eeg_spectrograms[key]} for key in data_eeg_spectrograms}

In [8]:
read_path = config.data_spectograms

files = os.listdir(read_path)
print(f"There are {len(files)} spectrogram parquets")

There are 11138 spectrogram parquets


In [9]:
# Create Output folder for wav2vec if does not exist
if not os.path.exists(config.data_w2v_specs):
    os.makedirs(config.data_w2v_specs)

In [10]:
force_regenerate = False

In [11]:
# Initialize the Filters class
ft = Filters(order=5)

# Define filter parameters for each filter type
cutoffs = {
    FilterType.LOWPASS: 50.0,  # Cutoff frequency for lowpass
    FilterType.HIGHPASS: 0.1,  # Cutoff frequency for highpass
    FilterType.BANDPASS: [0.1, 50.0],  # Low and high cutoff frequencies for bandpass
    FilterType.BANDSTOP: [45.0, 55.0],  # Low and high cutoff frequencies for bandstop
}
fs = 250  # Sampling rate (Hz)

In [12]:
# for i, f in tqdm(enumerate(files), total=len(files)):
#     name = f[:-8]

#     # First, handle the wav2vec-only processing
#     # Define a directory for wav2vec processed data without filtering
#     w2v_only_dir = os.path.join(config.data_w2v_specs, 'wav2vec_only')
#     os.makedirs(w2v_only_dir, exist_ok=True)
    
#     # Define output filename for wav2vec processed data without filtering
#     w2v_only_output_filename = os.path.join(w2v_only_dir, f"{name}.npy")
    
#     # Check if wav2vec processed data without filtering needs to be generated
#     if not os.path.exists(w2v_only_output_filename) or force_regenerate:
#         try:
#             # Load the data from the parquet file
#             parquet_file = pd.read_parquet(os.path.join(read_path, f))
#             data_for_processing = parquet_file.iloc[:, 1:].values
#             # Assuming data_for_processing needs to be in a specific format for wav2vec, adjust as necessary
            
#             # Process with wav2vec and save
#             w2v_only_output = w2v.wav2vec2(data_for_processing)
#             np.save(w2v_only_output_filename, w2v_only_output)
#         except Exception as e:
#             print(f"ERROR: An unexpected error occurred for {name} (wav2vec only): {e}")

#     # Then, continue with the existing loop for filtered data processing
#     for filter_type, cutoff in cutoffs.items():
#         # Define directories for raw and w2v processed data
#         raw_dir = os.path.join(config.data_w2v_specs, filter_type.name, 'raw')
#         w2v_dir = os.path.join(config.data_w2v_specs, filter_type.name, 'w2v')
        
#         # Ensure directories exist
#         os.makedirs(raw_dir, exist_ok=True)
#         os.makedirs(w2v_dir, exist_ok=True)
        
#         # Define output filenames for raw and w2v processed data
#         raw_output_filename = os.path.join(raw_dir, f"{name}.npy")
#         w2v_output_filename = os.path.join(w2v_dir, f"{name}.npy")
        
#         try:
#             # Check if raw filtered data needs to be processed
#             if not os.path.exists(raw_output_filename) or force_regenerate:
#                 if 'data_for_processing' not in locals():
#                     # Load the data only if it hasn't been loaded already
#                     parquet_file = pd.read_parquet(os.path.join(read_path, f))
#                     data_for_processing = parquet_file.iloc[:, 1:].values
#                 parquet_file_non_nan = np.nan_to_num(data_for_processing, nan=0)
                
#                 filtered_spectrogram = ft.apply_filter_to_spectrogram(parquet_file_non_nan, cutoff, fs, filter_type)
#                 np.save(raw_output_filename, filtered_spectrogram)
                
#             # Check if w2v processed data needs to be generated
#             if not os.path.exists(w2v_output_filename) or force_regenerate:
#                 if 'filtered_spectrogram' not in locals():
#                     # Load the existing raw filtered data if it wasn't just generated
#                     filtered_spectrogram = np.load(raw_output_filename)
                
#                 w2v_output = w2v.wav2vec2(filtered_spectrogram)
#                 np.save(w2v_output_filename, w2v_output)

#         except Exception as e:
#             print(f"ERROR: An unexpected error occurred for {name}: {e}")


In [13]:
read_path_eeg = config.path_to_eeg_spectrograms_folder

files_eeg = os.listdir(read_path_eeg)
print(f"There are {len(files_eeg)} EEG spectrogram NPYs")

There are 17089 EEG spectrogram NPYs


In [14]:
# Create Output folder for wav2vec if does not exist
if not os.path.exists(config.data_w2v_specs_eeg):
    os.makedirs(config.data_w2v_specs_eeg)

In [15]:
force_regenerate=False
min_length=1600

In [20]:
# Define a directory for wav2vec processed data without filtering
w2v_only_dir = os.path.join(config.data_w2v_specs_eeg, 'wav2vec_only')
os.makedirs(w2v_only_dir, exist_ok=True)

for eeg_id, spectrogram_path in tqdm(enumerate(files_eeg), total=len(files_eeg)):
    name = spectrogram_path[:-8]
    
    
    # Define output filename for wav2vec processed data without filtering
    w2v_only_output_filename = os.path.join(w2v_only_dir, f"{eeg_id}.npy")
    
    # Process with wav2vec (no filter)
    if force_regenerate or not os.path.exists(w2v_only_output_filename):
        # Assuming w2v.wav2vec2 can handle the preprocessed_data directly
        # padded_spectrogram = pad_sequence(padded_spectrogram, min_length)
        spectrogram = np.load(os.path.join(read_path_eeg, spectrogram_path))
        w2v_output = w2v.wav2vec2(spectrogram, proc_eegs=True)
        np.save(w2v_only_output_filename, w2v_output)
        
        load = np.load(w2v_only_output_filename)
        print(spectrogram.shape)
        print("haha")
        print(load.shape)
        exit()
        
    for filter_type in cutoffs:
        # Define directories for raw and w2v processed data within each filter type folder
        raw_dir = os.path.join(config.data_w2v_specs_eeg, filter_type.name, 'raw')
        w2v_dir = os.path.join(config.data_w2v_specs_eeg, filter_type.name, 'w2v')
        os.makedirs(raw_dir, exist_ok=True)
        os.makedirs(w2v_dir, exist_ok=True)
        
        # Define output filenames for raw and w2v processed data within their respective directories
        raw_output_filename = os.path.join(raw_dir, f"{eeg_id}.npy")
        w2v_output_filename = os.path.join(w2v_dir, f"{eeg_id}.npy")
        
        if force_regenerate or not os.path.exists(raw_output_filename) or not os.path.exists(w2v_output_filename):
            # Apply filter
            cutoff = cutoffs[filter_type]
            filtered_spectrogram = ft.apply_filter_to_spectrogram(spectrogram, cutoff, fs, filter_type)
            
            # Save raw filtered data
            np.save(raw_output_filename, filtered_spectrogram)
            
            # Process filtered data with wav2vec
            # padded_filtered_spectrogram = pad_sequence(filtered_spectrogram, min_length)
            w2v_filtered_output = w2v.wav2vec2(filtered_spectrogram, proc_eegs=True)

            np.save(w2v_output_filename, w2v_filtered_output)

  0%|          | 0/17089 [00:00<?, ?it/s]