In [1]:
import os
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import gc

from kl_divergence import score as kaggle_score 
from engine_hms_trainer import (
    seed_everything, gen_non_overlap_samples, calc_entropy, evaluate_oof, get_logger, 
    Trainer, TARGETS, TARGETS_PRED, BRAIN_ACTIVITY
    )
from engine_hms_model import (
    KagglePaths, LocalPaths, ModelConfig, CustomDataset, CustomEfficientNET, CustomVITMAE, DualEncoderModel, 
)

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from time import ctime, time
from sklearn.model_selection import KFold

import warnings
# warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)

In [2]:
PATHS = KagglePaths if os.path.exists(KagglePaths.OUTPUT_DIR) else LocalPaths
print("Output Dir: ", PATHS.OUTPUT_DIR)

ALL_SPECS = np.load(PATHS.PRE_LOADED_SPECTROGRAMS, allow_pickle=True).item()
ALL_EEGS = np.load(PATHS.PRE_LOADED_EEGS, allow_pickle=True).item()

seed_everything(ModelConfig.SEED)

Output Dir:  ./outputs/


In [4]:
# train_easy, train_hard, all_specs, all_eegs = load_kaggle_data(
#     paths.TRAIN_CSV, paths.PRE_LOADED_SPECTOGRAMS, paths.PRE_LOADED_EEGS, split_entropy=ModelConfig.SPLIT_ENTROPY)

train_csv = pd.read_csv(PATHS.TRAIN_CSV)
targets = train_csv.columns[-6:]

print("targets: ", targets.to_list())

train_csv['entropy'] = train_csv.apply(calc_entropy, axis=1, tgt_list=targets)
train_csv['total_votes'] = train_csv[targets].sum(axis=1)

hard_csv = train_csv[train_csv['entropy'] < ModelConfig.SPLIT_ENTROPY].copy().reset_index(drop=True)

train_all = gen_non_overlap_samples(train_csv, targets)
train_hard = gen_non_overlap_samples(hard_csv, targets)
train_all['entropy'] = train_all.apply(calc_entropy, axis=1, tgt_list=targets)

print("train_all.shape = ", train_all.shape)
print("train_all nan_count: ", train_all.isnull().sum().sum())
display(train_all.head())

print(" ")

print("train_hard.shape = ", train_hard.shape)
print("train_hard nan_count: ", train_hard.isnull().sum().sum())
display(train_hard.head())

targets:  ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
train_all.shape =  (17089, 13)
train_all nan_count:  0


Unnamed: 0,eeg_id,spectrogram_id,min,max,patient_id,target,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,total_votes
0,568657,789577333,0.0,16.0,20654,Other,0.0,0.0,0.25,0.0,0.166667,0.583333,48
1,582999,1552638400,0.0,38.0,20230,LPD,0.0,0.857143,0.0,0.071429,0.0,0.071429,154
2,642382,14960202,1008.0,1032.0,5955,Other,0.0,0.0,0.0,0.0,0.0,1.0,2
3,751790,618728447,908.0,908.0,38549,GPD,0.0,0.0,1.0,0.0,0.0,0.0,1
4,778705,52296320,0.0,0.0,40955,Other,0.0,0.0,0.0,0.0,0.0,1.0,2


 
train_hard.shape =  (5536, 13)
train_hard nan_count:  0


Unnamed: 0,eeg_id,spectrogram_id,min,max,patient_id,target,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,total_votes
0,568657,789577333,0.0,16.0,20654,Other,0.0,0.0,0.25,0.0,0.166667,0.583333,48
1,582999,1552638400,0.0,38.0,20230,LPD,0.0,0.857143,0.0,0.071429,0.0,0.071429,154
2,1895581,128369999,1138.0,1138.0,47999,Other,0.076923,0.0,0.0,0.0,0.076923,0.846154,13
3,2482631,978166025,1902.0,1944.0,20606,Other,0.0,0.0,0.133333,0.066667,0.133333,0.666667,105
4,2521897,673742515,0.0,4.0,62117,Other,0.0,0.0,0.083333,0.083333,0.333333,0.5,24


