# KOCH PAIRED

In [3]:
import os
from glob import glob

import numpy as np
from scipy.signal import butter, filtfilt, find_peaks

import torch
from torch.utils.data import Dataset


class KochPairedBeatsDataset(Dataset):
    def __init__(self, npz_path="koch_pairs.npz", augment=False):
        """
        npz_path: path to the koch_pairs.npz file
        augment: if True, apply very light noise / jitter augmentations
        """
        data = np.load(npz_path, allow_pickle=True)
        self.ecg_beats = data["ecg_beats"]  # (N, C_ecg, T)
        self.mcg_beats = data["mcg_beats"]  # (N, C_mcg, T)
        self.fs = int(data["fs"][0])
        self.augment = augment

    def __len__(self):
        return self.ecg_beats.shape[0]

    def _augment(self, x):
        """
        x: numpy array (C, T)
        very light augmentations: small Gaussian noise and tiny time shift
        """
        # small noise
        x = x + 0.01 * np.random.randn(*x.shape).astype(np.float32)

        # tiny circular time shift up to ±20 samples
        max_shift = 20
        shift = np.random.randint(-max_shift, max_shift + 1)
        if shift != 0:
            x = np.roll(x, shift, axis=1)

        return x

    def __getitem__(self, idx):
        ecg = self.ecg_beats[idx]  # (C_ecg, T)
        mcg = self.mcg_beats[idx]  # (C_mcg, T)

        if self.augment:
            ecg = self._augment(ecg)
            mcg = self._augment(mcg)

        # convert to torch tensors
        ecg = torch.from_numpy(ecg)  # float32, shape (C_ecg, T)
        mcg = torch.from_numpy(mcg)  # float32, shape (C_mcg, T)

        return ecg, mcg


#from koch_dataset import KochPairedBeatsDataset
from torch.utils.data import DataLoader

ds = KochPairedBeatsDataset("koch_pairs.npz", augment=False)
print("Num beats:", len(ds))

loader = DataLoader(ds, batch_size=8, shuffle=True)
ecg, mcg = next(iter(loader))
print("ECG batch shape:", ecg.shape)  # expect: torch.Size([8, 32, 2000])
print("MCG batch shape:", mcg.shape)  # expect: torch.Size([8, 100, 2000])

Num beats: 127
ECG batch shape: torch.Size([8, 32, 2000])
MCG batch shape: torch.Size([8, 100, 2000])


# PTB Beats Dataset

In [4]:
import numpy as np
import torch
from torch.utils.data import Dataset


class PTBBeatsDataset(Dataset):
    def __init__(self, npz_path="ptb_beats.npz", augment=True):
        data = np.load(npz_path, allow_pickle=True)
        self.beats = data["beats"]        # (N, 12, 2000)
        self.fs = int(data["fs"][0])
        self.augment = augment

    def __len__(self):
        return self.beats.shape[0]

    def _augment(self, x):
        # x: (12, T)
        x = x + 0.01 * np.random.randn(*x.shape).astype(np.float32)  # noise
        max_shift = 20
        shift = np.random.randint(-max_shift, max_shift + 1)
        if shift != 0:
            x = np.roll(x, shift, axis=1)
        return x

    def __getitem__(self, idx):
        beat = self.beats[idx]
        if self.augment:
            beat = self._augment(beat)
        return torch.from_numpy(beat)



from torch.utils.data import DataLoader

ds_ptb = PTBBeatsDataset("ptb_beats.npz", augment=True)
loader_ptb = DataLoader(ds_ptb, batch_size=32, shuffle=True)

batch = next(iter(loader_ptb))
print(batch.shape)  # torch.Size([32, 12, 2000])


torch.Size([32, 12, 2000])


# Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class DepthwiseSeparableConv1d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
        super().__init__()
        self.depthwise = nn.Conv1d(in_ch, in_ch, kernel_size=kernel_size,
                                   stride=stride, padding=padding, groups=in_ch)
        self.pointwise = nn.Conv1d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class MultiScaleBlock(nn.Module):
    """
    Multi-scale depthwise-separable residual block.
    Input: (B, C, T) -> Output: (B, C, T)
    """
    def __init__(self, channels, kernels=(5, 9, 17)):
        super().__init__()
        self.branches = nn.ModuleList()
        for k in kernels:
            pad = k // 2
            self.branches.append(
                DepthwiseSeparableConv1d(channels, channels, kernel_size=k, padding=pad)
            )
        self.bn = nn.BatchNorm1d(channels)
        self.act = nn.ReLU()

    def forward(self, x):
        # x: (B, C, T)
        outs = []
        for conv in self.branches:
            outs.append(conv(x))
        out = sum(outs) / len(outs)  # average branches
        out = self.bn(out)
        out = self.act(out)
        return x + out  # residual


