# KOCH PAIRED

In [1]:
import os
from glob import glob

import numpy as np
from scipy.signal import butter, filtfilt, find_peaks

import torch
from torch.utils.data import Dataset


class KochPairedBeatsDataset(Dataset):
    def __init__(self, npz_path="koch_pairs.npz", augment=False):
        """
        npz_path: path to the koch_pairs.npz file
        augment: if True, apply very light noise / jitter augmentations
        """
        data = np.load(npz_path, allow_pickle=True)
        self.ecg_beats = data["ecg_beats"]  # (N, C_ecg, T)
        self.mcg_beats = data["mcg_beats"]  # (N, C_mcg, T)
        self.fs = int(data["fs"][0])
        self.augment = augment

    def __len__(self):
        return self.ecg_beats.shape[0]

    def _augment(self, x):
        """
        x: numpy array (C, T)
        very light augmentations: small Gaussian noise and tiny time shift
        """
        # small noise
        x = x + 0.01 * np.random.randn(*x.shape).astype(np.float32)

        # tiny circular time shift up to ±20 samples
        max_shift = 20
        shift = np.random.randint(-max_shift, max_shift + 1)
        if shift != 0:
            x = np.roll(x, shift, axis=1)

        return x

    def __getitem__(self, idx):
        ecg = self.ecg_beats[idx]  # (C_ecg, T)
        mcg = self.mcg_beats[idx]  # (C_mcg, T)

        if self.augment:
            ecg = self._augment(ecg)
            mcg = self._augment(mcg)

        # convert to torch tensors
        ecg = torch.from_numpy(ecg)  # float32, shape (C_ecg, T)
        mcg = torch.from_numpy(mcg)  # float32, shape (C_mcg, T)

        return ecg, mcg


#from koch_dataset import KochPairedBeatsDataset
from torch.utils.data import DataLoader

ds = KochPairedBeatsDataset("koch_pairs.npz", augment=False)
print("Num beats:", len(ds))

loader = DataLoader(ds, batch_size=8, shuffle=True)
ecg, mcg = next(iter(loader))
print("ECG batch shape:", ecg.shape)  # expect: torch.Size([8, 32, 2000])
print("MCG batch shape:", mcg.shape)  # expect: torch.Size([8, 100, 2000])


Num beats: 127
ECG batch shape: torch.Size([8, 32, 2000])
MCG batch shape: torch.Size([8, 100, 2000])


# PTB Beats Dataset

In [2]:
import numpy as np
import torch
from torch.utils.data import Dataset


class PTBBeatsDataset(Dataset):
    def __init__(self, npz_path="ptb_beats.npz", augment=True):
        data = np.load(npz_path, allow_pickle=True)
        self.beats = data["beats"]        # (N, 12, 2000)
        self.fs = int(data["fs"][0])
        self.augment = augment

    def __len__(self):
        return self.beats.shape[0]

    def _augment(self, x):
        # x: (12, T)
        x = x + 0.01 * np.random.randn(*x.shape).astype(np.float32)  # noise
        max_shift = 20
        shift = np.random.randint(-max_shift, max_shift + 1)
        if shift != 0:
            x = np.roll(x, shift, axis=1)
        return x

    def __getitem__(self, idx):
        beat = self.beats[idx]
        if self.augment:
            beat = self._augment(beat)
        return torch.from_numpy(beat)



from torch.utils.data import DataLoader

ds_ptb = PTBBeatsDataset("ptb_beats.npz", augment=True)
loader_ptb = DataLoader(ds_ptb, batch_size=32, shuffle=True)

batch = next(iter(loader_ptb))
print(batch.shape)  # torch.Size([32, 12, 2000])


torch.Size([32, 12, 2000])


# Model

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class DepthwiseSeparableConv1d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
        super().__init__()
        self.depthwise = nn.Conv1d(in_ch, in_ch, kernel_size=kernel_size,
                                   stride=stride, padding=padding, groups=in_ch)
        self.pointwise = nn.Conv1d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class MultiScaleBlock(nn.Module):
    """
    Multi-scale depthwise-separable residual block.
    Input: (B, C, T) -> Output: (B, C, T)
    """
    def __init__(self, channels, kernels=(5, 9, 17)):
        super().__init__()
        self.branches = nn.ModuleList()
        for k in kernels:
            pad = k // 2
            self.branches.append(
                DepthwiseSeparableConv1d(channels, channels, kernel_size=k, padding=pad)
            )
        self.bn = nn.BatchNorm1d(channels)
        self.act = nn.ReLU()

    def forward(self, x):
        # x: (B, C, T)
        outs = []
        for conv in self.branches:
            outs.append(conv(x))
        out = sum(outs) / len(outs)  # average branches
        out = self.bn(out)
        out = self.act(out)
        return x + out  # residual


