In [1]:
import os
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import gc

from kl_divergence import score as kaggle_score 
from engine_hms_trainer import (
    seed_everything, gen_non_overlap_samples, calc_entropy, evaluate_oof, get_logger, 
    Trainer, TARGETS, TARGETS_PRED, BRAIN_ACTIVITY
    )
from engine_hms_model import (
    KagglePaths, LocalPaths, ModelConfig, CustomDataset, CustomEfficientNET, CustomVITMAE, DualEncoderModel, 
)

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from time import ctime, time
from sklearn.model_selection import KFold

import warnings
# warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)

In [2]:
PATHS = KagglePaths if os.path.exists(KagglePaths.OUTPUT_DIR) else LocalPaths
print("Output Dir: ", PATHS.OUTPUT_DIR)

ALL_SPECS = np.load(PATHS.PRE_LOADED_SPECTROGRAMS, allow_pickle=True).item()
ALL_EEGS = np.load(PATHS.PRE_LOADED_EEGS, allow_pickle=True).item()

seed_everything(ModelConfig.SEED)

Output Dir:  ./outputs/


In [3]:
def prepare_k_fold(df, k_folds=5):

    kf = KFold(n_splits=k_folds, shuffle=True, random_state=ModelConfig.SEED)
    unique_spec_id = df['spectrogram_id'].unique()
    df['fold'] = k_folds

    for fold, (train_index, valid_index) in enumerate(kf.split(unique_spec_id)):
        df.loc[df['spectrogram_id'].isin(unique_spec_id[valid_index]), 'fold'] = fold

    return df


def train_fold(model, fold_id, train_folds, valid_folds, logger, stage=1, checkpoint=None):

    train_dataset = CustomDataset(
        train_folds, TARGETS, ModelConfig, ALL_SPECS, ALL_EEGS, mode="train")

    valid_dataset = CustomDataset(
        valid_folds, TARGETS, ModelConfig, ALL_SPECS, ALL_EEGS, mode="valid")

    # ======== DATALOADERS ==========
    loader_kwargs = {
        "batch_size": ModelConfig.BATCH_SIZE,
        "num_workers": ModelConfig.NUM_WORKERS,
        "pin_memory": True,
        "shuffle": False,
    }
    train_loader = DataLoader(train_dataset, drop_last=True, **loader_kwargs)
    valid_loader = DataLoader(valid_dataset, drop_last=False, **loader_kwargs)

    trainer = Trainer(model, ModelConfig, logger)
    best_weights, best_preds, loss_records = trainer.train(
        train_loader, valid_loader, from_checkpoint=checkpoint)

    save_model_name = f"{ModelConfig.MODEL_NAME}_fold_{fold_id}_stage_{stage}.pth"
    torch.save(best_weights, os.path.join(PATHS.OUTPUT_DIR, save_model_name))

    del train_dataset, valid_dataset, train_loader, valid_loader
    torch.cuda.empty_cache()
    gc.collect()

    return best_preds, loss_records

def get_model(pretrained=True):
    
    backbone = ModelConfig.MODEL_BACKBONE

    if "efficientnet" in backbone:
        return CustomEfficientNET(ModelConfig, num_classes=6, pretrained=pretrained)
    elif "vit" in backbone:
        return CustomVITMAE(ModelConfig, num_classes=6, pretrained=pretrained)
    elif "dual" in backbone:
        return DualEncoderModel(ModelConfig, num_classes=6, pretrained=pretrained)
    else:
        return None

In [4]:
TARGET2ID = {'Seizure': 0, 'LPD': 1, 'GPD': 2, 'LRDA': 3, 'GRDA': 4, 'Other': 5}

def calc_kaggle_score(oof_df):
    submission_df = oof_df[['eeg_id']+TARGETS_PRED].copy()
    submission_df.columns = ['eeg_id'] + TARGETS
    solution_df = oof_df[['eeg_id']+TARGETS].copy()
    return kaggle_score(solution_df, submission_df, 'eeg_id')

def analyze_oof(oof_csv):

    kl_criteria = nn.KLDivLoss(reduction='batchmean')
    softmax = nn.Softmax(dim=1)

    oof_df = pd.read_csv(oof_csv)
    oof_df['target_pred'] = oof_df[TARGETS_PRED].apply(lambda x: np.argmax(x), axis=1)
    oof_df['target_id'] = oof_df[TARGETS].apply(lambda x: np.argmax(x), axis=1)
    
    oof_df["kl_loss"] = oof_df.apply(
    lambda row: 
        kl_criteria(
            F.log_softmax(
                    torch.tensor(row[TARGETS_PRED].values.astype(np.float32)).unsqueeze(0)
                , dim=1
                ), 
            torch.tensor(row[TARGETS].values.astype(np.float32))
            ).numpy(),
    axis=1)

    oof_df["kl_loss"] = oof_df['kl_loss'].astype(np.float32)

    oof_df[TARGETS_PRED] = softmax( torch.tensor(oof_df[TARGETS_PRED].values.astype(np.float32)) )

    oof_df.head()

    return oof_df

