# Import Tools

In [1]:
import os
import sys
import gc

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
import pandas as pd, numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import albumentations as albu
from sklearn.model_selection import KFold, GroupKFold
import torchvision.transforms as transforms
import random

In [2]:
class Config:
    seed = 42 
    batch_size = 16
    num_epochs = 20
    num_folds = 5
#     root_path = ""
    root_path = "/kaggle/input/hms-harmful-brain-activity-classification/"
    image_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229])
    ])


def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('medium')
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
set_seed(Config.seed)

def KL_loss(p,q):
    epsilon=10**(-15)
    p=torch.clip(p,epsilon,1-epsilon)
    q = nn.functional.log_softmax(q,dim=1)
    return torch.mean(torch.sum(p*(torch.log(p)-q),dim=1))

gc.collect()

0

# Get a data

In [3]:
df = pd.read_csv(Config.root_path + 'train.csv')
TARGETS = df.columns[-6:]
print('Train shape:', df.shape )
print('Targets', list(TARGETS))
df.head()

Train shape: (106800, 15)
Targets ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']


Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,1628180742,0,0.0,353733,0,0.0,127492639,42516,Seizure,3,0,0,0,0,0
1,1628180742,1,6.0,353733,1,6.0,3887563113,42516,Seizure,3,0,0,0,0,0
2,1628180742,2,8.0,353733,2,8.0,1142670488,42516,Seizure,3,0,0,0,0,0
3,1628180742,3,18.0,353733,3,18.0,2718991173,42516,Seizure,3,0,0,0,0,0
4,1628180742,4,24.0,353733,4,24.0,3080632009,42516,Seizure,3,0,0,0,0,0


In [4]:
train = df.groupby('eeg_id')[
    ['spectrogram_id', 'spectrogram_label_offset_seconds']
].agg({'spectrogram_id': 'first', 'spectrogram_label_offset_seconds': 'min'})
train.columns = ['spec_id', 'min']

tmp = df.groupby('eeg_id')[
    ['spectrogram_id','spectrogram_label_offset_seconds']
].agg({'spectrogram_label_offset_seconds' :'max'})
train['max'] = tmp

tmp = df.groupby('eeg_id')[['patient_id']].agg('first')
train['patient_id'] = tmp

tmp = df.groupby('eeg_id')[TARGETS].agg('sum')
for t in TARGETS:
    train[t] = tmp[t].values
    
y_data = train[TARGETS].values
y_data = y_data / y_data.sum(axis=1, keepdims=True)
train[TARGETS] = y_data

tmp = df.groupby('eeg_id')[['expert_consensus']].agg('first')
train['target'] = tmp

train = train.reset_index()
print('Train non-overlapp eeg_id shape:', train.shape )
train.head()

Train non-overlapp eeg_id shape: (17089, 12)


Unnamed: 0,eeg_id,spec_id,min,max,patient_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,target
0,568657,789577333,0.0,16.0,20654,0.0,0.0,0.25,0.0,0.166667,0.583333,Other
1,582999,1552638400,0.0,38.0,20230,0.0,0.857143,0.0,0.071429,0.0,0.071429,LPD
2,642382,14960202,1008.0,1032.0,5955,0.0,0.0,0.0,0.0,0.0,1.0,Other
3,751790,618728447,908.0,908.0,38549,0.0,0.0,1.0,0.0,0.0,0.0,GPD
4,778705,52296320,0.0,0.0,40955,0.0,0.0,0.0,0.0,0.0,1.0,Other


## Dataset

In [5]:
class ViTDataset(Dataset):
    
    def __init__(self, data, augment=True, mode='train'): 
        self.data = data
        self.augment = augment
        self.mode = mode
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        path = self.data.iloc[index]["spec_id"]
        batch_data = self.get_batch(path)
        if self.mode == 'train':
            return batch_data, torch.tensor([0])
        else:
            return batch_data
    
    def get_batch(self, path):
        eps = 1e-6        
        data = pd.read_parquet(f"{Config.root_path}train_spectrograms/{path}.parquet")
        # Preprocess data
        data = data.fillna(-1).values[:, 1:].T
        data = np.clip(data, np.exp(-6), np.exp(10))
        data = np.log(data)

        # Normalize data
        data_mean = data.mean(axis=(0, 1))
        data_std = data.std(axis=(0, 1))
        data = (data - data_mean) / (data_std + eps)

        data_tensor = torch.unsqueeze(torch.Tensor(data), dim=0)
        data = Config.image_transform(data_tensor)
        
        return data

