# ECG Printout Digitization (Kaggle-ready)

This notebook rebuilds the inference pipeline described in **Combining Hough Transform and Deep Learning Approaches to Reconstruct ECG Signals From Printouts**. It performs:

1. Rotation correction via Hough Transform.
2. ECG trace segmentation with **nnU-Net v2**.
3. Lead-wise vectorization to reconstruct 12-lead signals and save them as WFDB/NumPy outputs.

Fill in the CONFIG cell (paths to your input images and pretrained nnU-Net fold) and run all cells. All outputs are written to `/kaggle/working/` so the notebook works offline within Kaggle.

In [None]:
# %%capture --no-stdout
!pip install -q opencv-python torch torchvision nnunetv2 wfdb tqdm matplotlib pillow numpy nibabel

import importlib, sys
packages = {
    "cv2": "opencv-python",
    "torch": "torch",
    "nnunetv2": "nnunetv2",
    "wfdb": "wfdb",
    "tqdm": "tqdm",
    "PIL": "pillow",
    "numpy": "numpy",
    "matplotlib": "matplotlib",
    "nibabel": "nibabel"
}
for module, pkg in packages.items():
    try:
        mod = importlib.import_module(module if module != "PIL" else "PIL")
        version = getattr(mod, "__version__", getattr(mod, "VERSION", None))
        print(f"{pkg}: {version}")
    except Exception as e:
        print(f"{pkg}: not available ({e})")


## CONFIG: edit input paths and defaults
Update the following cell with your Kaggle Dataset paths. Defaults are aligned with `config.py` and the provided nnU-Net fold layout.

In [None]:
# User-editable configuration
DEBUG = True
DEBUG_MAX_IMAGES = 2
SAVE_OVERLAY = False

# Input/output paths
INPUT_IMAGES_DIR = "/kaggle/input/<your-ecg-images>/"
OUTPUT_DIR = "/kaggle/working/digitized/"

# nnU-Net pretrained model placement
MODEL_FOLD_DIR_DEFAULT = "/kaggle/input/nnunet-ecg-dataset500-fold0-best-min/Dataset500_Signals/nnUNetTrainer__nnUNetPlans__2d/fold_0"
NNUNET_RESULTS_ROOT_DEFAULT = "/kaggle/input/nnunet-ecg-dataset500-fold0-best-min"
NNUNET_DATASET_NAME = "Dataset500_Signals"
NNUNET_TRAINER = "nnUNetTrainer"
NNUNET_CONFIG = "2d"
NNUNET_PLANS = "nnUNetPlans"
NNUNET_FOLD = "0"
CHECKPOINT_NAME = "checkpoint_best.pth"  # fallback to checkpoint_final.pth if missing

# Signal configuration (from config.py)
IMAGE_TYPE = "png"
FREQUENCY = 500
DATASET_NAME = "Dataset500_Signals"
LONG_SIGNAL_LENGTH_SEC = 10
SHORT_SIGNAL_LENGTH_SEC = 2.5
SIGNAL_UNITS = "mV"
FMT = "16"
ADC_GAIN = 1000.0
BASELINE = 0

LEAD_LABEL_MAPPING = {
  "I":1,"II":2,"III":3,"aVR":4,"aVL":5,"aVF":6,
  "V1":7,"V2":8,"V3":9,"V4":10,"V5":11,"V6":12
}

Y_SHIFT_RATIO = {
  "I": 12.6/21.59, "II": 9/21.59, "III": 5.4/21.59,
  "aVR": 12.6/21.59, "aVL": 9/21.59, "aVF": 5.4/21.59,
  "V1": 12.59/21.59, "V2": 9/21.59, "V3": 5.4/21.59,
  "V4": 12.59/21.59, "V5": 9/21.59, "V6": 5.4/21.59,
  "full": 2.1/21.59
}

