In [1]:
import os, gc, random
import numpy as np
import pandas as pd 
from pathlib import Path
import matplotlib.pyplot as plt
from typing import List, Dict
from tqdm.notebook import tqdm
from time import time, ctime

from sklearn.model_selection import KFold, GroupKFold

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2
from torch.optim.lr_scheduler import OneCycleLR,  CosineAnnealingWarmRestarts
from torch.optim import Adam, AdamW
from torch.cuda.amp import autocast, GradScaler

from scipy.signal import butter, lfilter, freqz
from scipy.stats import entropy
from scipy.special import rel_entr, softmax

In [2]:
def get_logger(log_dir, logger_name="train_model.log"):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger_file = os.path.join(log_dir, logger_name)
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=logger_file, mode="a+")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [3]:
class ModelConfig:
    SEED = 20
    SPLIT_ENTROPY = 5.5
    MODEL_NAME = "ResnetGRU_v1_LB048"
    MODEL_BACKBONE = "reset_gru"
    BATCH_SIZE = 32
    EPOCHS = 20
    EARLY_STOP_ROUNDS = 5
    GRADIENT_ACCUMULATION_STEPS = 1
    DROP_RATE = 0.15 # default: 0.1
    DROP_PATH_RATE = 0.25 # default: 0.2
    WEIGHT_DECAY = 0.01
    AMP = True
    PRINT_FREQ = 100
    NUM_WORKERS = 0 
    MAX_GRAD_NORM = 1e7
    REGULARIZATION = 0.15
    RESNET_GRU_BANDPASS = None #(0.5, 20)
    RESNET_GRU_IN_CHANNELS = 8
    RESNET_GRU_KERNELS = [3, 5, 7, 9, 11]
    RESNET_GRU_FIXED_KERNEL_SIZE = 5
    RESNET_GRU_DOWNSAMPLE = 5 # None #5
    RESNET_GRU_HIDDEN_SIZE = 304 #448 #304
    RESNET_GRU_DILATED = False

In [4]:
N_GPU = torch.cuda.device_count()
if N_GPU > 1:
    DEVICE = torch.device("cuda")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
elif N_GPU == 1:
    DEVICE = torch.device("cuda:0")
else:
    DEVICE = torch.device("cpu")

print("Use Device: ", DEVICE)

Use Device:  cuda:0


In [5]:
class KagglePaths:
    OUTPUT_DIR = "/kaggle/working/"
    PRE_LOADED_EEGS = '/kaggle/input/brain-eeg-spectrograms/eeg_specs.npy'
    PRE_LOADED_SPECTROGRAMS = '/kaggle/input/brain-spectrograms/specs.npy'
    TRAIN_CSV = "/kaggle/input/hms-harmful-brain-activity-classification/train.csv"
    TRAIN_EEGS = "/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/"
    TRAIN_SPECTROGRAMS = "/kaggle/input/hms-harmful-brain-activity-classification/train_spectrograms/"
    TEST_CSV = "/kaggle/input/hms-harmful-brain-activity-classification/test.csv"
    TEST_SPECTROGRAMS = "/kaggle/input/hms-harmful-brain-activity-classification/test_spectrograms/"
    TEST_EEGS = "/kaggle/input/hms-harmful-brain-activity-classification/test_eegs/"


class LocalPaths:
    OUTPUT_DIR = "./outputs/"
    PRE_LOADED_EEGS = './inputs/brain-eeg-spectrograms/eeg_specs.npy'
    PRE_LOADED_SPECTROGRAMS = './inputs/brain-spectrograms/specs.npy'
    TRAIN_CSV = "./inputs/hms-harmful-brain-activity-classification/train.csv"
    TRAIN_EEGS = "./inputs/hms-harmful-brain-activity-classification/train_eegs"
    TRAIN_SPECTROGRAMS = "./inputs/hms-harmful-brain-activity-classification/train_spectrograms"
    TEST_CSV = "./inputs/hms-harmful-brain-activity-classification/test.csv"
    TEST_SPECTROGRAMS = "./inputs/hms-harmful-brain-activity-classification/test_spectrograms"
    TEST_EEGS = "./inputs/hms-harmful-brain-activity-classification/test_eegs"

PATHS = KagglePaths if os.path.exists("/kaggle") else LocalPaths

print("Output Dir: ", PATHS.OUTPUT_DIR)

EEG_FEAT_ALL = [
    'Fp1', 'F3', 'C3', 'P3', 
    'F7', 'T3', 'T5', 'O1', 
    'Fz', 'Cz', 'Pz', 'Fp2', 
    'F4', 'C4', 'P4', 'F8', 
    'T4', 'T6', 'O2', 'EKG'
    ]

EEG_FEAT_USE =  ['Fp1','T3','C3','O1','Fp2','C4','T4','O2']
EEG_FEAT_INDEX = {x:y for x,y in zip(EEG_FEAT_USE, range(len(EEG_FEAT_USE)))}

BRAIN_ACTIVITY = ['seizure', 'lpd', 'gpd', 'lrda', 'grda', 'other']
TARGETS = [f"{lb}_vote" for lb in BRAIN_ACTIVITY]
TARGETS_PRED = [f"{lb}_pred" for lb in BRAIN_ACTIVITY]

seed_everything(ModelConfig.SEED)

print(EEG_FEAT_INDEX)

Output Dir:  ./outputs/
{'Fp1': 0, 'T3': 1, 'C3': 2, 'O1': 3, 'Fp2': 4, 'C4': 5, 'T4': 6, 'O2': 7}


