In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches

In [2]:
import os, cv2
from pathlib import Path
from PIL import Image
import albumentations as A

In [3]:
from collections import defaultdict
from tqdm import tqdm
import random

In [4]:
# Paths to dataset directories
project_root = Path.cwd().parent
dataset_dir = project_root / 'dataset'

In [5]:
# Training dataset specific paths
train_dir = dataset_dir / 'train'
train_raw_images_dir = train_dir / 'raw_images'
train_aug_images_dir = train_dir / 'aug_images'
train_aug_labels_dir = train_dir / 'aug_labels'
train_raw_labels_dir = train_dir / 'raw_labels'

In [6]:
# Create directories if not exist
train_aug_images_dir.mkdir(parents=True, exist_ok=True)
train_aug_labels_dir.mkdir(parents=True, exist_ok=True)

In [7]:
print(f"Project Root: {project_root}")
print(f"Dataset Dir: {dataset_dir}")
print(f"Train Dir: {train_dir}")
print(f"Train Raw Images Dir: {train_raw_images_dir}")
print(f"Train Aug Images Dir: {train_aug_images_dir}")
print(f"Train Aug Labels Dir: {train_aug_labels_dir}")
print(f"Train Raw Labels Dir: {train_raw_labels_dir}")

Project Root: C:\Users\ADITHYA\OneDrive\Kesari
Dataset Dir: C:\Users\ADITHYA\OneDrive\Kesari\dataset
Train Dir: C:\Users\ADITHYA\OneDrive\Kesari\dataset\train
Train Raw Images Dir: C:\Users\ADITHYA\OneDrive\Kesari\dataset\train\raw_images
Train Aug Images Dir: C:\Users\ADITHYA\OneDrive\Kesari\dataset\train\aug_images
Train Aug Labels Dir: C:\Users\ADITHYA\OneDrive\Kesari\dataset\train\aug_labels
Train Raw Labels Dir: C:\Users\ADITHYA\OneDrive\Kesari\dataset\train\raw_labels


In [11]:
def load_image_and_labels(image_path, label_path):
    """
    Loads an image and its corresponding YOLO-format labels, converting
    bounding boxes to pixel coordinates.

    Args:
        image_path (Path): Path to the image file.
        label_path (Path): Path to the label file (YOLO format).

    Returns:
        tuple: A tuple containing:
            - image (np.array): The loaded image.
            - bboxes (list): A list of bounding boxes in [class_id, x_min, y_min, x_max, y_max] format.
    """
    # Read the image
    image = cv2.imread(str(image_path))
    if image is None:
        print(f"Warning: Could not read image at {image_path}. Skipping.")
        return None, []

    # Read the label file
    bboxes = []
    try:
        with open(label_path, 'r') as file:
            labels = file.readlines()

        for label in labels:
            parts = label.strip().split()
            if len(parts) != 5:
                print(f"Warning: Malformed label line '{label.strip()}' in {label_path}. Skipping.")
                continue

            class_id = int(parts[0])
            x_center = float(parts[1])
            y_center = float(parts[2])
            width = float(parts[3])
            height = float(parts[4])

            # Convert from YOLO format to pixel coordinates (x_min, y_min, x_max, y_max)
            # Ensure image.shape[1] (width) and image.shape[0] (height) are valid
            img_h, img_w = image.shape[0], image.shape[1]
            x_min = int((x_center - width / 2) * img_w)
            y_min = int((y_center - height / 2) * img_h)
            x_max = int((x_center + width / 2) * img_w)
            y_max = int((y_center + height / 2) * img_h)

            bboxes.append([class_id, x_min, y_min, x_max, y_max])
    except FileNotFoundError:
        print(f"Warning: Label file not found at {label_path}. Proceeding with no bounding boxes for this image.")
    except Exception as e:
        print(f"Error reading label file {label_path}: {e}. Skipping bounding boxes for this image.")

    return image, bboxes

In [12]:
# Load all images and corresponding labels
image_paths = sorted(list(train_raw_images_dir.glob("*.jpg")))
label_paths = sorted(list(train_raw_labels_dir.glob("*.txt")))

# Ensure image and label paths are aligned
# This is a crucial step if your file naming convention isn't strictly sequential.
# A more robust approach would be to create a dictionary mapping image filenames to label paths.
# For now, assuming names match and lists are sorted.
if len(image_paths) != len(label_paths):
    print("Warning: Number of images and label files do not match. Ensure proper alignment.")
    # You might want to implement more sophisticated matching here if filenames are different.

dataset = []
# Use tqdm for a progress bar while loading the dataset
print("Loading dataset...")
for image_path, label_path in tqdm(zip(image_paths, label_paths), total=len(image_paths), desc="Loading Images and Labels"):
    image, bboxes = load_image_and_labels(image_path, label_path)
    if image is not None: # Only append if image was loaded successfully
        dataset.append((image, bboxes))

print(f"Loaded {len(dataset)} images and their corresponding labels.")

Loading dataset...


Loading Images and Labels: 100%|████████████████████████████████████████████████████| 201/201 [00:01<00:00, 124.03it/s]

Loaded 201 images and their corresponding labels.





In [None]:
# Display 10 random images with bounding boxes
num_images_to_display = min(10, len(dataset))
if num_images_to_display > 0:
    print(f"\nDisplaying {num_images_to_display} loaded images with bounding boxes:")
    random_indices = random.sample(range(len(dataset)), num_images_to_display)

    plt.figure(figsize=(15, 10))
    for i, idx in enumerate(random_indices):
        image, bboxes = dataset[idx]

        plt.subplot(2, 5, i + 1) # Adjust subplot grid based on num_images_to_display
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # Convert BGR to RGB for matplotlib

        for bbox in bboxes:
            class_id, x_min, y_min, x_max, y_max = bbox
            # Draw rectangle
            rect = plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                                 fill=False, edgecolor='r', linewidth=2)
            plt.gca().add_patch(rect)
            # Add class ID text
            plt.text(x_min, y_min - 5, f'Class: {class_id}',
                     bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8, color='black')

        plt.title(f"Image {idx + 1}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()
else:
    print("No images were loaded to display.")