# Weight Initialization

Author: YinTaiChen

## Xavier Initialization

“Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)

In [1]:
import torch
import torch.nn as nn

In [2]:
w = torch.Tensor(3, 5)

In [3]:
print(w)


-9.9332e+12  4.5663e-41 -9.9332e+12  4.5663e-41  0.0000e+00
 0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00
 0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00  0.0000e+00
[torch.FloatTensor of size 3x5]



In [4]:
nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))


 0.5650  0.8589  1.1236  1.1099 -0.2263
 0.6843 -1.1411 -0.9068  0.3745  0.5870
 0.4829 -0.3405 -0.4201 -0.6438  0.0388
[torch.FloatTensor of size 3x5]

## Initialize Alexnet

In [5]:
from torchvision import models

In [6]:
alexnet = models.alexnet(pretrained=True)

In [7]:
parameters = alexnet.state_dict()

In [8]:
print(type(parameters))

<class 'collections.OrderedDict'>


In [9]:
for key in parameters:
    print(key)

features.0.weight
features.0.bias
features.3.weight
features.3.bias
features.6.weight
features.6.bias
features.8.weight
features.8.bias
features.10.weight
features.10.bias
classifier.1.weight
classifier.1.bias
classifier.4.weight
classifier.4.bias
classifier.6.weight
classifier.6.bias


In [10]:
parameters['features.0.weight']


(0 ,0 ,.,.) = 
  0.1186  0.0941  0.0954  ...   0.0558  0.0216  0.0500
  0.0749  0.0389  0.0530  ...   0.0257 -0.0113  0.0042
  0.0754  0.0388  0.0549  ...   0.0436  0.0102  0.0133
           ...             ⋱             ...          
  0.0932  0.1037  0.0675  ...  -0.2028 -0.1284 -0.1122
  0.0435  0.0649  0.0362  ...  -0.2025 -0.1138 -0.1072
  0.0474  0.0625  0.0248  ...  -0.1184 -0.0956 -0.0839

(0 ,1 ,.,.) = 
 -0.0726 -0.0580 -0.0807  ...  -0.0006 -0.0253  0.0255
 -0.0690 -0.0676 -0.0764  ...  -0.0040 -0.0304  0.0105
 -0.0995 -0.0856 -0.1052  ...  -0.0266 -0.0228  0.0066
           ...             ⋱             ...          
 -0.1512 -0.0887 -0.0967  ...   0.3085  0.1810  0.0843
 -0.1431 -0.0757 -0.0722  ...   0.2042  0.1645  0.0952
 -0.0859 -0.0401 -0.0515  ...   0.1635  0.1482  0.1020

(0 ,2 ,.,.) = 
 -0.0236 -0.0021 -0.0278  ...   0.0399 -0.0071  0.0322
  0.0003  0.0225  0.0089  ...   0.0188 -0.0142  0.0183
  0.0054  0.0294  0.0003  ...   0.0121 -0.0025  0.0084
           ...   

In [11]:
alexnet.state_dict()['features.0.weight'] = nn.init.xavier_uniform(parameters['features.0.weight'], gain=nn.init.calculate_gain('relu'))

In [12]:
alexnet.state_dict()['features.0.weight']


(0 ,0 ,.,.) = 
1.00000e-02 *
  0.3106  0.1652  2.4613  ...   2.6552 -0.6437  1.8675
 -0.7930  0.0835 -2.6740  ...  -3.6523  0.8647  1.1658
 -2.2559 -0.7712 -3.1179  ...   2.8873  1.0292  2.2368
           ...             ⋱             ...          
 -0.8628  0.4459 -0.2512  ...   3.0699 -2.4046 -3.2686
  2.5632  2.9395 -1.2826  ...  -1.1968  0.7460  2.9725
 -0.3657 -0.4645 -2.7041  ...  -1.3255 -2.5543 -2.4107

(0 ,1 ,.,.) = 
1.00000e-02 *
 -0.5254 -1.3041  0.4437  ...   1.7805 -0.4303 -1.4664
  2.1706  0.4799 -1.0999  ...  -3.2584  3.5179 -1.0486
  2.2597  0.6343  1.0798  ...   3.4110  0.7957 -2.7100
           ...             ⋱             ...          
  1.8909 -2.5138 -2.2303  ...  -3.5146 -3.4340 -1.8039
 -0.2363 -2.6465 -2.1034  ...  -1.9407  3.5831 -0.5134
 -2.2667  3.6254  1.2560  ...   0.2011  2.8297 -2.9327

(0 ,2 ,.,.) = 
1.00000e-02 *
 -3.4080  3.0724  2.7311  ...   2.7429 -2.9930  0.4181
 -0.7619  1.8450  2.8201  ...   1.5895  2.5558 -1.5208
  0.6302 -2.3551  1.4368  ... 