In [None]:
import torch
import torchvision
import kornia as K
from kornia.augmentation import *
from kornia.geometry.transform import translate, scale, shear, rotate
import PIL.Image
import os
import pyspng
import numpy as np
from matplotlib import pyplot as plt
import cv2
from torchvision import transforms
from cleanfid import fid
import random
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

p = 1
size = [512, 512]
source_path = '/home/featurize/data/Dawn'

grayscale = True

In [None]:
class Translation(object):
    def __init__(self, ratio = 0.25, choice = 'both'):
        self.ratio = ratio
        self.choice = choice
    
    def __call__(self, tensor_in : torch.tensor):
        assert tensor_in.ndim == 3 # [C, H, W]
        assert tensor_in.dtype == torch.float32
        _, h, w = tensor_in.shape
        _, h, w = tensor_in.shape
        max_distance_x = self.ratio * w
        max_distance_y = self.ratio * h
        x = random.uniform(-max_distance_x, max_distance_x)
        y = random.uniform(-max_distance_y, max_distance_y)
        translate_dict = {
        'x' : torch.tensor([[x, 0]]),
        'y' : torch.tensor([[0, y]]),
        'both' :torch.tensor([[x, y]]),
        }
        translate_factor = translate_dict[self.choice]
        tensor_out = translate(tensor_in.unsqueeze(0), translate_factor, padding_mode='reflection').squeeze(0)


        return tensor_out


In [None]:
class Zoom(object):
    def __init__(self, min_in_ratio = 0.8, max_out_ratio = 1.2, choice = 'in'):
        self.min_in_ratio = min_in_ratio
        self.max_out_ratio = max_out_ratio
        self.choice = choice
    
    def __call__(self, tensor_in : torch.tensor):
        assert tensor_in.ndim == 3 # [C, H, W]
        assert tensor_in.dtype == torch.float32
        
        in_ratio_x = random.uniform(1, self.max_out_ratio)
        in_ratio_y = random.uniform(1, self.max_out_ratio)
        
        out_ratio_x = random.uniform(self.min_in_ratio, 1)
        out_ratio_y = random.uniform(self.min_in_ratio, 1)

        scale_dict = {
        'in' : torch.tensor([[in_ratio_x, in_ratio_y],]),
        'out' : torch.tensor([[out_ratio_x, out_ratio_y]]),
        }
        
        scale_factor = scale_dict[self.choice]
        tensor_out = scale(tensor_in.unsqueeze(0), scale_factor, padding_mode='reflection').squeeze(0)

        return tensor_out


In [None]:
class Shear(object):
    def __init__(self, min_angle = 0.15, max_angle = 0.2):
        self.min_angle = min_angle
        self.max_angle = max_angle
    
    def __call__(self, tensor_in : torch.tensor):
        assert tensor_in.ndim == 3 # [C, H, W]
        assert tensor_in.dtype == torch.float32
        
        shear_x = random.uniform(self.min_angle, self.max_angle)
        shear_y = random.uniform(self.min_angle, self.max_angle)
        
        shear_factor = torch.tensor([[shear_x, shear_y]])

        tensor_out = shear(tensor_in.unsqueeze(0), shear_factor, padding_mode='reflection').squeeze(0)

        return tensor_out


In [None]:

class Rotate(object):
    def __init__(self, angle = None, min_angle = -15, max_angle =15):
        self.angle = angle
        self.min_angle = min_angle
        self.max_angle = max_angle
        self.i = 0
    
    def __call__(self, tensor_in : torch.tensor):
        assert tensor_in.ndim == 3 # [C, H, W]
        assert tensor_in.dtype == torch.float32
        
        self.angle = random.uniform(self.min_angle, self.max_angle)
     
        
        self.angle = torch.tensor(self.angle)

        tensor_out = rotate(tensor_in.unsqueeze(0), self.angle, padding_mode='reflection').squeeze(0)

        return tensor_out


In [None]:
def aug_img(img_in : np.ndarray,aug):    
    
    # img_in [H, W, C] for rgb or [H, W] for gray,  range[0, 255]
    # img_out [C, H, W],  range[0, 255]

    
    if img_in.ndim == 2:
        img_in = img_in[:, :, np.newaxis] # HW => HWC
    img_in = img_in.transpose(2, 0, 1) # HWC => CHW
    assert img_in.ndim == 3 # [C, H, W]
    assert img_in.dtype == np.uint8
    
    tensor_in = torch.tensor(img_in / 255).to(torch.float32)  # [0,255] -> [0,1]
    assert tensor_in.shape[0] == 1 or tensor_in.shape[0] == 3 # tensor_in [C, H, W] 
    if isinstance(aug, list):
        for aug_ in aug:
            tensor_out = aug_(tensor_in)
    else:        
        tensor_out = aug(tensor_in)
    img_out = tensor_transform_reverse(tensor_out) # ndarray [C, H, W] range[0, 255]
    return img_out

