In [1]:
import gc
import logging
import math
import os
import random
import sys
import pickle
import warnings
from abc import abstractmethod
from datetime import datetime
from pathlib import Path
from tqdm.notebook import tqdm
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
warnings.simplefilter("ignore")

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from scipy.signal import butter, lfilter
from sklearn.model_selection import GroupKFold
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer, lr_scheduler
from torch.optim.lr_scheduler import _LRScheduler
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import get_cosine_schedule_with_warmup

In [2]:
import ipywidgets as widgets

In [3]:
class _Logger:
    """Customized logger.

    Args:
        logging_level: lowest-severity log message the logger handles
        logging_file: file stream for logging
            *Note: If `logging_file` isn't specified, message is only
                logged to system standard output.
    """

    _logger: logging.Logger = None

    def __init__(
        self,
        logging_level: str = "INFO",
        logging_file: Optional[Path] = None,
    ):
        self.logging_level = logging_level
        self.logging_file = logging_file

        self._build_logger()

    def get_logger(self) -> logging.Logger:
        """Return customized logger."""
        return self._logger

    def _build_logger(self) -> None:
        """Build logger."""
        self._logger = logging.getLogger()
        self._logger.setLevel(self._get_level())
        self._add_handler()

    def _get_level(self) -> int:
        """Return lowest severity of the events the logger handles.

        Returns:
            level: severity of the events
        """
        level = 0

        if self.logging_level == "DEBUG":
            level = logging.DEBUG
        elif self.logging_level == "INFO":
            level = logging.INFO
        elif self.logging_level == "WARNING":
            level = logging.WARNING
        elif self.logging_level == "ERROR":
            level = logging.ERROR
        elif self.logging_level == "CRITICAL":
            level = logging.CRITICAL

        return level

    def _add_handler(self) -> None:
        """Add stream and file (optional) handlers to logger."""
        s_handler = logging.StreamHandler(sys.stdout)
        self._logger.addHandler(s_handler)

        if self.logging_file is not None:
            f_handler = logging.FileHandler(self.logging_file, mode="a")
            self._logger.addHandler(f_handler)

            
