In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
import tarfile
from math import sqrt
import torchvision
import torchvision.models as models
from torch.autograd import Variable
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, EMNIST, CIFAR10,CIFAR100
from PIL import Image
import torch.optim as optim
import numpy as np
from sklearn.metrics import confusion_matrix
# from resnetcifar import ResNet18_cifar10, ResNet50_cifar10
import logging
import os
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
from tensorboardX import SummaryWriter
# writer = SummaryWriter("logs")

In [2]:
def partition_data(dataset, datadir, logdir, partition, n_parties, beta=0.4):
    if dataset == 'cifar10':
        X_train, y_train, X_test, y_test = load_cifar10_data(datadir)
    elif dataset == 'cifar100':
        X_train, y_train, X_test, y_test = load_cifar100_data(datadir)
    elif dataset == 'tinyimagenet':
        X_train, y_train, X_test, y_test = load_tinyimagenet_data(datadir)

    n_train = y_train.shape[0]

    if partition == "homo" or partition == "iid":
        idxs = np.random.permutation(n_train)
        batch_idxs = np.array_split(idxs, n_parties)
        net_dataidx_map = {i: batch_idxs[i] for i in range(n_parties)}


    elif partition == "noniid-labeldir" or partition == "noniid":
        min_size = 0
        min_require_size = 10
        K = 10
        if dataset == 'cifar100':
            K = 100
        elif dataset == 'tinyimagenet':
            K = 200
            # min_require_size = 100

        N = y_train.shape[0]
        net_dataidx_map = {}

        while min_size < min_require_size:
            idx_batch = [[] for _ in range(n_parties)]
            for k in range(K):
                idx_k = np.where(y_train == k)[0]
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(beta, n_parties))
                proportions = np.array([p * (len(idx_j) < N / n_parties) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])
                # if K == 2 and n_parties <= 10:
                #     if np.min(proportions) < 200:
                #         min_size = 0
                #         break

        for j in range(n_parties):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]

    traindata_cls_counts = record_net_data_stats(y_train, net_dataidx_map, logdir)
    return (X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts)

def load_cifar10_data(datadir):
    transform = transforms.Compose([transforms.ToTensor()])

    cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=transform)
    cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=transform)

    X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target
    X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target

    # y_train = y_train.numpy()
    # y_test = y_test.numpy()

    return (X_train, y_train, X_test, y_test)
def record_net_data_stats(y_train, net_dataidx_map, logdir):
    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp

    data_list=[]
    for net_id, data in net_cls_counts.items():
        n_total=0
        for class_id, n_data in data.items():
            n_total += n_data
        data_list.append(n_total)
    print('mean:', np.mean(data_list))
    print('std:', np.std(data_list))
    logger.info('Data statistics: %s' % str(net_cls_counts))

    return net_cls_counts

In [3]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1., net_id=None, total=0):
        self.std = std
        self.mean = mean
        self.net_id = net_id
        self.num = int(sqrt(total))
        if self.num * self.num < total:
            self.num = self.num + 1

    def __call__(self, tensor):
        if self.net_id is None:
            return tensor + torch.randn(tensor.size()) * self.std + self.mean
        else:
            tmp = torch.randn(tensor.size())
            filt = torch.zeros(tensor.size())
            size = int(28 / self.num)
            row = int(self.net_id / size)
            col = self.net_id % size
            for i in range(size):
                for j in range(size):
                    filt[:,row*size+i,col*size+j] = 1
            tmp = tmp * filt
            return tensor + tmp * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


