# 2D ECG Masking Visualization

This notebook visualizes the different masking strategies for ECG-JEPA:

1. **1D Temporal Masking (Standard)**: Same mask applied to all 12 leads
2. **2D Random Block Masking**: Rectangular blocks in (lead x time) space
3. **2D Lead Group Masking**: Clinically motivated groups (inferior, lateral, anterior) at random time ranges

The 2D approaches allow different leads to have different masks, enabling the model to learn cross-lead relationships explicitly.


In [None]:
import sys
import os
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap

# Add project root to path
sys.path.insert(0, os.getcwd())

from data.masks import MaskCollator, MaskCollator2D, MaskCollator2DLeadGroup

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['axes.grid'] = False

print("Modules loaded successfully!")


In [None]:
# Configuration
DATA_PATH = Path("../data/ptb-xl.npy")  # Adjust path as needed
PATCH_SIZE = 25
NUM_PATCHES = 200
CROP_SIZE = PATCH_SIZE * NUM_PATCHES
CHANNELS = 12
SAMPLING_FREQUENCY = 500

LEAD_NAMES = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

# Clinical lead groupings
LEAD_GROUPS = {
    'inferior': [1, 2, 5],       # II, III, aVF
    'lateral': [0, 4, 10, 11],   # I, aVL, V5, V6
    'anterior': [6, 7, 8, 9],    # V1, V2, V3, V4
}

def load_single_ecg(path: Path):
    """Load a single ECG sample from numpy file."""
    if path.is_file():
        arr = np.load(path, mmap_mode='r')
        idx = np.random.randint(0, len(arr))
        sample = np.array(arr[idx])
        ecg = sample if sample.shape[0] == CHANNELS else sample.T
        if ecg.shape[1] < CROP_SIZE:
            pad = np.zeros((CHANNELS, CROP_SIZE), dtype=ecg.dtype)
            pad[:, :ecg.shape[1]] = ecg
            ecg = pad
        elif ecg.shape[1] > CROP_SIZE:
            start = np.random.randint(0, ecg.shape[1] - CROP_SIZE + 1)
            ecg = ecg[:, start:start + CROP_SIZE]
        print(f"Loaded ECG #{idx} from {path}")
        return torch.from_numpy(ecg).float()
    else:
        print(f"File not found. Using synthetic signal.")
        t = np.linspace(0, 10, CROP_SIZE)
        ecg = np.zeros((CHANNELS, CROP_SIZE))
        for i in range(CHANNELS):
            ecg[i] = 0.1 * np.sin(2 * np.pi * 1.2 * t + i * 0.3)
            for qrs_time in np.arange(0.5, 10, 0.8):
                qrs_idx = int(qrs_time * SAMPLING_FREQUENCY)
                if qrs_idx < CROP_SIZE - 50:
                    ecg[i, qrs_idx:qrs_idx+20] += np.random.randn(20) * 0.5
        return torch.from_numpy(ecg).float()

ecg = load_single_ecg(DATA_PATH)
print(f"ECG shape: {ecg.shape} (12 leads x 5000 samples)")


## Create Mask Collators


In [None]:
# Create all three types of collators
collator_1d = MaskCollator(
    patch_size=PATCH_SIZE,
    min_block_size=10,
    min_keep_ratio=0.15,
    max_keep_ratio=0.25
)

collator_2d_random = MaskCollator2D(
    patch_size=PATCH_SIZE,
    num_leads=CHANNELS,
    min_keep_ratio=0.15,
    max_keep_ratio=0.25,
    min_block_leads=2,
    max_block_leads=6,
    min_block_time=15,
    max_block_time=50
)

collator_2d_lead_group = MaskCollator2DLeadGroup(
    patch_size=PATCH_SIZE,
    num_leads=CHANNELS,
    min_keep_ratio=0.15,
    max_keep_ratio=0.25,
    min_block_time=20,
    max_block_time=60
)

print("Collators created!")


## Generate Masks and Statistics


In [None]:
# Generate masks from each collator
_, mask_1d_ctx, mask_1d_tgt = collator_1d([ecg])
_, (mask_2d_rand_ctx, _), (mask_2d_rand_tgt, _) = collator_2d_random([ecg])
_, (mask_2d_group_ctx, _), (mask_2d_group_tgt, _) = collator_2d_lead_group([ecg])

