In [76]:
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from fastai import datasets
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt

MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))

def normalize(x, m, s): return (x-m)/s

def stats(x): return x.mean(),x.std()

#gain values: 1 = linear, 0 = sqrt(2), > 1 = sqrt(2)
#original Torch Library Value: Sqrt(5)
def gain(a):
    return math.sqrt(2.0 / (1+a**2)) #2.0 from Kaiming Init Paper

In [27]:
x_train, y_train, x_valid, y_valid = get_data() #download data from url


#normalize data
training_mean, training_std = x_train.mean(), x_train.std()
x_train = normalize(x_train, training_mean, training_std)
x_valid = normalize(x_valid, training_mean, training_std)
x_train = x_train.view(-1,1,28,28) #convert into 28x28[]
x_valid = x_valid.view(-1,1,28,28) #convert into 28x28[]

#get sizes for defining the model
n,*_ = x_train.shape
c = y_train.max()+1
nh = 32
n,c

(50000, tensor(10))

In [65]:
layer1 = nn.Conv2d(1, nh, 5) 
x = x_valid[:100]
t = layer1(x)
stats(t) #stats without initalization are terrible (aka: why we use initalization techniques)

(tensor(0.0057, grad_fn=<MeanBackward0>),
 tensor(0.6488, grad_fn=<StdBackward0>))

In [64]:
#kaiming init without a Relu
init.kaiming_normal_(layer1.weight, a=1.) 
print("No Relu:\n",stats(layer1(x)))

#kaiming init with a Leaky Relu
gain =  0
def layer1_func(x, a=0): #
    return F.leaky_relu(layer1(x), a)
init.kaiming_normal_(layer1.weight, a=gain)
print("Leaky Relu, Gain = %d:\n"%(gain),stats(layer1_func(x)))

No Relu:
 (tensor(0.0129, grad_fn=<MeanBackward0>), tensor(1.0585, grad_fn=<StdBackward0>))
Leaky Relu, Gain = 0:
 (tensor(0.5698, grad_fn=<MeanBackward0>), tensor(1.0390, grad_fn=<StdBackward0>))


In [74]:
receptive_field_size = layer1.weight[0,0].numel() #count of # of elements in the kernel
num_filters_out, num_filters_in, *_ = layer1.weight.shape
print("Receptive Field Size: %d\n# of Filters Out: %d\n# of Filters In: %d"%(receptive_field_size, num_filters_out, num_filters_in))

fan_in = num_filters_in*receptive_field_size
fan_out = num_filters_out*receptive_field_size
print("Kaiming Init Fan In, Out [%d,%d]"%(fan_in,fan_out))

Receptive Field Size: 25
# of Filters Out: 32
# of Filters In: 1
Kaiming Init Fan In, Out [25,800]


In [78]:
??torch.nn.modules.conv._ConvNd.reset_parameters??
#pytorch uses a uniform distribution

[1;31mSignature:[0m [0mtorch[0m[1;33m.[0m[0mnn[0m[1;33m.[0m[0mmodules[0m[1;33m.[0m[0mconv[0m[1;33m.[0m[0m_ConvNd[0m[1;33m.[0m[0mreset_parameters[0m[1;33m([0m[0mself[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m <no docstring>
[1;31mSource:[0m   
    [1;32mdef[0m [0mreset_parameters[0m[1;33m([0m[0mself[0m[1;33m)[0m[1;33m:[0m[1;33m
[0m        [0minit[0m[1;33m.[0m[0mkaiming_uniform_[0m[1;33m([0m[0mself[0m[1;33m.[0m[0mweight[0m[1;33m,[0m [0ma[0m[1;33m=[0m[0mmath[0m[1;33m.[0m[0msqrt[0m[1;33m([0m[1;36m5[0m[1;33m)[0m[1;33m)[0m[1;33m
[0m        [1;32mif[0m [0mself[0m[1;33m.[0m[0mbias[0m [1;32mis[0m [1;32mnot[0m [1;32mNone[0m[1;33m:[0m[1;33m
[0m            [0mfan_in[0m[1;33m,[0m [0m_[0m [1;33m=[0m [0minit[0m[1;33m.[0m[0m_calculate_fan_in_and_fan_out[0m[1;33m([0m[0mself[0m[1;33m.[0m[0mweight[0m[1;33m)[0m[1;33m
[0m            [0mbound[0m [1;33m=[0m [1;3

In [89]:
#implementing Kaiming for better intuitive understanding
def kaiming_initalization(x, a, fan_dir_out = False):
        nf, ni, *_ = x.shape
        receptive_field_sz = x[0,0].shape.numel()
        fan = nf*receptive_field_size if fan_dir_out else ni*receptive_field_size
        std = gain(a)/ math.sqrt(fan)
        bound = math.sqrt(3) * std
        x.data.uniform_(-bound, bound)

a = 0
kaiming_initalization(layer1.weight, a)
print("My Kaiming, Gain=%d\n"%(a),stats(layer1_func(x)))
a = math.sqrt(5.)
kaiming_initalization(layer1.weight, a)
print("My Kaiming, Gain=%d\n"%(a),stats(layer1_func(x))) #much worse

My Kaiming, Gain=0
 (tensor(0.5371, grad_fn=<MeanBackward0>), tensor(0.9381, grad_fn=<StdBackward0>))
My Kaiming, Gain=2
 (tensor(0.2348, grad_fn=<MeanBackward0>), tensor(0.4143, grad_fn=<StdBackward0>))
