In [None]:
import cv2
import albumentations as A
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


# Define the augmentation pipeline


In [None]:
transform = A.Compose([
    A.Rotate(limit=20, p=0.5),
    A.RandomScale(scale_limit=(0.1, 0.4), p=0.5),  # Randomly scale image (10% to 40% zoom)
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
], bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids']))

In [None]:
# Function to process each image and its corresponding label file
def process_image_and_label(image_path, bbox_file, output_images_dir, output_labels_dir, num_augments=5):
    # Load the image
    image = cv2.imread(image_path)
    height, width, _ = image.shape

    # Read bounding boxes
    with open(bbox_file, 'r') as f:
        lines = f.readlines()

    # YOLO format: (class, x_center, y_center, width, height)
    bboxes = [list(map(float, line.strip().split()[1:])) for line in lines]
    classes = [int(line.strip().split()[0]) for line in lines]

    # Function to save YOLO-format bbox file for augmented images
    def save_bbox_file(output_image_path, augmented_bboxes, augmented_classes, output_labels_dir):
        # Save the bounding boxes in YOLO format
        base_filename = os.path.basename(output_image_path)
        bbox_file_path = os.path.join(output_labels_dir, base_filename.replace('.jpg', '.txt'))

        with open(bbox_file_path, 'w') as f:
            for bbox, cls in zip(augmented_bboxes, augmented_classes):
                # The bounding boxes are in [x_center, y_center, width, height] (normalized)
                line = f"{cls} " + " ".join(map(str, bbox)) + "\n"
                f.write(line)
        print(f"Saved augmented bbox file at {bbox_file_path}")

    # Generate and save augmented images and bounding boxes
    for i in tqdm(range(num_augments)):  # Generate 'num_augments' examples
        augmented = transform(image=image, bboxes=bboxes, category_ids=classes)
        augmented_image = augmented['image']
        augmented_bboxes = augmented['bboxes']
        augmented_classes = augmented['category_ids']

        # Save augmented image
        output_image_path = os.path.join(output_images_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_aug_{i}.jpg")
        cv2.imwrite(output_image_path, augmented_image)

        # Save the augmented bounding boxes to a new file
        save_bbox_file(output_image_path, augmented_bboxes, augmented_classes, output_labels_dir)

        # Optional: Display original and augmented images side by side
        if i == 0:  # Only display the first augmented image (can remove this condition to display all)
            plt.figure(figsize=(12, 6))

            # Original Image
            plt.subplot(1, 2, 1)
            visualize(image, bboxes, classes, title="Original")

            # Augmented Image
            plt.subplot(1, 2, 2)
            visualize(augmented_image, augmented_bboxes, augmented_classes, title="Augmented")

            plt.tight_layout()
            # plt.show()

# Function to visualize images with bounding boxes
def visualize(image, bboxes, category_ids, output_path=None, title=None):
    image_copy = image.copy()
    height, width, _ = image.shape
    for bbox, cls in zip(bboxes, category_ids):
        x_center, y_center, box_width, box_height = bbox
        x_min = int((x_center - box_width / 2) * width)
        y_min = int((y_center - box_height / 2) * height)
        x_max = int((x_center + box_width / 2) * width)
        y_max = int((y_center + box_height / 2) * height)
        cv2.rectangle(image_copy, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
        cv2.putText(image_copy, str(cls), (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
    plt.imshow(cv2.cvtColor(image_copy, cv2.COLOR_BGR2RGB))
    if title:
        plt.title(title)
    plt.axis('off')
    if output_path:
        plt.savefig(output_path)
    # plt.show()

# Main processing function
def process_dataset(image_folder, label_folder, output_images_dir, output_labels_dir, num_augments=5):
    # Create directories if they don't exist
    os.makedirs(output_images_dir, exist_ok=True)
    os.makedirs(output_labels_dir, exist_ok=True)

    # Loop through all images in the image folder
    image_files = [f for f in os.listdir(image_folder) if f.endswith(('.jpg', '.jpeg', '.png'))]

    for image_file in tqdm(image_files):
        image_path = os.path.join(image_folder, image_file)
        label_file = os.path.join(label_folder, image_file.replace('.jpg', '.txt').replace('.jpeg', '.txt').replace('.png', '.txt'))

        if os.path.exists(label_file):
            process_image_and_label(image_path, label_file, output_images_dir, output_labels_dir, num_augments)
        else:
            print(f"Warning: Label file for {image_file} not found. Skipping.")

In [1]:
# Source Folders
image_folder = "/content/Hockey-Puck-Detection-1/train/images"
label_folder = "/content/Hockey-Puck-Detection-1/train/labels"

# Destination Folders
output_images_dir = "/content/drive/MyDrive/Puck Dataset/augmented_images"
output_labels_dir = "/content/drive/MyDrive/Puck Dataset/augmented_labels"

# Process the dataset (Generate 5 augmented examples per image)
process_dataset(image_folder, label_folder, output_images_dir, output_labels_dir, num_augments=5)

print(f"Augmented images saved in {output_images_dir} and labels saved in {output_labels_dir}.")