In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os

# === Paths ===
BASE_PATH = "/kaggle/input/physionet-ecg-image-digitization"  # <-- update if different
train_csv = os.path.join(BASE_PATH, "train.csv")

# === Load train metadata ===
train_meta = pd.read_csv(train_csv)
print("Training metadata shape:", train_meta.shape)
print(train_meta.head())

# === Pick one example ===
example_id = train_meta.iloc[0]["id"]
print("Example ID:", example_id)

# === Load corresponding ECG CSV ===
signal_path = os.path.join(BASE_PATH, "train", str(example_id), f"{example_id}.csv")
signal_df = pd.read_csv(signal_path)
print("Signal shape:", signal_df.shape)
print("Columns:", signal_df.columns)

# Plot one lead
plt.figure(figsize=(12, 3))
plt.plot(signal_df["II"].values[:1000])
plt.title(f"Lead II waveform (first 1000 samples) — {example_id}")
plt.xlabel("Samples")
plt.ylabel("mV")
plt.show()

# === Load an ECG image (e.g., the clean synthetic one) ===
img_path = os.path.join(BASE_PATH, "train", str(example_id), f"{example_id}-0001.png")
img = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(12, 6))
plt.imshow(img_rgb)
plt.title(f"ECG Image: {example_id}-0001.png")
plt.axis("off")
plt.show()


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def preprocess_ecg_image(img):
    """
    Convert ECG image to a cleaned, normalized grayscale version.
    - Converts to grayscale
    - Applies adaptive threshold to highlight waveform lines
    - Optionally deskews (disabled for now)
    """
    # 1️⃣ Convert to grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 2️⃣ Normalize brightness
    gray = cv2.equalizeHist(gray)

    # 3️⃣ Adaptive thresholding — separate ECG trace (dark lines) from background
    thresh = cv2.adaptiveThreshold(
        gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 51, 8
    )

    # 4️⃣ (Optional) Remove small noise blobs
    kernel = np.ones((2, 2), np.uint8)
    cleaned = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)

    return gray, cleaned


# === Test it on your example ===
img_path = "/kaggle/input/physionet-ecg-image-digitization/train/7663343/7663343-0001.png"
img = cv2.imread(img_path)

gray, cleaned = preprocess_ecg_image(img)

plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.imshow(gray, cmap="gray")
plt.title("Grayscale Equalized Image")
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(cleaned, cmap="gray")
plt.title("Thresholded + Cleaned Image (ECG trace visible)")
plt.axis("off")
plt.show()


In [None]:
import cv2
import matplotlib.pyplot as plt

def split_into_leads(image, rows=3, cols=4):
    """
    Splits an ECG image roughly into 12 rectangular lead regions (3 rows × 4 columns).
    Works well for standard 12-lead printouts.
    """
    h, w = image.shape[:2]
    lead_h, lead_w = h // rows, w // cols
    boxes = []
    for r in range(rows):
        for c in range(cols):
            y0, y1 = r * lead_h, (r + 1) * lead_h
            x0, x1 = c * lead_w, (c + 1) * lead_w
            boxes.append(((x0, y0, x1, y1), image[y0:y1, x0:x1]))
    return boxes


# --- Load the same example and preprocess ---
img_path = "/kaggle/input/physionet-ecg-image-digitization/train/7663343/7663343-0001.png"
img = cv2.imread(img_path)
gray, cleaned = preprocess_ecg_image(img)

# --- Split into 12 approximate leads ---
lead_boxes = split_into_leads(cleaned)

# Display all 12 crops for inspection
plt.figure(figsize=(12, 8))
for i, (coords, crop) in enumerate(lead_boxes):
    plt.subplot(3, 4, i + 1)
    plt.imshow(crop, cmap="gray")
    plt.title(f"Lead {i + 1}")
    plt.axis("off")
plt.tight_layout()
plt.show()

# --- Select Lead II (2nd region, index 1 in 0-based) ---
leadII_img = lead_boxes[1][1]
plt.figure(figsize=(10, 3))
plt.imshow(leadII_img, cmap="gray")
plt.title("Lead II cropped region")
plt.axis("off")
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

