In [63]:
import os
import glob

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.io import read_image, ImageReadMode
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import random
from tqdm.notebook import tqdm
from torchvision.utils import make_grid

In [2]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7fdc9334e430>

In [3]:
images_dir = "../images"
os.makedirs(images_dir + "/positive", exist_ok=True)
os.makedirs(images_dir + "/negative", exist_ok=True)

In [4]:
generate_df_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=random.uniform(0, 1)),
    transforms.RandomVerticalFlip(p=random.uniform(0, 1)),
    transforms.RandomRotation(degrees=random.uniform(1, 50)),
    transforms.ColorJitter(
        brightness=random.uniform(0.6, 1.4),
        contrast=random.uniform(0.6, 1.4),
        saturation=random.uniform(0.6, 1.4),
        hue=random.uniform(0.01, 0.1)
    ),
    transforms.GaussianBlur(kernel_size=random.choice([3, 5])),
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

In [5]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
])

In [6]:

def generate_augmented_data(images_dir, output_dir, num_augmented_images=100, label=None, transform=generate_df_transform):
    for filename in os.listdir(images_dir):
        if filename == ".gitkeep":
            continue
        # Load the image
        image_path = os.path.join(images_dir, filename)
        image = Image.open(image_path)

        for i in range(num_augmented_images):
            # Apply the transformations
            augmented_image = transform(image)
            output_filename = f"{os.path.splitext(filename)[0]}_augmented_{i}.png"
            output_path = os.path.join(output_dir, output_filename)
            transforms.ToPILImage()(augmented_image).save(output_path)

In [7]:
generate_augmented_data("../train/positive/", images_dir + "/positive")

In [8]:
generate_augmented_data("../train/negative/", images_dir + "/negative")

In [9]:
dataset = ImageFolder(root="../images", transform=transform)

In [10]:
len(dataset)

800