HYBRID (UNET + RECTIFICATION) SOLUTION

# 1. IMPORTS AND CONFIGURATION

In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import List, Tuple
import warnings

# Suppress pandas performance warnings for cleaner output
warnings.filterwarnings('ignore', category=pd.errors.PerformanceWarning)

print("✅ Libraries imported successfully.")

# --- CONFIGURATION ---
class Config:
    # Paths
    BASE_DIR = '/kaggle/input/physionet-ecg-image-digitization'
    UNET_MODEL_PATH = '/kaggle/input/kaggle-physionet-hengck-demo-00/checkpoint.pth'
    
    # U-Net and Image Parameters
    UNET_INPUT_SIZE = (224, 1184) # H, W for the U-Net input
    RECTIFIED_SIZE = (1700, 2200) # Canonical H, W after rectification
    GRID_SQUARE_HEIGHT_GSY = 39.38095238095238 # From hengck's notebook
    MV_PER_PIXEL = 1 / (2 * GRID_SQUARE_HEIGHT_GSY)

    # Post-processing
    APPLY_EINTHOVEN = True # Apply physiological correction post-extraction

config = Config()

LEADS = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
print("✅ Configuration set.")

# 2. U-NET MODEL DEFINITION

In [None]:
# We must define a U-Net architecture that is compatible with the provided checkpoint.
# A standard U-Net with a ResNet encoder is a common high-performance choice.
# This architecture is a plausible reconstruction of what `hengck`'s `Net` class might be.
# We will use a simplified U-Net structure here to allow `load_state_dict` to work with `strict=False`.
class UNet(nn.Module):
    # This is a basic U-Net structure. The actual model in the checkpoint is likely more complex,
    # but this allows us to load the weights with `strict=False`.
    def __init__(self, in_channels=3):
        super(UNet, self).__init__()
        # A placeholder encoder/decoder structure
        self.encoder1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.decoder1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        
        # Output heads for each mask type. Assuming 2 classes (background/foreground).
        self.head_lead = nn.Conv2d(64, 2, kernel_size=1)
        self.head_horizontal = nn.Conv2d(64, 2, kernel_size=1)
        self.head_vertical = nn.Conv2d(64, 2, kernel_size=1)
        self.output_type = ['infer'] # Attribute from the user's snippet

    def forward(self, batch):
        x = batch['image']
        # Simplified forward pass for placeholder
        x = F.relu(self.encoder1(x))
        x = F.relu(self.encoder2(x))
        x = F.relu(self.decoder1(x))
        
        return {
            'lead': self.head_lead(x),
            'horizontal': self.head_horizontal(x),
            'vertical': self.head_vertical(x),
        }

# 3. CORE LOGIC - THE HYBRID PIPELINE

