In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_07 import *

In [3]:
x_train, y_train, x_valid, y_valid = get_data()

In [4]:
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)

In [5]:
nh, bs = 50, 512
c = y_train.max().item() + 1
loss_func = F.cross_entropy

In [6]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [7]:
mnist_view = view_tfm(1,28,28)
cbfs = [Recorder,
       CudaCallback,
       partial(AvgStatsCallback, accuracy),
       partial(BatchTransformXCallback, mnist_view)]

In [8]:
nfs = [8, 16, 32, 64, 64]

In [10]:
class ConvLayer(nn.Module):
    def __init__(self, ni, nf, ks=3, stride=2, sub=0., **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(ni, nf, ks, stride, padding=ks//2, bias=True)
        self.relu = GeneralRelu(sub=sub, **kwargs)
        
    def forward(self, x): return self.relu(self.conv(x))
    
    @property
    def bias(self): return -self.relu.sub
    @bias.setter
    def bias(self, v): self.relu.sub = -v
    @property
    def weight(self): return self.conv.weight

In [43]:
learn, run = get_learn_run(nfs, data, 0.6, ConvLayer, cbfs)

In [44]:
run.fit(2, learn)

train: [2.07595625, tensor(0.2970, device='cuda:0')]
valid: [2.801163671875, tensor(0.3647, device='cuda:0')]
train: [0.8409492838541667, tensor(0.7226, device='cuda:0')]
valid: [0.33439150390625, tensor(0.9018, device='cuda:0')]


In [45]:
learn, run = get_learn_run(nfs, data, 0.6, ConvLayer, cbfs)

In [46]:
#export
def get_batch(dl, run):
    run.xb, run.yb = next(iter(dl))
    for cb in run.cbs: cb.set_runner(run)
    run('begin_batch')
    return run.xb, run.yb

In [47]:
xb, yb = get_batch(data.train_dl, run)

In [48]:
mdl = learn.model.cuda()

In [49]:
#export
def find_modules(m, cond):
    if cond(m): return [m]
    return sum([find_modules(o,cond) for o in m.children()], [])

In [50]:
#export
def is_lin_layer(l):
    lin_layers = (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear, nn.ReLU)
    return isinstance(l, lin_layers)

In [51]:
mods = find_modules(learn.model, lambda o: isinstance(o, ConvLayer))

In [52]:
mods

[ConvLayer(
   (conv): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
   (relu): GeneralRelu()
 ),
 ConvLayer(
   (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ),
 ConvLayer(
   (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ),
 ConvLayer(
   (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ),
 ConvLayer(
   (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 )]

In [53]:
def append_stat(hook, mod, inp, outp):
    d = outp.data
    hook.mean, hook.std =  d.mean().item(), d.std().item()

In [54]:
Hooks??

[0;31mInit signature:[0m [0mHooks[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mf[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      <no docstring>
[0;31mSource:[0m        
[0;32mclass[0m [0mHooks[0m[0;34m([0m[0mListContainer[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0mf[0m[0;34m)[0m[0;34m:[0m [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m[[0m[0mHook[0m[0;34m([0m[0mm[0m[0;34m,[0m [0mf[0m[0;34m)[0m [0;32mfor[0m [0mm[0m [0;32min[0m [0mmodel[0m[0;34m][0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__enter__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m)[0m[0;34m:[0m [0;32mreturn[0m [0mself[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__exit__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m)[0m[0;34m:[0m [0mself[0

In [55]:
with Hooks(mods, append_stats) as hooks:
    mdl(xb)
    for hook in hooks: print(hook.mean, hook.std)

0.5039094090461731 0.7828356027603149
0.3972770571708679 0.7780233025550842
0.2512374520301819 0.477178156375885
0.25667449831962585 0.4033251404762268
0.1685505509376526 0.27236488461494446


In [56]:
#export
def lsuv_model(m, xb):
    h = Hook(m, append_stat)
    
    while mdl(xb) is not None and abs(h.mean) > 1e-3: m.bias -= h.mean
    while mdl(xb) is not None and abs(h.std - 1) > 1e-3: m.weight.data /= h.std
        
    h.remove()
    return h.mean, h.std

In [57]:
for m in mods: print(lsuv_model(m, xb))

(0.13978813588619232, 0.9999998807907104)
(0.06221422553062439, 1.0)
(0.24237248301506042, 0.9999998807907104)
(0.15730416774749756, 0.9999999403953552)
(0.260518878698349, 0.9999999403953552)


In [58]:
%time run.fit(2, learn)

train: [0.5086228841145833, tensor(0.8342, device='cuda:0')]
valid: [0.14613636474609376, tensor(0.9537, device='cuda:0')]
train: [0.0942843994140625, tensor(0.9705, device='cuda:0')]
valid: [0.41388642578125, tensor(0.8649, device='cuda:0')]
CPU times: user 2.58 s, sys: 376 ms, total: 2.95 s
Wall time: 2.92 s