In [None]:
def prepare_k_fold(df, k_folds=5):

    kf = KFold(n_splits=k_folds, shuffle=True, random_state=ModelConfig.SEED)
    unique_spec_id = df['spectrogram_id'].unique()
    df['fold'] = k_folds

    for fold, (train_index, valid_index) in enumerate(kf.split(unique_spec_id)):
        df.loc[df['spectrogram_id'].isin(unique_spec_id[valid_index]), 'fold'] = fold

    return df


def train_fold(model, fold_id, train_folds, valid_folds, logger, stage=1, checkpoint=None):

    train_dataset = CustomDataset(
        train_folds, TARGETS, ModelConfig, ALL_SPECS, ALL_EEGS, mode="train")

    valid_dataset = CustomDataset(
        valid_folds, TARGETS, ModelConfig, ALL_SPECS, ALL_EEGS, mode="valid")

    # ======== DATALOADERS ==========
    loader_kwargs = {
        "batch_size": ModelConfig.BATCH_SIZE,
        "num_workers": ModelConfig.NUM_WORKERS,
        "pin_memory": True,
        "shuffle": False,
    }
    train_loader = DataLoader(train_dataset, drop_last=True, **loader_kwargs)
    valid_loader = DataLoader(valid_dataset, drop_last=False, **loader_kwargs)

    trainer = Trainer(model, ModelConfig, logger)
    best_weights, best_preds, loss_records = trainer.train(
        train_loader, valid_loader, from_checkpoint=checkpoint)

    save_model_name = f"{ModelConfig.MODEL_NAME}_fold_{fold_id}_stage_{stage}.pth"
    torch.save(best_weights, os.path.join(PATHS.OUTPUT_DIR, save_model_name))

    del train_dataset, valid_dataset, train_loader, valid_loader
    torch.cuda.empty_cache()
    gc.collect()

    return best_preds, loss_records

def get_model(pretrained=True):
    
    backbone = ModelConfig.MODEL_BACKBONE

    if "efficientnet" in backbone:
        return CustomEfficientNET(ModelConfig, num_classes=6, pretrained=pretrained)
    elif "vit" in backbone:
        return CustomVITMAE(ModelConfig, num_classes=6, pretrained=pretrained)
    elif "dual" in backbone:
        return DualEncoderModel(ModelConfig, num_classes=6, pretrained=pretrained)
    else:
        return None

In [None]:
TARGET2ID = {'Seizure': 0, 'LPD': 1, 'GPD': 2, 'LRDA': 3, 'GRDA': 4, 'Other': 5}

def calc_kaggle_score(oof_df):
    submission_df = oof_df[['eeg_id']+TARGETS_PRED].copy()
    submission_df.columns = ['eeg_id'] + TARGETS
    solution_df = oof_df[['eeg_id']+TARGETS].copy()
    return kaggle_score(solution_df, submission_df, 'eeg_id')

def analyze_oof(oof_csv):

    kl_criteria = nn.KLDivLoss(reduction='batchmean')
    softmax = nn.Softmax(dim=1)

    oof_df = pd.read_csv(oof_csv)
    oof_df['target_pred'] = oof_df[TARGETS_PRED].apply(lambda x: np.argmax(x), axis=1)
    oof_df['target_id'] = oof_df[TARGETS].apply(lambda x: np.argmax(x), axis=1)
    
    oof_df["kl_loss"] = oof_df.apply(
    lambda row: 
        kl_criteria(
            F.log_softmax(
                    torch.tensor(row[TARGETS_PRED].values.astype(np.float32)).unsqueeze(0)
                , dim=1
                ), 
            torch.tensor(row[TARGETS].values.astype(np.float32))
            ).numpy(),
    axis=1)

    oof_df["kl_loss"] = oof_df['kl_loss'].astype(np.float32)

    oof_df[TARGETS_PRED] = softmax( torch.tensor(oof_df[TARGETS_PRED].values.astype(np.float32)) )

    oof_df.head()

    return oof_df

