In [1]:
import os
from PIL import Image
from torchvision import transforms, datasets
from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy
from torchvision.utils import save_image
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

In [2]:
copies_per_image = 3

source_dir = "archive/train"
target_dir = f"archive/train_augment_advanced_{copies_per_image}_copies"
os.makedirs(target_dir, exist_ok=True)

In [3]:
augment_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor()
])

In [4]:
dataset = datasets.ImageFolder(source_dir)
class_to_idx = dataset.class_to_idx
idx_to_class = {v: k for k, v in class_to_idx.items()}

for class_name in class_to_idx.keys():
    os.makedirs(os.path.join(target_dir, class_name), exist_ok=True)

def process_image(idx, img_path, label):
    img = Image.open(img_path).convert("RGB")
    class_name = idx_to_class[label]
    saved_paths = []

    for i in range(copies_per_image):
        augmented = augment_transform(img)
        filename = f"aug_{idx}_{i}.png"
        save_path = os.path.join(target_dir, class_name, filename)
        save_image(augmented, save_path)
        saved_paths.append(save_path)

    return saved_paths

max_workers = os.cpu_count() or 4
print(max_workers)

futures = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    for idx, (img_path, label) in enumerate(dataset.imgs):
        futures.append(executor.submit(process_image, idx, img_path, label))

    for future in tqdm(as_completed(futures), total=len(futures), desc="Advanced Augmentation"):
        _ = future.result()

print("Augment DONE.")

32


Advanced Augmentation: 100%|██████████| 90000/90000 [10:06<00:00, 148.50it/s]

Augment DONE.



