In [1]:
import os
import sys
import gc
from argparse import Namespace

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import lightning.pytorch as pl
import pandas as pd, numpy as np
import matplotlib.pyplot as plt
import albumentations as albu
from sklearn.model_selection import KFold, GroupKFold

import pywt, librosa

if os.path.isdir("/kaggle"):
    a = 0
    ROOT = "/kaggle"
else:
    sys.path.append("../main")
    ROOT = ".."
from kaggle_kl_div import score
from utils import grid_search, kfold

In [2]:
def preprocess_meta_data(csv_path):
    df = pd.read_csv(csv_path)
    targets = df.columns[-6:]
    print("Train shape:", df.shape)
    print("Targets", list(targets))
    train = df.groupby("eeg_id")[
        ["spectrogram_id", "spectrogram_label_offset_seconds"]
    ].agg({"spectrogram_id": "first", "spectrogram_label_offset_seconds": "min"})
    train.columns = ["spec_id", "min"]
    tmp = df.groupby("eeg_id")[
        ["spectrogram_id", "spectrogram_label_offset_seconds"]
    ].agg({"spectrogram_label_offset_seconds": "max"})
    train["max"] = tmp
    tmp = df.groupby("eeg_id")[["patient_id"]].agg("first")
    train["patient_id"] = tmp
    tmp = df.groupby("eeg_id")[targets].agg("sum")
    for t in targets:
        train[t] = tmp[t].values
    y_data = train[targets].values
    y_data = y_data / y_data.sum(axis=1, keepdims=True)
    train[targets] = y_data
    tmp = df.groupby("eeg_id")[["expert_consensus"]].agg("first")
    train["target"] = tmp
    train = train.reset_index()
    print("Train non-overlapp eeg_id shape:", train.shape)
    train.head()
    return train, targets


def read_spectrogram(spec_path, read_spec_file):
    files = os.listdir(spec_path)
    print(f"There are {len(files)} spectrogram parquets")
    if read_spec_file:
        spectrograms = {}
        for i, f in enumerate(files):
            if i % 100 == 0:
                print(i, ", ", end="")
            tmp = pd.read_parquet(f"{spec_path}{f}")
            name = int(f.split(".")[0])
            spectrograms[name] = tmp.iloc[:, 1:].values
    else:
        spectrograms = np.load(
            f"{ROOT}/input/brain-spectrograms/specs.npy", allow_pickle=True
        ).item()
    return spectrograms


def read_eeg(meta_data, eeg_path, read_eeg_spec_file):
    if read_eeg_spec_file:
        all_eegs = {}
        for i, e in enumerate(meta_data.eeg_id.values):
            if i % 100 == 0:
                print(i, ", ", end="")
            x = np.load(f"{eeg_path}EEG_Spectrograms/{e}.npy")
            all_eegs[e] = x
    else:
        all_eegs = np.load(f"{eeg_path}eeg_specs.npy", allow_pickle=True).item()
    return all_eegs


def load_train(meta_csv_path, spec_path, eeg_path, read_spec_file, read_eeg_spec_file):
    meta_data, targets = preprocess_meta_data(meta_csv_path)
    spectrograms = read_spectrogram(spec_path, read_spec_file)
    all_eegs = read_eeg(meta_data, eeg_path, read_eeg_spec_file)
    return meta_data, targets, spectrograms, all_eegs


train_meta_csv_path = (
    f"{ROOT}/input/hms-harmful-brain-activity-classification/train.csv"
)
spec_path = (
    f"{ROOT}/input/hms-harmful-brain-activity-classification/train_spectrograms/"
)
eeg_path = f"{ROOT}/input/brain-eeg-spectrograms/"

# train_meta, train_targets, train_spectrograms, train_all_eegs = load_data(
#    train_meta_csv_path, spec_path, eeg_path, READ_SPEC_FILES, READ_EEG_SPEC_FILES
# )

In [3]:
def test_meta_and_spectrogram(test_meta_path, test_spectrogram_path):
    test_meta_df = pd.read_csv(test_meta_path)
    print("Test shape", test_meta_df.shape)
    test_meta_df.head()
    PATH2 = test_spectrogram_path
    files2 = os.listdir(PATH2)
    print(f"There are {len(files2)} test spectrogram parquets")

    test_spectrograms = {}
    for i, f in enumerate(files2):
        if i % 100 == 0:
            print(i, ", ", end="")
        tmp = pd.read_parquet(f"{PATH2}{f}")
        name = int(f.split(".")[0])
        test_spectrograms[name] = tmp.iloc[:, 1:].values

    # RENAME FOR DATALOADER
    test_meta_df = test_meta_df.rename({"spectrogram_id": "spec_id"}, axis=1)
    return test_meta_df, test_spectrograms


USE_WAVELET = None

NAMES = ["LL", "LP", "RP", "RR"]

FEATS = [
    ["Fp1", "F7", "T3", "T5", "O1"],
    ["Fp1", "F3", "C3", "P3", "O1"],
    ["Fp2", "F8", "T4", "T6", "O2"],
    ["Fp2", "F4", "C4", "P4", "O2"],
]


