In [2]:
import os
import random
from PIL import Image, ImageOps, ImageEnhance, ImageFilter
from tqdm import tqdm

# ------------------
# Config
# ------------------
RAW_DATA_DIR = "raw_data"      # Root folder containing all class subfolders
OUTPUT_DIR = "augmented_data"  # Augmented dataset output folder

TARGET_TOTAL = 12000
TRAIN_TARGET = 10000
VAL_TARGET = 2000

VAL_ORIGINALS = {
    "non_defect": 10,
    "hole": 1,
    "lycra_cut": 1,
    "needln": 1,
    "twoply": 1
}

PER_ORIGINAL_COUNTS = {
    "non_defect": 120,
    "hole": 200,
    "lycra_cut": 200,
    "needln": 200,
    "twoply": 200
}

PROBS = {
    "hflip": 0.5,
    "vflip": 0.3,
    "rotate180": 0.4,    # probability of rotating exactly 180 degrees
    "brightness": 0.4,
    "contrast": 0.4,
    "sharpness": 0.3,
    "blur": 0.2,
    "zoom": 0.4       # probability of zooming in only
}

# ------------------
# Helper functions
# ------------------
def collect_images(root=RAW_DATA_DIR):
    classes = {}
    for root_dir, dirs, files in os.walk(root):
        if root_dir == root:
            continue
        cls_name = os.path.basename(root_dir)
        image_files = [
            os.path.join(root_dir, f)
            for f in files
            if f.lower().endswith((".jpg", ".jpeg", ".png"))
        ]
        if image_files:
            classes[cls_name] = image_files
    return classes


def split_originals(classes):
    train_split = {}
    val_split = {}
    for cls, images in classes.items():
        random.shuffle(images)
        val_count = VAL_ORIGINALS.get(cls, 0)
        val_split[cls] = images[:val_count]
        train_split[cls] = images[val_count:]
    return train_split, val_split


def pad_to_square(image, fill_color=(0, 0, 0)):
    w, h = image.size
    if w == h:
        return image
    max_side = max(w, h)
    delta_w = max_side - w
    delta_h = max_side - h
    padding = (delta_w // 2, delta_h // 2,
               delta_w - delta_w // 2, delta_h - delta_h // 2)
    return ImageOps.expand(image, padding, fill=fill_color)


def zoom_out(image, scale=0.8, fill_color=(0, 0, 0)):
    """Zoom out by scaling down and padding to original size."""
    w, h = image.size
    new_w = int(w * scale)
    new_h = int(h * scale)
    image_resized = image.resize((new_w, new_h), Image.LANCZOS)

    background = Image.new("RGB", (w, h), fill_color)
    offset = ((w - new_w) // 2, (h - new_h) // 2)
    background.paste(image_resized, offset)
    return background


def augment_image(image):
    if random.random() < PROBS["hflip"]:
        image = ImageOps.mirror(image)
    if random.random() < PROBS["vflip"]:
        image = ImageOps.flip(image)
    if random.random() < PROBS["rotate180"]:
        image = image.rotate(180, expand=True)
    if random.random() < PROBS["brightness"]:
        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(random.uniform(0.8, 1.2))
    if random.random() < PROBS["contrast"]:
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(random.uniform(0.8, 1.2))
    if random.random() < PROBS["sharpness"]:
        enhancer = ImageEnhance.Sharpness(image)
        image = enhancer.enhance(random.uniform(0.8, 1.5))
    if random.random() < PROBS["blur"]:
        image = image.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 1.5)))
    if random.random() < PROBS["zoom"]:
        image = zoom_out(image, scale=random.uniform(1.1, 1.3))
    return image


def save_augmented(images, cls, split, count_per_original):
    save_dir = os.path.join(OUTPUT_DIR, split, cls)
    os.makedirs(save_dir, exist_ok=True)

    idx = 0
    for img_path in tqdm(images, desc=f"Augmenting {split} - {cls}"):
        img = Image.open(img_path).convert("RGB")
        for _ in range(count_per_original):
            aug_img = augment_image(img)
            aug_img = pad_to_square(aug_img)
            aug_img.save(os.path.join(save_dir, f"{cls}_{idx}.png"))
            idx += 1


def copy_originals(images, cls, split):
    save_dir = os.path.join(OUTPUT_DIR, split, cls)
    os.makedirs(save_dir, exist_ok=True)
    for img_path in images:
        img = Image.open(img_path).convert("RGB")
        img = pad_to_square(img)
        img.save(os.path.join(save_dir, os.path.basename(img_path)))


# ------------------
# Main
# ------------------
def main():
    random.seed(42)

    classes = collect_images(RAW_DATA_DIR)
    print("\nFound images per class:")
    for cls, imgs in classes.items():
        print(f"{cls}: {len(imgs)}")

    train_split, val_split = split_originals(classes)

    for cls, images in train_split.items():
        count_per_original = PER_ORIGINAL_COUNTS.get(cls, 1)
        save_augmented(images, cls, "train", count_per_original)

    for cls, images in val_split.items():
        copy_originals(images, cls, "val")

    print("\n✅ Augmentation complete. Output at:", OUTPUT_DIR)


if __name__ == "__main__":
    main()



Found images per class:
lycra_cut: 2
twoply: 3
needln: 5
hole: 2
non_defect: 50


Augmenting train - lycra_cut: 100%|██████████| 1/1 [00:16<00:00, 16.05s/it]
Augmenting train - twoply: 100%|██████████| 2/2 [00:33<00:00, 16.55s/it]
Augmenting train - needln: 100%|██████████| 4/4 [01:02<00:00, 15.74s/it]
Augmenting train - hole: 100%|██████████| 1/1 [00:15<00:00, 15.14s/it]
Augmenting train - non_defect: 100%|██████████| 40/40 [06:13<00:00,  9.34s/it]


✅ Augmentation complete. Output at: augmented_data