def trace_waveform(binary_img):
    """
    Converts a preprocessed (white-on-black) ECG image into a 1D waveform.
    For each column, computes the intensity-weighted vertical centroid
    of the white trace pixels.
    """
    # ensure grayscale
    if len(binary_img.shape) == 3:
        gray = cv2.cvtColor(binary_img, cv2.COLOR_BGR2GRAY)
    else:
        gray = binary_img.copy()

    # invert if necessary (trace should be bright)
    if np.mean(gray) < 127:
        gray = 255 - gray

    h, w = gray.shape
    y_positions = np.zeros(w, dtype=float)

    for x in range(w):
        column = gray[:, x].astype(np.float32)
        weights = column / 255.0 + 1e-6
        y_positions[x] = np.sum(np.arange(h) * weights) / np.sum(weights)

    # normalize vertically (invert y and center)
    waveform = (h - y_positions) - np.mean(h - y_positions)
    return waveform

# === Apply to the Lead II cropped image ===
leadII_waveform = trace_waveform(leadII_img)

plt.figure(figsize=(12, 3))
plt.plot(leadII_waveform)
plt.title("Extracted Waveform from Lead II image")
plt.xlabel("Column index (time proxy)")
plt.ylabel("Relative amplitude (pixels)")
plt.show()

print("Extracted waveform length:", len(leadII_waveform))


In [None]:
from scipy.signal import resample
import numpy as np
import matplotlib.pyplot as plt

# --- Load the ground truth signal again ---
signal_path = "/kaggle/input/physionet-ecg-image-digitization/train/7663343/7663343.csv"
signal_df = pd.read_csv(signal_path)
gt_waveform = signal_df["II"].values  # true Lead II

# --- Resample extracted waveform to match length of ground truth ---
resampled_waveform = resample(leadII_waveform, len(gt_waveform))

# --- Normalize both for fair comparison ---
resampled_waveform = (resampled_waveform - np.mean(resampled_waveform)) / np.std(resampled_waveform)
gt_norm = (gt_waveform - np.mean(gt_waveform)) / np.std(gt_waveform)

# --- Overlay plot ---
plt.figure(figsize=(12, 4))
plt.plot(gt_norm[:2000], label="Ground Truth Lead II", linewidth=2)
plt.plot(resampled_waveform[:2000], label="Extracted from Image", alpha=0.7)
plt.title("Overlay: True vs Extracted Waveform (First ~4 s)")
plt.xlabel("Samples")
plt.ylabel("Normalized amplitude")
plt.legend()
plt.show()

# --- Correlation (rough alignment quality) ---
corr = np.corrcoef(gt_norm, resampled_waveform)[0, 1]
print(f"Correlation between extracted and true waveform: {corr:.3f}")


In [None]:
from scipy.signal import correlate

# --- 1️⃣ Flip polarity (invert y-axis sense) ---
flipped_waveform = -resampled_waveform

# --- 2️⃣ Simple amplitude scaling to match ground truth std ---
flipped_waveform *= np.std(gt_waveform) / np.std(flipped_waveform)

# --- 3️⃣ Cross-correlation alignment (to remove small time shift) ---
corr_full = correlate(gt_norm, flipped_waveform, mode="full")
shift = np.argmax(corr_full) - len(gt_norm) + 1
print(f"Best shift: {shift} samples")

# align signals by shift
if shift > 0:
    aligned_waveform = flipped_waveform[shift:]
    gt_aligned = gt_waveform[:len(aligned_waveform)]
else:
    aligned_waveform = flipped_waveform[:len(gt_waveform)+shift]
    gt_aligned = gt_waveform[-shift:len(gt_waveform)]

# --- 4️⃣ Normalize and compare again ---
aligned_waveform = (aligned_waveform - np.mean(aligned_waveform)) / np.std(aligned_waveform)
gt_aligned_norm = (gt_aligned - np.mean(gt_aligned)) / np.std(gt_aligned)

plt.figure(figsize=(12,4))
plt.plot(gt_aligned_norm[:2000], label="Ground Truth (aligned)")
plt.plot(aligned_waveform[:2000], label="Extracted (flipped + aligned)", alpha=0.7)
plt.title("After alignment and polarity correction")
plt.xlabel("Samples")
plt.ylabel("Normalized amplitude")
plt.legend()
plt.show()

