In [1]:
import os
from PIL import Image
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm.notebook import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

In [2]:
copies_per_image = 3

source_dir = "archive/train"
target_dir = f"archive/train_augment_standard_{copies_per_image}_copies"
os.makedirs(target_dir, exist_ok=True)

In [3]:
augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
    transforms.RandomResizedCrop(32, scale=(0.7, 1.0)),
    transforms.ToTensor()
])

In [4]:
dataset = datasets.ImageFolder(source_dir)
class_to_idx = dataset.class_to_idx
idx_to_class = {v: k for k, v in class_to_idx.items()}

for class_name in class_to_idx.keys():
    os.makedirs(os.path.join(target_dir, class_name), exist_ok=True)

def process_image(idx, img_path, label):
    img = Image.open(img_path).convert("RGB")
    class_name = idx_to_class[label]
    saved_paths = []

    for i in range(copies_per_image):
        augmented = augment_transform(img)
        filename = f"aug_{idx}_{i}.png"
        save_path = os.path.join(target_dir, class_name, filename)
        save_image(augmented, save_path)
        saved_paths.append(save_path)
    
    return saved_paths

max_workers = os.cpu_count() or 4
print(max_workers)
futures = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    for idx, (img_path, label) in enumerate(dataset.imgs):
        futures.append(executor.submit(process_image, idx, img_path, label))

    for future in tqdm(as_completed(futures), total=len(futures), desc="Standard Augmentation"):
        _ = future.result()

print("Augment DONE.")

32


Standard Augmentation:   0%|          | 0/90000 [00:00<?, ?it/s]

Augment DONE.
