# Import Tools

In [None]:
import os
import sys
import gc

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
import pandas as pd, numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import albumentations as albu
from sklearn.model_selection import KFold, GroupKFold
import torchvision.transforms as transforms
import random

In [None]:
class Config:
    seed = 42 
    image_transform = transforms.Resize((512,512))  
    batch_size = 16
    num_epochs = 20
    num_folds = 5
#     root_path = ""
    spec_path = "/kaggle/input/brain-spectrograms/specs.npy/"
    egg_path = "/kaggle/input/brain-eeg-spectrograms/eeg_specs.npy/"
    root_path = "/kaggle/input/hms-harmful-brain-activity-classification/"

def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('medium')
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
set_seed(Config.seed)

def KL_loss(p,q):
    epsilon=10**(-15)
    p=torch.clip(p,epsilon,1-epsilon)
    q = nn.functional.log_softmax(q,dim=1)
    return torch.mean(torch.sum(p*(torch.log(p)-q),dim=1))

gc.collect()

# Get a data

In [None]:
df = pd.read_csv(Config.root_path + 'train.csv')
TARGETS = df.columns[-6:]
print('Train shape:', df.shape )
print('Targets', list(TARGETS))
df.head()

In [None]:
train = df.groupby('eeg_id')[
    ['spectrogram_id', 'spectrogram_label_offset_seconds']
].agg({'spectrogram_id': 'first', 'spectrogram_label_offset_seconds': 'min'})
train.columns = ['spec_id', 'min']

tmp = df.groupby('eeg_id')[
    ['spectrogram_id','spectrogram_label_offset_seconds']
].agg({'spectrogram_label_offset_seconds' :'max'})
train['max'] = tmp

tmp = df.groupby('eeg_id')[['patient_id']].agg('first')
train['patient_id'] = tmp

tmp = df.groupby('eeg_id')[TARGETS].agg('sum')
for t in TARGETS:
    train[t] = tmp[t].values
    
y_data = train[TARGETS].values
y_data = y_data / y_data.sum(axis=1, keepdims=True)
train[TARGETS] = y_data

tmp = df.groupby('eeg_id')[['expert_consensus']].agg('first')
train['target'] = tmp

train = train.reset_index()
print('Train non-overlapp eeg_id shape:', train.shape )
train.head()

In [None]:
spectrograms = np.load(Config.spec_path + 'specs.npy',allow_pickle=True).item()
all_eegs = np.load(Config.egg_path + 'eeg_specs.npy',allow_pickle=True).item()

## Dataset

In [None]:
TARS = {'Seizure':0, 'LPD':1, 'GPD':2, 'LRDA':3, 'GRDA':4, 'Other':5}
TARS2 = {x: y for y, x in TARS.items()}