In [5]:
# train_easy, train_hard, all_specs, all_eegs = load_kaggle_data(
#     paths.TRAIN_CSV, paths.PRE_LOADED_SPECTOGRAMS, paths.PRE_LOADED_EEGS, split_entropy=ModelConfig.SPLIT_ENTROPY)

train_csv = pd.read_csv(PATHS.TRAIN_CSV)
targets = train_csv.columns[-6:]

print("targets: ", targets.to_list())

train_csv['entropy'] = train_csv.apply(calc_entropy, axis=1, tgt_list=targets)
train_csv['total_votes'] = train_csv[targets].sum(axis=1)

hard_csv = train_csv[train_csv['entropy'] < ModelConfig.SPLIT_ENTROPY].copy().reset_index(drop=True)

train_all = gen_non_overlap_samples(train_csv, targets)
train_hard = gen_non_overlap_samples(hard_csv, targets)

print("train_all.shape = ", train_all.shape)
print("train_all nan_count: ", train_all.isnull().sum().sum())
display(train_all.head())

print(" ")

print("train_hard.shape = ", train_hard.shape)
print("train_hard nan_count: ", train_hard.isnull().sum().sum())
display(train_hard.head())

targets:  ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
train_all.shape =  (20183, 12)
train_all nan_count:  0


Unnamed: 0,eeg_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,spectrogram_id,min,max,patient_id,target
0,568657,0.0,0.0,0.25,0.0,0.166667,0.583333,789577333,0.0,16.0,20654,Other
1,582999,0.0,0.857143,0.0,0.071429,0.0,0.071429,1552638400,0.0,38.0,20230,LPD
2,642382,0.0,0.0,0.0,0.0,0.0,1.0,14960202,1008.0,1032.0,5955,Other
3,751790,0.0,0.0,1.0,0.0,0.0,0.0,618728447,908.0,908.0,38549,GPD
4,778705,0.0,0.0,0.0,0.0,0.0,1.0,52296320,0.0,0.0,40955,Other


 
train_hard.shape =  (6187, 12)
train_hard nan_count:  0


Unnamed: 0,eeg_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,spectrogram_id,min,max,patient_id,target
0,568657,0.0,0.0,0.25,0.0,0.166667,0.583333,789577333,0.0,16.0,20654,Other
1,582999,0.0,0.857143,0.0,0.071429,0.0,0.071429,1552638400,0.0,38.0,20230,LPD
2,1895581,0.076923,0.0,0.0,0.0,0.076923,0.846154,128369999,1138.0,1138.0,47999,Other
3,2482631,0.0,0.0,0.133333,0.066667,0.133333,0.666667,978166025,1902.0,1944.0,20606,Other
4,2521897,0.0,0.0,0.083333,0.083333,0.333333,0.5,673742515,0.0,4.0,62117,Other


In [6]:
# # EfficientNet_B2_resplit (CV=0.5330731377333943)
# ModelConfig.EPOCHS = 6
# ModelConfig.BATCH_SIZE = 16
# ModelConfig.GRADIENT_ACCUMULATION_STEPS = 1
# ModelConfig.MODEL_BACKBONE = 'tf_efficientnet_b2'
# ModelConfig.MODEL_NAME = "EfficientNet_B2_resplit"
# ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
# ModelConfig.USE_EEG_SPECTROGRAMS = True
# ModelConfig.REGULARIZATION = None
# ModelConfig.AUGMENT = True
# ModelConfig.AUGMENTATIONS = ['xy_masking']

# # config DualEncoder
# ModelConfig.EPOCHS = 6
# ModelConfig.BATCH_SIZE = 16
# ModelConfig.GRADIENT_ACCUMULATION_STEPS = 1
# ModelConfig.MODEL_BACKBONE = 'dual_encoder'
# ModelConfig.MODEL_NAME = "DualEncoder_B0"
# ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
# ModelConfig.USE_EEG_SPECTROGRAMS = True
# ModelConfig.REGULARIZATION = None
# ModelConfig.AUGMENT = False
# ModelConfig.AUGMENTATIONS = []
# ModelConfig.DUAL_ENCODER_BACKBONE = 'tf_efficientnet_b0'

# Config ViTMAE
ModelConfig.EPOCHS = 15
ModelConfig.BATCH_SIZE = 16
ModelConfig.GRADIENT_ACCUMULATION_STEPS = 1
ModelConfig.MODEL_BACKBONE = 'vit_mae_base'
ModelConfig.MODEL_NAME = "MAE_RawBase_SeqPool_epoch_15"
ModelConfig.AUGMENT = True
ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
ModelConfig.USE_EEG_SPECTROGRAMS = True
ModelConfig.REGULARIZATION = None
ModelConfig.AUGMENTATIONS = ['xy_masking']
ModelConfig.MAE_PRETRAINED_WEIGHTS = "facebook/vit-mae-base" #"./outputs/vit_mae_pretraining/ViTMAE_PreTrained_Best.pth"
ModelConfig.MAE_HIDDEN_DROPOUT_PROB = 0.1
ModelConfig.MAE_ATTENTION_DROPOUT_PROB = 0.1
ModelConfig.DROP_RATE = 0.2

