In [1]:
import torch
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision.io import read_image, write_jpeg
from torchvision.utils import save_image
import pandas as pd

#from mpmath.identification import transforms

In [2]:
image = read_image('./data/original/AKU01.jpg')

# Random Crop

In [3]:
for x in range(10):
    transform = v2.RandomCrop(size=256)
    out = transform(image)
    path = './data/augmented/AKU01_cropped%s.jpg' % x
    write_jpeg(out, path)

# Random Erasing

In [4]:
for x in range(10):
    transform = v2.RandomErasing(p=1)
    out = transform(image)
    path = './data/augmented/AKU01_erased%s.jpg' % x
    write_jpeg(out, path)

# Random Perspective

In [5]:
for x in range(10):
    transform = v2.RandomPerspective(p=1)
    out = transform(image)
    path = './data/augmented/AKU01_perspective%s.jpg' % x
    write_jpeg(out, path)

# Random Affine

In [6]:
for x in range(10):
    transform = v2.RandomAffine(degrees=70, scale=[0.2, 2], translate=[0.2, 0.7], shear=5)
    out = transform(image)
    path = './data/augmented/AKU01_affine%s.jpg' % x
    write_jpeg(out, path)

# Elastic Transform

In [7]:
for x in range(10):
    transform = v2.ElasticTransform(alpha=90.0, sigma=9.0)
    out = transform(image)
    path = './data/augmented/AKU01_elastic%s.jpg' % x
    write_jpeg(out, path)

# Gaussian Noise

In [8]:
def gauss_noise_tensor(img, sig):
    assert isinstance(img, torch.Tensor)
    dtype = img.dtype
    if not img.is_floating_point():
        img = img.to(torch.float32)
    
    sigma = sig
    
    out = img + sigma * torch.randn_like(img)
    
    if out.dtype != dtype:
        out = out.to(dtype)
        
    return out

for x in range(10):
    transform = gauss_noise_tensor
    out = transform(image, 10)
    path = './data/augmented/AKU01_gaussian%s.jpg' % x
    write_jpeg(out, path)

# Compose Transformation

In [9]:
for x in range(10):
    transform = v2.Compose([v2.RandomCrop(size=256),
                            v2.RandomErasing(p=1),
                            v2.RandomPerspective(p=1),
                            v2.RandomAffine(degrees=70, scale=[0.2, 2], translate=[0.2, 0.7], shear=5),
                            v2.ElasticTransform(alpha=90.0, sigma=9.0)])                    
    out = transform(image)
    path = './data/augmented/AKU01_composed%s.jpg' % x
    write_jpeg(out, path)

# CutMix & MixUp

In [10]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    
prepoc = v2.Compose([v2.PILToTensor(), v2.RandomCrop(size=400), v2.ToDtype(torch.float32, scale=True)])
    
aku_dataset = CustomImageDataset('./labels.csv', './data/original/', transform=prepoc)
dataloader = DataLoader(aku_dataset, batch_size=4, shuffle=True)

In [13]:
NUM_CLASSES = 2

cutmix = v2.CutMix(num_classes=NUM_CLASSES)
mixup = v2.MixUp(num_classes=NUM_CLASSES)
#cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

x = 0
for images, labels in dataloader:
    cutimages, cutlabels = cutmix(images, labels)
    for i in range(cutimages.size(0)):
        save_image(cutimages[i, :, :, :], './data/augmented/cutmix{}.jpg'.format(x))
        x+=1
        
x = 0
for images, labels in dataloader:
    miximages, mixlabels = mixup(images, labels)
    for i in range(miximages.size(0)):
        save_image(miximages[i, :, :, :], './data/augmented/mixup{}.jpg'.format(x))
        x+=1