In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm

In [2]:
RAW_PATH = "../dataset/raw"
AUG_PATH = "../dataset/augmented"

os.makedirs(AUG_PATH, exist_ok=True)

In [3]:
class_counts = {}

for cls in os.listdir(RAW_PATH):
    cls_path = os.path.join(RAW_PATH, cls)
    if os.path.isdir(cls_path):
        class_counts[cls] = len(os.listdir(cls_path))
class_counts

{'cardboard': 259,
 'glass': 401,
 'metal': 328,
 'paper': 476,
 'plastic': 386,
 'trash': 110,
 'unknown': 0}

In [4]:
TARGET_COUNT = max(class_counts.values())
TARGET_COUNT

476

In [5]:
def augment_image(img):
    aug_images = []

    aug_images.append(cv2.flip(img, 1))

    for angle in [-15, 15]:
        h, w = img.shape[:2]
        M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1)
        aug_images.append(cv2.warpAffine(img, M, (w, h)))

    bright = cv2.convertScaleAbs(img, alpha=1.2, beta=20)
    aug_images.append(bright)

    return aug_images

In [8]:
Valid_Extentions = (".jpg", ".jpeg", ".png")

for cls, count in class_counts.items():
    src_cls = os.path.join(RAW_PATH, cls)
    dst_cls = os.path.join(AUG_PATH, cls)
    os.makedirs(dst_cls, exist_ok=True)

    images = [
        f for f in os.listdir(src_cls)
        if f.lower().endswith(Valid_Extentions)
    ]

    if len(images) == 0:
        print(f" Skipping class '{cls}' (no images found)")
        continue

    for img_name in images:
        img_path = os.path.join(src_cls, img_name)
        img = cv2.imread(img_path)

        if img is None:
            continue  # skip unreadable files

        cv2.imwrite(os.path.join(dst_cls, img_name), img)

    i = 0
    while len(os.listdir(dst_cls)) < TARGET_COUNT:
        img_name = images[i % len(images)]
        img = cv2.imread(os.path.join(src_cls, img_name))

        if img is None:
            i +=1
            continue

        for aug in augment_image(img):
            if len(os.listdir(dst_cls)) >= TARGET_COUNT:
                break
            new_name = f"aug_{len(os.listdir(dst_cls))}.jpg"
            cv2.imwrite(os.path.join(dst_cls, new_name), aug)

        i += 1

 Skipping class 'unknown' (no images found)


In [9]:
aug_counts = {}

for cls in os.listdir(AUG_PATH):
    aug_counts[cls] = len(os.listdir(os.path.join(AUG_PATH, cls)))

aug_counts

{'cardboard': 476,
 'glass': 476,
 'metal': 476,
 'paper': 476,
 'plastic': 476,
 'trash': 476,
 'unknown': 0}