# Multimodal approach using both RAW EEG Features & Spectrograms

#### This is the training notebook for Multimodal based solution in Kaggle's Brain comp. It scores 0.607 on GroupKfold 5 Folds split. Kindly update the paths of the datasets accordingly, I have done the training locally.

#### Submission Notebook link: https://www.kaggle.com/nischaydnk/multimodal-1d-2d-eeg-approach-submission


##### 1D EEGNet training: https://www.kaggle.com/code/nischaydnk/lightning-1d-eegnet-training-pipeline-hbs
##### 1D EEGNet Submission: https://www.kaggle.com/code/nischaydnk/hms-submission-1d-eegnet-pipeline-lightning


#### PS: This approach is somewhat in continuation of what I have shared earlier in these notebooks .You can also refer my previous work to study more about the 1D approach I have used in the notebook. Also, I have used some of the code from [@moth](https://www.kaggle.com/alejopaullier) baseline with 2D approach [notebook](https://www.kaggle.com/code/alejopaullier/hms-efficientnetb0-pytorch-inference) & ideas shared by [@Chris](https://www.kaggle.com/cdeotte)


#### Overall pipeline in this notebook is somewhat like this.
![img2](https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F4712534%2Fc4e39641fa8ec13588b62065ae1e9b58%2FScreenshot%202024-02-10%20at%2010.40.36%20PM.png?generation=1707585942607444&alt=media)

In [99]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from glob import glob

from sklearn import model_selection
from scipy.signal import butter, lfilter

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, BackboneFinetuning, EarlyStopping
from torch.utils.data import Dataset, DataLoader
import torch_audiomentations as tA
from torch_audiomentations import Compose, Gain, PolarityInversion
import torchvision.transforms as transforms
import torchvision.io 
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, ReduceLROnPlateau, OneCycleLR

from braindecode.augmentation import AugmentedDataLoader, SignFlip, FrequencyShift, ChannelsDropout,ChannelsShuffle
from braindecode.augmentation import Transform

import librosa
from PIL import Image
import albumentations as alb
import gc
import timm 

import warnings
warnings.filterwarnings('ignore')

In [74]:
if not os.path.exists(Config.output_dir):
    os.makedirs(Config.output_dir)
    print(f"Directory has been made {Config.output_dir}")
pl.seed_everything(Config.seed, workers=True)

Seed set to 0


0

In [75]:
class Config:
    use_aug = False
    seed = 0
    
    num_classes = 6
    batch_size = 32
    epochs = 10
    PRECISION = 16    
    PATIENCE = 20    
    
    backbone_2d = 'tf_efficientnet_b0'
    pretrained = True            
    weight_decay = 1e-2
    use_mixup = False
    mixup_alpha = 0.1   
    num_channels = 8
    
    data_root = "../Data/original_data/"
    raw_eeg_path = "../Data/raw_eeg/eegs.npy" 
    raw_spec_path = "../Data/spec/eeg_specs.npy" 
    kaggle_spec_path = "../Data/original_data/train_spectrograms/" 
    PRE_LOADED_EEGS = "../Data/spec/eeg_specs.npy" 
    PRE_LOADED_SPECTOGRAMS = "../Data/kagglespec/specs.npy"
    TRAIN_EEGS = "../Data/spec/EEG_Spectrograms/"

    processed_train = None
    LR = 7e-4
    output_dir = '../Results/'
    trn_folds = [0,1,2,3,4]

In [76]:
def config_to_dict(cfg):
    return dict((name, getattr(cfg, name)) for name in dir(cfg) if not name.startswith('__'))

def preprocess_train_data(df):
    # Define targets and mapping
    TARGETS = df.columns[-6:]
    TARS = {'Seizure': 0, 'LPD': 1, 'GPD': 2, 'LRDA': 3, 'GRDA': 4, 'Other': 5}
    num_classes = len(TARS.keys())
    TARS_INV = {x: y for y, x in TARS.items()}
    
    # Aggregate min and max offset seconds, and the first spectrogram_id for each eeg_id
    train_df = df.groupby('eeg_id').agg({
        'spectrogram_id': 'first',
        'spectrogram_label_offset_seconds': ['min', 'max']
    })
    train_df.columns = ['spectrogram_id', 'min_offset_seconds', 'max_offset_seconds']
    
    # Aggregate the first patient_id for each eeg_id
    train_df['patient_id'] = df.groupby('eeg_id')['patient_id'].agg('first')
    
    # Aggregate the sum of target labels for each eeg_id
    targets_sum = df.groupby('eeg_id')[TARGETS].agg('sum')
    for label in TARGETS:
        train_df[label] = targets_sum[label].values
    
    # Normalize the target labels
    y_data = train_df[TARGETS].values
    y_data = y_data / y_data.sum(axis=1, keepdims=True)
    train_df[TARGETS] = y_data
    
    # Aggregate the first expert_consensus for each eeg_id
    train_df['target'] = df.groupby('eeg_id')['expert_consensus'].agg('first')
    
    # Reset index to get eeg_id back as a column
    train_df = train_df.reset_index()
    
    return train_df, num_classes, TARS, TARS_INV

