In [1]:
import os
import gc
import cv2
import time
import random
from monai.inferers import sliding_window_inference

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
from pytorch_toolbelt import losses as L

# Utils
from tqdm.auto import tqdm

# For Image Models
import timm
from timm.utils.model_ema import ModelEmaV3

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings("ignore")

## using gpu:1
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything()

In [2]:
class WidthAttention(nn.Module):
    def __init__(self, in_ch, width: int):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, kernel_size=(1, 1)),
            nn.BatchNorm2d(in_ch),
            nn.SiLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_ch, width),
            nn.Sigmoid()
        )

    def forward(self, x):
        attention = self.attention(x)
        attention = attention.unsqueeze(1).unsqueeze(1)
        return x * attention

    
    
class Slide_Window_Model(nn.Module):
    def __init__(self, model_name, cls):
        super().__init__()
        self.cls= cls
        self.model = timm.create_model(model_name, 
                                       pretrained=True, 
                                       num_classes=cls,
                                       drop_rate= CFG['drop_out'], 
                                       drop_path_rate= CFG['drop_path'])
        
        self.gp= self.model.global_pool
        self.out= self.model.classifier
        self.model.global_pool= nn.Identity()
        self.model.classifier= nn.Identity()
        self.att= nn.Sequential(
                        WidthAttention(self.out.in_features, 29),
                    )
        
    def forward(self, image):
        x = self.model(image)  ## (1,1280,13,21)
        x = self.att(x)  ## (1,1280,13,21)
        x = self.gp(x)   ## (1,1280)
        x = self.out(x)  ## (1,6)
        return x if self.training else x.view(-1, self.cls, 1, 1)
    
    
# attention= Slide_Window_Model('tf_efficientnet_b0_ns', 6)
# x= torch.rand(1,3,400,656)
# attention(x).shape

In [3]:
params = {
    "num_masks_x": (1, 10),
#     "num_masks_y": (1, 5),    
#     "mask_y_length": (5, 10),
    "mask_x_length": (5, 10),
    "fill_value": 0,
}

def get_train_transform(img_size):
    return A.Compose([
        A.PadIfNeeded(min_height=400, min_width=CFG['img_crop'], border_mode=0, p=1),
        A.RandomCrop(width=CFG['img_crop'], height=400, p=1),
        
        A.RandomBrightnessContrast(brightness_limit=0.6, contrast_limit=0., p=0.5),
        
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.XYMasking(**params, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.05, rotate_limit= 15,
                                        interpolation=cv2.INTER_LINEAR, border_mode=0, p=0.7),
        ToTensorV2(p=1.0),
    ])


def get_test_transform(img_size):
    return A.Compose([
        A.PadIfNeeded(min_height=400, min_width=CFG['img_crop'], border_mode=0, p=1),
        ToTensorV2(p=1.0),
    ])

In [4]:
from preprocessing import spectrogram_from_eeg

class Customize_Dataset(Dataset):
    def __init__(self, df, transforms=None, training=False):
        self.df = df
        self.transforms = transforms
        self.training= training
    
    def mixup_aug(self, img_1, mask_1, 
                        img_2, mask_2):
        """
        img: numpy array of shape (height, width,channel)
        mask: numpy array of shape (height, width,channel)
        """
        ## mixup
        weight= np.random.beta(a=0.5, b=0.5)
        img= img_1*weight + img_2*(1-weight)
        mask= mask_1*weight + mask_2*(1-weight)
        return img, mask
    
    def read_data(self, data):
        img= np.load(data['npy_path'])
        label= np.array(eval(data['soft_label']))
        pl_label= np.array(eval(data[f"PL_prob"]))
        if data['voter']<=7 and CFG['finetune'] and self.training:
            label= label*0. + pl_label*1.0
        return img, label
    
    def __getitem__(self, index):
        data = self.df.loc[index]
        img, label= self.read_data(data)
        
        # use mixup
        if self.training and np.random.rand() >= (1-CFG['mixup']) and img.shape[1]<=656:
            img_1= img
            label_1= np.array(label)
            while True:
                indx= np.random.randint(len(self.df))
                data= self.df.loc[indx]
                img_2, label_2= self.read_data(data)
                if img_2.shape[1] > img.shape[1]: 
                    img_2= img_2[:, :img.shape[1]]
                    break
            img, label= self.mixup_aug(img_1, label_1, 
                                       img_2, label_2)
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': torch.tensor(img, dtype=torch.float32),
            'label': torch.tensor(label, dtype=torch.float32),
        }
    
    def __len__(self):
        return len(self.df)