# Hough/segmentation settings
HOUGH_DEGREE_WINDOW = 5
PARALLELISM_COUNT = 5
PARALLELISM_WINDOW = 2
CANNY_THRESHOLDS = (50, 150)
SEC_PER_PAPER_SECOND = 25  # mm/s
MV_PER_MM = 0.1  # 10 mm/mV -> 0.1 mV per mm -> mm_per_pixel/10

# Derived writable paths
import os
os.makedirs(OUTPUT_DIR, exist_ok=True)
TMP_ROTATED_DIR = os.path.join(OUTPUT_DIR, "rotated_images")
TMP_MASKS_DIR = os.path.join(OUTPUT_DIR, "nnunet_masks")
for d in [TMP_ROTATED_DIR, TMP_MASKS_DIR]:
    os.makedirs(d, exist_ok=True)


## Utilities
Helper functions for path inspection and logging.

In [None]:
import os, json, glob, shutil, math, subprocess, warnings
from typing import List, Tuple, Dict, Optional
from pathlib import Path
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import nibabel as nib
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import wfdb


def print_directory_tree(base: str, max_depth: int = 2):
    base_path = Path(base)
    print(f"
Directory tree for {base_path} (depth {max_depth}):")
    for path in sorted(base_path.rglob('*')):
        depth = len(path.relative_to(base_path).parts)
        if depth > max_depth:
            continue
        prefix = '  ' * depth
        print(f"{prefix}{path.name}{'/' if path.is_dir() else ''}")


def check_paths():
    print(f"INPUT_IMAGES_DIR exists: {os.path.exists(INPUT_IMAGES_DIR)} -> {INPUT_IMAGES_DIR}")
    print(f"MODEL_FOLD_DIR_DEFAULT exists: {os.path.exists(MODEL_FOLD_DIR_DEFAULT)} -> {MODEL_FOLD_DIR_DEFAULT}")
    print(f"NNUNET_RESULTS_ROOT_DEFAULT exists: {os.path.exists(NNUNET_RESULTS_ROOT_DEFAULT)} -> {NNUNET_RESULTS_ROOT_DEFAULT}")
    print_directory_tree(NNUNET_RESULTS_ROOT_DEFAULT, max_depth=3)
    print_directory_tree(MODEL_FOLD_DIR_DEFAULT, max_depth=2)

check_paths()


## Rotation correction (Hough transform)
Follows the paper: detect near-horizontal lines with Canny + HoughLines, keep parallel lines, and rotate by the median angle. If no reliable lines are found, fall back to 0°.

In [None]:
def get_lines(gray_img: np.ndarray, canny_thresholds: Tuple[int, int] = CANNY_THRESHOLDS):
    edges = cv2.Canny(gray_img, *canny_thresholds)
    lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=150)
    return lines


def filter_lines(lines, degree_window=HOUGH_DEGREE_WINDOW, parallelism_count=PARALLELISM_COUNT, parallelism_window=PARALLELISM_WINDOW):
    if lines is None:
        return []
    filtered = []
    for line in lines:
        rho, theta = line[0]
        degree = np.degrees(theta) - 90  # horizontal around 0
        if abs(degree) <= degree_window:
            filtered.append(degree)
    if len(filtered) < parallelism_count:
        return []
    filtered_sorted = sorted(filtered)
    grouped = []
    for d in filtered_sorted:
        if not grouped or abs(grouped[-1][-1] - d) > parallelism_window:
            grouped.append([d])
        else:
            grouped[-1].append(d)
    best_group = max(grouped, key=len) if grouped else []
    return best_group


def get_median_degrees(filtered_degrees: List[float]) -> float:
    return float(np.median(filtered_degrees)) if filtered_degrees else 0.0


def get_rotation_angle(gray_img: np.ndarray) -> float:
    lines = get_lines(gray_img)
    filtered = filter_lines(lines)
    angle = get_median_degrees(filtered)
    if not filtered:
        warnings.warn("Hough transform failed, using 0 degrees.")
    return angle


