In [None]:
import ast

import numpy as np
import pandas as pd
import wfdb


def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path + "/" + f) for f in df.filename_lr[:1000]]
    else:
        data = [wfdb.rdsamp(path + "/" + f) for f in df.filename_hr[:1000]]
    data = np.array([signal for signal, meta in data])
    return data


path = "/Users/shadyali/Self-Supervied-Contrastive-Representation-Learning-ECG-Signals/data/raw/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1"
sampling_rate = 500

# load and convert annotation data
Y = pd.read_csv(path + "/ptbxl_database.csv", index_col="ecg_id")[:1000]
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load raw signal data
X = load_raw_data(Y, sampling_rate, path)

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path + "/scp_statements.csv", index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]


def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))


# Apply diagnostic superclass
Y["diagnostic_superclass"] = Y.scp_codes.apply(aggregate_diagnostic)

# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where(Y.strat_fold != test_fold)]
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass


In [8]:
X.shape, Y.shape

((1000, 5000, 12), (1000, 28))

## Model's Architecture

In [None]:
import math
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn


@dataclass
class ECGModelConfig:
    sequence_length: int = 5000
    num_channels: int = 12
    d_model: int = 96
    time_heads: int = 2
    channel_heads: int = 2
    time_layers: int = 1
    channel_layers: int = 1
    ff_multiplier: int = 2
    dropout: float = 0.2
    temperature: float = 0.1
    projection_dim: int = 256
    dtype: torch.dtype = torch.bfloat16


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int, dtype: torch.dtype) -> None:
        super().__init__()
        position = torch.arange(0, max_len, dtype=dtype).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=dtype) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model, dtype=dtype)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, : x.size(1)]


class TimeTransformer(nn.Module):
    def __init__(self, config: ECGModelConfig) -> None:
        super().__init__()
        self.input_proj = nn.Linear(
            config.num_channels, config.d_model, dtype=config.dtype
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.time_heads,
            dim_feedforward=config.d_model * config.ff_multiplier,
            dropout=config.dropout,
            batch_first=True,
            activation="relu",
            dtype=config.dtype,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=config.time_layers
        )
        self.positional_encoding = SinusoidalPositionalEncoding(
            config.d_model, config.sequence_length, config.dtype
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dtype != self.input_proj.weight.dtype:
            x = x.to(self.input_proj.weight.dtype)
        x = self.input_proj(x)
        x = self.positional_encoding(x)
        return self.encoder(x)


class ChannelTransformer(nn.Module):
    def __init__(self, config: ECGModelConfig) -> None:
        super().__init__()
        self.channel_proj = nn.Conv1d(
            in_channels=config.num_channels,
            out_channels=config.num_channels * config.d_model,
            kernel_size=1,
            groups=config.num_channels,
            dtype=config.dtype,
        )
        self.pool = nn.AdaptiveAvgPool1d(1)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.channel_heads,
            dim_feedforward=config.d_model * config.ff_multiplier,
            dropout=config.dropout,
            batch_first=True,
            activation="relu",
            dtype=config.dtype,
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=config.channel_layers
        )
        self.num_channels = config.num_channels
        self.d_model = config.d_model
        self.dtype = config.dtype

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dtype != self.dtype:
            x = x.to(self.dtype)
        
        batch_size = x.size(0)
        x = x.permute(0, 2, 1)  # (batch, channels, time)
        x = self.channel_proj(x)  # (batch, channels * d_model, time)
        x = x.view(batch_size, self.num_channels, self.d_model, -1)  # (batch, channels, d_model, time)
        x = self.pool(x.flatten(1, 2)).view(batch_size, self.num_channels, self.d_model)  # (batch, channels, d_model)
        
        return self.encoder(x)


class BidirectionalCrossAttention(nn.Module):
    def __init__(self, config: ECGModelConfig) -> None:
        super().__init__()
        self.time_to_channel = nn.MultiheadAttention(
            config.d_model, config.time_heads, batch_first=True, dtype=config.dtype
        )
        self.channel_to_time = nn.MultiheadAttention(
            config.d_model, config.channel_heads, batch_first=True, dtype=config.dtype
        )
        self.time_norm = nn.LayerNorm(config.d_model, dtype=config.dtype)
        self.channel_norm = nn.LayerNorm(config.d_model, dtype=config.dtype)

    def forward(
        self, time_tokens: torch.Tensor, channel_tokens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        time_cross, _ = self.time_to_channel(
            time_tokens, channel_tokens, channel_tokens
        )
        fused_time = self.time_norm(time_tokens + time_cross)
        channel_cross, _ = self.channel_to_time(
            channel_tokens, time_tokens, time_tokens
        )
        fused_channel = self.channel_norm(channel_tokens + channel_cross)
        return fused_time, fused_channel


class FusionHead(nn.Module):
    def __init__(self, config: ECGModelConfig) -> None:
        super().__init__()
        self.linear = nn.Linear(config.d_model * 2, config.d_model, dtype=config.dtype)
        self.norm = nn.LayerNorm(config.d_model, dtype=config.dtype)
        self.activation = nn.ReLU()

    def forward(
        self, time_repr: torch.Tensor, channel_repr: torch.Tensor
    ) -> torch.Tensor:
        fused = torch.cat([time_repr, channel_repr], dim=-1)
        mapped = self.linear(fused)
        return self.activation(self.norm(mapped))


class ProjectionHead(nn.Module):
    def __init__(self, config: ECGModelConfig) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.d_model, config.d_model, dtype=config.dtype),
            nn.ReLU(),
            nn.Linear(config.d_model, config.projection_dim, dtype=config.dtype),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.normalize(self.net(x), dim=-1)


