In [5]:
import os
import cv2
import numpy as np
from ultralytics import YOLO

NUM_IMAGES = 20
CONF_THRESHOLD = 0.7

TEST_IMAGE_DIR = '../data/segmentation-dataset-6by2/train-augmented/images'
SAVE_ROOT = '../../data/yolo-segmentation-results'
MODEL_PATH = '../../models/yolo-segmentation-weights.pt'

# Load model
model = YOLO(MODEL_PATH)

# Read image files
image_files = [f for f in os.listdir(TEST_IMAGE_DIR) if f.endswith((".jpg", ".png"))]
os.makedirs(SAVE_ROOT, exist_ok=True)

# Mapping from class ID to name
class_names = model.names

ORDERED_WAVE_LABELS = ['I', 'V1', 'II', 'V2', 'III', 'V3',
                       'aVL', 'V4', 'aVR', 'V5', 'aVF', 'V6', 'II_ext']

In [6]:
def merge_boxes(box1, box2):
    x1 = min(box1[0], box2[0])
    y1 = min(box1[1], box2[1])
    x2 = max(box1[2], box2[2])
    y2 = max(box1[3], box2[3])
    return [x1, y1, x2, y2]

def sort_wave_boxes(boxes):
    # Group into rows based on y-coordinate proximity (30 pixel threshold)
    rows = []
    for box in sorted(boxes, key=lambda b: b[1]):  # sort by y1
        placed = False
        for row in rows:
            if abs(row[-1][1] - box[1]) < 30:  # y1 close enough
                row.append(box)
                placed = True
                break
        if not placed:
            rows.append([box])
    # Sort each row left to right
    for row in rows:
        row.sort(key=lambda b: b[0])
    # Flatten and return
    flat = [b for row in rows for b in row]
    return flat

for idx in range(15):
    image_file = image_files[idx]
    image_path = os.path.join(TEST_IMAGE_DIR, image_file)
    base_name = os.path.splitext(image_file)[0]

    # Inference
    results = model(image_path)[0]

    # Read and prepare image
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2RGB)

    # Create save folder
    image_save_dir = os.path.join(SAVE_ROOT, base_name)
    os.makedirs(image_save_dir, exist_ok=True)

    # Get class 0 boxes above threshold
    wave_boxes = []
    for box in results.boxes:
        cls_id = int(box.cls[0])
        conf = float(box.conf[0])
        if conf >= CONF_THRESHOLD and cls_id == 0:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            wave_boxes.append([x1, y1, x2, y2])

    # Special handling for II extended
    if len(wave_boxes) == 14:
        wave_boxes[-2:] = [merge_boxes(wave_boxes[-2], wave_boxes[-1])]

    # Sort in row-wise (top to bottom, then left to right) order
    wave_boxes = sort_wave_boxes(wave_boxes)

    # Save full image with real class labels
    for box in results.boxes:
        conf = float(box.conf[0])
        if conf < CONF_THRESHOLD:
            continue
        cls_id = int(box.cls[0])
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        label = f"{class_names[cls_id]} ({conf:.2f})"
        color = (0, 255, 0)
        cv2.rectangle(img_rgb, (x1, y1), (x2, y2), color, 2)
        cv2.putText(img_rgb, label, (x1, max(y1 - 10, 10)), cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 2)

    # Save cropped waves with correct label names
    for i, box in enumerate(wave_boxes):
        if i >= len(ORDERED_WAVE_LABELS):
            break
        label = ORDERED_WAVE_LABELS[i]
        x1, y1, x2, y2 = box
        crop = img[y1:y2, x1:x2]
        crop_path = os.path.join(image_save_dir, f"{label}.jpg")
        cv2.imwrite(crop_path, crop, [cv2.IMWRITE_JPEG_QUALITY, 100])

    # Save full annotated image
    full_img_path = os.path.join(image_save_dir, "full.jpg")
    cv2.imwrite(full_img_path, cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, 100])


image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/lead-segmentation/scripts/../data/segmentation-dataset-6by2/train-augmented/images/1926_6by2_aug0.jpg: 320x640 12 lead_containers, 1 label_I, 1 label_II, 1 label_III, 1 label_aVR, 1 label_aVL, 1 label_aVF, 1 label_V1, 1 label_V2, 1 label_V3, 1 label_V4, 1 label_V5, 1 label_V6, 210.8ms
Speed: 2.0ms preprocess, 210.8ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 640)

image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/lead-segmentation/scripts/../data/segmentation-dataset-6by2/train-augmented/images/686_6by2_aug1.jpg: 320x640 12 lead_containers, 1 label_I, 1 label_II, 1 label_III, 1 label_aVR, 1 label_aVL, 1 label_aVF, 1 label_V1, 1 label_V2, 1 label_V3, 1 label_V4, 236.9ms
Speed: 4.2ms preprocess, 236.9ms inference, 0.9ms postprocess per image at shape (1, 3, 320, 640)

image 1/1 /home/abdullah-bin-mansoor/Desktop/ECG Project/lead-segmentation/scripts/../data/segmentation-dataset-6by2/train-augmented/images/