class SMEEBackbone(nn.Module):
    """
    Shared Multi-Scale Efficient Encoder backbone.
    Input: (B, C_bottleneck, T) -> Output: (B, feat_dim)
    """
    def __init__(self, bottleneck_channels=32, n_blocks=3, feat_dim=256):
        super().__init__()
        blocks = []
        for _ in range(n_blocks):
            blocks.append(MultiScaleBlock(bottleneck_channels, kernels=(5, 9, 17)))
        self.blocks = nn.Sequential(*blocks)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(bottleneck_channels, feat_dim)

    def forward(self, x):
        # x: (B, C_bottleneck, T)
        x = self.blocks(x)                 # (B, C_bottleneck, T)
        x = self.global_pool(x).squeeze(-1)  # (B, C_bottleneck)
        x = self.fc(x)                     # (B, feat_dim)
        return x


class ECGEncoderSMEE(nn.Module):
    """
    ECG encoder using SMEE backbone.
    """
    def __init__(self, in_channels=32, bottleneck_channels=32, feat_dim=256):
        super().__init__()
        self.input_proj = nn.Conv1d(in_channels, bottleneck_channels, kernel_size=1)
        self.backbone = SMEEBackbone(bottleneck_channels=bottleneck_channels,
                                     n_blocks=3, feat_dim=feat_dim)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.backbone(x)
        return x


class MCGEncoderSMEE(nn.Module):
    """
    MCG encoder using SMEE backbone. Can share backbone weights with ECG encoder if desired.
    """
    def __init__(
        self,
        in_channels: int = 100,
        bottleneck_channels: int = 32,
        feat_dim: int = 256,
        shared_backbone: Optional[SMEEBackbone] = None,
    ):
        super().__init__()
        self.input_proj = nn.Conv1d(in_channels, bottleneck_channels, kernel_size=1)
        if shared_backbone is None:
            self.backbone = SMEEBackbone(
                bottleneck_channels=bottleneck_channels,
                n_blocks=3,
                feat_dim=feat_dim,
            )
        else:
            self.backbone = shared_backbone  # weight sharing

    def forward(self, x):
        x = self.input_proj(x)
        x = self.backbone(x)
        return x


class Conv1DEncoder(nn.Module):
    """
    Generic 1D CNN encoder for time series.
    Used for ECG (C=12 or 32) and MCG (C=100).
    Input: (B, C, T)
    Output: (B, feat_dim)
    """
    def __init__(self, in_channels, feat_dim=256):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1   = nn.BatchNorm1d(64)

        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2)
        self.bn2   = nn.BatchNorm1d(128)

        self.conv3 = nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2)
        self.bn3   = nn.BatchNorm1d(256)

        self.conv4 = nn.Conv1d(256, 256, kernel_size=3, stride=2, padding=1)
        self.bn4   = nn.BatchNorm1d(256)

        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, feat_dim)

    def forward(self, x):
        # x: (B, C, T)
        x = F.relu(self.bn1(self.conv1(x)))  # (B, 64, T/2)
        x = F.relu(self.bn2(self.conv2(x)))  # (B, 128, T/4)
        x = F.relu(self.bn3(self.conv3(x)))  # (B, 256, T/8)
        x = F.relu(self.bn4(self.conv4(x)))  # (B, 256, T/16)
        x = self.global_pool(x)              # (B, 256, 1)
        x = x.squeeze(-1)                    # (B, 256)
        x = self.fc(x)                       # (B, feat_dim)
        return x




In [None]:
#from models import ECGEncoderSMEE, MCGEncoderSMEE  
import torch

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

ecg_smee = ECGEncoderSMEE(in_channels=32)
mcg_smee = MCGEncoderSMEE(in_channels=100)

print("ECG SMEE params:", count_params(ecg_smee))
print("MCG SMEE params:", count_params(mcg_smee))

# quick forward test
x_ecg = torch.randn(4, 32, 2000)
x_mcg = torch.randn(4, 100, 2000)
h_e = ecg_smee(x_ecg)
h_m = mcg_smee(x_mcg)
print("Shapes:", h_e.shape, h_m.shape)  # should be: torch.Size([4, 256]) torch.Size([4, 256])


ECG SMEE params: 22464
MCG SMEE params: 24640
Shapes: torch.Size([4, 256]) torch.Size([4, 256])


# Projection Head

In [5]:


