In [1]:
%%capture
!pip install --upgrade trax

# Introduction

Prior to the introduction of [Wide Residual Networks](https://arxiv.org/pdf/1605.07146.pdf) (WRNs) by Sergey Zagoruyko and Nikos Komodakis, deep residual networks were shown to have a fractional increase in performance but at the cost of **doubling** the number of layers. This led to the problem of diminishing feature reuse and overall made the models slow to train. WRNs showed that having a wider residual network leads to better performance and increased the then SOTA results on CIFAR, SVHN and COCO. 

In this notebook we run through a simple demonstration of training a WideResnet on the `cifar10` dataset using the [Trax](https://github.com/google/trax) framework. Trax is an end-to-end library for deep learning that focuses on **clear code and speed**. It is actively used and maintained in the *Google Brain team*.

# Issues with Traditional Residual Networks

## Diminishing Feature Reuse

A Residual block with a identity mapping, which allows us to train very deep networks is a weakness. As the gradient flows through the network there is nothing to force it to go through the residual block weights and thus it can avoid learning during training. This only a few blocks can run valuable representations or many blocks could share very little information with small contributions to the final goal. This problem was tried to be addressed using a special case of dropout applied to residual blocks in which an identity scalar weight is added to each residual block on which dropout is applied.

# Importing Libraries

In [2]:
import trax
from trax import layers as tl
from trax.supervised import training
from trax.models.resnet import WideResnet

trax.fastmath.set_backend('tensorflow-numpy')

# Downloading Dataset

In [3]:
%%capture
train_stream = trax.data.TFDS('cifar10', keys=('image', 'label'), train=True)()
eval_stream = trax.data.TFDS('cifar10', keys=('image', 'label'), train=False)()

# Batch Generator

In [4]:
train_data_pipeline = trax.data.Serial(
    trax.data.Shuffle(),
    trax.data.Batch(64),
    trax.data.AddLossWeights(),
)

train_batches_stream = train_data_pipeline(train_stream)

eval_data_pipeline = trax.data.Serial(
    trax.data.Batch(64),
    trax.data.AddLossWeights(),
)

eval_batches_stream = eval_data_pipeline(eval_stream)

In [8]:
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=1000,
)

eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20,
)

In [9]:
model = tl.Serial(
    WideResnet(),
    tl.LogSoftmax()
)

training_loop = training.Loop(model, 
                              train_task, 
                              eval_tasks=[eval_task], 
                              output_dir='./cnn_model')

training_loop.run(5000)


Step   1000: Ran 200 train steps in 141.04 secs
Step   1000: train CrossEntropyLoss |  1.49697816
Step   1000: eval  CrossEntropyLoss |  1.31979529
Step   1000: eval          Accuracy |  0.52265625

Step   2000: Ran 1000 train steps in 572.22 secs
Step   2000: train CrossEntropyLoss |  1.16957819
Step   2000: eval  CrossEntropyLoss |  1.08194559
Step   2000: eval          Accuracy |  0.60468750

Step   3000: Ran 1000 train steps in 565.36 secs
Step   3000: train CrossEntropyLoss |  1.00960529
Step   3000: eval  CrossEntropyLoss |  0.94915995
Step   3000: eval          Accuracy |  0.64921875

Step   4000: Ran 1000 train steps in 569.97 secs
Step   4000: train CrossEntropyLoss |  0.90725946
Step   4000: eval  CrossEntropyLoss |  0.89641787
Step   4000: eval          Accuracy |  0.68984375

Step   5000: Ran 1000 train steps in 570.14 secs
Step   5000: train CrossEntropyLoss |  0.83230388
Step   5000: eval  CrossEntropyLoss |  0.87516099
Step   5000: eval          Accuracy |  0.69531250
