In [90]:
import os

from google.colab import drive
drive.mount('/content/gdrive')

os.chdir('/content/gdrive/MyDrive/first_try_of_fastai')

print("------------------------------------------------------------------")

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
------------------------------------------------------------------


In [91]:
#export
import os
os.chdir('/content/gdrive/MyDrive/first_try_of_fastai/exp')
from nb_07 import *
os.chdir('/content/gdrive/MyDrive/first_try_of_fastai')

In [92]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [93]:
#export
class AvgStats():
  def __init__(self, metrics, in_train):
    self.metrics = listify(metrics)
    self.in_train = in_train
    self.tot_loss = 0.
    self.count=0
    self.tot_mets = [0.] * len(self.metrics)

  def reset(self):
    self.tot_loss, self.count = 0.,0
    self.tot_mets = [0.] * len(self.metrics)
  
  @property
  def all_stats(self): return[self.tot_loss.item()] + self.tot_mets
  @property
  def avg_stats(self): return[o/self.count for o in self.all_stats]

  def __repr__(self):
    if not self.count: return ''
    return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

  def accumulate(self, run):
    bn = run.xb.shape[0]
    self.tot_loss += run.loss * bn
    self.count += bn
    for i,m in enumerate(self.metrics):
      self.tot_mets[i] += m(run.pred, run.yb) * bn

class AvgStatsCallback(Callback):
  def __init__(self, metrics):
    self.train_stats,self.valid_stats = AvgStats(metrics,True), AvgStats(metrics, False)

  def begin_epoch(self):
    self.train_stats.reset()
    self.valid_stats.reset()

  def after_loss(self):
    stats = self.train_stats if self.in_train else self.valid_stats
    with torch.no_grad(): stats.accumulate(self.run)

  def after_epoch(self):
    print(self.train_stats)
    print(self.valid_stats)

class Recorder(Callback):
  def begin_fit(self): self.lrs, self.losses = [],[]

  def after_batch(self):
    if not self.in_train: return
    self.lrs.append(self.opt.param_groups[-1]['lr'])
    self.losses.append(self.loss.detach().cpu())

  def plot_lr(self): plt.plot(self.lrs)
  def plot_loss(self): plt.plot(self.losses)

In [94]:
x_train,y_train,x_valid,y_valid = get_data()

x_train,x_valid = normalize_to(x_train,x_valid)
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)

nh, bs = 50,512
c=y_train.max().item()+1
loss_func = F.cross_entropy

data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [95]:
mnist_view = view_tfm(1,28,28)
cbfs = [Recorder,
        partial(AvgStatsCallback,accuracy),
        CudaCallback,
        partial(BatchTransformXCallback, mnist_view)]

In [96]:
nfs = [8,16,32,64,64]

In [97]:
class ConvLayer(nn.Module):
  def __init__(self, ni, nf, ks=3, stride=2, sub=0., **kwargs):
    super().__init__()
    self.conv = nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=True)
    self.relu = GeneralRelu(sub=sub, **kwargs)

  def forward(self, x): return self.relu(self.conv(x))

  @property
  def bias(self): return -self.relu.sub
  @bias.setter
  def bias(self,v): self.relu.sub = -v
  @property
  def weight(self): return self.conv.weight

In [98]:
learn, run = get_learn_run(nfs, data, 0.6, ConvLayer, cbs =cbfs)

In [99]:
run.fit(2, learn)

train: [1.221888515625, tensor(0.5880, device='cuda:0')]
valid: [0.1947199462890625, tensor(0.9397, device='cuda:0')]
train: [1.221888515625, tensor(0.5880, device='cuda:0')]
valid: [0.20219624720982143, tensor(0.9386, device='cuda:0')]


In [100]:
learn, run = get_learn_run(nfs, data, 0.6, ConvLayer, cbs=cbfs)

In [101]:
#export
def get_batch(dl, run):
  run.xb,run.yb = next(iter(dl))
  for cb in run.cbs: cb.set_runner(run)
  run('begin_batch')
  return run.xb,run.yb

In [102]:
xb,yb = get_batch(data.train_dl, run)

In [103]:
#export
def find_modules(m, cond):
  if cond(m): return [m]
  return sum([find_modules(o,cond) for o in m.children()], [])

def is_lin_layer(l):
  lin_layers = (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear, nn.ReLU)
  return isinstance(l, lin_layers)

def children(m): return list(m.children())

class Hook():
  def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
  def remove(self): self.hook.remove()
  def __del__(self): self.remove()

def append_stats(hook, mod, inp, outp):
  if not hasattr(hook,'stats'): hook.stats = ([],[])
  means, stds = hook.stats
  means.append(outp.data.mean())
  stds.append(outp.data.std())

class Hooks(ListContainer):
  def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
  def __enter__(self, *args): return self
  def __exit__(self, *args): self.remove()
  def __del__(self): self.remove()

  def __delitem__(self, i):
    self[i].remove()
    super().__delitem__(i)

  def remove(self):
    for h in self: h.remove()

In [104]:
mods = find_modules(learn.model, lambda o: isinstance(o, ConvLayer))

In [105]:
mods

[ConvLayer(
   (conv): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 )]

In [106]:
def append_stat(hook, mod, inp, outp):
  d = outp.data
  hook.mean,hook.std = d.mean().item(), d.std().item()

In [107]:
mdl = learn.model.cuda()

In [108]:
with Hooks(mods, append_stat) as hooks:
  mdl(xb)
  for hook in hooks: print(hook.mean,hook.std)

0.4250083267688751 0.8984415531158447
0.5426522493362427 1.0302879810333252
0.47679564356803894 0.9402784705162048
0.4566575884819031 0.7585752606391907
0.3478483557701111 0.4889228940010071


In [109]:
#export
def lsuv_module(m, xb):
  h = Hook(m, append_stat)

  while mdl(xb) is not None and abs(h.mean) > 1e-3: m.bias -=h.mean
  while mdl(xb) is not None and abs(h.std-1) > 1e-3: m.weight.data /=h.std

  h.remove()
  return h.mean, h.std

In [111]:
for m in mods: print(lsuv_module(m, xb))

(-1.9462740752373975e-08, 1.0)
(0.006507309153676033, 0.9999998807907104)
(0.00389703456312418, 1.0)
(0.0057980758138000965, 0.9999999403953552)
(0.0015523573383688927, 1.0)
