In [1]:
import os
import gc
import cv2
import time
import random
from monai.inferers import sliding_window_inference

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
from pytorch_toolbelt import losses as L

# Utils
import math
from tqdm.auto import tqdm

# For Image Models
import timm
from timm.utils.model_ema import ModelEmaV3

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings("ignore")

## using gpu:1
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything()

In [2]:
nsamples= 10_000
eeg_features = ["Fp1", "T3", "C3", "O1", "Fp2", "C4", "T4", "O2"]

def eeg_from_parquet(
    parquet_path: str, display: bool = False, seq_length=50):

    # Вырезаем среднюю 50 секундную часть
    eeg = pd.read_parquet(parquet_path, columns=eeg_features)
    rows = len(eeg)

    # начало смещения данных, чтобы забрать середину
    offset = (rows - nsamples) // 2

    # средние 50 секунд, имеет одинаковое количество показаний слева и справа
    eeg = eeg.iloc[offset : offset + nsamples]

    if display:
        plt.figure(figsize=(10, 5))
        offset = 0

    # Конвертировать в numpy

    # создать заполнитель той же формы с нулями
    data = np.zeros((nsamples, len(eeg_features)))

    for index, feature in enumerate(eeg_features):
        x = eeg[feature].values.astype("float32")  # конвертировать в float32

        # Вычисляет среднее арифметическое вдоль указанной оси, игнорируя NaN.
        mean = np.nanmean(x)
        nan_percentage = np.isnan(x).mean()  # percentage of NaN values in feature

        # Заполнение значения Nan
        # Поэлементная проверка на NaN и возврат результата в виде логического массива.
        if nan_percentage < 1:  # если некоторые значения равны Nan, но не все
            x = np.nan_to_num(x, nan=mean)
        else:  # если все значения — Nan
            x[:] = 0
        data[:, index] = x

    return data

In [3]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    

class WidthAttention(nn.Module):
    def __init__(self, in_ch, width: int):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, kernel_size=(1, 1)),
            nn.BatchNorm2d(in_ch),
            nn.SiLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(in_ch, width),
            nn.Sigmoid()
        )

    def forward(self, x):
        attention = self.attention(x)
        attention = attention.unsqueeze(1).unsqueeze(1)
        return x * attention

    
class Raw_Signal_Model(nn.Module):
    def __init__(self, model_name, cls):
        super().__init__()
        self.cls= cls
        self.model = timm.create_model(model_name, 
                                       pretrained=True, 
                                       num_classes=cls,
                                       drop_rate= CFG['drop_out'], 
                                       drop_path_rate= CFG['drop_path'])
        
        self.gp= self.model.global_pool
        self.out= self.model.classifier
        self.model.global_pool= nn.Identity()
        self.model.classifier= nn.Identity()
        self.att= nn.Sequential(
                        WidthAttention(self.out.in_features, 313),
                    )
        
    def forward(self, image):
        x = self.model(image)  ## (1,1280,13,21)
        x = self.att(x)  ## (1,1280,13,21)
        x = self.gp(x)   ## (1,1280)
        x = self.out(x)  ## (1,6)
        return x if self.training else x.view(-1, self.cls, 1, 1)
    
    
class Transformer_Model(nn.Module):
    def __init__(self, cls):
        super().__init__()
        self.cls= cls
        
        self.model = timm.create_model('tf_efficientnet_b0_ns', 
                                       pretrained=True, 
                                       num_classes=cls,
                                       drop_rate= CFG['drop_out'], 
                                       drop_path_rate= CFG['drop_path'])
        self.model.global_pool= nn.Identity()
        self.model.classifier= nn.Identity()
        
        self.posen= PositionalEncoding(1280)
        encoder_layer1 = nn.TransformerEncoderLayer(d_model=1280, nhead=8)
        self.transformer_encoder1 = nn.TransformerEncoder(encoder_layer1, num_layers=1)
        self.fc= nn.Linear(1280, self.cls)
        
    def forward(self, image):
        x = image
        
        x = self.model(x).squeeze(dim=-2).permute(0,2,1)
        
        x = self.posen(x)
        x = self.transformer_encoder1(x)
        x = self.fc(x.mean(dim=-2))
        return x if self.training else x.view(-1, self.cls, 1, 1)
    
    
# model= Raw_Signal_Model('tf_efficientnet_b0_ns', 6)
# model= Transformer_Model(6)
# x= torch.rand(1,3,20,2000)
# model(x).shape