class ProjectionHead(nn.Module):
    """
    2-layer MLP projection head.
    Input: (B, feat_dim)
    Output: (B, proj_dim) normalized
    """
    def __init__(self, in_dim=256, proj_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, in_dim)
        self.fc2 = nn.Linear(in_dim, proj_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.normalize(x, p=2, dim=-1)
        return x



def cross_modal_info_nce(z_e, z_m, temperature=0.1):
    """
    Cross-modal InfoNCE loss between ECG (z_e) and MCG (z_m).
    z_e: (B, D) ECG embeddings (L2-normalized)
    z_m: (B, D) MCG embeddings (L2-normalized)

    We compute similarity matrix S = z_e @ z_m^T / T
    and use symmetric loss: ECG→MCG and MCG→ECG.
    """
    assert z_e.shape == z_m.shape
    B, D = z_e.shape

    # cosine similarity (since both are normalized, dot = cos)
    logits = z_e @ z_m.T / temperature  # (B, B)

    targets = torch.arange(B, device=z_e.device)

    # ECG→MCG
    loss_e2m = F.cross_entropy(logits, targets)

    # MCG→ECG (transpose)
    loss_m2e = F.cross_entropy(logits.T, targets)

    loss = 0.5 * (loss_e2m + loss_m2e)
    return loss


# Loss

In [6]:
import torch
import torch.nn.functional as F


def simclr_nt_xent_loss(z1, z2, temperature=0.1):
    """
    z1, z2: (B, D) normalized embeddings from two views of the same batch.
    Returns scalar loss.
    """
    assert z1.shape == z2.shape
    batch_size = z1.shape[0]

    z = torch.cat([z1, z2], dim=0)  # (2B, D)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=-1)  # (2B, 2B)

    # Mask to remove self-similarity
    self_mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    sim = sim / temperature

    # For each anchor i in 0..2B-1, define positives and negatives
    # Positives: (i, i+B) or (i, i-B) depending on which half
    labels = torch.arange(2 * batch_size, device=z.device)
    labels = (labels + batch_size) % (2 * batch_size)  # positive index for each anchor

    # For cross-entropy, we need logits (2B, 2B-1) and labels
    sim = sim.masked_fill(self_mask, -1e9)

    loss = F.cross_entropy(sim, labels)
    return loss


## New LOss :ECG↔MCG contrastive loss

In [25]:
import torch
import torch.nn.functional as F


def simclr_nt_xent_loss(z1, z2, temperature=0.1):
    """
    Standard NT-Xent (SimCLR) loss for two views.
    z1, z2: (B, D) L2-normalized embeddings.
    """
    assert z1.shape == z2.shape
    B = z1.shape[0]

    z = torch.cat([z1, z2], dim=0)  # (2B, D)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=-1)  # (2B, 2B)
    sim = sim / temperature

    labels = torch.arange(2 * B, device=z.device)
    labels = (labels + B) % (2 * B)  # positive for each index is the other view

    # mask self-similarities
    mask = torch.eye(2 * B, dtype=torch.bool, device=z.device)
    sim = sim.masked_fill(mask, -1e9)

    loss = F.cross_entropy(sim, labels)
    return loss


def cross_modal_info_nce(z_e, z_m, temperature=0.1):
    """
    Symmetric cross-modal InfoNCE between ECG (z_e) and MCG (z_m).
    z_e, z_m: (B, D) L2-normalized.
    """
    B = z_e.shape[0]
    logits = z_e @ z_m.T / temperature  # (B, B)
    targets = torch.arange(B, device=z_e.device)

    loss_e2m = F.cross_entropy(logits, targets)
    loss_m2e = F.cross_entropy(logits.T, targets)
    return 0.5 * (loss_e2m + loss_m2e)


def ecg_mcg_contrastive_loss(
    z_e, z_m,
    z_e_aug1=None, z_e_aug2=None,
    z_m_aug1=None, z_m_aug2=None,
    temperature=0.1,
    lambda_within=0.1,
):
    """
    Task-specific ECG<->MCG loss:

    L_total = L_cross(ECG, MCG) + λ (L_within_ECG + L_within_MCG)

    z_e, z_m: (B, D) main ECG/MCG embeddings
    z_e_aug1, z_e_aug2: (B, D) augmented ECG embeddings
    z_m_aug1, z_m_aug2: (B, D) augmented MCG embeddings
    """
    # cross-modal alignment
    loss = cross_modal_info_nce(z_e, z_m, temperature=temperature)

    # within-ECG consistency
    if (z_e_aug1 is not None) and (z_e_aug2 is not None):
        loss_ecg = simclr_nt_xent_loss(z_e_aug1, z_e_aug2, temperature=temperature)
        loss = loss + lambda_within * loss_ecg

    # within-MCG consistency
    if (z_m_aug1 is not None) and (z_m_aug2 is not None):
        loss_mcg = simclr_nt_xent_loss(z_m_aug1, z_m_aug2, temperature=temperature)
        loss = loss + lambda_within * loss_mcg

    return loss


# train_koch_crossmodal_smee

In [None]:
import torch
from torch.utils.data import DataLoader

#from koch_dataset import KochPairedBeatsDataset
#from models import (ECGEncoderSMEE,MCGEncoderSMEE,ProjectionHead, cross_modal_info_nce,   )


import torch




def train_koch_crossmodal_smee(
    npz_path="koch_pairs.npz",
    batch_size=8,
    lr=1e-3,
    epochs=50,
    device="cpu",
):
    # 1) Dataset & DataLoader
    dataset = KochPairedBeatsDataset(npz_path=npz_path, augment=True)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # 2) New encoders: SMEE-based ECG (32 ch) & MCG (100 ch)
    #    (no shared backbone yet, we keep it simple first)
    ecg_encoder = ECGEncoderSMEE(in_channels=32, bottleneck_channels=32, feat_dim=256)
    mcg_encoder = MCGEncoderSMEE(in_channels=100, bottleneck_channels=32, feat_dim=256)

    ecg_proj = ProjectionHead(in_dim=256, proj_dim=128)
    mcg_proj = ProjectionHead(in_dim=256, proj_dim=128)

    ecg_encoder.to(device)
    mcg_encoder.to(device)
    ecg_proj.to(device)
    mcg_proj.to(device)

    # 3) Optimizer
    params = (
        list(ecg_encoder.parameters())
        + list(mcg_encoder.parameters())
        + list(ecg_proj.parameters())
        + list(mcg_proj.parameters())
    )
    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4)

    # 4) Training loop
    ecg_encoder.train()
    mcg_encoder.train()
    ecg_proj.train()
    mcg_proj.train()

    for epoch in range(1, epochs + 1):
        running_loss = 0.0

        for step, (ecg, mcg) in enumerate(loader):
            ecg = ecg.to(device)  # (B, 32, 2000)
            mcg = mcg.to(device)  # (B, 100, 2000)

            optimizer.zero_grad()

            # Encode
            h_e = ecg_encoder(ecg)   # (B, 256)
            h_m = mcg_encoder(mcg)   # (B, 256)

            # Project
            z_e = ecg_proj(h_e)      # (B, 128)
            z_m = mcg_proj(h_m)      # (B, 128)

            # Cross-modal InfoNCE
            loss = cross_modal_info_nce(z_e, z_m, temperature=0.1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / (step + 1)
        print(f"Epoch {epoch:03d}/{epochs} - SMEE cross-modal loss: {avg_loss:.4f}")

    # 5) Save new model weights
    torch.save(ecg_encoder.state_dict(), "ecg_encoder_koch_smee.pth")
    torch.save(mcg_encoder.state_dict(), "mcg_encoder_koch_smee.pth")
    torch.save(ecg_proj.state_dict(), "ecg_proj_koch_smee.pth")
    torch.save(mcg_proj.state_dict(), "mcg_proj_koch_smee.pth")
    print("Saved SMEE-based ECG + MCG encoders and projection heads (Koch dataset)")


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    train_koch_crossmodal_smee(
        npz_path="koch_pairs.npz",
        batch_size=8,
        lr=1e-3,
        epochs=50, 
        device=device,
    )


Using device: cuda
Epoch 001/50 - SMEE cross-modal loss: 2.0827
Epoch 002/50 - SMEE cross-modal loss: 2.0606
Epoch 003/50 - SMEE cross-modal loss: 2.0716
Epoch 004/50 - SMEE cross-modal loss: 2.0653
Epoch 005/50 - SMEE cross-modal loss: 2.0598
Epoch 006/50 - SMEE cross-modal loss: 2.0424
Epoch 007/50 - SMEE cross-modal loss: 2.0075
Epoch 008/50 - SMEE cross-modal loss: 1.9159
Epoch 009/50 - SMEE cross-modal loss: 2.0060
Epoch 010/50 - SMEE cross-modal loss: 1.9069
Epoch 011/50 - SMEE cross-modal loss: 1.7929
Epoch 012/50 - SMEE cross-modal loss: 1.8310
Epoch 013/50 - SMEE cross-modal loss: 1.8830
Epoch 014/50 - SMEE cross-modal loss: 1.6852
Epoch 015/50 - SMEE cross-modal loss: 1.5853
Epoch 016/50 - SMEE cross-modal loss: 1.5584
Epoch 017/50 - SMEE cross-modal loss: 1.4793
Epoch 018/50 - SMEE cross-modal loss: 1.5021
Epoch 019/50 - SMEE cross-modal loss: 1.5005
Epoch 020/50 - SMEE cross-modal loss: 1.3316
Epoch 021/50 - SMEE cross-modal loss: 1.3059
Epoch 022/50 - SMEE cross-modal loss

## Train with new loss:ECG↔MCG contrastive loss

In [26]:
import torch
from torch.utils.data import DataLoader

def augment_signal(x, noise_std=0.01, max_shift=20):
    """
    x: (B, C, T) tensor
    Adds small Gaussian noise and a tiny circular time shift.
    """
    # small Gaussian noise
    x = x + noise_std * torch.randn_like(x)

    # random circular shift along time dimension
    if max_shift > 0:
        shift = torch.randint(-max_shift, max_shift + 1, (1,)).item()
        if shift != 0:
            x = torch.roll(x, shifts=shift, dims=-1)

    return x



#from koch_dataset import KochPairedBeatsDataset
# models import (
#ECGEncoderSMEE,
#    MCGEncoderSMEE,
#    ProjectionHead,
#ecg_mcg_contrastive_loss,
#)
#from train_utils import augment_signal  # or paste augment_signal in this file


def train_koch_crossmodal_smee_ecgmcg_loss(
    npz_path="koch_pairs.npz",
    batch_size=8,
    lr=1e-3,
    epochs=50,
    device="cpu",
    temperature=0.1,
    lambda_within=0.1,
):
    # 1) Dataset (we can keep augment=False; we handle aug in the loop)
    dataset = KochPairedBeatsDataset(npz_path=npz_path, augment=False)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    # 2) New SMEE encoders
    ecg_encoder = ECGEncoderSMEE(in_channels=32, bottleneck_channels=32, feat_dim=256)
    mcg_encoder = MCGEncoderSMEE(in_channels=100, bottleneck_channels=32, feat_dim=256)

    ecg_proj = ProjectionHead(in_dim=256, proj_dim=128)
    mcg_proj = ProjectionHead(in_dim=256, proj_dim=128)

    ecg_encoder.to(device)
    mcg_encoder.to(device)
    ecg_proj.to(device)
    mcg_proj.to(device)

    params = (
        list(ecg_encoder.parameters())
        + list(mcg_encoder.parameters())
        + list(ecg_proj.parameters())
        + list(mcg_proj.parameters())
    )
    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=1e-4)

    ecg_encoder.train()
    mcg_encoder.train()
    ecg_proj.train()
    mcg_proj.train()

    for epoch in range(1, epochs + 1):
        running_loss = 0.0

        for step, (ecg, mcg) in enumerate(loader):
            ecg = ecg.to(device)  # (B, 32, 2000)
            mcg = mcg.to(device)  # (B, 100, 2000)

            # Create augmentations for within-modal SimCLR terms
            ecg1 = augment_signal(ecg.clone())
            ecg2 = augment_signal(ecg.clone())
            mcg1 = augment_signal(mcg.clone())
            mcg2 = augment_signal(mcg.clone())

            optimizer.zero_grad()

            # --- Original (for cross-modal) ---
            h_e = ecg_encoder(ecg)   # (B, 256)
            h_m = mcg_encoder(mcg)   # (B, 256)
            z_e = ecg_proj(h_e)      # (B, 128)
            z_m = mcg_proj(h_m)      # (B, 128)

            # --- Augmented ECG ---
            h_e1 = ecg_encoder(ecg1)
            h_e2 = ecg_encoder(ecg2)
            z_e1 = ecg_proj(h_e1)
            z_e2 = ecg_proj(h_e2)

            # --- Augmented MCG ---
            h_m1 = mcg_encoder(mcg1)
            h_m2 = mcg_encoder(mcg2)
            z_m1 = mcg_proj(h_m1)
            z_m2 = mcg_proj(h_m2)

            # New ECG<->MCG-specific loss
            loss = ecg_mcg_contrastive_loss(
                z_e, z_m,
                z_e_aug1=z_e1, z_e_aug2=z_e2,
                z_m_aug1=z_m1, z_m_aug2=z_m2,
                temperature=temperature,
                lambda_within=lambda_within,
            )

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / (step + 1)
        print(
            f"Epoch {epoch:03d}/{epochs} - SMEE ECG<->MCG loss: "
            f"{avg_loss:.4f} (λ_within={lambda_within})"
        )

    # Save under new names so you keep the old SMEE model as well
    torch.save(ecg_encoder.state_dict(), "ecg_encoder_koch_smee_loss.pth")
    torch.save(mcg_encoder.state_dict(), "mcg_encoder_koch_smee_loss.pth")
    torch.save(ecg_proj.state_dict(), "ecg_proj_koch_smee_loss.pth")
    torch.save(mcg_proj.state_dict(), "mcg_proj_koch_smee_loss.pth")
    print("Saved SMEE ECG+MCG encoders with ECG<->MCG-specific loss.")


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)
    train_koch_crossmodal_smee_ecgmcg_loss(
        npz_path="koch_pairs.npz",
        batch_size=8,
        lr=1e-3,
        epochs=50,
        device=device,
        temperature=0.1,
        lambda_within=0.1,
    )


