In [None]:
import numpy as np
import itertools
import random
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import models
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import pandas as pd
import csv
import pywt
from librosa.feature import melspectrogram
from librosa import power_to_db

# Check dataset path

In [None]:
!ls -al ../input/hms-harmful-brain-activity-classification

In [None]:
!pip install /kaggle/input/pip-whl/einops-0.7.0-py3-none-any.whl \
    /kaggle/input/pip-whl/beartype-0.18.2-py3-none-any.whl \
    /kaggle/input/pip-whl/rotary_embedding_torch-0.5.3-py3-none-any.whl \
    --default-timeout 1

# Model define

In [None]:
class WidthAttention(nn.Module):
    def __init__(self, in_ch, width: int, debug_mode=False):
        super().__init__()
        h_dim = 64
        self.attention = nn.Sequential(  # B, w
            nn.Conv2d(in_ch, h_dim, kernel_size=(1, 1)),
            nn.BatchNorm2d(h_dim),
            nn.SiLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(h_dim, width),
            nn.Sigmoid()
        )
        self.feat_atten = None
        if not debug_mode:
            self.attention.register_forward_hook(self._capture_attention)

    def _capture_attention(self, module, input, output):
        self.feat_atten = output

    def forward(self, x):
        attention = self.attention(x)
        attention = attention.unsqueeze(1).unsqueeze(1)
        return x * attention