class ECGEncoder(nn.Module):
    def __init__(self, config: ECGModelConfig) -> None:
        super().__init__()
        self.time_encoder = TimeTransformer(config)
        self.channel_encoder = ChannelTransformer(config)
        self.cross_attention = BidirectionalCrossAttention(config)
        self.fusion = FusionHead(config)
        self.projection = ProjectionHead(config)
        self.dtype = config.dtype
        self.to(dtype=config.dtype)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if x.dtype != self.dtype:
            x = x.to(self.dtype)
        time_tokens = self.time_encoder(x)
        channel_tokens = self.channel_encoder(x)
        fused_time, fused_channel = self.cross_attention(time_tokens, channel_tokens)
        time_repr = fused_time.mean(dim=1)
        channel_repr = fused_channel.mean(dim=1)
        representation = self.fusion(time_repr, channel_repr)
        projection = self.projection(representation)
        return representation, projection

In [None]:
def info_nce_loss(z: torch.Tensor, temperature: float) -> torch.Tensor:
    num_views, batch_size, _ = z.shape
    embeddings = F.normalize(z, dim=-1).view(num_views * batch_size, -1)
    similarity = torch.matmul(embeddings, embeddings.T) / temperature
    identity_mask = torch.eye(
        similarity.size(0), device=similarity.device, dtype=torch.bool
    )
    similarity = similarity.masked_fill(identity_mask, -torch.inf)
    labels = torch.arange(batch_size, device=similarity.device).repeat(num_views)
    positives_mask = labels.unsqueeze(0) == labels.unsqueeze(1)
    positives_mask = positives_mask & (~identity_mask)
    exp_similarity = torch.exp(similarity)
    positives_sum = (exp_similarity * positives_mask.float()).sum(dim=1)
    denominator = exp_similarity.sum(dim=1)
    loss = -torch.log(positives_sum / (denominator + 1e-8) + 1e-8)
    return loss.mean()

In [17]:
from typing import Tuple

def summarize_encoder_params(model) -> Tuple[int, int]:
    """
    Print a concise summary of the encoder's parameters and return (total, trainable).
    Uses existing `ecg_encoder`, `encoder_config`, `total_params`, `trainable_params` if available.
    """
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Model: {model.__class__.__name__}")
    print(f"Total params:     {total:,}")
    print(f"Trainable params: {trainable:,}\n")

    print("Per top-level module parameter counts:")
    for name, module in model.named_children():
        mod_total = sum(p.numel() for p in module.parameters())
        mod_train = sum(p.numel() for p in module.parameters() if p.requires_grad)
        print(f"  {name:20s} | total: {mod_total:10,} | trainable: {mod_train:10,}")

    print("\nParameter shapes (name, shape, trainable):")
    for name, p in model.named_parameters():
        print(f"  {name:60s} {tuple(p.shape):20s} {'trainable' if p.requires_grad else 'frozen'}")

    return total, trainable

ecg_encoder = ECGEncoder(ECGModelConfig())

# Use the existing ecg_encoder variable in the notebook
total, trainable = summarize_encoder_params(ecg_encoder)

# If encoder_config and provided totals exist, print a quick consistency check
try:
    print("\nConsistency check with precomputed values (if present):")
    print(f"encoder_config: {encoder_config}")
    print(f"precomputed total_params:     {total_params:,}")
    print(f"precomputed trainable_params: {trainable_params:,}")
except NameError:
    pass

Model: ECGEncoder
Total params:     278,752
Trainable params: 278,752

Per top-level module parameter counts:
  time_encoder         | total:     76,032 | trainable:     76,032
  channel_encoder      | total:     74,976 | trainable:     74,976
  cross_attention      | total:     74,880 | trainable:     74,880
  fusion               | total:     18,720 | trainable:     18,720
  projection           | total:     34,144 | trainable:     34,144

Parameter shapes (name, shape, trainable):


TypeError: unsupported format string passed to tuple.__format__