# [CV 0.68 | LB 0.46] DilatedInception WaveNet in PyTorch - Inference

## Introduction
After joining this competition, I focus on validating how far raw EEG signals can go through many experiments. Finally, I find an simple architecture mixing the concept of **dilation** and **inception**, which can be seen as an extension of [Chris' version](https://www.kaggle.com/code/cdeotte/wavenet-starter-lb-0-52). And, I'm happy to announce that we can achieve CV 0.68 (below 0.7) and LB 0.46 with this architecture. Don't forget to upvote Chris' notebook!

## About this Notebook
In this kernel, I run the inference process with 5-fold models (equally-weighted blending) and obtain LB 0.46. If you're also interested in training part, pleasse see [[LB 0.46] DilatedInception WaveNet - Training](https://www.kaggle.com/code/abaojiang/lb-0-46-dilatedinception-wavenet-training).

## Acknowledgements
Special thanks to [@cdeotte](https://www.kaggle.com/cdeotte)'s sharing, [WaveNet Starter - [LB 0.52]](https://www.kaggle.com/code/cdeotte/wavenet-starter-lb-0-52).

<a id="toc"></a>
## Table of Contents
* [1. Load Data](#load_data)
* [2. Define Dataset](#dataset)
* [3. Create Test Loader](#test_loader)
* [4. Define Model Architecture](#model)
* [5. Load Models](#load_models)
* [6. Run Inference](#infer)
* [7. Submission](#sub)

## Import Packages

In [None]:
import gc
import os
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import pickle
import warnings
from pathlib import Path
from tqdm.notebook import tqdm
warnings.simplefilter("ignore")

import numpy as np
import pandas as pd
import yaml
from scipy.signal import butter, lfilter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

## Define Data Paths and Configuration and Metadata

In [None]:
DATA_PATH = Path("/kaggle/input/hms-harmful-brain-activity-classification")

class CFG:
    exp_id = "0311-17-20-55"
    model_path = Path("/kaggle/input/dilated-wavenet/0311-17-20-55")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # == Data ==
    # Chris' 8 channels
    feats = [
        "Fp1", "T3", "C3", "O1",
        "Fp2", "C4", "T4", "O2"
    ]
    cast_eegs = True
    dataset = {
        "eeg": {
            "n_feats": 8,
            "apply_chris_magic_ch8": True,
            "normalize": True,
            "apply_butter_lowpass_filter": True,
            "apply_mu_law_encoding": False,
            "downsample": 5
        }
    }
    
    # == Data Loader ==
    batch_size = 32
    
    
N_CLASSES = 6
TGT_VOTE_COLS = [
    "seizure_vote", "lpd_vote", "gpd_vote", "lrda_vote",
    "grda_vote", "other_vote"
]
EEG_FREQ = 200  # Hz
EEG_WLEN = 50  # sec
EEG_PTS = int(EEG_FREQ * EEG_WLEN)

<a id="load_data"></a>
## 1. Load Data
[**<span style="color:#FEF1FE; background-color:#535d70;border-radius: 5px; padding: 2px">Go to Table of Content</span>**](#toc)

Note test data will be replaced with **hidden test set** during rerun.

In [None]:
def _get_eeg_window(file: Path) -> np.ndarray:
    """Return cropped EEG window.

    Default setting is to return the middle 50-sec window.

    Args:
        file: EEG file path
        test: if True, there's no need to truncate EEGs

    Returns:
        eeg_win: cropped EEG window 
    """
    eeg = pd.read_parquet(file, columns=CFG.feats)
    n_pts = len(eeg)
    offset = (n_pts - EEG_PTS) // 2
    eeg = eeg.iloc[offset:offset + EEG_PTS]
    
    eeg_win = np.zeros((EEG_PTS, len(CFG.feats)))
    for j, col in enumerate(CFG.feats):
        if CFG.cast_eegs:
            eeg_raw = eeg[col].values.astype("float32")
        else:
            eeg_raw = eeg[col].values 

        # Fill missing values
        mean = np.nanmean(eeg_raw)
        if np.isnan(eeg_raw).mean() < 1:
            eeg_raw = np.nan_to_num(eeg_raw, nan=mean)
        else: 
            # All missing
            eeg_raw[:] = 0
        eeg_win[:, j] = eeg_raw 
        
    return eeg_win 

In [None]:
test = pd.read_csv(DATA_PATH / "test.csv")
print(f"Test data shape | {test.shape}")

In [None]:
uniq_eeg_ids = test["eeg_id"].unique()
n_uniq_eeg_ids = len(uniq_eeg_ids)

all_eegs = {}
for i, eeg_id in tqdm(enumerate(uniq_eeg_ids), total=n_uniq_eeg_ids):
    eeg_win = _get_eeg_window(DATA_PATH / "test_eegs" / f"{eeg_id}.parquet")
    all_eegs[eeg_id] = eeg_win

print(f"Demo EEG shape | {list(all_eegs.values())[0].shape}")

<a id="dataset"></a>
## 2. Define Dataset
[**<span style="color:#FEF1FE; background-color:#535d70;border-radius: 5px; padding: 2px">Go to Table of Content</span>**](#toc)

In [None]:
class _EEGTransformer(object):
    """Data transformer for raw EEG signals."""

    FEAT2CODE = {f: i for i, f in enumerate(CFG.feats)}

    def __init__(
        self,
        n_feats: int,
        apply_chris_magic_ch8: bool = True,
        normalize: bool = True,
        apply_butter_lowpass_filter: bool = True,
        apply_mu_law_encoding: bool = False,
        downsample: Optional[int] = None,
    ) -> None:
        self.n_feats = n_feats
        self.apply_chris_magic_ch8 = apply_chris_magic_ch8
        self.normalize = normalize
        self.apply_butter_lowpass_filter = apply_butter_lowpass_filter
        self.apply_mu_law_encoding = apply_mu_law_encoding
        self.downsample = downsample

    def transform(self, x: np.ndarray) -> np.ndarray:
        """Apply transformation on raw EEG signals.
        
        Args:
            x: raw EEG signals, with shape (L, C)

        Return:
            x_: transformed EEG signals
        """
        x_ = x.copy()
        if self.apply_chris_magic_ch8:
            x_ = self._apply_chris_magic_ch8(x_)

        if self.normalize:
            x_ = np.clip(x_, -1024, 1024)
            x_ = np.nan_to_num(x_, nan=0) / 32.0

        if self.apply_butter_lowpass_filter:
            x_ = self._butter_lowpass_filter(x_) 

        if self.apply_mu_law_encoding:
            x_ = self._quantize_data(x_, 1)

        if self.downsample is not None:
            x_ = x_[::self.downsample, :]

        return x_

    def _apply_chris_magic_ch8(self, x: np.ndarray) -> np.ndarray:
        """Generate features based on Chris' magic formula.""" 
        x_tmp = np.zeros((EEG_PTS, self.n_feats), dtype="float32")

        # Generate features
        x_tmp[:, 0] = x[:, self.FEAT2CODE["Fp1"]] - x[:, self.FEAT2CODE["T3"]]
        x_tmp[:, 1] = x[:, self.FEAT2CODE["T3"]] - x[:, self.FEAT2CODE["O1"]]
        
        x_tmp[:, 2] = x[:, self.FEAT2CODE["Fp1"]] - x[:, self.FEAT2CODE["C3"]]
        x_tmp[:, 3] = x[:, self.FEAT2CODE["C3"]] - x[:, self.FEAT2CODE["O1"]]
        
        x_tmp[:, 4] = x[:, self.FEAT2CODE["Fp2"]] - x[:, self.FEAT2CODE["C4"]]
        x_tmp[:, 5] = x[:, self.FEAT2CODE["C4"]] - x[:, self.FEAT2CODE["O2"]]
        
        x_tmp[:, 6] = x[:, self.FEAT2CODE["Fp2"]] - x[:, self.FEAT2CODE["T4"]]
        x_tmp[:, 7] = x[:, self.FEAT2CODE["T4"]] - x[:, self.FEAT2CODE["O2"]]

        return x_tmp

    def _butter_lowpass_filter(self, data, cutoff_freq=20, sampling_rate=200, order=4):
        nyquist = 0.5 * sampling_rate
        normal_cutoff = cutoff_freq / nyquist
        b, a = butter(order, normal_cutoff, btype="low", analog=False)
        filtered_data = lfilter(b, a, data, axis=0)

        return filtered_data
                
    def _quantize_data(self, data, classes):
        mu_x = self._mu_law_encoding(data, classes)
        
        return mu_x

    def _mu_law_encoding(self, data, mu):
        mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)

        return mu_x

In [None]:
class EEGDataset(Dataset):
    """Dataset for pure raw EEG signals.

    Args:
        data: processed data
        split: data split

    Attributes:
        _n_samples: number of samples
        _infer: if True, the dataset is constructed for inference
            *Note: Ground truth is not provided.
    """

    def __init__(
        self,
        data: Dict[str,  Any],
        split: str,
        **dataset_cfg: Any,
    ) -> None:
        self.metadata = data["meta"]
        self.all_eegs = data["eeg"]
        self.dataset_cfg = dataset_cfg

        # Raw EEG data transformer
        self.eeg_params = dataset_cfg["eeg"]
        self.eeg_trafo = _EEGTransformer(**self.eeg_params)

        self._set_n_samples()
        self._infer = True if split == "test" else False

        self._stream_X = True if self.all_eegs is None else False
        self._X, self._y = self._transform()

    def _set_n_samples(self) -> None:
        assert len(self.metadata) == self.metadata["eeg_id"].nunique()
        self._n_samples = len(self.metadata)

    def _transform(self) -> Tuple[Optional[np.ndarray], np.ndarray]:
        """Transform feature and target matrices."""
        if self.eeg_params["downsample"] is not None:
            eeg_len = int(EEG_PTS / self.eeg_params["downsample"])
        else:
            eeg_len = int(EEG_PTS)
        if not self._stream_X:
            X = np.zeros((self._n_samples, eeg_len, self.eeg_params["n_feats"]), dtype="float32")
        else:
            X = None
        y = np.zeros((self._n_samples, N_CLASSES), dtype="float32") if not self._infer else None

        for i, row in tqdm(self.metadata.iterrows(), total=len(self.metadata)):
            # Process raw EEG signals
            if not self._stream_X:
                # Retrieve raw EEG signals
                eeg = self.all_eegs[row["eeg_id"]]

                # Apply EEG transformer
                x = self.eeg_trafo.transform(eeg)

                X[i] = x

            if not self._infer:
                y[i] = row[TGT_VOTE_COLS] 

        return X, y

    def __len__(self) -> int:
        return self._n_samples

    def __getitem__(self, idx: int) -> Dict[str, Tensor]:
        if self._X is None:
            # Load data here...
#             x = np.load(...)
#             x = self.eeg_trafo.transform(x)
            pass
        else:
            x = self._X[idx, ...]
        data_sample = {"x": torch.tensor(x, dtype=torch.float32)}
        if not self._infer:
            data_sample["y"] = torch.tensor(self._y[idx, :], dtype=torch.float32)

        return data_sample

<a id="test_loader"></a>
## 3. Create Test Loader
[**<span style="color:#FEF1FE; background-color:#535d70;border-radius: 5px; padding: 2px">Go to Table of Content</span>**](#toc)

In [None]:
test_data = {"meta": test, "eeg": all_eegs}
test_loader = DataLoader(
    EEGDataset(test_data, "test", **CFG.dataset),
    batch_size=CFG.batch_size,
    shuffle=False,
    num_workers=0
)
print(f"There are {len(test_loader.dataset)} test samples to infer.")

<a id="model"></a>
## 4. Define Model Architecture
[**<span style="color:#FEF1FE; background-color:#535d70;border-radius: 5px; padding: 2px">Go to Table of Content</span>**](#toc)

[![Screenshot-2024-02-19-at-1-11-40-PM.png](https://i.postimg.cc/MKp8xVVV/Screenshot-2024-02-19-at-1-11-40-PM.png)](https://postimg.cc/7bdRnCBZ)

In [None]:
class _WaveBlock(nn.Module):
    """WaveNet block.

    Args:
        kernel_size: kernel size, pass a list of kernel sizes for
            inception
    """

    def __init__(
        self,
        n_layers: int, 
        in_dim: int,
        h_dim: int,
        kernel_size: Union[int, List[int]],
        conv_module: Optional[Type[nn.Module]] = None,
    ) -> None:
        super().__init__()

        self.n_layers = n_layers
        self.dilation_rates = [2**l for l in range(n_layers)]

        self.in_conv = nn.Conv2d(in_dim, h_dim, kernel_size=(1, 1)) 
        self.gated_tcns = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        for layer in range(n_layers):
            c_in, c_out = h_dim, h_dim
            self.gated_tcns.append(
                _GatedTCN(
                    in_dim=c_in,
                    h_dim=c_out,
                    kernel_size=kernel_size,
                    dilation_factor=self.dilation_rates[layer],
                    conv_module=conv_module,
                )
            )
            self.skip_convs.append(nn.Conv2d(h_dim, h_dim, kernel_size=(1, 1)))

        # Initialize parameters
        nn.init.xavier_uniform_(self.in_conv.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.zeros_(self.in_conv.bias)
        for i in range(len(self.skip_convs)):
            nn.init.xavier_uniform_(self.skip_convs[i].weight, gain=nn.init.calculate_gain("relu"))
            nn.init.zeros_(self.skip_convs[i].bias)

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass.
        
        Shape:
            x: (B, C, N, L), where C denotes in_dim
            x_skip: (B, C', N, L), where C' denotes h_dim
        """
        # Input convolution
        x = self.in_conv(x)

        x_skip = x
        for layer in range(self.n_layers):
            x = self.gated_tcns[layer](x)
            x = self.skip_convs[layer](x)

            # Skip-connection
            x_skip = x_skip + x 

        return x_skip


class _GatedTCN(nn.Module):
    """Gated temporal convolution layer.

    Parameters:
        conv_module: customized convolution module
    """

    def __init__(
        self,
        in_dim: int,
        h_dim: int,
        kernel_size: Union[int, List[int]],
        dilation_factor: int,
        dropout: Optional[float] = None,
        conv_module: Optional[Type[nn.Module]] = None,
    ) -> None:
        super().__init__()

        # Model blocks
        if conv_module is None:
            self.filt = nn.Conv2d(
                in_channels=in_dim, out_channels=h_dim, kernel_size=(1, kernel_size), dilation=dilation_factor
            )
            self.gate = nn.Conv2d(
                in_channels=in_dim, out_channels=h_dim, kernel_size=(1, kernel_size), dilation=dilation_factor
            )
        else:
            self.filt = conv_module(
                in_channels=in_dim, out_channels=h_dim, kernel_size=kernel_size, dilation=dilation_factor
            )
            self.gate = conv_module(
                in_channels=in_dim, out_channels=h_dim, kernel_size=kernel_size, dilation=dilation_factor
            )

        if dropout is not None:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = None

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass.

        Parameters:
            x: input sequence

        Return:
            h: output sequence

        Shape:
            x: (B, C, N, L), where L denotes the input sequence length
            h: (B, h_dim, N, L')
        """
        x_filt = F.tanh(self.filt(x))
        x_gate = F.sigmoid(self.gate(x))
        h = x_filt * x_gate
        if self.dropout is not None:
            h = self.dropout(h)

        return h


class _DilatedInception(nn.Module):
    """Dilated inception layer.

    Note that `out_channels` will be split across #kernels.
    """

    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
        kernel_size: List[int], 
        dilation: int
    ) -> None:
        super().__init__()

        # Network parameters
        n_kernels = len(kernel_size)
        assert out_channels % n_kernels == 0, "`out_channels` must be divisible by #kernels."
        h_dim = out_channels // n_kernels

        # Model blocks
        self.convs = nn.ModuleList()
        for k in kernel_size:
            self.convs.append(
                nn.Conv2d(
                    in_channels=in_channels, 
                    out_channels=h_dim, 
                    kernel_size=(1, k),
                    padding="same",
                    dilation=dilation),
            )

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass.

        Parameters:
            x: input sequence

        Return:
            h: output sequence

        Shape:
            x: (B, C, N, L), where C = in_channels
            h: (B, C', N, L'), where C' = out_channels
        """
        x_convs = []
        for conv in self.convs:
            x_conv = conv(x)
            x_convs.append(x_conv)
        h = torch.cat(x_convs, dim=1)

        return h

In [None]:
class DilatedInceptionWaveNet(nn.Module):
    """WaveNet architecture with dilated inception conv."""

    def __init__(self,) -> None:
        super().__init__()

        kernel_size = [2, 3, 6, 7]

        # Model blocks 
        self.wave_module = nn.Sequential(
            _WaveBlock(12, 1, 16, kernel_size, _DilatedInception),
            _WaveBlock(8, 16, 32, kernel_size, _DilatedInception),
            _WaveBlock(4, 32, 64, kernel_size, _DilatedInception),
            _WaveBlock(1, 64, 64, kernel_size, _DilatedInception),
        )
        self.output = nn.Sequential(
            nn.Linear(64 * 4, 64),
            nn.ReLU(),
            nn.Linear(64, N_CLASSES)
        ) 

    def forward(self, inputs: Dict[str, Tensor]) -> Tensor:
        """Forward pass.

        Shape:
            x: (B, L, C)
        """
        x = inputs["x"]
        bs, length, in_dim = x.shape
        x = x.transpose(1, 2).unsqueeze(dim=2)  # (B, C, N, L), N is redundant

        x_ll_1 = self.wave_module(x[:, 0:1, :])
        x_ll_2 = self.wave_module(x[:, 1:2, :])
        x_ll = (F.adaptive_avg_pool2d(x_ll_1, (1, 1)) + F.adaptive_avg_pool2d(x_ll_2, (1, 1))) / 2

        x_rl_1 = self.wave_module(x[:, 2:3, :])
        x_rl_2 = self.wave_module(x[:, 3:4, :])
        x_rl = (F.adaptive_avg_pool2d(x_rl_1, (1, 1)) + F.adaptive_avg_pool2d(x_rl_2, (1, 1))) / 2

        x_lp_1 = self.wave_module(x[:, 4:5, :])
        x_lp_2 = self.wave_module(x[:, 5:6, :])
        x_lp = (F.adaptive_avg_pool2d(x_lp_1, (1, 1)) + F.adaptive_avg_pool2d(x_lp_2, (1, 1))) / 2

        x_rp_1 = self.wave_module(x[:, 6:7, :])
        x_rp_2 = self.wave_module(x[:, 7:8, :])
        x_rp = (F.adaptive_avg_pool2d(x_rp_1, (1, 1)) + F.adaptive_avg_pool2d(x_rp_2, (1, 1))) / 2

        x = torch.cat([x_ll, x_rl, x_lp, x_rp], axis=1).reshape(bs, -1)
        output = self.output(x)

        return output

<a id="load_models"></a>
## 5. Load Models
[**<span style="color:#FEF1FE; background-color:#535d70;border-radius: 5px; padding: 2px">Go to Table of Content</span>**](#toc)

In [None]:
models = []
for fold, file in enumerate(sorted(CFG.model_path.glob("./*.pth"))):
    print(f"Load model from {file}...")
    fold_model = DilatedInceptionWaveNet()
    fold_model.load_state_dict(torch.load(file, map_location=CFG.device))
    fold_model = fold_model.to(CFG.device)
    models.append(fold_model)

<a id="infer"></a>
## 6. Run Inference
[**<span style="color:#FEF1FE; background-color:#535d70;border-radius: 5px; padding: 2px">Go to Table of Content</span>**](#toc)

In [None]:
@torch.no_grad()
def _infer(inputs: Dict[str, Tensor], models: List[nn.Module]) -> Tensor:
    n_models = len(models)

    for i, model in enumerate(models):
        model.eval()
        y_pred_fold = F.softmax(model(inputs)) / n_models    # (B, N_CLASSES)
        
        if i == 0:
            y_pred = y_pred_fold
        else:
            y_pred += y_pred_fold
        
    return y_pred

In [None]:
y_preds = []
for i, batch_data in enumerate(test_loader):
    batch_data["x"] = batch_data["x"].to(CFG.device)
    y_pred = _infer(batch_data, models)
    y_preds.append(y_pred.detach().cpu().numpy())
y_preds = np.vstack(y_preds)
print(f"Sum of row 0 in y_preds {np.sum(y_preds[0, :])}.")

<a id="sub"></a>
## 7. Submission
[**<span style="color:#FEF1FE; background-color:#535d70;border-radius: 5px; padding: 2px">Go to Table of Content</span>**](#toc)

In [None]:
sub = pd.DataFrame({"eeg_id": test.eeg_id.values})
sub[TGT_VOTE_COLS] = y_preds
sub.to_csv("submission.csv", index=False)

print("===== Submission Demo =====")
sub.head()