class ViTDataset(Dataset):
    
    def __init__(self, data, augment=False, mode='train', specs=spectrograms, eeg_specs=all_eegs): 
        self.data = data
        self.augment = augment
        self.mode = mode
        self.specs = specs
        self.eeg_specs = eeg_specs
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.__getitems__([index])
    
    def __getitems__(self, indices):
        X, y = self._generate_data(indices)
        if self.augment:
            X = self._augment(X) 
        if self.mode == 'train':
            return list(zip(X, y))
        else:
            return X
    
    def _generate_data(self, indexes):
        X = np.zeros((len(indexes), 128, 256, 8),dtype='float32')
        y = np.zeros((len(indexes), 6),dtype='float32')
        img = np.ones((128, 256),dtype='float32')
        
        for j, i in enumerate(indexes):
            row = self.data.iloc[i]
            if self.mode == 'test': 
                r = 0
            else: 
                r = int((row['min'] + row['max'])//4)

            for k in range(4):
                # EXTRACT 300 ROWS OF SPECTROGRAM
                img = self.specs[row.spec_id][r:r+300, k*100:(k+1)*100].T
                
                # LOG TRANSFORM SPECTROGRAM
                img = np.clip(img, np.exp(-4), np.exp(8))
                img = np.log(img)
                
                # STANDARDIZE PER IMAGE
                ep = 1e-6
                m = np.nanmean(img.flatten())
                s = np.nanstd(img.flatten())
                img = (img - m) / (s + ep)
                img = np.nan_to_num(img, nan=0.0)
                
                # CROP TO 256 TIME STEPS
                X[j, 14:-14, :, k] = img[:, 22:-22] / 2.0
        
            # EEG SPECTROGRAMS
            img = self.eeg_specs[row.eeg_id]
            X[j, :, :, 4:] = img
                
            if self.mode != 'test':
                y[j,] = row[TARGETS]
            
        return X, y
    
    def _random_transform(self, img):
        composition = albu.Compose([
            albu.HorizontalFlip(p=0.5),
            # albu.CoarseDropout(max_holes=8,max_height=32,max_width=32,fill_value=0,p=0.5),
        ])
        return composition(image=img)['image']
            
    def __augment(self, img_batch):
        for i in range(img_batch.shape[0]):
            img_batch[i,] = self._random_transform(img_batch[i,])
        return img_batch

# Training

In [None]:
class ViT(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        self.base_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        num_features = self.base_model.head.in_features
        self.base_model.head = nn.Linear(num_features, 6)
        self.prob_out = nn.Softmax()
        
    def forward(self, x):
        x1 = [x[:, :, :, i:i+1] for i in range(4)]
        x1 = torch.concat(x1, dim=1)
        x2 = [x[:, :, :, i+4:i+5] for i in range(4)]
        x2 = torch.concat(x2, dim=1)
        
        x = torch.concat([x1, x2], dim=2)

        x = torch.concat([x, x, x], dim=3)
        x = x.permute(0, 3, 1, 2)
        
        out = self.base_model(x)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self.forward(x)
        out = F.log_softmax(out, dim=1)
        kl_loss = nn.KLDivLoss(reduction='batchmean')
        loss = kl_loss(out, y)
        return loss
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return F.softmax(self(batch), dim=1)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# Find best params

In [None]:
# import optuna
# from pytorch_lightning import Trainer

# def objective(trial):
#     lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
#     batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
    
#     train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=3, persistent_workers=True)
#     model = EEGEffnetB0(lr=lr)

#     trainer = Trainer(
#         max_epochs=10,
#         gpus=1 if torch.cuda.is_available() else 0,
#     )
    
#     trainer.fit(model, train_loader, valid_loader)
#     return trainer.callback_metrics["val_loss"].item()

# study = optuna.create_study(direction="minimize")
# study.optimize(objective, n_trials=10)
# print("Best trial:", study.best_trial.params)


## Training

In [None]:
all_oof = []
all_true = []
valid_loaders = []

gkf = GroupKFold(n_splits=5)
for i, (train_index, valid_index) in enumerate(gkf.split(train, train.target, train.patient_id)):  
    print('#'*25)
    print(f'### Fold {i+1}')
    
    train_ds = Dataset(train.iloc[train_index])
    train_loader = DataLoader(train_ds, shuffle=True, batch_size=32, num_workers=3, persistent_workers=True)
    valid_ds = EEGDataset(train.iloc[valid_index], mode='valid')
    valid_loader = DataLoader(valid_ds, shuffle=False, batch_size=64, num_workers=3)
    
    print(f'### Train size: {len(train_index)}, Valid size: {len(valid_index)}')
    print('#'*25)
    
    trainer = pl.Trainer(max_epochs=4)
    model = EEGEffnetB7()
    trainer.fit(model=model, train_dataloaders=train_loader)
    trainer.save_checkpoint(f'Vit_f{i}.ckpt')

    valid_loaders.append(valid_loader)
    all_true.append(train.iloc[valid_index][TARGETS].values)
    del trainer, model
    gc.collect()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
for i in range(5):
    print('#'*25)
    print(f'### Validating Fold {i+1}')

    ckpt_file = f'Vit_f{i}.pth'
    model = EEGEffnetB0.load_from_checkpoint(ckpt_file)
    model.to(device).eval()
    with torch.inference_mode():
        for val_batch in valid_loaders[i]:
            val_batch = val_batch.to(device)
            oof = torch.softmax(model(val_batch), dim=1).cpu().numpy()
            all_oof.append(oof)
    del model
    gc.collect()

all_oof = np.concatenate(all_oof)
all_true = np.concatenate(all_true)

In [None]:
oof = pd.DataFrame(all_oof.copy())
oof['id'] = np.arange(len(oof))

true = pd.DataFrame(all_true.copy())
true['id'] = np.arange(len(true))

# cv = score(solution=true, submission=oof, row_id_column_name='id')
kl_loss = nn.KLDivLoss(reduction='batchmean')
loss = kl_loss(out, y)
print('CV Score KL-Div for EfficientNetB0 =', loss)

In [None]:
del all_eegs, spectrograms
gc.collect()

test = pd.read_csv(Config.root_path + 'test.csv')
print('Test shape',test.shape)
test.head()

In [None]:
PATH2 = Config.root_path + 'test_spectrograms/'
files2 = os.listdir(PATH2)
print(f'There are {len(files2)} test spectrogram parquets')
    
spectrograms2 = {}
for i, f in enumerate(files2):
    if i % 100 == 0:
        print(i, ', ',end='')
    tmp = pd.read_parquet(f'{PATH2}{f}')
    name = int(f.split('.')[0])
    spectrograms2[name] = tmp.iloc[:, 1:].values
    
# RENAME FOR DATALOADER
test = test.rename({'spectrogram_id': 'spec_id'}, axis=1)

In [None]:
# INFER EFFICIENTNET ON TEST
preds = []
test_ds = EEGDataset(test, mode='test', specs=spectrograms2, eeg_specs=all_eegs2)
test_loader = DataLoader(test_ds, shuffle=False, batch_size=64, num_workers=3)

for i in range(5):
    print('#'*25)
    print(f'### Testing Fold {i+1}')

    ckpt_file = f'EffNet_f{i}.ckpt'
    model = EEGEffnetB0.load_from_checkpoint(ckpt_file)
    model.to(device).eval()
    fold_preds = []

    with torch.inference_mode():
        for test_batch in test_loader:
            test_batch = test_batch.to(device)
            pred = torch.softmax(model(test_batch), dim=1).cpu().numpy()
            fold_preds.append(pred)
        fold_preds = np.concatenate(fold_preds)

    preds.append(fold_preds)

pred = np.mean(preds,axis=0)
print()
print('Test preds shape',pred.shape)

In [None]:
sub = pd.DataFrame({'eeg_id': test.eeg_id.values})
sub[TARGETS] = pred
sub.to_csv('submission.csv',index=False)
print('Submissionn shape',sub.shape)
sub.head()

In [None]:
sub.iloc[:,-6:].sum(axis=1)