## Acknowledgements & Additions

This notebook builds upon the open ECG digitizer pipeline released by Ángel Jacinto Sánchez Ruiz (https://www.kaggle.com/code/sacuscreed/open-ecg-digitizer-5e8bfc) and the Open-ECG-Digitizer team. Their models, training code, and shared weights make this solution possible.

Key additions in this revision:
- Dynamic resampling plus configurable post-processing tailored to the PhysioNet ECG Digitization task.
- Lead-wise quality diagnostics with drift, amplitude, and stability metrics exported for offline inspection.
- Optional Gaussian smoothing and baseline correction to denoise extracted waveforms while preserving morphology.
- However... the score ended up much lower than the 7.32 score of Ángel Jacinto Sánchez Ruiz's code. :P


In [None]:
%cd /kaggle/input/open-ecg-digitizer/pytorch/default/1
import torch
import glob
import os.path
import yaml
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torchvision.io import read_image
from torch import Tensor
# Imports from https://github.com/Ahus-AIM/Open-ECG-Digitizer
from src.model.unet import UNet
from src.model.perspective_detector import PerspectiveDetector
from src.model.cropper import Cropper
from src.model.pixel_size_finder import PixelSizeFinder
from src.model.signal_extractor import SignalExtractor
from src.model.lead_identifier import LeadIdentifier
from src.model.lead_identifier import LeadIdentifier

In [None]:
def add_noise_to_image(input_img, sigma=2, opacity=0.2):
    noise = torch.sigmoid(torch.randn_like(input_img) * sigma)
    input_img = (1-opacity)*input_img + opacity * noise
    return input_img

def load_model(weights_path, **kwargs):
    model = UNet(
        num_in_channels=kwargs.get("num_in_channels", 3),
        num_out_channels=kwargs.get("num_out_channels", 4),
        dims=kwargs.get(
            "dims",
            [32, 64, 128, 256, 320, 320, 320, 320],
        ),
        depth=kwargs.get("depth", 2),
    )
    state_dict = torch.load(weights_path, map_location=device)
    # replace _orig_model. with nothing in all keys (model was trained with torch.compile)
    state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    model.eval().to(device)
    return model


def load_png_file(path):
    img = read_image(path)
    img = img.float() / 255.0  # Normalize to [0, 1]
    img = img.unsqueeze(0)  # Add batch dimension
    # ensure only 3 channels
    if img.shape[1] > 3:
        img = img[:, :3, :, :]
    return img


def _crop_y(
    image: Tensor,
    signal_prob: Tensor,
    grid_prob: Tensor,
    text_prob: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    def get_bounds(tensor: Tensor) -> tuple[int, int]:
        prob = torch.clamp(
            tensor.squeeze().sum(dim=tensor.dim() - 3)
            - tensor.squeeze().sum(dim=tensor.dim() - 3).mean(),
            min=0,
        )
        non_zero = (prob > 0).nonzero(as_tuple=True)[0]
        if non_zero.numel() == 0:
            return 0, tensor.shape[2] - 1
        return int(non_zero[0].item()), int(non_zero[-1].item())

    y1, y2 = get_bounds(signal_prob + grid_prob)

    slices = (slice(None), slice(None), slice(y1, y2 + 1), slice(None))
    return (
        image[slices],
        signal_prob[slices],
        grid_prob[slices],
        text_prob[slices],
    )


def _align_feature_maps(
    cropper: Cropper,
    image: Tensor,
    signal_prob: Tensor,
    grid_prob: Tensor,
    text_prob: Tensor,
    source_points: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    aligned_signal_prob = cropper.apply_perspective(
        signal_prob,
        source_points,
        fill_value=0,
    )
    aligned_image = cropper.apply_perspective(
        image,
        source_points,
        fill_value=0,
    )
    aligned_grid_prob = cropper.apply_perspective(
        grid_prob,
        source_points,
        fill_value=0,
    )
    aligned_text_prob = cropper.apply_perspective(
        text_prob,
        source_points,
        fill_value=0,
    )
    (
        aligned_image,
        aligned_signal_prob,
        aligned_grid_prob,
        aligned_text_prob,
    ) = _crop_y(
        aligned_image,
        aligned_signal_prob,
        aligned_grid_prob,
        aligned_text_prob,
    )

    return (
        aligned_image,
        aligned_signal_prob,
        aligned_grid_prob,
        aligned_text_prob,
    )


def plot_segmentation_and_image(
    image,
    segmentation,
    aligned_signal,
    aligned_grid,
    lines,
):
    import matplotlib.pyplot as plt
    import numpy as np

    image_np = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
    probs = segmentation.squeeze(0).cpu()

    # Make the black pixels white
    show_featuremap = torch.ones(probs.shape[1], probs.shape[2], 3)
    probs[2] /= probs[2].max()  # Normalize the third channel to [0, 1]
    show_featuremap[:, :, [0, 1, 2]] -= 2 * probs[2].unsqueeze(-1)
    show_featuremap[:, :, [1, 2]] -= probs[0].unsqueeze(-1)
    show_featuremap = torch.clamp(show_featuremap, 0, 1).numpy()

    straightened_featuremap = torch.ones(
        aligned_signal.shape[2],
        aligned_signal.shape[3],
        3,
        device=aligned_signal.device,
    )
    aligned_signal /= aligned_signal.max()
    straightened_featuremap[:, :, [0, 1, 2]] -= 2 * aligned_signal[0, 0].unsqueeze(-1)
    aligned_grid /= aligned_grid.max()
    straightened_featuremap[:, :, [1, 2]] -= aligned_grid[0, 0].unsqueeze(-1)
    straightened_featuremap = torch.clamp(straightened_featuremap, 0, 1)

    fig, ax = plt.subplots(2, 2, figsize=(16, 12))
    ax[0, 0].imshow(image_np)
    ax[0, 0].axis("off")

    ax[0, 1].imshow(show_featuremap)
    ax[0, 1].axis("off")

    ax[1, 0].imshow(straightened_featuremap.cpu())
    ax[1, 0].axis("off")

    offsets = [-0, -9, -6, -0, -3, -6, -0, -3, -6, -0, -3, -6]
    if lines.numel() > 0:
        ax[1, 1].plot(lines.T.cpu().numpy() + offsets[: lines.shape[0]])
    ax[1, 1].axis("off")
    plt.tight_layout()
    plt.show()


def crop_image(image, probs):
    perspective_detector = PerspectiveDetector(num_thetas=200)

    cropper = Cropper(percentiles=(0.03, 0.97), alpha=0.99)

    alignment_params = perspective_detector(probs[0, 0])

    source_points = cropper(probs[0, 1], alignment_params)

    signal_prob, grid_prob, text_prob = (
        probs[:, [2]],
        probs[:, [0]],
        probs[:, [1]],
    )

    (
        aligned_image,
        aligned_signal_prob,
        aligned_grid_prob,
        aligned_text_prob,
    ) = _align_feature_maps(
        cropper,
        image,
        signal_prob,
        grid_prob,
        text_prob,
        source_points,
    )

    return (
        aligned_image,
        aligned_signal_prob,
        aligned_grid_prob,
        aligned_text_prob,
    )


def extract_signals(
    aligned_signal_prob: Tensor,
    aligned_grid_prob: Tensor,
    aligned_text_prob: Tensor,
    target_num_samples: int,
) -> Tensor:
    pixel_size_finder = PixelSizeFinder(
        min_number_of_grid_lines=30,
        max_number_of_grid_lines=100,
        lower_grid_line_factor=0.1,
    )
    signal_extractor = SignalExtractor()

    num_in_channels = 1
    num_out_channels = 13
    dims = [32, 64, 128, 256, 256]
    depth = 2
    layout_unet = load_model(
        weights_path="/kaggle/input/open-ecg-digitizer-weights/pytorch/default/1/lead_name_unet_weights_07072025.pt",
        num_in_channels=num_in_channels,
        num_out_channels=num_out_channels,
        dims=dims,
        depth=depth,
        device=device,
    )

    layouts = yaml.safe_load(
        open("src/config/lead_layouts_george-moody-2024.yml", "r"),
    )

    identifier = LeadIdentifier(
        layouts=layouts,
        unet=layout_unet,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        possibly_flipped=False,
        target_num_samples=target_num_samples,
        required_valid_samples=2,
    )
    mm_per_pixel_x, mm_per_pixel_y = pixel_size_finder(aligned_grid_prob)

    avg_pixel_per_mm = (1 / mm_per_pixel_x + 1 / mm_per_pixel_y) / 2
    signals = signal_extractor(aligned_signal_prob.squeeze())

    signals = identifier(
        signals,
        aligned_text_prob,
        avg_pixel_per_mm=avg_pixel_per_mm,
    )

    return signals

def resample_image(image: Tensor, resample_size: int) -> Tensor:
    height, width = image.shape[2], image.shape[3]
    min_dim = min(height, width)
    max_dim = max(height, width)


    if isinstance(resample_size, int):
        if max_dim > resample_size:
            scale = resample_size / max_dim
            new_size = (int(height * scale), int(width * scale))
            return F.interpolate(image, size=new_size, mode="bilinear", align_corners=False, antialias=True)
        return image

    if isinstance(resample_size, tuple):
        interpolated = F.interpolate(
            image, size=resample_size, mode="bilinear", align_corners=False, antialias=True
        )
        return interpolated

    raise ValueError(f"Invalid resample_size: {resample_size}. Expected int or tuple of (height, width).")

leads_names = ['I','II','III','aVR','aVL','aVF','V1','V2','V3','V4','V5','V6']
def get_slice(lead_name: str, number_of_rows: int):
    assert lead_name in leads_names
    if lead_name in ("II",):
        return slice(0, number_of_rows)
    if lead_name in (("I", "III")):
        return slice(0, number_of_rows)
    if lead_name in (("aVR", "aVF", "aVL")):
        return slice(1*number_of_rows, 2*number_of_rows)
    if lead_name in (("V1", "V2", "V3")):
        return slice(2*number_of_rows, 3*number_of_rows)
    if lead_name in (("V4", "V5", "V6")):
        return slice(3*number_of_rows, 4*number_of_rows)
        


def adaptive_resample_size(image: Tensor, max_dim: int, min_dim: int = 2048) -> int:
    _, _, height, width = image.shape
    largest = max(int(height), int(width))
    if largest <= 0:
        return max_dim
    if min_dim and largest < min_dim:
        return min_dim
    if largest > max_dim:
        return max_dim
    return largest

def _ensure_odd(value: int) -> int:
    value = int(value)
    if value % 2 == 0:
        value += 1
    return max(value, 1)

def _rolling_mean(signal: np.ndarray, window: int) -> np.ndarray:
    if window <= 1:
        return np.zeros_like(signal, dtype=np.float32)
    pad_left = window // 2
    pad_right = window - 1 - pad_left
    padded = np.pad(signal, (pad_left, pad_right), mode="edge")
    kernel = np.ones(window, dtype=np.float32) / window
    return np.convolve(padded, kernel, mode="valid").astype(np.float32)

def _gaussian_smooth(signal: np.ndarray, kernel_size: int, sigma: float) -> np.ndarray:
    if kernel_size <= 1:
        return signal.astype(np.float32)
    offsets = np.arange(kernel_size, dtype=np.float32) - (kernel_size - 1) / 2.0
    kernel = np.exp(-0.5 * (offsets / max(sigma, 1e-3)) ** 2)
    kernel = (kernel / kernel.sum()).astype(np.float32)
    pad_left = kernel_size // 2
    pad_right = kernel_size - 1 - pad_left
    padded = np.pad(signal, (pad_left, pad_right), mode="edge")
    smoothed = np.convolve(padded, kernel, mode="valid")
    return smoothed.astype(np.float32)

def postprocess_signals(signals, lead_names, sample_rate, config=None):
    config = config or POSTPROCESS_CONFIG
    baseline_cfg = config.get("baseline_correction", {})
    smoothing_cfg = config.get("smoothing", {})
    quality_cfg = config.get("quality", {})
    array = signals
    if isinstance(array, torch.Tensor):
        array = array.detach().cpu().numpy()
    array = np.asarray(array, dtype=np.float32)
    processed = []
    metrics = []
    for idx, lead_name in enumerate(lead_names):
        lead_signal = array[idx].astype(np.float32).copy()
        if lead_signal.size == 0:
            processed.append(lead_signal)
            metrics.append(
                {
                    "lead": lead_name,
                    "amplitude_range": 0.0,
                    "baseline_drift": 0.0,
                    "nan_fraction": 1.0,
                    "rms": 0.0,
                    "stability_score": 0.0,
                    "saturation_ratio": 0.0,
                    "quality_flag": "empty",
                    "baseline_window": 0,
                    "smoothing_kernel": 0,
                }
            )
            continue

        finite_mask = np.isfinite(lead_signal)
        nan_fraction = 1.0 - float(finite_mask.sum() / lead_signal.size)
        if finite_mask.any():
            fill_value = float(np.nanmedian(lead_signal[finite_mask]))
            lead_signal = np.where(finite_mask, lead_signal, fill_value)
        else:
            lead_signal = np.zeros_like(lead_signal, dtype=np.float32)

        baseline_window = 0
        if baseline_cfg.get("enabled", True):
            baseline_window = max(
                baseline_cfg.get("min_window", 5),
                int(sample_rate * baseline_cfg.get("seconds", 0.6)),
            )
            baseline_window = min(baseline_window, lead_signal.size - 1)
            baseline_window = _ensure_odd(max(baseline_window, 3))
            if baseline_window < lead_signal.size:
                baseline = _rolling_mean(lead_signal, baseline_window)
                lead_signal = lead_signal - baseline

        smoothing_kernel = 0
        if smoothing_cfg.get("enabled", True):
            smoothing_kernel = max(
                smoothing_cfg.get("min_kernel_size", 5),
                int(sample_rate * smoothing_cfg.get("seconds", 0.08)),
            )
            smoothing_kernel = min(smoothing_kernel, smoothing_cfg.get("max_kernel_size", 49))
            smoothing_kernel = min(smoothing_kernel, lead_signal.size - 1)
            smoothing_kernel = _ensure_odd(max(smoothing_kernel, 3))
            if smoothing_kernel < lead_signal.size:
                sigma = smoothing_cfg.get("sigma", max(smoothing_kernel / 6.0, 1.0))
                lead_signal = _gaussian_smooth(lead_signal, smoothing_kernel, sigma)

        amplitude_range = float(np.max(lead_signal) - np.min(lead_signal))
        eval_window = min(
            lead_signal.size // 4,
            max(5, int(sample_rate * quality_cfg.get("drift_seconds", 0.4))),
        )
        if eval_window > 0:
            start_mean = float(np.mean(lead_signal[:eval_window]))
            end_mean = float(np.mean(lead_signal[-eval_window:]))
            baseline_drift = end_mean - start_mean
        else:
            baseline_drift = 0.0
        rms = float(np.sqrt(np.mean(np.square(lead_signal))))
        derivative = np.diff(lead_signal, prepend=lead_signal[0])
        diff_mean = float(np.mean(np.abs(derivative)))
        stability_score = float(rms / (diff_mean + 1e-6))
        threshold = quality_cfg.get("large_derivative_threshold", 1.2)
        saturation_ratio = float(np.mean(np.abs(derivative) > threshold))

        quality_flag = "ok"
        if amplitude_range < quality_cfg.get("min_amplitude", 0.05):
            quality_flag = "low_amplitude"
        if nan_fraction > quality_cfg.get("nan_threshold", 0.05):
            quality_flag = "nan_rich"

        processed.append(lead_signal.astype(np.float32))
        metrics.append(
            {
                "lead": lead_name,
                "amplitude_range": amplitude_range,
                "baseline_drift": baseline_drift,
                "nan_fraction": nan_fraction,
                "rms": rms,
                "stability_score": stability_score,
                "saturation_ratio": saturation_ratio,
                "quality_flag": quality_flag,
                "baseline_window": baseline_window,
                "smoothing_kernel": smoothing_kernel,
            }
        )

    processed_array = np.stack(processed).astype(np.float32)
    processed_tensor = torch.from_numpy(processed_array)
    return processed_tensor, metrics


POSTPROCESS_CONFIG = {
    "baseline_correction": {"enabled": True, "seconds": 0.6, "min_window": 15},
    "smoothing": {
        "enabled": True,
        "seconds": 0.08,
        "min_kernel_size": 5,
        "max_kernel_size": 49,
        "sigma": 2.5,
    },
    "quality": {
        "min_amplitude": 0.05,
        "nan_threshold": 0.05,
        "drift_seconds": 0.4,
        "large_derivative_threshold": 1.2,
    },
}

INFERENCE_CONFIG = {
    "resample": {"max_dim": 3200, "min_dim": 2400},
    "postprocess": POSTPROCESS_CONFIG,
    "collect_quality_metrics": True,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_in_channels = 3
num_out_channels = 4
dims = [32, 64, 128, 256, 320, 320, 320, 320]
depth = 2
model = load_model(weights_path="/kaggle/input/open-ecg-digitizer-weights/pytorch/default/1/unet_weights_07072025.pt", num_in_channels=num_in_channels, num_out_channels=num_out_channels, dims=dims, depth=depth, device=device)

In [None]:

test = pd.read_csv('/kaggle/input/physionet-ecg-image-digitization/test.csv')
output_path = '/kaggle/working/submission.csv'
quality_records = []

if os.path.exists(output_path):
    os.remove(output_path)
pd.DataFrame(columns=["id", "value"]).to_csv(output_path, index=False)

old_id = None
lines = None

for index, row in test.iterrows():
    if row.id != old_id:
        old_id = row.id

        path = f"/kaggle/input/physionet-ecg-image-digitization/test/{row.id}.png"
        target_num_samples = row.fs * 10  # Assuming 10 second signals.
        input_img = load_png_file(path)

        resample_target = adaptive_resample_size(
            input_img,
            max_dim=INFERENCE_CONFIG["resample"]["max_dim"],
            min_dim=INFERENCE_CONFIG["resample"]["min_dim"],
        )
        input_img = resample_image(image=input_img, resample_size=resample_target)

        with torch.no_grad():
            logits = model(input_img.to(device))
            output_probs = torch.softmax(logits, dim=1)
            aligned_image, aligned_signal, aligned_grid, aligned_text = crop_image(
                input_img,
                output_probs,
            )
            extracted = extract_signals(
                aligned_signal,
                aligned_grid,
                aligned_text,
                target_num_samples=target_num_samples,
            )
            raw_lines = extracted["canonical_lines"] * 1e-3  # microvolt to millivolt

        processed_lines, per_lead_metrics = postprocess_signals(
            raw_lines,
            lead_names=leads_names,
            sample_rate=row.fs,
            config=POSTPROCESS_CONFIG,
        )
        for metric in per_lead_metrics:
            metric.update(
                {
                    "id": row.id,
                    "fs": row.fs,
                    "num_samples": target_num_samples,
                    "resample_target": resample_target,
                }
            )
            quality_records.append(metric)
        lines = processed_lines

        if index == 0:
            plot_segmentation_and_image(
                input_img,
                output_probs,
                aligned_signal,
                aligned_grid,
                processed_lines,
            )

    file_id = row.id
    lead_name = row.lead
    number_of_rows_in_lead = row.number_of_rows

    lead_index = leads_names.index(lead_name)

    lead_data = lines[lead_index]
    lead_data = lead_data[get_slice(lead_name, number_of_rows_in_lead)]

    mean_val = np.nanmean(lead_data)
    if np.isnan(mean_val):
        mean_val = 0.0
    lead_data = np.nan_to_num(lead_data, nan=mean_val)

    assert len(lead_data) == number_of_rows_in_lead

    chunk = []
    for t in range(number_of_rows_in_lead):
        chunk.append({"id": f"{file_id}_{t}_{lead_name}", "value": float(lead_data[t])})

    if chunk:
        pd.DataFrame(chunk).to_csv(output_path, mode='a', index=False, header=False)

if INFERENCE_CONFIG.get("collect_quality_metrics", True):
    if quality_records:
        quality_diagnostics = pd.DataFrame(quality_records)
        quality_diagnostics.to_csv('/kaggle/working/quality_diagnostics.csv', index=False)
    else:
        quality_diagnostics = pd.DataFrame(
            columns=[
                "id",
                "lead",
                "amplitude_range",
                "baseline_drift",
                "nan_fraction",
                "rms",
                "stability_score",
                "saturation_ratio",
                "quality_flag",
                "baseline_window",
                "smoothing_kernel",
                "fs",
                "num_samples",
                "resample_target",
            ]
        )


In [None]:
if 'quality_diagnostics' in globals() and not quality_diagnostics.empty:
    print("Quality diagnostics preview (first 12 rows):")
    print(quality_diagnostics.head(12).to_string(index=False))
    summary = quality_diagnostics.groupby("lead").agg(
        amplitude_mean=("amplitude_range", "mean"),
        amplitude_std=("amplitude_range", "std"),
        drift_mean=("baseline_drift", "mean"),
        drift_std=("baseline_drift", "std"),
        nan_fraction_mean=("nan_fraction", "mean"),
        stability_mean=("stability_score", "mean"),
    ).round(6)
    print("Lead-level diagnostic summary:")
    print(summary.to_string())
else:
    print("No quality diagnostics were generated. Check configuration or input files.")
print(f"Submission saved to {output_path}")
