<center><img src='https://drive.google.com/uc?id=1_utx_ZGclmCwNttSe40kYA6VHzNocdET' height="60"></center>

AI TECH - Akademia Innowacyjnych Zastosowań Technologii Cyfrowych. Program Operacyjny Polska Cyfrowa na lata 2014-2020
<hr>

<center><img src='https://drive.google.com/uc?id=1BXZ0u3562N_MqCLcekI-Ens77Kk4LpPm'></center>

<center>
Projekt współfinansowany ze środków Unii Europejskiej w ramach Europejskiego Funduszu Rozwoju Regionalnego
Program Operacyjny Polska Cyfrowa na lata 2014-2020,
Oś Priorytetowa nr 3 "Cyfrowe kompetencje społeczeństwa" Działanie  nr 3.2 "Innowacyjne rozwiązania na rzecz aktywizacji cyfrowej"
Tytuł projektu:  „Akademia Innowacyjnych Zastosowań Technologii Cyfrowych (AI Tech)”
    </center>

Code based on https://github.com/pytorch/examples/blob/master/mnist/main.py

This exercise covers two aspects:
* In tasks 1-6 you will implement mechanisms that allow training deeper models (better initialization, batch normalization). Note that for dropout and batch norm you are expected to implement it yourself without relying on ready-made components from Pytorch. After doing each of the tasks you can look at the plots and check how your changes impact gradients of network layers.
* In task 7 you will implement a convnet using [conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html).


Tasks:
1. Check that the given implementation reaches 95% test accuracy for
   architecture input-64-64-10 in a few thousand batches.
2. Improve initialization and check that the network learns much faster
   and reaches over 97% test accuracy. A good basic initialization scheme is so-called Glorot initialization. For a set of weights going from a layer with $n_{in}$ neurons to a layer with $n_{out}$ neurons, it samples each weight from normal distribution with $0$ mean and standard deviation of $\sqrt{\frac{2}{n_{in}+n_{out}}}$.  
Check how better initialization changes distribution of gradients at the first epoch.
3. Check, that with proper initialization we can train architecture
   input-64-64-64-64-64-10, while with bad initialization it does
   not even get off the ground.
4. Add dropout implemented in pytorch (but without using torch.nn.Dropout)
5. Check that with 10 hidden layers (64 units each) even with proper
    initialization the network has a hard time to start learning.
6. Implement batch normalization (use train mode also for testing - it should perform well enough):
    * compute batch mean and variance
    * add new variables beta and gamma
    * check that the networks learns much faster for 5 layers
    * check that the network learns even for 10 hidden layers.
    * check how gradients change in comparison to network without batch norm.
7. So far we worked with a fully connected network. Design and implement in pytorch (by using pytorch functions) a simple convolutional network and achieve 99% test accuracy. The architecture is up to you, but even a few convolutional layers should be enough.

In [59]:
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn import init
import torchvision
import torchvision.transforms as transforms
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

if 'google.colab' in sys.modules:
    from google.colab import output
    output.enable_custom_widget_manager()

In [34]:
# @title Visualize gradients