def eeg_from_parquet(parquet_path, display=False):
    
    # EXTRACT MIDDLE 50 SECONDS
    eeg = pd.read_parquet(parquet_path, columns=FEATS)
    rows = len(eeg)
    offset = (rows-10_000)//2
    eeg = eeg.iloc[offset:offset+10_000]
    
    if display: 
        plt.figure(figsize=(10,5))
        offset = 0
    
    # CONVERT TO NUMPY
    data = np.zeros((10_000,len(FEATS)))
    for j,col in enumerate(FEATS):
        
        # FILL NAN
        x = eeg[col].values.astype('float32')
        m = np.nanmean(x)
        if np.isnan(x).mean()<1: x = np.nan_to_num(x,nan=m)
        else: x[:] = 0
            
        data[:,j] = x
        
        if display: 
            if j!=0: offset += x.max()
            plt.plot(range(10_000),x-offset,label=col)
            offset -= x.min()
            
    if display:
        plt.legend()
        name = parquet_path.split('/')[-1]
        name = name.split('.')[0]
        plt.title(f'EEG {name}',size=16)
        plt.show()
        
    return data


In [77]:
# Read train_df
df = pd.read_csv(f'{Config.data_root}train.csv')

# Preprocess dataset 
train_df, num_classes, TARS, TARS_INV = preprocess_train_data(df)

# Set Num classes to Config.num_classes
Config.num_classes = num_classes

In [78]:
%%time

CREATE_EEGS = True
df_eeg = pd.read_parquet(f'{Config.data_root}train_eegs/1000913311.parquet')
print(f'There are {len(df_eeg.columns)} raw eeg features')
print(list(df_eeg.columns))

if Config.raw_eeg_path is not None:
    raw_eegs = np.load(Config.raw_eeg_path, allow_pickle=True).item()
else:

    all_eegs = {}
    DISPLAY = 4
    EEG_IDS = train.eeg_id.unique()
    PATH = f'{Config.data_root}train_eegs/'
    
    for i,eeg_id in enumerate(EEG_IDS):
        if (i%100==0)&(i!=0): print(i,', ',end='') 
        
        # SAVE EEG TO PYTHON DICTIONARY OF NUMPY ARRAYS
        data = eeg_from_parquet(f'{PATH}{eeg_id}.parquet', display=i<DISPLAY)              
        all_eegs[eeg_id] = data
        
        if i==DISPLAY:
            if CREATE_EEGS:
                print(f'Processing {train.eeg_id.nunique()} eeg parquets... ',end='')
            else:
                print(f'Reading {len(EEG_IDS)} eeg NumPys from disk.')
                break
                
    if CREATE_EEGS: 
        np.save(f'{Config.data_root}eegs_20ch',all_eegs)
print(f"Length of eegs: {len(raw_eegs)}")

There are 20 raw eeg features
['Fp1', 'F3', 'C3', 'P3', 'F7', 'T3', 'T5', 'O1', 'Fz', 'Cz', 'Pz', 'Fp2', 'F4', 'C4', 'P4', 'F8', 'T4', 'T6', 'O2', 'EKG']
Length of eegs: 17089
CPU times: total: 7.64 s
Wall time: 7.69 s


In [79]:
%%time
READ_SPEC_FILES = False

paths_spectograms = glob(Config.kaggle_spec_path + "*.parquet")
print(f'There are {len(paths_spectograms)} spectrogram parquets')

if READ_SPEC_FILES:    
    all_spectrograms = {}
    for file_path in tqdm(paths_spectograms):
        aux = pd.read_parquet(file_path)
        name = int(file_path.split("/")[-1].split('.')[0])
        all_spectrograms[name] = aux.iloc[:,1:].values
        del aux
else:
    all_spectrograms = np.load(Config.PRE_LOADED_SPECTOGRAMS, allow_pickle=True).item()

There are 11138 spectrogram parquets
CPU times: total: 5.97 s
Wall time: 5.99 s


In [80]:
%%time
READ_EEG_SPEC_FILES = False

paths_eegs = glob(Config.TRAIN_EEGS + "*.npy")
print(f'There are {len(paths_eegs)} EEG spectograms')

if READ_EEG_SPEC_FILES:
    all_eegs = {}
    for file_path in tqdm(paths_eegs):
        eeg_id = file_path.split("/")[-1].split(".")[0]
        eeg_spectogram = np.load(file_path)
        all_eegs[eeg_id] = eeg_spectogram
else:
    all_eegs = np.load(Config.PRE_LOADED_EEGS, allow_pickle=True).item()



There are 17089 EEG spectograms
CPU times: total: 7.64 s
Wall time: 7.66 s


