In [1]:
import os
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict

In [2]:
# Paths
BASE_DIR = Path("../datasets/processed")
CLASS_NAMES = ["pedestrian", "car", "cyclist"]
NUM_CLASSES = len(CLASS_NAMES)

In [None]:
# Validate and separate data-label pairs
def validate_yolo_dataset(split="train"):
    print(f"\nChecking {split.upper()} set...")
    image_dir = BASE_DIR / split / "images"
    label_dir = BASE_DIR / split / "labels"

    image_paths = sorted(list(image_dir.glob("*.png")))
    label_paths = sorted(list(label_dir.glob("*.txt")))

    valid_data = []
    invalid_data = []

    for img_path in image_paths:
        label_path = label_dir / f"{img_path.stem}.txt"

        if not label_path.exists():
            print(f"Missing label for {img_path.name}")
            invalid_data.append((img_path, None))
            continue

        with open(label_path, 'r') as f:
            lines = f.readlines()

        valid = True
        for line in lines:
            parts = line.strip().split()
            if len(parts) != 5:
                print(f"Invalid line format in {label_path.name}: {line}")
                valid = False
                break

            cls_id, x, y, w, h = map(float, parts)
            if not (0 <= cls_id < NUM_CLASSES):
                print(f"Invalid class ID in {label_path.name}: {cls_id}")
                valid = False
                break
            if not all(0 <= v <= 1 for v in [x, y, w, h]):
                print(f"Out-of-bounds values in {label_path.name}: {line}")
                valid = False
                break

        if valid:
            valid_data.append((img_path, label_path))
        else:
            invalid_data.append((img_path, label_path))

    print(f"✅ Valid samples: {len(valid_data)}")
    print(f"❌ Invalid samples: {len(invalid_data)}")
    return valid_data, invalid_data

In [None]:
# Run for both splits
train_data, train_invalid = validate_yolo_dataset("train")
val_data, val_invalid = validate_yolo_dataset("val")

In [None]:
# Visualize a few valid samples
def show_yolo_sample(data_pair):
    img_path, label_path = data_pair
    img = cv2.imread(str(img_path))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]

    with open(label_path, 'r') as f:
        for line in f:
            cls_id, x, y, bw, bh = map(float, line.strip().split())
            cls_id = int(cls_id)
            x1 = int((x - bw/2) * w)
            y1 = int((y - bh/2) * h)
            x2 = int((x + bw/2) * w)
            y2 = int((y + bh/2) * h)

            color = (0, 255, 0) if cls_id == 1 else (255, 0, 0) if cls_id == 0 else (255, 255, 0)
            cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
            cv2.putText(img, CLASS_NAMES[cls_id], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)

    plt.figure(figsize=(10, 6))
    plt.imshow(img)
    plt.title(img_path.name)
    plt.axis("off")
    plt.show()

In [None]:
# Show 3 random valid samples from train
import random
for _ in range(3):
    show_yolo_sample(random.choice(train_data))