# Brain Solver Python Training Notebook

This notebook utilizes the custom `brain_solver` package for analyzing brain activity data. Our data sources include official datasets from Kaggle competitions and additional datasets for enhanced model training and evaluation.

This is the Training notebook.

**Authors: Luppo Sloup, Dick Blankvoort, Tygo Francissen (MLiP Group 9)**

## Data Sources

### Official:

- **HMS - Harmful Brain Activity Classification**
  - **Source:** [Kaggle Competition](https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification)
  - **Description:** This competition focuses on classifying harmful brain activity. It includes a comprehensive dataset for training and testing models.

- **Brain-Spectrograms**
  - **Source:** [Kaggle Dataset](https://www.kaggle.com/datasets/cdeotte/brain-spectrograms)
  - **Description:** The `specs.npy` file contains all the spectrograms from the HMS competition, offering a detailed view of brain activity through visual representations.

### Additional:

- **Brain-EEG-Spectrograms**
  - **Source:** [Kaggle Dataset](https://www.kaggle.com/datasets/cdeotte/brain-eeg-spectrograms)
  - **Description:** The `EEG_Spectrograms` folder includes one NumPy file per EEG ID, with each array shaped as (128x256x4), representing (frequency, time, montage chain). This dataset provides a more nuanced understanding of brain activity through EEG spectrograms. They were created based on the raw data.

- **hms_efficientnetb0_pt_ckpts**
  - **Source:** [Kaggle Dataset](https://www.kaggle.com/datasets/crackle/hms-efficientnetb0-pt-ckpts)
  - **Description:** This dataset offers pre-trained checkpoints for EfficientNetB0 models, tailored for the HMS competition. It's intended for use in fine-tuning models on the specific task of harmful brain activity classification.

### Overview:

In addition to the data sources above, the inputs that are needed for this notebook are shown in 

<img src="overview_training.png" width="250" />

In [None]:
# These commands install the packages that are required for the notebook, should only be used when running the notebook on Kaggle
!pip install d2l --no-index --find-links=file:///kaggle/input/d2l-package/d2l/
!pip install /kaggle/input/brain-solver/brain_solver-1.0.0-py3-none-any.whl

In [None]:
# Imports for the notebook
import os, sys, gc, torch, warnings
import numpy as np, pandas as pd, pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers.utils import logging
from brain_solver import (
    Helpers as hp,
    BrainModel as br,
    EEGDataset,
    Config,
)

# Suppress warnings
warnings.filterwarnings("ignore")
logging.set_verbosity(logging.CRITICAL)

# Setup for CUDA device selection
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Config Class Summary

The `Config` class manages configurations for the brain activity classification project. It includes:

- **Data and Model Paths**: Centralizes paths for data (e.g., EEG, spectrograms) and model checkpoints.
- **Training Parameters**: Configures training details like epochs, batch size, and learning rate.
- **Feature Flags**: Toggles for using model settings, wavelets, spectrograms, and reading options.

We designed this class for easy adjustments to facilitate model development and experimentation.

In [None]:
# Possibility to set a local path for the data
full_path = ""
config = Config(
    full_path,
    full_path + "out/",
    USE_EEG_SPECTROGRAMS=True,
    USE_KAGGLE_SPECTROGRAMS=True,
    should_read_brain_spectograms=False,
    should_read_eeg_spectrogram_files=False,
    USE_PRETRAINED_MODEL=False,
    FINE_TUNE=False,
)

# Path to set for Kaggle
full_path = "/kaggle/input/"
config = Config(
    full_path,
    "/kaggle/working/",
    USE_EEG_SPECTROGRAMS=True,
    USE_KAGGLE_SPECTROGRAMS=True,
    should_read_brain_spectograms=False,
    should_read_eeg_spectrogram_files=False,
    USE_PRETRAINED_MODEL=False,
    FINE_TUNE=False,
)

# Load scoring function
sys.path.append(full_path + "kaggle-kl-div")
from kaggle_kl_div import score

In [None]:
# Create output folder if it does not exist
if not os.path.exists(config.output_path):
    os.makedirs(config.output_path)

# Initialize random environment
pl.seed_everything(config.seed, workers=True)

In [None]:
# Read the train CSV file
train_df: pd.DataFrame = hp.load_csv(config.data_train_csv)

if train_df is None:
    print("Failed to load the CSV file.")
    exit()
else:
    EEG_IDS = train_df.eeg_id.unique()
    TARGETS = train_df.columns[-6:]
    TARS = {"Seizure": 0, "LPD": 1, "GPD": 2, "LRDA": 3, "GRDA": 4, "Other": 5}
    TARS_INV = {x: y for y, x in TARS.items()}
    print("Train shape:", train_df.shape)

# Preprocess the train data
train_data_preprocessed = hp.preprocess_eeg_data(train_df, TARGETS)
train_data_preprocessed.head()

In [None]:
# Read the Kaggle and EEG spectrograms
spectrograms = hp.read_spectrograms(
    config.data_spectograms,
    config.path_to_brain_spectrograms_npy,
    config.should_read_brain_spectograms,
)
data_eeg_spectrograms = hp.read_eeg_spectrograms(
    train_data_preprocessed,
    config.path_to_eeg_spectrograms_folder,
    config.path_to_eeg_spectrograms_npy,
    config.should_read_eeg_spectrogram_files,
)
print(f"Length of spectrograms: {spectrograms.__len__()}, Length of all EEGs: {data_eeg_spectrograms.__len__()}")

In [None]:
# Plot some example spectograms and remove them afterwards
dataset = EEGDataset(train_data_preprocessed, spectrograms, data_eeg_spectrograms, TARGETS)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
hp.plot_spectrograms(dataloader, train_data_preprocessed, ROWS=2, COLS=3, BATCHES=2)
del dataset, dataloader, train_df
gc.collect()

In [None]:
# Train the model
all_oof, all_true, valid_loaders = br.cross_validate_eeg(
    config,
    device,
    train_data_preprocessed=train_data_preprocessed,
    spectrograms=spectrograms,
    data_eeg_spectograms=data_eeg_spectrograms,
    TARGETS=TARGETS,
    n_splits=5,
    batch_size_train=32,
    batch_size_valid=64,
    num_workers=3,
    max_epochs_first_stage=5,
    max_epochs_second_stage=3,
)

In [None]:
# Validate the model
all_oof, all_true = br.validate_model_across_folds(config, device, all_oof, all_true, valid_loaders)

In [None]:
# Convert the values to data frames
oof = pd.DataFrame(all_oof.copy())
oof["id"] = np.arange(len(oof))
true = pd.DataFrame(all_true.copy())
true["id"] = np.arange(len(true))

# Calculate the CV score
cv = score(solution=true, submission=oof, row_id_column_name="id")
print("CV Score KL-Div for EfficientNetB0 =", cv)

# Remove the used variables
del data_eeg_spectrograms, spectrograms
gc.collect()