In [4]:
def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None, noise_level=0, net_id=None, total=0):
    if dataset in ('mnist', 'femnist', 'fmnist', 'cifar10', 'svhn', 'generated', 'covtype', 'a9a', 'rcv1', 'SUSY'):
        if dataset == 'mnist':
            dl_obj = MNIST_truncated

            transform_train = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])

        elif dataset == 'femnist':
            dl_obj = FEMNIST
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])

        elif dataset == 'fmnist':
            dl_obj = FashionMNIST_truncated
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])

        elif dataset == 'svhn':
            dl_obj = SVHN_custom
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])


        elif dataset == 'cifar10':
            dl_obj = CIFAR10_truncated

            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(
                    Variable(x.unsqueeze(0), requires_grad=False),
                    (4, 4, 4, 4), mode='reflect').data.squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)
            ])
            # data prep for test set
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(0., noise_level, net_id, total)])

        else:
            dl_obj = Generated
            transform_train = None
            transform_test = None


        train_ds = dl_obj(datadir, dataidxs=dataidxs, train=True, transform=transform_train, download=True)
        test_ds = dl_obj(datadir, train=False, transform=transform_test, download=True)

        train_dl = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=False)
        test_dl = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=False)

    return train_dl, test_dl, train_ds, test_ds


In [5]:
class CIFAR10_truncated(data.Dataset):

    def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):

        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        self.data, self.target = self.__build_truncated_dataset__()

    def __build_truncated_dataset__(self):

        cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)

        if torchvision.__version__ == '0.2.1':
            if self.train:
                data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels)
            else:
                data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels)
        else:
            data = cifar_dataobj.data
            target = np.array(cifar_dataobj.targets)

        if self.dataidxs is not None:
            data = data[self.dataidxs]
            target = target[self.dataidxs]

        return data, target

    def truncate_channel(self, index):
        for i in range(index.shape[0]):
            gs_index = index[i]
            self.data[gs_index, :, :, 1] = 0.0
            self.data[gs_index, :, :, 2] = 0.0

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.target[index]
        # img = Image.fromarray(img)
        # print("cifar10 img:", img)
        # print("cifar10 target:", target)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

## generate data

In [6]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
alpha = 0.5
all_clients = 10
X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_data(
"cifar10", "../data", "../data", "noniid", all_clients, beta=alpha)

Files already downloaded and verified
Files already downloaded and verified


INFO:root:Data statistics: {0: {0: 953, 1: 142, 2: 141, 3: 75, 4: 695, 5: 819, 7: 2482}, 1: {0: 16, 1: 43, 2: 902, 3: 1650, 4: 86, 5: 182, 7: 693, 8: 110, 9: 6}, 2: {0: 9, 1: 8, 2: 290, 3: 769, 4: 841, 5: 283, 6: 119, 7: 1044, 8: 1014, 9: 58}, 3: {0: 395, 1: 1200, 2: 48, 3: 68, 4: 896, 5: 681, 6: 90, 7: 17, 8: 1351, 9: 301}, 4: {0: 504, 1: 2917, 2: 570, 3: 721, 4: 121, 5: 356}, 5: {0: 1262, 1: 71, 2: 325, 3: 119, 4: 1560, 5: 14, 6: 1, 7: 85, 8: 366}, 6: {0: 9, 1: 273, 2: 1657, 3: 40, 4: 1, 5: 130, 6: 1911, 7: 21, 8: 1160}, 7: {0: 722, 1: 3, 2: 281, 3: 738, 4: 22, 5: 974, 6: 624, 7: 1, 8: 82, 9: 878}, 8: {0: 1127, 1: 153, 2: 680, 3: 500, 4: 698, 5: 1139, 6: 92, 7: 23, 8: 1, 9: 3665}, 9: {0: 3, 1: 190, 2: 106, 3: 320, 4: 80, 5: 422, 6: 2163, 7: 634, 8: 916, 9: 92}}


mean: 5000.0
std: 1165.398901664147


In [7]:
train_dl_local, test_dl_local, train_ds, test_ds = get_dataloader("cifar10", 
                                                                      '../data', 64, 32,
                                                                      dataidxs=net_dataidx_map[0],
                                                                  noise_level=0.1,net_id=0, total=1)
# train_ds.data

Files already downloaded and verified
Files already downloaded and verified


In [11]:
train_ds.data[0]

In [1]:
next(iter(train_dl_local))[0][0]

In [2]:
transforms.ToPILImage()(next(iter(train_dl_local))[0][0]).resize((256,256))

In [177]:
train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=256, drop_last=False, shuffle=False)


In [84]:
from collections import Counter
import shutil
# data_ = train_ds.data
# labels_=list(train_ds.target)