In [6]:
logger = get_logger(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_train.log")

# Load Data

In [7]:
def eeg_from_parquet(parquet_path: str, use_feature=EEG_FEAT_USE, display: bool = False) -> np.ndarray:
    # === Extract full length EEG Sequence ===
    # fill missing values with mean
    # first fill missing values with mean of each column
    # then if all values are missing, fill with 0
    eeg = pd.read_parquet(parquet_path, columns=use_feature)
    eeg = eeg.fillna(eeg.mean(skipna=True)).fillna(0)
    data = eeg.values.astype(np.float32)
    
    rows = len(eeg)
    offset = (rows - 10_000) // 2 # 50 * 200 = 10_000
    data = data[offset:offset+10_000, :]

    if display:
        fig, ax = plt.subplots(len(use_feature), 1, figsize=(10, 2*len(use_feature)), sharex=True)
        
        for i, feat in enumerate(use_feature):
            ax[i].plot(data[:, i], label=feat)
            ax[i].legend()
            ax[i].grid()
       
        name = parquet_path.split('/')[-1].split('.')[0]
        ax[0].set_title(f'EEG {name}',size=16)
        fig.tight_layout()
        plt.show()    
    return data

In [8]:
%%time
CREATE_EEGS = False
ALL_EEG_SIGNALS = {}
eeg_paths = list(Path(PATHS.TRAIN_EEGS).glob('*.parquet'))
preload_eegs_path = Path('./inputs/eegs_full.npy')

if CREATE_EEGS:
    count = 0
    for parquet_path in tqdm(eeg_paths, total=len(eeg_paths)):
        eeg_id = int(parquet_path.stem)
        eeg_path = str(parquet_path)
        data = eeg_from_parquet(eeg_path, display=False)
        ALL_EEG_SIGNALS[eeg_id] = data
        count += 1
    np.save("./inputs/eegs_full.npy", ALL_EEG_SIGNALS)
else:
    ALL_EEG_SIGNALS = np.load(preload_eegs_path, allow_pickle=True).item()

CPU times: user 178 ms, sys: 1.27 s, total: 1.45 s
Wall time: 1.45 s


In [9]:
def gen_non_overlap_samples(df_csv, targets):
    # Reference Discussion:
    # https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification/discussion/467021

    tgt_list = targets.tolist()
    brain_activity = ['seizure', 'lpd', 'gpd', 'lrda', 'grda', 'other']

    agg_dict = {
        'spectrogram_id': 'first',
        'spectrogram_label_offset_seconds': ['min', 'max'],
        'patient_id': 'first',
        'expert_consensus': 'first'
    }

    groupby = df_csv.groupby(['eeg_id'] + tgt_list)
    train = groupby.agg(agg_dict)
    train = train.reset_index()
    train.columns = ['_'.join(col).strip() for col in train.columns.values]
    train.columns = ["eeg_id"] + tgt_list + ['spectrogram_id', 'min', 'max', 'patient_id', 'target']
    
    train['total_votes'] = train[tgt_list].sum(axis=1)
    train[tgt_list] = train[tgt_list].div(train['total_votes'], axis=0)
    
    return train

In [10]:
# # Enhanced Samples Split 

# train_csv = pd.read_csv(PATHS.TRAIN_CSV)
# targets = train_csv.columns[-6:].tolist()

# raw_csv_len = len(train_csv)

# subset_counts = train_csv.groupby(['eeg_id']+targets).size().reset_index(name='subset_counts')
# train_csv = train_csv.merge(subset_counts, on=['eeg_id']+targets, how='left')

# tmp_cols = ['expert_consensus', 'eeg_label_offset_seconds', 'subset_counts']

# def sample_rule(x):
#     if (x['subset_counts'].min() > 3) & ((x['expert_consensus']!='Other').any()):
#         return x['eeg_label_offset_seconds'].sample(n=(x['subset_counts'].min()//3))
#     else:
#         return x['eeg_label_offset_seconds'].sample(n=1)

# train_samples = train_csv.groupby(['eeg_id']+targets)[tmp_cols].apply(sample_rule).reset_index()
# train_samples = train_samples.rename(columns={'eeg_label_offset_seconds': 'eeg_off_seconds'})
# train_samples.drop(columns=['level_7'], inplace=True)

# train_meta = train_csv.groupby(['eeg_id']+targets).agg({
#     'spectrogram_id': 'first',
#     'spectrogram_label_offset_seconds': ['min', 'max'],
#     'eeg_sub_id': 'count',
#     'eeg_label_offset_seconds': ['min', 'max'],
#     'patient_id': 'first',
# }).reset_index()

# agged_cols = [
#     'spectrogram_id', 'min', 'max', 'subset_counts', 'eeg_off_min', 'eeg_off_max', 'patient_id'
# ]
# train_meta.columns = ['eeg_id'] + targets + agged_cols
# train_meta = train_meta[['eeg_id'] + agged_cols + targets]

# train_meta['total_votes'] = train_meta[targets].sum(axis=1)
# train_meta['target'] = train_meta[targets].idxmax(axis=1).apply(lambda x: x.split('_')[0])
# train_meta['fold'] = -1

# K_FOLDS = 5
# kf = KFold(n_splits=K_FOLDS, shuffle=False)
# unique_eegs = train_meta['eeg_id'].unique()
# for fold, (_, valid_idx) in enumerate(kf.split(unique_eegs)):
#     train_meta.loc[train_meta['eeg_id'].isin(unique_eegs[valid_idx]), 'fold'] = fold

# train_all = train_samples.merge(train_meta, on=['eeg_id']+targets, how='left')

# train_all[targets] = train_all[targets].div(train_all['total_votes'], axis=0)

# train_all['stage'] = train_all['total_votes'].apply(lambda x: 1 if x < 10 else 2)

# train_all

In [11]:
# Original Split 

train_csv = pd.read_csv(PATHS.TRAIN_CSV)
targets = train_csv.columns[-6:]

print("targets: ", targets.to_list())

train_csv['total_votes'] = train_csv[targets].sum(axis=1)
train_csv[targets] = train_csv[targets].astype('float32')

targets_prob = [f"{t.split('_')[0]}_prob" for t in targets]
train_csv[targets_prob] = train_csv[targets].div(train_csv['total_votes'], axis=0)
# train_csv['rel_entropy'] = train_csv[targets_prob].apply(lambda row: sum(rel_entr([1/6]*6, row.values+1e-5)), axis=1)
# train_csv['entropy'] = train_csv[targets_prob].apply(lambda row: entropy(row.values), axis=1)

# hard_csv = train_csv[train_csv['entropy'] < ModelConfig.SPLIT_ENTROPY].copy().reset_index(drop=True)
# hard_csv = train_csv[train_csv['entropy'] >= 0.75].copy().reset_index(drop=True)
hard_csv = train_csv[train_csv['total_votes'] >= 6].copy().reset_index(drop=True)


train_all = gen_non_overlap_samples(train_csv, targets)
train_hard = gen_non_overlap_samples(hard_csv, targets)

print("train_all.shape = ", train_all.shape)
print("train_all nan_count: ", train_all.isnull().sum().sum())
display(train_all.head())

print(" ")

print("train_hard.shape = ", train_hard.shape)
print("train_hard nan_count: ", train_hard.isnull().sum().sum())
display(train_hard.head())

targets:  ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
train_all.shape =  (20183, 13)
train_all nan_count:  0


Unnamed: 0,eeg_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,spectrogram_id,min,max,patient_id,target,total_votes
0,568657,0.0,0.0,0.25,0.0,0.166667,0.583333,789577333,0.0,16.0,20654,Other,12.0
1,582999,0.0,0.857143,0.0,0.071429,0.0,0.071429,1552638400,0.0,38.0,20230,LPD,14.0
2,642382,0.0,0.0,0.0,0.0,0.0,1.0,14960202,1008.0,1032.0,5955,Other,1.0
3,751790,0.0,0.0,1.0,0.0,0.0,0.0,618728447,908.0,908.0,38549,GPD,1.0
4,778705,0.0,0.0,0.0,0.0,0.0,1.0,52296320,0.0,0.0,40955,Other,2.0


 
train_hard.shape =  (6492, 13)
train_hard nan_count:  0


Unnamed: 0,eeg_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,spectrogram_id,min,max,patient_id,target,total_votes
0,568657,0.0,0.0,0.25,0.0,0.166667,0.583333,789577333,0.0,16.0,20654,Other,12.0
1,582999,0.0,0.857143,0.0,0.071429,0.0,0.071429,1552638400,0.0,38.0,20230,LPD,14.0
2,1895581,0.076923,0.0,0.0,0.0,0.076923,0.846154,128369999,1138.0,1138.0,47999,Other,13.0
3,2482631,0.0,0.0,0.133333,0.066667,0.133333,0.666667,978166025,1902.0,1944.0,20606,Other,15.0
4,2521897,0.0,0.0,0.083333,0.083333,0.333333,0.5,673742515,0.0,4.0,62117,Other,12.0


# Dataset

In [12]:
# Functional Utils
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter(order, [lowcut, highcut], fs=fs, btype='band')
    y = lfilter(b, a, data)
    return y

def denoise_filter(x):
    # Sample rate and desired cutoff frequencies (in Hz).
    fs = 200.0
    lowcut = 1.0
    highcut = 25.0
    
    # Filter a noisy signal.
    T = 50
    nsamples = T * fs
    t = np.arange(0, nsamples) / fs
    y = butter_bandpass_filter(x, lowcut, highcut, fs, order=6)
    y = (y + np.roll(y,-1)+ np.roll(y,-2)+ np.roll(y,-3))/4
    y = y[0:-1:4]
    
    return y

def mu_law_encoding(data, mu):
    mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
    return mu_x

def mu_law_expansion(data, mu):
    s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu
    return s

def quantize_data(data, classes):
    mu_x = mu_law_encoding(data, classes)
    return mu_x #quantized

def butter_lowpass_filter(data, cutoff_freq=20, sampling_rate=200, order=4):
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = lfilter(b, a, data, axis=0)
    return filtered_data


In [13]:
class EEGSeqDataset(Dataset):
    def __init__(self, df, config, eegs, mode='train', verbose=False):
        self.df = df
        self.mode = mode
        self.eegs = eegs
        self.verbose = verbose
        self.downsample = config.RESNET_GRU_DOWNSAMPLE
        self.use_bandpass = config.RESNET_GRU_BANDPASS
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        X, y_prob = self.__data_generation(idx)
        
        if self.downsample is not None:
            X = X[::self.downsample,:]
        
        return torch.tensor(X, dtype=torch.float32), torch.tensor(y_prob, dtype=torch.float32)
    
    def __data_generation(self, index):
        row = self.df.iloc[index]
        
        if self.verbose:
            print(f"Row {index}", row[['eeg_id', 'eeg_off_min', 'target']].tolist())

        X = np.zeros((10_000, 8), dtype='float32')
        
        # # start_sec = int((row['eeg_off_min'] + row['eeg_off_max']) // 2)
        # eeg_seq = self.eegs[row.eeg_id]
        # len_seq = eeg_seq.shape[0]
        # start_at = int(row['eeg_off_min']) + (len_seq - 10_000) // 2 
        # # !!! use randomly sampled offset !!!
        # # start_sec = int(row['eeg_off_sample']) 
        # data = eeg_seq[start_at:start_at+10_000, :]
        
        data = self.eegs[row.eeg_id]

        # === Feature engineering ===
        X[:,0] = data[:,EEG_FEAT_INDEX['Fp1']] - data[:,EEG_FEAT_INDEX['T3']]
        X[:,1] = data[:,EEG_FEAT_INDEX['T3']] - data[:,EEG_FEAT_INDEX['O1']]

        X[:,2] = data[:,EEG_FEAT_INDEX['Fp1']] - data[:,EEG_FEAT_INDEX['C3']]
        X[:,3] = data[:,EEG_FEAT_INDEX['C3']] - data[:,EEG_FEAT_INDEX['O1']]

        X[:,4] = data[:,EEG_FEAT_INDEX['Fp2']] - data[:,EEG_FEAT_INDEX['C4']]
        X[:,5] = data[:,EEG_FEAT_INDEX['C4']] - data[:,EEG_FEAT_INDEX['O2']]

        X[:,6] = data[:,EEG_FEAT_INDEX['Fp2']] - data[:,EEG_FEAT_INDEX['T4']]
        X[:,7] = data[:,EEG_FEAT_INDEX['T4']] - data[:,EEG_FEAT_INDEX['O2']]

        # === Standarize ===
        X = np.clip(X,-1024, 1024)
        X = np.nan_to_num(X, nan=0) / 32.0

        # === Butter Low-pass Filter ===
        # ??? change to bandpass filter (low=0.5, hight=20, order=2) ???
        if self.use_bandpass is not None:
            X = butter_lowpass_filter(X, self.use_bandpass[0], self.use_bandpass[1], order=2)
            
        X = butter_lowpass_filter(X) 
        
        if self.mode != 'test':
            y_prob = row[TARGETS].values.astype(np.float32)
        else:
            y_prob = np.zeros(6, dtype='float32')

        return X, y_prob 

In [14]:
# # visualize the dataset
# train_dataset = EEGSeqDataset(train_all, ModelConfig, ALL_EEG_SIGNALS, mode="train")
# train_loader = DataLoader(train_dataset, drop_last=True, batch_size=16, num_workers=4, pin_memory=True, shuffle=False)

# for batch in train_loader:
#     X, y = batch
#     print(f"X shape: {X.shape}")
#     print(f"y shape: {y.shape}")
    
#     fig, axes = plt.subplots(4, 1, figsize=(20, 20))
#     ax_idx = 0
#     for item in np.random.choice(range(X.shape[0]), 4):
#         offset = 0
#         for col in range(X.shape[-1]):
#             if col != 0:
#                 offset -= X[item,:,col].min()
#             axes[ax_idx].plot(np.arange(X.shape[1]), X[item,:,col]+offset, label=f'feature {col+1}')
#             offset += X[item,:,col].max()
#         print(y[item])
#         # axes[ax_idx].set_title(f'Weight = {weights[item]}',size=14)
#         axes[ax_idx].legend()
#         ax_idx += 1
#     fig.tight_layout()
#     plt.show()
#     break

# del train_dataset, train_loader
# torch.cuda.empty_cache()
# gc.collect()

# Model

### Resnet 1D Encoder

In [15]:
class ResNet_1D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling, dropout=0.0, dilation=1):
        super(ResNet_1D_Block, self).__init__()
        self.block = nn.Sequential(
            nn.BatchNorm1d(num_features=in_channels),
            nn.Hardswish(), #nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, bias=False),
            nn.BatchNorm1d(num_features=out_channels),
            nn.Hardswish(), #nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding, dilation=dilation, bias=False),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        )
        self.downsampling = downsampling

    def forward(self, x):
        identity = self.downsampling(x)
        out = self.block(x)
        out += identity
        return out

class SelfAttentionPooling(nn.Module):
    """
    Implementation of SelfAttentionPooling 
    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
    https://arxiv.org/pdf/2008.01077v1.pdf
    """
    def __init__(self, input_dim):
        super(SelfAttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, batch_rep):
        """
        input:
            batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
        attention_weight:
            att_w : size (N, T, 1)
        return:
            utter_rep: size (N, H)
        """
        att_w = self.softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)

        return utter_rep