In [4]:
def get_train_transform(img_size):
    return A.Compose([
#         A.PadIfNeeded(min_height=8, min_width=CFG['img_crop'], border_mode=0, p=1),
#         A.RandomCrop(width=CFG['img_crop'], height=8, p=1),
        
#         A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
        
        ToTensorV2(p=1.0),
    ])


def get_test_transform(img_size):
    return A.Compose([
#         A.PadIfNeeded(min_height=8, min_width=CFG['img_crop'], border_mode=0, p=1),
        ToTensorV2(p=1.0),
    ])

In [5]:
class Customize_Dataset(Dataset):
    def __init__(self, df, transforms=None, mixup=False):
        self.df = df
        self.transforms = transforms
        self.mixup= mixup
    
    def mixup_aug(self, img_1, mask_1, 
                        img_2, mask_2):
        """
        img: numpy array of shape (height, width,channel)
        mask: numpy array of shape (height, width,channel)
        """
        ## mixup
        weight= np.random.beta(a=0.5, b=0.5)
        img= img_1*weight + img_2*(1-weight)
        mask= mask_1*weight + mask_2*(1-weight)
        return img, mask
    
    def read_data(self, data):
        def norm_to_standard(img):
            ep = 1e-6
            m = np.nanmean(img.flatten())
            s = np.nanstd(img.flatten())
            img = (img-m)/(s+ep)
            img = np.nan_to_num(img, nan=0.0)
            return img
        
        img= eeg_from_parquet(data['image_path']).transpose(1,0)
        for i in range(img.shape[0]):
            img[i]= norm_to_standard(img[i])
        img= img[:,:,None]
        img= np.concatenate([img, img, img], axis=-1)
        label= np.array(eval(data['soft_label']))
        return img, label
    
    def __getitem__(self, index):
        data = self.df.loc[index]
        img, label= self.read_data(data)
        
        # use mixup
        if self.mixup and np.random.rand() >= (1-self.mixup) and img.shape[1]<=656:
            img_1= img
            label_1= np.array(label)
            while True:
                indx= np.random.randint(len(self.df))
                data= self.df.loc[indx]
                img_2, label_2= self.read_data(data)
                if img_2.shape[1] > img.shape[1]: 
                    img_2= img_2[:, :img.shape[1]]
                    break
            img, label= self.mixup_aug(img_1, label_1, 
                                       img_2, label_2)
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': torch.tensor(img, dtype=torch.float32),
            'label': torch.tensor(label, dtype=torch.float32),
        }
    
    def __len__(self):
        return len(self.df)

In [6]:
class Customize_loss(nn.Module):
    def  __init__(self):
        super().__init__()
        self.CrossEntropy= nn.CrossEntropyLoss(weight= None, label_smoothing=0.)
        self.FocalCosineLoss= L.FocalCosineLoss()
        self.kl_loss = nn.KLDivLoss()
    
    def forward(self, y_pred, y_true):
        loss= 1*self.kl_loss(y_pred.log_softmax(dim=-1), y_true)
        return loss

In [7]:
def train_epoch(dataloader, model, criterion, optimizer, model_ema):
    scaler= amp.GradScaler()
    model.train()

    ep_loss= []
    for i, data in enumerate(tqdm(dataloader)):

        imgs= data['image'].to('cuda')
        labels= data['label'].to('cuda')
        
        with amp.autocast():
            preds= model(imgs)
            loss= criterion(preds, labels)
            ep_loss.append(loss.item())
            loss/= CFG['gradient_accumulation']
            scaler.scale(loss).backward()
            
            if (i+1) % CFG['gradient_accumulation']== 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                
            if model_ema: model_ema.update(model)
                
    return np.mean(ep_loss)

In [8]:
from metrics import *

