In [1]:
import os
import torch
import random
from torchvision import transforms
from PIL import Image
import pandas as pd

In [2]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7fead4f38150>

In [3]:
images_dir = "../data/images"
output_dir = "../data/augmented_images"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(images_dir, exist_ok=True)

In [4]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=random.uniform(0, 1)),
    transforms.RandomVerticalFlip(p=random.uniform(0, 1)),
    transforms.RandomRotation(degrees=random.uniform(1, 50)),
    transforms.ColorJitter(
        brightness=random.uniform(0.6, 1.4),
        contrast=random.uniform(0.6, 1.4),
        saturation=random.uniform(0.6, 1.4),
        hue=random.uniform(0.01, 0.1)
    ),
    transforms.RandomResizedCrop(
        size=(256, 256),
        scale=(0.8, 1.2),
        ratio=(0.8, 1.2)
    ),
    transforms.GaussianBlur(kernel_size=random.choice([3, 5])),
    transforms.ToTensor()
])

In [5]:
transform

Compose(
    RandomHorizontalFlip(p=0.658732760284061)
    RandomVerticalFlip(p=0.4614267028513931)
    RandomRotation(degrees=[-19.65769329741804, 19.65769329741804], interpolation=nearest, expand=False, fill=0)
    ColorJitter(brightness=(0.2608931058399616, 1.7391068941600385), contrast=(0.3494842903790305, 1.6505157096209695), saturation=(0.021518257933764384, 1.9784817420662355), hue=(-0.04947055436524255, 0.04947055436524255))
    RandomResizedCrop(size=(256, 256), scale=(0.8, 1.2), ratio=(0.8, 1.2), interpolation=bilinear, antialias=warn)
    GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0))
    ToTensor()
)

In [6]:
num_augmented_images = 30
labeled_data = []

In [7]:
for filename in os.listdir(images_dir):
    if filename == ".gitkeep":
        continue
    # Load the image
    image_path = os.path.join(images_dir, filename)
    image = Image.open(image_path)
    image_tensor = transforms.ToTensor()(image)
    flattened_tensor = image_tensor.view(-1)
    labeled_data.append((flattened_tensor, 1))


    for i in range(num_augmented_images):
        # Apply the transformations
        augmented_image = transform(image)
        labeled_data.append((augmented_image.view(-1), 1))
        # Save the augmented image
        output_filename = f"{os.path.splitext(filename)[0]}_augmented_{i}.png"
        output_path = os.path.join(output_dir, output_filename)
        transforms.ToPILImage()(augmented_image).save(output_path)

In [8]:
for item in labeled_data:
    assert(item[0].dtype == torch.float32)

In [9]:
image_data, labels = zip(*labeled_data)
df = pd.DataFrame({"data": image_data, "label": labels})

In [10]:
len(labeled_data)

124