In [None]:
import cv2
import numpy as np
from skimage.morphology import skeletonize
from scipy.signal import medfilt
import pandas as pd
import torch
from torch import nn
from torchvision import transforms
from PIL import Image

# Dummy U-Net model for demonstration
# In a real-world scenario, you would define, train, and load a proper U-Net model here.
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Simplified for demonstration. A full U-Net has multiple encoder/decoder blocks.
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 1, 3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return torch.sigmoid(x)

def load_unet_model(model_path, device):
    """
    Loads a trained U-Net model from a file.
    Args:
        model_path (str): The path to the saved model state.
        device (torch.device): The device to load the model onto.
    Returns:
        UNet: The loaded and prepped U-Net model.
    """
    model = UNet()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

# 1. Pre-process the ECG images
def preprocess_image(image_path):
    """
    Performs all necessary preprocessing steps on an ECG image.
    Args:
        image_path (str): Path to the input image.
    Returns:
        np.ndarray: The preprocessed image ready for segmentation.
    """
    # Load image in grayscale
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Image not found at {image_path}")

    # Invert colors if needed (assuming black signal on white paper)
    img = cv2.bitwise_not(img)

    # Correct skew and rotation (using Hough lines or other techniques)
    # This is a complex task. For simplicity, assume manual alignment or a basic rotation.
    # Here, we use a simple horizontal line detection approach.
    edges = cv2.Canny(img, 50, 150, apertureSize=3)
    lines = cv2.HoughLines(edges, 1, np.pi/180, 200)
    angle = 0
    if lines is not None:
        for rho, theta in lines[0]:
            angle = theta
            break
    rotation_angle = np.degrees(angle)
    M = cv2.getRotationMatrix2D((img.shape[1]/2, img.shape[0]/2), rotation_angle - 90, 1)
    img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]), borderMode=cv2.BORDER_CONSTANT, borderValue=255)

    # Basic denoising using a median filter
    img = medfilt(img, kernel_size=3)
    
    return img

# 2. Segment the ECG waveform (using a trained U-Net model)
def segment_waveform(preprocessed_img, unet_model, device):
    """
    Segments the ECG waveform from the preprocessed image using a U-Net model.
    Args:
        preprocessed_img (np.ndarray): The preprocessed input image.
        unet_model (UNet): The trained U-Net model.
        device (torch.device): The computation device.
    Returns:
        np.ndarray: A binary mask of the segmented waveform.
    """
    # Resize and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # U-Net typically works with fixed-size images
        transforms.ToTensor(),
    ])
    
    pil_img = Image.fromarray(preprocessed_img).convert('L')
    tensor_img = transform(pil_img).unsqueeze(0).to(device)

    # Get segmentation mask from U-Net
    with torch.no_grad():
        mask_tensor = unet_model(tensor_img)
    
    # Post-process mask: resize to original size and convert to binary
    mask_np = mask_tensor.squeeze().cpu().numpy()
    mask_resized = cv2.resize(mask_np, (preprocessed_img.shape[1], preprocessed_img.shape[0]))
    
    return (mask_resized > 0.5).astype(np.uint8) * 255

# 3. Extract and refine the pixel coordinates
def extract_coordinates(mask):
    """
    Extracts x, y pixel coordinates from the binary mask.
    Args:
        mask (np.ndarray): The binary mask of the ECG waveform.
    Returns:
        tuple: Arrays for x-coordinates (time) and y-coordinates (voltage).
    """
    # Thin the waveform to a single pixel line using skeletonization
    skeleton = skeletonize(mask > 0)
    
    # Extract coordinates
    coords = np.argwhere(skeleton)
    
    # Sort coordinates by x-axis (time)
    coords = coords[np.argsort(coords[:, 1])]
    
    # Extract median y-coordinate for each x-coordinate
    time_pixels = []
    voltage_pixels = []
    
    current_x = -1
    y_vals = []
    
    for y, x in coords:
        if x != current_x and y_vals:
            time_pixels.append(current_x)
            voltage_pixels.append(np.median(y_vals))
            y_vals = []
        current_x = x
        y_vals.append(y)
    
    if y_vals:
        time_pixels.append(current_x)
        voltage_pixels.append(np.median(y_vals))

    # Path refinement with a simple moving average filter
    voltage_pixels_refined = medfilt(voltage_pixels, kernel_size=5)
    
    return np.array(time_pixels), np.array(voltage_pixels_refined)