In [None]:
# # EfficientNet_B2_resplit (CV=0.5330731377333943)
# ModelConfig.EPOCHS = 6
# ModelConfig.BATCH_SIZE = 16
# ModelConfig.GRADIENT_ACCUMULATION_STEPS = 1
# ModelConfig.MODEL_BACKBONE = 'tf_efficientnet_b2'
# ModelConfig.MODEL_NAME = "EfficientNet_B2_resplit"
# ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
# ModelConfig.USE_EEG_SPECTROGRAMS = True
# ModelConfig.REGULARIZATION = None
# ModelConfig.AUGMENT = True
# ModelConfig.AUGMENTATIONS = ['xy_masking']

# # config DualEncoder
# ModelConfig.EPOCHS = 6
# ModelConfig.BATCH_SIZE = 16
# ModelConfig.GRADIENT_ACCUMULATION_STEPS = 1
# ModelConfig.MODEL_BACKBONE = 'dual_encoder'
# ModelConfig.MODEL_NAME = "DualEncoder_B0"
# ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
# ModelConfig.USE_EEG_SPECTROGRAMS = True
# ModelConfig.REGULARIZATION = None
# ModelConfig.AUGMENT = False
# ModelConfig.AUGMENTATIONS = []
# ModelConfig.DUAL_ENCODER_BACKBONE = 'tf_efficientnet_b0'

# Config ViTMAE
ModelConfig.EPOCHS = 15
ModelConfig.BATCH_SIZE = 16
ModelConfig.GRADIENT_ACCUMULATION_STEPS = 1
ModelConfig.MODEL_BACKBONE = 'vit_mae_base'
ModelConfig.MODEL_NAME = "MAE_RawBase_SeqPool_epoch_15"
ModelConfig.AUGMENT = True
ModelConfig.USE_KAGGLE_SPECTROGRAMS = True
ModelConfig.USE_EEG_SPECTROGRAMS = True
ModelConfig.REGULARIZATION = None
ModelConfig.AUGMENTATIONS = ['xy_masking']
ModelConfig.MAE_PRETRAINED_WEIGHTS = "facebook/vit-mae-base" #"./outputs/vit_mae_pretraining/ViTMAE_PreTrained_Best.pth"
ModelConfig.MAE_HIDDEN_DROPOUT_PROB = 0.1
ModelConfig.MAE_ATTENTION_DROPOUT_PROB = 0.1
ModelConfig.DROP_RATE = 0.2