In [81]:
print(f"Shape of spectrograms: {all_eegs[train.loc[0,'eeg_id']].shape}")
print(f"Shape of raw eegs: {raw_eegs[train.loc[0,'eeg_id']].shape}")

Shape of spectrograms: (128, 256, 4)
Shape of raw eegs: (10000, 8)


In [82]:
class ResNet_1D_Block(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling):
        super(ResNet_1D_Block, self).__init__()
        self.bn1 = nn.BatchNorm1d(num_features=in_channels)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.downsampling = downsampling

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)

        out = self.maxpool(out)
        identity = self.downsampling(x)

        out += identity
        return out

In [83]:
class EEGMegaNet(nn.Module):

    def __init__(self, backbone_2d,in_channels_2d, kernels, pretrained=False, in_channels=20, fixed_kernel_size=17, num_classes=6):
        super(EEGMegaNet, self).__init__()
        
        self.kernels = kernels
        self.planes = 24
        self.parallel_conv = nn.ModuleList()
        self.in_channels = in_channels


        
        self.backbone_2d = timm.create_model(
            Config.backbone_2d,
            pretrained=pretrained,
            drop_rate = 0.1,
            drop_path_rate = 0.1
        
        )
        
        self.features_2d = nn.Sequential(*list(self.backbone_2d.children())[:-2] + [nn.AdaptiveAvgPool2d(1),nn.Flatten()])
        
        # nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),

        for i, kernel_size in enumerate(list(self.kernels)):
            sep_conv = nn.Conv1d(in_channels=in_channels, out_channels=self.planes, kernel_size=(kernel_size),
                               stride=1, padding=0, bias=False,)
            self.parallel_conv.append(sep_conv)

        self.bn1 = nn.BatchNorm1d(num_features=self.planes)
        self.relu = nn.ReLU(inplace=False)
        self.conv1 = nn.Conv1d(in_channels=self.planes, out_channels=self.planes, kernel_size=fixed_kernel_size,
                               stride=2, padding=2, bias=False)
        self.block = self._make_resnet_layer(kernel_size=fixed_kernel_size, stride=1, padding=fixed_kernel_size//2)
        self.bn2 = nn.BatchNorm1d(num_features=self.planes)
        self.avgpool = nn.AvgPool1d(kernel_size=4, stride=4, padding=2)
        self.rnn = nn.GRU(input_size=self.in_channels, hidden_size=128, num_layers=1, bidirectional=True)
        
        self.fc1 = nn.Linear(in_features=1280, out_features=128)
        self.fc2 = nn.Linear(in_features=736, out_features=128)
        self.fc = nn.Linear(in_features=256, out_features=num_classes)

        self.fc1d = nn.Linear(in_features=128, out_features=num_classes)
        self.fc2d = nn.Linear(in_features=128, out_features=num_classes)
        
        
        self.rnn1 = nn.GRU(input_size=156, hidden_size=156, num_layers=1, bidirectional=True)

    def _make_resnet_layer(self, kernel_size, stride, blocks=8, padding=0):
        layers = []
        downsample = None
        base_width = self.planes

        for i in range(blocks):
            downsampling = nn.Sequential(
                    nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
                )
            layers.append(ResNet_1D_Block(in_channels=self.planes, out_channels=self.planes, kernel_size=kernel_size,
                                       stride=stride, padding=padding, downsampling=downsampling))

        return nn.Sequential(*layers)

    def _reshape_input(self, spec):
        """
        Reshapes input (128, 256, 8) -> (512, 512, 3) monotone image.
        """ 
        # === Get spectograms ===
        spectograms = [spec[:, :, :, i:i+1] for i in range(4)]
        spectograms = torch.cat(spectograms, dim=1)
        
        # === Get EEG spectograms ===
        eegs = [spec[:, :, :, i:i+1] for i in range(4,8)]
        eegs = torch.cat(eegs, dim=1)
        
        # === Reshape (512,512,3) ===
        spec = spectograms
            
        spec = torch.cat([spec,spec,spec], dim=3)
        spec = spec.permute(0, 3, 1, 2)
        return spec

    def forward(self, x, spec):

        spec = self._reshape_input(spec)
        spec = self.features_2d(spec)
        # print(spec.shape) #2, 1280, 16, 8
        out_sep = []

        for i in range(len(self.kernels)):
            sep = self.parallel_conv[i](x)
            out_sep.append(sep)

        out = torch.cat(out_sep, dim=2)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv1(out)  

        out = self.block(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.avgpool(out)  


        
        out = out.reshape(out.shape[0], -1)  

        rnn_out, _ = self.rnn(x.permute(0,2, 1))
        new_rnn_h = rnn_out[:, -1, :]  

        new_out = torch.cat([out, new_rnn_h], dim=1)  
        new_out = self.fc2(new_out)  
        out1d = self.fc1d(new_out)
        
        spec = self.fc1(spec)  
        out2d = self.fc2d(spec)
        
        result = torch.cat([new_out, spec], dim=1)  
        result = self.fc(result)
        
        
        return result, new_out, spec, out1d, out2d

In [84]:
iot = torch.randn(2, Config.num_channels, 10000)#.cuda()
spec = torch.randn(2, 128, 256, 8)#.cuda()

model = EEGMegaNet(backbone_2d=Config.backbone_2d,in_channels_2d=8,
                   kernels=[3,5,7,9],pretrained=False,
                   in_channels=Config.num_channels, fixed_kernel_size=5,
                   num_classes=6)#.cuda()
output,_,_,_,_ = model(iot, spec)
print(output.shape)

del iot, model
gc.collect()

torch.Size([2, 6])


32

In [88]:
def get_train_transform():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.OneOf([
                A.Cutout(max_h_size=5, max_w_size=16),
                A.CoarseDropout(max_holes=4),
            ], p=0.5),
    ])

