<a href="https://colab.research.google.com/github/GUOOOZI/wb_augmentation/blob/main/WB_augmenter_on_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%matplotlib inline


WB augmenter tutorial
=====================

This is a tutorial on how to use our WB augmenter to augment images on the fly. In this example, we are going to train a neural network for classification on the cifar10 dataset. For simplicity, we will use only 500 images per class for training.

In particular, this tutorial focuses on showing how to use the WB augmenter with PyTorch built-in datasets, like cifar10.

The approach used in this tutorial could also be used for custom datasets/dataLoaders. Another way to apply the WB augmenter on the fly to loaded images is also provided in the official github page: https://github.com/mahmoudnafifi/WB_color_augmenter

In this tutorial, we clone only the Python version of the WB augmenter from here: https://github.com/mahmoudnafifi/WB_color_augmenter_python. To see more details and examples, please check the official github page.



**Citation:**

*Mahmoud Afifi and Michael S. Brown. What Else Can Fool Deep Learning? Addressing Color Constancy Errors on Deep Neural Network Performance. International Conference on Computer Vision (ICCV), 2019.*


In [None]:
import torch
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from collections import defaultdict, deque
import itertools
import numpy as np
import pickle

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f'Training/testing is going to use {device}')

Training/testing is going to use cuda:0


0. Clone the WB augmenter repo

In [None]:
import shutil
import os
if os.path.exists('WB_color_augmenter_python') != 0:
  shutil.rmtree('WB_color_augmenter_python')

!git clone https://github.com/mahmoudnafifi/WB_color_augmenter_python.git

