### Visualise examples from the dataset
Run the cell below multiple times to see various images. (They might look a bit blurry because we've blown up the small images.)

In [0]:
plt.figure(figsize=(10,10))
for i in range(25):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid('off')

  img_index = np.random.randint(0, 40000)
  plt.imshow(train_images[img_index])
  plt.xlabel(cifar_labels[train_labels[img_index]])

###Visualizing the model

In [0]:
metric_values = model.evaluate(x=test_images, y=test_labels)

print('Final TEST performance')
for metric_value, metric_name in zip(metric_values, model.metrics_names):
  print('{}: {}'.format(metric_name, metric_value))

### Classifying examples
We now use our trained model to classify a sample of 25 images from the test set. We pass these 25 images to the  ```model.predict``` function, which returns a [25, 10] dimensional matrix. The entry at position $(i, j)$ of this matrix contains the probability that image $i$ belongs to class $j$. We obtain the most-likely prediction using the ```np.argmax``` function which returns the index of the maximum entry along the columns. Finally, we plot the result with the prediction and prediction probability labelled underneath the image and true label on the side. 

In [0]:
img_indices = np.random.randint(0, len(test_images), size=[25])
sample_test_images = test_images[img_indices]
sample_test_labels = [cifar_labels[i] for i in test_labels[img_indices].squeeze()]

predictions = model.predict(sample_test_images)
max_prediction = np.argmax(predictions, axis=1)
prediction_probs = np.max(predictions, axis=1)

In [0]:
plt.figure(figsize=(10,10))
for i, (img, prediction, prob, true_label) in enumerate(
    zip(sample_test_images, max_prediction, prediction_probs, sample_test_labels)):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid('off')

  plt.imshow(img)
  plt.xlabel('{} ({:0.3f})'.format(cifar_labels[prediction], prob))
  plt.ylabel('{}'.format(true_label))


## Your Tasks
1. [**ALL**] Experiment with the network architecture, try changing the numbers, types and sizes of layers, the sizes of filters, using different padding etc. How do these decisions affect the performance of the model? In particular, try building a *fully convolutinoal* network, with no (max-)pooling layers. 
2. [**ALL**] Add BATCH NORMALISATION ([Tensorflow documentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers/BatchNormalization) and [research paper](http://proceedings.mlr.press/v37/ioffe15.pdf)) to improve the model's generalisation.
3. [**ADVANCED**] Read about Residual networks ([original paper](https://arxiv.org/pdf/1512.03385.pdf), ) and add **shortcut connections** to the model architecture. Try to build a simple reusable "residual block" as a [Keras Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model). 
4. [**OPTIONAL**]. Visualise the filters of the convolutional layers using Matplotlib. **HINT**: You can retrieve a reference to an individual layer from the sequential Keras model by calling```model.get_layer(name)```, replacing "name" with the name of the layer. 

# ResNet blocks

In [0]:
# may need to update tensorflow_addons if using google collab
!pip install --upgrade tensorflow_addons
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import tensorflow_addons as tfa
%matplotlib inline

from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)


print("TensorFlow executing eagerly: {}".format(tf.executing_eagerly()))

In [0]:
class Conv2DBatchNormReLU(tf.keras.layers.Layer):
    """
    Creates a layer of the form:

    Convolutional layer 2D -> Batch Norm -> ReLU

    Conv layer uses He Uniform weight initialization, no bias, padding=1. Batch Norm uses epsilon=1e-5, momentum=0.9 

Arguments:
    filters (int): number of filters used convolutions, and gives num channels of output tensor
    kernel_size (int or tuple of 2 integers): window size used in convolutions
    input_shape (tuple of 3 integers): only needed if this is the first layer in the network. Use (rows, cols, filters)

Output shape:
    (batch_size, rows, cols, filters)

Returns:
    A tensor of rank 4
    """

    def __init__(self, filters, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.conv2d = tf.keras.layers.Conv2D(filters=filters,
                                             kernel_size=kernel_size,
                                             padding='SAME',
                                             use_bias=False,
                                             kernel_initializer=tf.keras.initializers.he_normal()
)
        self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9)
        self.relu = tf.keras.layers.ReLU()
      
    def call(self, inputs, training=False):
        x = self.conv2d(inputs)
        # From documentation: training=False (in call - not in layer!) 
        # layer will normalize its inputs using the mean and variance of its 
        # moving statistics, learned during training
        x = self.batch_norm(x, training=training)
        x = self.relu(x)
        return x

class ResNet9Block(tf.keras.layers.Layer):
    """
    Creates a Residual Block of type used in the network in https://github.com/davidcpage/cifar10-fast (see demo.ipynbwhich details the structure). Structure:

    Let CBR = Convolutional layer 2D -> Batch Norm -> ReLU
    Let x = CBR -> Pooling Layer (factor of 2)

    Then the structure is (x -> CBR -> CBR) + x

    Arguments:
        filters (int): number of filters used in all convolutions, and gives num channels of output tensor
        kernel_size (int or tuple of 2 integers): window size used in all convolutions
    input_shape (tuple of 3 integers): only needed if this is the first layer in the network. Use (rows, cols, filters)

    Output shape:
        (batch_size, rows/2, cols/2, filters)

    Returns:
        A tensor of rank 4
    """

    def __init__(self, filters, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.conv_bn_relu_1 = Conv2DBatchNormReLU(filters=filters, kernel_size=kernel_size)
        self.conv_bn_relu_2 = Conv2DBatchNormReLU(filters=filters, kernel_size=kernel_size)
        self.conv_bn_relu_3 = Conv2DBatchNormReLU(filters=filters, kernel_size=kernel_size)
        self.max_pool2d = tf.keras.layers.MaxPool2D()

    def call(self, inputs):
        """
        Residual behaviour implemented here.
        """
        x = self.conv_bn_relu_1(inputs)
        x = self.max_pool2d(x)
        y = self.conv_bn_relu_2(x)
        y = self.conv_bn_relu_3(y)
        x = x + y
        return x

class ScalarMultiply(tf.keras.layers.Layer):
    """
    Layer to implement multiplication by a scalar.

    Arguments:
        scalar (float): scalar to multiply by.

    Output shape:
        (batch_size, rows, cols, filters)

    Returns:
        A tensor of rank 4
    """

    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)
        self.scalar = tf.constant(scalar)

    def call(self, inputs):
        x = tf.math.scalar_mul(self.scalar, inputs)
        return x

class ResNet9BlockShort(tf.keras.layers.Layer):
    """
    Creates a Residual Block of type used in the network in https://github.com/davidcpage/cifar10-fast (see demo.ipynbwhich details the structure). Structure:

    Let CBR = Convolutional layer 2D -> Batch Norm -> ReLU

    Then the structure is x -> CBR -> CBR + x

    Arguments:
        filters (int): number of filters used in all convolutions, and gives num channels of output tensor
        kernel_size (int or tuple of 2 integers): window size used in all convolutions
    input_shape (tuple of 3 integers): only needed if this is the first layer in the network. Use (rows, cols, filters)

    Output shape:
        (batch_size, rows, cols, filters)

    Returns:
        A tensor of rank 4
    """

    def __init__(self, filters, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.conv2d_1 = tf.keras.layers.Conv2D(filters=filters,
                                             kernel_size=kernel_size,
                                             padding='SAME',
                                             use_bias=False,
                                             activation='relu',
                                             kernel_initializer=tf.keras.initializers.he_uniform(),
                                             kernel_regularizer=tf.keras.regularizers.l2(1e-3)
)
        self.conv2d_2 = tf.keras.layers.Conv2D(filters=filters,
                                             kernel_size=kernel_size,
                                             padding='SAME',
                                             activation='relu',
                                             use_bias=False,
                                             kernel_initializer=tf.keras.initializers.he_uniform(),
                                             kernel_regularizer=tf.keras.regularizers.l2(1e-3)
)
        self.batch_norm_1 = tf.keras.layers.BatchNormalization(epsilon=1.0e-5, momentum=0.9)
        self.batch_norm_2 = tf.keras.layers.BatchNormalization(epsilon=1.0e-5, momentum=0.9)

    def call(self, inputs):
        """
        Residual behaviour implemented here.
        """
        x = self.conv2d_1(inputs)
        x = self.batch_norm_1(x)
        x = self.conv2d_2(x)
        x = self.batch_norm_2(x)
        x = x + inputs
        return x

# Augment

In [0]:
class NormalizeZscore:
    """
    Normalize numpy array of images by z-score, independently for each channel.
    Note: converts to float32 and not the default of float 64.

    Attributes:
        self.means (numpy array) = list of means, calculated across dataset. one mean per channel
        self.stds (numpy array) = list of stds, calculated across dataset. one std per channel
    
    Methods:
        self.fit = calculate means and stds of train dataset and store as class attribute
        self.transform = normalize dataset by z-score
    """
    def __init__(self):
        self.means = None
        self.stds = None

    def fit(self, train_dataset):
        self.means = np.mean(train_dataset, axis=(0,1,2))#, dtype=np.float32)
        self.stds = np.std(train_dataset, axis=(0,1,2))#, dtype=np.float32)
        return None

    def transform(self, dataset):
        dataset = (dataset - self.means)/self.stds
        return dataset

def flip(images):
    """
    Use on 3D image (height, width, channels) or 4D (batch_size, height, width, channels)
    """
    images = tf.image.random_flip_left_right(images)
    return images

# TODO investigate speed of just using tf function for shifts, padding often expensive
def pad_and_crop(images):
    """
    Use on 4D image (height, width, channels) 
    """
    original_shape = tf.shape(images)
    # pad 4 pixels above, below, left right
    paddings = tf.constant([[0, 0], [4, 4], [4, 4], [0, 0]])
    images = tf.pad(images, paddings=paddings, mode="REFLECT")
    images = tf.image.random_crop(images, size=original_shape)
    return images
  
def cutout(images, cutout_dim=[8,8]):
    '''
    Takes a 4D variable, channels last. Needs package tensorflow_addons

    Args:
        cutout_dim (list of integers): [height, width]
    '''
    images = tfa.image.random_cutout(images, mask_size = tf.constant(cutout_dim))
    return images

@tf.function # TODO check if decorating here improves performance
def augment(images, labels):
    """
    Apply random LR flip, pad with 4 pixels, then crop to 32x32, then apply 8x8 cutout.
    Use on 4D image (batch_size, height, width, channels) 
    Use on datasets create with images paired with labels ie.

    dataset = tf.from_tensor_slices((images, labels))
    """
    images = pad_and_crop(images)
    images = flip(images)
    images = cutout(images)
    return images, labels

# Check Augmentation worked

In [0]:
# modified from https://www.wouterbulten.nl/blog/tech/data-augmentation-using-tensorflow-data-dataset/
def plot_augmented_images(dataset, n_images, samples_per_image):
    """
    Checks what repeated calls to the dataset do by plotting a grid of images.

    1. If not using .batch already in the dataset, need to modify the iterator here
    by writing 
    
    for images in dataset.repeat(samples_per_image).batch(n_images):

    2. If .batch is already being used as a method on the dataset, then this is the number of images
    that will be displayed and

    for images in dataset.repeat(samples_per_image):
    
    should be used. The correct n_images=batch_size also needs to be passed as 
    argument to the function in this second case.
    """
    # initialize output array
    output = np.zeros((32*n_images, 32*samples_per_image, 3))

    row = 0
    for images, labels in dataset.repeat(samples_per_image): #.batch(n_images):
        # replace row in output array with several transforms of the image
        output[:, row*32:(row+1)*32, :] = np.vstack(images.numpy())
        row += 1

    plt.figure(figsize=(10,10))
    plt.imshow(output)
    plt.show()

# get data and test augmentation
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

# visualize augmentation on a few images
n_images = 8

# need to convert to float32 from int8 and normalize to visualize augmentation
train_sample = (train_images[0:n_images]/255).astype(np.float32)
train_sample_labels = train_labels[0:n_images]

# convert to tensorflow datasets and apply pad, crop, cutout via .map(augment, ..)
train_sample_dataset = tf.data.Dataset.from_tensor_slices((train_sample, train_sample_labels))

augmented_train_sample_dataset = (train_sample_dataset
                                  .batch(n_images)
                                  .map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                                  )

plot_augmented_images(augmented_train_sample_dataset, n_images=n_images, samples_per_image=10)

## Train

In [0]:
# Build the model (structure, sizes of filters, convolutional kernels, pooling sizes can be found at https://github.com/davidcpage/cifar10-fast/blob/master/demo.ipynb)

# Their implementation multiplies logits by 0.125, decorate for speed improvements
@tf.function
def scalar_multiply(x):
    return 0.125*x

model = tf.keras.models.Sequential([
    Conv2DBatchNormReLU(filters=64, input_shape=(32,32,3)),
    ResNet9Block(filters=128),
    Conv2DBatchNormReLU(filters=256),
    tf.keras.layers.MaxPool2D(2),
    ResNet9Block(filters=512),
    tf.keras.layers.MaxPool2D(4),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, use_bias=False, activation=scalar_multiply)#,
    # multiply logits by 0.125
    #ScalarMultiply(scalar=0.125)
])

model.summary()

batch_size = 512
num_epochs = 24
warmup_epochs = 5
max_learning_rate = 0.4

# learning rate initially linearly increases, then decreases linearly
learning_rates = list(np.linspace(0.08, max_learning_rate, warmup_epochs-1, endpoint=False)) 
learning_rates += list(np.linspace(max_learning_rate, 0.0, num_epochs-warmup_epochs+1, endpoint=False))
#learning_rates += list(np.linspace(0.002, 0.0, 6, endpoint=False))

lr_schedule = tf.keras.callbacks.LearningRateScheduler(lambda epoch: learning_rates[epoch], verbose=1)

# Needs package tensorflow_addons for SGDW
# Compiling the model adds a loss function, optimiser and metrics to track during training
model.compile(tfa.optimizers.SGDW(learning_rate=1e-3, momentum=0.9, nesterov=True, weight_decay=5e-4),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# get data
(train, train_labels), (test, test_labels) = tf.keras.datasets.cifar10.load_data()
cifar_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Take the last 10000 images from the training set to form a validation set
train_labels = train_labels.squeeze()
validation = train[-10000:, :, :]
validation_labels = train_labels[-10000:]
train = train[:-10000, :, :]
train_labels = train_labels[:-10000]

# normalize by z-score, fitting parameters from train
zscore = NormalizeZscore()
zscore.fit(train)
train = zscore.transform(train)
validation = zscore.transform(validation)
test = zscore.transform(test)

# create datasets
train_dataset = tf.data.Dataset.from_tensor_slices((train, train_labels))
validation_dataset = tf.data.Dataset.from_tensor_slices((validation, validation_labels)) 
test_dataset = tf.data.Dataset.from_tensor_slices((test, test_labels)) 

# augment
train_dataset = (train_dataset
                 .cache()
                 .repeat()  # keep augmenting on the fly
                 .shuffle(40000, reshuffle_each_iteration=True)  # may want to change this
                 .batch(batch_size)
                 .map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                 .prefetch(tf.data.experimental.AUTOTUNE)
                  )

# no augmentation on validation and test
validation_dataset = (validation_dataset
                      .cache()
                      .batch(batch_size)
                      .prefetch(tf.data.experimental.AUTOTUNE)
                      )
test_dataset = (test_dataset
                      .cache()
                      .batch(batch_size)
                      .prefetch(tf.data.experimental.AUTOTUNE)
                      )

# train
model_fit = model.fit(train_dataset,
                          epochs=num_epochs,
                          steps_per_epoch=40000//batch_size,
                          callbacks=[lr_schedule],
                          validation_data=validation_dataset
)

print('Training complete')

def plot_loss(model_fit):
    figure, axes = plt.subplots(1, 3, figsize=(30,10)) 
    for i, metric in enumerate(['loss', 'accuracy']):
        axes[2*i].plot(model_fit.epoch, model_fit.history[metric], color='b', label='Train')
        axes[2*i].plot(model_fit.epoch, model_fit.history['val_'+metric], color='g', label='Validation')
        axes[2*i].set(xlabel='Epoch', ylabel=metric)
        axes[2*i].legend()
    axes[1].plot(model_fit.epoch, np.log(model_fit.history['loss']), color='b', label='Train')
    axes[1].plot(model_fit.epoch, np.log(model_fit.history['val_loss']), color='g', label='Validation')
    axes[1].set(xlabel='Epoch', ylabel="Log loss")
    plt.show()

plot_loss(model_fit)

## Using less custom layer grouping - runs the same speed as my implementation proving my implementation is fine

In [0]:
cifar = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar.load_data()
cifar_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Take the last 10000 images from the training set to form a validation set
train_labels = train_labels.squeeze()
validation_images = train_images[-10000:, :, :]
validation_labels = train_labels[-10000:]
train_images = train_images[:-10000, :, :]
train_labels = train_labels[:-10000]

# Build the model (structure, sizes of filters, convolutional kernels, pooling sizes can be found at https://github.com/davidcpage/cifar10-fast/blob/master/demo.ipynb)

model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=64, input_shape=(32,32,3),
                                            kernel_size=3,
                                            padding='SAME',
                                            activation='relu',
                                            use_bias=False,
                                            kernel_initializer=tf.keras.initializers.he_uniform(),
                                            kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
        tf.keras.layers.BatchNormalization(epsilon=1.0e-5, momentum=0.9),
        tf.keras.layers.Conv2D(filters=128, 
                                            kernel_size=3,
                                            padding='SAME',
                                            activation='relu',
                                            use_bias=False,
                                            kernel_initializer=tf.keras.initializers.he_uniform(),
                                            kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
        tf.keras.layers.BatchNormalization(epsilon=1.0e-5, momentum=0.9),
        tf.keras.layers.MaxPool2D(2),
        ResNet9BlockShort(filters=128),
        tf.keras.layers.Conv2D(filters=256, 
                                            kernel_size=3,
                                            padding='SAME',
                                            activation='relu',
                                            use_bias=False,
                                            kernel_initializer=tf.keras.initializers.he_uniform(),
                                            kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
        tf.keras.layers.BatchNormalization(epsilon=1.0e-5, momentum=0.9),
        tf.keras.layers.MaxPool2D(2),
        tf.keras.layers.Conv2D(filters=512, 
                                            kernel_size=3,
                                            padding='SAME',
                                            activation='relu',
                                            use_bias=False,
                                            kernel_initializer=tf.keras.initializers.he_uniform(),
                                            kernel_regularizer=tf.keras.regularizers.l2(1e-3)),
        tf.keras.layers.BatchNormalization(epsilon=1.0e-5, momentum=0.9),
        tf.keras.layers.MaxPool2D(2),
        ResNet9BlockShort(filters=512),
        tf.keras.layers.MaxPool2D(4),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, use_bias=False)
        # Their implementation multiplies logits by 0.125
        #tf.keras.layers.Multiply()
])

model.summary()

# TODO visualize (tried using functional API, but some more needed to get at residual blocks)
# tf.keras.utils.plot_model(
#     model, to_file='model.png', show_shapes=True, show_layer_names=True,
#     rankdir='TB', expand_nested=True
# )

batch_size = 512
num_epochs = 35

# triangle shape learning rate (abs upside-down, shifted, scaled)
amplitude = 0.001 # decided by initial search, see below (default 0.001)
offset = amplitude/10 # recommended to be 1/10 or 1/20 of amplitude
warmup_epochs = 5 # increase learning rate for first few epochs
last_epochs = 10 # use a very small learning rate over last few epochs

def triangle_list(amplitude, epochs, offset):
    return [-abs(i*2*amplitude/epochs - amplitude) + amplitude + offset for i in range(epochs)]

# first train with triangular learning rate, then small learning rate for last few epochs
#learning_rates = triangle_list(amplitude, triangle_epochs, offset) + list(np.linspace(offset, offset/100, num_epochs-triangle_epochs))
#learning_rates = list(np.linspace(amplitude, offset, triangle_epochs)) + list(np.linspace(offset, offset/10, num_epochs-triangle_epochs))
#learning_rates = [1e0]*3 + [1e-1]*7 + [1e-2]*10

# from their github
#learning_rates = list(np.linspace(0.08, 0.4, triangle_epochs, endpoint=False)) + list(np.linspace(0.4, 0.0, num_epochs-triangle_epochs+1, endpoint=False))

# my version
learning_rates = list(np.linspace(offset, amplitude, warmup_epochs, endpoint=False)) 
learning_rates += list(np.linspace(amplitude, offset/2, num_epochs-warmup_epochs-last_epochs-5, endpoint=True)) 
learning_rates += list(np.linspace(offset/2, offset/100, last_epochs, endpoint=True))
learning_rates += list(np.linspace(1e-5, 1e-6, last_epochs, endpoint=True))


#learning_rates = [i*10**j for j in range(-5,-1) for i in range(1,10,2)] # 20 epochs, search for max learning rate initially
lr_schedule = tf.keras.callbacks.LearningRateScheduler(lambda epoch: learning_rates[epoch], verbose=1)

# Compiling the model adds a loss function, optimiser and metrics to track during training
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# cutout from https://github.com/yu4u/cutout-random-erasing
#get_random_eraser(p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1/0.3,
#                  v_l=0, v_h=255, pixel_level=False)

#    p : the probability that random erasing is performed
#    s_l, s_h : minimum / maximum proportion of erased area against input image
#    r_1, r_2 : minimum / maximum aspect ratio of erased area
#    v_l, v_h : minimum / maximum value for erased area
#    pixel_level : pixel-level randomization for erased area
def get_random_eraser(p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255, pixel_level=False):
    def eraser(input_img):
        img_h, img_w, img_c = input_img.shape
        p_1 = np.random.rand()

        if p_1 > p:
            return input_img

        while True:
            s = np.random.uniform(s_l, s_h) * img_h * img_w
            r = np.random.uniform(r_1, r_2)
            w = int(np.sqrt(s / r))
            h = int(np.sqrt(s * r))
            left = np.random.randint(0, img_w)
            top = np.random.randint(0, img_h)

            if left + w <= img_w and top + h <= img_h:
                break

        if pixel_level:
            c = np.random.uniform(v_l, v_h, (h, w, img_c))
        else:
            c = np.random.uniform(v_l, v_h)

        input_img[top:top + h, left:left + w, :] = c

        return input_img

    return eraser

# always apply cutout with pixel value 0, size 8x8 (8^2/32^2 = 0.0625 to use in their implementation)
#eraser = get_random_eraser(p=1.0, s_l=0.0625, s_h=.0625, r_1=1.0, r_2=1.0, v_l=0, v_h=0)
# using some erasing defaults (leads to random choices of width and height and size)
eraser = get_random_eraser(p=0.8, s_h=0.0625)

# Data preprocessing (z-score normalize) and augmentation (LR flip, TODO cutout)
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    horizontal_flip=True,
    height_shift_range=8, # this means +- 4 pixels
    width_shift_range=8,
    preprocessing_function=eraser # cutout
    )


test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True)

# compute quantities required for normalization based on train
train_datagen.fit(train_images)
test_datagen.fit(train_images)

# train
model_fit = model.fit(train_datagen.flow(train_images, train_labels, batch_size=batch_size),
                          epochs=num_epochs,
                          callbacks=[lr_schedule],
                          validation_data=test_datagen.flow(validation_images, validation_labels)
)

print('Training complete')

def plot_loss(model_fit):
    figure, axes = plt.subplots(1, 3, figsize=(30,10)) 
    for i, metric in enumerate(['loss', 'accuracy']):
        axes[2*i].plot(model_fit.epoch, model_fit.history[metric], color='b', label='Train')
        axes[2*i].plot(model_fit.epoch, model_fit.history['val_'+metric], color='g', label='Validation')
        axes[2*i].set(xlabel='Epoch', ylabel=metric)
        axes[2*i].legend()
    axes[1].plot(model_fit.epoch, np.log(model_fit.history['loss']), color='b', label='Train')
    axes[1].plot(model_fit.epoch, np.log(model_fit.history['val_loss']), color='g', label='Validation')
    axes[1].set(xlabel='Epoch', ylabel="Log loss")
    plt.show()

plot_loss(model_fit)

# Timing their implementation

In [0]:
####################
## CORE
#####################

import inspect
from collections import namedtuple, defaultdict
from functools import partial
import functools
from itertools import chain, count, islice as take

#####################
## dict utils
#####################

union = lambda *dicts: {k: v for d in dicts for (k, v) in d.items()}

make_tuple = lambda path: (path,) if isinstance(path, str) else path

def path_iter(nested_dict, pfx=()):
    for name, val in nested_dict.items():
        if isinstance(val, dict): yield from path_iter(val, pfx+make_tuple(name))
        else: yield (pfx+make_tuple(name), val)  
            
map_values = lambda func, dct: {k: func(v) for k,v in dct.items()}

def map_nested(func, nested_dict):
    return {k: map_nested(func, v) if isinstance(v, dict) else func(v) for k,v in nested_dict.items()}

def group_by_key(seq):
    res = defaultdict(list)
    for k, v in seq: 
        res[k].append(v) 
    return res

reorder = lambda dct, keys: {k: dct[k] for k in keys}

#####################
## graph building
#####################

def identity(value): return value

def build_graph(net, path_map='_'.join):
    net = {path: node if len(node) is 3 else (*node, None) for path, node in path_iter(net)}
    default_inputs = chain([('input',)], net.keys())
    resolve_path = lambda path, pfx: pfx+path if (pfx+path in net or not pfx) else resolve_path(net, path, pfx[:-1])
    return {path_map(path): (typ, value, ([path_map(default)] if inputs is None else [path_map(resolve_path(make_tuple(k), path[:-1])) for k in inputs])) 
            for (path, (typ, value, inputs)), default in zip(net.items(), default_inputs)}

#####################
## network visualisation (requires pydot)
#####################
import IPython.display

class ColorMap(dict):
    palette = (
        'bebada,ffffb3,fb8072,8dd3c7,80b1d3,fdb462,b3de69,fccde5,bc80bd,ccebc5,ffed6f,1f78b4,33a02c,e31a1c,ff7f00,'
        '4dddf8,e66493,b07b87,4e90e3,dea05e,d0c281,f0e189,e9e8b1,e0eb71,bbd2a4,6ed641,57eb9c,3ca4d4,92d5e7,b15928'
    ).split(',')
 
    def __missing__(self, key):
        self[key] = self.palette[len(self) % len(self.palette)]
        return self[key]

def make_pydot(nodes, edges, direction='LR', sep='_', **kwargs):
    from pydot import Dot, Cluster, Node, Edge
    class Subgraphs(dict):
        def __missing__(self, path):
            *parent, label = path
            subgraph = Cluster(sep.join(path), label=label, style='rounded, filled', fillcolor='#77777744')
            self[tuple(parent)].add_subgraph(subgraph)
            return subgraph
    g = Dot(rankdir=direction, directed=True, **kwargs)
    g.set_node_defaults(
        shape='box', style='rounded, filled', fillcolor='#ffffff')
    subgraphs = Subgraphs({(): g})
    for path, attr in nodes:
        *parent, label = path.split(sep)
        subgraphs[tuple(parent)].add_node(
            Node(name=path, label=label, **attr))
    for src, dst, attr in edges:
        g.add_edge(Edge(src, dst, **attr))
    return g

class DotGraph():
    colors = ColorMap()   
    def __init__(self, graph, size=15, direction='LR'):
        self.nodes = [(k, {
            'tooltip': '%s %.1000r' % (typ, value), 
            'fillcolor': '#'+self.colors[typ],
        }) for k, (typ, value, inputs) in graph.items()] 
        self.edges = [(src, k, {}) for (k, (_,_,inputs)) in graph.items() for src in inputs]
        self.size, self.direction = size, direction

    def dot_graph(self, **kwargs):
        return make_pydot(self.nodes, self.edges, size=self.size, 
                            direction=self.direction, **kwargs)

    def svg(self, **kwargs):
        return self.dot_graph(**kwargs).create(format='svg').decode('utf-8')

    try:
        import pydot
        def _repr_svg_(self):
            return self.svg()
    except ImportError:
        def __repr__(self):
            return 'pydot is needed for network visualisation'


#####################
## Layers
##################### 

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from collections import namedtuple
import copy

torch.backends.cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cpu = torch.device('cpu')
    
class Network(nn.Module):
    def __init__(self, net, loss=None):
        super().__init__()
        self.graph = {path: (typ, typ(**params), inputs) for path, (typ, params, inputs) in build_graph(net).items()}
        self.loss = loss or identity
        for path, (_,node,_) in self.graph.items(): 
            setattr(self, path, node)
    
    def nodes(self):
        return (node for _,node,_ in self.graph.values())
    
    def forward(self, inputs):
        outputs = dict(inputs)
        for k, (_, node, ins) in self.graph.items():
            outputs[k] = node(*[outputs[x] for x in ins])
        return outputs
    
    def half(self):
        for node in self.nodes():
            if isinstance(node, nn.Module) and not isinstance(node, nn.BatchNorm2d):
                node.half()
        return self

build_model = lambda network, loss: Network(network, loss).half().to(device)
show = lambda network, size=15: display(DotGraph(network.graph if isinstance(network, Network) else build_graph(network), size=size))
    
class Add(namedtuple('Add', [])):
    def __call__(self, x, y): return x + y 
    
class AddWeighted(namedtuple('AddWeighted', ['wx', 'wy'])):
    def __call__(self, x, y): return self.wx*x + self.wy*y 
    
class Identity(namedtuple('Identity', [])):
    def __call__(self, x): return x

class BatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, weight=True, bias=True):
        super().__init__(num_features, eps=eps, momentum=momentum)
        self.weight.data.fill_(1.0)
        self.bias.data.fill_(0.0)
        self.weight.requires_grad = weight
        self.bias.requires_grad = bias

class GhostBatchNorm(BatchNorm):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits
        self.register_buffer('running_mean', torch.zeros(num_features*self.num_splits))
        self.register_buffer('running_var', torch.ones(num_features*self.num_splits))

    def train(self, mode=True):
        if (self.training is True) and (mode is False): #lazily collate stats when we are going to use them
            self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
            self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)
        return super().train(mode)
        
    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            return F.batch_norm(
                input.view(-1, C*self.num_splits, H, W), self.running_mean, self.running_var, 
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W) 
        else:
            return F.batch_norm(
                input, self.running_mean[:self.num_features], self.running_var[:self.num_features], 
                self.weight, self.bias, False, self.momentum, self.eps)
        