def get_transforms(*, data):
    
    if data == 'train':
        return tA.Compose(
                transforms=[
                     # tA.ShuffleChannels(p=0.25,mode="per_channel",p_mode="per_channel",),
                     tA.AddColoredNoise(p=0.15,mode="per_channel",p_mode="per_channel", max_snr_in_db = 15, sample_rate=200),
                     tA.Shift(p=0.5,mode="per_example",p_mode="per_example",max_shift=0.025, min_shift=-0.025,sample_rate=200),
                ])

    elif data == 'valid':
        return tA.Compose([
        ])

freq_shift = FrequencyShift(
    probability=.3,
    sfreq=200,
    max_delta_freq=2.  # the frequency shifts are sampled now between -2 and 2 Hz
)

ch_drop = ChannelsDropout(probability=0.2, p_drop=0.2)
sign_flip = SignFlip(probability=.1)
ch_shuffle = ChannelsShuffle(probability=0.25)

transforms = [
    freq_shift,
    sign_flip,
    # ch_drop,
    # ch_shuffle    
]

class Compose(Transform):
    """Transform composition.

    Callable class allowing to cast a sequence of Transform objects into a
    single one.

    Parameters
    ----------
    transforms: list
        Sequence of Transforms to be composed.
    """

    def __init__(self, transforms):
        self.transforms = transforms
        super().__init__()

    def forward(self, X, y):
        for transform in self.transforms:
            X = transform(X, None)
        return X, y

def get_transforms(*, data):
    
    if data == 'train':
        return Compose(
                transforms=transforms)

    elif data == 'valid':
        return Compose([
        ])



In [89]:
from scipy.signal import butter, lfilter

def quantize_data(data, classes):
    mu_x = mu_law_encoding(data, classes)
    # bins = np.linspace(-1, 1, classes)
    # quantized = np.digitize(mu_x, bins) - 1
    return mu_x#quantized

def mu_law_encoding(data, mu):
    mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
    return mu_x

def mu_law_expansion(data, mu):
    s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu
    return s

def butter_lowpass_filter(data, cutoff_freq=20, sampling_rate=200, order=4):
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = lfilter(b, a, data, axis=0)
    return filtered_data

class EEGDataset(torch.utils.data.Dataset):

    def __init__(self, data, eegs=None, augmentations = None, test = False): 

        self.data = data
        self.eegs = eegs
        self.augmentations = augmentations
        self.test = test

        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):

        row = self.data.iloc[index]      
        data = self.eegs[row.eeg_id]

        data = np.clip(data,-1024,1024)
        data = np.nan_to_num(data, nan=0) / 32.0
        
        data = butter_lowpass_filter(data)
        data = quantize_data(data,1)

        samples = torch.from_numpy(data).float()
        
        # samples,_ = self.augmentations(samples.unsqueeze(0), None)
        # samples = samples.squeeze()
        # samples = np.stack(samples, axis=0)
        # samples = torch.from_numpy(samples).float()
    
        samples = samples.permute(1,0)
        if not self.test:
            label = row[TARGETS] 
            label = torch.tensor(label).float()  
            return samples, label
        else:
            return samples
# ================================

In [90]:

class EEGDataset(torch.utils.data.Dataset):

    def __init__(self, data, eegs=None, augmentations=None, test=False): 
        self.data = data
        self.eegs = eegs
        self.augmentations = augmentations
        self.test = test
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]      
        data = self.eegs[row.eeg_id]
        sample = np.zeros((data.shape[0], 8))  # Assuming data.shape[1] == 8 for 8 channels
        
        # Mapping of channel names to indices
        FEAT2IDX = {'Fp1': 0, 'T3': 1, 'C3': 2, 'O1': 3, 'Fp2': 4, 'C4': 5, 'T4': 6, 'O2': 7}
        
        # Compute differences
        sample[:,0] = data[:,FEAT2IDX['Fp1']] - data[:,FEAT2IDX['T3']]
        sample[:,1] = data[:,FEAT2IDX['T3']] - data[:,FEAT2IDX['O1']]
        
        sample[:,2] = data[:,FEAT2IDX['Fp1']] - data[:,FEAT2IDX['C3']]
        sample[:,3] = data[:,FEAT2IDX['C3']] - data[:,FEAT2IDX['O1']]
        
        sample[:,4] = data[:,FEAT2IDX['Fp2']] - data[:,FEAT2IDX['C4']]
        sample[:,5] = data[:,FEAT2IDX['C4']] - data[:,FEAT2IDX['O2']]
        
        sample[:,6] = data[:,FEAT2IDX['Fp2']] - data[:,FEAT2IDX['T4']]
        sample[:,7] = data[:,FEAT2IDX['T4']] - data[:,FEAT2IDX['O2']]

        # sample = np.concatenate([sample,data],1)
        
        # Feature Engineering on sample instead of data
        
        sample = (sample - np.mean(sample, axis=0)) / np.std(sample, axis=0)

        sample = np.clip(sample, -1024, 1024)
        sample = np.nan_to_num(sample, nan=0)# / 32.0
        
        sample = butter_lowpass_filter(sample)
        sample = quantize_data(sample, 1)


        # samples = self.augmentations(samples.unsqueeze(0))
        # samples = samples.squeeze()
        # samples = self.augmentations(sample)
        
        samples = torch.from_numpy(sample).float()

        
        # samples = np.stack(samples, axis=0)
        # samples = torch.from_numpy(samples).float()
        # samples = samples[::2,:]

        
        samples = samples.permute(1, 0)
        if not self.test:
            label = row[TARGETS]  # Assuming 'TARGETS' is defined somewhere as the label column name
            label = torch.tensor(label).float()  
            return samples, label
        else:
            return samples




