# **Demo:** Net2WiderNet on MNIST with LeNet5

The following demo shows how to apply Net2WiderNet to LeNet5 in order to increase the number of output filters of a convolutional layer. The input image shape is the one of MNIST, but the network and the Net2WiderNet algorithm can be applied to any other image size.

In [None]:
# Import libraries
import torch
import torchinfo

# Import custom modules and packages
from lenet import LeNet
import params.lenet_mnist
import net2net.net2net_wider

### 1. Create a LeNet5 model

We start by creating the standard LeNet5 model.

In [None]:
# Create a LeNet model
model = LeNet(nb_classes=params.lenet_mnist.NB_CLASSES)

# Create a random input
x = torch.randn(1,
                params.lenet_mnist.NB_CHANNELS,
                *params.lenet_mnist.IMAGE_SHAPE)

# Compute the output of the teacher network
y_teacher = model(x)

### 2. Create a wider version of LeNet5

We then apply the Net2WiderNet algorithm to the standard LeNet5 model to increase the number of output filters of the first convolutional layer. The weights and biases of the student model are initialized with those of the teacher model, in such a way that the output of the student model is the same as the output of the teacher model for the same input at initialization.

In [None]:
# Instantiate a Net2Net object from a (pre-trained) model
net2net = net2net.net2net_wider.Net2Net(teacher_network=model)

# Set the widening operations to be performed
# Here we only increase the width of the first convolutional layer
wider_operations = {"operation1": {"target_conv_layers": ["layer1.0"],
                    "width": [10],
                    "batch_norm_layers": ["layer1.1"],
                    "next_layers": ["layer2.0"]}}

# Add some noise to the copied weights (optional)
sigma = 0.  # Standard deviation of the noise


# Go through the list of widening operations
for key in wider_operations.keys():

    print("Widening operation: ", key)
    
    # Get the parameters of the wider operation
    target_conv_layers = wider_operations[key]["target_conv_layers"]
    next_layer = wider_operations[key]["next_layers"]
    new_width = wider_operations[key]["width"]
    batch_norm_layers = wider_operations[key]["batch_norm_layers"]

    # Widen a layer of the network
    net2net.net2wider(target_conv_layers=target_conv_layers,
                      next_layers=next_layer,
                      width=new_width,
                      batch_norm_layers=batch_norm_layers,
                      sigma=sigma)


# Compute the output of the student network
y_student = net2net.student_network(x)

### 3. Check that the student and teacher models have the same output for the same input

We check that the output of the student model is the same as the output of the teacher model for the same input at initialization. They can be slightly different if some noise has been added to the weights of the student model during the initialization.

In [None]:
# The outputs should be the same
print("Teacher output: ", y_teacher)
print("Student output: ", y_student, "\n")

### 4. Have a look at the student and teacher architectures

We display the student and teacher architectures to check that the student model has more filters than the teacher model in the first convolutional layer.

In [None]:
# Display the architecture of the student network
torchinfo.summary(model, input_size=(1,
                                     params.lenet_mnist.NB_CHANNELS,
                                     *params.lenet_mnist.IMAGE_SHAPE))

In [None]:
# Display the architecture of the student network
torchinfo.summary(net2net.student_network, input_size=(1,
                                                       params.lenet_mnist.NB_CHANNELS,
                                                       *params.lenet_mnist.IMAGE_SHAPE))