class Mul(nn.Module):
    def __init__(self, weight):
        super().__init__()
        self.weight = weight
    def __call__(self, x): 
        return x*self.weight
    
class Flatten(nn.Module):
    def forward(self, x): 
        return x.view(x.size(0), x.size(1))

# Losses
class CrossEntropyLoss(namedtuple('CrossEntropyLoss', [])):
    def __call__(self, log_probs, target):
        return torch.nn.functional.nll_loss(log_probs, target, reduction='none')
    
class KLLoss(namedtuple('KLLoss', [])):        
    def __call__(self, log_probs):
        return -log_probs.mean(dim=1)

class Correct(namedtuple('Correct', [])):
    def __call__(self, classifier, target):
        return classifier.max(dim = 1)[1] == target

class LogSoftmax(namedtuple('LogSoftmax', ['dim'])):
    def __call__(self, x):
        return torch.nn.functional.log_softmax(x, self.dim, _stacklevel=5)

    
# node definitions   
from inspect import signature    
empty_signature = inspect.Signature()

class node_def(namedtuple('node_def', ['type'])):
    def __call__(self, *args, **kwargs):
        return (self.type, dict(signature(self.type).bind(*args, **kwargs).arguments))

conv = node_def(nn.Conv2d)
linear = node_def(nn.Linear)
batch_norm = node_def(BatchNorm)
pool = node_def(nn.MaxPool2d)
relu = node_def(nn.ReLU)
    