Using device: cuda
Epoch 001/50 - SMEE ECG<->MCG loss: 2.5535 (λ_within=0.1)
Epoch 002/50 - SMEE ECG<->MCG loss: 2.2967 (λ_within=0.1)
Epoch 003/50 - SMEE ECG<->MCG loss: 2.1071 (λ_within=0.1)
Epoch 004/50 - SMEE ECG<->MCG loss: 2.0618 (λ_within=0.1)
Epoch 005/50 - SMEE ECG<->MCG loss: 1.6499 (λ_within=0.1)
Epoch 006/50 - SMEE ECG<->MCG loss: 1.6335 (λ_within=0.1)
Epoch 007/50 - SMEE ECG<->MCG loss: 1.4556 (λ_within=0.1)
Epoch 008/50 - SMEE ECG<->MCG loss: 1.2041 (λ_within=0.1)
Epoch 009/50 - SMEE ECG<->MCG loss: 1.0939 (λ_within=0.1)
Epoch 010/50 - SMEE ECG<->MCG loss: 1.1297 (λ_within=0.1)
Epoch 011/50 - SMEE ECG<->MCG loss: 1.1095 (λ_within=0.1)
Epoch 012/50 - SMEE ECG<->MCG loss: 0.7167 (λ_within=0.1)
Epoch 013/50 - SMEE ECG<->MCG loss: 0.9641 (λ_within=0.1)
Epoch 014/50 - SMEE ECG<->MCG loss: 0.8343 (λ_within=0.1)
Epoch 015/50 - SMEE ECG<->MCG loss: 0.5815 (λ_within=0.1)
Epoch 016/50 - SMEE ECG<->MCG loss: 0.5014 (λ_within=0.1)
Epoch 017/50 - SMEE ECG<->MCG loss: 0.4769 (λ_within=

# eval script for SMEE model for koch

## New

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader

#from koch_dataset import KochPairedBeatsDataset
#from models import ECGEncoderSMEE, MCGEncoderSMEE, ProjectionHead


def eval_koch_retrieval_smee(
    npz_path="koch_pairs.npz",
    ecg_encoder_path="ecg_encoder_koch_smee.pth",
    mcg_encoder_path="mcg_encoder_koch_smee.pth",
    ecg_proj_path="ecg_proj_koch_smee.pth",
    mcg_proj_path="mcg_proj_koch_smee.pth",
    batch_size=64,
    device="cpu",
):
    # 1) Dataset (no augmentation)
    dataset = KochPairedBeatsDataset(npz_path=npz_path, augment=False)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # 2) Rebuild SMEE encoders + projection heads
    ecg_encoder = ECGEncoderSMEE(in_channels=32, bottleneck_channels=32, feat_dim=256)
    mcg_encoder = MCGEncoderSMEE(in_channels=100, bottleneck_channels=32, feat_dim=256)

    ecg_proj = ProjectionHead(in_dim=256, proj_dim=128)
    mcg_proj = ProjectionHead(in_dim=256, proj_dim=128)

    # Load weights
    ecg_encoder.load_state_dict(torch.load(ecg_encoder_path, map_location="cpu"))
    mcg_encoder.load_state_dict(torch.load(mcg_encoder_path, map_location="cpu"))
    ecg_proj.load_state_dict(torch.load(ecg_proj_path, map_location="cpu"))
    mcg_proj.load_state_dict(torch.load(mcg_proj_path, map_location="cpu"))

    ecg_encoder.to(device).eval()
    mcg_encoder.to(device).eval()
    ecg_proj.to(device).eval()
    mcg_proj.to(device).eval()

    # 3) Compute embeddings for all beats
    all_z_e = []
    all_z_m = []

    with torch.no_grad():
        for ecg, mcg in loader:
            ecg = ecg.to(device)  # (B, 32, 2000)
            mcg = mcg.to(device)  # (B, 100, 2000)

            h_e = ecg_encoder(ecg)   # (B, 256)
            h_m = mcg_encoder(mcg)   # (B, 256)

            z_e = ecg_proj(h_e)      # (B, 128)
            z_m = mcg_proj(h_m)      # (B, 128)

            all_z_e.append(z_e.cpu())
            all_z_m.append(z_m.cpu())

    z_e = torch.cat(all_z_e, dim=0)  # (N, 128)
    z_m = torch.cat(all_z_m, dim=0)  # (N, 128)
    N = z_e.shape[0]
    print(f"Total beats: {N}")

    # 4) Cosine similarity matrix
    sim = z_e @ z_m.T  # (N, N), embeddings already L2-normalized

    # ECG -> MCG top-1
    preds_e2m = sim.argmax(dim=1)
    correct_e2m = (preds_e2m == torch.arange(N)).float().mean().item()

    # MCG -> ECG top-1
    preds_m2e = sim.argmax(dim=0)
    correct_m2e = (preds_m2e == torch.arange(N)).float().mean().item()

    # ECG -> MCG top-5
    topk = 5
    topk_e2m = sim.topk(topk, dim=1).indices  # (N, K)
    correct_topk_e2m = (
        (topk_e2m == torch.arange(N).unsqueeze(1)).any(dim=1).float().mean().item()
    )

    # MCG -> ECG top-5
    topk_m2e = sim.topk(topk, dim=0).indices  # (K, N)
    correct_topk_m2e = (
        (topk_m2e == torch.arange(N).unsqueeze(0)).any(dim=0).float().mean().item()
    )

    print(f"SMEE ECG→MCG top-1 retrieval:  {correct_e2m*100:.2f}%")
    print(f"SMEE MCG→ECG top-1 retrieval:  {correct_m2e*100:.2f}%")
    print(f"SMEE ECG→MCG top-{topk} retrieval: {correct_topk_e2m*100:.2f}%")
    print(f"SMEE MCG→ECG top-{topk} retrieval: {correct_topk_m2e*100:.2f}%")


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)
    eval_koch_retrieval_smee(device=device)