def valid_epoch(dataloader, model, criterion):
    model.eval()
    
    ep_loss= []
    all_pred= []
    all_label= []
    for i, data in enumerate(tqdm(dataloader)):

        imgs= data['image'].to('cuda')
        labels= data['label'].to('cuda')
        all_label.extend(labels.cpu().numpy())
        
        with torch.no_grad():
            preds= sliding_window_inference(imgs, 
                                            roi_size=(-1,CFG['img_crop']), 
                                            mode= 'gaussian',
                                            sw_batch_size=1, 
                                            predictor=model)
            preds= preds.view(preds.shape[0],model.cls,-1).mean(dim=-1)
            loss= criterion(preds, labels)
            ep_loss.append(loss.item())
        all_pred.extend(preds.cpu().softmax(dim=-1).numpy())
    
    ## caculate metrics
    soft_label= all_label.copy()
    all_label= np.array(all_label).argmax(1)
    all_pred= np.array(all_pred)
    
    acc= Accuracy(all_pred, all_label)
    print(f'accuracy: {acc}')
    recall= Mean_Recall(all_pred, all_label)
    print(f'mean_recall: {recall}')
    kl_score= kl_divergence(soft_label, all_pred)
    print(f'kl_divergence: {kl_score}')
    
    score= kl_score
    return np.mean(ep_loss), score

# CFG

In [9]:
CFG= {
    'fold': 0,
    'epoch': 20,
    
    'model_name': 'tf_efficientnet_b0_ns',
    'img_size': None,
    'img_crop': 10000,
    
    'batch_size': 16,
    'gradient_accumulation': 1,
    'gradient_checkpoint': False,
    'drop_out': 0.3,
    'drop_path': 0.2,
    'mixup': 0.,
    'EMA': 0.995,
    
    'lr': 3e-4,
    'weight_decay': 0.,
    
    'num_classes': 6,
    'load_model': False,
    'save_model': './train_model_copy',
    
    'finetune': False,
}

if CFG['finetune']:
    print('finetune')
    CFG['epoch']= 10
    CFG['load_model']= f"{CFG['save_model']}/cv{CFG['fold']}_best.pth"
    CFG['lr']= 3e-4

# Prepare Dataset

In [10]:
df= pd.read_csv('../Data/train_npy.csv')
df= df[df['voter']>7]

train_df= df[df['fold']!=CFG['fold']].reset_index(drop=True)
valid_df= df[df['fold']==CFG['fold']].reset_index(drop=True)
valid_df= valid_df.drop_duplicates(subset=['spectrogram_id']).reset_index(drop=True)
print(f'train dataset: {len(train_df)}')
print(f'valid dataset: {len(valid_df)}')

train_dataset= Customize_Dataset(train_df, get_train_transform(CFG['img_size']), mixup=CFG['mixup'])
valid_dataset= Customize_Dataset(valid_df, get_test_transform(CFG['img_size']), mixup=False)

train_loader= DataLoader(train_dataset, batch_size= CFG['batch_size'], shuffle=True, num_workers=0)
valid_loader= DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)
train_df.head()

train dataset: 4269
valid dataset: 921


Unnamed: 0,eeg_id,spectrogram_id,image_path,expert_consensus,patient_id,label,soft_label,fold,npy_path,time_length,voter
0,2277392603,924234,../Data/train_eegs/2277392603.parquet,GPD,30539,2.0,"[0.0, 0.0, 0.45454545454545453, 0.0, 0.0909090...",1.0,../Data/train_npy/1.npy,557.0,11.0
1,722738444,999431,../Data/train_eegs/722738444.parquet,LRDA,56885,3.0,"[0.0, 0.0625, 0.0, 0.875, 0.0, 0.0625]",1.0,../Data/train_npy/2.npy,568.0,16.0
2,374504640,3452193,../Data/train_eegs/374504640.parquet,GRDA,29847,4.0,"[0.0, 0.0, 0.16666666666666666, 0.0, 0.6666666...",2.0,../Data/train_npy/19.npy,582.0,12.0
3,1445780287,4004824,../Data/train_eegs/1445780287.parquet,Other,22597,5.0,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0]",4.0,../Data/train_npy/20.npy,564.0,17.0
4,893864755,4367732,../Data/train_eegs/893864755.parquet,Other,2338,5.0,"[0.0, 0.0, 0.0, 0.0, 0.23076923076923078, 0.76...",3.0,../Data/train_npy/23.npy,571.0,13.0


# Train

In [11]:
## create model
if CFG['load_model']:
    print(f"load_model: {CFG['load_model']}")
    model= torch.load(CFG['load_model'], map_location= 'cuda')
else:
    model= Raw_Signal_Model(CFG['model_name'], CFG['num_classes'])
#     model= Transformer_Model(CFG['num_classes'])
    Transformer_Model
    
if CFG['gradient_checkpoint']: 
    print('use gradient checkpoint')
    model.model.set_grad_checkpointing(enable=True)
    