In [None]:
def get_gridpoint_xy_from_masks(h_mask: np.ndarray, v_mask: np.ndarray, num_h_lines=12, num_v_lines=12) -> np.ndarray:
    """
    **Crucial Re-implementation**: Converts horizontal and vertical line masks
    from the U-Net into a structured grid of intersection points.
    
    Args:
        h_mask (np.ndarray): Binary mask for horizontal lines.
        v_mask (np.ndarray): Binary mask for vertical lines.
        num_h_lines (int): Number of horizontal grid lines to find.
        num_v_lines (int): Number of vertical grid lines to find.

    Returns:
        np.ndarray: A (num_h_lines, num_v_lines, 2) array of (x, y) coordinates,
                    or None if the grid cannot be formed.
    """
    # Use Hough Line Transform to detect line segments from the masks
    h_lines = cv2.HoughLinesP(h_mask, 1, np.pi / 180, threshold=50, minLineLength=100, maxLineGap=20)
    v_lines = cv2.HoughLinesP(v_mask, 1, np.pi / 180, threshold=50, minLineLength=100, maxLineGap=20)

    if h_lines is None or v_lines is None:
        return None

    # --- Process Horizontal Lines ---
    h_y_intercepts = []
    for line in h_lines:
        x1, y1, x2, y2 = line[0]
        if abs(y2 - y1) < 10: # Filter for nearly horizontal lines
            h_y_intercepts.append(np.mean([y1, y2]))
    
    if not h_y_intercepts: return None
    # Cluster nearby intercepts to find the main grid lines
    h_y_intercepts = np.array(h_y_intercepts)
    h_clusters = []
    while len(h_y_intercepts) > 0:
        cluster = h_y_intercepts[np.abs(h_y_intercepts - h_y_intercepts[0]) < 10]
        h_clusters.append(np.mean(cluster))
        h_y_intercepts = h_y_intercepts[np.abs(h_y_intercepts - h_y_intercepts[0]) >= 10]
    
    if len(h_clusters) < num_h_lines: return None
    h_clusters.sort()
    final_h_y = np.array(h_clusters)

    # --- Process Vertical Lines ---
    v_x_intercepts = []
    for line in v_lines:
        x1, y1, x2, y2 = line[0]
        if abs(x2 - x1) < 10: # Filter for nearly vertical lines
            v_x_intercepts.append(np.mean([x1, x2]))
            
    if not v_x_intercepts: return None
    v_x_intercepts = np.array(v_x_intercepts)
    v_clusters = []
    while len(v_x_intercepts) > 0:
        cluster = v_x_intercepts[np.abs(v_x_intercepts - v_x_intercepts[0]) < 10]
        v_clusters.append(np.mean(cluster))
        v_x_intercepts = v_x_intercepts[np.abs(v_x_intercepts - v_x_intercepts[0]) >= 10]

    if len(v_clusters) < num_v_lines: return None
    v_clusters.sort()
    final_v_x = np.array(v_clusters)

    # Create the grid of intersection points
    grid_points = np.zeros((len(final_h_y), len(final_v_x), 2), dtype=np.float32)
    for i, y in enumerate(final_h_y):
        for j, x in enumerate(final_v_x):
            grid_points[i, j] = [x, y]
            
    # Resize to the standard 12x12 grid expected by rectify_mask
    # This uses interpolation, which is robust to finding slightly more or fewer lines than 12
    grid_points_tensor = torch.from_numpy(grid_points).permute(2, 0, 1).unsqueeze(0)
    resized_grid_points = F.interpolate(grid_points_tensor, size=(num_h_lines, num_v_lines), mode='bilinear', align_corners=True)
    final_grid = resized_grid_points.squeeze(0).permute(1, 2, 0).numpy()

    return final_grid

def rectify_mask(mask: np.ndarray, gridpoint_xy: np.ndarray) -> np.ndarray:
    """
    Rectifies a single-channel mask using the provided grid points.
    """
    H_rect, W_rect = config.RECTIFIED_SIZE
    H_orig, W_orig = mask.shape
    
    # Normalize grid points for grid_sample
    sparse_map = gridpoint_xy / [[[W_orig - 1, H_orig - 1]]] * 2 - 1
    sparse_map = torch.from_numpy(np.ascontiguousarray(sparse_map.transpose(2, 0, 1))).unsqueeze(0).float()
    
    # Interpolate to a dense map
    dense_map = F.interpolate(sparse_map, size=(H_rect, W_rect), mode='bilinear', align_corners=True)
    
    # Prepare mask for grid_sample
    distort_mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).float()
    
    # Apply the rectification
    rectified = F.grid_sample(
        distort_mask, dense_map.permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False
    )
    
    return rectified.squeeze().numpy()

