In [1]:
import os
import sys
import gc
import logging
import random
import time
import warnings

import rich
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, RichProgressBar
import pandas as pd, numpy as np
import matplotlib.pyplot as plt
import albumentations as albu
from sklearn.model_selection import KFold, GroupKFold

if os.path.isdir("/kaggle"):
    ROOT = "/kaggle"
else:
    ROOT = ".."
INPUT = f"{ROOT}/input"
sys.path.append(f"{ROOT}/main/")
from kaggle_kl_div import score
from data import HMS_DM
from utils import grid_search, TBLogger, ExCB
from bottle import HMS_Lightning

In [2]:
n_trials = 0
debug = 0
n_workers = 0
run_folds = [0]
seed = 42
early_stop = 10
log_dir = "e0"
pred_dir = "subs"
pred = False
metric = "loss/V"
hp_conf = {
    "n_epochs": 30,
    "lr": 2e-3,
    "lr_warmup": 0.02,
    "wt_decay": 0.05,
    "n_grad_accum": 1,
    "seed": 5,
    "n_folds": 5,
    "fold": 0,
    "read_spec_files": False,
    "read_eeg_spec_files": False,
    "use_kaggle_spectrograms": True,
    "use_eeg_spectrograms": True,
    "batch_size": 32,
}
hp_skip = []
ckpt = (
    f"{ROOT}/input/hms-efficientnetb0-pt-ckpts/efficientnet_b0_rwightman-7f5810bc.pth"
)

train_meta_csv_path = f"{INPUT}/hms-harmful-brain-activity-classification/train.csv"
train_spec_path = (
    f"{INPUT}/hms-harmful-brain-activity-classification/train_spectrograms/"
)
train_eeg_path = f"{INPUT}/brain-eeg-spectrograms/"
test_meta_csv_path = f"{INPUT}/hms-harmful-brain-activity-classification/test.csv"
test_spec_path = f"{INPUT}/hms-harmful-brain-activity-classification/test_spectrograms/"
test_eeg_path = f"{INPUT}/hms-harmful-brain-activity-classification/test_eegs/"

LOAD_MODELS_FROM = None  # f"{ROOT}/input/hms-efficientnetb0-pt-ckpts/"
USE_KAGGLE_SPECTROGRAMS = True
USE_EEG_SPECTROGRAMS = True

In [3]:
# try:
with rich.get_console().status("Reticulating Splines"):
    if not debug:
        warnings.filterwarnings("ignore")
        for n in logging.root.manager.loggerDict:
            logging.getLogger(n).setLevel(logging.WARN)
    torch.set_float32_matmul_precision("medium")
    torch.manual_seed(seed)
    random.seed(seed)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(pred_dir, exist_ok=True)
    trials = grid_search(hp_conf, hp_skip)
    n_trials = len(trials) if not n_trials else n_trials
    n_trials = len(trials) if len(trials) < n_trials else n_trials

print(f"Log: {log_dir} | EStop: {early_stop} | Ckpt: {ckpt} | Pred: {pred}")
for i, hp in enumerate(trials[:n_trials]):
    for j, f in enumerate(run_folds):
        print(f"Trial {i + 1}/{n_trials} Fold {j + 1}/{len(run_folds)} ({f})")
        hp.fold = f
        tbl = TBLogger(os.getcwd(), log_dir, default_hp_metric=False)
        cb = [RichProgressBar(), ExCB()]
        cb += [ModelCheckpoint(tbl.log_dir, None, metric)] if ckpt else []
        cb += [EarlyStopping(metric, 0, early_stop)] if early_stop else []
        dm = HMS_DM(
            hp,
            n_workers,
            train_meta_csv_path,
            train_spec_path,
            train_eeg_path,
            test_meta_csv_path,
            test_spec_path,
            test_eeg_path,
        )

        model = HMS_Lightning(hp)
        trainer = pl.Trainer(
            precision="bf16-mixed",
            accelerator="gpu",
            benchmark=True,
            max_epochs=hp.n_epochs,
            accumulate_grad_batches=hp.n_grad_accum,
            # gradient_clip_val=hp.grad_clip,
            fast_dev_run=debug,
            num_sanity_val_steps=0,
            enable_model_summary=False,
            logger=tbl,
            callbacks=cb,
        )
        gc.collect()
        try:
            trainer.fit(model, datamodule=dm)
        except KeyboardInterrupt:
            print("Fit Interrupted")
            if i + 1 < n_trials:
                with rich.get_console().status("Quit?") as s:
                    for k in range(3):
                        s.update(f"Quit? {3-k}")
                        time.sleep(1)
            continue
        if pred:
            try:
                cp = None if debug else "best"
                preds = trainer.predict(model, datamodule=dm, ckpt_path=cp)
            except KeyboardInterrupt:
                print("Prediction Interrupted")
                continue
            with rich.get_console().status("Processing Submission"):
                if hp.tta_flip:
                    preds = [torch.concat(preds[i]).float() for i in [0, 1]]
                    preds = (preds[0] + preds[1]) / 2
                else:
                    preds = torch.concat(preds).float()
                fn = f"{pred_dir}/{log_dir}v{tbl.version:02}"
                # submission(preds, fn if not debug else None)
                del preds
# except KeyboardInterrupt:
#    print("Goodbye")
#    sys.exit()

Log: e0 | EStop: 10 | Ckpt: ../input/hms-efficientnetb0-pt-ckpts/efficientnet_b0_rwightman-7f5810bc.pth | Pred: False
Trial 1/1 Fold 1/1 (0)
Train shape: (106800, 15)
Targets ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
Train non-overlapp eeg_id shape: (17089, 12)
There are 11138 spectrogram parquets
Test shape (1, 3)
There are 1 test spectrogram parquets
0 , 

Fit Interrupted
