### Comes from [paper](https://arxiv.org/pdf/1511.06422.pdf): "All you need is a good init" 
- a good initalization technique for complex and/or deep architectures, 
    <br />-> where is can be difficult to get unit variance at the last layer

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
from utilities.imports import *

In [6]:
#typical example setup
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
x_train, y_train, x_valid, y_valid = get_data(MNIST_URL)
x_train, x_valid = normalize_to(x_train, x_valid)

number_hidden = 50
batch_size = 512
num_categories = y_train.max().item()+1

loss_function = F.cross_entropy

#data 
training_ds = Dataset(x_train, y_train)
validation_ds = Dataset(x_valid, y_valid)
train_dl = DataLoader(training_ds, batch_size, shuffle = True) #random sampler
valid_dl = DataLoader(validation_ds, batch_size*2, shuffle = False) #sequential sampler
number_features = [8,16,32,64, 64]
#callbacks
mnist_view = view_tfm(1,28,28)
cbfs = [Recorder,
        partial(AvgStatsCallback, accuracy),
        CudaCallback,
        partial(IndependentVarBatchTransformCallback,mnist_view)]

[1;31mInit signature:[0m
[0mDataLoader[0m[1;33m([0m[1;33m
[0m    [0mdataset[0m[1;33m,[0m[1;33m
[0m    [0mbatch_size[0m[1;33m=[0m[1;36m1[0m[1;33m,[0m[1;33m
[0m    [0mshuffle[0m[1;33m=[0m[1;32mFalse[0m[1;33m,[0m[1;33m
[0m    [0msampler[0m[1;33m=[0m[1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mbatch_sampler[0m[1;33m=[0m[1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mnum_workers[0m[1;33m=[0m[1;36m0[0m[1;33m,[0m[1;33m
[0m    [0mcollate_fn[0m[1;33m=[0m[1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mpin_memory[0m[1;33m=[0m[1;32mFalse[0m[1;33m,[0m[1;33m
[0m    [0mdrop_last[0m[1;33m=[0m[1;32mFalse[0m[1;33m,[0m[1;33m
[0m    [0mtimeout[0m[1;33m=[0m[1;36m0[0m[1;33m,[0m[1;33m
[0m    [0mworker_init_fn[0m[1;33m=[0m[1;32mNone[0m[1;33m,[0m[1;33m
[0m    [0mmultiprocessing_context[0m[1;33m=[0m[1;32mNone[0m[1;33m,[0m[1;33m
[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m     
Data loader. C

In [5]:
x, y= next(iter(train_dl))
x

tensor([[-0.4245, -0.4245, -0.4245,  ..., -0.4245, -0.4245, -0.4245],
        [-0.4245, -0.4245, -0.4245,  ..., -0.4245, -0.4245, -0.4245],
        [-0.4245, -0.4245, -0.4245,  ..., -0.4245, -0.4245, -0.4245],
        ...,
        [-0.4245, -0.4245, -0.4245,  ..., -0.4245, -0.4245, -0.4245],
        [-0.4245, -0.4245, -0.4245,  ..., -0.4245, -0.4245, -0.4245],
        [-0.4245, -0.4245, -0.4245,  ..., -0.4245, -0.4245, -0.4245]])

In [4]:
model = get_cnn_model(num_categories, number_features, ConvLayer2D)
init_cnn(model)
opt = optim.SGD(model.parameters(), lr=0.9)
runner = Runner(cb_funcs=cbfs)
#runner.fit(2, model, opt, loss_function, train_dl, valid_dl) 

In [5]:
#get one batch
xb, yb = get_one_batch(train_dl, runner)

In [6]:
all_modules = get_all_modules(model, lambda o : isinstance(o, ConvLayer2D))
all_modules

[ConvLayer2D(
   (conv): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
   (relu): GeneralReLU()
 ), ConvLayer2D(
   (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralReLU()
 ), ConvLayer2D(
   (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralReLU()
 ), ConvLayer2D(
   (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralReLU()
 ), ConvLayer2D(
   (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralReLU()
 )]

### showing the means are too high, standard deviations are too low

In [8]:
model.cuda() #put model on the GPu
with Hooks(all_modules, append_stat) as hooks:
    model(xb)
    for hook in hooks: print(hook.mean, hook.std)

0.42125365138053894 0.9243180751800537
0.35774943232536316 0.878524661064148
0.3686355650424957 0.7132577300071716
0.3702735900878906 0.5816133618354797
0.2686123251914978 0.3781846761703491


## adding LSUV to initalize the convolution layers

In [13]:
def lsuv_module(model, module, x_mb):
    error_ceiling = 1e-3
    hook = ForwardHook(module, append_stat)
    while model(x_mb) is not None and abs(hook.mean) > error_ceiling: #correct the means
        module.bias -= hook.mean
    while model(x_mb) is not None and abs(hook.std - 1) > error_ceiling: #correct the standard deviations
        module.weight.data /= hook.std
    hook.remove()
    return hook.mean, hook.std

### showing the improved means and stds, with a simple LSUV loop

In [15]:
for mod in all_modules: print(lsuv_module(model, mod, xb))

(8.971107412492074e-09, 1.0)
(0.001057819346897304, 0.9999999403953552)
(0.004357390571385622, 1.0)
(0.00654597207903862, 1.0)
(0.008741334080696106, 1.0)