Using device: cuda
Total beats: 127
SMEE ECG→MCG top-1 retrieval:  18.90%
SMEE MCG→ECG top-1 retrieval:  19.69%
SMEE ECG→MCG top-5 retrieval: 72.44%
SMEE MCG→ECG top-5 retrieval: 58.27%


## New Loss: ECG↔MCG contrastive loss

In [27]:
import torch
import numpy as np
from torch.utils.data import DataLoader

#from koch_dataset import KochPairedBeatsDataset
#from models import ECGEncoderSMEE, MCGEncoderSMEE, ProjectionHead


def eval_koch_retrieval_smee(
    npz_path="koch_pairs.npz",
    ecg_encoder_path="ecg_encoder_koch_smee_loss.pth",
    mcg_encoder_path="mcg_encoder_koch_smee_loss.pth",
    ecg_proj_path="ecg_proj_koch_smee_loss.pth",
    mcg_proj_path="mcg_proj_koch_smee_loss.pth",
    batch_size=64,
    device="cpu",
):
    # 1) Dataset (no augmentation)
    dataset = KochPairedBeatsDataset(npz_path=npz_path, augment=False)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # 2) Rebuild SMEE encoders + projection heads
    ecg_encoder = ECGEncoderSMEE(in_channels=32, bottleneck_channels=32, feat_dim=256)
    mcg_encoder = MCGEncoderSMEE(in_channels=100, bottleneck_channels=32, feat_dim=256)

    ecg_proj = ProjectionHead(in_dim=256, proj_dim=128)
    mcg_proj = ProjectionHead(in_dim=256, proj_dim=128)

    # Load weights
    ecg_encoder.load_state_dict(torch.load(ecg_encoder_path, map_location="cpu"))
    mcg_encoder.load_state_dict(torch.load(mcg_encoder_path, map_location="cpu"))
    ecg_proj.load_state_dict(torch.load(ecg_proj_path, map_location="cpu"))
    mcg_proj.load_state_dict(torch.load(mcg_proj_path, map_location="cpu"))

    ecg_encoder.to(device).eval()
    mcg_encoder.to(device).eval()
    ecg_proj.to(device).eval()
    mcg_proj.to(device).eval()

    # 3) Compute embeddings for all beats
    all_z_e = []
    all_z_m = []

    with torch.no_grad():
        for ecg, mcg in loader:
            ecg = ecg.to(device)  # (B, 32, 2000)
            mcg = mcg.to(device)  # (B, 100, 2000)

            h_e = ecg_encoder(ecg)   # (B, 256)
            h_m = mcg_encoder(mcg)   # (B, 256)

            z_e = ecg_proj(h_e)      # (B, 128)
            z_m = mcg_proj(h_m)      # (B, 128)

            all_z_e.append(z_e.cpu())
            all_z_m.append(z_m.cpu())

    z_e = torch.cat(all_z_e, dim=0)  # (N, 128)
    z_m = torch.cat(all_z_m, dim=0)  # (N, 128)
    N = z_e.shape[0]
    print(f"Total beats: {N}")

    # 4) Cosine similarity matrix
    sim = z_e @ z_m.T  # (N, N), embeddings already L2-normalized

    # ECG -> MCG top-1
    preds_e2m = sim.argmax(dim=1)
    correct_e2m = (preds_e2m == torch.arange(N)).float().mean().item()

    # MCG -> ECG top-1
    preds_m2e = sim.argmax(dim=0)
    correct_m2e = (preds_m2e == torch.arange(N)).float().mean().item()

    # ECG -> MCG top-5
    topk = 5
    topk_e2m = sim.topk(topk, dim=1).indices  # (N, K)
    correct_topk_e2m = (
        (topk_e2m == torch.arange(N).unsqueeze(1)).any(dim=1).float().mean().item()
    )

    # MCG -> ECG top-5
    topk_m2e = sim.topk(topk, dim=0).indices  # (K, N)
    correct_topk_m2e = (
        (topk_m2e == torch.arange(N).unsqueeze(0)).any(dim=0).float().mean().item()
    )

    print(f"SMEE ECG→MCG top-1 retrieval:  {correct_e2m*100:.2f}%")
    print(f"SMEE MCG→ECG top-1 retrieval:  {correct_m2e*100:.2f}%")
    print(f"SMEE ECG→MCG top-{topk} retrieval: {correct_topk_e2m*100:.2f}%")
    print(f"SMEE MCG→ECG top-{topk} retrieval: {correct_topk_m2e*100:.2f}%")


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)
    eval_koch_retrieval_smee(device=device)