def rotate_image(img: np.ndarray, angle: float) -> np.ndarray:
    h, w = img.shape[:2]
    center = (w // 2, h // 2)
    matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(img, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
    return rotated


## nnU-Net v2 inference
Adapts `digitize.py` to Kaggle: sets `nnUNet_results` from the provided root, resolves the checkpoint (best → final), and uses fold 0 by default. Predictions are batched by first writing rotated images to a temp folder.

In [None]:
def resolve_nnunet_paths(model_fold_dir: str, results_root: str) -> Tuple[str, str]:
    resolved_results = results_root
    if os.path.isdir(os.path.join(model_fold_dir, "nnUNet_results")):
        resolved_results = os.path.join(model_fold_dir, "nnUNet_results")
    os.environ["nnUNet_results"] = resolved_results
    os.environ["nnUNet_raw"] = os.path.join(OUTPUT_DIR, "nnUNet_raw")
    os.environ["nnUNet_preprocessed"] = os.path.join(OUTPUT_DIR, "nnUNet_preprocessed")
    for env_path in [os.environ["nnUNet_raw"], os.environ["nnUNet_preprocessed"]]:
        os.makedirs(env_path, exist_ok=True)
    print(f"nnUNet_results -> {resolved_results}")
    print(f"nnUNet_raw -> {os.environ['nnUNet_raw']}")
    print(f"nnUNet_preprocessed -> {os.environ['nnUNet_preprocessed']}")
    return resolved_results, model_fold_dir


def resolve_checkpoint(model_fold_dir: str, checkpoint_name: str) -> str:
    best_path = os.path.join(model_fold_dir, checkpoint_name)
    final_path = os.path.join(model_fold_dir, "checkpoint_final.pth")
    if os.path.exists(best_path):
        return checkpoint_name
    if os.path.exists(final_path):
        return "checkpoint_final.pth"
    raise FileNotFoundError(f"No checkpoint found in {model_fold_dir}. Checked {best_path} and {final_path}.")


def run_nnunet_predict(rotated_dir: str, masks_dir: str, model_fold_dir: str, results_root: str, device: str = "gpu"):
    resolve_nnunet_paths(model_fold_dir, results_root)
    checkpoint_flag = resolve_checkpoint(model_fold_dir, CHECKPOINT_NAME)
    cmd = [
        "nnUNetv2_predict",
        "-d", NNUNET_DATASET_NAME,
        "-i", rotated_dir,
        "-o", masks_dir,
        "-f", NNUNET_FOLD,
        "-tr", NNUNET_TRAINER,
        "-c", NNUNET_CONFIG,
        "-p", NNUNET_PLANS,
        "-chk", checkpoint_flag
    ]
    if device.lower() == "cpu":
        cmd += ["-device", "cpu", "--verbose"]
    print("Running nnUNetv2_predict:", " ".join(cmd))
    result = subprocess.run(cmd, check=True, text=True, capture_output=True)
    print(result.stdout)
    if result.stderr:
        print(result.stderr)


## Mask cutting and vectorization
Recreate `cut_to_mask`, `cut_binary`, and the column-wise vectorization. Each lead mask is cropped by bounding box; missing leads are filled with NaN. Temporal scaling uses the paper's constants (25 mm/s, 10 mm/mV).

In [None]:
def cut_binary(mask: np.ndarray, lead_id: int) -> Tuple[Optional[np.ndarray], Optional[Tuple[int,int]]]:
    binary = (mask == lead_id).astype(np.uint8)
    coords = cv2.findNonZero(binary)
    if coords is None:
        return None, None
    x, y, w, h = cv2.boundingRect(coords)
    cropped = binary[y:y+h, x:x+w]
    return cropped, (y, x)


def cut_to_mask(mask: np.ndarray, lead_mapping: Dict[str, int]):
    lead_masks = {}
    for lead, idx in lead_mapping.items():
        cropped, origin = cut_binary(mask, idx)
        lead_masks[lead] = {"mask": cropped, "origin": origin}
    return lead_masks


def estimate_scaling(mask_width: int) -> Tuple[float, float]:
    x_pixel_list = [mask_width]
    median = np.median(x_pixel_list)
    mean_below_2x = np.mean([v for v in x_pixel_list if v < 2 * median])
    sec_per_pixel = SHORT_SIGNAL_LENGTH_SEC / mean_below_2x
    mm_per_pixel = SEC_PER_PAPER_SECOND * sec_per_pixel
    mV_per_pixel = mm_per_pixel / 10
    return sec_per_pixel, mV_per_pixel


def vectorise_mask(mask_info: Dict[str, Dict], image_height: int) -> Dict[str, np.ndarray]:
    lead_signals = {}
    for lead, info in mask_info.items():
        lead_mask, origin = info.get("mask"), info.get("origin")
        if lead_mask is None or origin is None:
            lead_signals[lead] = None
            continue
        y1, x1 = origin
        h, w = lead_mask.shape
        non_zero = torch.nonzero(torch.tensor(lead_mask))
        if non_zero.numel() == 0:
            lead_signals[lead] = None
            continue
        non_zero_y_mean = non_zero[:, 0].float().mean()
        sec_per_pixel, mV_per_pixel = estimate_scaling(w)
        total_seconds_from_mask = sec_per_pixel * w
        if total_seconds_from_mask > LONG_SIGNAL_LENGTH_SEC / 2:
            total_seconds = LONG_SIGNAL_LENGTH_SEC
            y_shift_ratio = Y_SHIFT_RATIO["full"]
        else:
            total_seconds = SHORT_SIGNAL_LENGTH_SEC
            y_shift_ratio = Y_SHIFT_RATIO[lead]
        values_needed = int(total_seconds * FREQUENCY)
        signal_cropped_shifted = (1 - y_shift_ratio) * image_height - y1
        predicted_signal = (signal_cropped_shifted - non_zero_y_mean) * mV_per_pixel
        signal = torch.full((w,), predicted_signal, dtype=torch.float32)
        resampled = F.interpolate(signal.unsqueeze(0).unsqueeze(0), size=values_needed, mode="linear", align_corners=False).squeeze().numpy()
        lead_signals[lead] = resampled
    return lead_signals


## Saving outputs
Write WFDB records (or NumPy) with consistent metadata. Missing leads fall back to zeros for WFDB compatibility.

In [None]:
def save_wfdb(record_id: str, lead_signals: Dict[str, np.ndarray], output_dir: str):
    sig_names = list(LEAD_LABEL_MAPPING.keys())
    max_len = max(len(v) for v in lead_signals.values() if v is not None)
    stacked = []
    for lead in sig_names:
        sig = lead_signals.get(lead)
        if sig is None:
            sig = np.zeros(max_len)
        if len(sig) != max_len:
            sig = F.interpolate(torch.tensor(sig).unsqueeze(0).unsqueeze(0), size=max_len, mode="linear", align_corners=False).squeeze().numpy()
        stacked.append(sig)
    stacked = np.vstack(stacked).T
    wfdb.wrsamp(
        os.path.join(output_dir, record_id),
        fs=FREQUENCY,
        units=[SIGNAL_UNITS] * len(sig_names),
        sig_name=sig_names,
        p_signal=stacked,
        fmt=[FMT] * len(sig_names),
        adc_gain=[ADC_GAIN] * len(sig_names),
        baseline=[BASELINE] * len(sig_names)
    )
    np.savez(os.path.join(output_dir, f"{record_id}.npz"), signal=stacked, leads=sig_names)
    print(f"Saved WFDB and NPZ for {record_id} -> {output_dir}")


def plot_overlay(image_path: str, mask_array: np.ndarray, lead_signals: Dict[str, np.ndarray], record_id: str):
    fig, axes = plt.subplots(4, 4, figsize=(14, 10))
    img = Image.open(image_path)
    axes = axes.flatten()
    axes[0].imshow(img, cmap='gray')
    axes[0].imshow(mask_array, alpha=0.3)
    axes[0].set_title("Image + mask")
    lead_names = list(LEAD_LABEL_MAPPING.keys())
    for i, lead in enumerate(lead_names, start=1):
        ax = axes[i]
        sig = lead_signals.get(lead)
        if sig is None:
            ax.text(0.5, 0.5, 'Missing', ha='center', va='center')
        else:
            ax.plot(sig)
        ax.set_title(lead)
        ax.axis('off')
    plt.tight_layout()
    overlay_path = os.path.join(OUTPUT_DIR, f"{record_id}_overlay.png")
    plt.savefig(overlay_path)
    plt.close(fig)
    print(f"Saved overlay -> {overlay_path}")


## Main pipeline
1. Scan input images
2. Rotate via Hough
3. Batch nnU-Net prediction
4. Cut masks and vectorize
5. Save WFDB/NumPy (and optional overlays)

In [None]:
def load_mask(mask_path: str) -> np.ndarray:
    if mask_path.endswith('.npz'):
        data = np.load(mask_path)
        if 'seg' in data:
            return data['seg']
        if 'pred' in data:
            return data['pred']
        if len(data.files) == 1:
            return data[list(data.files)[0]]
        raise ValueError(f"Unknown npz structure in {mask_path}")
    img = nib.load(mask_path)
    return np.asanyarray(img.get_fdata()).squeeze()


def collect_images(input_dir: str, image_ext: str) -> List[str]:
    files = sorted(glob.glob(os.path.join(input_dir, f"*.{image_ext}")))
    if DEBUG:
        files = files[:DEBUG_MAX_IMAGES]
    print(f"Found {len(files)} image(s) in {input_dir}")
    return files


def prepare_rotated_images(image_paths: List[str]) -> Dict[str, Dict]:
    meta = {}
    for img_path in tqdm(image_paths, desc="Rotating"):
        record_id = Path(img_path).stem
        img = cv2.imread(img_path)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        angle = get_rotation_angle(gray)
        rotated = rotate_image(img, angle)
        out_path = os.path.join(TMP_ROTATED_DIR, f"{record_id}.png")
        cv2.imwrite(out_path, rotated)
        meta[record_id] = {
            "rotated_path": out_path,
            "height": rotated.shape[0],
            "width": rotated.shape[1],
            "angle": angle
        }
        print(f"{record_id}: rotation {angle:.2f}° -> {out_path}")
    return meta


def map_masks_to_records(masks_dir: str) -> Dict[str, str]:
    mask_files = glob.glob(os.path.join(masks_dir, "*.nii")) + glob.glob(os.path.join(masks_dir, "*.nii.gz")) + glob.glob(os.path.join(masks_dir, "*.npz"))
    mapping = {}
    for mf in mask_files:
        stem = Path(mf).stem.replace('.nii', '').replace('.npz', '')
        mapping[stem] = mf
    print(f"Collected {len(mapping)} mask files")
    return mapping


def run_pipeline():
    images = collect_images(INPUT_IMAGES_DIR, IMAGE_TYPE)
    if not images:
        raise RuntimeError("No images found. Update INPUT_IMAGES_DIR.")
    meta = prepare_rotated_images(images)
    run_nnunet_predict(TMP_ROTATED_DIR, TMP_MASKS_DIR, MODEL_FOLD_DIR_DEFAULT, NNUNET_RESULTS_ROOT_DEFAULT, device="gpu")
    mask_map = map_masks_to_records(TMP_MASKS_DIR)
    for record_id, info in meta.items():
        mask_file = mask_map.get(record_id)
        if mask_file is None:
            warnings.warn(f"Mask missing for {record_id}")
            continue
        mask = load_mask(mask_file)
        lead_masks = cut_to_mask(mask, LEAD_LABEL_MAPPING)
        lead_signals = vectorise_mask(lead_masks, info["height"])
        save_wfdb(record_id, lead_signals, OUTPUT_DIR)
        if SAVE_OVERLAY:
            plot_overlay(info["rotated_path"], mask, lead_signals, record_id)

run_pipeline()