In [5]:
class Customize_loss(nn.Module):
    def  __init__(self):
        super().__init__()
        self.CrossEntropy= nn.CrossEntropyLoss(weight= None, label_smoothing=0.)
        self.FocalCosineLoss= L.FocalCosineLoss()
        self.kl_loss = nn.KLDivLoss()
        self.bce= nn.BCELoss()
        self.mse= nn.MSELoss()
    
    def forward(self, y_pred, y_true):
        loss= 1*self.kl_loss(y_pred.log_softmax(dim=-1), y_true) + 1*self.mse(y_pred.softmax(dim=-1), y_true)
        return loss

In [6]:
def train_epoch(dataloader, model, criterion, optimizer, model_ema):
    scaler= amp.GradScaler()
    model.train()

    ep_loss= []
    for i, data in enumerate(tqdm(dataloader)):

        imgs= data['image'].to('cuda')
        labels= data['label'].to('cuda')
        
        with amp.autocast():
            preds= model(imgs)
            loss= criterion(preds, labels)
            ep_loss.append(loss.item())
            loss/= CFG['gradient_accumulation']
            scaler.scale(loss).backward()
            
            if (i+1) % CFG['gradient_accumulation']== 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                
            if model_ema: model_ema.update(model)
                
    return np.mean(ep_loss)

In [7]:
from metrics import *

def valid_epoch(dataloader, model, criterion):
    model.eval()
    
    ep_loss= []
    all_pred= []
    all_label= []
    for i, data in enumerate(tqdm(dataloader)):

        imgs= data['image'].to('cuda')
        labels= data['label'].to('cuda')
        all_label.extend(labels.cpu().numpy())
        
        with torch.no_grad():
            preds= sliding_window_inference(imgs, 
                                            roi_size=(-1,CFG['img_crop']), 
                                            mode= 'gaussian',
                                            sw_batch_size=1, 
                                            predictor=model)
            preds= preds.view(preds.shape[0],model.cls,-1).mean(dim=-1)
            loss= criterion(preds, labels)
            ep_loss.append(loss.item())
        all_pred.extend(preds.cpu().softmax(dim=-1).numpy())
    
    ## caculate metrics
    soft_label= all_label.copy()
    all_label= np.array(all_label).argmax(1)
    all_pred= np.array(all_pred)
    
    acc= Accuracy(all_pred, all_label)
    print(f'accuracy: {acc}')
    recall= Mean_Recall(all_pred, all_label)
    print(f'mean_recall: {recall}')
    kl_score= kl_divergence(soft_label, all_pred)
    print(f'kl_divergence: {kl_score}')
    
    score= kl_score
    return np.mean(ep_loss), score

# CFG

In [8]:
CFG= {
    'fold': 1,
    'epoch': 25,
    'model_name': 'tf_efficientnet_b0_ns',
    
    'img_size': None,
    'img_crop': 912,
    
    'batch_size': 16,
    'gradient_accumulation': 1,
    'gradient_checkpoint': False,
    'drop_out': 0.3,
    'drop_path': 0.2,
    'mixup': 0.3,
    'EMA': 0.995,
    
    'lr': 3e-4,
    'weight_decay': 0.,
    
    'num_classes': 6,
    'load_model': False,
    'save_model': './train_model',
    
    'finetune': True,
}

if CFG['finetune']:
    print('finetune')
    CFG['epoch']= 20
    CFG['load_model']= f"{CFG['save_model']}/cv{CFG['fold']}_best.pth"
    CFG['lr']= 3e-4

finetune


# Prepare Dataset

In [9]:
df= pd.read_csv('../Data/train_npy_PL1.csv')
if CFG['finetune']: 
    df.loc[ (df['voter']<7)&(df['fold']==CFG['fold']), 'fold' ]= -1
    df= df[(df['voter']>7) | (df['fold']!=CFG['fold'])]
df= df.drop_duplicates(subset=['spectrogram_id'])

train_df= df[df['fold']!=CFG['fold']].reset_index(drop=True)
valid_df= df[df['fold']==CFG['fold']].reset_index(drop=True)
print(f'train dataset: {len(train_df)}')
print(f'valid dataset: {len(valid_df)}')

train_dataset= Customize_Dataset(train_df, get_train_transform(CFG['img_size']), training=True)
valid_dataset= Customize_Dataset(valid_df, get_test_transform(CFG['img_size']), training=False)

train_loader= DataLoader(train_dataset, batch_size= CFG['batch_size'], shuffle=True, num_workers=0)
valid_loader= DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
train_df.head()

train dataset: 10417
valid dataset: 721


