In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path

from fastai import *
from fastai.vision import *
from ipyexperiments import *
from fastai.basic_train import *

In [3]:
import sys
sys.path.append("../dev")
from data_utils import seed_everything
from fastai.train import *

In [26]:
def get_bn_layers(m:nn.Module)->None:
    "Set bn layers in eval mode for all recursive children of `m`."
    for l in m.children():
        if isinstance(l, bn_types):
            return l
        find_active_bn(l)

In [19]:
class AccumulateBatchNorm(nn.Module):
    
    def __init__(self, bn_class, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super().__init__()
        self.num_features = num_features
        self.track_running_stats,self.momentum = track_running_stats,momentum
        self.bn = bn_class(num_features, eps=eps, momentum=momentum, affine=affine,
                 track_running_stats=track_running_stats)
        self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_running_stats(self):
        self.running_mean,self.running_square,self.iterations = None,None,None
        self.bn.reset_running_stats()
    
    def update_stats(self):
        if self.training and self.track_running_stats:
            self.bn.num_batches_tracked += 1
            eaf = 1.0 / float(self.bn.num_batches_tracked) if self.bn.momentum is None else self.bn.momentum
            self.bn.running_mean = self.bn.running_mean * (1-eaf) + self.running_mean * eaf / self.iterations
            var = self.running_square/self.iterations - (self.running_mean/self.iterations).pow(2)
            self.bn.running_var  = self.bn.running_var  * (1-eaf) + var  * eaf
            self.running_mean,self.running_square,self.iterations = None,None,None
    
    def reset_parameters(self):
        self.bn.reset_parameters()
        
    def forward(self, input):
        self.bn._check_input_dim(input)
        if self.track_running_stats:
            if self.iterations is None:
                self.running_mean   = self.bn.weight.new_zeros(self.num_features)
                self.running_square = self.bn.weight.new_zeros(self.num_features)
                self.iterations   = 0
            self.running_mean += input.view(input.size(0), input.size(1), -1).mean(2).sum(0)
            self.running_square += input.view(input.size(0), input.size(1), -1).pow(2).mean(2).sum(0)
            self.iterations += input.size(0)
        return torch.batch_norm(input, self.bn.weight, self.bn.bias, self.bn.running_mean, self.bn.running_var, 
            False, 0., self.bn.eps, torch.backends.cudnn.enabled)  

In [20]:
class AccumulateStepper(LearnerCallback):
    "Does accumlated step every nth step by accumulating gradients"

    def __init__(self, learn:Learner, n_step:int = 1, drop_last:bool = False):
        super().__init__(learn)
        self.n_step,self.drop_last = n_step,drop_last
        # wrap all BN layers
        
 
    def on_train_begin(self, **kwargs):
        "check if loss is reduction"
        if hasattr(self.loss_func, "reduction") and (self.loss_func.reduction != "sum"):
             warn("For better gradients consider 'reduction=sum'")
        
    def on_epoch_begin(self, **kwargs):
        "init samples and batches, change optimizer"
        self.acc_samples, self.acc_batches = 0., 0. 
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        "accumulate samples and batches"
        self.acc_samples += last_input.shape[0]
        self.acc_batches += 1
        
    def on_backward_end(self, **kwargs):
        "accumulated step and reset samples, True will result in no stepping"
        if (self.acc_batches % self.n_step) == 0:
            for p in (self.learn.model.parameters()):
                if p.requires_grad: p.grad.div_(self.acc_samples)
            self.acc_samples = 0
        else: return True
    
    def on_step_end(self, **kwargs):
        "zero gradients after stepping, True will result in no zeroing"
        return (self.acc_batches % self.n_step) != 0
    
    def on_epoch_end(self, **kwargs):
        "step the rest of the accumulated grads if not perfectly divisible"
        for p in (self.learn.model.parameters()):
                if p.requires_grad: p.grad.div_(self.acc_samples)
        if not self.drop_last: self.learn.opt.step()
        self.learn.opt.zero_grad()


Maybe use `loss(reduction="sum")` and average before `real_step()`

### MNIST batch size = 32, no accumulation

`effective batch size = 32`

In [21]:
import torchvision

In [22]:
# seed everything for reproducibility
seed_everything(42)

In [23]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, bs=32)

