The following additional libraries are needed to run this
notebook. Note that running on Colab is experimental, please report a Github
issue if you have any problem.

In [1]:
%%capture

import sys
sys.path.append('..')
import mock_d2l_jax as d2l

# Network in Network (NiN)
:label:`sec_nin`

LeNet, AlexNet, and VGG all share a common design pattern:
extract features exploiting *spatial* structure
via a sequence of convolutions and pooling layers
and post-process the representations via fully connected layers.
The improvements upon LeNet by AlexNet and VGG mainly lie
in how these later networks widen and deepen these two modules.

This design poses two major challenges. 
First, the fully connected layers at the end
of the architecture consume tremendous numbers of parameters. For instance, even a simple
model such as VGG-11 requires a monstrous $25088 \times 4096$ matrix, occupying almost
400MB of RAM. This is a significant impediment to speedy computation, in particular on
mobile and embedded devices. Second, it is equally impossible to add fully connected layers
earlier in the network to increase the degree of nonlinearity: doing so would destroy the
spatial structure and require potentially even more memory.

The *network in network* (*NiN*) blocks of :cite:`Lin.Chen.Yan.2013` offer an alternative,
capable of solving both problems in one simple strategy.
They were proposed based on a very simple insight: (i) use $1 \times 1$ convolutions to add
local nonlinearities across the channel activations and (ii) use global average pooling to integrate
across all locations in the last representation layer. Note that global average pooling would not
be effective, were it not for the added nonlinearities. Let's dive into this in detail.


## (**NiN Blocks**)

Recall :numref:`subsec_1x1`. In it we discussed that the inputs and outputs of convolutional layers
consist of four-dimensional tensors with axes
corresponding to the example, channel, height, and width.
Also recall that the inputs and outputs of fully connected layers
are typically two-dimensional tensors corresponding to the example and feature.
The idea behind NiN is to apply a fully connected layer
at each pixel location (for each height and width).
The resulting $1 \times 1$ convolution can be thought as
a fully connected layer acting independently on each pixel location.

:numref:`fig_nin` illustrates the main structural
differences between VGG and NiN, and their blocks.
Note both the difference in the NiN blocks (the initial convolution is followed by $1 \times 1$ convolutions, whereas VGG retains $3 \times 3$ convolutions) and in the end where we no longer require a giant fully connected layer.

![Comparing architectures of VGG and NiN, and their blocks.](http://d2l.ai/_images/nin.svg)
:width:`600px`
:label:`fig_nin`


In [19]:
import jax
from jax import numpy as jnp, random, grad, vmap, jit
from flax import linen as nn
import optax
# from d2l import jax as d2l


def nin_block(out_channels, kernel_size, strides, padding):
    return nn.Sequential([
        nn.Conv(out_channels, kernel_size, strides, padding),
        nn.relu,
        nn.Conv(out_channels, kernel_size=(1, 1)), nn.relu,
        nn.Conv(out_channels, kernel_size=(1, 1)), nn.relu])

## [**NiN Model**]

NiN uses the same initial convolution sizes as AlexNet (it was proposed shortly thereafter).
The kernel sizes are $11\times 11$, $5\times 5$, and $3\times 3$, respectively,
and the numbers of output channels match those of AlexNet. Each NiN block is followed by a max-pooling layer
with a stride of 2 and a window shape of $3\times 3$.

The second significant difference between NiN and both AlexNet and VGG
is that NiN avoids fully connected layers altogether.
Instead, NiN uses a NiN block with a number of output channels equal to the number of label classes, followed by a *global* average pooling layer,
yielding a vector of logits.
This design significantly reduces the number of required model parameters, albeit at the expense of a potential increase in training time.


In [20]:
input_shape = (2, 4, 5, 3)
x = random.normal(random.PRNGKey(0), input_shape)
y = nn.avg_pool(x, (4, 5))
print(y.shape)

(2, 1, 1, 3)


In [26]:
class NiN(d2l.Classifier):
    lr: float = 0.1
    num_classes = 10
    
    def setup(self):
        max_pool = lambda x: nn.max_pool(x, (3, 3), strides=(2, 2))

        self.net = nn.Sequential([
            nin_block(96, kernel_size=(11, 11), strides=(4, 4), padding=(0, 0)),
            max_pool,
            nin_block(256, kernel_size=(5, 5), strides=(1, 1), padding=(2, 2)),
            max_pool,
            nin_block(384, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1)),
            max_pool,
            nn.Dropout(0.5, deterministic=False),
            nin_block(self.num_classes, kernel_size=(3, 3), strides=1, padding=(1, 1)),
            lambda x: nn.avg_pool(x, (1, 1)), # TODO: not sure what window size should be (global avg pooling)
            d2l.flatten])
        # self.net.apply(d2l.init_cnn)

We create a data example to see [**the output shape of each block**].


In [63]:
rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}

model = NiN()
X = random.normal(random.PRNGKey(0), (1, 224, 224, 1))
params = model.init(rngs, X)
jax.tree_map(lambda x: x.shape, params)

FrozenDict({
    params: {
        net: {
            layers_0: {
                layers_0: {
                    bias: (96,),
                    kernel: (11, 11, 1, 96),
                },
                layers_2: {
                    bias: (96,),
                    kernel: (1, 1, 96, 96),
                },
                layers_4: {
                    bias: (96,),
                    kernel: (1, 1, 96, 96),
                },
            },
            layers_2: {
                layers_0: {
                    bias: (256,),
                    kernel: (5, 5, 96, 256),
                },
                layers_2: {
                    bias: (256,),
                    kernel: (1, 1, 256, 256),
                },
                layers_4: {
                    bias: (256,),
                    kernel: (1, 1, 256, 256),
                },
            },
            layers_4: {
                layers_0: {
                    bias: (384,),
                    kernel: (3, 3, 256, 3

## [**Training**]

As before we use Fashion-MNIST to train the model.
NiN's training is similar to that for AlexNet and VGG.


In [None]:
model = NiN(lr=0.05)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(224, 224))
# model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data, rngs=rngs)

## Summary

NiN has dramatically fewer parameters than AlexNet and VGG. This stems from the fact that it needs no giant fully connected layers and fewer convolutions with wide kernels. Instead, it uses local $1 \times 1$ convolutions and global average pooling. These design choices influenced many subsequent CNN designs.

## Exercises

1. Why are there two $1\times 1$ convolutional layers per NiN block? What happens if you add one? What happens if you reduce this to one?
1. What happens if you replace the global average pooling by a fully connected layer (speed, accuracy, number of parameters)?
1. Calculate the resource usage for NiN.
    1. What is the number of parameters?
    1. What is the amount of computation?
    1. What is the amount of memory needed during training?
    1. What is the amount of memory needed during prediction?
1. What are possible problems with reducing the $384 \times 5 \times 5$ representation to a $10 \times 5 \times 5$ representation in one step?
1. Use the structural design decisions in VGG that led to VGG-11, VGG-16, and VGG-19 to design a family of NiN-like networks.


[Discussions](https://discuss.d2l.ai/t/80)