class ResNetGRU(nn.Module):
    def __init__(self, config=ModelConfig, num_classes=6):
        super(ResNetGRU, self).__init__()

        self.planes = 24
        self.kernels = config.RESNET_GRU_KERNELS
        self.in_channels = config.RESNET_GRU_IN_CHANNELS
        self.use_dilation = config.RESNET_GRU_DILATED

        fixed_kernel_size = config.RESNET_GRU_FIXED_KERNEL_SIZE
        hidden_size = config.RESNET_GRU_HIDDEN_SIZE

        # Define the separate convolutional layers
        self.parallel_conv = self._make_parallel_conv_layers()
        # Define the ResNet part of the model
        self.resnet_part = self._make_resnet_part(fixed_kernel_size, n_blocks=9)
        # Define the GRU part of the model
        self.rnn = nn.GRU(input_size=self.in_channels, hidden_size=128, num_layers=1, bidirectional=True)
        self.pooling = SelfAttentionPooling(256)
        # Define the final fully connected layer
        self.fc = nn.Linear(in_features=hidden_size, out_features=num_classes)

    def _make_parallel_conv_layers(self):
        return nn.ModuleList([
            nn.Conv1d(
                in_channels=self.in_channels, 
                out_channels=self.planes, 
                kernel_size=kernel_size,
                stride=1, 
                padding=0, 
                bias=False
            ) for kernel_size in self.kernels
        ])

    def _make_resnet_part(self, fixed_kernel_size, n_blocks=9):
        # prepare resnet layers
        downsampling = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)

        if self.use_dilation:
            dilation_rates = [1, 2, 2, 2, 2, 4, 4, 4, 4] #[1] * n_blocks
        else:
            dilation_rates = [1] * n_blocks

        paddings = [fixed_kernel_size//2 * rate for rate in dilation_rates]
        resnet_layers = [
            ResNet_1D_Block(
                in_channels=self.planes, 
                out_channels=self.planes, 
                kernel_size=fixed_kernel_size, 
                stride=1, 
                padding=paddings[i], 
                downsampling=downsampling,
                dropout=0.0,
                dilation=dilation_rates[i])
            for i in range(n_blocks)
        ]
        # return the resnet encoder
        return nn.Sequential(
            nn.BatchNorm1d(num_features=self.planes),
            nn.SiLU(), #nn.ReLU(inplace=False),
            nn.Conv1d(
                in_channels=self.planes, 
                out_channels=self.planes, 
                kernel_size=fixed_kernel_size, 
                stride=2, 
                padding=2, 
                bias=False
            ),
            *resnet_layers,
            nn.BatchNorm1d(num_features=self.planes),
            nn.SiLU(), #nn.ReLU(inplace=False),
            nn.AvgPool1d(kernel_size=6, stride=6, padding=2)
        )
    
    def forward(self, x):
        # extract features using resnet 
        x = x.permute(0, 2, 1)
        out_sep = [conv(x) for conv in self.parallel_conv]
        out = torch.cat(out_sep, dim=2)
        out = self.resnet_part(out)
        out = out.reshape(out.shape[0], -1)
        # extract features using rnn
        rnn_out, _ = self.rnn(x.permute(0, 2, 1))
        new_rnn_h = self.pooling(rnn_out)
        # concatenate the features
        new_out = torch.cat([out, new_rnn_h], dim=1) 
        # total features = 424 = 24*6 + 128*2 
        # pass through the final fully connected layer
        result = self.fc(new_out)  
        
        return result


### Dilated Inception Wavenet Encoder

In [16]:
# from typing import List

# class DilatedInception(nn.Module):
#     def __init__(self, in_channels: int, out_channels: int, kernel_sizes: List[int], dilation: int) -> None:
#         super().__init__()
#         assert out_channels % len(kernel_sizes) == 0, "`out_channels` must be divisible by the number of kernel sizes."
#         hidden_dim = out_channels // len(kernel_sizes)
#         self.convs = nn.ModuleList([
#             nn.Conv1d(in_channels, hidden_dim, k, padding='same', dilation=dilation)
#             for k in kernel_sizes
#         ])

#     def forward(self, x):
#         outputs = [conv(x) for conv in self.convs]
#         out = torch.cat(outputs, dim=1)
#         return out

# class GatedTCN(nn.Module):
#     def __init__(self, in_dim: int, h_dim: int, kernel_sizes: List[int], dilation_factor: int, dropout: float = 0.0) -> None:
#         super().__init__()
#         self.filt = DilatedInception(in_dim, h_dim, kernel_sizes, dilation=dilation_factor)
#         self.gate = DilatedInception(in_dim, h_dim, kernel_sizes, dilation=dilation_factor)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, x):
#         x_filt = torch.tanh(self.filt(x))
#         x_gate = torch.sigmoid(self.gate(x))
#         h = x_filt * x_gate
#         h = self.dropout(h)
#         return h

# class WaveBlock(nn.Module):
#     def __init__(self, n_layers: int, in_dim: int, h_dim: int, kernel_sizes: List[int]) -> None:
#         super().__init__()
#         self.dilation_rates = [2**i for i in range(n_layers)]
#         self.in_conv = nn.Conv1d(in_dim, h_dim, kernel_size=1)
#         self.gated_tcns = nn.ModuleList([
#             GatedTCN(h_dim, h_dim, kernel_sizes, dilation)
#             for dilation in self.dilation_rates
#         ])
#         self.skip_convs = nn.ModuleList([
#             nn.Conv1d(h_dim, h_dim, kernel_size=1)
#             for _ in range(n_layers)
#             ])
#         self._initialize_weights()

#     def _initialize_weights(self):
#         nn.init.xavier_uniform_(self.in_conv.weight, gain=nn.init.calculate_gain('relu'))
#         nn.init.zeros_(self.in_conv.bias)
#         for conv in self.skip_convs:
#             nn.init.xavier_uniform_(conv.weight, gain=nn.init.calculate_gain('relu'))
#             nn.init.zeros_(conv.bias)

#     def forward(self, x):
#         # x: (B, C, L)
#         x = self.in_conv(x)
#         x_skip = x
#         for gated_tcn, skip_conv in zip(self.gated_tcns, self.skip_convs):
#             x = gated_tcn(x)
#             x = skip_conv(x)
#             x_skip = x_skip + x
#         return x_skip

# class DilatedWaveNet(nn.Module):
#     """WaveNet architecture with dilated inception conv, enhanced with list comprehension for input processing."""

#     def __init__(self, kernel_sizes: List[int]) -> None:
#         super().__init__()
#         self.kernel_sizes = kernel_sizes
        
#         # Initialize wave blocks with specified kernel sizes
#         self.wave_module = nn.Sequential(
#             WaveBlock(9, 8, 128, self.kernel_sizes), #12
#             WaveBlock(6, 128, 256, self.kernel_sizes), #8
#             WaveBlock(3, 256, 512, self.kernel_sizes), #4
#             WaveBlock(1, 512, 512, self.kernel_sizes), #1
#         )
#         self.pool_layer = nn.AdaptiveAvgPool1d(1)

#     def forward(self, x) -> torch.Tensor:
#         # x: (B, L, C)
#         bs, seq_len, n_channels = x.shape
#         x = x.permute(0, 2, 1) # -> (B, C, L)
#         # Process different parts of the input with list comprehension
#         x = self.wave_module(x)
#         x = self.pool_layer(x) # ->(B, 512, 1)
#         x = x.reshape(bs, n_channels, -1).reshape(bs, n_channels//2, 2, 64)
#         features = x.mean(dim=2).reshape(bs, -1) # -> (16, 256)
# #         pooled_outputs = [(x[:, i:i+64] + x[:, i+64:i+128]) / 2 for i in range(0, n_channels, 2)]
# #         # Combine the pooled features and reshape for classification
# #         features = torch.cat(pooled_outputs, dim=1).reshape(bs, -1)
       
#         return features

### Dilated ResNet 1D Encoder

In [17]:
# class ResnetBlock(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, dropout=0.0):
#         super(ResnetBlock, self).__init__()

#         self.bn1 = nn.BatchNorm1d(in_channels)
#         self.relu1 = nn.ReLU()
#         self.conv1 = nn.Conv1d(
#             in_channels, out_channels, kernel_size, 
#             stride=stride, 
#             padding=dilation*(kernel_size//2), 
#             dilation=dilation, 
#             bias=False)
#         self.drop1 = nn.Dropout(p=dropout)
#         self.bn2 = nn.BatchNorm1d(out_channels)
#         self.relu2 = nn.ReLU()
#         self.drop2 = nn.Dropout(p=dropout)
#         self.conv2 = nn.Conv1d(
#             out_channels, out_channels, kernel_size, 
#             stride=stride, 
#             padding=dilation*(kernel_size//2), 
#             dilation=dilation, 
#             bias=False)
        
#         self.bn3 = nn.BatchNorm1d(out_channels)
#         self.relu3 = nn.ReLU()
#         self.downsample = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)

#     def forward(self, x):
#         identity = x
#         identity = self.downsample(identity)

#         out = self.bn1(x)
#         out = self.relu1(out)
#         out = self.drop1(out)
#         out = self.conv1(out)

#         out = self.bn2(out)
#         out = self.relu2(out)
#         out = self.drop2(out)
#         out = self.conv2(out)

#         out = self.downsample(out)

#         out += identity
#         out = self.bn3(out)
#         out = self.relu3(out)

#         return out

# class DilatedResnet(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size, n_layers, expansion_factor=4):
#         super(DilatedResnet, self).__init__()

#         self.in_channels = in_channels
#         self.kernel_size = kernel_size
#         self.h_dim = out_channels // n_layers
        
#         fix_kernel_size = 5
#         self.conv1 = nn.Conv1d(
#             self.in_channels, self.h_dim, kernel_size=fix_kernel_size, stride=1, padding=fix_kernel_size//2
#             )

#         dilation_rates = [expansion_factor**i for i in range(n_layers)]

#         self.blocks = nn.ModuleList([
#             ResnetBlock(self.h_dim, self.h_dim, self.kernel_size, dilation=dilation)
#             for dilation in dilation_rates
#         ])

#     def forward(self, x):
#         x = self.conv1(x)
#         outputs = [ block(x) for block in self.blocks ]
#         output = torch.cat(outputs, dim=1)
        
#         return output

# class DilatedResnetEncoder(nn.Module):
#     def __init__(self, kernel_sizes=[3, 5, 7, 9], in_channels=8, planes=24, dilate_layers=[6,3,1], expansion_factor=4):
#         super(DilatedResnetEncoder, self).__init__()

#         self.in_channels = in_channels
#         self.planes = planes
#         self.kernel_sizes = kernel_sizes
#         self.dilate_layers = dilate_layers # must be 3 layers
#         self.expansion_factor = expansion_factor
        
#         # out_channels = self.planes * self.in_channels
#         # fix_kernel_size = 5
#         # self.conv1 = nn.Conv1d(
#         #     self.in_channels, out_channels, kernel_size=fix_kernel_size, stride=1, padding=fix_kernel_size//2
#         #     )
        
#         self.blocks = nn.ModuleList([
#             self._make_dilated_block(kernel_size)
#             for kernel_size in self.kernel_sizes
#         ])

#         bottleneck_in_channels = self.in_channels * self.planes * self.dilate_layers[1] * self.dilate_layers[2]
#         bottoleneck_out_channels = self.in_channels * self.planes

#         self.bottleneck = nn.Sequential(
#             nn.BatchNorm1d(num_features=bottleneck_in_channels),
#             nn.ReLU(),
#             nn.Conv1d(
#                 in_channels=bottleneck_in_channels,
#                 out_channels=bottoleneck_out_channels,
#                 kernel_size=1,
#                 stride=1,
#                 padding=0,
#                 bias=False
#             )
#         )
        
#         self.pooling = nn.AdaptiveAvgPool1d(1)
#         # self.blocks = nn.ModuleList([
#         #     nn.Sequential(*[
#         #         ResidualBlock(
#         #             out_channels, out_channels, kernel_size, dilation=dilation
#         #         ) for dilation in self.dilate_layers
#         #     ])
#         #     for kernel_size in self.kernel_sizes
#         # ])

#     def _make_dilated_block(self, kernel_size):
#         out_channel_1 = self.in_channels * self.planes
#         block_1 = DilatedResnet(self.in_channels, out_channel_1, kernel_size, self.dilate_layers[0], self.expansion_factor)

#         out_channel_2 = out_channel_1 * self.dilate_layers[1]
#         block_2 = DilatedResnet(out_channel_1, out_channel_2, kernel_size, self.dilate_layers[1], self.expansion_factor)

#         out_channel_3 = out_channel_2 * self.dilate_layers[2]
#         block_3 = DilatedResnet(out_channel_2, out_channel_3, kernel_size, self.dilate_layers[2], self.expansion_factor)

#         return nn.Sequential(block_1, block_2, block_3)
        
    
#     def forward(self, x):
#         # <- # [batch_size, seq_len=2000, in_channels=8]
#         x = x.permute(0, 2, 1)
#         # x = self.conv1(x)
#         outputs = [ block(x) for block in self.blocks ]
#         outputs = [ self.bottleneck(out) for out in outputs ]
#         output = torch.cat(outputs, dim=1)
#         output = self.pooling(output).squeeze(-1)
        
#         return output

In [18]:
train_dataset = EEGSeqDataset(train_all, ModelConfig, ALL_EEG_SIGNALS, mode="train")
train_loader = DataLoader(train_dataset, drop_last=True, batch_size=16, num_workers=4, pin_memory=True, shuffle=False)

model = ResNetGRU(config=ModelConfig, num_classes=6)

model.to(DEVICE)
for i, batch in enumerate(train_loader):
    X, y = batch
    X = X.to(DEVICE)
    y = y.to(DEVICE)
    print(f"X shape: {X.shape}")
    print(f"y shape: {y.shape}")
    
    y_pred = model(X)
    print(y_pred.shape)
    break 

del model, train_dataset, train_loader, X, y
torch.cuda.empty_cache()
gc.collect()

X shape: torch.Size([16, 2000, 8])
y shape: torch.Size([16, 6])
torch.Size([16, 6])


0

In [19]:
!nvidia-smi

Fri Apr  5 14:16:19 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.239.06   Driver Version: 470.239.06   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:0B:00.0 Off |                  N/A |
| 26%   36C    P2    55W / 260W |   1605MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Train

In [20]:
import warnings
warnings.filterwarnings("ignore")

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
class Trainer:

    def __init__(self, model, config, logger):

        self.model = model
        self.logger = logger
        self.config = config
        
        self.early_stop_rounds = config.EARLY_STOP_ROUNDS
        self.early_stop_counter = 0
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.kl_div_loss = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()
        self.gamma = config.REGULARIZATION
        
        # self.criterion = nn.KLDivLoss(reduction="batchmean")
    
    def criterion(self, y_pred, y_true, weights=None, mode='train'):
        kl_loss = self.kl_div_loss(F.log_softmax(y_pred, dim=1), y_true)
        if (self.gamma is not None) & (mode == 'train'):
            softmax_probs = F.softmax(y_pred, dim=1)  # Compute softmax probabilities
            entropy_loss = -(softmax_probs * torch.log(softmax_probs + 1e-9)).sum(dim=1).mean(dim=0) # Compute entropy, add epsilon to avoid log(0)
            return kl_loss - self.gamma * entropy_loss
        else:
            return kl_loss
        
    def train(self, train_loader, valid_loader, from_checkpoint=None):

        self.optimizer = AdamW(self.model.parameters(), lr=8e-3, weight_decay=self.config.WEIGHT_DECAY)

        # CosineAnnealingWarmRestarts( 
        #     self.optimizer,
        #     T_0=20,
        #     eta_min=1e-6,
        #     T_mult=1,
        #     last_epoch=-1
        # )
        self.scheduler =  OneCycleLR(
            self.optimizer,
            max_lr=1e-4,
            epochs=self.config.EPOCHS,
            steps_per_epoch=len(train_loader),
            pct_start=0.1,
            anneal_strategy="cos",
            final_div_factor=100,
        )

        if from_checkpoint is not None:
            self.model.load_state_dict(torch.load(from_checkpoint, map_location=self.device))

        self.model.to(self.device)
        best_weights, best_preds, best_loss = None, None, float("inf")
        loss_records = {"train": [], "valid": []}

        for epoch in range(self.config.EPOCHS):
            start_epoch = time()

            train_loss, _ = self._train_or_valid_epoch(epoch, train_loader, is_train=True)
            valid_loss, valid_preds = self._train_or_valid_epoch(epoch, valid_loader, is_train=False)

            loss_records["train"].append(train_loss)
            loss_records["valid"].append(valid_loss)

            elapsed = time() - start_epoch

            info = f"{'-' * 100}\nEpoch {epoch + 1} - "
            info += f"Average Loss: (train) {train_loss:.4f}; (valid) {valid_loss:.4f} | Time: {elapsed:.2f}s"
            self.logger.info(info)

            if valid_loss < best_loss:
                best_loss = valid_loss
                best_weights = self.model.state_dict()
                best_preds = valid_preds
                self.logger.info(f"Best model found in epoch {epoch + 1} | valid loss: {best_loss:.4f}")
                self.early_stop_counter = 0
            
            else:
                self.early_stop_counter += 1
                if self.early_stop_counter >= self.early_stop_rounds:
                    self.logger.info(f"Early stopping at epoch {epoch + 1}")
                    break

        return best_weights, best_preds, loss_records

    def _train_or_valid_epoch(self, epoch_id, dataloader, is_train=True):

        self.model.train() if is_train else self.model.eval()
        mode = "Train" if is_train else "Valid"

        len_loader = len(dataloader)
        scaler = GradScaler(enabled=self.config.AMP)
        loss_meter, predicts_record = AverageMeter(), []

        start = time()
        pbar = tqdm(dataloader, total=len(dataloader), unit="batch", desc=f"{mode} [{epoch_id}]")
        for step, (X, y) in enumerate(pbar):
            X, y = X.to(self.device), y.to(self.device)

            if is_train:
                with autocast(enabled=self.config.AMP):
                    y_pred = self.model(X)
                    loss = self.criterion(y_pred, y)
                if self.config.GRADIENT_ACCUMULATION_STEPS > 1:
                    loss = loss / self.config.GRADIENT_ACCUMULATION_STEPS
                scaler.scale(loss).backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.MAX_GRAD_NORM)
                if (step + 1) % self.config.GRADIENT_ACCUMULATION_STEPS == 0:
                    scaler.step(self.optimizer)
                    scaler.update()
                    self.optimizer.zero_grad()
                    self.scheduler.step()
            else:
                with torch.no_grad():
                    y_pred = self.model(X)
                    loss = self.criterion(y_pred, y, mode='valid')
                if self.config.GRADIENT_ACCUMULATION_STEPS > 1:
                    loss = loss / self.config.GRADIENT_ACCUMULATION_STEPS
                
                predicts_record.append(y_pred.to('cpu').numpy())
            
            loss_meter.update(loss.item(), y.size(0))
            end = time()

            if (step % self.config.PRINT_FREQ == 0) or (step == (len_loader - 1)):
                lr = self.scheduler.get_last_lr()[0]
                info = f"Epoch {epoch_id + 1} [{step}/{len_loader}] | {mode} Loss: {loss_meter.avg:.4f}"
                if is_train:
                    info += f" Grad: {grad_norm:.4f} LR: {lr:.4e}"
                info += f" | Elapse: {end - start:.2f}s"
                print(info)

        if not is_train:
            predicts_record = np.concatenate(predicts_record)
            
        return loss_meter.avg, predicts_record