# for i in range(10):
#     if i not in Counter(train_ds.target).keys():
#         for j in range(10):
#             data_=np.concatenate((data_,np.zeros_like(data_[0]).reshape((1,32, 32, 3))))
#             labels_.append(i)

In [85]:
data_.shape

(4479, 3072)

In [110]:
## put the data that after partition into disk from memory, so CAREFULLY DO IT TO NOT OVERWRITE THE EXISTING DATA
for client in range(10):
    train_dl_local, test_dl_local, train_ds, test_ds = get_dataloader("cifar10", 
                                                                      '../data', 64, 32,
                                                                      dataidxs=net_dataidx_map[client])
    data_ = train_ds.data
    labels_=list(train_ds.target)

    for i in range(10):
        if i not in Counter(train_ds.target).keys():
            for j in range(10):
                data_=np.concatenate((data_,np.zeros_like(data_[0]).reshape((1,32, 32, 3))))
                labels_.append(i)
            
    num = data_.shape[0]
    step = num//5
    data_ = data_.reshape(num,-1)
#     labels_=list(train_ds.target)
    os.makedirs("cifar-10-batches-py",exist_ok=True)
    for i in range(5):
        dic = {}
        dic['data']=data_[step*i:step*(i+1)]
        dic['labels']=labels_[step*i:step*(i+1)]
        with open("cifar-10-batches-py/data_batch_"+str(i+1),"wb") as f:
            pickle.dump(dic,f)

    ## verify
    data = []
    targets = []
    downloaded_list = [
        "cifar-10-batches-py/data_batch_1",
        "cifar-10-batches-py/data_batch_2",
        "cifar-10-batches-py/data_batch_3",
        "cifar-10-batches-py/data_batch_4",
        "cifar-10-batches-py/data_batch_5"
    ]
    # now load the picked numpy arrays
    for file_name in downloaded_list:
        file_path = file_name
        with open(file_path, "rb") as f:
            entry = pickle.load(f, encoding="latin1")
            data.append(entry["data"])
            if "labels" in entry:
                targets.extend(entry["labels"])
            else:
                targets.extend(entry["fine_labels"])

    data = np.vstack(data).reshape(-1, 3, 32, 32)
    data = data.transpose((0, 2, 3, 1))  # convert to HWC
    # data.shape
    #  should like this: (5305, 32, 32, 3)

    # compress file to tar.gz
    def Converter(path, tar):
        with tarfile.open(tar, "w:gz") as t:
            for root, dirs, files in os.walk(path):
                for file in files:
                    t.add(file)
        
    Converter("cifar-10-batches-py", "cifar-10-python.tar.gz")
    # move to new dir
    ne_path = "data/alpha-"+str(alpha)+"/partition_client_"+str(client+1)
    os.makedirs(ne_path) # ,exist_ok=True)
    shutil.move("cifar-10-python.tar.gz",ne_path)

In [109]:
current_dir = os.getcwd()


'/home/aikedaer/Desktop/FedMOON/ipynb'