In [None]:
logger = get_logger(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_train.log")

logger.info(f"{'*' * 100}")
logger.info(f"Script Start: {ctime()}")
logger.info(f"Model Configurations:")
for key, value in ModelConfig.__dict__.items():
    if not key.startswith("__"):
        logger.info(f"{key}: {value}")
logger.info(f"{'*' * 100}")

k_folds = 5
train_all = prepare_k_fold(train_all, k_folds=k_folds)

oof_stage_1, oof_stage_2 = pd.DataFrame(), pd.DataFrame()
loss_history_1, loss_history_2 = [], []

for fold in range(k_folds):
    tik = time()

    model = get_model(pretrained=True)
    # model = CustomVITMAE(ModelConfig, num_classes=6, pretrained=False)

    valid_folds = train_all[train_all['fold'] == fold].reset_index(drop=True)
    train_folds = train_all[train_all['fold'] != fold].reset_index(drop=True)

    # ============== STAGE 1 ==============
    logger.info(f"{'=' * 100}\nFold: {fold} || Valid size {valid_folds.shape[0]} \n{'=' * 100}")
    logger.info(f"- First Stage -")
    valid_predicts, loss_records = train_fold(
        model, fold, train_folds, valid_folds, logger, stage=1, checkpoint=None)

    loss_history_1.append(loss_records)
    valid_folds[TARGETS_PRED] = valid_predicts
    oof_stage_1 = pd.concat([oof_stage_1, valid_folds], axis=0).reset_index(drop=True)
    kl_loss_torch = evaluate_oof(valid_folds)
    info = f"{'=' * 100}\nFold {fold} Valid Loss: {kl_loss_torch}\n"
    info += f"Elapse: {(time() - tik) / 60:.2f} min \n{'=' * 100}"
    logger.info(info)
    oof_stage_1.to_csv(os.path.join(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_oof_1.csv"), index=False)

    # ============== STAGE 2 ==============
    tik = time()
    logger.info(f"- Second Stage -")
    check_point = os.path.join(
        PATHS.OUTPUT_DIR,
        f"{ModelConfig.MODEL_NAME}_fold_{fold}_stage_1.pth"
    )
    logger.info(f"Use Checkpoint: {check_point.split('/')[-1]}")

    model = get_model(pretrained=True)
    valid_predicts, loss_records = train_fold(
        model, fold, train_hard, valid_folds, logger, stage=2, checkpoint=check_point)

    loss_history_2.append(loss_records)
    valid_folds[TARGETS_PRED] = valid_predicts
    oof_stage_2 = pd.concat([oof_stage_2, valid_folds], axis=0).reset_index(drop=True)
    kl_loss_torch = evaluate_oof(valid_folds)
    info = f"{'=' * 100}\nFold {fold} Valid Loss: {kl_loss_torch}\n"
    info += f"Elapse: {(time() - tik) / 60:.2f} min \n{'=' * 100}"
    logger.info(info)
    oof_stage_2.to_csv(os.path.join(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_oof_2.csv"), index=False)

In [None]:
csv_path = f'./outputs/{ModelConfig.MODEL_NAME}_oof_1.csv'
print("CSV Path: ", csv_path)

oof_df = analyze_oof(csv_path)

print("Kaggle Score: ", calc_kaggle_score(oof_df))
print("Average KL Loss: ", oof_df["kl_loss"].mean())

display(oof_df.head())

# plot confusion matrix
cm = confusion_matrix(oof_df['target_id'], oof_df['target_pred']) # (y_true, y_pred)
cm = cm / cm.sum(axis=1)[:, np.newaxis]

fig = plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=TARGET2ID.keys(), yticklabels=TARGET2ID.keys())
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.title(csv_path.split('/')[-1].split('.')[0], fontsize=12)
fig.tight_layout()
fig.savefig(f"./outputs/{csv_path.split('/')[-1].split('.')[0]}_CM.png")
plt.show()

In [None]:
# check distribution of targets
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
train_all["target"].value_counts().plot(kind="bar", ax=axes[0])
train_hard["target"].value_counts().plot(kind="bar", ax=axes[1])
axes[0].set_title("Easy")
axes[1].set_title("Hard")
fig.tight_layout()
plt.show()

In [None]:
# hms_predictor = HMSPredictor(paths.OUTPUT_DIR, ModelConfig, k_fold=5)

In [None]:
# hms_predictor.train_model(train_easy, train_hard, all_specs, all_eegs)

In [None]:


# hms_predictor = HMSPredictor(paths.OUTPUT_DIR, ModelConfig, k_fold=5)

In [None]:
# new figure
fig, axes = plt.subplots(6, 5, figsize=(18, 16), sharex=True, sharey=True)

plot_oof = oof_df[oof_df['kl_loss'] > 0.2]

for row in range(axes.shape[0]):
    row_selects = plot_oof[plot_oof['target_id']==row]
    target_label = BRAIN_ACTIVITY[row]
    for col in range(axes.shape[1]):
        ax = axes[row, col]
        idx = np.random.choice(row_selects.index)
        df_rows = plot_oof.loc[idx]
        ax.plot(df_rows[TARGETS].values , label='True')
        ax.plot(df_rows[TARGETS_PRED].values, label='Pred')
        ax.set_title(f"{idx} | KL: {df_rows['kl_loss']:.4f} ") #
        ax.set_xticks(range(6))
        ax.set_xticklabels(BRAIN_ACTIVITY)
        ax.grid(True)
        ax.legend()
        if col == 0:
            ax.set_ylabel(target_label, fontsize=12)
       
fig.tight_layout()
plt.show()


