In [1]:
import autoaug
import autoaug.autoaugment_learners as aal

import torchvision
import torch.nn as nn
import torchvision.datasets as datasets

### Defining our CNN Classifier

In [2]:
class LeNet(nn.Module):
    def __init__(self, img_height=28, img_width=28, num_labels=10, img_channels=1):
        super().__init__()
        self.conv1 = nn.Conv2d(img_channels, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, num_labels)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

### Defining the training and validation datasets

In [3]:
import torchvision.datasets as datasets

train_dataset = datasets.MNIST(
                        root='./autoaug/datasets/mnist/train',
                        train=True,
                        download=True,
                        transform=None
                        )
val_dataset = datasets.MNIST(
                        root='./autoaug/datasets/mnist/test',
                        train=False,
                        download=True,
                        transform=torchvision.transforms.ToTensor()
                        )

### Defining the child network architecture

In [4]:
child_network_architecture = LeNet

### specifying parameters for the auto-augment learner

In [5]:
search_space_hyp = {
        'exclude_method': ['Invert', 'Solarize']
        }
child_network_hyp = {
        'learning_rate': 0.01,
        'early_stop_num': 5,
        'batch_size': 32,
        'toy_size': 0.005
        }


learner = aal.RsLearner(
                        **search_space_hyp,
                        **child_network_hyp,
                        )

### Training the auto-augment learner

In [6]:
learner.learn(
        train_dataset=train_dataset,
        test_dataset=val_dataset,
        child_network_architecture=child_network_architecture,
        iterations = 9)

[(('Posterize', 0.8, 3), ('Rotate', 0.6, 4)),
 (('Rotate', 1.0, 9), ('Sharpness', 0.1, 1)),
 (('ShearY', 0.5, 4), ('Sharpness', 0.9, 3)),
 (('Posterize', 0.0, 6), ('Rotate', 0.0, 9)),
 (('ShearX', 0.3, 9), ('Brightness', 0.1, 7))]
main.train_child_network best accuracy:  tensor(0.0800)
main.train_child_network best accuracy:  tensor(0.1400)
main.train_child_network best accuracy:  tensor(0.1400)
main.train_child_network best accuracy:  tensor(0.1400)
main.train_child_network best accuracy:  tensor(0.1400)
main.train_child_network best accuracy:  tensor(0.1400)
[(('Sharpness', 0.6, 9), ('Rotate', 0.4, 2)),
 (('AutoContrast', 0.3, None), ('ShearX', 0.7, 3)),
 (('Contrast', 0.3, 7), ('ShearX', 0.8, 8)),
 (('Equalize', 0.0, None), ('Sharpness', 0.8, 0)),
 (('AutoContrast', 0.0, None), ('Sharpness', 0.3, 3))]
main.train_child_network best accuracy:  tensor(0.1200)
main.train_child_network best accuracy:  tensor(0.1200)
main.train_child_network best accuracy:  tensor(0.1200)
main.train_child

### Viewing the Results

In [7]:
# pretty print
from pprint import pprint

pprint(learner.get_n_best_policies(2))

[([(('Posterize', 0.8, 3), ('Rotate', 0.6, 4)),
   (('Rotate', 1.0, 9), ('Sharpness', 0.1, 1)),
   (('ShearY', 0.5, 4), ('Sharpness', 0.9, 3)),
   (('Posterize', 0.0, 6), ('Rotate', 0.0, 9)),
   (('ShearX', 0.3, 9), ('Brightness', 0.1, 7))],
  0.14000000059604645),
 ([(('Rotate', 0.6, 0), ('Posterize', 0.8, 4)),
   (('Contrast', 0.4, 8), ('Color', 0.8, 9)),
   (('Contrast', 0.4, 3), ('Equalize', 0.8, None)),
   (('TranslateX', 0.5, 4), ('TranslateY', 0.8, 4)),
   (('TranslateX', 0.8, 2), ('AutoContrast', 0.0, None))],
  0.14000000059604645)]