class SMEEBackbone(nn.Module):
    """
    Shared Multi-Scale Efficient Encoder backbone.
    Input: (B, C_bottleneck, T) -> Output: (B, feat_dim)
    """
    def __init__(self, bottleneck_channels=32, n_blocks=3, feat_dim=256):
        super().__init__()
        blocks = []
        for _ in range(n_blocks):
            blocks.append(MultiScaleBlock(bottleneck_channels, kernels=(5, 9, 17)))
        self.blocks = nn.Sequential(*blocks)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(bottleneck_channels, feat_dim)

    def forward(self, x):
        # x: (B, C_bottleneck, T)
        x = self.blocks(x)                 # (B, C_bottleneck, T)
        x = self.global_pool(x).squeeze(-1)  # (B, C_bottleneck)
        x = self.fc(x)                     # (B, feat_dim)
        return x


class ECGEncoderSMEE(nn.Module):
    """
    ECG encoder using SMEE backbone.
    """
    def __init__(self, in_channels=32, bottleneck_channels=32, feat_dim=256):
        super().__init__()
        self.input_proj = nn.Conv1d(in_channels, bottleneck_channels, kernel_size=1)
        self.backbone = SMEEBackbone(bottleneck_channels=bottleneck_channels,
                                     n_blocks=3, feat_dim=feat_dim)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.backbone(x)
        return x


class MCGEncoderSMEE(nn.Module):
    """
    MCG encoder using SMEE backbone. Can share backbone weights with ECG encoder if desired.
    """
    def __init__(
        self,
        in_channels: int = 100,
        bottleneck_channels: int = 32,
        feat_dim: int = 256,
        shared_backbone: Optional[SMEEBackbone] = None,
    ):
        super().__init__()
        self.input_proj = nn.Conv1d(in_channels, bottleneck_channels, kernel_size=1)
        if shared_backbone is None:
            self.backbone = SMEEBackbone(
                bottleneck_channels=bottleneck_channels,
                n_blocks=3,
                feat_dim=feat_dim,
            )
        else:
            self.backbone = shared_backbone  # weight sharing

    def forward(self, x):
        x = self.input_proj(x)
        x = self.backbone(x)
        return x




In [15]:
#from models import ECGEncoderSMEE, MCGEncoderSMEE  # or whatever file name you used
import torch

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

ecg_smee = ECGEncoderSMEE(in_channels=32)
mcg_smee = MCGEncoderSMEE(in_channels=100)

print("ECG SMEE params:", count_params(ecg_smee))
print("MCG SMEE params:", count_params(mcg_smee))

# quick forward test
x_ecg = torch.randn(4, 32, 2000)
x_mcg = torch.randn(4, 100, 2000)
h_e = ecg_smee(x_ecg)
h_m = mcg_smee(x_mcg)
print("Shapes:", h_e.shape, h_m.shape)  # should be: torch.Size([4, 256]) torch.Size([4, 256])


ECG SMEE params: 22464
MCG SMEE params: 24640
Shapes: torch.Size([4, 256]) torch.Size([4, 256])


# Projection Head

In [19]:


class ProjectionHead(nn.Module):
    """
    2-layer MLP projection head.
    Input: (B, feat_dim)
    Output: (B, proj_dim) normalized
    """
    def __init__(self, in_dim=256, proj_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, in_dim)
        self.fc2 = nn.Linear(in_dim, proj_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.normalize(x, p=2, dim=-1)
        return x



def cross_modal_info_nce(z_e, z_m, temperature=0.1):
    """
    Cross-modal InfoNCE loss between ECG (z_e) and MCG (z_m).
    z_e: (B, D) ECG embeddings (L2-normalized)
    z_m: (B, D) MCG embeddings (L2-normalized)

    We compute similarity matrix S = z_e @ z_m^T / T
    and use symmetric loss: ECG→MCG and MCG→ECG.
    """
    assert z_e.shape == z_m.shape
    B, D = z_e.shape

    # cosine similarity (since both are normalized, dot = cos)
    logits = z_e @ z_m.T / temperature  # (B, B)

    targets = torch.arange(B, device=z_e.device)

    # ECG→MCG
    loss_e2m = F.cross_entropy(logits, targets)

    # MCG→ECG (transpose)
    loss_m2e = F.cross_entropy(logits.T, targets)

    loss = 0.5 * (loss_e2m + loss_m2e)
    return loss


# Loss

In [20]:
import torch
import torch.nn.functional as F


def simclr_nt_xent_loss(z1, z2, temperature=0.1):
    """
    z1, z2: (B, D) normalized embeddings from two views of the same batch.
    Returns scalar loss.
    """
    assert z1.shape == z2.shape
    batch_size = z1.shape[0]

    z = torch.cat([z1, z2], dim=0)  # (2B, D)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=-1)  # (2B, 2B)

    # Mask to remove self-similarity
    self_mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    sim = sim / temperature

    # For each anchor i in 0..2B-1, define positives and negatives
    # Positives: (i, i+B) or (i, i-B) depending on which half
    labels = torch.arange(2 * batch_size, device=z.device)
    labels = (labels + batch_size) % (2 * batch_size)  # positive index for each anchor

    # For cross-entropy, we need logits (2B, 2B-1) and labels
    sim = sim.masked_fill(self_mask, -1e9)

    loss = F.cross_entropy(sim, labels)
    return loss


# train_koch_crossmodal_smee

In [22]:
import torch
from torch.utils.data import DataLoader

#from koch_dataset import KochPairedBeatsDataset
#from models import (ECGEncoderSMEE,MCGEncoderSMEE,ProjectionHead, cross_modal_info_nce,   )


def train_koch_crossmodal_smee(
    npz_path="koch_pairs.npz",
    batch_size=8,
    lr=1e-3,
    epochs=50,
    device="cpu",
):
    # 1) Dataset & DataLoader
    dataset = KochPairedBeatsDataset(npz_path=npz_path, augment=True)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # 2) New encoders: SMEE-based ECG (32 ch) & MCG (100 ch)
    #    (no shared backbone yet, we keep it simple first)
    ecg_encoder = ECGEncoderSMEE(in_channels=32, bottleneck_channels=32, feat_dim=256)
    mcg_encoder = MCGEncoderSMEE(in_channels=100, bottleneck_channels=32, feat_dim=256)

    ecg_proj = ProjectionHead(in_dim=256, proj_dim=128)
    mcg_proj = ProjectionHead(in_dim=256, proj_dim=128)

    ecg_encoder.to(device)
    mcg_encoder.to(device)
    ecg_proj.to(device)
    mcg_proj.to(device)

    # 3) Optimizer
    params = (
        list(ecg_encoder.parameters())
        + list(mcg_encoder.parameters())
        + list(ecg_proj.parameters())
        + list(mcg_proj.parameters())
    )
    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4)

    # 4) Training loop
    ecg_encoder.train()
    mcg_encoder.train()
    ecg_proj.train()
    mcg_proj.train()

    for epoch in range(1, epochs + 1):
        running_loss = 0.0

        for step, (ecg, mcg) in enumerate(loader):
            ecg = ecg.to(device)  # (B, 32, 2000)
            mcg = mcg.to(device)  # (B, 100, 2000)

            optimizer.zero_grad()

            # Encode
            h_e = ecg_encoder(ecg)   # (B, 256)
            h_m = mcg_encoder(mcg)   # (B, 256)

            # Project
            z_e = ecg_proj(h_e)      # (B, 128)
            z_m = mcg_proj(h_m)      # (B, 128)

            # Cross-modal InfoNCE
            loss = cross_modal_info_nce(z_e, z_m, temperature=0.1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / (step + 1)
        print(f"Epoch {epoch:03d}/{epochs} - SMEE cross-modal loss: {avg_loss:.4f}")

    # 5) Save new model weights
    torch.save(ecg_encoder.state_dict(), "ecg_encoder_koch_smee.pth")
    torch.save(mcg_encoder.state_dict(), "mcg_encoder_koch_smee.pth")
    torch.save(ecg_proj.state_dict(), "ecg_proj_koch_smee.pth")
    torch.save(mcg_proj.state_dict(), "mcg_proj_koch_smee.pth")
    print("Saved SMEE-based ECG + MCG encoders and projection heads (Koch dataset)")


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    train_koch_crossmodal_smee(
        npz_path="koch_pairs.npz",
        batch_size=8,
        lr=1e-3,
        epochs=10,   # start with 30–50 to check things; you can increase later
        device=device,
    )


Using device: cuda
Epoch 001/10 - SMEE cross-modal loss: 2.0847
Epoch 002/10 - SMEE cross-modal loss: 2.0599
Epoch 003/10 - SMEE cross-modal loss: 2.0683
Epoch 004/10 - SMEE cross-modal loss: 1.9906
Epoch 005/10 - SMEE cross-modal loss: 1.8966
Epoch 006/10 - SMEE cross-modal loss: 1.9340
Epoch 007/10 - SMEE cross-modal loss: 1.7252
Epoch 008/10 - SMEE cross-modal loss: 1.5543
Epoch 009/10 - SMEE cross-modal loss: 1.3880
Epoch 010/10 - SMEE cross-modal loss: 1.2995
Saved SMEE-based ECG + MCG encoders and projection heads (Koch dataset)