Using device: cuda
Total beats: 127
SMEE ECG→MCG top-1 retrieval:  59.06%
SMEE MCG→ECG top-1 retrieval:  56.69%
SMEE ECG→MCG top-5 retrieval: 95.28%
SMEE MCG→ECG top-5 retrieval: 91.34%


In [19]:
import torch
#from models import Conv1DEncoder, ECGEncoderSMEE, MCGEncoderSMEE

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Old baseline encoders (what you used first on Koch)
old_ecg = Conv1DEncoder(in_channels=32, feat_dim=256)
old_mcg = Conv1DEncoder(in_channels=100, feat_dim=256)

# New SMEE encoders
new_ecg = ECGEncoderSMEE(in_channels=32, bottleneck_channels=32, feat_dim=256)
new_mcg = MCGEncoderSMEE(in_channels=100, bottleneck_channels=32, feat_dim=256)

print("Old ECG encoder params:", count_params(old_ecg))
print("Old MCG encoder params:", count_params(old_mcg))
print("New ECG SMEE params:", count_params(new_ecg))
print("New MCG SMEE params:", count_params(new_mcg))


Old ECG encoder params: 483648
Old MCG encoder params: 514112
New ECG SMEE params: 22464
New MCG SMEE params: 24640


In [24]:
import torch
from torch.utils.data import DataLoader