In [21]:
def train_fold(model, fold_id, train_folds, valid_folds, logger, stage=1, checkpoint=None):

    train_dataset = EEGSeqDataset(train_folds, ModelConfig, ALL_EEG_SIGNALS, mode="train")
    valid_dataset = EEGSeqDataset(valid_folds, ModelConfig, ALL_EEG_SIGNALS, mode="valid")

    # ======== DATALOADERS ==========
    loader_kwargs = {
        "batch_size": ModelConfig.BATCH_SIZE,
        "num_workers": ModelConfig.NUM_WORKERS,
        "pin_memory": True,
        "shuffle": False,
    }

    train_loader = DataLoader(train_dataset, drop_last=True, collate_fn=None, **loader_kwargs)
    valid_loader = DataLoader(valid_dataset, drop_last=False, collate_fn=None, **loader_kwargs)

    if checkpoint is not None:
        print(f"Loading model from checkpoint: {checkpoint}")

    trainer = Trainer(model, ModelConfig, logger)
    best_weights, best_preds, loss_records = trainer.train(
        train_loader, valid_loader, from_checkpoint=checkpoint)

    save_model_name = f"{ModelConfig.MODEL_NAME}_fold_{fold_id}_stage_{stage}.pth"
    torch.save(best_weights, os.path.join(PATHS.OUTPUT_DIR, save_model_name))

    del train_dataset, valid_dataset, train_loader, valid_loader
    torch.cuda.empty_cache()
    gc.collect()

    return best_preds, loss_records

