In [None]:
%load_ext autoreload
%autoreload 2

<img src='../img/boromir.png'>

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
%matplotlib inline

In [None]:
import sys
sys.path.append('../../modules/')
import mnist

## MNIST

In [None]:
mnist_set = mnist.MNIST('../data/')

In [None]:
data, labels = mnist_set.random_data(12000)
data = 2*data-1

In [None]:
fig_mnist, ax = plt.subplots(1,8, figsize=(8*4,4))
for i in range(8):
    ax[i].imshow(data[i].numpy(), cmap='Greys');
    ax[i].set_title(labels[i].item(), fontsize=16)

In [None]:
dataset = torch.utils.data.TensorDataset(data.view(-1,28*28), labels) 

In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, (10000,2000))

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100,drop_last=True)

## Model

In [None]:
model = nn.Sequential(
nn.Linear(28*28,1200), nn.ReLU(),
nn.Linear(1200,600), nn.ReLU(),
nn.Linear(600,300), nn.ReLU(),
nn.Linear(300,10)
)

In [None]:
def accuracy(pred, labels):
    return torch.sum(torch.argmax(pred,axis = 1)==labels).to(dtype=torch.float32).item()/len(labels)

def model_accuracy(model, dataset):
    features, labels = dataset[:]
    with torch.no_grad():
        pred = model(features)
    return accuracy(pred, labels)

In [None]:
loss_f = nn.CrossEntropyLoss()

In [None]:
optim = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
errors = []
batches = 0
epochs = 0

In [None]:
%%time
for e in range(5):
    for f,l in train_loader:        
        optim.zero_grad()
        pred = model(f)
        loss = loss_f(pred, l)
        errors.append(loss.item())
        loss.backward()
        optim.step()
        batches += 1
    epochs += 1   
    print(loss.item())        

In [None]:
plt.plot(np.linspace(0,epochs, batches),errors);

In [None]:
model_accuracy(model, train_dataset)

In [None]:
model_accuracy(model, test_dataset)

## How are the weights initialized ? 

In [None]:
model = nn.Sequential(
nn.Linear(28*28,1200), nn.ReLU(),
nn.Linear(1200,600), nn.ReLU(),
nn.Linear(600,300), nn.ReLU(),
nn.Linear(300,10)
)

In [None]:
for layer in model.modules():
    if isinstance(layer,torch.nn.modules.linear.Linear):
                  print( torch.std_mean(layer.weight))
    

In [None]:
modules = model.modules()
next(modules)
layer = next(modules)
print(layer)

In [None]:
plt.hist(layer.weight.detach().numpy().ravel(),bins=100);

In [None]:
1/layer.weight.max().item()**2

Pytorch LinearLayer weights are initalized from uniform distribution 

$$\mathcal{U}(-\sqrt{k},\sqrt{k}),\quad k =\frac{1}{n_{in}}$$

## Wanishing or exploding gradients

In [None]:
def init_layer_uniform(sigma):
    def init(layer): 
        if isinstance(layer,torch.nn.modules.linear.Linear):
            torch.nn.init.uniform_(layer.weight,-sigma,sigma)
            if layer.bias is not None:
                torch.nn.init.zeros_(layer.bias)
    return init  

def init_layer_gauss(sigma):
    def init(layer): 
        if isinstance(layer,torch.nn.modules.linear.Linear):
            torch.nn.init.normal__(layer.weight,0,sigma)
            if layer.bias is not None:
                torch.nn.init.zeros_(layer.bias)
    return init  

In [None]:
model.apply(init_layer_uniform(0.12));
optim = torch.optim.SGD(model.parameters(), lr=0.1)
errors = []
batches = 0
epochs = 0
for e in range(5):
    for f,l in train_loader:        
        optim.zero_grad()
        pred = model(f)
        loss = loss_f(pred, l)
        errors.append(loss.item())
        loss.backward()
        optim.step()
        batches += 1
    epochs += 1   
    print(loss.item())        

In [None]:
plt.plot(np.linspace(0,epochs, batches),errors);

In [None]:
model_accuracy(model, train_dataset)