In [10]:
x=traindata_cls_counts
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 5, sharex='col', sharey='row',figsize=(25,10))
for i in range(10):
    if i==0:
        ax[i//5,i%5].set_ylabel("Number of corresponding label")
    elif i>5:
        ax[i//5,i%5].set_xlabel("Class_label")
    elif i==5:
        ax[i//5,i%5].set_ylabel("Number of corresponding label")
        ax[i//5,i%5].set_xlabel("Class_label")
        
    ax[i//5,i%5].set_xlim(-1,10)
    ax[i//5,i%5].bar(list(x[i].keys()),list(x[i].values()),width = 1)
    ax[i//5,i%5].set_title("client-{0}".format(i))
    for m,n in zip(list(x[i].keys()),list(x[i].values())):
        ax[i//5,i%5].text(m+0.05,n+0.05,'%d' %n, ha='center',va='bottom')
# plt.title("partition-"+str(alpha)) not work
plt.savefig("data/alpha-"+str(alpha)+"/partition.png",dpi=330)
plt.show()

## load data

In [8]:
from typing import Any, Callable, Optional, Tuple
import hashlib
import os.path
import pickle
import numpy as np
from PIL import Image

class MYCIFAR10(torch.utils.data.Dataset):
    base_folder = "cifar-10-batches-py"
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]

    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }

    def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:

        #         super().__init__(root, transform=transform, target_transform=target_transform)
        self.root = root
        self.train = train  # training set or test set
        self.transform = transform
        self.target_transform = target_transform
        self.downlod = download

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data: Any = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
                self.data.append(entry["data"])
                if "labels" in entry:
                    self.targets.extend(entry["labels"])
                else:
                    self.targets.extend(entry["fine_labels"])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.classes = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)
    
#     def extract_file(self):
#         # open file
#         file = tarfile.open('gfg.tar.gz')

#         # extracting file
#         file.extractall('./Destination_FolderName')

#         file.close()



In [9]:
transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.RandomRotation(50,expand=True),  
                              transforms.ToTensor()
                              ])

In [14]:
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
cifar_dataobj = MYCIFAR10("data/alpha-0.5/partition_client_1", True, transform)
cifar_dataobj.data.shape

(5335, 32, 32, 3)

In [16]:
cifar_dataobj.data[0].shape

(32, 32, 3)

In [249]:
# mean = [0.4914, 0.4822, 0.4465]
# # std = [0.2023, 0.1994, 0.2010]
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
# cifar_dataobj = CIFAR10("../data", True, transform,download=True)
# cifar_dataobj

In [251]:
cifar10_obj = MYCIFAR10("data/alpha-0.5/partition_client_1",True,transform)

In [9]:
cifar10_obj[0]

In [234]:
from collections import Counter
for i in Counter(cifar10_obj.targets).keys():
    print(i)

7
0
1
5
3
4
2


In [206]:
cifar10_obj.classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [80]:
import pickle
data = []
targets = []
downloaded_list = [
    "/home/aikedaer/Desktop/cifar-10-batches-py/data_batch_1",
    "/home/aikedaer/Desktop/cifar-10-batches-py/data_batch_2",
    "/home/aikedaer/Desktop/cifar-10-batches-py/data_batch_3",
    "/home/aikedaer/Desktop/cifar-10-batches-py/data_batch_4",
    "/home/aikedaer/Desktop/cifar-10-batches-py/data_batch_5"
]
# now load the picked numpy arrays
for file_name in downloaded_list:
    file_path = file_name
    with open(file_path, "rb") as f:
        entry = pickle.load(f, encoding="latin1")
        data.append(entry["data"])
        if "labels" in entry:
            targets.extend(entry["labels"])
        else:
            targets.extend(entry["fine_labels"])

data = np.vstack(data).reshape(-1, 3, 32, 32)
data = data.transpose((0, 2, 3, 1))  # convert to HWC

In [81]:
dat = data.transpose((0,3,1,2))


In [27]:
entry.keys()

dict_keys(['batch_label', 'labels', 'data', 'filenames'])

In [3]:
entry['filenames']

In [4]:
entry['labels']

In [31]:
entry['data'].shape

(10000, 3072)

In [34]:
data.shape

(50000, 32, 32, 3)

In [89]:
with open("partition","wb") as f:
    h = pickle.dump(data.reshape(-1),f)

In [90]:
with open("partition", "rb") as f:
    ed = pickle.load(f, encoding="latin1")
ed

array([ 59,  62,  63, ..., 163, 163, 161], dtype=uint8)

In [5]:
data[0].reshape(-1).reshape((32,32,3))==data[0]

In [102]:
entry['data']

array([[255, 252, 253, ..., 173, 231, 248],
       [127, 126, 127, ..., 102, 108, 112],
       [116,  64,  19, ...,   7,   6,   5],
       ...,
       [ 35,  40,  42, ...,  77,  66,  50],
       [189, 186, 185, ..., 169, 171, 171],
       [229, 236, 234, ..., 173, 162, 161]], dtype=uint8)


https://datamahadev.com/performing-image-augmentation-using-pytorch/


In [159]:
transform1 = torchvision.transforms.Compose([
    torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
#     torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(20),
    torchvision.transforms.ToTensor()
])
transform2 = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=mean, std=std)])
# root_folder is the string containing address of the root image data directory 