In [7]:
logger = get_logger(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_train.log")

logger.info(f"{'*' * 100}")
logger.info(f"Script Start: {ctime()}")
logger.info(f"Model Configurations:")
for key, value in ModelConfig.__dict__.items():
    if not key.startswith("__"):
        logger.info(f"{key}: {value}")
logger.info(f"{'*' * 100}")

k_folds = 5
train_all = prepare_k_fold(train_all, k_folds=k_folds)

oof_stage_1, oof_stage_2 = pd.DataFrame(), pd.DataFrame()
loss_history_1, loss_history_2 = [], []

for fold in range(k_folds):
    tik = time()

    model = get_model(pretrained=True)
    # model = CustomVITMAE(ModelConfig, num_classes=6, pretrained=False)

    valid_folds = train_all[train_all['fold'] == fold].reset_index(drop=True)
    train_folds = train_all[train_all['fold'] != fold].reset_index(drop=True)

    # ============== STAGE 1 ==============
    logger.info(f"{'=' * 100}\nFold: {fold} || Valid size {valid_folds.shape[0]} \n{'=' * 100}")
    logger.info(f"- First Stage -")
    valid_predicts, loss_records = train_fold(
        model, fold, train_folds, valid_folds, logger, stage=1, checkpoint=None)

    loss_history_1.append(loss_records)
    valid_folds[TARGETS_PRED] = valid_predicts
    oof_stage_1 = pd.concat([oof_stage_1, valid_folds], axis=0).reset_index(drop=True)
    kl_loss_torch = evaluate_oof(valid_folds)
    info = f"{'=' * 100}\nFold {fold} Valid Loss: {kl_loss_torch}\n"
    info += f"Elapse: {(time() - tik) / 60:.2f} min \n{'=' * 100}"
    logger.info(info)
    oof_stage_1.to_csv(os.path.join(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_oof_1.csv"), index=False)

    # ============== STAGE 2 ==============
    tik = time()
    logger.info(f"- Second Stage -")
    check_point = os.path.join(
        PATHS.OUTPUT_DIR,
        f"{ModelConfig.MODEL_NAME}_fold_{fold}_stage_1.pth"
    )
    logger.info(f"Use Checkpoint: {check_point.split('/')[-1]}")

    model = get_model(pretrained=True)
    valid_predicts, loss_records = train_fold(
        model, fold, train_hard, valid_folds, logger, stage=2, checkpoint=check_point)

    loss_history_2.append(loss_records)
    valid_folds[TARGETS_PRED] = valid_predicts
    oof_stage_2 = pd.concat([oof_stage_2, valid_folds], axis=0).reset_index(drop=True)
    kl_loss_torch = evaluate_oof(valid_folds)
    info = f"{'=' * 100}\nFold {fold} Valid Loss: {kl_loss_torch}\n"
    info += f"Elapse: {(time() - tik) / 60:.2f} min \n{'=' * 100}"
    logger.info(info)
    oof_stage_2.to_csv(os.path.join(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_oof_2.csv"), index=False)

****************************************************************************************************
Script Start: Mon Mar 25 16:33:33 2024
Model Configurations:
SEED: 20
SPLIT_ENTROPY: 5.5
MODEL_NAME: MAE_RawBase_SeqPool_epoch_15
MODEL_BACKBONE: vit_mae_base
BATCH_SIZE: 16
EPOCHS: 15
GRADIENT_ACCUMULATION_STEPS: 1
DROP_RATE: 0.2
DROP_PATH_RATE: 0.25
WEIGHT_DECAY: 0.01
REGULARIZATION: None
USE_KAGGLE_SPECTROGRAMS: True
USE_EEG_SPECTROGRAMS: True
AMP: True
AUGMENT: True
AUGMENTATIONS: ['xy_masking']
PRINT_FREQ: 50
FREEZE: False
NUM_FROZEN_LAYERS: 0
NUM_WORKERS: 0
MAX_GRAD_NORM: 10000000.0
DUAL_ENCODER_BACKBONE: tf_efficientnet_b2
MAE_PRETRAINED_WEIGHTS: facebook/vit-mae-base
MAE_HIDDEN_DROPOUT_PROB: 0.1
MAE_ATTENTION_DROPOUT_PROB: 0.1
****************************************************************************************************


Loading pretrained weights from facebook/vit-mae-base


Fold: 0 || Valid size 3988 
- First Stage -


Train [0]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 1 [0/1012] | Train Loss: 1.4969 Grad: 119964.2812 LR: 4.0001e-06 | Elapse: 0.91s
Epoch 1 [50/1012] | Train Loss: 1.3669 Grad: 40882.8789 LR: 4.2675e-06 | Elapse: 4.86s
Epoch 1 [100/1012] | Train Loss: 1.3231 Grad: 46299.6016 LR: 5.0462e-06 | Elapse: 8.82s
Epoch 1 [150/1012] | Train Loss: 1.2943 Grad: 42300.8984 LR: 6.3278e-06 | Elapse: 12.76s
Epoch 1 [200/1012] | Train Loss: 1.2733 Grad: 144970.4531 LR: 8.0988e-06 | Elapse: 16.73s
Epoch 1 [250/1012] | Train Loss: 1.2660 Grad: 83684.8203 LR: 1.0340e-05 | Elapse: 20.67s
Epoch 1 [300/1012] | Train Loss: 1.2547 Grad: 93656.4375 LR: 1.3027e-05 | Elapse: 24.62s
Epoch 1 [350/1012] | Train Loss: 1.2509 Grad: 153152.7031 LR: 1.6132e-05 | Elapse: 28.54s
Epoch 1 [400/1012] | Train Loss: 1.2397 Grad: 183264.5625 LR: 1.9622e-05 | Elapse: 32.46s
Epoch 1 [450/1012] | Train Loss: 1.2220 Grad: 103367.3438 LR: 2.3458e-05 | Elapse: 36.40s
Epoch 1 [500/1012] | Train Loss: 1.2089 Grad: 119292.9609 LR: 2.7599e-05 | Elapse: 40.35s
Epoch 1 [550/1012] | 

Valid [0]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 1 [0/250] | Valid Loss: 0.7515 | Elapse: 0.06s
Epoch 1 [50/250] | Valid Loss: 0.9421 | Elapse: 2.51s
Epoch 1 [100/250] | Valid Loss: 0.9596 | Elapse: 5.07s
Epoch 1 [150/250] | Valid Loss: 0.9713 | Elapse: 7.67s
Epoch 1 [200/250] | Valid Loss: 0.9784 | Elapse: 10.30s


----------------------------------------------------------------------------------------------------
Epoch 1 - Average Loss: (train) 1.1018; (valid) 0.9758 | Time: 93.30s
Best model found in epoch 1 | valid loss: 0.9758


Epoch 1 [249/250] | Valid Loss: 0.9758 | Elapse: 12.85s


Train [1]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 2 [0/1012] | Train Loss: 1.1342 Grad: 222867.9844 LR: 7.6143e-05 | Elapse: 0.09s
Epoch 2 [50/1012] | Train Loss: 0.9043 Grad: 101021.7891 LR: 8.0302e-05 | Elapse: 3.97s
Epoch 2 [100/1012] | Train Loss: 0.9019 Grad: 99565.1719 LR: 8.4158e-05 | Elapse: 7.88s
Epoch 2 [150/1012] | Train Loss: 0.9092 Grad: 101741.6719 LR: 8.7669e-05 | Elapse: 11.80s
Epoch 2 [200/1012] | Train Loss: 0.9074 Grad: 135650.0312 LR: 9.0798e-05 | Elapse: 15.78s
Epoch 2 [250/1012] | Train Loss: 0.9121 Grad: 208963.3906 LR: 9.3511e-05 | Elapse: 19.67s
Epoch 2 [300/1012] | Train Loss: 0.9058 Grad: 83757.8047 LR: 9.5780e-05 | Elapse: 23.54s
Epoch 2 [350/1012] | Train Loss: 0.9103 Grad: 164514.6562 LR: 9.7580e-05 | Elapse: 27.58s
Epoch 2 [400/1012] | Train Loss: 0.9060 Grad: 156066.2812 LR: 9.8891e-05 | Elapse: 31.47s
Epoch 2 [450/1012] | Train Loss: 0.9062 Grad: 63541.8789 LR: 9.9700e-05 | Elapse: 35.39s
Epoch 2 [500/1012] | Train Loss: 0.9034 Grad: 130645.9922 LR: 9.9998e-05 | Elapse: 39.27s
Epoch 2 [550/1012] 

Valid [1]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 2 [0/250] | Valid Loss: 0.6560 | Elapse: 0.05s
Epoch 2 [50/250] | Valid Loss: 0.8125 | Elapse: 2.50s
Epoch 2 [100/250] | Valid Loss: 0.7851 | Elapse: 5.05s
Epoch 2 [150/250] | Valid Loss: 0.8023 | Elapse: 7.66s
Epoch 2 [200/250] | Valid Loss: 0.8018 | Elapse: 10.29s


----------------------------------------------------------------------------------------------------
Epoch 2 - Average Loss: (train) 0.8663; (valid) 0.8039 | Time: 91.91s
Best model found in epoch 2 | valid loss: 0.8039


Epoch 2 [249/250] | Valid Loss: 0.8039 | Elapse: 12.83s


Train [2]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 3 [0/1012] | Train Loss: 0.7860 Grad: 172501.9844 LR: 9.9659e-05 | Elapse: 0.10s
Epoch 3 [50/1012] | Train Loss: 0.7720 Grad: 97949.8672 LR: 9.9589e-05 | Elapse: 4.04s
Epoch 3 [100/1012] | Train Loss: 0.7588 Grad: 88034.0312 LR: 9.9512e-05 | Elapse: 7.94s
Epoch 3 [150/1012] | Train Loss: 0.7666 Grad: 120735.2578 LR: 9.9429e-05 | Elapse: 11.96s
Epoch 3 [200/1012] | Train Loss: 0.7683 Grad: 101304.0859 LR: 9.9339e-05 | Elapse: 15.87s
Epoch 3 [250/1012] | Train Loss: 0.7723 Grad: 142041.2500 LR: 9.9243e-05 | Elapse: 19.75s
Epoch 3 [300/1012] | Train Loss: 0.7668 Grad: 72669.3438 LR: 9.9140e-05 | Elapse: 23.64s
Epoch 3 [350/1012] | Train Loss: 0.7755 Grad: 148030.7344 LR: 9.9030e-05 | Elapse: 27.55s
Epoch 3 [400/1012] | Train Loss: 0.7770 Grad: 134449.6250 LR: 9.8914e-05 | Elapse: 31.45s
Epoch 3 [450/1012] | Train Loss: 0.7771 Grad: 80175.9453 LR: 9.8792e-05 | Elapse: 35.33s
Epoch 3 [500/1012] | Train Loss: 0.7735 Grad: 174752.9219 LR: 9.8663e-05 | Elapse: 39.28s
Epoch 3 [550/1012] |

Valid [2]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 3 [0/250] | Valid Loss: 0.4342 | Elapse: 0.06s
Epoch 3 [50/250] | Valid Loss: 0.6990 | Elapse: 2.48s
Epoch 3 [100/250] | Valid Loss: 0.6841 | Elapse: 5.01s
Epoch 3 [150/250] | Valid Loss: 0.7000 | Elapse: 7.56s
Epoch 3 [200/250] | Valid Loss: 0.7012 | Elapse: 10.21s


----------------------------------------------------------------------------------------------------
Epoch 3 - Average Loss: (train) 0.7571; (valid) 0.7087 | Time: 92.18s
Best model found in epoch 3 | valid loss: 0.7087


Epoch 3 [249/250] | Valid Loss: 0.7087 | Elapse: 12.77s


Train [3]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 4 [0/1012] | Train Loss: 0.7487 Grad: 149668.0156 LR: 9.6978e-05 | Elapse: 0.09s
Epoch 4 [50/1012] | Train Loss: 0.7082 Grad: 70096.8359 LR: 9.6778e-05 | Elapse: 4.05s
Epoch 4 [100/1012] | Train Loss: 0.6876 Grad: 74807.9141 LR: 9.6572e-05 | Elapse: 7.97s
Epoch 4 [150/1012] | Train Loss: 0.6960 Grad: 122819.4688 LR: 9.6360e-05 | Elapse: 11.89s
Epoch 4 [200/1012] | Train Loss: 0.6954 Grad: 110480.6172 LR: 9.6141e-05 | Elapse: 15.80s
Epoch 4 [250/1012] | Train Loss: 0.6986 Grad: 124092.7109 LR: 9.5917e-05 | Elapse: 19.72s
Epoch 4 [300/1012] | Train Loss: 0.6957 Grad: 75263.4141 LR: 9.5686e-05 | Elapse: 23.63s
Epoch 4 [350/1012] | Train Loss: 0.7015 Grad: 158884.6562 LR: 9.5450e-05 | Elapse: 27.52s
Epoch 4 [400/1012] | Train Loss: 0.7035 Grad: 118381.5625 LR: 9.5207e-05 | Elapse: 31.44s
Epoch 4 [450/1012] | Train Loss: 0.7027 Grad: 87112.3203 LR: 9.4959e-05 | Elapse: 35.33s
Epoch 4 [500/1012] | Train Loss: 0.6993 Grad: 142761.5625 LR: 9.4704e-05 | Elapse: 39.24s
Epoch 4 [550/1012] |

Valid [3]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 4 [0/250] | Valid Loss: 0.4259 | Elapse: 0.05s
Epoch 4 [50/250] | Valid Loss: 0.6840 | Elapse: 2.49s
Epoch 4 [100/250] | Valid Loss: 0.6525 | Elapse: 5.04s
Epoch 4 [150/250] | Valid Loss: 0.6782 | Elapse: 7.67s
Epoch 4 [200/250] | Valid Loss: 0.6850 | Elapse: 10.32s


----------------------------------------------------------------------------------------------------
Epoch 4 - Average Loss: (train) 0.6964; (valid) 0.6945 | Time: 92.33s
Best model found in epoch 4 | valid loss: 0.6945


Epoch 4 [249/250] | Valid Loss: 0.6945 | Elapse: 12.86s


Train [4]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 5 [0/1012] | Train Loss: 0.7112 Grad: 160418.1719 LR: 9.1765e-05 | Elapse: 0.10s
Epoch 5 [50/1012] | Train Loss: 0.6515 Grad: 90979.2734 LR: 9.1446e-05 | Elapse: 4.09s
Epoch 5 [100/1012] | Train Loss: 0.6373 Grad: 76329.0156 LR: 9.1122e-05 | Elapse: 8.09s
Epoch 5 [150/1012] | Train Loss: 0.6387 Grad: 100398.1953 LR: 9.0792e-05 | Elapse: 12.01s
Epoch 5 [200/1012] | Train Loss: 0.6449 Grad: 77309.0312 LR: 9.0457e-05 | Elapse: 15.94s
Epoch 5 [250/1012] | Train Loss: 0.6471 Grad: 86744.7656 LR: 9.0117e-05 | Elapse: 19.86s
Epoch 5 [300/1012] | Train Loss: 0.6428 Grad: 84923.4766 LR: 8.9771e-05 | Elapse: 23.81s
Epoch 5 [350/1012] | Train Loss: 0.6474 Grad: 153732.4688 LR: 8.9420e-05 | Elapse: 27.74s
Epoch 5 [400/1012] | Train Loss: 0.6489 Grad: 119066.8594 LR: 8.9064e-05 | Elapse: 31.67s
Epoch 5 [450/1012] | Train Loss: 0.6526 Grad: 83390.6172 LR: 8.8703e-05 | Elapse: 35.55s
Epoch 5 [500/1012] | Train Loss: 0.6503 Grad: 113584.0469 LR: 8.8336e-05 | Elapse: 39.43s
Epoch 5 [550/1012] | T

Valid [4]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 5 [0/250] | Valid Loss: 0.5469 | Elapse: 0.05s
Epoch 5 [50/250] | Valid Loss: 0.6528 | Elapse: 2.53s
Epoch 5 [100/250] | Valid Loss: 0.6426 | Elapse: 5.10s
Epoch 5 [150/250] | Valid Loss: 0.6672 | Elapse: 7.72s
Epoch 5 [200/250] | Valid Loss: 0.6690 | Elapse: 10.37s


----------------------------------------------------------------------------------------------------
Epoch 5 - Average Loss: (train) 0.6468; (valid) 0.6736 | Time: 92.48s
Best model found in epoch 5 | valid loss: 0.6736


Epoch 5 [249/250] | Valid Loss: 0.6736 | Elapse: 12.93s


Train [5]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 6 [0/1012] | Train Loss: 0.5598 Grad: 158147.5469 LR: 8.4302e-05 | Elapse: 0.10s
Epoch 6 [50/1012] | Train Loss: 0.6468 Grad: 86923.8203 LR: 8.3881e-05 | Elapse: 4.02s
Epoch 6 [100/1012] | Train Loss: 0.6122 Grad: 91893.5312 LR: 8.3456e-05 | Elapse: 7.94s
Epoch 6 [150/1012] | Train Loss: 0.6163 Grad: 104865.3438 LR: 8.3027e-05 | Elapse: 11.83s
Epoch 6 [200/1012] | Train Loss: 0.6115 Grad: 75960.3516 LR: 8.2593e-05 | Elapse: 15.74s
Epoch 6 [250/1012] | Train Loss: 0.6100 Grad: 106711.6406 LR: 8.2155e-05 | Elapse: 19.66s
Epoch 6 [300/1012] | Train Loss: 0.6085 Grad: 89312.8984 LR: 8.1713e-05 | Elapse: 23.58s
Epoch 6 [350/1012] | Train Loss: 0.6103 Grad: 136169.1875 LR: 8.1267e-05 | Elapse: 27.51s
Epoch 6 [400/1012] | Train Loss: 0.6119 Grad: 136466.6406 LR: 8.0816e-05 | Elapse: 31.40s
Epoch 6 [450/1012] | Train Loss: 0.6126 Grad: 94497.2500 LR: 8.0361e-05 | Elapse: 35.31s
Epoch 6 [500/1012] | Train Loss: 0.6093 Grad: 95490.2578 LR: 7.9903e-05 | Elapse: 39.25s
Epoch 6 [550/1012] | T

Valid [5]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 6 [0/250] | Valid Loss: 0.4617 | Elapse: 0.05s
Epoch 6 [50/250] | Valid Loss: 0.6550 | Elapse: 2.53s
Epoch 6 [100/250] | Valid Loss: 0.6457 | Elapse: 5.11s
Epoch 6 [150/250] | Valid Loss: 0.6660 | Elapse: 7.77s
Epoch 6 [200/250] | Valid Loss: 0.6715 | Elapse: 10.41s


----------------------------------------------------------------------------------------------------
Epoch 6 - Average Loss: (train) 0.6099; (valid) 0.6744 | Time: 92.64s


Epoch 6 [249/250] | Valid Loss: 0.6744 | Elapse: 12.98s


Train [6]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 7 [0/1012] | Train Loss: 0.5924 Grad: 149223.6562 LR: 7.4990e-05 | Elapse: 0.09s
Epoch 7 [50/1012] | Train Loss: 0.5989 Grad: 93245.5859 LR: 7.4491e-05 | Elapse: 4.03s
Epoch 7 [100/1012] | Train Loss: 0.5895 Grad: 80092.4844 LR: 7.3988e-05 | Elapse: 7.90s
Epoch 7 [150/1012] | Train Loss: 0.5853 Grad: 89566.9531 LR: 7.3482e-05 | Elapse: 11.79s
Epoch 7 [200/1012] | Train Loss: 0.5802 Grad: 83732.4297 LR: 7.2973e-05 | Elapse: 15.66s
Epoch 7 [250/1012] | Train Loss: 0.5823 Grad: 85803.3672 LR: 7.2461e-05 | Elapse: 19.53s
Epoch 7 [300/1012] | Train Loss: 0.5806 Grad: 98000.6797 LR: 7.1946e-05 | Elapse: 23.41s
Epoch 7 [350/1012] | Train Loss: 0.5800 Grad: 158732.1250 LR: 7.1428e-05 | Elapse: 27.30s
Epoch 7 [400/1012] | Train Loss: 0.5749 Grad: 101585.0703 LR: 7.0908e-05 | Elapse: 31.34s
Epoch 7 [450/1012] | Train Loss: 0.5752 Grad: 86790.9844 LR: 7.0384e-05 | Elapse: 35.28s
Epoch 7 [500/1012] | Train Loss: 0.5725 Grad: 89048.9219 LR: 6.9858e-05 | Elapse: 39.16s
Epoch 7 [550/1012] | Tra

Valid [6]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 7 [0/250] | Valid Loss: 0.3921 | Elapse: 0.05s
Epoch 7 [50/250] | Valid Loss: 0.6713 | Elapse: 2.49s
Epoch 7 [100/250] | Valid Loss: 0.6519 | Elapse: 5.05s
Epoch 7 [150/250] | Valid Loss: 0.6738 | Elapse: 7.62s
Epoch 7 [200/250] | Valid Loss: 0.6742 | Elapse: 10.25s


----------------------------------------------------------------------------------------------------
Epoch 7 - Average Loss: (train) 0.5682; (valid) 0.6773 | Time: 91.67s


Epoch 7 [249/250] | Valid Loss: 0.6773 | Elapse: 12.80s


Train [7]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 8 [0/1012] | Train Loss: 0.3970 Grad: 145703.7969 LR: 6.4332e-05 | Elapse: 0.09s
Epoch 8 [50/1012] | Train Loss: 0.5390 Grad: 73951.5625 LR: 6.3781e-05 | Elapse: 4.01s
Epoch 8 [100/1012] | Train Loss: 0.5379 Grad: 59853.1484 LR: 6.3228e-05 | Elapse: 7.89s
Epoch 8 [150/1012] | Train Loss: 0.5381 Grad: 108148.7422 LR: 6.2672e-05 | Elapse: 11.77s
Epoch 8 [200/1012] | Train Loss: 0.5382 Grad: 79860.5391 LR: 6.2116e-05 | Elapse: 15.67s
Epoch 8 [250/1012] | Train Loss: 0.5397 Grad: 101969.1719 LR: 6.1557e-05 | Elapse: 19.57s
Epoch 8 [300/1012] | Train Loss: 0.5381 Grad: 95052.9375 LR: 6.0997e-05 | Elapse: 23.45s
Epoch 8 [350/1012] | Train Loss: 0.5374 Grad: 191226.6094 LR: 6.0436e-05 | Elapse: 27.43s
Epoch 8 [400/1012] | Train Loss: 0.5381 Grad: 95677.6484 LR: 5.9873e-05 | Elapse: 31.36s
Epoch 8 [450/1012] | Train Loss: 0.5405 Grad: 126936.6719 LR: 5.9309e-05 | Elapse: 35.28s
Epoch 8 [500/1012] | Train Loss: 0.5387 Grad: 82434.2891 LR: 5.8744e-05 | Elapse: 39.20s
Epoch 8 [550/1012] | T

Valid [7]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 8 [0/250] | Valid Loss: 0.4070 | Elapse: 0.05s
Epoch 8 [50/250] | Valid Loss: 0.7081 | Elapse: 2.51s
Epoch 8 [100/250] | Valid Loss: 0.6699 | Elapse: 5.07s
Epoch 8 [150/250] | Valid Loss: 0.6854 | Elapse: 7.66s
Epoch 8 [200/250] | Valid Loss: 0.6949 | Elapse: 10.32s


----------------------------------------------------------------------------------------------------
Epoch 8 - Average Loss: (train) 0.5340; (valid) 0.7034 | Time: 92.41s


Epoch 8 [249/250] | Valid Loss: 0.7034 | Elapse: 12.89s


Train [8]:   0%|          | 0/1012 [00:00<?, ?batch/s]

Epoch 9 [0/1012] | Train Loss: 0.4009 Grad: 174624.1406 LR: 5.2903e-05 | Elapse: 0.10s
Epoch 9 [50/1012] | Train Loss: 0.5355 Grad: 111552.3516 LR: 5.2329e-05 | Elapse: 3.99s
Epoch 9 [100/1012] | Train Loss: 0.5089 Grad: 50621.9766 LR: 5.1755e-05 | Elapse: 8.01s
Epoch 9 [150/1012] | Train Loss: 0.5147 Grad: 100171.1641 LR: 5.1181e-05 | Elapse: 11.95s
Epoch 9 [200/1012] | Train Loss: 0.5111 Grad: 73752.4062 LR: 5.0606e-05 | Elapse: 15.93s
Epoch 9 [250/1012] | Train Loss: 0.5087 Grad: 86965.5312 LR: 5.0031e-05 | Elapse: 19.90s
Epoch 9 [300/1012] | Train Loss: 0.5059 Grad: 141423.3281 LR: 4.9457e-05 | Elapse: 23.79s
Epoch 9 [350/1012] | Train Loss: 0.5106 Grad: 79033.5156 LR: 4.8882e-05 | Elapse: 27.72s
Epoch 9 [400/1012] | Train Loss: 0.5052 Grad: 111368.2734 LR: 4.8308e-05 | Elapse: 31.62s
Epoch 9 [450/1012] | Train Loss: 0.5051 Grad: 141484.8125 LR: 4.7734e-05 | Elapse: 35.55s
Epoch 9 [500/1012] | Train Loss: 0.5020 Grad: 89760.7266 LR: 4.7160e-05 | Elapse: 39.47s
Epoch 9 [550/1012] | 

Valid [8]:   0%|          | 0/250 [00:00<?, ?batch/s]

Epoch 9 [0/250] | Valid Loss: 0.3888 | Elapse: 0.06s
Epoch 9 [50/250] | Valid Loss: 0.6493 | Elapse: 2.51s
Epoch 9 [100/250] | Valid Loss: 0.6101 | Elapse: 5.07s


In [None]:
csv_path = f'./outputs/{ModelConfig.MODEL_NAME}_oof_2.csv'
print("CSV Path: ", csv_path)

oof_df = analyze_oof(csv_path)

print("Kaggle Score: ", calc_kaggle_score(oof_df))
print("Average KL Loss: ", oof_df["kl_loss"].mean())

display(oof_df.head())

# plot confusion matrix
cm = confusion_matrix(oof_df['target_id'], oof_df['target_pred']) # (y_true, y_pred)
cm = cm / cm.sum(axis=1)[:, np.newaxis]

fig = plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=TARGET2ID.keys(), yticklabels=TARGET2ID.keys())
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.title(csv_path.split('/')[-1].split('.')[0], fontsize=12)
fig.tight_layout()
fig.savefig(f"./outputs/{csv_path.split('/')[-1].split('.')[0]}_CM.png")
plt.show()

In [None]:
# check distribution of targets
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
train_all["target"].value_counts().plot(kind="bar", ax=axes[0])
train_hard["target"].value_counts().plot(kind="bar", ax=axes[1])
axes[0].set_title("Easy")
axes[1].set_title("Hard")
fig.tight_layout()
plt.show()

In [None]:
# hms_predictor = HMSPredictor(paths.OUTPUT_DIR, ModelConfig, k_fold=5)

In [None]:
# hms_predictor.train_model(train_easy, train_hard, all_specs, all_eegs)

In [None]:


# hms_predictor = HMSPredictor(paths.OUTPUT_DIR, ModelConfig, k_fold=5)

In [None]:
# new figure
fig, axes = plt.subplots(6, 5, figsize=(18, 16), sharex=True, sharey=True)

plot_oof = oof_df[oof_df['kl_loss'] > 0.2]

for row in range(axes.shape[0]):
    row_selects = plot_oof[plot_oof['target_id']==row]
    target_label = BRAIN_ACTIVITY[row]
    for col in range(axes.shape[1]):
        ax = axes[row, col]
        idx = np.random.choice(row_selects.index)
        df_rows = plot_oof.loc[idx]
        ax.plot(df_rows[TARGETS].values , label='True')
        ax.plot(df_rows[TARGETS_PRED].values, label='Pred')
        ax.set_title(f"{idx} | KL: {df_rows['kl_loss']:.4f} ") #
        ax.set_xticks(range(6))
        ax.set_xticklabels(BRAIN_ACTIVITY)
        ax.grid(True)
        ax.legend()
        if col == 0:
            ax.set_ylabel(target_label, fontsize=12)
       
fig.tight_layout()
plt.show()