def _seed_everything(seed: int) -> None:
    """Seed current experiment to guarantee reproducibility.

    Args:
        seed: manually specified seed number
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # When running with cudnn backend
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True 
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

In [4]:
DATA_PATH = Path("/data2/users/koushani/HMS_data")

class CFG:
    train_models = True
    seed = 42
    
    exp_id = datetime.now().strftime("%m%d-%H-%M-%S")
    # Define experiment ID and path
    exp_dump_path = Path(DATA_PATH/"kaggle"/"working"/exp_id)

    # Create the directory
    exp_dump_path.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # == Data ==
    gen_eegs = False
    # Chris' 8 channels
    feats = [
        "Fp1", "T3", "C3", "O1",
        "Fp2", "C4", "T4", "O2"
    ]
    cast_eegs = True
    dataset = {
        "eeg": {
            "n_feats": 8,
            "apply_chris_magic_ch8": True,
            "normalize": True,
            "apply_butter_lowpass_filter": True,
            "apply_mu_law_encoding": False,
            "downsample": 5
        }
    }

    # == Trainer ==
    trainer = {
        "epochs": 5,
        "lr": 1e-3,
        "dataloader": {
            "batch_size": 32,
            "shuffle": True,
            "num_workers": 2
        },
        "use_amp": True,
        "grad_accum_steps": 1,
        "model_ckpt": {
            "ckpt_metric": "kldiv",
            "ckpt_mode": "min",
            "best_ckpt_mid": "last"
        },
        "es": {"patience": 0},
        "step_per_batch": True,
        "one_batch_only": False
    }
    
    # == Debug ==
    one_fold_only = True
    
    
N_CLASSES = 6
TGT_VOTE_COLS = [
    "seizure_vote", "lpd_vote", "gpd_vote", "lrda_vote",
    "grda_vote", "other_vote"
]
TGT_COL = "target"
EEG_FREQ = 200  # Hz
EEG_WLEN = 50  # sec
EEG_PTS = int(EEG_FREQ * EEG_WLEN)

In [5]:
if not CFG.exp_dump_path.exists():
    os.mkdir(CFG.exp_dump_path)
    
logger = _Logger(logging_file=CFG.exp_dump_path / "train_eval.log").get_logger()
_seed_everything(CFG.seed)

In [6]:
def _get_eeg_window(file: Path) -> np.ndarray:
    """Return cropped EEG window.

    Default setting is to return the middle 50-sec window.

    Args:
        file: EEG file path
        test: if True, there's no need to truncate EEGs

    Returns:
        eeg_win: cropped EEG window 
    """
    eeg = pd.read_parquet(file, columns=CFG.feats)
    n_pts = len(eeg)
    offset = (n_pts - EEG_PTS) // 2
    eeg = eeg.iloc[offset:offset + EEG_PTS]
    
    eeg_win = np.zeros((EEG_PTS, len(CFG.feats)))
    for j, col in enumerate(CFG.feats):
        if CFG.cast_eegs:
            eeg_raw = eeg[col].values.astype("float32")
        else:
            eeg_raw = eeg[col].values 

        # Fill missing values
        mean = np.nanmean(eeg_raw)
        if np.isnan(eeg_raw).mean() < 1:
            eeg_raw = np.nan_to_num(eeg_raw, nan=mean)
        else: 
            # All missing
            eeg_raw[:] = 0
        eeg_win[:, j] = eeg_raw 
        
    return eeg_win 

In [7]:
train = pd.read_csv(DATA_PATH / "train.csv")
logger.info(f"Train data shape | {train.shape}")

Train data shape | (106800, 15)


In [8]:
# Define paths
eeg_file_path = DATA_PATH/"kaggle" /"input"/ "brain-eegs" / "eegs.npy"

# Ensure directories exist
eeg_file_path.parent.mkdir(parents=True, exist_ok=True)

# Unique EEG IDs
uniq_eeg_ids = train["eeg_id"].unique()
n_uniq_eeg_ids = len(uniq_eeg_ids)

logger.info("Load cropped EEGs...")
all_eegs = np.load(eeg_file_path, allow_pickle=True).item()
assert len(all_eegs) == n_uniq_eeg_ids

# Debug: Print a sample EEG shape
logger.info(f"Demo EEG shape | {list(all_eegs.values())[0].shape}")

Load cropped EEGs...
Demo EEG shape | (10000, 8)


In [9]:
logger.info(f"Process labels...")
df_tmp = train.groupby("eeg_id")[["patient_id"]].agg("first")
labels_tmp = train.groupby("eeg_id")[TGT_VOTE_COLS].agg("sum")
for col in TGT_VOTE_COLS:
    df_tmp[col] = labels_tmp[col].values

# Normalize target columns
y_data = df_tmp[TGT_VOTE_COLS].values
y_data = y_data / y_data.sum(axis=1, keepdims=True)
df_tmp[TGT_VOTE_COLS] = y_data

tgt = train.groupby("eeg_id")[["expert_consensus"]].agg("first")
df_tmp[TGT_COL] = tgt 

train = df_tmp.reset_index()
logger.info(f"Training DataFrame shape | {train.shape}")

Process labels...
Training DataFrame shape | (17089, 9)


In [10]:
class _EEGTransformer(object):
    """Data transformer for raw EEG signals."""

    FEAT2CODE = {f: i for i, f in enumerate(CFG.feats)}

    def __init__(
        self,
        n_feats: int,
        apply_chris_magic_ch8: bool = True,
        normalize: bool = True,
        apply_butter_lowpass_filter: bool = True,
        apply_mu_law_encoding: bool = False,
        downsample: Optional[int] = None,
    ) -> None:
        self.n_feats = n_feats
        self.apply_chris_magic_ch8 = apply_chris_magic_ch8
        self.normalize = normalize
        self.apply_butter_lowpass_filter = apply_butter_lowpass_filter
        self.apply_mu_law_encoding = apply_mu_law_encoding
        self.downsample = downsample

    def transform(self, x: np.ndarray) -> np.ndarray:
        """Apply transformation on raw EEG signals.
        
        Args:
            x: raw EEG signals, with shape (L, C)

        Return:
            x_: transformed EEG signals
        """
        x_ = x.copy()
        if self.apply_chris_magic_ch8:
            x_ = self._apply_chris_magic_ch8(x_)

        if self.normalize:
            x_ = np.clip(x_, -1024, 1024)
            x_ = np.nan_to_num(x_, nan=0) / 32.0

        if self.apply_butter_lowpass_filter:
            x_ = self._butter_lowpass_filter(x_) 

        if self.apply_mu_law_encoding:
            x_ = self._quantize_data(x_, 1)

        if self.downsample is not None:
            x_ = x_[::self.downsample, :]

        return x_

    def _apply_chris_magic_ch8(self, x: np.ndarray) -> np.ndarray:
        """Generate features based on Chris' magic formula.""" 
        x_tmp = np.zeros((EEG_PTS, self.n_feats), dtype="float32")

        # Generate features
        x_tmp[:, 0] = x[:, self.FEAT2CODE["Fp1"]] - x[:, self.FEAT2CODE["T3"]]
        x_tmp[:, 1] = x[:, self.FEAT2CODE["T3"]] - x[:, self.FEAT2CODE["O1"]]
        
        x_tmp[:, 2] = x[:, self.FEAT2CODE["Fp1"]] - x[:, self.FEAT2CODE["C3"]]
        x_tmp[:, 3] = x[:, self.FEAT2CODE["C3"]] - x[:, self.FEAT2CODE["O1"]]
        
        x_tmp[:, 4] = x[:, self.FEAT2CODE["Fp2"]] - x[:, self.FEAT2CODE["C4"]]
        x_tmp[:, 5] = x[:, self.FEAT2CODE["C4"]] - x[:, self.FEAT2CODE["O2"]]
        
        x_tmp[:, 6] = x[:, self.FEAT2CODE["Fp2"]] - x[:, self.FEAT2CODE["T4"]]
        x_tmp[:, 7] = x[:, self.FEAT2CODE["T4"]] - x[:, self.FEAT2CODE["O2"]]

        return x_tmp

    def _butter_lowpass_filter(self, data, cutoff_freq=20, sampling_rate=200, order=4):
        nyquist = 0.5 * sampling_rate
        normal_cutoff = cutoff_freq / nyquist
        b, a = butter(order, normal_cutoff, btype="low", analog=False)
        filtered_data = lfilter(b, a, data, axis=0)

        return filtered_data
                
    def _quantize_data(self, data, classes):
        mu_x = self._mu_law_encoding(data, classes)
        
        return mu_x

    def _mu_law_encoding(self, data, mu):
        mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)

        return mu_x