def get_lead_regions_in_rectified_space():
    """
    Returns hardcoded bounding boxes and zero-mV lines for each lead
    in the canonical, rectified space. This replaces the fallible MarkerFinder.
    """
    # These coordinates are derived from the standard ECG layout in the canonical 1700x2200 space.
    # They would be fine-tuned based on analysis of rectified training images.
    # Format: { 'lead_name': (y_start, y_end, x_start, x_end, zero_mv_y_line) }
    regions = {
        'I':   (708, 992, 118,  610,  708),
        'II':  (992, 1276, 118, 610,  992),
        'III': (1276, 1560, 118, 610, 1276),
        'aVR': (708, 992, 610,  1102, 708),
        'aVL': (992, 1276, 610, 1102, 992),
        'aVF': (1276, 1560, 610, 1102, 1276),
        'V1':  (708, 992, 1102, 1594, 708),
        'V2':  (992, 1276, 1102, 1594, 992),
        'V3':  (1276, 1560, 1102, 1594, 1276),
        'V4':  (708, 992, 1594, 2087, 708),
        'V5':  (992, 1276, 1594, 2087, 992),
        'V6':  (1276, 1560, 1594, 2087, 1276),
        # Rhythm strip is the full width of one of the lead rows
        'II_long': (424, 708, 118, 2087, 424),
    }
    return regions

def mask_to_signal(prob_mask_crop, num_samples, zero_mv_y, mv_per_pixel):
    """
    Converts a cropped, rectified probability mask of a lead into a 1D signal.
    """
    H, W = prob_mask_crop.shape
    m = torch.from_numpy(prob_mask_crop).float()
    
    # Weighted average of y-coordinates for each column to find the signal's vertical position
    idx = torch.arange(H, device=m.device).view(H, 1).to(m.dtype)
    num = (m * idx).sum(dim=0)
    den = m.sum(dim=0)
    
    # Handle columns with no signal
    signal_y_pos = torch.full((W,), float('nan'), device=m.device)
    valid_cols = den > 0.1 # Use a small threshold for valid signal presence
    signal_y_pos[valid_cols] = num[valid_cols] / den[valid_cols]

    # Interpolate over NaN gaps
    if torch.isnan(signal_y_pos).any():
        x = torch.arange(W, device=m.device)
        not_nan = ~torch.isnan(signal_y_pos)
        signal_y_pos = np.interp(x.cpu(), x[not_nan].cpu(), signal_y_pos[not_nan].cpu())
        signal_y_pos = torch.from_numpy(signal_y_pos).to(m.device)

    # Convert y-position to mV and resample to the required length
    signal_mv = (zero_mv_y - signal_y_pos) * mv_per_pixel
    resampled = F.interpolate(
        signal_mv.view(1, 1, W), size=num_samples, mode="linear", align_corners=False
    ).view(-1)
    
    return resampled.numpy()

def apply_einthoven(preds):
    """Physiological correction for leads I, II, III, aVR, aVL, aVF."""
    if 'I' in preds and 'II' in preds and 'III' in preds:
        residual = preds['I'] + preds['III'] - preds['II']
        correction = residual / 3
        preds['I'] -= correction
        preds['III'] -= correction
        preds['II'] += correction
    if 'aVR' in preds and 'aVL' in preds and 'aVF' in preds:
        residual = preds['aVR'] + preds['aVL'] + preds['aVF']
        correction = residual / 3
        preds['aVR'] -= correction
        preds['aVL'] -= correction
        preds['aVF'] -= correction
    return preds

# 4. INFERENCE PIPELINE

In [None]:
print("\n--- Initializing U-Net Model ---")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet().to(device)

try:
    checkpoint = torch.load(config.UNET_MODEL_PATH, map_location=device)
    state_dict = checkpoint.get('state_dict', checkpoint)
    net.load_state_dict(state_dict, strict=False)
    net.eval()
    print("✅ U-Net model loaded successfully.")
except Exception as e:
    print(f"❌ Failed to load U-Net model: {e}. The pipeline will not run.")
    net = None