In [96]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, data, eegs=None, specs=None, eeg_specs=None, spec_aug = False, augmentations=None, test=False): 
        self.data = data
        self.eegs = eegs
        self.specs = specs  # Spectrograms for each ID
        self.eeg_specs = eeg_specs  # EEG spectrograms for each ID
        self.augmentations = augmentations
        self.test = test
        self.spec_aug = spec_aug
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        
        # Processing EEG signal data
        eeg_data = self.eegs[row.eeg_id]
        sample = np.zeros((eeg_data.shape[0], 8))  # Assuming eeg_data.shape[1] == 8 for 8 channels
        FEAT2IDX = {'Fp1': 0, 'T3': 1, 'C3': 2, 'O1': 3, 'Fp2': 4, 'C4': 5, 'T4': 6, 'O2': 7}
        
        # Compute differences
        for i, (start, end) in enumerate([('Fp1', 'T3'), ('T3', 'O1'), ('Fp1', 'C3'), ('C3', 'O1'), 
                                          ('Fp2', 'C4'), ('C4', 'O2'), ('Fp2', 'T4'), ('T4', 'O2')]):
            sample[:, i] = eeg_data[:, FEAT2IDX[start]] - eeg_data[:, FEAT2IDX[end]]
        
        sample = self.process_sample(sample)
        
        samples = torch.from_numpy(sample).float()
        samples = samples.permute(1, 0)
        
        # Processing spectrogram data
        spec = self.__spec_data_generation(row)
        if self.spec_aug:
            spec = self.__transform(spec) 
        
        if not self.test:
            label = row[TARGETS]  # Assuming 'TARGETS' is defined somewhere as the label column name
            label = torch.tensor(label).float()  
            return samples, spec, label
        else:
            return samples, spec
    
    def process_sample(self, sample):
        # Normalize the sample data
        sample = (sample - np.mean(sample, axis=0)) / np.std(sample, axis=0)
        sample = np.clip(sample, -1024, 1024)
        sample = np.nan_to_num(sample, nan=0)
        sample = butter_lowpass_filter(sample)
        sample = quantize_data(sample, 1)
        return sample

    def __spec_data_generation(self, row):
        """
        Generates data containing batch_size samples. This method directly
        uses class attributes for spectrograms and EEG spectrograms.
        """
        X = np.zeros((128, 256, 8), dtype='float32')        
        if not self.test:
            # Assuming your DataFrame has a column that combines min and max values for slicing
            r = int((row['min'] + row['max']) // 4)
        else:
            r = 0  # Adjust as necessary for test mode
        
        for region in range(4):
            img = self.specs[row.spectogram_id][r:r+300, region*100:(region+1)*100].T
            
            # Log transform spectogram
            img = np.clip(img, np.exp(-4), np.exp(8))
            img = np.log(img)

            # Standarize per image
            ep = 1e-6
            mu = np.nanmean(img.flatten())
            std = np.nanstd(img.flatten())
            img = (img-mu)/(std+ep)
            img = np.nan_to_num(img, nan=0.0)
            X[14:-14, :, region] = img[:, 22:-22] / 2.0
        
        # Process EEG spectrogram - assuming a single channel example
        img = self.eeg_specs[row.eeg_id]
        X[:, :, 4:] = img                
        return X

    def __transform(self, img):
        transforms = A.Compose([
            A.HorizontalFlip(p=0.5),
        ])
        return transforms(image=img)['image']



In [97]:
transforms = dict(
    train=Compose([
        BandPass(lower=12, upper=512),
        GaussianNoiseSNR(min_snr=15, max_snr=30, p=0.5),
    ]),
    test=BandPass(lower=12, upper=512),
    # tta=BandPass(lower=12, upper=512)
)

def get_transforms(*, data):
    
    if data == 'train':
        return Compose(
                # transforms=transforms
            [
                # BandPass(lower=12, upper=512),
                # GaussianNoiseSNR(min_snr=15, max_snr=30, p=0.5),
                        FlipWave(p=0.5)
]
        )

    elif data == 'valid':
        return Compose([
            # BandPass(lower=12, upper=512),
            
        ])

NameError: name 'BandPass' is not defined

In [98]:
from sklearn.model_selection import *

gkf = GroupKFold(n_splits=5)
train['fold'] = 0
for fold, (tr_idx, val_idx) in enumerate(gkf.split(train, train.target, train.patient_id)):   
    train.loc[val_idx, 'fold'] = fold
# train.to_csv('/home/nischay/brain/Data/5gkf_fold.csv',index=False)


In [None]:
def get_fold_dls(df_train, df_valid):

    ds_train = EEGDataset(
        df_train, 
        eegs=raw_eegs,
        specs=all_spectrograms,
        eeg_specs=all_eegs,
        augmentations = get_transforms(data='train'),
        test = False
    )
    
    ds_val = EEGDataset(
        df_valid, 
        eegs=raw_eegs,
        specs=all_spectrograms,
        eeg_specs=all_eegs,
        augmentations = get_transforms(data='valid'),
        test = False
    )
    dl_train = DataLoader(ds_train, batch_size=Config.batch_size , shuffle=True, num_workers = 2)    
    dl_val = DataLoader(ds_val, batch_size=Config.batch_size, num_workers = 2)
    return dl_train, dl_val, ds_train, ds_val

In [None]:


def show_batch(img_ds, num_items, num_rows, num_cols, EEG_IDS, predict_arr=None):
    fig = plt.figure(figsize=(12, 6))    
    img_index = np.random.randint(0, len(img_ds)-1, num_items)
    for index, img_index in enumerate(img_index):  # list first items
        img,spec, lb = img_ds[img_index]    
        print(spec.shape)
        ax = fig.add_subplot(num_rows, num_cols, index + 1, xticks=[], yticks=[])
        if isinstance(img, torch.Tensor):
            img = img.detach().numpy()
            img = img.transpose(1,0)
            # print(img.shape)
        offset = 0
        for j in range(img.shape[-1]):
            if j != 0: offset -= img[:, j].min()
            ax.plot(img[:, j] + offset, label=f'feature {j+1}')
            offset += img[:, j].max() + 1  # Adding 1 for visual separation

        ax.legend()
        ax.set_title(f'EEG_Id = {EEG_IDS[img_index]}', size=14)

    plt.tight_layout()
    plt.show()


In [None]:
dummy_train = train[train['fold']!=0].copy()
dummy_valid = train[train['fold']==0].copy()


In [None]:
%%time

dl_train, dl_val, ds_train, ds_val = get_fold_dls(dummy_train, dummy_valid)
show_batch(ds_val, 8, 2, 4, EEG_IDS)

In [None]:
%%time

dl_train, dl_val, ds_train, ds_val = get_fold_dls(dummy_train, dummy_valid)
show_batch(ds_train, 8, 2, 4, EEG_IDS)

In [None]:
# Config.num_channels = 16

In [None]:
def get_optimizer(lr, params):
    model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params), 
            lr=lr,
            weight_decay=Config.weight_decay
        )
    interval = "epoch"
    
    lr_scheduler = CosineAnnealingWarmRestarts(
                            model_optimizer, 
                            T_0=Config.epochs, 
                            T_mult=1, 
                            eta_min=1e-7, 
                            last_epoch=-1
                        )

    return {
        "optimizer": model_optimizer, 
        "lr_scheduler": {
            "scheduler": lr_scheduler,
            "interval": interval,
            "monitor": "val_loss",
            "frequency": 1
        }
    }

In [None]:
from torchtoolbox.tools import mixup_data, mixup_criterion
import torch.nn as nn
from torch.nn.functional import cross_entropy
import torchmetrics
import timm
import sklearn.metrics
import sys
sys.path.append('/home/nischay/brain/scripts')

from kaggle_kl_div import score


In [None]:
dummy = dummy_valid.copy()
dummy[TARGETS] = np.random.rand(dummy.shape[0],len(TARGETS))

dummy[TARGETS] = dummy[TARGETS].div(dummy[TARGETS].sum(axis=1), axis=0)