# DENOISE FUNCTION
def maddest(d, axis=None):
    return np.mean(np.absolute(d - np.mean(d, axis)), axis)


def denoise(x, wavelet="haar", level=1):
    coeff = pywt.wavedec(x, wavelet, mode="per")
    sigma = (1 / 0.6745) * maddest(coeff[-level])

    uthresh = sigma * np.sqrt(2 * np.log(len(x)))
    coeff[1:] = (pywt.threshold(i, value=uthresh, mode="hard") for i in coeff[1:])

    ret = pywt.waverec(coeff, wavelet, mode="per")

    return ret


def spectrogram_from_eeg(parquet_path, display=False):

    # LOAD MIDDLE 50 SECONDS OF EEG SERIES
    eeg = pd.read_parquet(parquet_path)
    middle = (len(eeg) - 10_000) // 2
    eeg = eeg.iloc[middle : middle + 10_000]

    # VARIABLE TO HOLD SPECTROGRAM
    img = np.zeros((128, 256, 4), dtype="float32")

    if display:
        plt.figure(figsize=(10, 7))
    signals = []
    for k in range(4):
        COLS = FEATS[k]

        for kk in range(4):

            # COMPUTE PAIR DIFFERENCES
            x = eeg[COLS[kk]].values - eeg[COLS[kk + 1]].values

            # FILL NANS
            m = np.nanmean(x)
            if np.isnan(x).mean() < 1:
                x = np.nan_to_num(x, nan=m)
            else:
                x[:] = 0

            # DENOISE
            if USE_WAVELET:
                x = denoise(x, wavelet=USE_WAVELET)
            signals.append(x)

            # RAW SPECTROGRAM
            mel_spec = librosa.feature.melspectrogram(
                y=x,
                sr=200,
                hop_length=len(x) // 256,
                n_fft=1024,
                n_mels=128,
                fmin=0,
                fmax=20,
                win_length=128,
            )

            # LOG TRANSFORM
            width = (mel_spec.shape[1] // 32) * 32
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max).astype(np.float32)[
                :, :width
            ]

            # STANDARDIZE TO -1 TO 1
            mel_spec_db = (mel_spec_db + 40) / 40
            img[:, :, k] += mel_spec_db

        # AVERAGE THE 4 MONTAGE DIFFERENCES
        img[:, :, k] /= 4.0

    return img


def all_spectrogram_from_eeg(eeg_path, eeg_ids):
    all_eegs = {}
    for i, eeg_id in enumerate(eeg_ids):
        img = spectrogram_from_eeg(f"{eeg_path}{eeg_id}.parquet", False)
        all_eegs[eeg_id] = img
    return all_eegs


# READ ALL EEG SPECTROGRAMS
test_meta_csv_path = f"{ROOT}/input/hms-harmful-brain-activity-classification/test.csv"
test_spectrogram_path = (
    f"{ROOT}/input/hms-harmful-brain-activity-classification/test_spectrograms/"
)
test_eeg_path = f"{ROOT}/input/hms-harmful-brain-activity-classification/test_eegs/"


def load_test(meta_path, spectrogram_path, eeg_path):
    meta_df, test_spectrograms = test_meta_and_spectrogram(
        meta_path, test_spectrogram_path
    )

    test_eeg_ids = meta_df.eeg_id.unique()
    all_eegs_test = all_spectrogram_from_eeg(test_eeg_path, test_eeg_ids)
    return meta_df


# test_meta_df, test_spectrograms, test_eegs = load_test(
#    test_meta_csv_path, test_spectrogram_path, test_eeg_path
# )

# Dataset