Unnamed: 0,eeg_id,spectrogram_id,image_path,expert_consensus,patient_id,label,soft_label,fold,npy_path,time_length,voter,PL_prob_cv0,PL_prob_cv1,PL_prob_cv2,PL_prob_cv3,PL_prob_cv4,PL_prob
0,1628180742,353733,../Data/train_eegs/1628180742.parquet,Seizure,42516,0,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]",4,../Data/train_npy/0.npy,576,3,"[0.3489752411842346, 0.10252505540847778, 0.03...","[0.4545011520385742, 0.08543295413255692, 0.00...","[0.35866212844848633, 0.06348183751106262, 0.0...","[0.42922917008399963, 0.04415779933333397, 0.0...","[0.32202839851379395, 0.1123746782541275, 0.00...","[0.4263441264629364, 0.058971207588911057, 0.0..."
1,387987538,1084844,../Data/train_eegs/387987538.parquet,LRDA,4264,3,"[0.0, 0.0, 0.0, 1.0, 0.0, 0.0]",0,../Data/train_npy/3.npy,562,3,"[0.039219897240400314, 0.024717289954423904, 0...","[0.03628334030508995, 0.024995330721139908, 0....","[0.021519465371966362, 0.022249963134527206, 0...","[0.028398025780916214, 0.030236368998885155, 0...","[0.07453320920467377, 0.04387536272406578, 0.0...","[0.05356224253773689, 0.041943296790122986, 0...."
2,2175806584,1219001,../Data/train_eegs/2175806584.parquet,Seizure,23435,0,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0]",3,../Data/train_npy/4.npy,666,3,"[0.2235855609178543, 0.11480796337127686, 0.50...","[0.22932660579681396, 0.04908479005098343, 0.6...","[0.12347868829965591, 0.033894166350364685, 0....","[0.08356039226055145, 0.06833972036838531, 0.7...","[0.12395842373371124, 0.053466178476810455, 0....","[0.411081999540329, 0.04547161981463432, 0.428..."
3,1202099836,1353070,../Data/train_eegs/1202099836.parquet,Other,34554,5,"[0.0, 0.0, 0.35714285714285715, 0.0, 0.0, 0.64...",0,../Data/train_npy/7.npy,556,14,"[0.008067328482866287, 0.029164904728531837, 0...","[0.0032608057372272015, 0.07425374537706375, 0...","[0.0032928376458585262, 0.17271824181079865, 0...","[0.002575099002569914, 0.054019488394260406, 0...","[0.00583513593301177, 0.15544316172599792, 0.3...","[0.01853444054722786, 0.08432004600763321, 0.4..."
4,3037445252,1730458,../Data/train_eegs/3037445252.parquet,Other,10187,5,"[0.0, 0.0, 0.3333333333333333, 0.0, 0.0, 0.666...",-1,../Data/train_npy/8.npy,693,3,"[0.06870211660861969, 0.010298708453774452, 0....","[0.00406707264482975, 0.006945385131984949, 0....","[0.015403089113533497, 0.029696019366383553, 0...","[0.09617599099874496, 0.02372109889984131, 0.3...","[0.021986497566103935, 0.03224774822592735, 0....","[0.0807790607213974, 0.033510494977235794, 0.1..."


# Train

In [10]:
## create model
if CFG['load_model']:
    print(f"load_model: {CFG['load_model']}")
    model= torch.load(CFG['load_model'], map_location= 'cuda')
else:
    model= Slide_Window_Model(CFG['model_name'], CFG['num_classes'])
    
if CFG['gradient_checkpoint']: 
    print('use gradient checkpoint')
    model.model.set_grad_checkpointing(enable=True)
    
## EMA
model.to('cuda')
if CFG['EMA']:
    print(f"Use EMA: {CFG['EMA']}")
    model_ema= ModelEmaV3(model, decay=CFG['EMA'])
    model_ema.to('cuda')
else:
    model_ema= type('model_ema', (object,), {'module':{}})
    
## hyperparameter
criterion= Customize_loss()
optimizer= optim.AdamW(model.parameters(), lr= CFG['lr'], weight_decay= CFG['weight_decay'])

## start training
best_score= 100000
for ep in range(1, CFG['epoch']+1):
    print(f'\nep: {ep}')
    
    if CFG['EMA']: train_loss= train_epoch(train_loader, model, criterion, optimizer, model_ema)
    else: 
        train_loss= train_epoch(train_loader, model, criterion, optimizer, False)
        model_ema.module= model
    valid_loss, valid_acc= valid_epoch(valid_loader, model_ema.module, criterion)
    print(f'train loss: {round(train_loss, 5)}')
    print(f'valid loss: {round(valid_loss, 5)}, valid_acc: {round(valid_acc, 5)}')
    
    if valid_acc <= best_score:
        best_score= valid_acc
        torch.save(model_ema.module, f"{CFG['save_model']}/cv{CFG['fold']}_best.pth")
        print(f'model save at score: {round(best_score, 5)}')
        
        ## save model every epoch
#         torch.save(model_ema.module, f"{CFG['save_model']}/cv{CFG['fold']}_ep{ep}.pth")

load_model: ./train_model/cv1_best.pth
Use EMA: 0.995

ep: 1


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7628294036061026
mean_recall: 0.6460050686372891
kl_divergence: 0.30165845822044013
train loss: 0.04038
valid loss: 0.06896, valid_acc: 0.30166
model save at score: 0.30166

ep: 2


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7850208044382802
mean_recall: 0.655042271119728
kl_divergence: 0.2889690387088194
train loss: 0.03627
valid loss: 0.06556, valid_acc: 0.28897
model save at score: 0.28897

ep: 3


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7850208044382802
mean_recall: 0.6768444546201485
kl_divergence: 0.28433378551589106
train loss: 0.03442
valid loss: 0.06455, valid_acc: 0.28433
model save at score: 0.28433

ep: 4


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7808599167822469
mean_recall: 0.6897965562819276
kl_divergence: 0.28007699821104903
train loss: 0.03314
valid loss: 0.0635, valid_acc: 0.28008
model save at score: 0.28008

ep: 5


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7739251040221914
mean_recall: 0.6733944037957542
kl_divergence: 0.28844916085762917
train loss: 0.03148
valid loss: 0.06557, valid_acc: 0.28845

ep: 6


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7766990291262136
mean_recall: 0.6574551714227258
kl_divergence: 0.28264518630375923
train loss: 0.03061
valid loss: 0.06411, valid_acc: 0.28265

ep: 7


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.782246879334258
mean_recall: 0.6806746329439504
kl_divergence: 0.2753605716848419
train loss: 0.03041
valid loss: 0.0625, valid_acc: 0.27536
model save at score: 0.27536

ep: 8


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7877947295423023
mean_recall: 0.6823721352416025
kl_divergence: 0.2774261439639997
train loss: 0.02862
valid loss: 0.06272, valid_acc: 0.27743

ep: 9


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7697642163661581
mean_recall: 0.6636170203265103
kl_divergence: 0.27633669165718466
train loss: 0.02751
valid loss: 0.06264, valid_acc: 0.27634

ep: 10


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7864077669902912
mean_recall: 0.6836918442272907
kl_divergence: 0.2736205153802159
train loss: 0.02745
valid loss: 0.06198, valid_acc: 0.27362
model save at score: 0.27362

ep: 11


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7753120665742025
mean_recall: 0.676195103970798
kl_divergence: 0.2742026028852993
train loss: 0.02645
valid loss: 0.06232, valid_acc: 0.2742

ep: 12


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7850208044382802
mean_recall: 0.6742152867762395
kl_divergence: 0.26768702593334076
train loss: 0.02631
valid loss: 0.06061, valid_acc: 0.26769
model save at score: 0.26769

ep: 13


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7780859916782247
mean_recall: 0.673326326711548
kl_divergence: 0.2719498707098494
train loss: 0.0256
valid loss: 0.06172, valid_acc: 0.27195

ep: 14


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7739251040221914
mean_recall: 0.6803949972882831
kl_divergence: 0.26978308832451264
train loss: 0.02428
valid loss: 0.06124, valid_acc: 0.26978

ep: 15


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7669902912621359
mean_recall: 0.6698135978799895
kl_divergence: 0.2691899247944097
train loss: 0.02433
valid loss: 0.0611, valid_acc: 0.26919

ep: 16


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7794729542302358
mean_recall: 0.6923328031683246
kl_divergence: 0.2657491799809015
train loss: 0.02353
valid loss: 0.06043, valid_acc: 0.26575
model save at score: 0.26575

ep: 17


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7739251040221914
mean_recall: 0.6700733421991862
kl_divergence: 0.26631485845312075
train loss: 0.0231
valid loss: 0.06062, valid_acc: 0.26631

ep: 18


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7808599167822469
mean_recall: 0.6501228594486285
kl_divergence: 0.2672289726736259
train loss: 0.02287
valid loss: 0.06082, valid_acc: 0.26723

ep: 19


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7725381414701803
mean_recall: 0.6408286527800782
kl_divergence: 0.2677640898287163
train loss: 0.02248
valid loss: 0.06101, valid_acc: 0.26776

ep: 20


  0%|          | 0/652 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

accuracy: 0.7780859916782247
mean_recall: 0.6602280034294288
kl_divergence: 0.26610458222294553
train loss: 0.02252
valid loss: 0.06052, valid_acc: 0.2661