In [22]:
def evaluate_oof(oof_df):
    '''
    Evaluate the out-of-fold dataframe using KL Divergence (torch and kaggle)
    '''
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    labels = torch.tensor(oof_df[TARGETS].values.astype('float32'))
    preds = F.log_softmax(
        torch.tensor(oof_df[TARGETS_PRED].values.astype('float32'), requires_grad=False),
        dim=1
    )
    kl_torch = kl_loss(preds, labels).item()

    return kl_torch

In [23]:
from kl_divergence import score as kaggle_score 
from sklearn.metrics import confusion_matrix
import seaborn as sns

TARGET2ID = {'Seizure': 0, 'LPD': 1, 'GPD': 2, 'LRDA': 3, 'GRDA': 4, 'Other': 5}

def calc_kaggle_score(oof_df):
    submission_df = oof_df[['eeg_id']+TARGETS_PRED].copy()
    submission_df.columns = ['eeg_id'] + TARGETS
    solution_df = oof_df[['eeg_id']+TARGETS].copy()
    return kaggle_score(solution_df, submission_df, 'eeg_id')

def analyze_oof(oof_csv):

    kl_criteria = nn.KLDivLoss(reduction='batchmean')
    softmax = nn.Softmax(dim=1)

    oof_df = pd.read_csv(oof_csv)
    oof_df['target_pred'] = oof_df[TARGETS_PRED].apply(lambda x: np.argmax(x), axis=1)
    oof_df['target_id'] = oof_df[TARGETS].apply(lambda x: np.argmax(x), axis=1)
    
    oof_df["kl_loss"] = oof_df.apply(
    lambda row: 
        kl_criteria(
            F.log_softmax(
                    torch.tensor(row[TARGETS_PRED].values.astype(np.float32)).unsqueeze(0)
                , dim=1
                ), 
            torch.tensor(row[TARGETS].values.astype(np.float32))
            ).numpy(),
    axis=1)

    oof_df["kl_loss"] = oof_df['kl_loss'].astype(np.float32)

    oof_df[TARGETS_PRED] = softmax( torch.tensor(oof_df[TARGETS_PRED].values.astype(np.float32)) )

    oof_df.head()

    return oof_df

