# PhysioNet ECG U-Net Inference

This notebook uses a trained U-Net model for ECG waveform segmentation.

**Model Details:**
- Architecture: U-Net (31M parameters)
- Training: 30 epochs on 8,793 images
- Best Validation Loss: 0.0525
- Model Size: 356MB

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
import cv2
from tqdm import tqdm

print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## Load U-Net Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.pool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.pool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                       diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        # Encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        
        # Decoder (FIXED)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        
        # Output
        self.outc = nn.Conv2d(64, n_classes, 1)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        return self.outc(x)

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=2)

# Load weights from dataset
model_path = '/kaggle/input/ecg-unet-trained-model/unet_best.pth'

checkpoint = torch.load(model_path, map_location=device)

  # Handle both checkpoint formats: dict with 'model_state_dict' or direct state_dict
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
  model.load_state_dict(checkpoint['model_state_dict'])
  print(f'Loaded checkpoint from epoch {checkpoint.get("epoch", "unknown")}')
  print(f'Training loss: {checkpoint.get("train_loss", "N/A"):.4f}')
  print(f'Validation loss: {checkpoint.get("val_loss", "N/A"):.4f}')
else:
  model.load_state_dict(checkpoint)
  
model = model.to(device)
model.eval()

print(f'Model loaded successfully!')
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Preprocessing Function