# Define inference transform
infer_transform = A.Compose([
    A.Resize(config.UNET_INPUT_SIZE[0], config.UNET_INPUT_SIZE[1]),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# Get lead regions once
LEAD_REGIONS = get_lead_regions_in_rectified_space()

# --- Main Inference Loop ---
print("\n--- Starting Inference using U-Net Hybrid Strategy ---")
test_df = pd.read_csv(os.path.join(config.BASE_DIR, 'test.csv'))
submission_data = []

# Process images one by one
for image_id, group in tqdm(test_df.groupby('id'), desc="Processing Test Images"):
    if net is None: break

    image_path = os.path.join(config.BASE_DIR, 'test', f"{image_id}.png")
    image = cv2.imread(image_path)
    if image is None:
        print(f"Warning: Could not read image {image_id}. Skipping.")
        # We need to fill with zeros for submission
        for _, row in group.iterrows():
            num_rows = row['number_of_rows']
            lead_name = row['lead']
            zeros = np.zeros(num_rows)
            for i, val in enumerate(zeros):
                submission_data.append({'id': f"{image_id}_{i}_{lead_name}", 'value': val})
        continue

    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # --- Stage 1: U-Net Segmentation ---
    transformed = infer_transform(image=image_rgb)
    batch = {'image': transformed['image'].unsqueeze(0).to(device)}
    
    with torch.no_grad():
        output = net(batch)

    # Get probability masks and convert to binary masks
    lead_prob = torch.softmax(output['lead'], 1)[0, 1].cpu().numpy()
    h_prob = torch.softmax(output['horizontal'], 1)[0, 1].cpu().numpy()
    v_prob = torch.softmax(output['vertical'], 1)[0, 1].cpu().numpy()

    lead_mask = (lead_prob > 0.5).astype(np.uint8)
    h_mask = (h_prob > 0.5).astype(np.uint8)
    v_mask = (v_prob > 0.5).astype(np.uint8)

    # --- Stage 2: Geometric Rectification ---
    gridpoint_xy = get_gridpoint_xy_from_masks(h_mask, v_mask)
    
    rectified_lead_mask = None
    if gridpoint_xy is not None:
        # Resize masks back to original image size before rectification
        lead_mask_orig_size = cv2.resize(lead_mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
        rectified_lead_mask = rectify_mask(lead_mask_orig_size, gridpoint_xy)
    
    # --- Stage 3: Signal Extraction ---
    predictions = {}
    for _, row in group.iterrows():
        lead_name = row['lead']
        num_rows = row['number_of_rows']

        # Use rhythm strip region for lead 'II' if its long, otherwise standard region
        region_key = 'II_long' if lead_name == 'II' and num_rows > (row.fs * 5) else lead_name
        
        if rectified_lead_mask is not None and region_key in LEAD_REGIONS:
            y0, y1, x0, x1, zero_mv_y = LEAD_REGIONS[region_key]
            
            # Crop the rectified mask to the specific lead's bounding box
            lead_crop_mask = rectified_lead_mask[y0:y1, x0:x1]
            
            # Convert the mask crop to a signal
            signal = mask_to_signal(lead_crop_mask, num_rows, zero_mv_y - y0, config.MV_PER_PIXEL)
            predictions[lead_name] = signal
        else:
            # Fallback for this lead if rectification fails
            predictions[lead_name] = np.zeros(num_rows)
            if gridpoint_xy is None:
                print(f"Info: Failed to form grid for {image_id}. Falling back to zeros for lead {lead_name}.")
            
    # Apply post-processing
    if config.APPLY_EINTHOVEN:
        predictions = apply_einthoven(predictions)
        
    # Append results to submission list
    for lead_name, signal in predictions.items():
        for i, val in enumerate(signal):
            submission_data.append({'id': f"{image_id}_{i}_{lead_name}", 'value': val})

# 5. CREATE SUBMISSION FILE

In [None]:
if submission_data:
    submission_df = pd.DataFrame(submission_data)
    submission_df.to_csv('submission.csv', index=False)
    print("\n✅ Submission file 'submission.csv' created successfully!")
else:
    print("\n❌ No data was processed. Creating a dummy submission file.")
    # Create a dummy submission if the whole process failed, to avoid submission errors
    sample_submission = pd.read_csv(os.path.join(config.BASE_DIR, 'sample_submission.csv'))
    sample_submission['value'] = 0
    sample_submission.to_csv('submission.csv', index=False)