print("=" * 60)
print("1D Temporal Masking:")
print(f"  Context patches: {mask_1d_ctx.shape[1]} ({mask_1d_ctx.shape[1]/NUM_PATCHES*100:.1f}%)")
print(f"  Target patches: {mask_1d_tgt.shape[1]} ({mask_1d_tgt.shape[1]/NUM_PATCHES*100:.1f}%)")
print(f"  Same mask for all 12 leads")

print("\n" + "=" * 60)
print("2D Random Block Masking:")
total_tokens = CHANNELS * NUM_PATCHES
ctx_tokens = mask_2d_rand_ctx.shape[1]
tgt_tokens = mask_2d_rand_tgt.shape[1]
print(f"  Context tokens: {ctx_tokens} ({ctx_tokens/total_tokens*100:.1f}%)")
print(f"  Target tokens: {tgt_tokens} ({tgt_tokens/total_tokens*100:.1f}%)")
print(f"  Different masks per lead (2D blocks)")

print("\n" + "=" * 60)
print("2D Lead Group Masking:")
ctx_tokens = mask_2d_group_ctx.shape[1]
tgt_tokens = mask_2d_group_tgt.shape[1]
print(f"  Context tokens: {ctx_tokens} ({ctx_tokens/total_tokens*100:.1f}%)")
print(f"  Target tokens: {tgt_tokens} ({tgt_tokens/total_tokens*100:.1f}%)")
print(f"  Clinically motivated lead groups")


## Visualization Functions


In [None]:
def convert_1d_to_2d_mask(mask_ctx, mask_tgt, num_leads=12, num_patches=200):
    """Convert 1D mask (time indices) to 2D grid for visualization."""
    grid = np.ones((num_leads, num_patches))  # Default: target
    ctx_indices = mask_ctx[0].numpy()
    for t in ctx_indices:
        grid[:, t] = 0
    return grid


def convert_2d_indices_to_grid(mask_ctx, mask_tgt, num_leads=12, num_patches=200):
    """Convert 2D mask indices (lead, time) pairs to 2D grid for visualization."""
    grid = np.ones((num_leads, num_patches))  # Default: target
    ctx_indices = mask_ctx[0].numpy()  # (K, 2)
    for idx in ctx_indices:
        lead, time = idx
        if lead < num_leads and time < num_patches:
            grid[lead, time] = 0
    return grid


def plot_mask_heatmap(grid, title, lead_names=LEAD_NAMES):
    """Plot a heatmap of the mask grid."""
    fig, ax = plt.subplots(figsize=(16, 5), dpi=100)
    cmap = ListedColormap(['#2E86AB', '#E94F37'])
    im = ax.imshow(grid, aspect='auto', cmap=cmap, vmin=0, vmax=1)
    
    ax.set_yticks(range(len(lead_names)))
    ax.set_yticklabels(lead_names)
    ax.set_xlabel('Time Patch Index', fontsize=12)
    ax.set_ylabel('Lead', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xticks(np.arange(0, grid.shape[1], 20))
    
    for i in range(1, len(lead_names)):
        ax.axhline(i - 0.5, color='white', linewidth=0.5, alpha=0.5)
    
    legend_elements = [
        mpatches.Patch(facecolor='#2E86AB', label='Context (visible to encoder)'),
        mpatches.Patch(facecolor='#E94F37', label='Target (to predict)'),
    ]
    ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1))
    
    ctx_pct = (grid == 0).sum() / grid.size * 100
    tgt_pct = (grid == 1).sum() / grid.size * 100
    ax.text(1.02, 0.5, f'Context: {ctx_pct:.1f}%\nTarget: {tgt_pct:.1f}%',
            transform=ax.transAxes, fontsize=10, verticalalignment='center')
    
    plt.tight_layout()
    plt.show()
    return fig


## 1. Standard 1D Temporal Masking

The same mask is applied to all 12 leads - this is the current ECG-JEPA approach.