In [11]:
class EEGDataset(Dataset):
    """Dataset for pure raw EEG signals.

    Args:
        data: processed data
        split: data split

    Attributes:
        _n_samples: number of samples
        _infer: if True, the dataset is constructed for inference
            *Note: Ground truth is not provided.
    """

    def __init__(
        self,
        data: Dict[str,  Any],
        split: str,
        **dataset_cfg: Any,
    ) -> None:
        self.metadata = data["meta"]
        self.all_eegs = data["eeg"]
        self.dataset_cfg = dataset_cfg

        # Raw EEG data transformer
        self.eeg_params = dataset_cfg["eeg"]
        self.eeg_trafo = _EEGTransformer(**self.eeg_params)

        self._set_n_samples()
        self._infer = True if split == "test" else False

        self._stream_X = True if self.all_eegs is None else False
        self._X, self._y = self._transform()

    def _set_n_samples(self) -> None:
        assert len(self.metadata) == self.metadata["eeg_id"].nunique()
        self._n_samples = len(self.metadata)

    def _transform(self) -> Tuple[Optional[np.ndarray], np.ndarray]:
        """Transform feature and target matrices."""
        if self.eeg_params["downsample"] is not None:
            eeg_len = int(EEG_PTS / self.eeg_params["downsample"])
        else:
            eeg_len = int(EEG_PTS)
        if not self._stream_X:
            X = np.zeros((self._n_samples, eeg_len, self.eeg_params["n_feats"]), dtype="float32")
        else:
            X = None
        y = np.zeros((self._n_samples, N_CLASSES), dtype="float32") if not self._infer else None

        for i, row in tqdm(self.metadata.iterrows(), total=len(self.metadata)):
            # Process raw EEG signals
            if not self._stream_X:
                # Retrieve raw EEG signals
                eeg = self.all_eegs[row["eeg_id"]]

                # Apply EEG transformer
                x = self.eeg_trafo.transform(eeg)

                X[i] = x

            if not self._infer:
                y[i] = row[TGT_VOTE_COLS] 

        return X, y

    def __len__(self) -> int:
        return self._n_samples

    def __getitem__(self, idx: int) -> Dict[str, Tensor]:
        if self._X is None:
            # Load data here...
#             x = np.load(...)
#             x = self.eeg_trafo.transform(x)
            pass
        else:
            x = self._X[idx, ...]
        data_sample = {"x": torch.tensor(x, dtype=torch.float32)}
        if not self._infer:
            data_sample["y"] = torch.tensor(self._y[idx, :], dtype=torch.float32)

        return data_sample

In [12]:
        
if CFG.train_models:
    oof = np.zeros((len(train), N_CLASSES))
    prfs = []

    cv = GroupKFold(n_splits=5)
    for fold, (tr_idx, val_idx) in enumerate(cv.split(train, train[TGT_COL], train["patient_id"])):
        logger.info(f"== Train and Eval Process - Fold{fold} ==")

        # Build dataloaders
        data_tr, data_val = train.iloc[tr_idx].reset_index(drop=True), train.iloc[val_idx].reset_index(drop=True)
        train_loader = DataLoader(
            EEGDataset({"meta": data_tr, "eeg": all_eegs}, "train", **CFG.dataset),
            shuffle=CFG.trainer["dataloader"]["shuffle"],
            batch_size=CFG.trainer["dataloader"]["batch_size"],
            num_workers=CFG.trainer["dataloader"]["num_workers"]
        )
        val_loader = DataLoader(
            EEGDataset({"meta": data_val, "eeg": all_eegs}, "valid", **CFG.dataset),
            shuffle=False,
            batch_size=CFG.trainer["dataloader"]["batch_size"],
            num_workers=CFG.trainer["dataloader"]["num_workers"]
        )
else:
    file_path = DATA_PATH / "kaggle/input/hms-oof-demo/oof_seed0.npy"
    print(f"Checking file path: {file_path}")
    print(f"File exists: {os.path.exists(file_path)}")
    oof = np.load(file_path)

== Train and Eval Process - Fold0 ==


Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


AttributeError: 'FloatProgress' object has no attribute 'style'