Cloning into 'WB_color_augmenter_python'...
remote: Enumerating objects: 37, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 37 (delta 12), reused 21 (delta 4), pack-reused 0[K
Unpacking objects: 100% (37/37), done.


In [None]:
import WB_color_augmenter_python.WBEmulator as wb_aug

1. Dataset class with our data augmenter. See `__getitem__` that applies the WB augmentation on the fly.


In [None]:
class Cifar10_wo_WB_aug(datasets.CIFAR10):
    def __init__(self, path, transforms, train=True, download=True):
        super().__init__(path, train, download=download)
        self.transforms = transforms
        self.n_images_per_class = 500
        self.n_classes = 10
        self.new2old_indices = self.create_idx_mapping()

    def create_idx_mapping(self):
        label2idx = defaultdict(lambda: deque(maxlen=self.n_images_per_class))
        for original_idx in range(super().__len__()):
            _, label = super().__getitem__(original_idx)
            label2idx[label].append(original_idx)

        old_idxs = set(itertools.chain(*label2idx.values()))
        new2old_indices = {}
        for new_idx, old_idx in enumerate(old_idxs):
            new2old_indices[new_idx] = old_idx

        return new2old_indices

    def __len__(self):
        return len(self.new2old_indices)

    def __getitem__(self, index):
        index = self.new2old_indices[index]
        im, label = super().__getitem__(index)
        return self.transforms(im), label

In [None]:
class Cifar10_w_WB_aug(datasets.CIFAR10):
    def __init__(self, path, transforms, train=True, download=True):
        super().__init__(path, train, download=download)
        self.path = path
        self.transforms = transforms
        self.n_images_per_class = 500
        self.n_classes = 10
        self.new2old_indices = self.create_idx_mapping()
        self.wb_color_aug = wb_aug.WBEmulator()
        self.mapping = self.compute_mapping()

    def create_idx_mapping(self):
        label2idx = defaultdict(lambda: deque(maxlen=self.n_images_per_class))
        for original_idx in range(super().__len__()):
            _, label = super().__getitem__(original_idx)
            label2idx[label].append(original_idx)

        old_idxs = set(itertools.chain(*label2idx.values()))
        new2old_indices = {}
        for new_idx, old_idx in enumerate(old_idxs):
            new2old_indices[new_idx] = old_idx

        return new2old_indices

    def __len__(self):
        return len(self.new2old_indices)

    def compute_mapping(self):
        if os.path.exists(os.path.join(self.path, 'wb_mfs.pickle')):
          with open(os.path.join(self.path, 'wb_mfs.pickle'), 'rb') as handle:
            mapping_funcs = pickle.load(handle)
          return mapping_funcs

        print('Computing mapping functions for WB augmenter. '
        'This process may take time....')
        mapping_funcs = []
        for idx in range(super().__len__()):
            img, label = super().__getitem__(idx)
            mfs = self.wb_color_aug.computeMappingFunc(img)
            mapping_funcs.append(mfs)
        with open(os.path.join(self.path, 'wb_mfs.pickle'), 'wb') as handle:
          pickle.dump(mapping_funcs, handle, protocol=pickle.HIGHEST_PROTOCOL)

        return mapping_funcs

    def __getitem__(self, index):
        index = self.new2old_indices[index]
        im, label = super().__getitem__(index)
        mfs = self.mapping[index]
        ind = np.random.randint(len(mfs))
        mf = mfs[ind]
        wb_aug.changeWB(np.array(im), mf)
        return self.transforms(im), label

2. Download dataset and create dataloaders

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32

## training loaders
# w/ white-balance augmenter
trainset_w_WB_aug = Cifar10_w_WB_aug(path = "./data", transforms=transform,
                                     train=True, download=True)
trainloader_w_WB_aug = torch.utils.data.DataLoader(
    trainset_w_WB_aug, batch_size=batch_size, shuffle=True, num_workers=2)

# w/o white-balance augmenter
trainset_wo_WB_aug = Cifar10_wo_WB_aug(path = "./data",transforms=transform,
                                       train=True, download=True)
trainloader_wo_WB_aug = torch.utils.data.DataLoader(
    trainset_wo_WB_aug, batch_size=batch_size, shuffle=True, num_workers=2)


## testing loaders
# w/ white-balance augmenter

testset = Cifar10_wo_WB_aug(path = "./data",transforms=transform,
                                      train=False, download=True)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


3. Build network

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 240)
        self.fc2 = nn.Linear(240, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net_w_WB_aug = Net()
net_w_WB_aug.to(device)

net_wo_WB_aug = Net()
net_wo_WB_aug.to(device)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=240, bias=True)
  (fc2): Linear(in_features=240, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

4. Loss function



In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer_w_WB_aug = optim.SGD(net_w_WB_aug.parameters(), lr=0.001,
                               momentum=0.9)
optimizer_wo_WB_aug = optim.SGD(net_wo_WB_aug.parameters(), lr=0.001,
                                momentum=0.9)

5. Train two networks: one with the WB augmenter and the second one without the WB augmenter


In [None]:
EPOCHS = 100

## train w/ the WB augmenter
print('Training w/ WB augmentation')
for epoch in range(EPOCHS):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader_w_WB_aug, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        # zero the parameter gradients
        optimizer_w_WB_aug.zero_grad()

        # forward + backward + optimize
        outputs = net_w_WB_aug(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_w_WB_aug.step()

        # print statistics
        running_loss += loss.item()
        if (i + 1) % 50 == 0:    # print every 50 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 50))
            running_loss = 0.0

print('Finished Training')

PATH = './cifar_net_w_WB_aug.pth'
torch.save(net_w_WB_aug.state_dict(), PATH)




## train w/o the WB augmenter
print('\n\n\nTraining w/ WB augmentation')
for epoch in range(EPOCHS):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader_wo_WB_aug, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device=device)
        labels = labels.to(device=device)

        # zero the parameter gradients
        optimizer_wo_WB_aug.zero_grad()

        # forward + backward + optimize
        outputs = net_wo_WB_aug(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_wo_WB_aug.step()

        # print statistics
        running_loss += loss.item()
        if (i + 1) % 50 == 0:    # print every 50 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 50))
            running_loss = 0.0

print('Finished Training')

PATH = './cifar_net_wo_WB_aug.pth'
torch.save(net_wo_WB_aug.state_dict(), PATH)

Training w/ WB augmentation
[1,    50] loss: 2.308
[1,   100] loss: 2.303
[1,   150] loss: 2.304
[2,    50] loss: 2.303
[2,   100] loss: 2.303
[2,   150] loss: 2.304
[3,    50] loss: 2.302
[3,   100] loss: 2.301
[3,   150] loss: 2.304
[4,    50] loss: 2.301
[4,   100] loss: 2.302
[4,   150] loss: 2.301
[5,    50] loss: 2.300
[5,   100] loss: 2.300
[5,   150] loss: 2.300
[6,    50] loss: 2.298
[6,   100] loss: 2.298
[6,   150] loss: 2.298
[7,    50] loss: 2.295
[7,   100] loss: 2.296
[7,   150] loss: 2.294
[8,    50] loss: 2.291
[8,   100] loss: 2.290
[8,   150] loss: 2.286
[9,    50] loss: 2.281
[9,   100] loss: 2.274
[9,   150] loss: 2.266
[10,    50] loss: 2.248
[10,   100] loss: 2.240
[10,   150] loss: 2.220
[11,    50] loss: 2.199
[11,   100] loss: 2.182
[11,   150] loss: 2.176
[12,    50] loss: 2.138
[12,   100] loss: 2.143
[12,   150] loss: 2.102
[13,    50] loss: 2.091
[13,   100] loss: 2.083
[13,   150] loss: 2.046
[14,    50] loss: 2.043
[14,   100] loss: 2.054
[14,   150] los

6. Test our networks on the original test set and the test set with some white-balance degradations

In [None]:
### test on original test set

correct_wo_WB = 0
correct_w_WB = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device=device)
        labels = labels.to(device=device)
        # calculate outputs by running images through the network
        outputs_wo_WB_aug = net_wo_WB_aug(images)
        outputs_w_WB_aug = net_w_WB_aug(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs_wo_WB_aug.data, 1)
        total += labels.size(0)
        correct_wo_WB += (predicted == labels).sum().item()

        _, predicted = torch.max(outputs_w_WB_aug.data, 1)
        correct_w_WB += (predicted == labels).sum().item()

print('Accuracy of the network trained without the WB augmenter on the 10000 '
'test images: %d %%' % (100 * correct_wo_WB / total))

print('Accuracy of the network trained with the WB augmenter on the 10000 '
'test images: %d %%' % (100 * correct_w_WB / total))


Accuracy of the network trained without the WB augmenter on the 10000 test images: 44 %
Accuracy of the network trained with the WB augmenter on the 10000 test images: 47 %