In [4]:
class HMS_DS(Dataset):
    def __init__(
        self,
        meta_data,
        targets=None,
        augment=False,
        mode="train",
        specs=None,
        eeg_specs=None,
    ):
        self.meta_data = meta_data
        self.targets = targets
        self.augment = augment
        self.mode = mode
        self.specs = specs
        self.eeg_specs = eeg_specs

    def __len__(self):
        return len(self.meta_data)

    def __getitem__(self, index):
        return self.__getitems__([index])

    def __getitems__(self, indices):
        X, y = self._generate_data(indices)
        if self.augment:
            X = self._augment(X)
        if self.mode == "train":
            return list(zip(X, y))
        else:
            return X

    def _generate_data(self, indexes):
        X = np.zeros((len(indexes), 128, 256, 8), dtype="float32")
        y = np.zeros((len(indexes), 6), dtype="float32")
        img = np.ones((128, 256), dtype="float32")

        for j, i in enumerate(indexes):
            row = self.meta_data.iloc[i]
            if self.mode == "test":
                r = 0
            else:
                r = int((row["min"] + row["max"]) // 4)

            for k in range(4):
                # EXTRACT 300 ROWS OF SPECTROGRAM
                img = self.specs[row.spec_id][r : r + 300, k * 100 : (k + 1) * 100].T

                # LOG TRANSFORM SPECTROGRAM
                img = np.clip(img, np.exp(-4), np.exp(8))
                img = np.log(img)

                # STANDARDIZE PER IMAGE
                ep = 1e-6
                m = np.nanmean(img.flatten())
                s = np.nanstd(img.flatten())
                img = (img - m) / (s + ep)
                img = np.nan_to_num(img, nan=0.0)

                # CROP TO 256 TIME STEPS
                X[j, 14:-14, :, k] = img[:, 22:-22] / 2.0

            # EEG SPECTROGRAMS
            img = self.eeg_specs[row.eeg_id]
            X[j, :, :, 4:] = img

            if self.mode != "test":
                y[j,] = row[self.targets]

        return X, y

    def _random_transform(self, img):
        composition = albu.Compose(
            [
                albu.HorizontalFlip(p=0.5),
                # albu.CoarseDropout(max_holes=8,max_height=32,max_width=32,fill_value=0,p=0.5),
            ]
        )
        return composition(image=img)["image"]

    def __augment(self, img_batch):
        for i in range(img_batch.shape[0]):
            img_batch[i,] = self._random_transform(img_batch[i,])
        return img_batch

In [5]:
# train_dataset = EEGDataset(
#    train_meta, train_targets, False, "train", train_spectrograms, train_all_eegs
# )

In [6]:
class HMS_DM(pl.LightningDataModule):
    def __init__(
        self,
        hp: Namespace,
        n_workers: int = 0,
        train_meta_csv_path: str | None = None,
        train_spec_path: str | None = None,
        train_eeg_path: str | None = None,
        test_meta_csv_path: str | None = None,
        test_spec_path: str | None = None,
        test_eeg_path: str | None = None,
    ):
        super().__init__()
        args = hp.read_spec_files, hp.read_eeg_spec_files
        if train_meta_csv_path:
            (
                self.train_meta_df,
                self.targets,
                self.train_spectrograms,
                self.train_eegs,
            ) = load_train(train_meta_csv_path, train_spec_path, train_eeg_path, *args)
            fold_idxs = kfold(self.train_meta_df, hp.n_folds)[hp.fold]
            self.train_split_df = self.train_meta_df.iloc[fold_idxs[0]]
            self.valid_split_df = self.train_meta_df.iloc[fold_idxs[1]]
            self.kwargs = {
                "batch_size": hp.batch_size,
                "num_workers": n_workers,
                "pin_memory": bool(n_workers),
            }
        if test_meta_csv_path:
            self.test_meta_df, self.test_spectrograms, self.test_eegs = load_test(
                test_meta_csv_path, test_spec_path, test_eeg_path
            )

    def train_dataloader(self):
        assert self.train_split_df is not None
        ds = HMS_DS(
            self.train_split_df,
            self.targets,
            False,
            "train",
            self.train_spectrograms,
            self.train_eegs,
        )
        return DataLoader(ds, **(self.kwargs | {"shuffle": True}))

    def val_dataloader(self):
        assert self.valid_split_df is not None
        ds = HMS_DS(
            self.valid_split_df,
            self.targets,
            False,
            "train",
            self.train_spectrograms,
            self.train_eegs,
        )
        return DataLoader(ds, **(self.kwargs | {"shuffle": False}))

    def predict_dataloader(self):
        assert self.test_meta_df is not None
        ds = HMS_DS(
            self.test_meta_df,
            self.targets,
            False,
            self.test_spectrograms,
            self.test_eegs,
        )
        return DataLoader(ds, **(self.kwargs) | {"shuffle": False})

In [7]:
n_trials = 1
hp_conf = {
    "seed": 5,
    "n_folds": 5,
    "fold": 0,
    "read_spec_files": False,
    "read_eeg_spec_files": False,
    "batch_size": 32,
}
hp = trials = grid_search(hp_conf, [])
n_trials = len(trials) if not n_trials else n_trials
n_trials = len(trials) if len(trials) < n_trials else n_trials


train_meta_csv_path = (
    f"{ROOT}/input/hms-harmful-brain-activity-classification/train.csv"
)
spec_path = (
    f"{ROOT}/input/hms-harmful-brain-activity-classification/train_spectrograms/"
)
eeg_path = f"{ROOT}/input/brain-eeg-spectrograms/"

test_meta_csv_path = f"{ROOT}/input/hms-harmful-brain-activity-classification/test.csv"
test_spec_path = (
    f"{ROOT}/input/hms-harmful-brain-activity-classification/test_spectrograms/"
)
test_eeg_path = f"{ROOT}/input/hms-harmful-brain-activity-classification/test_eegs/"


for i, hp in enumerate(trials[:n_trials]):
    hms_dm = HMS_DM(
        hp,
        0,
        train_meta_csv_path,
        spec_path,
        eeg_path,
        test_meta_csv_path,
        test_spec_path,
        test_eeg_path,
    )

Train shape: (106800, 15)
Targets ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
Train non-overlapp eeg_id shape: (17089, 12)
There are 11138 spectrogram parquets
Test shape (1, 3)
There are 1 test spectrogram parquets
0 , 