def map_types(mapping, net):
    def f(node):
        typ, *rest = node
        return (mapping.get(typ, typ), *rest)
    return map_nested(f, net) 

#####################
## Compat
##################### 

def to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()  
    return x
  
def flip_lr(x):
    if isinstance(x, torch.Tensor):
        return torch.flip(x, [-1]) 
    return x[..., ::-1].copy()
  
trainable_params = lambda model: {k:p for k,p in model.named_parameters() if p.requires_grad}

#####################
## Optimisers
##################### 

from functools import partial

def nesterov_update(w, dw, v, lr, weight_decay, momentum):
    dw.add_(weight_decay, w).mul_(-lr)
    v.mul_(momentum).add_(dw)
    w.add_(dw.add_(momentum, v))

norm = lambda x: torch.norm(x.reshape(x.size(0),-1).float(), dim=1)[:,None,None,None]

def LARS_update(w, dw, v, lr, weight_decay, momentum):
    nesterov_update(w, dw, v, lr*(norm(w)/(norm(dw)+1e-2)).to(w.dtype), weight_decay, momentum)

def zeros_like(weights):
    return [torch.zeros_like(w) for w in weights]

def optimiser(weights, param_schedule, update, state_init):
    weights = list(weights)
    return {'update': update, 'param_schedule': param_schedule, 'step_number': 0, 'weights': weights,  'opt_state': state_init(weights)}

