In [1]:
import os
import random
import cv2
import albumentations as A
from collections import defaultdict

image_dir = "./YOLODatasetFullOri/images/train"
label_dir = "./YOLODatasetFullOri/labels/train"

augment = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.GaussNoise(p=0.3),
    A.Rotate(limit=10, p=0.5)
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

# Count samples per class
class_counts = defaultdict(int)
label_files = [f for f in os.listdir(label_dir) if f.endswith('.txt')]
for lf in label_files:
    with open(os.path.join(label_dir, lf), 'r') as f:
        for line in f:
            class_id = line.strip().split()[0]
            class_counts[class_id] += 1

if not class_counts:
    raise ValueError("No class labels found. Check if your label files are formatted correctly.")

min_count = min(class_counts.values())
target_count = max(min_count * 2, 20)

for class_id, count in class_counts.items():
    if count >= target_count:
        continue
    needed = target_count - count
    print(f"Augmenting class {class_id}: need {needed} more samples")
    matching_files = []
    for lf in label_files:
        with open(os.path.join(label_dir, lf), 'r') as f:
            for line in f:
                if line.strip().split()[0] == class_id:
                    matching_files.append(lf)
                    break

    for i in range(needed):
        src_label = random.choice(matching_files)
        base_name = src_label.replace('.txt', '')
        img_path_jpg = os.path.join(image_dir, base_name + '.jpg')
        img_path_png = os.path.join(image_dir, base_name + '.png')

        if os.path.exists(img_path_jpg):
            img_path = img_path_jpg
            img_ext = '.jpg'
        elif os.path.exists(img_path_png):
            img_path = img_path_png
            img_ext = '.png'
        else:
            print(f"Image not found for label: {src_label}")
            continue
        lbl_path = os.path.join(label_dir, src_label)

        image = cv2.imread(img_path)
        if image is None:
            continue
        with open(lbl_path, 'r') as f:
            lines = f.readlines()
        boxes = []
        classes = []
        for line in lines:
            parts = line.strip().split()
            if parts[0] == class_id:
                x, y, bw, bh = map(float, parts[1:])
                boxes.append([x, y, bw, bh])
                classes.append(class_id)

        if not boxes:
            continue
        augmented = augment(image=image, bboxes=boxes, class_labels=classes)
        aug_img = augmented['image']
        aug_boxes = augmented['bboxes']

        #save_img_name = src_image.replace('.jpg', f'_aug{i}.jpg')
        save_img_name = base_name + f'_aug{i}{img_ext}'
        save_lbl_name = src_label.replace('.txt', f'_aug{i}.txt')
        cv2.imwrite(os.path.join(image_dir, save_img_name), aug_img)
        with open(os.path.join(label_dir, save_lbl_name), 'w') as f:
            for box in aug_boxes:
                f.write(f"{class_id} {' '.join([f'{v:.6f}' for v in box])}\n")


  check_for_updates()


Augmenting class 0: need 585 more samples
Augmenting class 1: need 690 more samples
Augmenting class 2: need 574 more samples
Augmenting class 3: need 686 more samples
