# ECG2Signal Complete Demo

This notebook demonstrates the complete ECG image-to-signal conversion pipeline.

## What You'll Learn

1. Loading and preprocessing ECG images
2. Grid detection and calibration
3. Layout detection and OCR
4. Signal segmentation and extraction
5. Signal reconstruction and post-processing
6. Clinical feature extraction
7. Quality assessment
8. Exporting to multiple formats

## Setup

In [None]:
# Install dependencies (if needed)
# !pip install -r requirements.txt

import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add project to path
sys.path.insert(0, '..')

print("✅ Imports successful!")

## Step 1: Create Sample ECG Image

First, let's create a realistic synthetic ECG image with grid and waveforms.

In [None]:
from ecg2signal.training.data_synth import synth_ecg, render

# Generate synthetic 12-lead ECG signals
print("Generating synthetic ECG signals...")
signals = synth_ecg.generate_12lead_ecg(
    duration=10.0,  # 10 seconds
    heart_rate=75,  # BPM
    sampling_rate=500  # Hz
)

print(f"Generated signals for {len(signals)} leads")
print(f"Signal duration: {len(signals['I'])/500:.1f} seconds")
print(f"Sampling rate: 500 Hz")

# Render to image with grid
print("\nRendering ECG to image with grid...")
ecg_image = render.render_ecg_to_image(
    signals=signals,
    paper_speed=25.0,  # mm/s
    gain=10.0,  # mm/mV
    add_grid=True,
    add_labels=True,
    add_noise=True,
    image_size=(3000, 2400)
)

# Display the image
plt.figure(figsize=(15, 12))
plt.imshow(ecg_image, cmap='gray')
plt.title('Synthetic ECG Image with Grid', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()

print(f"✅ Image generated: {ecg_image.shape}")

## Step 2: Load and Preprocess

Load the ECG image and apply preprocessing steps.

In [None]:
from ecg2signal.io import image_io
from ecg2signal.preprocess import detect_page, dewarp, denoise

# Step 2.1: Page detection
print("Step 2.1: Detecting page boundaries...")
page_detected = detect_page.detect_and_crop_page(ecg_image)

fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].imshow(ecg_image, cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(page_detected, cmap='gray')
axes[1].set_title('Page Detected & Cropped')
axes[1].axis('off')
plt.tight_layout()
plt.show()

# Step 2.2: Perspective correction
print("\nStep 2.2: Correcting perspective distortion...")
dewarped = dewarp.correct_perspective(page_detected)

# Step 2.3: Denoising
print("Step 2.3: Denoising image...")
denoised = denoise.denoise_image(dewarped)

fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].imshow(dewarped, cmap='gray')
axes[0].set_title('After Dewarp')
axes[0].axis('off')
axes[1].imshow(denoised, cmap='gray')
axes[1].set_title('After Denoising')
axes[1].axis('off')
plt.tight_layout()
plt.show()

print("✅ Preprocessing complete")

## Step 3: Grid Detection and Calibration

Detect the ECG grid and calculate pixel-to-physical unit calibration.

In [None]:
from ecg2signal.preprocess import grid_detect, scale_calibrate

# Detect grid
print("Detecting grid lines...")
grid_info = grid_detect.detect_grid(denoised)