def opt_step(update, param_schedule, step_number, weights, opt_state):
    step_number += 1
    param_values = {k: f(step_number) for k, f in param_schedule.items()}
    for w, v in zip(weights, opt_state):
        if w.requires_grad:
            update(w.data, w.grad.data, v, **param_values)
    return {'update': update, 'param_schedule': param_schedule, 'step_number': step_number, 'weights': weights,  'opt_state': opt_state}

LARS = partial(optimiser, update=LARS_update, state_init=zeros_like)
SGD = partial(optimiser, update=nesterov_update, state_init=zeros_like)
  
class PiecewiseLinear(namedtuple('PiecewiseLinear', ('knots', 'vals'))):
    def __call__(self, t):
        return np.interp([t], self.knots, self.vals)[0]
     
class Const(namedtuple('Const', ['val'])):
    def __call__(self, x):
        return self.val

#####################
## DATA
##################### 

import torchvision
from functools import lru_cache as cache

@cache(None)
def cifar10(root='./data'):
    download = lambda train: torchvision.datasets.CIFAR10(root=root, train=train, download=True)
    return {k: {'data': torch.tensor(v.data), 'targets': torch.tensor(v.targets)} 
            for k,v in [('train', download(True)), ('valid', download(False))]}
  
cifar10_mean, cifar10_std = [
    (125.31, 122.95, 113.87), # equals np.mean(cifar10()['train']['data'], axis=(0,1,2)) 
    (62.99, 62.09, 66.70), # equals np.std(cifar10()['train']['data'], axis=(0,1,2))
]
cifar10_classes= 'airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck'.split(', ')