transform_train = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Lambda(lambda x: F.pad(
#         Variable(x.unsqueeze(0), requires_grad=False),
#         (4, 4, 4, 4), mode='reflect').data.squeeze()),
    transforms.ToPILImage(),
    transforms.ColorJitter(brightness=0.1),
    transforms.RandomCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
#     transforms.Normalize(mean=mean, std=std)
])

dataset = torchvision.datasets.CIFAR10(root = "../data", transform=transform_train)
dataset.data.shape

(50000, 32, 32, 3)

In [6]:
transform_aug = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256,256)),
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(32),
    torchvision.transforms.ToTensor()
])
transform_aug(dataset[0][0]).cpu().numpy()

In [7]:
(torch.randn_like(dataset[0][0])*(1/10)).cpu().numpy()

In [8]:
testloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
for idx, (x,y) in enumerate(testloader):
    display(transforms.ToPILImage()(x[0]).resize((256,256)))
print(idx)

Performing Image Augmentation using Pytorch

by Aman Sharma · September 7, 2020
6+	
A Detailed Guide on How to Use Image Augmentation in PyTorch to Give Your Models a Data Boost.

In the last few years, there have been some major breakthroughs and developments in the field of Deep Learning. The constant research and rapid developments have made Deep Learning an industry-standard in the field of AI and the main topic of discussion in almost every AI and Data Science conventions, overthrowing its parent and predecessor— traditional Machine Learning. 

When it comes to Computer Vision, that strictly deals with video and image data, and problems like object detection, body pose detection, image segmentation, etc., Deep Learning has proved out to be a much more reliable option as compared to the traditional Machine Learning.

The reason behind this is that Deep Learning specializes is tackling high dimensionality problems. While machine learning works perfectly fine when you only have a few hundred features to train your model on, the performance starts to deteriorate as the dimensionality of your data increases. With the evolution of Data Science and Big Data over the years, the complexity of the problems and the type of data the Data Scientists have to work with has increased a lot. 

To give you an idea of this massive increase in the scale of data, we will consider an example here. Let’s say that you are working on an image dataset, where you have to deal with 3000×4000 px RGB images.

Considering each pixel data to be a feature, every single data instance (i.e., the images) will have (3000 x 4000 x 3) = 36,000,000 features. Yes, the number of features that the model will have to train on is in millions, which is, frankly speaking, not feasible for almost any traditional machine learning algorithm to handle.

Deep Learning, on the other hand, performs exceptionally well when we have to deal with high dimensional data, like the images in the example we discussed above. This makes Deep Learning the ideal choice for Computer Vision problems. 

This ability to deal with high dimensional data makes Deep Learning seem so powerful, right? And now that GPUs don’t cost an arm and a leg (and that you can access free, high-speed GPUs for free via services like Google Colab), it might seem like there’s no need for traditional Machine Learning at all! 

However, there’s a catch here. Generally, it is observed that the performance of a Deep Learning model is directly proportional to the size of the dataset, i.e., the total number of data instances within a dataset.

Upon a glance at the graph given above, you will observe a rather strange pattern in it. You will see that when the size of the dataset is small, Machine Learning tends to perform slightly better. However, as the size of the dataset that the model is to be trained on increases, Deep Learning models really start to outperform their Machine Learning counterparts by a huge margin.

The reason? Deep Learning model architectures, in general, have millions of parameters to train in order to effectively adapt to certain patterns within the data. To facilitate this extensive training task, a very large amount of data is required. If there isn’t enough data for the model to train, the model’s inference performance will take a huge hit, and you might not get the results that you expected.

Therefore, if you are working on a Computer Vision problem, say an image segmentation problem, then in order to get a good performance out of a Deep Learning model, you obviously need large amounts of image data. Now one solution to this can be a collection of more images for training your model. But the downside of data collection is that it can be a very expensive task, both economically as well as technologically. 

