In [20]:
import cv2
import numpy as np
import os
import random

## Global Parameters

In [21]:
IMAGE_EXTENSIONS = (".jpg", ".png", ".jpeg")
ROTATION_RANGE = (-10, 10)
BRIGHTNESS_RANGE = (-15, 15)
CONTRAST_RANGE = (0.9, 1.1)
NOISE_STD_RANGE = (2, 6)
ZOOM_RANGE = (0.9, 1.0)

## Dataset Paths

In [22]:
INPUT_PATH = "../data_split/train"
OUTPUT_PATH = "../data_split/train_aug"
os.makedirs(OUTPUT_PATH, exist_ok=True)


## Image Augmentation Functions

In [23]:
def random_augment(img):
    h, w = img.shape[:2]

    angle = random.uniform(*ROTATION_RANGE)
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1)
    img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)

    if random.random() < 0.5:
        img = cv2.flip(img, 1)

    alpha = random.uniform(*CONTRAST_RANGE)
    beta = random.randint(*BRIGHTNESS_RANGE)
    img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)

    noise_std = random.uniform(*NOISE_STD_RANGE)
    noise = np.random.normal(0, noise_std, img.shape)
    img = np.clip(img + noise, 0, 255).astype(np.uint8)

    zoom = random.uniform(*ZOOM_RANGE)
    nh, nw = int(h * zoom), int(w * zoom)
    y1 = (h - nh) // 2
    x1 = (w - nw) // 2
    img = img[y1:y1+nh, x1:x1+nw]
    img = cv2.resize(img, (w, h))

    return img


In [24]:
def count_images(path):
    return {
        cls: len([
            f for f in os.listdir(os.path.join(path, cls))
            if f.lower().endswith(IMAGE_EXTENSIONS)
        ])
        for cls in os.listdir(path)
        if os.path.isdir(os.path.join(path, cls))
    }

In [25]:
counts = count_images(INPUT_PATH)

before_total = sum(counts.values())
num_classes = len(counts)

target_total = int(before_total * 1.5)
target_per_class = target_total // num_classes

print("Target total images:", target_total)
print("Target per class:", target_per_class)

cnt = 0
for cls, count in counts.items():
    in_cls = os.path.join(INPUT_PATH, cls)
    out_cls = os.path.join(OUTPUT_PATH, cls)
    os.makedirs(out_cls, exist_ok=True)

    images = [
        f for f in os.listdir(in_cls)
        if f.lower().endswith(IMAGE_EXTENSIONS)
    ]

    # Copy originals safely
    for img_name in images:
        img_path = os.path.join(in_cls, img_name)
        img = cv2.imread(img_path)

        if img is None:
            print(f"[SKIP] Cannot read {img_path}")
            cnt += 1
            continue

        cv2.imwrite(os.path.join(out_cls, img_name), img)

    current = count
    idx = 0

    while current < target_per_class:
        img_name = random.choice(images)
        img_path = os.path.join(in_cls, img_name)
        img = cv2.imread(img_path)

        if img is None:
            continue

        aug = random_augment(img)

        cv2.imwrite(
            os.path.join(out_cls, f"{os.path.splitext(img_name)[0]}_aug{idx}.jpg"),
            aug
        )

        current += 1
        idx += 1

print("Train augmentation completed")
print("Skipped images:", cnt)

Target total images: 2055
Target per class: 342
[SKIP] Cannot read ../data_split/train\cardboard\2ec9d19b-8027-4c77-a13f-5eee033b9868.jpg
[SKIP] Cannot read ../data_split/train\cardboard\31381a44-38d6-4a44-9384-7690727801bc.jpg
[SKIP] Cannot read ../data_split/train\cardboard\345bdb67-4190-4235-a16f-b60c1556a28d.jpg
[SKIP] Cannot read ../data_split/train\cardboard\4840d678-7af4-4a2d-bda1-338c2f2a59c5.jpg
[SKIP] Cannot read ../data_split/train\cardboard\509251d8-4e3a-4f1e-aabc-4d034b0f2455.jpg
[SKIP] Cannot read ../data_split/train\cardboard\5b7da318-c2ab-4c29-8ace-19895a890840.jpg
[SKIP] Cannot read ../data_split/train\cardboard\8617221e-dc90-48fe-a116-46350b5f814e.jpg
[SKIP] Cannot read ../data_split/train\cardboard\88ce5fbf-e9c7-40ad-87a6-deffe95d8ee8.jpg
[SKIP] Cannot read ../data_split/train\cardboard\bff223bf-1a84-4d38-a486-c3f4c9bfef5e.jpg
[SKIP] Cannot read ../data_split/train\cardboard\d5856b01-c157-4e34-b921-80f29252976a.jpg
[SKIP] Cannot read ../data_split/train\glass\14bbfb5

In [26]:
def count_images_per_class(dataset_path):
    counts = {}
    total = 0

    for cls in sorted(os.listdir(dataset_path)):
        cls_path = os.path.join(dataset_path, cls)
        if not os.path.isdir(cls_path):
            continue

        num_images = len([
            f for f in os.listdir(cls_path)
            if f.lower().endswith(IMAGE_EXTENSIONS)
        ])

        counts[cls] = num_images
        total += num_images

    return counts, total

In [27]:
print("\nClass distribution after augmentation:")
for cls in os.listdir(OUTPUT_PATH):
    cls_path = os.path.join(OUTPUT_PATH, cls)
    if os.path.isdir(cls_path):
        count = len([f for f in os.listdir(cls_path)
                    if f.lower().endswith(IMAGE_EXTENSIONS)])
        print(f"   {cls}: {count} images")


Class distribution after augmentation:
   cardboard: 332 images
   glass: 332 images
   metal: 333 images
   paper: 321 images
   plastic: 323 images
   trash: 339 images


In [28]:
before_counts, before_total = count_images_per_class(INPUT_PATH)

print("Images BEFORE augmentation:")

for cls, cnt in before_counts.items():
    print("="*40)
    print(f"{cls:10s}: {cnt}")
    print(f"Percentage of {cls:10s}: {cnt / before_total * 100:.2f}%")
    print(f"count after 50% augmentation --> {int(cnt * 1.5)}")
print("TOTAL:", before_total)
print ("After 50 % augmentation --> ", before_total * 1.5)


Images BEFORE augmentation:
cardboard : 181
Percentage of cardboard : 13.21%
count after 50% augmentation --> 271
glass     : 280
Percentage of glass     : 20.44%
count after 50% augmentation --> 420
metal     : 229
Percentage of metal     : 16.72%
count after 50% augmentation --> 343
paper     : 333
Percentage of paper     : 24.31%
count after 50% augmentation --> 499
plastic   : 270
Percentage of plastic   : 19.71%
count after 50% augmentation --> 405
trash     : 77
Percentage of trash     : 5.62%
count after 50% augmentation --> 115
TOTAL: 1370
After 50 % augmentation -->  2055.0


In [29]:
after_counts, after_total = count_images_per_class(OUTPUT_PATH)

print("\nImages AFTER augmentation:")
for cls, cnt in after_counts.items():
    print(f"{cls:10s}: {cnt}")
print("TOTAL:", after_total)



Images AFTER augmentation:
cardboard : 332
glass     : 332
metal     : 333
paper     : 321
plastic   : 323
trash     : 339
TOTAL: 1980


In [30]:
print("Augmentation ratio ", (after_total - before_total) / before_total * 100)

Augmentation ratio  44.52554744525548
