In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# Import and prepare Data for Conv

In [4]:
#export
from exp.nb_02 import *

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))

def normalize(x, m, s): return (x-m)/s

In [30]:
def stats(x): return round(float(x.mean()),5),round(float(x.std()),5)

In [31]:
x_train,y_train,x_valid,y_valid = get_data()

In [32]:
train_mean,train_std = x_train.mean(),x_train.std()
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

stats(x_train), stats(x_valid)

((-0.0, 1.0), (-0.00585, 0.99243))

In [33]:
x_train.shape

torch.Size([50000, 784])

Rechape tensors into 2d images

In [34]:
x_train = x_train.view(-1,1,28,28)
x_valid = x_valid.view(-1,1,28,28)
x_train.shape

torch.Size([50000, 1, 28, 28])

In [35]:
n,*_ = x_train.shape
c = y_train.max()+1
nh = 32
n,c

(50000, tensor(10))

# Examine behavior of one conv layer:

In [36]:
l1 = nn.Conv2d(in_channels= 1, out_channels= nh, kernel_size= 5)

In [40]:
l1.weight.shape, stats(l1.weight)

(torch.Size([32, 1, 5, 5]), (0.00601, 0.11478))

In [41]:
l1.bias.shape, stats(l1.bias)

(torch.Size([32]), (-0.00225, 0.11638))

In [42]:
t = l1(x_valid[:200])

In [43]:
stats(t)

(0.01574, 0.67049)

In [44]:
init.kaiming_normal_(l1.weight, a=1.) # since l1 is a Conv and not a relu, a=1
stats(l1.weight)

(0.00036, 0.19544)

In [45]:
stats(l1(x_valid[:200]))

(-0.00155, 1.09674)

## Conv + Relu

In [46]:
import torch.nn.functional as F

In [47]:
def f1(x, a=0): return F.leaky_relu(l1(x),a)

In [48]:
x = x_valid[:100]

In [53]:
init.kaiming_normal_(l1.weight, a=0) #a=0 meand a regular ReLU
stats(l1.weight)

(-0.00273, 0.28005)

In [54]:
stats(f1(x))

(0.50809, 0.9277)

So with kaiming_normal init, the output stats are normal
but if we reset Conv2D without kaiming

In [55]:
l1 = nn.Conv2d(1, nh, 5)

In [56]:
stats(f1(x))

(0.21904, 0.38894)

These stats are bad !

# How does Kaiming init works

In [58]:
l1.weight.shape

torch.Size([32, 1, 5, 5])

## Get the "receptive field" of the Conv layer

In [60]:
l1.weight[0,0].shape

torch.Size([5, 5])

In [61]:
rec_fs = l1.weight[0,0].numel(); rec_fs

25

In [62]:
nf, ni, *_ = l1.weight.shape
nf, ni

(32, 1)

## Process the gain

Gain tries to compensate loss of standard deviation after ReLU

In [63]:
fan_in = ni * rec_fs
fan_out = nf * rec_fs
fan_in, fan_out

(25, 800)

In [64]:
def gain(a): return math.sqrt(2.0 / (1 + a**2))

In [65]:
gain(1),gain(0),gain(0.01),gain(0.1),gain(math.sqrt(5.))

(1.0,
 1.4142135623730951,
 1.4141428569978354,
 1.4071950894605838,
 0.5773502691896257)

Pytorch is using sqrt5 as gain value trying to compensate Uniform random distribution

In [69]:
float(torch.zeros(100000).uniform_(-1,1).std())  ,  1/math.sqrt(3)

(0.5774717330932617, 0.5773502691896258)

So Pytorch is trying to compensate Uniform std with the gain

## Kaiming init func

In [82]:
l1 = nn.Conv2d(1, nh, 5)

In [72]:
def kaiming2(x, a, use_fan_out = False):
    # process fan
    nf , ni, *_ = x.shape
    rec_fs = x[0,0].shape.numel()
    fan = nf*rec_fs if use_fan_out else ni*rec_fs
    
    std = gain(a) / math.sqrt(fan)
    
    bound = math.sqrt(3.) * std
    
    x.data.uniform_(-bound, bound)

In [79]:
l1.weight[0,0], stats(l1.weight)

(tensor([[ 0.0891, -0.1815,  0.1096,  0.0363,  0.0842],
         [-0.0763,  0.0612,  0.1519,  0.1893, -0.1800],
         [-0.1127,  0.1962,  0.0151, -0.0827,  0.1378],
         [ 0.1444,  0.1329,  0.0278,  0.0538, -0.0817],
         [ 0.1947, -0.0356,  0.0013,  0.0598,  0.0218]],
        grad_fn=<SelectBackward>), (-0.00416, 0.11944))

In [80]:
kaiming2(l1.weight, a=0)
l1.weight[0,0], stats(l1.weight)

(tensor([[ 0.2412, -0.2983,  0.0652, -0.2432, -0.0108],
         [ 0.4685, -0.3512, -0.0411,  0.0850, -0.1991],
         [ 0.4886,  0.2593, -0.1854, -0.0218,  0.4469],
         [-0.1797, -0.0658, -0.0864, -0.4701, -0.3045],
         [ 0.2093,  0.3543,  0.4819,  0.0385, -0.4624]],
        grad_fn=<SelectBackward>), (0.01823, 0.28548))

In [81]:
stats(f1(x))

(0.56268, 1.11624)

In [83]:
kaiming2(l1.weight, a= math.sqrt(5.))
stats(l1.weight)

(-0.0037, 0.11822)

In [84]:
stats(f1(x))

(0.20088, 0.35929)

Pretty bad stats

## Bad init on bigger CNN

In [85]:
class Flatten(nn.Module):
    def forward(self, x): 
        return x.view(-1)

In [95]:
torch.randn(2,2,2).view(-1)

tensor([ 1.0894,  1.1844,  0.2399, -0.3012,  0.4900,  0.1929,  1.1049, -0.0392])

In [113]:
m = nn.sequentialm = nn.Sequential(
    nn.Conv2d(1,8, 5,stride=2,padding=2), nn.ReLU(),
    nn.Conv2d(8,16,3,stride=2,padding=1), nn.ReLU(),
    nn.Conv2d(16,32,3,stride=2,padding=1), nn.ReLU(),
    nn.Conv2d(32,1,3,stride=2,padding=1),
    nn.AdaptiveAvgPool2d(1),
    Flatten(),
)

In [114]:
x = x_valid[:100]
y = y_valid[:100].float()

In [115]:
t = m(x)
stats(t)

(-0.03178, 0.01357)

This is very bad beacause weigts are tending to 0 and are not learning anything.

So if we initialize those layers manually:

In [116]:
for l in m:
    if isinstance(l, nn.Conv2d):
        init.kaiming_uniform_(l.weight)
        #kaiming2(l.weight, a=0)
        l.bias.data.zero_()

In [117]:
t = m(x)
stats(t)

(0.29573, 0.19726)

In [118]:
l = mse(t,y)
l.backward()
stats(m[0].weight.grad)

(-0.06808, 0.47778)

# explore uniform distrib