In [24]:
learn = create_cnn(data, models.resnet18, metrics=accuracy)

In [27]:
[find_bn_layers(learn.model)]

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(256, eps=1e-05, momen

[None]

In [11]:
learn.loss_func = CrossEntropyFlat(reduction='mean')

In [12]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy
1,0.141535,0.079934,0.972031


### MNIST batch size=4, accumulate every n_step=8, Default

`effective batch size = 32 (bs x n_step)`

In [13]:
# seed everything for reproducibility
seed_everything(42)

In [14]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, bs=2)

In [15]:
# freeze_to(-1) except for BN layers
learn = create_cnn(data, models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=16)])

In [16]:
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [17]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy
1,1.153869,2.707523,0.507851


### MNIST batch size=4, accumulate every n_step=8, BNFreeze

`effective batch size = 32 (bs x n_step)`

In [86]:
from fastai.train import BnFreeze

In [87]:
class BnFreeze(LearnerCallback):
    "Freeze moving average statistics in all non-trainable batchnorm layers."
    def on_train_begin(self, **kwargs:Any)->None:
        "Put bn layers in eval mode just after `model.train()`."
        set_bn_eval(self.learn.model)

In [88]:
# seed everything for reproducibility
seed_everything(42)

In [89]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, bs=2)

In [90]:
# freeze_to(-1) including BN layers
learn = create_cnn(data, models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=16), BnFreeze])

In [101]:
# freeze bn layers
for g in learn.layer_groups[:-1]:
    for l in g:
        if isinstance(l, bn_types): requires_grad(l, False)

In [103]:
find_active_bn(learn.model)

BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


In [104]:
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [105]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy
1,1.140334,2.693619,0.601570


### MNIST batch size=4, accumulate every n_step=8, BNFreeze + More momentum

`effective batch size = 32 (bs x n_step)`

In [122]:
from fastai.train import BnFreeze

In [123]:
class BnFreeze(LearnerCallback):
    "Freeze moving average statistics in all non-trainable batchnorm layers."
    def on_train_begin(self, **kwargs:Any)->None:
        "Put bn layers in eval mode just after `model.train()`."
        set_bn_eval(self.learn.model)

In [124]:
# seed everything for reproducibility
seed_everything(42)

In [125]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, bs=2)

In [126]:
# freeze_to(-1) including BN layers
learn = create_cnn(data, models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=16), BnFreeze])

In [127]:
# # freeze bn layers
# for g in learn.layer_groups[:-1]:
#     for l in g:
#         if isinstance(l, bn_types): requires_grad(l, False)
for g in learn.layer_groups:
    for l in g:
        if isinstance(l, bn_types): l.momentum = 0.9

In [128]:
find_active_bn(learn.model)

BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
BatchNorm2d(256, eps=1e-05, momen

In [129]:
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [130]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy
1,1.153869,411.230682,0.479392


### MNIST batch size=4, accumulate every n_step=8, BNFreeze + Custom Running

`effective batch size = 32 (bs x n_step)`

In [104]:
# seed everything for reproducibility
seed_everything(42)

In [105]:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path, bs=2)

In [112]:
learn = create_cnn(data, models.resnet18, metrics=accuracy,
                   callback_fns=[partial(AccumulateStepper, n_step=16)])

In [113]:
learn.loss_func = CrossEntropyFlat(reduction='sum')

In [114]:
learn.fit(4)

epoch,train_loss,valid_loss,accuracy
1,1.543319,1.134782,0.743867
2,1.379197,1.387173,0.564769
3,1.277315,1.233230,0.696762
4,1.208992,1.130871,0.759078