corr2 = np.corrcoef(gt_aligned_norm, aligned_waveform)[0,1]
print(f"Updated correlation after alignment: {corr2:.3f}")


In [None]:
import os

LEAD_NAMES = ["I","II","III","aVR","aVL","aVF","V1","V2","V3","V4","V5","V6"]

def extract_all_leads_from_image(img_path):
    """
    For a given ECG image path:
      - preprocess
      - split into 12 regions
      - trace each region
    Returns a dict of lead_name -> waveform array
    """
    img = cv2.imread(img_path)
    if img is None:
        print("⚠️ Image not found:", img_path)
        return {}

    gray, cleaned = preprocess_ecg_image(img)
    lead_boxes = split_into_leads(cleaned)
    leads_dict = {}

    for i, (coords, crop) in enumerate(lead_boxes):
        waveform = trace_waveform(crop)
        leads_dict[LEAD_NAMES[i]] = waveform

    return leads_dict


# === Test it on the same example ===
img_path = "/kaggle/input/physionet-ecg-image-digitization/train/7663343/7663343-0001.png"
leads_data = extract_all_leads_from_image(img_path)

# Show extracted Lead II vs original Lead II waveform (short preview)
plt.figure(figsize=(12, 4))
plt.plot(leads_data["II"], label="Extracted Lead II (pixels)")
plt.title("Batch Extracted Lead II (from all-leads function)")
plt.xlabel("Column index")
plt.ylabel("Relative amplitude (pixels)")
plt.legend()
plt.show()

print("Extracted leads:", list(leads_data.keys()))


In [None]:
from scipy.signal import savgol_filter
import numpy as np
import cv2
import matplotlib.pyplot as plt