In [None]:
dummy_valid.head(3)

In [None]:
dummy.head(3)

In [None]:
score(dummy_valid[['eeg_id']+list(TARGETS)], dummy[['eeg_id']+list(TARGETS)],row_id_column_name='eeg_id')

In [None]:
# mixup_criterion(KLDivLossWithLogits)

In [None]:
class KLDivLossWithLogits(nn.KLDivLoss):

    def __init__(self):
        super().__init__(reduction="batchmean")

    def forward(self, y, t):
        y = nn.functional.log_softmax(y,  dim=1)
        loss = super().forward(y, t)

        return loss


In [None]:
import random


In [None]:
class EEGModel(pl.LightningModule):
    def __init__(self, num_classes = Config.num_classes, pretrained = Config.pretrained, fold = fold):
        super().__init__()
        self.num_classes = num_classes
        self.fold = fold
        # self.backbone = EEGNet(kernels=[3,5,7,9], in_channels=Config.num_channels, fixed_kernel_size=5, num_classes=Config.num_classes)
        self.backbone = EEGMegaNet(backbone_2d=Config.backbone_2d,
                                   in_channels_2d=8,
                                   kernels=[3,5,7,9],pretrained=True,
                                   in_channels=Config.num_channels,
                                   fixed_kernel_size=5, num_classes=6)


        self.contrastive_loss = nn.CosineEmbeddingLoss()  # Using cosine similarity for contrastive loss

        self.loss_function = KLDivLossWithLogits() #nn.KLDivLoss() #nn.BCEWithLogitsLoss() 
        self.validation_step_outputs = []
        self.lin = nn.Softmax(dim=1)
        self.best_score = 1000.0
    def forward(self,eeg, spec):
        logits = self.backbone(eeg, spec)
        # logits = self.lin(logits)
        return logits
        
    def configure_optimizers(self):
        return get_optimizer(lr=Config.LR, params=self.parameters())

    def training_step(self, batch, batch_idx):
        eeg, spec, target = batch
        y_pred, embedding_1d, embedding_2d,yp1,yp2 = self(eeg, spec)
        classification_loss = self.loss_function(y_pred, target)
        
        classification_loss1 = self.loss_function(yp1, target)
        classification_loss2 = self.loss_function(yp2, target)
        

        # Calculate contrastive loss 
        
        embedding_1d = torch.nn.functional.normalize(embedding_1d, p=2, dim=1)
        embedding_2d = torch.nn.functional.normalize(embedding_2d, p=2, dim=1)

        
        contrastive_target = torch.ones(embedding_1d.size(0)).to(self.device)  # Assuming all pairs are similar
        contrastive_loss = self.contrastive_loss(embedding_1d, embedding_2d, contrastive_target)

        total_loss = classification_loss + classification_loss1*0.5 + classification_loss2*0.5 + contrastive_loss*0.5  # Aux losses

        self.log("train_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True)
        return total_loss     

    def validation_step(self, batch, batch_idx):
        eeg, spec, target = batch 
        # print(target)
        y_pred,_,_,y1d,y2d = self(eeg, spec)

        y_pred = y_pred*0.5 + y1d*0.25 + y2d*0.25
        val_loss = self.loss_function(y_pred, target)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        self.validation_step_outputs.append({"val_loss": val_loss, "logits": y_pred, "targets": target})

        return {"val_loss": val_loss, "logits": y_pred, "targets": target}
    
    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        # print(len(outputs))
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        output_val = nn.Softmax(dim=1)(torch.cat([x['logits'] for x in outputs],dim=0)).cpu().detach().numpy()
        target_val = torch.cat([x['targets'] for x in outputs],dim=0).cpu().detach().numpy()
        self.validation_step_outputs = []

        val_df = pd.DataFrame(target_val, columns = list(TARGETS))
        pred_df = pd.DataFrame(output_val, columns = list(TARGETS))

        val_df['id'] = [f'id_{i}' for i in range(len(val_df))] 
        pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))] 

        avg_score = avg_loss
        # avg_score = score(val_df, pred_df, row_id_column_name = 'id')

        if avg_score < self.best_score:
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation loss {avg_loss}')
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation KDL score {avg_score}')
            self.best_score = avg_score
            # val_df.to_csv(f'{Config.output_dir}/val_df_f{self.fold}.csv',index=False)
            # pred_df.to_csv(f'{Config.output_dir}/pred_df_f{self.fold}.csv',index=False)
        
        return {'val_loss': avg_loss,'val_cmap':avg_score}
    


In [None]:
from tqdm import tqdm
tqdm.pandas()