### BatchNorm

In [54]:
seed_everything(42)

In [55]:
x = torch.randn((2, 5));x

tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674]])

In [56]:
bn = nn.BatchNorm1d(5)

In [57]:
mom = bn.momentum, 
mean = bn.running_mean
var = bn.running_var
w = bn.weight
b = bn.bias
eps = bn.eps

mean, var, w, b, eps

(tensor([0., 0., 0., 0., 0.]),
 tensor([1., 1., 1., 1., 1.]),
 Parameter containing:
 tensor([0.2696, 0.4414, 0.2969, 0.8317, 0.1053], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0.], requires_grad=True),
 1e-05)

In [68]:
bn(x)

tensor([[ 0.2696, -0.4414,  0.2969, -0.8314, -0.1053],
        [-0.2696,  0.4414, -0.2969,  0.8314,  0.1053]],
       grad_fn=<NativeBatchNormBackward>)

In [69]:
mom = bn.momentum, 
mean = bn.running_mean
var = bn.running_var
w = bn.weight
b = bn.bias
eps = bn.eps

mean, var, w, b, eps

(tensor([ 0.0352,  0.5475, -0.0945,  0.1621, -0.2004]),
 tensor([0.5955, 1.5444, 0.7098, 0.5440, 0.9842]),
 Parameter containing:
 tensor([0.2696, 0.4414, 0.2969, 0.8317, 0.1053], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0.], requires_grad=True),
 1e-05)

In [9]:
model = learn.model

In [19]:
model[0][:3]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
)

In [13]:
bn = model[0][1]

In [22]:
bn.momentum, bn.running_mean, bn.running_var, bn.weight, bn.bias, bn.eps

(0.1, tensor([ 2.7681e-03, -2.5769e-02,  2.1254e-07, -8.4605e-02,  2.1121e-08,
          4.9691e-04, -2.2408e-02, -1.1582e-07, -4.8239e-03,  2.7507e-07,
          3.9582e-02,  3.1994e-02, -3.7490e-02, -1.3716e-06,  6.6002e-03,
          4.3782e-03,  6.4797e-02,  1.1176e-01,  3.6002e-02, -7.5075e-02,
         -3.8240e-02,  8.4358e-02, -5.2287e-02, -1.1799e-02,  1.3019e-03,
          3.2172e-02, -1.7784e-02, -9.1009e-02,  1.1319e-01, -4.1632e-02,
          8.7302e-03,  2.9693e-02, -7.0502e-02, -3.4847e-03,  1.0977e-01,
         -1.7341e-03, -5.9423e-08,  2.9330e-02, -7.8553e-09,  6.7320e-03,
         -3.7100e-03,  1.6028e-02, -2.7883e-02,  2.6593e-02,  2.8475e-02,
         -1.2735e-01,  4.4617e-02,  2.6329e-02,  2.1454e-08, -1.7045e-02,
         -3.5617e-03, -4.5841e-02,  6.3876e-02,  1.5220e-02, -3.8511e-02,
         -1.6428e-02, -1.6569e-02,  5.6057e-02, -8.0306e-02, -2.6646e-03,
         -4.1718e-02,  1.2611e-01, -4.9237e-02, -1.3261e-02], device='cuda:0'), tensor([1.0169e+00, 3.7167e

In [None]:
class MyBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True):
        super(MyBatchNorm, self).__init__()
        self.bn = nn.BatchNorm1d(num_features,
                                 eps=eps,
                                 momentum=momentum,
                                 affine=affine)

    def forward(self, x):        
        x = self.bn(x)
        mu = self.bn.running_mean
        var = self.bn.running_var
        gamma = self.bn.weight
        beta = self.bn.bias
        eps = self.bn.eps    
        k = gamma.data / torch.sqrt(var + eps)
        x.data = k * x.data + beta.data
        return x

mybn = MyBatchNorm(10)
x = Variable(torch.randn(16, 10))
x_ = mybn(x)