In [None]:
def tensor_transform_reverse(image_tensor):
    assert len(image_tensor.shape) == 3
    
    tensor = torch.zeros(image_tensor.size()).type_as(image_tensor)
    tensor = image_tensor
    # if image_tensor.shape[0] == 3:
    #     tensor[0,:,:] = image_tensor[0,:,:] * 0.229 + 0.485
    #     tensor[1,:,:] = image_tensor[1,:,:] * 0.224 + 0.456    
    #     tensor[2,:,:] = image_tensor[2,:,:] * 0.225 + 0.406
    # elif image_tensor.shape[0] == 1:
    #     tensor[0,:,:] = image_tensor[0,:,:] * 0.5 + 0.5
    # else:
    #     return
    image_np = (tensor * 255).numpy().astype(np.uint8)
    return image_np

In [None]:
aug_dict = {
    'Translation_x':Translation(choice = 'x'),
    'Translation_y':Translation(choice = 'y'),
    'Translation_both':Translation(choice = 'both'),
    'Zoom_in':Zoom(choice = 'in', min_in_ratio = 0.80, max_out_ratio = 1.20),
    'Zoom_out':Zoom(choice = 'out'),
    'Shear':Shear(),
    'Rotate':Rotate(),
    'FlipRot_Horizontal':[RandomHorizontalFlip(p = p, keepdim = True),Rotate(angle = 90.0)],
    'FlipRot_Vertical':[RandomVerticalFlip(p = p, keepdim = True),Rotate(angle = 90.0)],
    'RandomPlanckianJitter': RandomPlanckianJitter(mode='CIED',p = p, keepdim = True),
    'RandomPlasmaShadow': RandomPlasmaShadow(roughness=(0.1, 0.7),p =p, keepdim = True),
    'RandomPlasmaBrightness': RandomPlasmaBrightness(roughness=(0.1, 0.7), p=p, keepdim = True),
    'RandomPlasmaContrast':RandomPlasmaContrast(roughness=(0.1, 0.7), p=p, keepdim = True),
    'ColorJiggle': ColorJiggle(0.1, 0.1, 0.1, 0.1, p=p, keepdim = True),
    'RandomBoxBlur':RandomBoxBlur((7,7), keepdim = True),
    'RandomChannelShuffle': RandomChannelShuffle(p = p, keepdim = True),
    'RandomGaussianBlur':RandomGaussianBlur((3, 3), (0.1, 2.0), p=p, keepdim = True),
    'RandomGaussianNoise':RandomGaussianNoise(mean=0., std=0.015, p=p, keepdim = True),
    'RandomMotionBlur':RandomMotionBlur(3, 35., 0.5, p=p, keepdim = True),
    'RandomPosterize': RandomPosterize(3, p=p, keepdim = True),
    'RandomRGBShift':RandomRGBShift(p=p, keepdim = True),
    'RandomSharpness':RandomSharpness(1, p = p, keepdim = True),
    'RandomSolarize':RandomSolarize(0.1, 0.1, p=p, keepdim = True),
    'RandomAffine':RandomAffine((-15,15),padding_mode = 'reflection', p=p, keepdim = True),
    'RandomElasticTransform': RandomElasticTransform(p = p, keepdim = True),
    'HorizontalFlip':RandomHorizontalFlip(p = p, keepdim = True),
    'VerticalFlip': RandomVerticalFlip(p = p, keepdim = True),
    'RandomInvert':RandomInvert(p =p, keepdim = True),    
    'RandomResizedCrop':RandomResizedCrop(size = size ,p =p, keepdim = True),
    'RandomThinPlateSpline':RandomThinPlateSpline(p =p, keepdim = True),
}

if grayscale:
    aug_dict.pop('RandomPlanckianJitter')
    aug_dict.pop('RandomPlasmaShadow')
    aug_dict.pop('ColorJiggle')
    aug_dict.pop('RandomChannelShuffle')
    aug_dict.pop('RandomRGBShift')

keys = list(aug_dict.keys())
len(keys)

In [None]:
for key in keys:
    print("Execute Data Augmentation {}".format(key))
    aug = aug_dict[key]
    dest_path = source_path + '_{}'.format(key)

    if not os.path.exists(dest_path):
        os.mkdir(dest_path)
    
    for img_name in os.listdir(source_path):
        image_in = np.array(PIL.Image.open(os.path.join(source_path, img_name)))
        image_out = aug_img(image_in, aug)
        # print(image_in.max(), image_in.min())
        # print(image_out.max(), image_out.min())

        # [C, H, W] -> [H, W] for grayscal or [H, W, C] for rgb
        if grayscale:
            assert image_out.shape[0] == 1
            image_out = image_out.transpose(1,2,0).squeeze(-1)
        else:
            assert image_out.shape[0] == 3
            image_out = image_out.transpose(1,2,0)
        
        # plt.imshow(image_out,cmap=plt.get_cmap('gray'))
        image_out = PIL.Image.fromarray(image_out)
        

        # new_img_name = img_name.replace('.png', '_{}.png'.format(key))
        new_img_name = img_name
        image_out.save(os.path.join(dest_path, new_img_name))

print('Completed!')