# Data Augmentation using pyTorch
Geometric Transformations

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import glob
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import numpy as np
from torchvision.utils import save_image

In [None]:
#sub directories
classes = ['Araneae','Coleoptera', 'Diptera', 'Hemiptera', 'Hymenoptera', 'Lepidoptera', 'Odonata']

directory = "ArTaxOr_data_light/"
output_directory = "Augmentation_data/"

batch_size = 8

In [None]:
class MyDataset(Dataset):
    def __init__(self,image_list,transforms=None):
        self.image_list=image_list
        self.transforms=transforms
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self,i):
        img=plt.imread(self.image_list[i])
        img=Image.fromarray(img).convert('RGB')
        img=np.array(img).astype(np.uint8)

        if self.transforms is not None:
            img=self.transforms(img)
        return torch.tensor(img,dtype=torch.float)

## Image Rotation

In [None]:
def rotation(img_list, batch):
    transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((244,244)),
                              transforms.RandomRotation(50,expand=True),  
                              transforms.Resize((244,244)),
                              transforms.ToTensor(),
                              ])
    return DataLoader(MyDataset(img_list,transform),batch_size=batch,shuffle=True)

## Image Cropping

In [None]:
def cropping(img_list, batch):
    transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.RandomCrop((600,600)),
                              transforms.Resize((244,244)),
                              transforms.ToTensor(),
                              ])
    return DataLoader(MyDataset(img_list,transform),batch_size=batch,shuffle=True)

## Image Flipping 

In [None]:
def flip(img_list, batch):
    transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((244,244)),
                              transforms.RandomVerticalFlip(0.4), 
                              transforms.RandomHorizontalFlip(0.4),        
                              transforms.ToTensor(),
                              ])
    return DataLoader(MyDataset(img_list,transform),batch_size=batch,shuffle=True)

## Brightness, contrast, saturation, hue

In [None]:
def hue(img_list, batch):
    transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((244,244)),
                              transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0, hue=0),
                              transforms.ToTensor(),
                              ])
    return DataLoader(MyDataset(img_list,transform),batch_size=batch,shuffle=True)

## Gaussian Noise

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
def gauss(img_list, batch):
    transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((244,244)),
                              transforms.ToTensor(),
                              AddGaussianNoise(0.1, 0.08)
                              ])
    return DataLoader(MyDataset(img_list,transform),batch_size=batch,shuffle=True)

## Random Erasing

In [None]:
def erase(img_list, batch):
    transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((244,244)),   
                              transforms.ToTensor(),
                              transforms.RandomErasing(),  
                              ])
    return DataLoader(MyDataset(img_list,transform),batch_size=batch,shuffle=True)

## creating and saving new data
run next cell to have x6 the volume of data

In [None]:
for c in classes:
    print("Data Augmetation for class %s" % c)
    img_list=glob.glob(directory+c+'/*.jpg')
    data_size = len(img_list)
    loaders = [rotation(img_list, batch_size),
               cropping(img_list, batch_size),
               flip(img_list, batch_size),
               hue(img_list, batch_size),
               gauss(img_list, batch_size),
               erase(img_list, batch_size)]
    k = 0
    for l in loaders:
        data = iter(l)
        for i in range(data_size//batch_size):
            try:
                images_batch = data.next()
                for j in range(images_batch.shape[0]):
                    img1 = images_batch[j]
                    save_image(img1, output_directory+c+'/img_augm_'+str(k)+'.jpg')
                    k+=1
            except:
                k+=1