A more economically feasible option would a technique known as Image Augmentation. If you are just getting started in Deep Learning, this might be an entirely new term for you. In that case, you have ended up in the right place. In this article, we will understand what Image Augmentation is, as well as have a look at how to apply image augmentation to training data in Python using PyTorch.

So, let’s get started.
Image Augmentation

Image Augmentation can be defined as the process by which we can generate new images by creating randomized variations in the existing image data. The technique can be used to increase the size of your dataset by creating additional data instances that can be used to train your model on. For an image classifications model, this simply translates to better performance.

I think the definition will become clearer once you see an example. In the example given below, we have the original image of an SUV on a street. 

In the first augmented image, by zooming in and increasing the brightness, we got a new image. The second augmented image was generated by tweaking the hue and temperature of the original image. In the third augmented image, the original image was vertically flipped. Thus, just by tweaking the color and the alignment of the images, we were able to create 3 more data instances that a model can train on. 

For human eyes, all these images in the example given above might look alike. But for a Deep Learning model that deals with the images as individual pixels (with values ranging from 0-255) spanning across the 3 color channels (RGB), all these images are different, since the individual pixel values of these images are different. Thus, image augmentation allows us to generate new image data for training our deep learning model without having to go extra lengths to collect the data manually. 

One more advantage that the image augmentation technique provides for Deep Learning is that by creating randomization in the image data, it significantly reduces the chances of the model overfitting on the training data. This allows the model to generalize better, and hence, improves the inference accuracy of the model.
Image Augmentation Using PyTorch

Now that we know what the image augmentation technique is used for, let us have a look at how you can implement a variety of image augmentations in PyTorch.

For this tutorial, first, we will understand the use and the effect of different image augmentation methods individually on a single image. Once we are done with that, we will see how to perform image augmentations in a Deep Learning project for a real-world dataset.

Let us begin by importing all the necessary PyData modules and PyTorch.

Now, before we start performing the transformations, let us have a look at our original image.

Now, let us have a look at some of the most used image augmentation techniques in PyTorch and the purpose they are used for.

    `CenterCrop` – The CenterCrop image augmentation is used to crop the input image at the center. The size of the crop is determined by the help of the ‘size’ attribute. A single integer value as the size argument performs a square cropping on the image of dimension size x size. To set a custom size, the value of the size attribute should be a tuple size = (width, height).

Here’s how to implement CenterCrop in PyTorch:

    `ColorJitter`– ColorJitter augmentation technique is used to randomly change the brightness, contrast, saturation, and hue of the image. Unlike the CenterCrop image augmentation that we saw earlier, ColorJitter doesn’t have a fixed behavior. Rather, it results in a random color augmentation each time. 

Here’s how to implement ColorJitter in PyTorch:

img = Image.open('/content/2_city_car_.jpg')

color_jitter = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
img = color_jitter(img)
plt.imshow(img)

view raw
color_jitter.py hosted with ❤ by GitHub

    `Grayscale` – The Grayscale image augmentation is used to convert a multi-channeled (RGB, CYAN, etc.) image into a single-channeled (gray-scaled) or triple-channeled (r==g==b) image.

Here’s how to implement Grayscale in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

gray = torchvision.transforms.Grayscale(num_output_channels=1)
img = gray(img)
plt.imshow(img, cmap='gray')
view raw
gray.py hosted with ❤ by GitHub

    `Pad`– The Pad image transform is used to pad the given image on all sides. The thickness of the padding is determined by the ‘padding’ argument. 

Here’s how to implement Pad in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

pad = torchvision.transforms.Pad(50, fill=0, padding_mode='constant')
img = pad(img)
plt.imshow(img)
view raw
pad.py hosted with ❤ by GitHub

    `RandomCrop`– The RandomCrop image augmentation acts in a way similar to that as the CenterCrop. The only difference is that it crops the original image at any random location rather than from just the center. Again, the size of the crop is determined by the ‘size’ attribute.

Here’s how to implement RandomCrop in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

random_crop = torchvision.transforms.RandomCrop((200, 300), fill=0, padding_mode='constant')
img = random_crop(img)
plt.imshow(img)
view raw
random_crop.py hosted with ❤ by GitHub

    `RandomHorizontalFlip` – The RandomHorizontalFlip image augmentation horizontally flips the image. The probability of the flipping operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1. 