class MultiHeadAttention(nn.Module):
    def __init__(self, in_ch, width: int, heads: int, debug_mode=False):
        super().__init__()
        self.attentions = nn.ModuleList([WidthAttention(in_ch // (heads * 2), width, debug_mode) for _ in range(heads)])
        assert in_ch % heads == 0, f'in_ch: {in_ch} must be divisible by heads: {heads}'
        self.projections = nn.ModuleList([nn.Conv2d(in_ch, in_ch // (heads * 2), kernel_size=(1, 1)) for _ in range(heads)])
        self.fuse = nn.Sequential(
            nn.Conv2d(in_ch // 2, in_ch // 4, kernel_size=(1, 1)),
            nn.BatchNorm2d(in_ch // 4),
            nn.SiLU(inplace=True),
            nn.Conv2d(in_ch // 4, in_ch // 4, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(in_ch // 4),
            nn.SiLU(inplace=True),
            nn.Conv2d(in_ch // 4, in_ch, kernel_size=(1, 1)),
            nn.BatchNorm2d(in_ch),
            nn.SiLU(inplace=True),
        )

    def forward(self, x):
        heads = []
        for i, atten in enumerate(self.attentions):
            head = self.projections[i](x)
            head = atten(head)
            heads.append(head)
        return self.fuse(torch.cat(heads, 1))


class EfficientNet(nn.Module):
    def __init__(self, width: int, in_ch=4, num_classes=6, weights='IMAGENET1K_V1', use_attention=False, debug_mode=False, cnn_type='b0'):
        super(EfficientNet, self).__init__()
        if cnn_type == 'b0':
            efficientnet = models.efficientnet_b0(weights=weights)
            ori_net = list(efficientnet.features.children())
            cnn_ch = 1280
            c0_ch = 32
            w_factor = 5
        elif cnn_type == 'v2s':
            efficientnet = models.efficientnet_v2_s(weights=weights)
            ori_net = list(efficientnet.features.children())
            cnn_ch = 1280
            c0_ch = 24
            w_factor = 5
        elif cnn_type == 'convnext-tiny':
            network = models.convnext_tiny(weights=weights)
            ori_net = list(network.children())[0]
            cnn_ch = 768
            c0_ch = 96
            w_factor = 4
        elif cnn_type == 'mobilenet':
            network = models.mobilenet_v3_large(weights=weights)
            ori_net = list(network.children())[0]
            cnn_ch = 960
            c0_ch = 16
            w_factor = 5
        else:
            raise NotImplementedError(f'cnn_type: {cnn_type} not implemented')
        self.feat_atten = None
        if use_attention:
            w = width // (2 ** 5)
            self.width_attention = MultiHeadAttention(cnn_ch, w + 1, 4, debug_mode)
            ori_net.append(self.width_attention)
        self.features = nn.Sequential(*ori_net)
        # self.features[0][0] = nn.Conv2d(in_ch, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        stride = 2 ** (5 - w_factor + 1)
        self.features[0][0] = nn.Conv2d(in_ch, c0_ch, kernel_size=(17, 17), stride=(stride, stride), padding=(8, 8), bias=False)
        # self.features[0][0] = nn.Conv2d(in_ch, 32, kernel_size=(11, 11), stride=(2, 2), padding=(6, 6), bias=False)
        self.adv_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(cnn_ch, num_classes, bias=False)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.adv_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


class ElanBlock(nn.Module):
    def __init__(self, h_dim, depth=3):
        super().__init__()
        self.depth = depth
        self.cv1 = nn.Conv2d(h_dim, h_dim // 2, kernel_size=(1, 1))
        self.necks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(h_dim // 4, h_dim // 4, kernel_size=(5, 5), padding=(2, 2), groups=h_dim // 4, bias=False),
                nn.BatchNorm2d(h_dim // 4),
                nn.GELU(),
                nn.Conv2d(h_dim // 4, h_dim // 4, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(h_dim // 4),
            ) for _ in range(depth)
        ])
        self.cv2 = nn.Sequential(
            nn.Conv2d(h_dim // 4 * (1 + depth), h_dim, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(h_dim),
            nn.GELU(),
        )

    def forward(self, x):
        x = self.cv1(x)
        xs = list(torch.chunk(x, 2, 1))
        x = xs[-1]
        xs.pop()
        for neck in self.necks:
            x = neck(x)
            xs.append(x)
        x = self.cv2(torch.cat(xs, 1))
        return x


class MultiHeadRotaryAttention(nn.Module):
    def __init__(self, h_dim: int, heads: int, is_2d=False):
        super().__init__()
        self.heads = heads
        self.h_dim = h_dim
        self.is_2d = is_2d
        self.rope = rotary_embedding_torch.RotaryEmbedding(dim=h_dim // heads // 2)
        if is_2d:
            self.projection = nn.Conv2d(h_dim, h_dim, kernel_size=(1, 1))
        else:
            self.projection = nn.Linear(h_dim, h_dim * 3)
        self.out = nn.Sequential(
            nn.Linear(h_dim, h_dim),
            nn.LayerNorm(h_dim),
        )

    def forward(self, x):
        xp = self.projection(x)
        if self.is_2d:
            q, k, v = map(lambda t: rearrange(t, 'b (hs d) h w -> b hs h w d', hs=self.heads), [xp, xp, xp])
            q, k = map(self.rope.rotate_queries_or_keys, [q, k])
            q, k, v = map(lambda t: rearrange(t, 'b hs h w d -> b hs (h w) d', hs=self.heads), [q, k, v])
            attention = torch.einsum('b h i d, b h j d -> b h i j', q, k) * (self.h_dim ** -0.5)
            attention = torch.nn.functional.softmax(attention, dim=-1)
            xv = torch.einsum('b h i j, b h j c -> b h i c', attention, v)
            x = rearrange(xv, 'b h s c -> b s (h c)') + rearrange(x, 'b c h w -> b (h w) c')
        else:
            q, k, v = xp.chunk(3, dim=-1)
            q, k, v = map(lambda t: rearrange(t, 'b s (h d) -> b h s d', h=self.heads), [q, k, v])
            q, k = map(self.rope.rotate_queries_or_keys, [q, k])
            attention = torch.einsum('b h i d, b h j d -> b h i j', q, k) * (self.h_dim ** -0.5)
            attention = torch.nn.functional.softmax(attention, dim=-1)
            xv = torch.einsum('b h i j, b h j d -> b h i d', attention, v)
            x = rearrange(xv, 'b h s d -> b s (h d)') + x
        return self.out(x)


class EfficientTransNet(nn.Module):
    def __init__(self,
                 in_ch=4,
                 num_classes=6,
                 hid_dim=128,
                 weights='IMAGENET1K_V1',
                 debug_mode=False,
                 width=656 + 256,
                 pe_type='rotary',
                 use_pe_2d=False,
                 heads=8
         ):
        super(EfficientTransNet, self).__init__()
        efficientnet = models.efficientnet_b0(weights=weights)
        ori_net = list(efficientnet.features.children())
        self.feat_atten = None
        self.features = nn.Sequential(*ori_net[:-2])
        self.features[0][0] = nn.Conv2d(in_ch, 32, kernel_size=(17, 17), stride=(2, 2), padding=(8, 8), bias=False)
        self.cnn_proj = nn.Sequential(
            nn.Conv2d(192, hid_dim, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(hid_dim),
            nn.GELU(),
        )
        # Transformer related
        # Create class token with 2d PE
        self.cls_token = nn.Parameter(torch.randn(1, 1, hid_dim), requires_grad=True)
        # self.cls_token = nn.Parameter(torch.zeros((1, hid_dim, 1, self.get_size(width))), requires_grad=True)
        self.heads = heads
        self.hid_dim = hid_dim
        self.pe_type = pe_type
        if pe_type == 'rotary':
            self.mha = MultiHeadRotaryAttention(hid_dim, self.heads, is_2d=use_pe_2d)
        else:
            self.mha = nn.MultiheadAttention(hid_dim, self.heads)
        self.trans_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hid_dim, nhead=self.heads, dim_feedforward=hid_dim, dropout=0.1, batch_first=True),
            num_layers=2
        )
        # output layer
        self.cls_head = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hid_dim, hid_dim // 4),
            nn.Linear(hid_dim // 4, num_classes, bias=False)
        )

    def generate_positional_encoding(self, seq_len):
        hid_dim = self.hid_dim
        # Initialize the positional encoding matrix
        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hid_dim, 2).float() * (-math.log(10000.0) / hid_dim))
        pe = torch.zeros(seq_len, hid_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe

    @staticmethod
    def get_size(w, level=5):
        for _ in range(level):
            b = 1 if w % 2 == 1 else 0
            w = int(w / 2 + b)
        return int(w)

    def forward(self, x):
        # extract CNN features
        x = self.features(x)  # [B, 192, 4, w]
        x = self.cnn_proj(x)  # [B, 256, 4, w]
        # Transformer encoder related operations
        # Project and convert to channel-last
        x = rearrange(x, 'b c h w -> b (h w) c')
        # Append class token
        cls_token = self.cls_token.repeat_interleave(x.size(0), 0)
        x = torch.cat([cls_token, x], 1)
        if self.pe_type == 'rotary':
            x = self.mha(x)
        else:
            pe = self.generate_positional_encoding(x.size(1)).to(x.device)
            v, _ = self.mha[1](x + pe, x + pe, x + pe)
            x = v + x
        x = self.trans_encoder(x)[:, 0]
        # projection head
        x = self.cls_head(x)
        return x

# Data loader

In [None]:


class InferHMSDataset(Dataset):
    def __init__(
            self,
            dataset_path=None,
            csv_path=None,
            s_mean=None,
            s_std=None,
            cat_type='x',
            df=None,
            crop_size_sp=None,
            crop_size_eeg=None,
            slide_wd_ratio=0.5,  # -1 for no sliding window
            is_train=False,
            transform_sp=None,
            transform_eeg=None,
            duration=None,
    ):
        self.num_class = 6
        # Read the main csv file
        if df is None and dataset_path and csv_path:
            self.df = pd.read_csv(os.path.join(dataset_path, csv_path))
        elif df and (dataset_path is None and csv_path is None):
            self.df = df
        else:
            raise ValueError('Either df or dataset_path and csv_path must be provided')

        if s_std is None:
            s_std = [2.3019, 2.2844, 2.2947, 2.3129]
        if s_mean is None:
            s_mean = [-.1119, -.129, -.1395, -.1689]
        self.dataset_path = dataset_path
        self.mean = s_mean
        self.std = s_std
        self.cat_type: str = cat_type
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.transform_sp = transform_sp if transform_sp else self.transform
        self.transform_eeg = transform_eeg if transform_eeg else self.transform
        if crop_size_sp is None:
            crop_size_sp = [100, 256]
        self.crop_size_sp = crop_size_sp
        if crop_size_eeg is None:
            crop_size_eeg = [100, 256]
        self.crops_size_eeg = crop_size_eeg
        self.slide_wd_ratio = slide_wd_ratio
        self.base_dir = 'train' if is_train else 'test'
        self.is_train = is_train
        self.sensor_names = ['LL', 'RL', 'RP', 'LP']
        self.eeg_feat = [['Fp1', 'F7', 'T3', 'T5', 'O1'],
                         ['Fp1', 'F3', 'C3', 'P3', 'O1'],
                         ['Fp2', 'F8', 'T4', 'T6', 'O2'],
                         ['Fp2', 'F4', 'C4', 'P4', 'O2']]
        self.eeg_sample_duration = duration if duration is not None else [10, 15, 30, 45]
        self.spec_sample_duration = [600]

    def __len__(self):
        return len(self.df)

    def log(self, kwargs: dict):
        import mlflow
        prefix = 'valid' if not self.is_train else 'train'
        transform_sp_list = [t.__class__.__name__ for t in self.transform_sp.transforms]
        transform_eeg_list = [t.__class__.__name__ for t in self.transform_eeg.transforms]
        kwargs = {f'{prefix}-{k}': v for k, v in kwargs.items()}
        mlflow.log_params({
            f'{prefix}-mean': self.mean,
            f'{prefix}-std': self.std,
            f'{prefix}-transform_sp': transform_sp_list,
            f'{prefix}-transform_eeg': transform_eeg_list,
            **kwargs
        })

    def norm_img(self, img: np.ndarray, idx: int) -> np.ndarray:
        """ Input image should be channel-last image """
        # std_img = (img - self.mean[idx]) / self.std[idx]
        eps = 1e-6
        std_img = (img - np.nanmean(img)) / (np.nanstd(img) + eps)
        return std_img

    def read_eeg(self, eeg_id, offset, h=50, w=256) -> list[list[np.ndarray]]:
        parquet_path = f'{self.base_dir}_eegs/{eeg_id}.parquet'
        parquet_path = os.path.join(self.dataset_path, parquet_path)
        # LOAD MIDDLE 50 SECONDS OF EEG SERIES
        eeg_ori = pd.read_parquet(parquet_path)
        # Load multiple eeg from durations
        d_imgs = []
        for duration in self.eeg_sample_duration:
            middle = int(offset + 25)
            eeg = eeg_ori.iloc[(middle - duration // 2) * 200:(middle + duration // 2) * 200]
            # VARIABLE TO HOLD SPECTROGRAM
            img = np.zeros((h, w, 4), dtype='float32')
            signals = []
            for k in range(4):
                COLS = self.eeg_feat[k]
                for kk in range(4):
                    # COMPUTE PAIR DIFFERENCES
                    x = eeg[COLS[kk]].values - eeg[COLS[kk + 1]].values
                    # FILL NANS
                    m = np.nanmean(x)
                    m = 0 if np.isnan(m) else m
                    if np.isnan(x).mean() < 1:
                        x = np.nan_to_num(x, nan=m)
                    else:
                        x[:] = 0
                    # DENOISE
                    signals.append(x)
                    # RAW SPECTROGRAM
                    mel_spec = melspectrogram(y=x, sr=200, hop_length=len(x) // w, n_fft=1024, n_mels=h, fmin=0,
                                              fmax=20, win_length=128)
                    # LOG TRANSFORM
                    width = (mel_spec.shape[1] // 32) * 32
                    mel_spec_db = power_to_db(mel_spec, ref=np.max).astype(np.float32)[:, :width]
                    # STANDARDIZE TO -1 TO 1
                    mel_spec_db = (mel_spec_db + 40) / 40
                    img[:, :, k] += mel_spec_db
                # AVERAGE THE 4 MONTAGE DIFFERENCES
                img[:, :, k] /= 4.0
            d_imgs.append([img[:, :, i] for i in range(img.shape[-1])])
        return d_imgs

    def read_spectrogram(self, spectrogram_id, offset) -> list[np.ndarray]:
        """ Read the given spectrogram and apply normalization """
        spectrogram_path = f'{self.base_dir}_spectrograms/{spectrogram_id}.parquet'
        raw = pd.read_parquet(os.path.join(self.dataset_path, spectrogram_path)).fillna(0)
        sensor_types = self.sensor_names
        raw = raw.loc[(raw.time >= offset) & (raw.time < offset + 600)]
        sensor_data = [list(raw.filter(like=s, axis=1)) for s in sensor_types]
        sensor_data = [np.log1p(raw[s].T.values) for s in sensor_data]
        sensor_data = self.norm_img(np.stack(sensor_data, 0), -1)
        sensor_data = np.split(sensor_data, sensor_data.shape[0], 0)
        sensor_data = [np.nan_to_num(s, nan=0)[0] for s in sensor_data]
        return sensor_data

    def cat_imgs(self, img_list: list[torch.Tensor], data_type: str):
        cat_dim = 'ch,x,y'.split(',').index(self.cat_type)
        if data_type == 'eeg':
            n = len(img_list) // len(self.eeg_sample_duration)
            img_list = [img_list[i:i + n] for i in range(0, len(img_list), n)]
        else:
            img_list = [img_list]

        result = []
        for l in img_list:
            l = torch.cat(l, cat_dim)
            result.append(l)
        if len(result) == 4:
            result = [torch.cat(result[:2], -1), torch.cat(result[2:], -1)]
        return torch.cat(result, 1)

    def length_proces(self, t: torch.Tensor, data_type: str):
        """ Process the length of the given tensor on the last dimension (image width) """
        assert data_type in ['spec', 'eeg'], f'data_type {data_type} not implemented'
        base_length = self.crop_size_sp[1] if data_type == 'spec' else self.crops_size_eeg[1]
        if data_type == 'eeg':
            base_length *= 1 if len(self.eeg_sample_duration) <= 2 else len(self.eeg_sample_duration) // 2
        min_sliding_len = base_length * self.slide_wd_ratio
        # padding short image
        if t.shape[-1] < base_length:
            pad = base_length - t.shape[-1]
            t = torch.nn.functional.pad(t, (0, pad), 'constant', 0)
            return t.unsqueeze(0)
        # sliding window
        if self.slide_wd_ratio > 0 and t.shape[-1] > base_length:
            t_list = []
            while t.shape[-1] > min_sliding_len:
                t_list.append(t[:, :, :base_length])
                t = t[:, :, int(base_length * self.slide_wd_ratio):]
            if t_list[-1].shape[-1] < base_length:
                pad = base_length - t_list[-1].shape[-1]
                t_list[-1] = torch.nn.functional.pad(t_list[-1], (0, pad), 'constant', 0)
            t = torch.stack(t_list, 0)
            return t
        return t.unsqueeze(0)

    def process_spectrogram(self, spectrogram: list[torch.Tensor], data_type: str):
        spectrogram = self.cat_imgs(spectrogram, data_type)
        spectrogram = self.length_proces(spectrogram, data_type)
        return spectrogram

    def get_spectrogram_by_id(self, spectrogram_id, offset: float) -> list[np.ndarray]:
        spectrogram_images = self.read_spectrogram(spectrogram_id, offset)
        return spectrogram_images

    def get_eeg_by_id(self, eeg_id, offset) -> list[np.ndarray]:
        eeg_images = self.read_eeg(eeg_id, offset)
        eeg_image = [eeg for eeg_list in eeg_images for eeg in eeg_list]
        return eeg_image

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        eeg_id = row['eeg_id']
        spectrogram_id = row['spectrogram_id']
        offset_eeg = 0
        offset_spec = 0
        eeg, spec = self.get_eeg_by_id(eeg_id, offset_eeg), self.get_spectrogram_by_id(spectrogram_id, offset_spec)
        eeg, spec = [[trans(s) for s in ss] for trans, ss in zip([self.transform_eeg, self.transform_sp], [eeg, spec])]
        return self.process_spectrogram(eeg, 'eeg'), self.process_spectrogram(spec, 'spec'), eeg_id

    @staticmethod
    def collate_fn(batch):
        transposed_batch = list(zip(*batch))
        eeg, spec, eeg_id = transposed_batch
        data = []
        for sidx, s in enumerate(spec):
            bidx = s.size()[0]
            e = eeg[sidx] * torch.ones((bidx, 1, 1, 1))
            data.append(torch.cat([s, e], -1))
        return data, eeg_id

# Helper functions

In [None]:
def load_ema_model(checkpoint_path, model, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['ema'])

def sliding_pred(imgs: torch.Tensor, model: torch.nn.Module, aggregate_method):
    imgs = imgs.split(1, 0)
    preds = [model(img) for img in imgs]
    if aggregate_method == 'max':
        preds = torch.stack(preds).max(0).values
    elif aggregate_method == 'mean':
        preds = torch.stack(preds).mean(0)
    else:
        raise NotImplementedError(f'aggregate_method: {aggregate_method} not implemented')
    return preds

# Inference code

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = InferHMSDataset(
    dataset_path='../input/hms-harmful-brain-activity-classification',
    csv_path='test.csv',
    cat_type='x',
    is_train=False,
)

In [None]:
kwargs = [
    {'w': 
    '/kaggle/input/hms-baseline-efficient-b0/pytorch/efficientb0-wo-pl/1/f0.pt', 'cnn_type': 'b0'},
    {'w': 
    '/kaggle/input/hms-baseline-efficient-b0/pytorch/efficientb0-wo-pl/1/f2.pt', 'cnn_type': 'b0'},
    {'w': 
    '/kaggle/input/hms-baseline-efficient-b0/pytorch/efficientb0-wo-pl/1/f3.pt', 'cnn_type': 'b0'},
    {'w': 
    '/kaggle/input/hms-baseline-efficient-b0/pytorch/mobilenetv3/1/f0.pt', 'cnn_type': 'mobilenet'},
    {'w': 
    '/kaggle/input/hms-baseline-efficient-b0/pytorch/mobilenetv3/1/f2.pt', 'cnn_type': 'mobilenet'},
    {'w': 
    '/kaggle/input/hms-baseline-efficient-b0/pytorch/mobilenetv3/1/f3.pt', 'cnn_type': 'mobilenet'},
]

model_weights = [1.2, 1.2, 1.2, 1, 1, 1]
ensemble_models = []
for kwarg in kwargs:
    w_path = kwarg['w']
    del kwarg['w']
    net = EfficientNet(in_ch=1, width=256 + 256*2, weights=None, use_attention=False, **kwarg)
    load_ema_model(w_path, net, device)
    net.to(device)
    net.eval()
    ensemble_models.append(net)

In [None]:
# Open output file
with open('submission_david_lb31.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow('eeg_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote'.split(','))

    data_loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=4, collate_fn=dataset.collate_fn)
    dataset_iter = tqdm(data_loader, total=len(dataset), desc='Infering')
    for data in dataset_iter:
        img_tensor, eeg_id = data
        with torch.no_grad():
            outputs = []
            for x in img_tensor:
                pred_m = None
                for model, w in zip(ensemble_models, model_weights):
                    pred = sliding_pred(x.to(device), model, 'mean')
                    pred_prob = torch.nn.functional.softmax(pred, dim=1)[0]
                    pred_m = pred_prob * w if pred_m is None else pred_m + pred_prob * w
                pred_m = pred_m / sum(model_weights)
                outputs.append(pred_m)
                
        for oidx, output in enumerate(outputs):
            probability = output.cpu().tolist()
            content = [eeg_id[oidx].item()] + probability
            writer.writerow(content)

In [None]:
pd.read_csv('submission_david_lb31.csv')