In [None]:
def preprocess_image(image_path, target_size=(512, 1024)):
    """Preprocess ECG image for U-Net inference."""
    # Read image
    img = cv2.imread(str(image_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize
    img = cv2.resize(img, (target_size[1], target_size[0]))
    
    # Normalize to [0, 1]
    img = img.astype(np.float32) / 255.0
    
    # Convert to tensor (C, H, W)
    img_tensor = torch.from_numpy(img.transpose(2, 0, 1))
    
    return img_tensor.unsqueeze(0)  # Add batch dimension

def predict_mask(model, image_tensor, device):
    """Run inference and get segmentation mask."""
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        output = model(image_tensor)
        
        # Apply sigmoid for probabilities
        probs = torch.sigmoid(output)
        
        # Get binary mask (threshold at 0.5)
        mask = (probs > 0.5).float()
        
    return mask.cpu().numpy()[0]  # Return (2, H, W)

## Inference on Test Data

In [None]:
# Load test data
test_dir = Path('/kaggle/input/physionet-ecg-image-digitization/test')
test_images = sorted(test_dir.glob('*.png'))

print(f'Found {len(test_images)} test images')

# Run inference on a sample
if len(test_images) > 0:
    sample_img = test_images[0]
    print(f'Testing on: {sample_img.name}')
    
    # Preprocess
    img_tensor = preprocess_image(sample_img)
    
    # Predict
    mask = predict_mask(model, img_tensor, device)
    
    print(f'Mask shape: {mask.shape}')
    print(f'Waveform mask (channel 0): {mask[0].sum()} pixels')
    print(f'Grid mask (channel 1): {mask[1].sum()} pixels')

## Generate Submission

**Note:** This is a basic inference example. For full submission, you need to:
1. Extract signal from segmentation mask
2. Apply calibration and scaling
3. Convert to competition format (CSV with id, value columns)

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from scipy.interpolate import interp1d

# ============================================================================
# CALIBRATION DETECTION (handles small/large grid correctly)
# ============================================================================

def detect_horizontal_lines(grid_mask):
    """Detect horizontal grid lines from mask to calculate calibration."""
    # Sum across width to find horizontal lines
    horizontal_projection = np.sum(grid_mask, axis=1)

    if horizontal_projection.max() == 0:
        return []

    # Find peaks (rows with many grid pixels)
    threshold = np.percentile(horizontal_projection, 95)
    line_positions = np.where(horizontal_projection > threshold)[0]

    # Group nearby positions
    lines = []
    if len(line_positions) > 0:
        current_line = [line_positions[0]]
        for pos in line_positions[1:]:
            if pos - current_line[-1] > 5:  # New line
                lines.append(int(np.mean(current_line)))
                current_line = [pos]
            else:
                current_line.append(pos)
        lines.append(int(np.mean(current_line)))

    return lines

def predict_mask_with_calibration(model, image_tensor, device):
    """Run inference and get segmentation mask with calibration."""
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        output = model(image_tensor)

        # Apply sigmoid for probabilities
        probs = torch.sigmoid(output)

        # Get binary mask (threshold at 0.5)
        mask = (probs > 0.5).float()

    mask_np = mask.cpu().numpy()[0]  # (2, H, W)

    # CRITICAL: Detect calibration using grid mask
    grid_mask = mask_np[1]  # Channel 1 = grid lines

    # Find horizontal grid lines
    horizontal_lines = detect_horizontal_lines(grid_mask)

    if len(horizontal_lines) >= 2:
        # Calculate spacing between grid lines
        grid_spacing = np.median(np.diff(sorted(horizontal_lines)))

        # ECG paper has TWO grid sizes:
        # - Small grid: 1mm spacing (0.1 mV for standard 10mm/mV)
        # - Large grid: 5mm spacing (0.5 mV)
        # We want to calibrate to LARGE grid (5mm = 0.5 mV)

        # For 512px height and ~20-25 grid lines:
        # This is detecting SMALL grid (1mm spacing)
        # So grid_spacing ≈ 12px = 0.1 mV → multiply by 10 to get 1 mV

        # Standard ECG: 10mm = 1 mV
        # If we detect 20+ lines in 512px, those are 1mm (small) grids
        # 5mm (large grid) = 5 * grid_spacing

        if len(horizontal_lines) > 15:
            # Many lines detected → small grid (1mm)
            # 5mm spacing = 5 * grid_spacing = 0.5 mV
            # 10mm spacing = 10 * grid_spacing = 1.0 mV
            pixels_per_mv = 10 * grid_spacing  # Convert small grid to 1mV
            print(f"Detected small grid: {grid_spacing:.1f}px spacing")
            print(f"Calibration: {pixels_per_mv:.1f} pixels/mV")
        else:
            # Few lines detected → large grid (5mm)
            # Each spacing = 0.5 mV, so 1 mV = 2 * spacing
            pixels_per_mv = 2 * grid_spacing
            print(f"Detected large grid: {grid_spacing:.1f}px spacing")
            print(f"Calibration: {pixels_per_mv:.1f} pixels/mV")

    else:
        # Fallback: Use empirical value for 512px images
        # Original 1700px images have ~40-80 px/mV
        # Scaled to 512px: 40-80 * (512/1700) ≈ 12-24 px/mV
        # But we need to account for small grid → large grid conversion
        pixels_per_mv = 120.0  # Conservative estimate (10x small grid)
        print(f"No grid lines detected, using default: {pixels_per_mv:.1f} pixels/mV")

    return mask_np, pixels_per_mv

# ============================================================================
# SIGNAL EXTRACTION WITH CALIBRATION
# ============================================================================

def extract_signal_from_mask(mask, pixels_per_mv, target_length=5000, debug=False):
    """
    Extract ECG signal from waveform mask with proper mV conversion.

    Args:
        mask: Binary mask (H, W) with waveform pixels
        pixels_per_mv: Calibration factor (pixels per millivolt)
        target_length: Target signal length for resampling
        debug: If True, print debug information

    Returns:
        signal: 1D array of signal values in mV
    """
    h, w = mask.shape
    signal = []

    # Find baseline (median y-position of all waveform pixels)
    all_y_positions = []
    for x in range(w):
        column = mask[:, x]
        y_positions = np.where(column > 0)[0]
        if len(y_positions) > 0:
            all_y_positions.extend(y_positions.tolist())

    if len(all_y_positions) == 0:
        if debug:
            print("WARNING: No waveform pixels detected!")
        return np.zeros(target_length)

    # Baseline is the median position (0 mV reference)
    baseline = np.median(all_y_positions)

    if debug:
        print(f"Baseline position: {baseline:.1f} pixels")
        print(f"Total waveform pixels: {len(all_y_positions)}")
        print(f"Calibration: {pixels_per_mv:.1f} pixels/mV")

    # For each column, find the y-position of the waveform
    columns_with_signal = 0
    for x in range(w):
        column = mask[:, x]
        y_positions = np.where(column > 0)[0]

        if len(y_positions) > 0:
            columns_with_signal += 1
            # Use median y-position (robust to noise)
            y_median = np.median(y_positions)

            # CRITICAL: Convert pixel distance to mV
            # Negative because y increases downward in images
            pixel_distance = baseline - y_median
            signal_value_mv = pixel_distance / pixels_per_mv

            signal.append(signal_value_mv)
        else:
            # No waveform detected, use previous value or 0
            signal.append(signal[-1] if signal else 0.0)

    if debug:
        print(f"Columns with signal: {columns_with_signal}/{w}")
        if signal:
            print(f"Signal range before clipping: [{min(signal):.3f}, {max(signal):.3f}] mV")

    # Convert to numpy array
    signal = np.array(signal)

    # Remove DC offset (center around 0)
    signal = signal - np.median(signal)

    # Clip to physiologically reasonable range (±2 mV for normal ECG)
    signal = np.clip(signal, -2.0, 2.0)

    if debug:
        print(f"Signal range after processing: [{signal.min():.3f}, {signal.max():.3f}] mV")
        print(f"Signal mean: {signal.mean():.3f}, std: {signal.std():.3f}")

    # Resample to target length
    if len(signal) > 0:
        x_old = np.linspace(0, 1, len(signal))
        x_new = np.linspace(0, 1, target_length)
        f = interp1d(x_old, signal, kind='linear', fill_value='extrapolate')
        signal_resampled = f(x_new)
    else:
        signal_resampled = np.zeros(target_length)

    return signal_resampled

# ============================================================================
# SUBMISSION GENERATION
# ============================================================================

def generate_submission(test_dir, model, device, output_path='submission.csv', debug_first=True):
    """
    Generate submission CSV with proper calibration.

    Args:
        test_dir: Path to test images directory
        model: Trained U-Net model
        device: torch device
        output_path: Output CSV path
        debug_first: If True, show detailed debug info for first image
    """
    test_images = sorted(Path(test_dir).glob('*.png'))

    submissions = []

    print(f'Processing {len(test_images)} test images...')

    for idx, img_path in enumerate(tqdm(test_images)):
        image_id = img_path.stem
        is_first = (idx == 0)

        # Preprocess and predict WITH CALIBRATION
        img_tensor = preprocess_image(img_path)
        mask, pixels_per_mv = predict_mask_with_calibration(model, img_tensor, device)

        if is_first and debug_first:
            print(f"\n{'='*60}")
            print(f"DEBUG INFO for first image: {image_id}")
            print(f"{'='*60}")
            print(f"Mask shape: {mask.shape}")
            print(f"Channel 0 (waveform) pixels: {np.sum(mask[0] > 0)}")
            print(f"Channel 1 (grid) pixels: {np.sum(mask[1] > 0)}")
            print(f"Final calibration: {pixels_per_mv:.1f} pixels/mV")

        # Extract waveform mask (channel 0)
        waveform_mask = mask[0]

        # Extract signal with calibration
        signal = extract_signal_from_mask(
            waveform_mask,
            pixels_per_mv,
            target_length=5000,
            debug=(is_first and debug_first)
        )

        # Create submission rows
        lead_name = 'I'  # Single lead for now
        for sample_idx, value in enumerate(signal):
            row_id = f"{image_id}_{sample_idx}_{lead_name}"
            submissions.append({'id': row_id, 'value': value})

    # Save
    df = pd.DataFrame(submissions)
    df.to_csv(output_path, index=False)

    print(f'\n{"="*60}')
    print(f'Submission saved to {output_path}')
    print(f'Total rows: {len(df)}')
    print(f'\n{"="*60}')
    print(f'Signal statistics:')
    print(f'{"="*60}')
    stats = df['value'].describe()
    print(stats)

    # Quality checks
    print(f'\n{"="*60}')
    print(f'QUALITY CHECKS:')
    print(f'{"="*60}')

    mean_val = stats['mean']
    std_val = stats['std']
    min_val = stats['min']
    max_val = stats['max']

    # Check 1: Mean should be near 0
    if abs(mean_val) < 0.3:
        print(f'✓ Mean value OK: {mean_val:.3f} (expected: ~0)')
    else:
        print(f'⚠️  Mean value high: {mean_val:.3f} (expected: ~0)')

    # Check 2: Std should be 0.1-0.5 mV
    if 0.05 < std_val < 0.8:
        print(f'✓ Standard deviation OK: {std_val:.3f}')
    else:
        print(f'⚠️  Standard deviation unusual: {std_val:.3f} (expected: 0.1-0.5)')

    # Check 3: Range should be ±1-2 mV
    if -2.5 < min_val < -0.2 and 0.2 < max_val < 2.5:
        print(f'✓ Signal range OK: [{min_val:.3f}, {max_val:.3f}]')
    else:
        print(f'⚠️  Signal range unusual: [{min_val:.3f}, {max_val:.3f}] (expected: ±0.5-2 mV)')

    # Sanity check
    unique_values = df['value'].nunique()
    if unique_values < 100:
        print(f'⚠️  WARNING: Only {unique_values} unique values detected!')
    else:
        print(f'✓ Good diversity: {unique_values} unique values')

    print(f'\nPreview (first 20 rows):')
    print(df.head(20))

    return df

# =============================================================================
# MAIN EXECUTION
# =============================================================================

# Generate submission with calibration
test_dir = '/kaggle/input/physionet-ecg-image-digitization/test'
submission_df = generate_submission(test_dir, model, device, 'submission.csv', debug_first=True)

print('\n' + '='*60)
print('SUBMISSION COMPLETE!')
print('='*60)
print('File: submission.csv')