In [None]:
model_accuracy(model, test_dataset)

## Input variance

In [None]:
torch.var(train_dataset[:][0])

In [None]:
for sigma in [0.01, 0.05, 0.1, 0.15, 0.2, 0.5 ]:
    model.apply(init_layer_uniform(sigma))
    mnist_out = model(train_dataset[:][0])
    print(f"{sigma:4.2f}, {torch.var(mnist_out.detach()):.6g}" )

$$\newcommand{\var}{\operatorname{var}}
\renewcommand{\E}{\operatorname{E}}$$
$$x^{l}_i = f(y^{l}_i), 
\qquad y^{l}_i =\sum_{j=1}^{n_{l-1}} w^{l}_{ij}x^{l-1}_j$$

$$w^l_{ij} \sim \text{i.i.d}\qquad \E[w]=0$$

$$\E[y^{l}_i] = \sum_{j=1}^{n_{l-1}} \E[w^{l}_{ij}x^{l-1}_j] = 
\sum_{j=1}^{n_{l-1}} \E[w^{l}_{ij}]\E[x^{l-1}_j]=0$$

$$\var[y^{l}_i ]= \E[(y^{l}_i-\E[y^{l}_i ])^2] = \E[(y^{l}_i)^2]$$ 

$$\var[y^{l}_i] = 
\E\left[\left(
\sum_{j=1}^{n_{l-1}} w^{l}_{ij}x^{l-1}_j
\right)^2\right]
$$

$$\var[y^{l}_i] = 
\E\left[\left(
\sum_{j=1}^{n_{l-1}} w^{l}_{ij}x^{l-1}_j
\right)
\left(
\sum_{k=1}^{n_{l-1}} w^{l}_{ik}x^{l-1}_k
\right)
\right]
$$

$$\var[y^{l}_i] = 
\E\left[
\sum_{j,k=1}^{n_{l-1}} w^{l}_{ij}x^{l-1}_j
 w^{l}_{ik}x^{l-1}_k
\right]= 
\sum_{j,k=1}^{n_{l-1}} \E\left[
 w^{l}_{ij}x^{l-1}_j
 w^{l}_{ik}x^{l-1}_k
\right]
$$

$$\var[y^{l}_i] = 
\sum_{j,k=1}^{n_{l-1}} 
E\left[w^{l}_{ij} w^{l}_{ik}\right]
E\left[
x^{l-1}_k x^{l-1}_j
\right]
$$

$$E\left[w^{l}_{ij} w^{l}_{ik}\right]=\delta_{jk}E[(w^l_{ij})^2]= \delta_{jk} Var [w^l]$$

$$\var[y^{l}_i] = 
\sum_{j,k=1}^{n_{l-1}} 
\delta_{jk}Var [w^l]
E\left[
x^{l-1}_k x^{l-1}_j
\right] = 
\sum_{j}^{n_{l-1}} 
Var [w^l]
E\left[
(x^{l-1}_j)^2
\right] = 
n_{l-1} \var [w^{l}]E[(x^{l-1})^2]
$$

$$n_{l-1} \var [w^{l}] = 1\qquad \sigma_w = \frac{1}{\sqrt{n_{l-1}}} $$

### Uniform distribution

$$w\in (-a, a),\quad P(w)=\frac{1}{2 a}$$

$$\sigma_w^2 = \var[w]=\frac{1}{2a}\int\limits_{-a}^a w^2 
=\frac{1}{2 a}\frac{1}{3} 2 a^3 =\frac{1}{3}a^2
$$

$$a = \sqrt{3} \sigma_w$$

$$a = \sqrt{\frac{3}{n_{l-1}}}$$

## Xavier intialisation

$$w\in (-a, a),\qquad a=\sqrt{\frac{6}{n_l+n_{l-1}}}$$

## Kaiming/He initalization

$$\newcommand{\relu}{\operatorname{relu}}$$
$$x^{l-1} = \relu(y^{l-1})=\max(0,y^{l-1}),\qquad P(y^{l-1})=P(-y^{l-1})$$