#####################
## data preprocessing
#####################
mean, std = [torch.tensor(x, device=device, dtype=torch.float16) for x in (cifar10_mean, cifar10_std)]

normalise = lambda data, mean=mean, std=std: (data - mean)/std
unnormalise = lambda data, mean=mean, std=std: data*std + mean
pad = lambda data, border: nn.ReflectionPad2d(border)(data)
transpose = lambda x, source='NHWC', target='NCHW': x.permute([source.index(d) for d in target]) 
to = lambda *args, **kwargs: (lambda x: x.to(*args, **kwargs))

def preprocess(dataset, transforms):
    dataset = copy.copy(dataset)
    for transform in reversed(transforms):
        dataset['data'] = transform(dataset['data'])
    return dataset

#####################
## Data augmentation
#####################

chunks = lambda data, splits: (data[start:end] for (start, end) in zip(splits, splits[1:]))

even_splits = lambda N, num_chunks: np.cumsum([0] + [(N//num_chunks)+1]*(N % num_chunks)  + [N//num_chunks]*(num_chunks - (N % num_chunks)))

def shuffled(xs, inplace=False):
    xs = xs if inplace else copy.copy(xs) 
    np.random.shuffle(xs)
    return xs

def transformed(data, targets, transform, max_options=None, unshuffle=False):
    i = torch.randperm(len(data), device=device)
    data = data[i]
    options = shuffled(transform.options(data.shape), inplace=True)[:max_options]
    data = torch.cat([transform.apply(x, **choice) for choice, x in zip(options, chunks(data, even_splits(len(data), len(options))))])
    return (data[torch.argsort(i)], targets) if unshuffle else (data, targets[i])

class Batches():
    def __init__(self, batch_size, transforms=(), dataset=None, shuffle=True, drop_last=False, max_options=None):
        self.dataset, self.transforms, self.shuffle, self.max_options = dataset, transforms, shuffle, max_options
        N = len(dataset['data'])
        self.splits = list(range(0, N+1, batch_size))
        if not drop_last and self.splits[-1] != N:
            self.splits.append(N)
     
    def __iter__(self):
        data, targets = self.dataset['data'], self.dataset['targets']
        for transform in self.transforms:
            data, targets = transformed(data, targets, transform, max_options=self.max_options, unshuffle=not self.shuffle)
        if self.shuffle:
            i = torch.randperm(len(data), device=device)
            data, targets = data[i], targets[i]
        return ({'input': x.clone(), 'target': y} for (x, y) in zip(chunks(data, self.splits), chunks(targets, self.splits)))
    
    def __len__(self): 
        return len(self.splits) - 1
    
#####################
## Augmentations
#####################

class Crop(namedtuple('Crop', ('h', 'w'))):
    def apply(self, x, x0, y0):
        return x[..., y0:y0+self.h, x0:x0+self.w] 

    def options(self, shape):
        *_, H, W = shape
        return [{'x0': x0, 'y0': y0} for x0 in range(W+1-self.w) for y0 in range(H+1-self.h)]
    
class FlipLR(namedtuple('FlipLR', ())):
    def apply(self, x, choice):
        return flip_lr(x) if choice else x 
        
    def options(self, shape):
        return [{'choice': b} for b in [True, False]]

class Cutout(namedtuple('Cutout', ('h', 'w'))):
    def apply(self, x, x0, y0):
        x[..., y0:y0+self.h, x0:x0+self.w] = 0.0
        return x

    def options(self, shape):
        *_, H, W = shape
        return [{'x0': x0, 'y0': y0} for x0 in range(W+1-self.w) for y0 in range(H+1-self.h)]  

#####################
## TRAINING
#####################

import time

class Timer():
    def __init__(self, synch=None):
        self.synch = synch or (lambda: None)
        self.synch()
        self.times = [time.perf_counter()]
        self.total_time = 0.0

    def __call__(self, update_total=True):
        self.synch()
        self.times.append(time.perf_counter())
        delta_t = self.times[-1] - self.times[-2]
        if update_total:
            self.total_time += delta_t
        return delta_t

default_table_formats = {float: '{:{w}.4f}', str: '{:>{w}s}', 'default': '{:{w}}', 'title': '{:>{w}s}'}

def table_formatter(val, is_title=False, col_width=12, formats=None):
    formats = formats or default_table_formats
    type_ = lambda val: float if isinstance(val, (float, np.float)) else type(val)
    return (formats['title'] if is_title else formats.get(type_(val), formats['default'])).format(val, w=col_width)

every = lambda n, col: (lambda data: data[col] % n == 0)

class Table():
    def __init__(self, keys=None, report=(lambda data: True), formatter=table_formatter):
        self.keys, self.report, self.formatter = keys, report, formatter
        self.log = []
        
    def append(self, data):
        self.log.append(data)
        data = {' '.join(p): v for p,v in path_iter(data)}
        self.keys = self.keys or data.keys()
        if len(self.log) is 1:
            print(*(self.formatter(k, True) for k in self.keys))
        if self.report(data):
            print(*(self.formatter(data[k]) for k in self.keys))
            
    def df(self):
        return pd.DataFrame([{'_'.join(p): v for p,v in path_iter(row)} for row in self.log])     
            
def reduce(batches, state, steps):
    #state: is a dictionary
    #steps: are functions that take (batch, state)
    #and return a dictionary of updates to the state (or None)
    
    for batch in chain(batches, [None]): 
    #we send an extra batch=None at the end for steps that 
    #need to do some tidying-up (e.g. log_activations)
        for step in steps:
            updates = step(batch, state)
            if updates:
                for k,v in updates.items():
                    state[k] = v                  
    return state
  
#define keys in the state dict as constants
MODEL = 'model'
VALID_MODEL = 'valid_model'
OUTPUT = 'output'
OPTS = 'optimisers'
ACT_LOG = 'activation_log'
WEIGHT_LOG = 'weight_log'

#step definitions
def forward(training_mode):
    def step(batch, state):
        if not batch: return
        model = state[MODEL] if training_mode or (VALID_MODEL not in state) else state[VALID_MODEL]
        if model.training != training_mode: #without the guard it's slow!
            model.train(training_mode)
        return {OUTPUT: model.loss(model(batch))}
    return step

def forward_tta(tta_transforms):
    def step(batch, state):
        if not batch: return
        model = state[MODEL] if (VALID_MODEL not in state) else state[VALID_MODEL]
        if model.training:
            model.train(False)
        logits = torch.mean(torch.stack([model({'input': transform(batch['input'].clone())})['logits'].detach() for transform in tta_transforms], dim=0), dim=0)
        return {OUTPUT: model.loss(dict(batch, logits=logits))}
    return step

def backward(dtype=torch.float16):
    def step(batch, state):
        state[MODEL].zero_grad()
        if not batch: return
        state[OUTPUT]['loss'].to(dtype).sum().backward()
    return step

def opt_steps(batch, state):
    if not batch: return
    return {OPTS: [opt_step(**opt) for opt in state[OPTS]]}

def log_activations(node_names=('loss', 'acc')):
    logs = []
    def step(batch, state):
        if batch:
            logs.extend((k, state[OUTPUT][k].detach()) for k in node_names)
        else:
            res = map_values((lambda xs: to_numpy(torch.cat(xs)).astype(np.float)), group_by_key(logs))
            logs.clear()
            return {ACT_LOG: res}
    return step

def update_ema(momentum, update_freq=1):
    n = iter(count())
    rho = momentum**update_freq
    def step(batch, state):
        if not batch: return
        if (next(n) % update_freq) != 0: return
        for v, ema_v in zip(state[MODEL].state_dict().values(), state[VALID_MODEL].state_dict().values()):
            ema_v *= rho
            ema_v += (1-rho)*v
    return step

train_steps = (forward(training_mode=True), log_activations(('loss', 'acc')), backward(), opt_steps)
valid_steps = (forward(training_mode=False), log_activations(('loss', 'acc')))

epoch_stats = lambda state: {k: np.mean(v) for k, v in state[ACT_LOG].items()}

def train_epoch(state, timer, train_batches, valid_batches, train_steps=train_steps, valid_steps=valid_steps, on_epoch_end=identity):
    train_summary, train_time = epoch_stats(on_epoch_end(reduce(train_batches, state, train_steps))), timer()
    valid_summary, valid_time = epoch_stats(reduce(valid_batches, state, valid_steps)), timer(update_total=False) #DAWNBench rules
    return {
        'train': union({'time': train_time}, train_summary), 
        'valid': union({'time': valid_time}, valid_summary), 
        'total time': timer.total_time
    }

summary = lambda logs, cols=['valid_acc']: logs.df().query('epoch==epoch.max()')[cols].describe().transpose().astype({'count': int})[
    ['count', 'mean', 'min', 'max', 'std']]

#on_epoch_end
def log_weights(state, weights):
    state[WEIGHT_LOG] = state.get(WEIGHT_LOG, [])
    state[WEIGHT_LOG].append({k: to_numpy(v.data) for k,v in weights.items()})
    return state

def fine_tune_bn_stats(state, batches, model_key=VALID_MODEL):
    reduce(batches, {MODEL: state[model_key]}, [forward(True)])
    return state

#misc
def warmup_cudnn(model, batch):
    #run forward and backward pass of the model
    #to allow benchmarking of cudnn kernels 
    reduce([batch], {MODEL: model}, [forward(True), backward()])
    torch.cuda.synchronize()


#####################
## Plotting
#####################

import altair as alt
alt.renderers.enable('colab')
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import SVG

def empty_plot(ax, **kw):
    ax.axis('off')
    return ax

def image_plot(ax, img, title):
    ax.imshow(to_numpy(unnormalise(transpose(img, 'CHW', 'HWC'))).astype(np.int))
    ax.set_title(title)
    ax.axis('off')

def layout(figures, sharex=False, sharey=False, figure_title=None, col_width=4, row_height = 3.25, **kw):
    nrows, ncols = np.array(figures).shape

    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, figsize=(col_width*ncols, row_height*nrows))
    axs = [figure(ax, **kw) for row in zip(np.array(axs).reshape(nrows, ncols), figures) for ax, figure in zip(*row)]
    fig.suptitle(figure_title)
    return fig, axs

#####################
## Network
#####################

conv_block = lambda c_in, c_out: {
    'conv': conv(in_channels=c_in, out_channels=c_out, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), 
    'norm': batch_norm(c_out), 
    'act':  relu(),
}

conv_pool_block = lambda c_in, c_out: dict(conv_block(c_in, c_out), pool=pool(2))
conv_pool_block_pre = lambda c_in, c_out: reorder(conv_pool_block(c_in, c_out), ('conv', 'pool', 'norm', 'act'))

residual = lambda c, conv_block: {
    'in': (Identity, {}),
    'res1': conv_block(c, c),
    'res2': conv_block(c, c),
    'out': (Identity, {}),
    'add': (Add, {}, ['in', 'out']),
}

def build_network(channels, extra_layers, res_layers, scale, conv_block=conv_block, 
                  prep_block=conv_block, conv_pool_block=conv_pool_block, types=None): 
    net = {
        'prep': prep_block(3, channels['prep']),
        'layer1': conv_pool_block(channels['prep'], channels['layer1']),
        'layer2': conv_pool_block(channels['layer1'], channels['layer2']),
        'layer3': conv_pool_block(channels['layer2'], channels['layer3']),
        'pool': pool(4),
        'classifier': {
            'flatten': (Flatten, {}),
            'conv': linear(channels['layer3'], 10, bias=False),
            'scale': (Mul, {'weight': scale}),
        },
        'logits': (Identity, {}),
    }
    for layer in res_layers:
        net[layer]['residual'] = residual(channels[layer], conv_block)
    for layer in extra_layers:
        net[layer]['extra'] = conv_block(channels[layer], channels[layer])     
    if types: net = map_types(types, net)
    return net

channels={'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512}
network = partial(build_network, channels=channels, extra_layers=(), res_layers=('layer1', 'layer3'), scale=1/8)   

x_ent_loss = Network({
  'loss':  (nn.CrossEntropyLoss, {'reduction': 'none'}, ['logits', 'target']),
  'acc': (Correct, {}, ['logits', 'target'])
})

label_smoothing_loss = lambda alpha: Network({
        'logprobs': (LogSoftmax, {'dim': 1}, ['logits']),
        'KL':  (KLLoss, {}, ['logprobs']),
        'xent':  (CrossEntropyLoss, {}, ['logprobs', 'target']),
        'loss': (AddWeighted, {'wx': 1-alpha, 'wy': alpha}, ['xent', 'KL']),
        'acc': (Correct, {}, ['logits', 'target']),
    })

#####################
## Misc
#####################

lr_schedule = lambda knots, vals, batch_size: PiecewiseLinear(np.array(knots)*len(train_batches(batch_size)), np.array(vals)/batch_size)

In [0]:
#####################
## Config
#####################

N_RUNS = 1 #number of times to run each experiment

In [0]:
!git clone -q https://github.com/davidcpage/cifar10-fast.git
!cd cifar10-fast && python -m dawn --data_dir=~/data

In [0]:
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

def plot_images(dataset, n_images, samples_per_image):
    output = np.zeros((32 * n_images, 32 * samples_per_image, 3))

    row = 0
    for images in dataset.repeat(samples_per_image).batch(n_images):
        output[:, row*32:(row+1)*32] = np.vstack(images.numpy())
        row += 1

    plt.figure(figsize=(10,10))
    plt.imshow(output)
    plt.show()

def flip(x: tf.Tensor) -> tf.Tensor:
    """Flip augmentation

    Args:
        x: Image to flip

    Returns:
        Augmented image
    """
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)

    return x

def color(x: tf.Tensor) -> tf.Tensor:
    """Color augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """
    x = tf.image.random_hue(x, 0.08)
    x = tf.image.random_saturation(x, 0.6, 1.6)
    x = tf.image.random_brightness(x, 0.05)
    x = tf.image.random_contrast(x, 0.7, 1.3)
    return x

def rotate(x: tf.Tensor) -> tf.Tensor:
    """Rotation augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """

    return tf.image.rot90(x, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))

def zoom(x: tf.Tensor) -> tf.Tensor:
    """Zoom augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """

    # Generate 20 crop settings, ranging from a 1% to 20% crop.
    scales = list(np.arange(0.8, 1.0, 0.01))
    boxes = np.zeros((len(scales), 4))

    for i, scale in enumerate(scales):
        x1 = y1 = 0.5 - (0.5 * scale)
        x2 = y2 = 0.5 + (0.5 * scale)
        boxes[i] = [x1, y1, x2, y2]

    def random_crop(img):
        # Create different crops for an image
        crops = tf.image.crop_and_resize([img], boxes=boxes, box_indices=np.zeros(len(scales)), crop_size=(32, 32))
        # Return a random crop
        return crops[tf.random.uniform(shape=[], minval=0, maxval=len(scales), dtype=tf.int32)]


    choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)

    # Only apply cropping 50% of the time
    return tf.cond(choice < 0.5, lambda: x, lambda: random_crop(x))

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

data = (x_train[0:8] / 255).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data)

# Add augmentations
augmentations = [flip, color, zoom, rotate]

for f in augmentations:
    dataset = dataset.map(lambda x: tf.cond(tf.random.uniform([], 0, 1) > 0.75, lambda: f(x), lambda: x), num_parallel_calls=4)
dataset = dataset.map(lambda x: tf.clip_by_value(x, 0, 1))

plot_images(dataset, n_images=8, samples_per_image=10)