In [None]:
grid_1d = convert_1d_to_2d_mask(mask_1d_ctx, mask_1d_tgt)
plot_mask_heatmap(grid_1d, '1D Temporal Masking (Standard ECG-JEPA)\nSame mask for all leads')


## 2. 2D Random Block Masking

Random rectangular blocks are masked in the (lead x time) space. Different leads can have different masks.


In [None]:
grid_2d_random = convert_2d_indices_to_grid(mask_2d_rand_ctx, mask_2d_rand_tgt)
plot_mask_heatmap(grid_2d_random, '2D Random Block Masking\nRectangular blocks in (lead x time) space')


## 3. 2D Lead Group Masking (Clinically Motivated)

Clinically meaningful lead groups are masked at random time ranges:
- **Inferior**: II, III, aVF (right coronary artery territory)
- **Lateral**: I, aVL, V5, V6 (circumflex artery territory)  
- **Anterior**: V1, V2, V3, V4 (LAD territory)


In [None]:
grid_2d_group = convert_2d_indices_to_grid(mask_2d_group_ctx, mask_2d_group_tgt)
plot_mask_heatmap(grid_2d_group, '2D Lead Group Masking (Clinically Motivated)\nInferior, Lateral, Anterior lead groups')


## Side-by-Side Comparison


In [None]:
fig, axes = plt.subplots(3, 1, figsize=(16, 12), dpi=100)

cmap = ListedColormap(['#2E86AB', '#E94F37'])

# 1D
axes[0].imshow(grid_1d, aspect='auto', cmap=cmap, vmin=0, vmax=1)
axes[0].set_yticks(range(len(LEAD_NAMES)))
axes[0].set_yticklabels(LEAD_NAMES)
axes[0].set_title('1D Temporal Masking (Standard)', fontsize=12, fontweight='bold')
ctx_pct = (grid_1d == 0).sum() / grid_1d.size * 100
axes[0].text(1.01, 0.5, f'Context: {ctx_pct:.1f}%', transform=axes[0].transAxes, fontsize=10, va='center')

# 2D Random
axes[1].imshow(grid_2d_random, aspect='auto', cmap=cmap, vmin=0, vmax=1)
axes[1].set_yticks(range(len(LEAD_NAMES)))
axes[1].set_yticklabels(LEAD_NAMES)
axes[1].set_title('2D Random Block Masking', fontsize=12, fontweight='bold')
ctx_pct = (grid_2d_random == 0).sum() / grid_2d_random.size * 100
axes[1].text(1.01, 0.5, f'Context: {ctx_pct:.1f}%', transform=axes[1].transAxes, fontsize=10, va='center')

# 2D Lead Group
axes[2].imshow(grid_2d_group, aspect='auto', cmap=cmap, vmin=0, vmax=1)
axes[2].set_yticks(range(len(LEAD_NAMES)))
axes[2].set_yticklabels(LEAD_NAMES)
axes[2].set_xlabel('Time Patch Index', fontsize=12)
axes[2].set_title('2D Lead Group Masking (Clinical)', fontsize=12, fontweight='bold')
ctx_pct = (grid_2d_group == 0).sum() / grid_2d_group.size * 100
axes[2].text(1.01, 0.5, f'Context: {ctx_pct:.1f}%', transform=axes[2].transAxes, fontsize=10, va='center')

# Common legend
legend_elements = [
    mpatches.Patch(facecolor='#2E86AB', label='Context (visible)'),
    mpatches.Patch(facecolor='#E94F37', label='Target (predict)'),
]
fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.98))

plt.tight_layout()
plt.subplots_adjust(right=0.92)
plt.show()


## Key Observations

**1D Temporal Masking**: Same mask for all 12 leads, only temporal context used.

**2D Random Block Masking**: Rectangular blocks, different masks per lead, forces cross-lead reasoning.

**2D Lead Group Masking**: Clinically meaningful groups (inferior, lateral, anterior), most interpretable.


In [None]:
# Show clinical lead groups
print("Clinical Lead Groups:")
print("=" * 50)
for group, indices in LEAD_GROUPS.items():
    leads = [LEAD_NAMES[i] for i in indices]
    print(f"{group.capitalize():12s}: {', '.join(leads)}")