class GradientVisualizer:
    def __init__(self, net, num_epochs):
        self.num_epochs = num_epochs
        self.linear_layers = self.get_linear_layers(net)

        self.grad_to_weight_fig = None
        self.grads_in_layers_fig = None
        self.grads_at_epochs_fig = None
        self.init_figures()

    def get_linear_layers(self, net):
        linear_layers = []
        for field in net.__dict__['_modules'].values():
            if isinstance(field, Linear):
                linear_layers.append(field)
            if isinstance(field, nn.ModuleList):
                for module in field:
                    if isinstance(module, Linear):
                        linear_layers.append(module)

        assert linear_layers, \
        ('No linear layers found. Linear layers should be parameters of the network or they '
        'should be placed in a ModuleList which is a parameter of the network.')
        return linear_layers

    def get_epochs_for_one_layer(self):
        """
        We want to show gradient distributions from up to 7 selected epochs
        for one linear layer.
        """
        if self.num_epochs < 7:
            return list(range(self.num_epochs))
        else:
            return torch.linspace(0, self.num_epochs - 1, 7).int().tolist()

    def get_three_epochs(self):
        """
        We want to show gradients distributions from all layers at each of
        three epochs: first, middle and last.
        """
        return [0, self.num_epochs // 2, self.num_epochs - 1]

    def rgb_to_rgba(self, rgb_color, epoch):
        """
        Value of epoch parameter determines how transparent color should be
        in comparison to others.
        Colors for earlier epochs should be more transparent/less visible.
        """
        return f'rgba{rgb_color[3:-1]},{0.6 * (epoch + 1) / self.num_epochs + 0.15})'

    def init_figures(self):
        # Initialize figure with gradient to weight ratio plot
        fig = go.Figure()
        fig.update_layout(
            title='Gradient standard deviation to weight standard deviation ratio', title_x=0.5,
            xaxis_title='Epoch',
            yaxis_title='Gradient to weight ratio (log scale)',
            height=400, width=1500, margin=dict(b=10, t=60)
        )
        fig.update_yaxes(type='log')
        for i in range(len(self.linear_layers)):
            fig.add_trace(go.Scatter(
                x=[], y=[],
                mode='lines+markers', name=f'Linear layer {i}'
            ))

        self.grad_to_weight_fig = go.FigureWidget(fig)
        display(self.grad_to_weight_fig)

        # Initialize figure visualizing gradient distributions in layers
        num_rows = (len(self.linear_layers) - 1) // 3 + 1
        fig = make_subplots(
            rows=num_rows, cols=3,
            subplot_titles=[f'Linear layer {i}' for i in range(len(self.linear_layers))],
            vertical_spacing=0.2 / num_rows
        )
        fig.update_layout(
            title='Comparison between epochs of gradient distributions in layers', title_x=0.5,
            height=num_rows * 400, width=1500, margin=dict(b=10, t=60)
        )

        colors, _ = px.colors.convert_colors_to_same_type(2 * px.colors.qualitative.Plotly)
        for layer_num in range(len(self.linear_layers)):
            row = layer_num // 3 + 1
            col = layer_num % 3 + 1
            fig.update_xaxes(
                title_text='Gradient value', range=(-0.1, 0.1), row=row, col=col
            )
            fig.update_yaxes(
                title_text='Density (log scale)', type='log', row=row, col=col
            )

            # Create empty traces and update them later with actual gradient distributions.
            # Unfortunately, we cannot add new traces dynamically because Colab has problem
            # with widgets from plotly (traces added dynamically are rendered twice).
            for epoch in self.get_epochs_for_one_layer():
                fig.add_trace(
                    go.Scatter(
                        mode='lines', name=f'Epoch {epoch + 1}',
                        line=dict(color=self.rgb_to_rgba(colors[layer_num], epoch)),
                        legendgroup=layer_num
                    ),
                    row=row, col=col
                )

        self.grads_in_layers_fig = go.FigureWidget(fig)
        display(self.grads_in_layers_fig)

        # Initialize figure comparing gradient distributions between layers at the
        # first, middle and last epoch
        selected_epochs_indices = self.get_three_epochs()
        fig = make_subplots(
            rows=1, cols=3,
            subplot_titles=[f'Epoch {epoch + 1}' for epoch in selected_epochs_indices]
        )
        fig.update_layout(
            title='Comparison between layers of gradient distributions at epochs', title_x=0.5,
            height=400, width=1500, margin=dict(b=10, t=60)
        )

        for col, epoch in enumerate(selected_epochs_indices, 1):
            fig.update_yaxes(title_text='Density (log scale)', type='log', row=1, col=col)
            fig.update_xaxes(
                title_text='Gradient value',
                range=(-0.05, 0.05) if epoch != 0 else (-1, 1),
                row=1, col=col
            )

            # Create empty traces and update them later with actual gradient distributions.
            for layer_num in range(len(self.linear_layers)):
                fig.append_trace(
                    go.Scatter(
                        mode='lines', name=f'Linear layer {layer_num}',
                        line=dict(color=colors[layer_num]), showlegend=(col == 1)
                    ),
                    row=1, col=col
                )

        self.grads_at_epochs_fig = go.FigureWidget(fig)
        display(self.grads_at_epochs_fig)

    def visualize_gradients(self, lr, epoch, batch_idx):
        # It is enough to use gradients calculated for the first batch.
        if batch_idx != 0:
            return

        epoch_grads = []
        epoch_grad_to_weight_ratios = []
        for layer in self.linear_layers:
            epoch_grads.append(layer.weight.grad.flatten().detach())
            epoch_grad_to_weight_ratios.append(
                (lr * layer.weight.grad.std() / layer.weight.std()).item()
            )

        # Update figure with gradient to weight ratio plot
        for i, grad_to_weight_ratio in enumerate(epoch_grad_to_weight_ratios):
            x = self.grad_to_weight_fig.data[i].x
            next_x_val = x[-1] + 1 if x else 1
            self.grad_to_weight_fig.data[i].x += (next_x_val, )
            self.grad_to_weight_fig.data[i].y += (grad_to_weight_ratio, )

        # Update figure visualizing gradient distributions in layers
        selected_epochs = self.get_epochs_for_one_layer()
        if epoch in selected_epochs:
            epoch_idx = selected_epochs.index(epoch)
            for layer_num, layer_grad in enumerate(epoch_grads):
                trace_idx = layer_num * len(selected_epochs) + epoch_idx
                hy, hx = torch.histogram(layer_grad, bins=50, density=True)
                hy = hy / max(hy) + 0.001
                self.grads_in_layers_fig.data[trace_idx].x = hx[:-1].tolist()
                self.grads_in_layers_fig.data[trace_idx].y = hy.tolist()

        # Update figure visualizing gradient distributions at epochs
        selected_epochs = self.get_three_epochs()
        if epoch in selected_epochs:
            epoch_idx = selected_epochs.index(epoch)
            for layer_num, layer_grad in enumerate(epoch_grads):
                trace_idx = epoch_idx * len(self.linear_layers) + layer_num
                hy, hx = torch.histogram(layer_grad, bins=50, density=True)
                hy = hy / max(hy) + 0.001
                self.grads_at_epochs_fig.data[trace_idx].x = hx[:-1].tolist()
                self.grads_at_epochs_fig.data[trace_idx].y = hy.tolist()

In [74]:
class Linear(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.normal_(mean=0,std=math.sqrt(2/(self.in_features + self.out_features))) # math.sqrt(2/(self.in_features + self.out_features))
        init.zeros_(self.bias)

    def forward(self, x):
        r = x.matmul(self.weight.t())
        r += self.bias
        return r


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = Linear(784, 64)
        self.fc2 = Linear(64, 64)
        self.fc3 = Linear(64, 64)
        self.fc4 = Linear(64, 64)
        self.fc5 = Linear(64, 64)
        self.fc6 = Linear(64, 10)

    def dropout(self, x):
      for i in range(len(x)):
        x[i] = 0 if np.random.random() < 1/2 else x[i]/(1-p)
      return x


    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = F.relu(self.fc4(x))
        x = self.dropout(x)
        x = F.relu(self.fc5(x))
        x = self.dropout(x)
        x = self.fc6(x)
        return x


In [76]:
class MnistTrainer(object):
    def __init__(self, batch_size):
        transform = transforms.Compose(
                [transforms.ToTensor()])
        self.trainset = torchvision.datasets.MNIST(
            root='./data',
            download=True,
            train=True,
            transform=transform)
        self.trainloader = torch.utils.data.DataLoader(
            self.trainset, batch_size=batch_size, shuffle=True, num_workers=2)

        self.testset = torchvision.datasets.MNIST(
            root='./data',
            train=False,
            download=True, transform=transform)
        self.testloader = torch.utils.data.DataLoader(
            self.testset, batch_size=1, shuffle=False, num_workers=2)

    def train(self, net, gradient_visualizer, epochs=20, lr=0.05, momentum=0.9):
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)

        for epoch in range(epochs):
            running_loss = 0.0
            for i, data in enumerate(self.trainloader):
                inputs, labels = data
                optimizer.zero_grad()

                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                gradient_visualizer.visualize_gradients(lr, epoch, i)
                optimizer.step()

                running_loss += loss.item()
                if i % 100 == 99:
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 100))
                    running_loss = 0.0
            correct = 0
            total = 0
            with torch.no_grad():
                for data in self.testloader:
                    images, labels = data
                    outputs = net(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            print('Accuracy of the network on the {} test images: {} %'.format(
                total, 100 * correct / total))

In [77]:
epochs = 20

net = Net()
gradient_visualizer = GradientVisualizer(net, epochs)

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Linear layer 0',
              'type': 'scatter',
              'uid': 'cec861e2-bd02-47dd-95fb-624873874d88',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'name': 'Linear layer 1',
              'type': 'scatter',
              'uid': '19151c06-730f-4966-a5fd-7dab9404c3d4',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'name': 'Linear layer 2',
              'type': 'scatter',
              'uid': '6f262a2f-1f0c-4cc0-b50e-88bc3115b2bf',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'name': 'Linear layer 3',
              'type': 'scatter',
              'uid': 'a50ea9b5-089a-496b-91df-6c7a8c6f7166',
              'x': [],
              'y': []},
             {'mode': 'lines+markers',
              'name': 'Linear layer 4',
              'type': 'sca

FigureWidget({
    'data': [{'legendgroup': '0',
              'line': {'color': 'rgba(99, 110, 250,0.18)'},
              'mode': 'lines',
              'name': 'Epoch 1',
              'type': 'scatter',
              'uid': '0b6d2d44-a560-4832-9439-381d2d1d8007',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'legendgroup': '0',
              'line': {'color': 'rgba(99, 110, 250,0.27)'},
              'mode': 'lines',
              'name': 'Epoch 4',
              'type': 'scatter',
              'uid': '9fe4d806-6887-40ec-8093-43047f48b6da',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'legendgroup': '0',
              'line': {'color': 'rgba(99, 110, 250,0.36)'},
              'mode': 'lines',
              'name': 'Epoch 7',
              'type': 'scatter',
              'uid': 'ccb566d6-6830-4411-9e9a-86c090b186fa',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'legendgroup': '0',
              'line': {'

FigureWidget({
    'data': [{'line': {'color': 'rgb(99, 110, 250)'},
              'mode': 'lines',
              'name': 'Linear layer 0',
              'showlegend': True,
              'type': 'scatter',
              'uid': '37511482-379e-4623-9c27-490b1686b8f6',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'line': {'color': 'rgb(239, 85, 59)'},
              'mode': 'lines',
              'name': 'Linear layer 1',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'a46b63f9-fedc-4ac3-bd5f-9086c7bdcd74',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'line': {'color': 'rgb(0, 204, 150)'},
              'mode': 'lines',
              'name': 'Linear layer 2',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'f996caed-66d2-4178-b996-2999af535ad7',
              'xaxis': 'x',
              'yaxis': 'y'},
             {'line': {'color': 'rgb(171, 99, 250)'},
   

In [None]:
trainer = MnistTrainer(batch_size=128)
trainer.train(net, gradient_visualizer, epochs=epochs)

[1,   100] loss: 1.815
[1,   200] loss: 1.369
[1,   300] loss: 1.332
[1,   400] loss: 1.291
Accuracy of the network on the 10000 test images: 51.95 %
[2,   100] loss: 1.259
[2,   200] loss: 1.267
[2,   300] loss: 1.258
[2,   400] loss: 1.237
Accuracy of the network on the 10000 test images: 52.75 %
[3,   100] loss: 1.225
[3,   200] loss: 1.227
[3,   300] loss: 1.224
[3,   400] loss: 1.228
Accuracy of the network on the 10000 test images: 51.82 %
[4,   100] loss: 1.226
[4,   200] loss: 1.202
[4,   300] loss: 1.219
[4,   400] loss: 1.224
Accuracy of the network on the 10000 test images: 52.68 %
[5,   100] loss: 1.207
[5,   200] loss: 1.204
[5,   300] loss: 1.192
[5,   400] loss: 1.207
Accuracy of the network on the 10000 test images: 52.69 %
[6,   100] loss: 1.170
[6,   200] loss: 1.212
[6,   300] loss: 1.229
[6,   400] loss: 1.207
Accuracy of the network on the 10000 test images: 53.18 %
[7,   100] loss: 1.210
[7,   200] loss: 1.206
[7,   300] loss: 1.190
[7,   400] loss: 1.214
Accuracy