# 4. Calibrate pixels to time-series data
def calibrate_to_signal(image, time_pixels, voltage_pixels):
    """
    Calibrates pixel coordinates to time and voltage values.
    Args:
        image (np.ndarray): The original image to detect grid.
        time_pixels (np.ndarray): Array of extracted time pixel coordinates.
        voltage_pixels (np.ndarray): Array of extracted voltage pixel coordinates.
    Returns:
        pd.DataFrame: A DataFrame with calibrated time and voltage data.
    """
    # Detect grid scale
    # This is a complex task and simplified here by hardcoding standard ECG paper properties.
    # Advanced versions would use template matching or Hough lines to detect grid lines.
    pixel_per_mm = 10 # Assuming 10 pixels per mm for this example

    # Standard ECG paper properties
    mm_per_mv = 10
    mm_per_sec = 25

    # Calculate scaling factors
    pixel_per_mv = pixel_per_mm * mm_per_mv
    pixel_per_sec = pixel_per_mm * mm_per_sec

    # Invert voltage axis as image y-axis is inverted
    max_y = image.shape[0]
    voltage_calibrated = ((max_y - voltage_pixels) / pixel_per_mv) * 1 # Assuming 10 mm per mV

    # Calculate time
    time_calibrated = (time_pixels - time_pixels[0]) / pixel_per_sec

    return pd.DataFrame({'time_s': time_calibrated, 'voltage_mV': voltage_calibrated})

# Main execution block
def main(image_path, model_path, output_csv):
    """
    Main function to run the entire digitization pipeline.
    Args:
        image_path (str): Path to the ECG image.
        model_path (str): Path to the trained U-Net model.
        output_csv (str): Path for the output CSV file.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Step 1: Pre-process the ECG image
    print("Preprocessing image...")
    preprocessed_img = preprocess_image(image_path)
    
    # Step 2: Segment the ECG waveform
    print("Segmenting waveform with U-Net...")
    unet_model = load_unet_model(model_path, device)
    mask = segment_waveform(preprocessed_img, unet_model, device)
    
    # Step 3: Extract and refine pixel coordinates
    print("Extracting pixel coordinates...")
    time_pixels, voltage_pixels = extract_coordinates(mask)
    
    # Step 4: Calibrate and save the time-series data
    print("Calibrating and saving data...")
    time_series_df = calibrate_to_signal(preprocessed_img, time_pixels, voltage_pixels)
    time_series_df.to_csv(output_csv, index=False)
    
    print(f"Digitization complete. Output saved to {output_csv}")

if __name__ == "__main__":
    # Example usage:
    # Set up dummy image and model file paths for demonstration.
    # In a real project, replace with actual paths.
    DUMMY_IMAGE_PATH = 'dummy_ecg_image.png'
    DUMMY_MODEL_PATH = 'dummy_unet_model.pth'
    OUTPUT_CSV_PATH = 'digitized_ecg.csv'

    # Create dummy files for demonstration purposes
    dummy_img = np.zeros((500, 1000), dtype=np.uint8)
    cv2.line(dummy_img, (0, 250), (1000, 250), 255, 2)
    cv2.imwrite(DUMMY_IMAGE_PATH, dummy_img)
    torch.save(UNet().state_dict(), DUMMY_MODEL_PATH)
    
    main(DUMMY_IMAGE_PATH, DUMMY_MODEL_PATH, OUTPUT_CSV_PATH)
    