$$\E[(x^{l-1})^2]=\int\limits_0^\infty\!\!\text{d}y^{l-1} \,\max(0,y^{l-1})^2=\int\limits_0^\infty\!\!\text{d}y^{l-1}P (y^{l-1})\,(y^l)^2=
\frac{1}{2}\int\limits_{-\infty}^\infty\!\!\text{d}y^{l-1} P(y)\,(y^{l-1})^2=\frac{1}{2}\var[y^{l-1}]$$

$$\var[w^{l}]E[(x^l)^2]=\frac{1}{2}n_l \var[w^l]\var[y^{l-1}]
$$

$$\var[y^L]=
\var[y^1]\left(\prod_2^L\frac{1}{2} n_l \var[w^{l}]\right)
=\sum_{j=0}^{n_0}E[(x^0_j)^2] \var[w^1] \left(\prod_2^L\frac{1}{2} n_l \var[w^{l}]\right)$$

$$\left(\prod_2^L\frac{1}{2} n_l \var[w^{l}]\right) = 1$$

$$\var[w^{l}] = \frac{2}{n_l}$$

$$\sigma_w=\sqrt{\frac{2}{n_{l-1}}},\qquad a = \sqrt{\frac{6}{n_{l-1}}}$$

In [None]:
torch.nn.init.calculate_gain('relu')

In [None]:
def xavier_init_gauss(sigma=1):
    def init(layer): 
        if isinstance(layer,torch.nn.modules.linear.Linear):
            fan_in = layer.weight.size(1)
            s  = np.sqrt(1/fan_in)    
            torch.nn.init.normal_(layer.weight,0,s*sigma)
            if layer.bias is not None:
                torch.nn.init.zeros_(layer.bias)
    return init  

In [None]:
def xavier_init_uniform(sigma=1):
    def init(layer): 
        if isinstance(layer,torch.nn.modules.linear.Linear):
            fan_in = layer.weight.size(1)
            s  = np.sqrt(3/fan_in)    
            torch.nn.init.uniform_(layer.weight,-sigma*s,sigma*s)
            if layer.bias is not None:
                torch.nn.init.zeros_(layer.bias)
    return init  

In [None]:
def kaiming_init_gauss(sigma=1):
    def init(layer): 
        if isinstance(layer,torch.nn.modules.linear.Linear):
            fan_in = layer.weight.size(1)
            s  = np.sqrt(2/fan_in)    
            torch.nn.init.normal_(layer.weight,0,s*sigma)
            if layer.bias is not None:
                torch.nn.init.zeros_(layer.bias)
    return init  

In [None]:
def kaiming_init_uniform(sigma=1):
    def init(layer): 
        if isinstance(layer,torch.nn.modules.linear.Linear):
            fan_in = layer.weight.size(1)
            s  = np.sqrt(6/fan_in)    
            torch.nn.init.uniform_(layer.weight,-sigma*s,sigma*s)
            if layer.bias is not None:
                torch.nn.init.zeros_(layer.bias)
    return init  

In [None]:
model.apply(kaiming_init_uniform())

In [None]:
optim = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
errors = []
batches = 0
epochs = 0

In [None]:
%%time
for e in range(5):
    for f,l in train_loader:        
        optim.zero_grad()
        pred = model(f)
        loss = loss_f(pred, l)
        errors.append(loss.item())
        loss.backward()
        optim.step()
        batches += 1
    epochs += 1   
    print(loss.item())        

In [None]:
plt.plot(np.linspace(0,epochs, batches),errors);

In [None]:
model_accuracy(model, train_dataset)

In [None]:
model_accuracy(model, test_dataset)

In [None]:
for init in [xavier_init_uniform(), xavier_init_gauss(), kaiming_init_uniform(), kaiming_init_gauss()]:
    model.apply(init)
    optim = torch.optim.SGD(model.parameters(), lr=0.1)
    for e in range(5):
        for f,l in train_loader:        
            optim.zero_grad()
            pred = model(f)
            loss = loss_f(pred, l)
            errors.append(loss.item())
            loss.backward()
            optim.step()
            batches += 1
        epochs += 1   
    print( model_accuracy(model, train_dataset), model_accuracy(model, test_dataset))