#from koch_dataset import KochPairedBeatsDataset
#from models import Conv1DEncoder, ProjectionHead


def eval_koch_retrieval_baseline(
    npz_path="koch_pairs.npz",
    ecg_encoder_path="ecg_encoder_koch.pth",
    mcg_encoder_path="mcg_encoder_koch.pth",
    ecg_proj_path="ecg_proj_koch.pth",
    mcg_proj_path="mcg_proj_koch.pth",
    batch_size=64,
    device="cpu",
):
    dataset = KochPairedBeatsDataset(npz_path=npz_path, augment=False)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Baseline encoders
    ecg_encoder = Conv1DEncoder(in_channels=32, feat_dim=256)
    mcg_encoder = Conv1DEncoder(in_channels=100, feat_dim=256)

    ecg_proj = ProjectionHead(in_dim=256, proj_dim=128)
    mcg_proj = ProjectionHead(in_dim=256, proj_dim=128)

    ecg_encoder.load_state_dict(torch.load(ecg_encoder_path, map_location="cpu"))
    mcg_encoder.load_state_dict(torch.load(mcg_encoder_path, map_location="cpu"))
    ecg_proj.load_state_dict(torch.load(ecg_proj_path, map_location="cpu"))
    mcg_proj.load_state_dict(torch.load(mcg_proj_path, map_location="cpu"))

    ecg_encoder.to(device).eval()
    mcg_encoder.to(device).eval()
    ecg_proj.to(device).eval()
    mcg_proj.to(device).eval()

    all_z_e, all_z_m = [], []

    with torch.no_grad():
        for ecg, mcg in loader:
            ecg = ecg.to(device)
            mcg = mcg.to(device)

            h_e = ecg_encoder(ecg)
            h_m = mcg_encoder(mcg)

            z_e = ecg_proj(h_e)
            z_m = mcg_proj(h_m)

            all_z_e.append(z_e.cpu())
            all_z_m.append(z_m.cpu())

    z_e = torch.cat(all_z_e, dim=0)
    z_m = torch.cat(all_z_m, dim=0)
    N = z_e.shape[0]
    print(f"Total beats: {N}")

    sim = z_e @ z_m.T  # (N, N)

    # top-1
    preds_e2m = sim.argmax(dim=1)
    preds_m2e = sim.argmax(dim=0)
    top1_e2m = (preds_e2m == torch.arange(N)).float().mean().item()
    top1_m2e = (preds_m2e == torch.arange(N)).float().mean().item()

    # top-5
    K = 5
    topk_e2m = sim.topk(K, dim=1).indices
    topk_m2e = sim.topk(K, dim=0).indices

    top5_e2m = (
        (topk_e2m == torch.arange(N).unsqueeze(1)).any(dim=1).float().mean().item()
    )
    top5_m2e = (
        (topk_m2e == torch.arange(N).unsqueeze(0)).any(dim=0).float().mean().item()
    )

    print(f"Baseline ECG→MCG top-1:  {top1_e2m*100:.2f}%")
    print(f"Baseline MCG→ECG top-1:  {top1_m2e*100:.2f}%")
    print(f"Baseline ECG→MCG top-5: {top5_e2m*100:.2f}%")
    print(f"Baseline MCG→ECG top-5: {top5_m2e*100:.2f}%")


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)
    eval_koch_retrieval_baseline(device=device)


Using device: cuda
Total beats: 127
Baseline ECG→MCG top-1:  7.09%
Baseline MCG→ECG top-1:  7.87%
Baseline ECG→MCG top-5: 29.13%
Baseline MCG→ECG top-5: 29.13%
