In [16]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision.transforms import functional as TF
# Image transformations
from torchvision import transforms
from PIL import Image

In [2]:
device = torch.device("cpu")

In [3]:
class Agumenter:
    def __init__(self, device, cutn=10, cut_size=512, cut_pow=1.):
        self.cut_pow = cut_pow
        self.cutn = cutn
        
        # Note: nn.Sequential is so we can backpropagate through.
        # Compose from torchvision is not differentiable.
        self.cut_size = cut_size
        self.augs = nn.Sequential(
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees=30),
            transforms.RandomPerspective(0.2,p=0.4),
            transforms.ColorJitter(hue=0.01, saturation=0.01)
        ).to(device)

        self.norm = transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                         (0.26862954, 0.26130258, 0.27577711)).to(device)
        
        self.resize = transforms.Resize(cut_size).to(device)

    def __call__(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(self.resize(cutout))
            
        batch = self.augs(torch.cat(cutouts, dim=0))
        return  batch

In [4]:
agumenter = Agumenter(device)

In [8]:
img = Image.open("/home/krzys/monarch.jpg")

In [12]:
tensor = transforms.functional.pil_to_tensor(img).unsqueeze(0)

In [13]:
tensor.shape

torch.Size([1, 3, 1536, 1920])

In [32]:
aug = agumenter(tensor)

In [33]:
for i in range(10):
    img = TF.to_pil_image(aug[i])
    img.save(f"{i}.png")