## EMA
model.to('cuda')
if CFG['EMA']:
    print(f"Use EMA: {CFG['EMA']}")
    model_ema= ModelEmaV3(model, decay=CFG['EMA'])
    model_ema.to('cuda')
else:
    model_ema= type('model_ema', (object,), {'module':{}})
    
## hyperparameter
criterion= Customize_loss()
optimizer= optim.AdamW(model.parameters(), lr= CFG['lr'], weight_decay= CFG['weight_decay'])

## start training
best_score= 100000
for ep in range(1, CFG['epoch']+1):
    print(f'\nep: {ep}')
    
    if CFG['EMA']: train_loss= train_epoch(train_loader, model, criterion, optimizer, model_ema)
    else: 
        train_loss= train_epoch(train_loader, model, criterion, optimizer, False)
        model_ema.module= model
    valid_loss, valid_acc= valid_epoch(valid_loader, model_ema.module, criterion)
    print(f'train loss: {round(train_loss, 5)}')
    print(f'valid loss: {round(valid_loss, 5)}, valid_acc: {round(valid_acc, 5)}')
    
    if valid_acc <= best_score:
        best_score= valid_acc
        torch.save(model_ema.module, f"{CFG['save_model']}/cv{CFG['fold']}_best.pth")
        print(f'model save at score: {round(best_score, 5)}')
        
        ## save model every epoch
        torch.save(model_ema.module, f"{CFG['save_model']}/cv{CFG['fold']}_ep{ep}.pth")

Use EMA: 0.995

ep: 1


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.24104234527687296
mean_recall: 0.2410876058579675
kl_divergence: 1.07531218572957
train loss: 0.10983
valid loss: 0.17922, valid_acc: 1.07531
model save at score: 1.07531

ep: 2


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.5895765472312704
mean_recall: 0.2797297530789899
kl_divergence: 0.6308234080358068
train loss: 0.08269
valid loss: 0.10514, valid_acc: 0.63082
model save at score: 0.63082

ep: 3


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6047774158523345
mean_recall: 0.3241537464945872
kl_divergence: 0.5937011917312157
train loss: 0.07157
valid loss: 0.09895, valid_acc: 0.5937
model save at score: 0.5937

ep: 4


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6373507057546145
mean_recall: 0.3671536744940447
kl_divergence: 0.5421598338429052
train loss: 0.06343
valid loss: 0.09036, valid_acc: 0.54216
model save at score: 0.54216

ep: 5


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6568946796959826
mean_recall: 0.3983843334358363
kl_divergence: 0.5255691252529925
train loss: 0.05669
valid loss: 0.08759, valid_acc: 0.52557
model save at score: 0.52557

ep: 6


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6525515743756786
mean_recall: 0.4014212002208968
kl_divergence: 0.5394992904865191
train loss: 0.05088
valid loss: 0.08992, valid_acc: 0.5395

ep: 7


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6438653637350705
mean_recall: 0.3929874441127977
kl_divergence: 0.5526290030664559
train loss: 0.04582
valid loss: 0.0921, valid_acc: 0.55263

ep: 8


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6482084690553745
mean_recall: 0.3920462321523798
kl_divergence: 0.5542368708224287
train loss: 0.04207
valid loss: 0.09237, valid_acc: 0.55424

ep: 9


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6514657980456026
mean_recall: 0.3934431580549003
kl_divergence: 0.5532945601574226
train loss: 0.03844
valid loss: 0.09222, valid_acc: 0.55329

ep: 10


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6547231270358306
mean_recall: 0.40999568254393276
kl_divergence: 0.5654962362569478
train loss: 0.03422
valid loss: 0.09425, valid_acc: 0.5655

ep: 11


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6536373507057546
mean_recall: 0.39732591803641126
kl_divergence: 0.561824114520448
train loss: 0.03267
valid loss: 0.09364, valid_acc: 0.56182

ep: 12


  0%|          | 0/267 [00:00<?, ?it/s]

  0%|          | 0/921 [00:00<?, ?it/s]

accuracy: 0.6568946796959826
mean_recall: 0.38913973802566487
kl_divergence: 0.5534440388888611
train loss: 0.0301
valid loss: 0.09224, valid_acc: 0.55344

ep: 13


  0%|          | 0/267 [00:00<?, ?it/s]

KeyboardInterrupt: 