In [None]:
def predict(data_loader, model):
        
    model.to('cuda')
    model.eval()    
    predictions = []
    for batch in tqdm(data_loader):

        with torch.no_grad():
            x,x2, y = batch
            x = x.cuda()
            x2 = x2.cuda()
            
            # inputs = {key:val.reshape(val.shape[0], -1).to(config.device) for key,val in batch.items()}
            outputs,_,_,y1,y2 = model(x, x2)
            
            outputs = outputs*0.5 + y1*0.25 + y2*0.25
            
            outputs = nn.Softmax(dim=1)(outputs)
        predictions.extend(outputs.detach().cpu().numpy())
    predictions = np.vstack(predictions)
    return predictions

def predict2(ds_test, model):
    
    model.to('cuda')
    model.eval()    
    predictions = []
    for en in tqdm(range(len(ds_test))):
        # print(en)
        x,_ = ds_test[en]
        x = x.unsqueeze(0).cuda()
        # print(images.shape)
        with torch.no_grad():
            outputs,_,_ = model(x)
            outputs = nn.Softmax(dim=1)(outputs)
            outputs = outputs.detach().cpu().numpy()

        predictions.append(outputs)
        
    return predictions

In [None]:
from pytorch_lightning.loggers import WandbLogger
import gc
torch.set_float32_matmul_precision('high')
def run_training(fold_id, Config):
    print(f"Running training for fold {fold_id}...")
    logger = None
    pred_cols = [f'pred_{t}' for t in TARGETS]
    
    df_train = train[train['fold']!=fold_id].copy()
    df_valid = train[train['fold']==fold_id].copy()

    print(len(df_train),'train length')
    print(len(df_valid),'valid length')
    
    dl_train, dl_val, ds_train, ds_val = get_fold_dls(df_train, df_valid)
    
    eeg_model = EEGModel(num_classes = Config.num_classes, pretrained = Config.pretrained, fold = fold_id)

    
    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=Config.PATIENCE, verbose= True, mode="min")
    checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                          dirpath= f"{Config.output_dir}/",
                                      save_top_k=1,
                                      save_last= True,
                                      save_weights_only=False,
                                      filename= f'eegnet_best_loss_fold{fold_id}',
                                      verbose= True,
                                      mode='min')
    
    callbacks_to_use = [checkpoint_callback,early_stop_callback]


    trainer = pl.Trainer(
        # gpus=3,
        devices=[1],
        
        val_check_interval=0.5,
        deterministic=True,
        max_epochs=Config.epochs,
        # max_epochs=3,
        
        logger=logger,
        callbacks=callbacks_to_use,
        precision=Config.PRECISION,
        accelerator="gpu" 
    )
    

    print("Running trainer.fit")
    trainer.fit(eeg_model, train_dataloaders = dl_train, val_dataloaders = dl_val)                
    # trainer.

    model = EEGModel.load_from_checkpoint(f'{Config.output_dir}/eegnet_best_loss_fold{fold_id}.ckpt',train_dataloader=None,validation_dataloader=None,config=Config)    
    preds = predict(dl_val, model)  
    print(preds.shape)
    df_valid[pred_cols] = preds
    df_valid.to_csv(f'{Config.output_dir}/pred_df_f{fold_id}.csv',index=False)
    gc.collect()
    # torch.cuda.empty_cache()
    return preds
    

In [None]:
# run_training()

oof_df = train.copy()
pred_cols = [f'pred_{t}' for t in TARGETS]
oof_df[pred_cols] = 0.0
for f in Config.trn_folds:
    val_idx = list(train[train['fold']==f].index)
    print(len(val_idx))
    val_preds = run_training(f, Config)    
    # val_df = pd.read_csv(f'{Config.output_dir}/val_df_f{f}.csv')
    # pred_df = pd.read_csv(f'{Config.output_dir}/pred_df_f{f}.csv')
    oof_df.loc[val_idx, pred_cols] = val_preds
    

In [None]:
oof_df

In [None]:
oof_pred_df= oof_df[['eeg_id'] + list(['pred_'+i for i in TARGETS])]
oof_pred_df.columns = ['eeg_id'] + list(TARGETS)

oof_true_df = oof_df[oof_pred_df.columns].copy()

In [None]:
oof_score = score(solution=oof_true_df, submission=oof_pred_df, row_id_column_name='eeg_id')
print('OOF Score for solution =',oof_score)


In [None]:

Config.output_dir


In [None]:

oof_df.to_csv(f'{Config.output_dir}/oof.csv',index=False)
# pred_df[TARGETS].values.shape
# a.sum(axis=1)


In [None]:
val_idx = list(train[train['fold']==0].index)
oof_df.loc[val_idx, TARGETS]

In [None]:


# F0:0.803
# F1:0.719
# F2: 0.761
# F3: 0.743
# F4: 0.717

### 0.OOF Score for solution = 0.749223413337422



# F0:0.797
# F1:0.703
# F2: 0.781
# F3: 0.696
# F4: 0.705

### 0.OOF Score for solution = 0.736


# F0:0.776
# F1:0.691
# F2: 0.759
# F3: 0.743
# F4: 0.718

### 0.OOF Score for solution = 0.7376