if grid_info:
    print(f"✅ Grid detected!")
    print(f"   Horizontal spacing: {grid_info['horizontal_spacing']:.2f} pixels")
    print(f"   Vertical spacing: {grid_info['vertical_spacing']:.2f} pixels")
    print(f"   Detected {len(grid_info.get('horizontal_lines', []))} horizontal lines")
    print(f"   Detected {len(grid_info.get('vertical_lines', []))} vertical lines")
    
    # Visualize grid detection
    grid_vis = denoised.copy()
    if len(grid_vis.shape) == 2:
        grid_vis = cv2.cvtColor(grid_vis, cv2.COLOR_GRAY2BGR)
    
    # Draw detected lines
    for line in grid_info.get('horizontal_lines', [])[:20]:  # Show first 20
        cv2.line(grid_vis, (0, line), (grid_vis.shape[1], line), (0, 255, 0), 2)
    for line in grid_info.get('vertical_lines', [])[:20]:  # Show first 20
        cv2.line(grid_vis, (line, 0), (line, grid_vis.shape[0]), (255, 0, 0), 2)
    
    plt.figure(figsize=(15, 10))
    plt.imshow(cv2.cvtColor(grid_vis, cv2.COLOR_BGR2RGB))
    plt.title('Grid Detection (Green=Horizontal, Blue=Vertical)', fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
else:
    print("⚠️ No grid detected, using default calibration")
    grid_info = None

# Calibrate
print("\nCalculating calibration...")
calibration = scale_calibrate.calibrate_from_grid(
    grid_info,
    paper_speed=25.0,  # mm/s
    gain=10.0  # mm/mV
)

print(f"\n📏 Calibration Results:")
print(f"   Pixels per mm: {calibration.pixels_per_mm:.2f}")
print(f"   Paper speed: {calibration.paper_speed} mm/s")
print(f"   Gain: {calibration.gain} mm/mV")
print(f"   Time resolution: {1000 / (calibration.pixels_per_mm * calibration.paper_speed):.2f} ms/pixel")
print(f"   Amplitude resolution: {1.0 / (calibration.pixels_per_mm * calibration.gain):.3f} mV/pixel")

## Step 4: Layout Detection

Detect the positions of individual lead panels and rhythm strips.

In [None]:
from ecg2signal.layout import lead_layout, ocr_labels

# Detect lead layout
print("Detecting lead layout...")
layout = lead_layout.detect_lead_layout(denoised)

print(f"✅ Detected {len(layout)} leads")
print(f"\nLead positions:")
for lead_name, bbox in list(layout.items())[:5]:  # Show first 5
    print(f"   {lead_name}: ({bbox['x1']}, {bbox['y1']}) to ({bbox['x2']}, {bbox['y2']})")

# Visualize layout
layout_vis = denoised.copy()
if len(layout_vis.shape) == 2:
    layout_vis = cv2.cvtColor(layout_vis, cv2.COLOR_GRAY2BGR)

colors = {
    'I': (255, 0, 0), 'II': (0, 255, 0), 'III': (0, 0, 255),
    'aVR': (255, 255, 0), 'aVL': (255, 0, 255), 'aVF': (0, 255, 255),
    'V1': (128, 0, 0), 'V2': (0, 128, 0), 'V3': (0, 0, 128),
    'V4': (128, 128, 0), 'V5': (128, 0, 128), 'V6': (0, 128, 128)
}

for lead_name, bbox in layout.items():
    color = colors.get(lead_name, (200, 200, 200))
    cv2.rectangle(layout_vis, 
                  (bbox['x1'], bbox['y1']), 
                  (bbox['x2'], bbox['y2']), 
                  color, 3)
    cv2.putText(layout_vis, lead_name, 
                (bbox['x1'] + 5, bbox['y1'] + 30),
                cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)

plt.figure(figsize=(15, 12))
plt.imshow(cv2.cvtColor(layout_vis, cv2.COLOR_BGR2RGB))
plt.title('Lead Layout Detection', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()

# OCR for metadata
print("\nExtracting metadata with OCR...")
metadata = ocr_labels.extract_labels(denoised)
print(f"📄 Extracted metadata:")
for key, value in metadata.items():
    print(f"   {key}: {value}")

## Step 5: Segmentation

Segment the image into layers: grid, waveforms, and text.

In [None]:
from ecg2signal.segment import separate_layers

print("Segmenting image layers...")
masks = separate_layers.segment_layers(denoised)

print(f"✅ Generated masks for: {list(masks.keys())}")

# Visualize masks
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

axes[0, 0].imshow(denoised, cmap='gray')
axes[0, 0].set_title('Original Image', fontsize=14)
axes[0, 0].axis('off')

axes[0, 1].imshow(masks['grid'], cmap='gray')
axes[0, 1].set_title('Grid Mask', fontsize=14)
axes[0, 1].axis('off')

axes[1, 0].imshow(masks['waveform'], cmap='gray')
axes[1, 0].set_title('Waveform Mask', fontsize=14)
axes[1, 0].axis('off')

axes[1, 1].imshow(masks.get('text', np.zeros_like(denoised)), cmap='gray')
axes[1, 1].set_title('Text Mask', fontsize=14)
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

# Show waveform overlay
overlay = denoised.copy()
if len(overlay.shape) == 2:
    overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2BGR)
overlay[masks['waveform'] > 128] = [0, 255, 0]  # Highlight waveforms in green

plt.figure(figsize=(15, 10))
plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
plt.title('Detected Waveforms (Green)', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()

## Step 6: Signal Extraction and Reconstruction

Extract waveforms and convert them to calibrated time-series signals.

In [None]:
from ecg2signal.segment import trace_curve
from ecg2signal.reconstruct import raster_to_signal, resample, align_leads, postprocess

# Extract signals for each lead
print("Extracting signals from each lead...")
extracted_signals = {}

for lead_name, bbox in list(layout.items())[:3]:  # Process first 3 for demo
    print(f"\nProcessing {lead_name}...")
    
    # Extract lead region
    lead_mask = masks['waveform'][bbox['y1']:bbox['y2'], bbox['x1']:bbox['x2']]
    
    # Trace the waveform curve
    curve = trace_curve.trace_waveform(lead_mask)
    print(f"  Traced {len(curve)} points")
    
    # Convert to signal with calibration
    signal = raster_to_signal.pixels_to_signal(
        curve,
        calibration,
        roi_offset=(bbox['x1'], bbox['y1'])
    )
    print(f"  Signal length: {len(signal)} samples")
    
    # Resample to standard rate (500 Hz)
    resampled = resample.resample_signal(signal, target_fs=500.0)
    print(f"  Resampled to: {len(resampled)} samples")
    
    # Post-process (filter, baseline removal)
    clean_signal = postprocess.postprocess_signal(resampled)
    print(f"  Post-processed: {len(clean_signal)} samples")
    
    extracted_signals[lead_name] = clean_signal

# Visualize extracted signals
fig, axes = plt.subplots(len(extracted_signals), 1, figsize=(15, 4*len(extracted_signals)))
if len(extracted_signals) == 1:
    axes = [axes]

for idx, (lead_name, signal) in enumerate(extracted_signals.items()):
    time = np.arange(len(signal)) / 500.0  # Time in seconds
    axes[idx].plot(time, signal, 'b-', linewidth=1.5)
    axes[idx].set_title(f'Lead {lead_name}', fontsize=14)
    axes[idx].set_xlabel('Time (s)', fontsize=12)
    axes[idx].set_ylabel('Amplitude (mV)', fontsize=12)
    axes[idx].grid(True, alpha=0.3)
    axes[idx].set_xlim([0, 2])  # Show first 2 seconds

plt.tight_layout()
plt.show()

print(f"\n✅ Extracted {len(extracted_signals)} signals")

## Step 7: Clinical Feature Extraction

Extract clinical intervals and compute heart rate.

In [None]:
from ecg2signal.clinical import intervals

# Use first lead for clinical analysis
first_lead_name = list(extracted_signals.keys())[0]
first_signal = extracted_signals[first_lead_name]

print(f"Extracting clinical intervals from {first_lead_name}...")
clinical_intervals = intervals.extract_intervals(first_signal, fs=500.0)

print(f"\n❤️ Clinical Measurements:")
print(f"   Heart Rate: {clinical_intervals.get('heart_rate', 'N/A')} BPM")
print(f"   PR Interval: {clinical_intervals.get('pr_interval', 'N/A')} ms")
print(f"   QRS Duration: {clinical_intervals.get('qrs_duration', 'N/A')} ms")
print(f"   QT Interval: {clinical_intervals.get('qt_interval', 'N/A')} ms")
print(f"   QTc (Corrected): {clinical_intervals.get('qtc_interval', 'N/A')} ms")

# Visualize with annotations
if 'r_peaks' in clinical_intervals:
    time = np.arange(len(first_signal)) / 500.0
    
    plt.figure(figsize=(15, 5))
    plt.plot(time, first_signal, 'b-', linewidth=1.5, label='ECG Signal')
    
    # Mark R-peaks
    r_peaks = clinical_intervals['r_peaks']
    plt.plot(time[r_peaks], first_signal[r_peaks], 'ro', 
             markersize=10, label='R-peaks')
    
    plt.title(f'Lead {first_lead_name} with R-peak Detection', fontsize=16)
    plt.xlabel('Time (s)', fontsize=12)
    plt.ylabel('Amplitude (mV)', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)
    plt.xlim([0, min(5, len(first_signal)/500.0)])  # Show first 5 seconds
    plt.tight_layout()
    plt.show()

## Step 8: Quality Assessment

Assess signal quality metrics.

In [None]:
from ecg2signal.clinical import quality

print("Assessing signal quality...")
quality_metrics = quality.assess_quality(list(extracted_signals.values()))

print(f"\n📊 Quality Metrics:")
print(f"   Overall Quality: {quality_metrics['overall_quality']:.2%}")
print(f"   SNR: {quality_metrics.get('snr', 'N/A')} dB")
print(f"   Baseline Wander: {quality_metrics.get('baseline_wander', 'N/A'):.3f}")
print(f"   Clipping Detected: {'Yes' if quality_metrics.get('clipping_detected') else 'No'}")
print(f"   Coverage: {quality_metrics['coverage']:.2%}")

# Visualize quality
metrics_names = ['Overall Quality', 'SNR (norm)', 'Coverage']
metrics_values = [
    quality_metrics['overall_quality'],
    min(quality_metrics.get('snr', 20) / 40, 1.0),  # Normalize to 0-1
    quality_metrics['coverage']
]

colors = ['green' if v > 0.8 else 'orange' if v > 0.6 else 'red' for v in metrics_values]

plt.figure(figsize=(10, 6))
bars = plt.barh(metrics_names, metrics_values, color=colors, alpha=0.7)
plt.xlim([0, 1])
plt.xlabel('Score', fontsize=12)
plt.title('Signal Quality Assessment', fontsize=16)
plt.grid(axis='x', alpha=0.3)

# Add value labels
for i, (bar, val) in enumerate(zip(bars, metrics_values)):
    plt.text(val + 0.02, i, f'{val:.2%}', va='center', fontsize=12)

plt.tight_layout()
plt.show()

## Step 9: Export to Multiple Formats

Export the extracted signals to various clinical formats.

In [None]:
from ecg2signal.io import wfdb_io, edf_io, fhir, dcm_waveform
import json
import pandas as pd
from pathlib import Path

# Create output directory
output_dir = Path('demo_outputs')
output_dir.mkdir(exist_ok=True)

print("Exporting to multiple formats...\n")

# 1. WFDB Format
print("1. Exporting to WFDB (MIT format)...")
wfdb_io.write_wfdb(
    extracted_signals,
    str(output_dir / 'ecg'),
    fs=500.0
)
print(f"   ✅ Saved: {output_dir}/ecg.dat, {output_dir}/ecg.hea")

# 2. CSV Format
print("\n2. Exporting to CSV...")
df = pd.DataFrame(extracted_signals)
df.index.name = 'sample'
df.to_csv(output_dir / 'ecg_signals.csv')
print(f"   ✅ Saved: {output_dir}/ecg_signals.csv")
print(f"   Preview:")
print(df.head())

# 3. JSON Format
print("\n3. Exporting to JSON...")
json_data = {
    'signals': {k: v.tolist() for k, v in extracted_signals.items()},
    'sampling_rate': 500.0,
    'calibration': {
        'paper_speed': calibration.paper_speed,
        'gain': calibration.gain,
        'pixels_per_mm': calibration.pixels_per_mm
    },
    'clinical_intervals': clinical_intervals,
    'quality_metrics': quality_metrics
}

with open(output_dir / 'ecg_data.json', 'w') as f:
    json.dump(json_data, f, indent=2)
print(f"   ✅ Saved: {output_dir}/ecg_data.json")

# 4. EDF+ Format
print("\n4. Exporting to EDF+ format...")
edf_io.write_edf(
    extracted_signals,
    str(output_dir / 'ecg.edf'),
    fs=500.0
)
print(f"   ✅ Saved: {output_dir}/ecg.edf")

# 5. FHIR Format
print("\n5. Exporting to HL7 FHIR format...")
fhir_obs = fhir.to_fhir_observation(
    extracted_signals,
    fs=500.0,
    patient_id='demo-patient-001'
)
with open(output_dir / 'ecg_fhir.json', 'w') as f:
    json.dump(fhir_obs, f, indent=2)
print(f"   ✅ Saved: {output_dir}/ecg_fhir.json")

# 6. DICOM Waveform
print("\n6. Exporting to DICOM Waveform format...")
dcm_waveform.write_dicom_waveform(
    extracted_signals,
    str(output_dir / 'ecg.dcm'),
    fs=500.0,
    patient_name='Demo Patient'
)
print(f"   ✅ Saved: {output_dir}/ecg.dcm")

print(f"\n✅ All exports complete! Files saved to: {output_dir.absolute()}")

## Step 10: Complete Pipeline with ECGConverter

Now let's use the high-level `ECGConverter` API to process everything in one call.

In [None]:
from ecg2signal import ECGConverter
from ecg2signal.config import Settings

# Save our test image
test_image_path = output_dir / 'test_ecg.png'
cv2.imwrite(str(test_image_path), ecg_image)

print("Using ECGConverter for complete pipeline...\n")

# Initialize converter
settings = Settings()
converter = ECGConverter(settings)

# Convert image to signals
print("Converting ECG image to signals...")
result = converter.convert(
    str(test_image_path),
    paper_speed=25.0,
    gain=10.0,
    sample_rate=500
)

print(f"\n✅ Conversion complete!")
print(f"\nResult summary:")
print(f"   Number of leads: {len(result.signals)}")
print(f"   Sampling rate: {result.sample_rate} Hz")
print(f"   Paper speed: {result.paper_settings.paper_speed} mm/s")
print(f"   Gain: {result.paper_settings.gain} mm/mV")

if result.intervals:
    print(f"\n   Heart rate: {result.intervals.get('heart_rate', 'N/A')} BPM")
    print(f"   QRS duration: {result.intervals.get('qrs_duration', 'N/A')} ms")

if result.quality_metrics:
    print(f"\n   Quality score: {result.quality_metrics.overall_score:.2%}")

# Export using result object
print("\nExporting via result object...")
result.export_wfdb(str(output_dir / 'pipeline_output'))
print(f"✅ Exported to: {output_dir}/pipeline_output.*")

## Summary

### What We Accomplished

1. ✅ **Generated** synthetic ECG image with realistic grid and waveforms
2. ✅ **Preprocessed** image (page detection, dewarping, denoising)
3. ✅ **Detected** grid and calibrated pixel-to-physical units
4. ✅ **Identified** lead positions and extracted metadata
5. ✅ **Segmented** image into grid, waveforms, and text
6. ✅ **Reconstructed** digital signals from pixel traces
7. ✅ **Extracted** clinical intervals (HR, PR, QRS, QT)
8. ✅ **Assessed** signal quality metrics
9. ✅ **Exported** to 6 different formats (WFDB, CSV, JSON, EDF, FHIR, DICOM)
10. ✅ **Demonstrated** high-level API usage

### Next Steps

- **Train models** with real ECG data for better accuracy
- **Test with real ECG scans** from hospitals or PhysioNet
- **Tune parameters** for specific ECG types and qualities
- **Deploy** the API for production use
- **Integrate** with EHR/EMR systems

### Resources

- **Documentation**: See `docs/` directory
- **API Reference**: Run the API and visit `/docs`
- **Test Data**: PhysioNet databases (MIT-BIH, PTB-XL)
- **Support**: Check GitHub issues or documentation

---

**🎉 Congratulations!** You've completed the full ECG2Signal demo!