Here’s how to implement RandomHorizontalFlip in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

horizontal_flip = torchvision.transforms.RandomHorizontalFlip(p=1)
img = horizontal_flip(img)
plt.imshow(img)
view raw
horizontal_flip.py hosted with ❤ by GitHub

    `RandomVerticalFlip` – Just like the horizontal flip augmentation that we saw earlier, RandomVerticalFlip also flips the image. The only difference is the flipping occurs across the x-axis, i.e., in simple words, in the vertical direction. The probability of the flipping operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1. 

Here’s how to implement RandomVerticalFlip in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

vertical_flip = torchvision.transforms.RandomVerticalFlip(p=1)
img = vertical_flip(img)
plt.imshow(img)
view raw
vertical_flip.py hosted with ❤ by GitHub

    `RandomPerspective` – The RandomPerspective image augmentation is used to randomly distort the image along with a given perspective. The probability of the flipping operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1.; and the scale of the distortion can be controlled using the ‘distortion_scale’ attribute, its value also ranging between 0-1.

Here’s how to implement RandomPerspective in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

random_persp = torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=1, interpolation=3, fill=0)
img = random_persp(img)
plt.imshow(img)
view raw
random_perspective.py hosted with ❤ by GitHub

    `RandomRotation` – The RandomRotation randomly rotates the image. The degree of rotation of the image is determined using the ‘degree’ attribute. 

Here’s how to implement RandomRotation in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

random_rotation = torchvision.transforms.RandomRotation(degrees = 45)
img = random_rotation(img)
plt.imshow(img)
view raw
random_rotate.py hosted with ❤ by GitHub

    `RandomErasing` – The RandomErasing image augmentation technique randomly selects a rectangular region in the original image and erases all the pixels in the region. The probability or the erase operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1.

Here’s how to implement RandomErasing in PyTorch:
img = Image.open('/content/2_city_car_.jpg')

tensor = torchvision.transforms.ToTensor()
random_erase = torchvision.transforms.RandomErasing(p=1)

img = tensor(img)
img = random_erase(img)
plt.imshow(img.permute(1,2,0))
view raw
random_erase.py hosted with ❤ by GitHub

Now that we have seen some of the most used image augmentation techniques in PyTorch, let us have a look at how to apply these in a real-world project. Generally, the augmentations/transforms are applied in a sequence, all at once. For this, we have to use torchvision.transforms.Compose() method. The augmentations that are to be performed on the images are passes to the compose method as an argument.

Let us see how to implement this using PyTorch:
img = Image.open('/content/2_city_car_.jpg')

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((300,400)),
    torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(20)
])

img = transforms(img)
plt.imshow(img)
view raw
compose.py hosted with ❤ by GitHub

Up until now, we saw how to apply the transformations/augmentations on a single image. But in real-world problems, the datasets may have thousands of images. 

Unlike the Pandas DataFrames that we see in many traditional machine learning problems, it is generally not possible to store all the images in the memory (RAM) at once in the form of DataFrames. Therefore, PyTorch handles these images via the various Dataset classes available in PyTorch.In order to apply the transforms on an entire dataset, all you need to do is pass the torchvision.transforms.Compose method object (or an individual image augmentation method object, if you want) as the value to the ‘transform’ attribute. There are several Dataset classes in PyTorch, but as an example, we will see how to apply the image augmentation to an ImageFolder dataset.
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((300,400)),
    torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(20)
])

root_folder is the string containing address of the root image data directory 
`dataset = torchvision.datasets.ImageFolder(root = root_folder, transform=transforms)`
view raw
dataset.py hosted with ❤ by GitHub

With this, we come to an end of our tutorial part where we learned why image augmentation is necessary for Deep Learning and how to apply different image augmentations in PyTorch. By using multiple combinations of augmentations on different batches of data and retraining your model on this augmented and original data again and again over several epochs (training cycles), you can overcome the barrier that has a small image dataset poses against your model.