<a href="https://colab.research.google.com/github/AlexHeyman/FewShotGANTraining/blob/main/Differential_Data_Augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [13]:
def rand_brightness(a):
    a = a + (torch.rand(a.size(0), 1, 1, 1, dtype=a.dtype, device=a.device) - 0.5)
    return a

In [14]:
def rand_saturation(a):
    a_mean = a.mean(dim=1, keepdim=True)
    a = (a - a_mean) * (torch.rand(a.size(0), 1, 1, 1, dtype=a.dtype, device=a.device) * 2) + a_mean
    return a

In [15]:
def rand_contrast(a):
    a_mean = a.mean(dim=[1, 2, 3], keepdim=True)
    a = (a - a_mean) * (torch.rand(a.size(0), 1, 1, 1, dtype=a.dtype, device=a.device) + 0.5) + a_mean
    return a

In [16]:
def rand_translation(a, ratio=0.125):
    shift_a, shift_b = int(a.size(2) * ratio + 0.5), int(a.size(3) * ratio + 0.5)
    translation_a = torch.randint(-shift_a, shift_a + 1, size=[a.size(0), 1, 1], device=a.device)
    translation_b = torch.randint(-shift_b, shift_b + 1, size=[a.size(0), 1, 1], device=a.device)
    grid_batch, grid_a, grid_b = torch.meshgrid(
        torch.arange(a.size(0), dtype=torch.long, device=a.device),
        torch.arange(a.size(2), dtype=torch.long, device=a.device),
        torch.arange(a.size(3), dtype=torch.long, device=a.device),
    )
    grid_a = torch.clamp(grid_a + translation_a + 1, 0, a.size(2) + 1)
    grid_b = torch.clamp(grid_b + translation_b + 1, 0, a.size(3) + 1)
    a_pad = F.pad(a, [1, 1, 1, 1, 0, 0, 0, 0])
    a = a_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_a, grid_b].permute(0, 3, 1, 2)
    return a

In [17]:
def rand_cutout(a, ratio=0.5):
    cutout_size = int(a.size(2) * ratio + 0.5), int(a.size(3) * ratio + 0.5)
    offset_a = torch.randint(0, a.size(2) + (1 - cutout_size[0] % 2), size=[a.size(0), 1, 1], device=a.device)
    offset_b = torch.randint(0, a.size(3) + (1 - cutout_size[1] % 2), size=[a.size(0), 1, 1], device=a.device)
    grid_batch, grid_a, grid_b = torch.meshgrid(
        torch.arange(a.size(0), dtype=torch.long, device=a.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=a.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=a.device),
    )
    grid_a = torch.clamp(grid_a + offset_a - cutout_size[0] // 2, min=0, max=a.size(2) - 1)
    grid_b = torch.clamp(grid_b + offset_b - cutout_size[1] // 2, min=0, max=a.size(3) - 1)
    mask = torch.ones(a.size(0), a.size(2), a.size(3), dtype=a.dtype, device=a.device)
    mask[grid_batch, grid_a, grid_b] = 0
    a = a * mask.unsqueeze(1)
    return a

In [20]:
class Differential_Augmentation(nn.Module):

    def __init__(self):
        super().__init__()
        self._augment = [
                rand_brightness,
                rand_saturation,
                rand_contrast,
                rand_translation,
                rand_cutout,
            ]

    @torch.jit.ignore
    def forward(self, images: torch.Tensor) -> torch.Tensor:
        for augment in self._augment:
            images = augment(images)
        return images.contiguous()