def remove_grid_background(img):
    """
    Suppress regular grid patterns via frequency-domain notch filtering.
    Keeps ECG trace edges intact.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img.copy()
    f = np.fft.fft2(gray)
    fshift = np.fft.fftshift(f)

    rows, cols = gray.shape
    crow, ccol = rows // 2, cols // 2
    mask = np.ones((rows, cols), np.uint8)

    # Mask out horizontal + vertical gridline frequencies
    span = 5
    mask[crow - span:crow + span, :] = 0
    mask[:, ccol - span:ccol + span] = 0

    fshift = fshift * mask
    f_ishift = np.fft.ifftshift(fshift)
    img_back = np.abs(np.fft.ifft2(f_ishift))

    img_back = cv2.normalize(img_back, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
    return img_back

def refined_trace(crop_img):
    """Remove grid → trace → smooth"""
    no_grid = remove_grid_background(crop_img)
    waveform = trace_waveform(no_grid)
    waveform_smooth = savgol_filter(waveform, 21, 3)  # window=21, polyorder=3
    return waveform_smooth

# --- Test on Lead II ---
img_path = "/kaggle/input/physionet-ecg-image-digitization/train/7663343/7663343-0001.png"
img = cv2.imread(img_path)
gray, _ = preprocess_ecg_image(img)
leads = split_into_leads(gray)
leadII_crop = leads[1][1]

waveform_refined = refined_trace(leadII_crop)

plt.figure(figsize=(12, 3))
plt.plot(waveform_refined)
plt.title("Refined Lead II waveform (grid removed + smoothed)")
plt.xlabel("Column index")
plt.ylabel("Relative amplitude (pixels)")
plt.show()


In [None]:
import numpy as np
from scipy.signal import correlate

def compute_modified_snr(pred, true, fs, max_shift_s=0.2):
    """
    Safe Modified SNR per PhysioNet spec.
    Handles constant / zero-variance signals gracefully.
    """
    pred = np.nan_to_num(pred, nan=0.0, posinf=0.0, neginf=0.0)
    true = np.nan_to_num(true, nan=0.0, posinf=0.0, neginf=0.0)

    # check for constant or empty signals
    if np.std(true) < 1e-8 or np.std(pred) < 1e-8:
        return -100.0  # penalty for degenerate signals

    pred = (pred - np.mean(pred)) / (np.std(pred) + 1e-8)
    true = (true - np.mean(true)) / (np.std(true) + 1e-8)

    n = len(true)
    shift_max = int(max_shift_s * fs)

    corr = correlate(true, pred, mode="full", method="direct")
    shift_vals = np.arange(-len(pred)+1, len(true))
    valid = (shift_vals >= -shift_max) & (shift_vals <= shift_max)
    best_shift = shift_vals[np.argmax(corr[valid])]

    if best_shift > 0:
        pred_aligned = pred[:-best_shift]
        true_aligned = true[best_shift:]
    elif best_shift < 0:
        pred_aligned = pred[-best_shift:]
        true_aligned = true[:len(pred_aligned)]
    else:
        pred_aligned = pred
        true_aligned = true

    offset = np.mean(true_aligned - pred_aligned)
    pred_aligned += offset

    signal_power = np.sum(true_aligned ** 2)
    noise_power = np.sum((true_aligned - pred_aligned) ** 2) + 1e-12

    if noise_power <= 0:
        return -100.0

    snr = 10 * np.log10(signal_power / noise_power)
    return float(snr)



# === Test on current example ===
from scipy.signal import resample
signal_df = pd.read_csv("/kaggle/input/physionet-ecg-image-digitization/train/7663343/7663343.csv")
true_sig = signal_df["II"].values
fs = int(pd.read_csv("/kaggle/input/physionet-ecg-image-digitization/train.csv")\
         .query("id==7663343")["fs"].iloc[0])

# resample refined waveform to same length
pred_sig = resample(waveform_refined, len(true_sig))

snr_val = compute_modified_snr(pred_sig, true_sig, fs)
print(f"Modified SNR for Lead II = {snr_val:.2f} dB")


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2, numpy as np, pandas as pd, os
from scipy.signal import resample 

# --- Dataset ---
class ECGLeadDataset(Dataset):
    def __init__(self, base_path, ids, lead="II", target_len=5000, transform=None):
        self.base_path = base_path
        self.ids = ids
        self.lead = lead
        self.target_len = target_len
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        pid = self.ids[idx]
        img_path = f"{self.base_path}/train/{pid}/{pid}-0001.png"
        sig_path = f"{self.base_path}/train/{pid}/{pid}.csv"

        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (512, 256))  # normalize input size

        signal = pd.read_csv(sig_path)[self.lead].values.astype(np.float32)
        # --- resample all signals to same target length ---
        signal = resample(signal, self.target_len)
        signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-8)

        if self.transform:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)

        return img, torch.tensor(signal, dtype=torch.float32)

# --- Simple UNet-like regression model ---
class ECGDigitizer(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 2, stride=2), nn.ReLU(),
            nn.Conv2d(16, 1, 1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = torch.mean(x, dim=2)  # average over height → (B,1,W)
        return x.squeeze(1)       # output shape (B, W)

# --- Training setup ---
transform = transforms.Compose([transforms.ToTensor()])
train_meta = pd.read_csv("/kaggle/input/physionet-ecg-image-digitization/train.csv")
ids = train_meta["id"].astype(str).tolist()[:50]  # subset

dataset = ECGLeadDataset("/kaggle/input/physionet-ecg-image-digitization", ids)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = ECGDigitizer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

for epoch in range(3):
    for imgs, signals in loader:
        imgs, signals = imgs.to(device), signals.to(device)
        preds = model(imgs)
        signals_resamp = torch.nn.functional.interpolate(
            signals.unsqueeze(1), size=preds.shape[1], mode="linear"
        ).squeeze(1)
        loss = criterion(preds, signals_resamp)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, loss={loss.item():.6f}")


In [None]:
import torch
import matplotlib.pyplot as plt
from scipy.signal import resample

# pick one random record
pid = ids[5]  # change if you like
img_path = f"/kaggle/input/physionet-ecg-image-digitization/train/{pid}/{pid}-0001.png"
sig_path = f"/kaggle/input/physionet-ecg-image-digitization/train/{pid}/{pid}.csv"

# prepare data
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.resize(img, (512, 256))
img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)

true_sig = pd.read_csv(sig_path)["II"].values.astype(np.float32)
true_sig = resample(true_sig, 5000)  # match target_len
true_sig = (true_sig - np.mean(true_sig)) / (np.std(true_sig) + 1e-8)

# model inference
model.eval()
with torch.no_grad():
    pred = model(img_tensor).cpu().numpy().flatten()

# normalize prediction
pred = (pred - np.mean(pred)) / (np.std(pred) + 1e-8)

# overlay plot
plt.figure(figsize=(12,4))
plt.plot(true_sig, label="Ground Truth Lead II", linewidth=2)
plt.plot(pred, label="Model Prediction", alpha=0.7)
plt.title(f"Model Prediction vs Ground Truth — ID {pid}")
plt.xlabel("Samples")
plt.ylabel("Normalized Amplitude")
plt.legend()
plt.show()


In [None]:
def snr_loss(pred, target, eps=1e-6):
    """
    Differentiable safe negative SNR loss.
    Handles constant signals and avoids NaN/Inf.
    """
    # sanitize
    pred = torch.nan_to_num(pred, nan=0.0, posinf=0.0, neginf=0.0)
    target = torch.nan_to_num(target, nan=0.0, posinf=0.0, neginf=0.0)

    signal_power = torch.sum(target ** 2, dim=1) + eps
    noise_power  = torch.sum((target - pred) ** 2, dim=1) + eps
    ratio = signal_power / noise_power
    ratio = torch.clamp(ratio, min=eps, max=1e6)
    snr = 10 * torch.log10(ratio)
    snr = torch.nan_to_num(snr, nan=0.0, posinf=0.0, neginf=0.0)
    return -torch.mean(snr)



In [None]:
class ECGLeadDataset(Dataset):
    def __init__(self, base_path, ids, leads=None, target_len=5000, transform=None):
        self.base_path = base_path
        self.ids = ids
        self.leads = leads or ["I","II","III","aVR","aVL","aVF","V1","V2","V3","V4","V5","V6"]
        self.target_len = target_len
        self.transform = transform

    def __len__(self):
        return len(self.ids) * len(self.leads)

    def __getitem__(self, idx):
        pid = self.ids[idx // len(self.leads)]
        lead = self.leads[idx % len(self.leads)]
        img_path = f"{self.base_path}/train/{pid}/{pid}-0001.png"
        sig_path = f"{self.base_path}/train/{pid}/{pid}.csv"

        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (512, 256))
        signal = pd.read_csv(sig_path)[lead].values.astype(np.float32)
        signal = resample(signal, self.target_len)
        signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-8)

        if self.transform:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)

        return img, torch.tensor(signal, dtype=torch.float32)


In [None]:
lambda_snr = 0.7
criterion_mse = nn.MSELoss()

for epoch in range(5):
    for imgs, signals, lead_idx in loader:
        imgs, signals, lead_idx = imgs.to(device), signals.to(device), lead_idx.to(device)
        preds = model(imgs, lead_idx)

        signals_resamp = torch.nn.functional.interpolate(
            signals.unsqueeze(1), size=preds.shape[1], mode='linear'
        ).squeeze(1)

        loss_mse = criterion_mse(preds, signals_resamp)
        loss_snr = snr_loss(preds, signals_resamp)
        loss = (1 - lambda_snr) * loss_mse + lambda_snr * loss_snr

        # skip NaN batches if any appear
        if torch.isnan(loss) or torch.isinf(loss):
            print("⚠️  NaN loss detected — skipping batch")
            optimizer.zero_grad(set_to_none=True)
            continue

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, total_loss={loss.item():.6f}, snr_loss={loss_snr.item():.4f}")


In [None]:
import numpy as np
from scipy.signal import resample
import matplotlib.pyplot as plt

def safe_normalize(sig):
    sig = np.nan_to_num(sig, nan=0.0, posinf=0.0, neginf=0.0)
    if np.std(sig) < 1e-8:
        sig = sig - np.mean(sig)  # flatten to zero if constant
    else:
        sig = (sig - np.mean(sig)) / (np.std(sig) + 1e-8)
    return sig

def evaluate_record(pid, model, base_path="/kaggle/input/physionet-ecg-image-digitization", fs=500):
    img_path = f"{base_path}/train/{pid}/{pid}-0001.png"
    sig_path = f"{base_path}/train/{pid}/{pid}.csv"
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    img = cv2.resize(img, (512, 256))
    img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)

    true_df = pd.read_csv(sig_path)
    lead_snr = {}
    model.eval()
    with torch.no_grad():
        for lead in true_df.columns:
            true_sig = resample(true_df[lead].values, 5000)
            pred = model(img_tensor).cpu().numpy().flatten()
            pred = resample(pred, len(true_sig))

            # --- sanitize both before metric ---
            true_sig = safe_normalize(true_sig)
            pred = safe_normalize(pred)

            snr_val = compute_modified_snr(pred, true_sig, fs)
            lead_snr[lead] = snr_val

    avg_snr = np.nanmean(list(lead_snr.values()))
    return avg_snr, lead_snr


# --- Evaluate a few samples ---
test_ids = ids[:3]
for pid in test_ids:
    fs = int(train_meta.query(f"id=={pid}")["fs"].iloc[0])
    avg_snr, per_lead = evaluate_record(pid, model, fs=fs)
    print(f"\nRecord {pid}: Average SNR = {avg_snr:.2f} dB")
    for k, v in per_lead.items():
        print(f"  {k:>4}: {v:6.2f} dB")


In [None]:
import torch.nn.functional as F

class MultiLeadDigitizer(nn.Module):
    def __init__(self, n_leads=12):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 2, stride=2), nn.ReLU(),
            nn.Conv2d(16, 1, 1)
        )
        self.lead_embed = nn.Embedding(n_leads, 64)  # learn per-lead bias

    def forward(self, x, lead_idx):
        feat = self.encoder(x)               # (B, 64, H', W')
        feat = self.decoder(feat)            # (B, 1, H, W)
    
        # Collapse height → 1D waveform
        feat = torch.mean(feat, dim=2)       # (B, 1, W)
        feat = feat.squeeze(1)               # (B, W)
    
        # Lead embedding bias
        lead_bias = self.lead_embed(lead_idx).mean(dim=1, keepdim=True)
        feat = feat + lead_bias              # broadcast add
    
        # --- temporal smoothing (requires 3D input: B×C×L)
        feat = feat.unsqueeze(1)             # (B, 1, W)
        feat = F.avg_pool1d(feat, kernel_size=5, stride=1, padding=2)
        feat = feat.squeeze(1)               # (B, W)
    
        return feat




In [None]:
LEADS = ["I","II","III","aVR","aVL","aVF","V1","V2","V3","V4","V5","V6"]
lead_to_idx = {l:i for i,l in enumerate(LEADS)}

class ECGLeadDataset(Dataset):
    def __init__(self, base_path, ids, leads=LEADS, target_len=5000, transform=None):
        self.base_path = base_path
        self.ids = ids
        self.leads = leads
        self.target_len = target_len
        self.transform = transform

    def __len__(self):
        return len(self.ids) * len(self.leads)

    def __getitem__(self, idx):
        pid = self.ids[idx // len(self.leads)]
        lead = self.leads[idx % len(self.leads)]
        img_path = f"{self.base_path}/train/{pid}/{pid}-0001.png"
        sig_path = f"{self.base_path}/train/{pid}/{pid}.csv"

        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.resize(img, (512, 256))
        sig = pd.read_csv(sig_path)[lead].values.astype(np.float32)
        sig = resample(sig, self.target_len)
        sig = (sig - np.mean(sig)) / (np.std(sig) + 1e-8)
        img = transforms.ToTensor()(img)
        return img, torch.tensor(sig), torch.tensor(lead_to_idx[lead])


In [None]:
dataset = ECGLeadDataset("/kaggle/input/physionet-ecg-image-digitization", ids)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

model = MultiLeadDigitizer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(5):
    for imgs, signals, lead_idx in loader:
        imgs, signals, lead_idx = imgs.to(device), signals.to(device), lead_idx.to(device)
        preds = model(imgs, lead_idx)
        signals_resamp = F.interpolate(signals.unsqueeze(1), size=preds.shape[1], mode='linear').squeeze(1)
        loss_mse = criterion_mse(preds, signals_resamp)
        loss_snr = snr_loss(preds, signals_resamp)
        loss = 0.3*loss_mse + 0.7*loss_snr
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: loss={loss.item():.5f}")