# Training

In [6]:
class ViT(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.base_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        last_layer = self.base_model.heads[0]
        new_last_layer = nn.Linear(last_layer.in_features, 6)
        self.base_model.heads = new_last_layer

        self.loss_fn = nn.KLDivLoss(reduction='batchmean')
        self.resize_transform = T.Resize(size=(224, 224), interpolation=InterpolationMode.BILINEAR)
        
    def forward(self, x):
        # Process input as before
        x1 = [x[:, :, :, i:i+1] for i in range(4)]
        x1 = torch.concat(x1, dim=1)
        x2 = [x[:, :, :, i+4:i+5] for i in range(4)]
        x2 = torch.concat(x2, dim=1)
        x = torch.concat([x1, x2], dim=2)
        x = torch.concat([x, x, x], dim=3)
        x = x.permute(0, 3, 1, 2)
        x_resized = torch.stack([self.resize_transform(xi) for xi in x])
        out = self.base_model(x_resized)
    
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        out = F.log_softmax(out, dim=1)
        y = F.one_hot(y, num_classes=6).to(torch.float)
        loss = self.loss_fn(out, y)
        return loss
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return F.softmax(self(batch), dim=1)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# Find best params

In [7]:
# import optuna
# from pytorch_lightning import Trainer

# def objective(trial):
#     lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
#     batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
    
#     train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=3, persistent_workers=True)
#     model = EEGEffnetB0(lr=lr)

#     trainer = Trainer(
#         max_epochs=10,
#         gpus=1 if torch.cuda.is_available() else 0,
#     )
    
#     trainer.fit(model, train_loader, valid_loader)
#     return trainer.callback_metrics["val_loss"].item()

# study = optuna.create_study(direction="minimize")
# study.optimize(objective, n_trials=10)
# print("Best trial:", study.best_trial.params)


## Training

In [8]:
all_oof = []
all_true = []
valid_loaders = []

gkf = GroupKFold(n_splits=5)
for i, (train_indices, valid_indices) in enumerate(gkf.split(train, train.target, train.patient_id)):  
    print('#'*25)
    print(f'### Fold {i+1}')
    
    train_ds = ViTDataset(train.iloc[train_indices])
    train_loader = DataLoader(train_ds, shuffle=True, batch_size=32, num_workers=3, persistent_workers=True)
    valid_ds = ViTDataset(train.iloc[valid_indices], mode='valid')
    valid_loader = DataLoader(valid_ds, shuffle=False, batch_size=64, num_workers=3)
    
    print(f'### Train size: {len(train_indices)}, Valid size: {len(valid_indices)}')
    print('#'*25)
    
    trainer = pl.Trainer(max_epochs=4)
    model = ViT()
    trainer.fit(model=model, train_dataloaders=train_loader)
    trainer.save_checkpoint(f'Vit_f{i}.ckpt')

    valid_loaders.append(valid_loader)
    all_true.append(train.iloc[valid_indices][TARGETS].values)
    del trainer, model
    gc.collect()

#########################
### Fold 1
### Train size: 13671, Valid size: 3418
#########################


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 173MB/s]
2024-02-11 09:50:22.039082: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-11 09:50:22.039189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-11 09:50:22.159420: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Training: |          | 0/? [00:00<?, ?it/s]



#########################
### Fold 2
### Train size: 13671, Valid size: 3418
#########################


Training: |          | 0/? [00:00<?, ?it/s]

#########################
### Fold 3
### Train size: 13671, Valid size: 3418
#########################


Training: |          | 0/? [00:00<?, ?it/s]

#########################
### Fold 4
### Train size: 13671, Valid size: 3418
#########################


Training: |          | 0/? [00:00<?, ?it/s]

#########################
### Fold 5
### Train size: 13672, Valid size: 3417
#########################


Training: |          | 0/? [00:00<?, ?it/s]