In [1]:
from functools import partial
from typing import Any, Callable, List, Optional, Sequence

import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import Tensor, nn
from torch.nn import functional as F
from torchvision.ops.stochastic_depth import StochasticDepth


class CNBlockConfig1D:
    def __init__(
        self,
        input_channels: int,
        out_channels: Optional[int],
        num_layers: int,
    ) -> None:
        self.input_channels = input_channels
        self.out_channels = out_channels
        self.num_layers = num_layers

    def __repr__(self) -> str:
        s = self.__class__.__name__ + "("
        s += "input_channels={input_channels}"
        s += ", out_channels={out_channels}"
        s += ", num_layers={num_layers}"
        s += ")"
        return s.format(**self.__dict__)


class LayerNorm1D(nn.LayerNorm):
    """
    Rearranges input from (N, C, S) to (N, S, C) to apply Layer Normalization over channels (C),
    as nn.LayerNorm normalizes the last dimension. Returns input to original shape (N, C, S) after normalization.
    """

    def forward(self, x: Tensor) -> Tensor:
        x = rearrange(x, "N C S -> N S C")
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = rearrange(x, "N S C -> N C S")
        return x


class ResNetBlock1D(nn.Module):
    """
    A ResNet residual block for 1D inputs with two Conv1d layers, BatchNorm, and ReLU activation.
    """

    def __init__(self, dim, hidden_dim=None, downsample=False):
        super().__init__()
        hidden_dim = hidden_dim or dim
        stride = 2 if downsample else 1

        self.conv1 = nn.Conv1d(
            dim, hidden_dim, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(hidden_dim, dim, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(dim)

        # Downsampling layer if needed
        self.downsample = (
            nn.Conv1d(dim, dim, kernel_size=1, stride=stride, bias=False)
            if downsample
            else None
        )

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # Apply downsampling if specified
        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        out = self.relu(out)

        return out


class CNBlock1D(nn.Module):
    """
    A ConvNeXt residual block for 1D inputs with depthwise convolution, LayerNorm, LayerScale,
    GELU activation, inverted bottlneck, and stochastic depth regularization.
    """

    def __init__(
        self,
        dim,
        layer_scale: float,
        stochastic_depth_prob: float,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        bottleneck_inversion_factor: int = 4,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)

        inverted_bottleneck_dim = dim * bottleneck_inversion_factor

        self.block = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
            Rearrange("N S C -> N C S"),
            norm_layer(dim),
            nn.Linear(
                in_features=dim,
                out_features=inverted_bottleneck_dim,
                bias=True,
            ),
            nn.GELU(),
            nn.Linear(
                in_features=inverted_bottleneck_dim,
                out_features=dim,
                bias=True,
            ),
            Rearrange("N C S -> N S C"),
        )
        self.layer_scale = nn.Parameter(torch.ones(dim, 1) * layer_scale)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")

    def forward(self, input: Tensor) -> Tensor:
        result = self.layer_scale * self.block(input)
        result = self.stochastic_depth(result)
        result += input
        return result


class ConvNeXt1D(nn.Module):
    """ConvNeXt1D model for ECG data classification."""

    def __init__(
        self,
        block_setting: List[CNBlockConfig1D],
        stochastic_depth_prob: float = 0.0,
        layer_scale: float = 1e-6,
        channels=12,
        num_classes: int = 5,
        block: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        """
        Args:
            block_setting: List of CNBlockConfig1D for each stage.
            stochastic_depth_prob: Probability of dropping out a block. The probability is linearly increased.
            layer_scale: Layer scale for LayerScale module.
            channels: Number of input channels.
            num_classes: Number of output classes.
            block: Block module to use. Defaults to CNBlock1D.
            norm_layer: Normalization layer to use. Defaults to LayerNorm1D.
        """
        super().__init__()

        if not block_setting:
            raise ValueError("The block_setting should not be empty")
        elif not (
            isinstance(block_setting, Sequence)
            and all([isinstance(s, CNBlockConfig1D) for s in block_setting])
        ):
            raise TypeError("The block_setting should be List[CNBlockConfig1D]")

        if block is None:
            block = CNBlock1D

        if norm_layer is None:
            norm_layer = partial(LayerNorm1D, eps=1e-6)

        layers: List[nn.Module] = []

        # Stem
        firstconv_output_channels = block_setting[0].input_channels
        layers.append(
            nn.Conv1d(
                channels,
                firstconv_output_channels,
                kernel_size=4,
                stride=4,
                padding=0,
                bias=True,
            )
        )

        total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
        stage_block_id = 0
        for cnf in block_setting:
            # Bottlenecks
            stage: List[nn.Module] = []
            for _ in range(cnf.num_layers):
                # adjust stochastic depth probability based on the depth of the stage block
                sd_prob = (
                    stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
                )
                stage.append(block(cnf.input_channels, layer_scale, sd_prob))
                stage_block_id += 1
            layers.append(nn.Sequential(*stage))
            if cnf.out_channels is not None:
                # Downsampling
                layers.append(
                    nn.Sequential(
                        LayerNorm1D(cnf.input_channels, eps=1e-6),
                        nn.Conv1d(
                            cnf.input_channels,
                            cnf.out_channels,
                            kernel_size=2,
                            stride=2,
                        ),
                    )
                )

        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool1d(1)

        lastblock = block_setting[-1]
        lastconv_output_channels = (
            lastblock.out_channels
            if lastblock.out_channels is not None
            else lastblock.input_channels
        )
        self.classifier = nn.Sequential(
            norm_layer(lastconv_output_channels),
            nn.Flatten(1),
            nn.Linear(lastconv_output_channels, num_classes),
        )

        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        return x


def convnext1d_tiny(**kwargs: Any) -> ConvNeXt1D:
    return ConvNeXt1D(
        [
            CNBlockConfig1D(input_channels=24, out_channels=48, num_layers=3),
            CNBlockConfig1D(input_channels=48, out_channels=96, num_layers=3),
            CNBlockConfig1D(input_channels=96, out_channels=None, num_layers=3),
        ],
        **kwargs,
    )


def convnext1d_small(**kwargs: Any) -> ConvNeXt1D:
    return ConvNeXt1D(
        [
            CNBlockConfig1D(input_channels=64, out_channels=96, num_layers=3),
            CNBlockConfig1D(input_channels=96, out_channels=128, num_layers=3),
            CNBlockConfig1D(input_channels=128, out_channels=256, num_layers=3),
            CNBlockConfig1D(input_channels=256, out_channels=None, num_layers=3),
        ],
        **kwargs,
    )


def convnext1d_large(**kwargs: Any) -> ConvNeXt1D:
    return ConvNeXt1D(
        [
            CNBlockConfig1D(input_channels=96, out_channels=128, num_layers=3),
            CNBlockConfig1D(input_channels=128, out_channels=256, num_layers=3),
            CNBlockConfig1D(input_channels=256, out_channels=512, num_layers=3),
            CNBlockConfig1D(input_channels=512, out_channels=None, num_layers=3),
        ],
        **kwargs,
    )


def get_convnext(name: str, **kwargs: Any) -> ConvNeXt1D:
    if name == "tiny":
        return convnext1d_tiny(**kwargs)
    elif name == "small":
        return convnext1d_small(**kwargs)
    elif name == "large":
        return convnext1d_large(**kwargs)
    else:
        raise ValueError(f"Unknown ConvNeXt model name: {name}")


In [2]:
import base64
import pickle
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

def read_face_meta(path):
    with open(path, mode="rb") as file:
        meta = pickle.load(file)
    
    for col in meta.columns:
        if isinstance(meta.loc[0, col], list):
            continue
        
        if len(meta[col].unique()) < 100:
            meta[col] = meta[col].astype("category")
    
    return meta


def load_ecg(path, encoding, strip_index, n_channels=2):
    assert encoding in ["binary", "numpy"]
    
    with h5py.File(path, "r") as file:
        ecg_strip = file["ecg_strips"][strip_index]
        
        if encoding == "numpy":
            return np.moveaxis(ecg_strip.astype(np.float32), 0, -1)

        binary_data = base64.b64decode(ecg_strip)
        return np.frombuffer(binary_data, dtype=np.float32).reshape(n_channels, -1).copy()

ROOT = Path("/sc-scratch/sc-scratch-gbm-radiomics/ecg/face")
FILES = ROOT / "files"
LABEL_MAPPING = {0: "SR", 1: "AFIB", 2: "OTHER", 3: "NOISE"}

In [3]:
meta = read_face_meta(ROOT / "merged_meta.pkl")

  meta = pickle.load(file)


In [4]:
# train_meta = meta.loc[meta.use_for_train == True].copy()
# train_meta = train_meta[~train_meta.label.isna()]

# valid_meta = meta.loc[meta.use_for_train == False].copy()
# valid_meta = valid_meta[~valid_meta.label.isna()]

# print(f"Train examples: {len(train_meta):,}\nValid examples: {len(valid_meta):,}")
# del train_meta, valid_meta

In [5]:
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from torch.nn.functional import pad
import random

def random_crop(x: torch.Tensor, seq_len: int) -> torch.Tensor:
    """Randomly crop a signal to a fixed length."""
    offset = 1 + x.size(-1) - seq_len
    start = torch.randint(0, offset, (1,)).item()
    return x[..., start : start + seq_len]


class CropOrPad:
    def __init__(self, seq_len=1_000, pad_mode="constant"):
        self.seq_len = seq_len
        self.pad_mode = pad_mode

    def __call__(self, x):
        if x.size(-1) > self.seq_len:
            return random_crop(x, self.seq_len)
        if x.size(-1) < self.seq_len:
            left_pad = (self.seq_len - x.size(-1)) // 2
            right_pad = int(np.ceil((self.seq_len - x.size(-1)) / 2))
            return pad(x, (left_pad, right_pad), mode=self.pad_mode)
        return x

class Dataset:
    def __init__(self, root, meta, transform=None):
        self.meta = meta
        self.root = Path(root)
        self.file_path = self.root / "files"
        self.transform = transform
        
    def __len__(self):
        return len(self.meta)

    def __getitem__(self, index):
        sample = self.meta.iloc[index]

        label = sample.label.item()

        encoding = sample["hdf5-type"]
        strip_index = int(sample.strip_index)
        path = self.file_path / f"{sample.extracted_strips_filename}.hdf5"
        
        ecg = load_ecg(path, encoding, strip_index)
        ecg = torch.tensor(ecg)
        
        if self.transform:
            ecg = self.transform(ecg)

        return ecg, label

In [6]:
meta_nona = meta[~meta.label.isna()]

In [12]:
train_meta = meta_nona.iloc[np.arange(10)]
valid_meta = meta_nona.iloc[np.arange(10, 20)]

train_dataset = Dataset(ROOT, train_meta, transform=CropOrPad())
valid_dataset = Dataset(ROOT, valid_meta, transform=CropOrPad())

train_loader = DataLoader(train_dataset, batch_size=4)
valid_loader = DataLoader(valid_dataset)

In [13]:
ecg, label = next(iter(train_loader))

In [14]:
ecg.shape

torch.Size([4, 2, 1000])

In [16]:
label

tensor([1, 1, 1, 1])

In [9]:
def count_params(model):
    return sum([x.numel() for x in model.parameters() if x.requires_grad])

In [10]:
for m in ["tiny", "small", "large"]:
    model = get_convnext(m, channels=2, num_classes=4)
    params = count_params(model)
    print(m, f"{params:,}")

tiny 310,972
small 2,417,252
large 8,884,388