In [24]:
def prepare_k_fold(df, k_folds=5):

    kf = KFold(n_splits=k_folds, shuffle=True, random_state=ModelConfig.SEED)
    unique_spec_id = df['spectrogram_id'].unique()
    df['fold'] = k_folds

    for fold, (train_index, valid_index) in enumerate(kf.split(unique_spec_id)):
        df.loc[df['spectrogram_id'].isin(unique_spec_id[valid_index]), 'fold'] = fold

    return df

In [25]:
# Major Train Loop
# ================== Logger ==================
logger.info(f"{'*' * 100}")
logger.info(f"Script Start: {ctime()}")
logger.info(f"Model Configurations:")
for key, value in ModelConfig.__dict__.items():
    if not key.startswith("__"):
        logger.info(f"{key}: {value}")
logger.info(f"{'*' * 100}")

# ================== Prepare Training ==================
oof_stage_1, oof_stage_2 = pd.DataFrame(), pd.DataFrame()
loss_history_1, loss_history_2 = [], []
t_start = time()

K_FOLDS = 5
train_all = prepare_k_fold(train_all, k_folds=K_FOLDS)

for fold in range(0, K_FOLDS):
    tik_total = time()
    tik = time()

    valid_folds = train_all[(train_all['fold'] == fold) ].reset_index(drop=True)
    train_folds = train_all[(train_all['fold'] != fold) ].reset_index(drop=True)
    train_size, valid_size = train_folds.shape[0], valid_folds.shape[0]

    # ================== Stage 1: Train ====================
    # model = ResNetGRU(
    #     kernels=ModelConfig.RESNET_GRU_KERNELS, 
    #     in_channels=8, 
    #     fixed_kernel_size=ModelConfig.RESNET_GRU_FIXED_KERNEL_SIZE,
    #     hidden_size=ModelConfig.RESNET_GRU_HIDDEN_SIZE,
    #     num_classes=6
    #     )
    model = ResNetGRU(config=ModelConfig, num_classes=6)

    ## STAGE 1
    logger.info(f"{'=' * 100}\nFold: {fold}\n{'=' * 100}")
    logger.info(f"- Stage 1 | Train: {train_size}; Valid: {valid_size} -")
    valid_predicts, loss_records = train_fold(
        model, fold, train_folds, valid_folds, logger, stage=1, checkpoint=None)

    loss_history_1.append(loss_records)
    valid_folds[TARGETS_PRED] = valid_predicts
    kl_loss_torch = evaluate_oof(valid_folds)
    info = f"{'=' * 100}\nFold {fold} Valid Loss: {kl_loss_torch}\n"
    info += f"Elapse: {(time() - tik) / 60:.2f} min \n{'=' * 100}"
    logger.info(info)

    oof_stage_1 = pd.concat([oof_stage_1, valid_folds], axis=0).reset_index(drop=True)
    oof_stage_1.to_csv(os.path.join(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_oof_1.csv"), index=False)

    # ================== Stage 2: Train ====================
    tik = time()
    # model = ResNetGRU(
    #     kernels=ModelConfig.RESNET_GRU_KERNELS, 
    #     in_channels=8, 
    #     fixed_kernel_size=ModelConfig.RESNET_GRU_FIXED_KERNEL_SIZE,
    #     hidden_size=ModelConfig.RESNET_GRU_HIDDEN_SIZE,
    #     num_classes=6
    #     )
    model = ResNetGRU(config=ModelConfig, num_classes=6)
    
    train_folds_2 = train_hard[~train_hard['eeg_id'].isin(valid_folds['eeg_id'])].reset_index(drop=True)
    valid_folds_2 = train_hard[ train_hard['eeg_id'].isin(valid_folds['eeg_id'])].reset_index(drop=True)
    train_size = train_folds_2.shape[0]
    valid_size = valid_folds_2.shape[0]
    
    ## STAGE 2
    logger.info(f"- Stage 2 | Train: {train_size}; Valid: {valid_size} -")

    # model_dir = "/home/shiyi/kaggle_hms/outputs/ResnetGRU_Originalsplit/Reg015"
    # checkpoint = list(Path(model_dir).glob(f"*_fold_{fold}_stage_1.pth"))[0]
    checkpoint = list(Path(PATHS.OUTPUT_DIR).glob(f"{ModelConfig.MODEL_NAME}_fold_{fold}_stage_1.pth"))[0]

    valid_predicts, loss_records = train_fold(
        model, fold, train_folds_2, valid_folds_2, logger, stage=2, checkpoint=checkpoint)
    
    loss_history_2.append(loss_records)
    valid_folds_2[TARGETS_PRED] = valid_predicts
    kl_loss_torch = evaluate_oof(valid_folds_2)
    info = f"{'=' * 100}\nFold {fold} Valid Loss: {kl_loss_torch}\n"
    info += f"Elapse: {(time() - tik) / 60:.2f} min \n{'=' * 100}"
    logger.info(info)

    oof_stage_2 = pd.concat([oof_stage_2, valid_folds_2], axis=0).reset_index(drop=True)
    oof_stage_2.to_csv(os.path.join(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_oof_2.csv"), index=False)

    logger.info(f"Fold {fold} Elapse: {(time() - tik_total) / 60:.2f} min")

info = f"{'=' * 100}\nTraining Complete!\n"
cv_results_1 = evaluate_oof(oof_stage_1)
cv_results_2 = evaluate_oof(oof_stage_2)
info += f"CV Result: Stage 1: {cv_results_1} | Stage 2: {cv_results_2}\n"
info += f"Elapse: {(time() - t_start) / 60:.2f} min \n{'=' * 100}"
logger.info(info)

****************************************************************************************************
Script Start: Fri Apr  5 14:16:28 2024
Model Configurations:
SEED: 20
SPLIT_ENTROPY: 5.5
MODEL_NAME: ResnetGRU_v1_LB048
MODEL_BACKBONE: reset_gru
BATCH_SIZE: 32
EPOCHS: 20
EARLY_STOP_ROUNDS: 5
GRADIENT_ACCUMULATION_STEPS: 1
DROP_RATE: 0.15
DROP_PATH_RATE: 0.25
WEIGHT_DECAY: 0.01
AMP: True
PRINT_FREQ: 100
NUM_WORKERS: 0
MAX_GRAD_NORM: 10000000.0
REGULARIZATION: 0.15
RESNET_GRU_BANDPASS: None
RESNET_GRU_IN_CHANNELS: 8
RESNET_GRU_KERNELS: [3, 5, 7, 9, 11]
RESNET_GRU_FIXED_KERNEL_SIZE: 5
RESNET_GRU_DOWNSAMPLE: 5
RESNET_GRU_HIDDEN_SIZE: 304
RESNET_GRU_DILATED: False
****************************************************************************************************
Fold: 0
- Stage 1 | Train: 16195; Valid: 3988 -


Train [0]:   0%|          | 0/506 [00:00<?, ?batch/s]

Epoch 1 [0/506] | Train Loss: 1.2684 Grad: 80195.3828 LR: 4.0002e-06 | Elapse: 0.23s
Epoch 1 [100/506] | Train Loss: 1.1780 Grad: 91877.4531 LR: 6.3447e-06 | Elapse: 6.03s
Epoch 1 [200/506] | Train Loss: 1.1851 Grad: 78362.3906 LR: 1.3062e-05 | Elapse: 11.85s
Epoch 1 [300/506] | Train Loss: 1.1739 Grad: 52261.8555 LR: 2.3509e-05 | Elapse: 17.65s
Epoch 1 [400/506] | Train Loss: 1.1569 Grad: 65165.1328 LR: 3.6686e-05 | Elapse: 23.46s
Epoch 1 [500/506] | Train Loss: 1.1345 Grad: 50491.7734 LR: 5.1329e-05 | Elapse: 29.27s
Epoch 1 [505/506] | Train Loss: 1.1335 Grad: 49006.5547 LR: 5.2075e-05 | Elapse: 29.56s


Valid [0]:   0%|          | 0/125 [00:00<?, ?batch/s]

Epoch 1 [0/125] | Valid Loss: 1.2688 | Elapse: 0.07s
Epoch 1 [100/125] | Valid Loss: 1.3073 | Elapse: 5.09s


----------------------------------------------------------------------------------------------------
Epoch 1 - Average Loss: (train) 1.1335; (valid) 1.3002 | Time: 35.85s
Best model found in epoch 1 | valid loss: 1.3002


Epoch 1 [124/125] | Valid Loss: 1.3002 | Elapse: 6.28s


Train [1]:   0%|          | 0/506 [00:00<?, ?batch/s]

Epoch 2 [0/506] | Train Loss: 1.0906 Grad: 53102.0742 LR: 5.2224e-05 | Elapse: 0.06s
Epoch 2 [100/506] | Train Loss: 1.0051 Grad: 114642.4766 LR: 6.6890e-05 | Elapse: 5.93s
Epoch 2 [200/506] | Train Loss: 1.0109 Grad: 54453.7422 LR: 8.0129e-05 | Elapse: 11.80s
Epoch 2 [300/506] | Train Loss: 1.0020 Grad: 49530.3711 LR: 9.0674e-05 | Elapse: 17.65s
Epoch 2 [400/506] | Train Loss: 0.9903 Grad: 42100.6680 LR: 9.7515e-05 | Elapse: 23.48s
Epoch 2 [500/506] | Train Loss: 0.9761 Grad: 49098.9688 LR: 9.9996e-05 | Elapse: 29.30s
Epoch 2 [505/506] | Train Loss: 0.9753 Grad: 72887.2188 LR: 1.0000e-04 | Elapse: 29.59s


Valid [1]:   0%|          | 0/125 [00:00<?, ?batch/s]

Epoch 2 [0/125] | Valid Loss: 0.9886 | Elapse: 0.05s
Epoch 2 [100/125] | Valid Loss: 1.1559 | Elapse: 5.03s


----------------------------------------------------------------------------------------------------
Epoch 2 - Average Loss: (train) 0.9753; (valid) 1.1520 | Time: 35.80s
Best model found in epoch 2 | valid loss: 1.1520


Epoch 2 [124/125] | Valid Loss: 1.1520 | Elapse: 6.21s


Train [2]:   0%|          | 0/506 [00:00<?, ?batch/s]

Epoch 3 [0/506] | Train Loss: 1.0148 Grad: 54527.7500 LR: 1.0000e-04 | Elapse: 0.06s
Epoch 3 [100/506] | Train Loss: 0.8859 Grad: 143155.5938 LR: 9.9969e-05 | Elapse: 5.90s
Epoch 3 [200/506] | Train Loss: 0.8889 Grad: 76853.5703 LR: 9.9879e-05 | Elapse: 11.75s
Epoch 3 [300/506] | Train Loss: 0.8759 Grad: 64058.4180 LR: 9.9729e-05 | Elapse: 17.59s
Epoch 3 [400/506] | Train Loss: 0.8652 Grad: 101814.7109 LR: 9.9520e-05 | Elapse: 23.43s


In [None]:
# plot loss history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

for i, loss in enumerate(loss_history_1):
    ax1.plot(loss['train'], marker="*", ls="-", label=f"Fold {i} Train")
    ax1.plot(loss['valid'], marker="o", ls=":", label=f"Fold {i} Valid")

for i, loss in enumerate(loss_history_2):
    ax2.plot(loss['train'], marker="*", ls="-", label=f"Fold {i} Train")
    ax2.plot(loss['valid'], marker="o", ls=":", label=f"Fold {i} Valid")

ax1.set_title("Stage 1 Loss")
ax2.set_title("Stage 2 Loss")

for ax in (ax1, ax2):
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True)

fig.tight_layout()
fig.savefig(Path(PATHS.OUTPUT_DIR) / f"{ModelConfig.MODEL_NAME}_loss_history.png")
plt.show()

In [None]:
csv_path = f'./outputs/{ModelConfig.MODEL_NAME}_oof_1.csv'
print("CSV Path: ", csv_path)

oof_df = analyze_oof(csv_path)

print("Kaggle Score: ", calc_kaggle_score(oof_df))
print("Average KL Loss: ", oof_df["kl_loss"].mean())

display(oof_df.head())

# plot confusion matrix
cm = confusion_matrix(oof_df['target_id'], oof_df['target_pred']) # (y_true, y_pred)
cm = cm / cm.sum(axis=1)[:, np.newaxis]

fig = plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=TARGET2ID.keys(), yticklabels=TARGET2ID.keys())
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.title(csv_path.split('/')[-1].split('.')[0], fontsize=12)
fig.tight_layout()
fig.savefig(f"./outputs/{csv_path.split('/')[-1].split('.')[0]}_CM.png")
plt.show()

In [None]:
fig, axes = plt.subplots(5, 5, figsize=(15, 15), sharex=True, sharey=True)
oof_samples = oof_df.sample(axes.size)

for i, ax in enumerate(axes.flatten()):
    row = oof_samples.iloc[i]
    x = np.arange(6)
    ax.plot(x, row[TARGETS].T, marker="o", ls="-", label="True")
    ax.plot(x, row[TARGETS_PRED].T, marker="*", ls="--", label="Predicted")
    ax.set_title(f"{row['target']} | KL Loss: {row['kl_loss']:.4f}")
    ax.legend()
    
fig.tight_layout()
fig.savefig(f"./outputs/{csv_path.split('/')[-1].split('.')[0]}_samples.png")
plt.show()

In [None]:
csv_path = f'./outputs/{ModelConfig.MODEL_NAME}_oof_2.csv'
print("CSV Path: ", csv_path)

oof_df = analyze_oof(csv_path)

print("Kaggle Score: ", calc_kaggle_score(oof_df))
print("Average KL Loss: ", oof_df["kl_loss"].mean())

display(oof_df.head())

# plot confusion matrix
cm = confusion_matrix(oof_df['target_id'], oof_df['target_pred']) # (y_true, y_pred)
cm = cm / cm.sum(axis=1)[:, np.newaxis]

fig = plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=TARGET2ID.keys(), yticklabels=TARGET2ID.keys())
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.title(csv_path.split('/')[-1].split('.')[0], fontsize=12)
fig.tight_layout()
fig.savefig(f"./outputs/{csv_path.split('/')[-1].split('.')[0]}_CM.png")
plt.show()

In [None]:
fig, axes = plt.subplots(5, 5, figsize=(15, 15), sharex=True, sharey=True)
oof_samples = oof_df.sample(axes.size)

for i, ax in enumerate(axes.flatten()):
    row = oof_samples.iloc[i]
    x = np.arange(6)
    ax.plot(x, row[TARGETS].T, marker="o", ls="-", label="True")
    ax.plot(x, row[TARGETS_PRED].T, marker="*", ls="--", label="Predicted")
    ax.set_title(f"{row['target']} | KL Loss: {row['kl_loss']:.4f}")
    ax.legend()
    
fig.tight_layout()
fig.savefig(f"./outputs/{csv_path.split('/')[-1].split('.')[0]}_samples.png")
plt.show()

In [None]:
oof_stage_2_full = pd.DataFrame()

for fold in range(1):

    valid_folds = train_all[train_all['fold'] == fold].reset_index(drop=True)

    # predict labels using stage-2 models
    model = ResNetGRU(
        kernels=ModelConfig.RESNET_GRU_KERNELS, 
        in_channels=8, 
        fixed_kernel_size=ModelConfig.RESNET_GRU_FIXED_KERNEL_SIZE,
        hidden_size=ModelConfig.RESNET_GRU_HIDDEN_SIZE,
        num_classes=6
        )
    
    check_point = os.path.join(
        PATHS.OUTPUT_DIR,
        f"{ModelConfig.MODEL_NAME}_fold_{fold}_stage_2.pth"
    )

    model.load_state_dict(torch.load(check_point, map_location=DEVICE))

    loader_kwargs = {
        "batch_size": ModelConfig.BATCH_SIZE,
        "num_workers": ModelConfig.NUM_WORKERS,
        "pin_memory": True,
        "shuffle": False,
    }

    valid_dataset = EEGSeqDataset(
        valid_folds, ModelConfig, ALL_EEG_SIGNALS, mode="valid", downsample=ModelConfig.RESNET_GRU_DOWNSAMPLE)
    valid_loader = DataLoader(valid_dataset, drop_last=False, collate_fn=None, **loader_kwargs)

    model.to(DEVICE)
    model.eval()

    valid_predicts = []
    with torch.no_grad():
        for X, y in valid_loader:
            X = X.to(DEVICE)
            y_pred = model(X)
            valid_predicts.append(y_pred.to('cpu').numpy())

    valid_predicts = np.concatenate(valid_predicts)
    valid_folds[TARGETS_PRED] = valid_predicts
    oof_stage_2_full = pd.concat([oof_stage_2, valid_folds], axis=0).reset_index(drop=True)

    del valid_dataset, valid_loader
    torch.cuda.empty_cache()
    gc.collect()

    oof_stage_2_full.to_csv(os.path.join(PATHS.OUTPUT_DIR, f"{ModelConfig.MODEL_NAME}_oof_2_full.csv"), index=False)

cv_results = evaluate_oof(oof_stage_2_full)
logger.info(f"{'=' * 100}\nCV Result (Stage 2 Full): {cv_results}\n{'=' * 100}")


Reg = 0.15, Downsample = 0, CV Result (Stage 2 Full): 0.639643669128418

In [None]:
csv_path = f'./outputs/Resnet_SeqGRU_ChrisNO_NoReg_oof_2_full.csv'
print("CSV Path: ", csv_path)

oof_df = analyze_oof(csv_path)

print("Kaggle Score: ", calc_kaggle_score(oof_df))
print("Average KL Loss: ", oof_df["kl_loss"].mean())

display(oof_df.head())

# plot confusion matrix
cm = confusion_matrix(oof_df['target_id'], oof_df['target_pred']) # (y_true, y_pred)
cm = cm / cm.sum(axis=1)[:, np.newaxis]

fig = plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, cmap='Blues', xticklabels=TARGET2ID.keys(), yticklabels=TARGET2ID.keys())
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.title(csv_path.split('/')[-1].split('.')[0], fontsize=12)
fig.tight_layout()
fig.savefig(f"./outputs/{csv_path.split('/')[-1].split('.')[0]}_CM.png")
plt.show()