In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from pooling import *

In [3]:
#export
def weighted_sum(t1, t2, ratio):
    assert(0 <= ratio <= 1)
    return t1 * ratio + t2 * (1 - ratio) 

\begin{equation}\label{eq:bnorm}
    \hat{x}_{i} =  \frac{x_{i} - \mu_{\beta}}{\sqrt{\sigma_{\beta}^2 + \epsilon}}
    \\
    y_{i} = \gamma\hat{x}_{i} + \beta
\end{equation}

In [4]:
#export
class BatchNorm(Module):
    def __init__(self, c, momentum=0.1, epsilon=1e-6):
        super().__init__()
        self.momentum = momentum
        self.epsilon = epsilon
        self.mean = torch.zeros(1,c,1,1)
        self.var  = torch.ones (1,c,1,1)
        # trainable linear transformation
        self.gamma = Parameter(torch.ones (1,c,1,1))
        self.beta  = Parameter(torch.zeros(1,c,1,1))
    
    def update_stats(self, inp):
        mean = inp.mean((0,2,3), keepdim=True)
        var =  inp.var ((0,2,3), keepdim=True) 
        self.mean = weighted_sum(self.mean, mean, self.momentum)
        self.var = weighted_sum(self.var, var, self.momentum)
        return mean, var
    
    def fwd(self, inp): 
        mean, var = self.update_stats(inp)
        self.x_hat = (inp - mean) / (var + self.epsilon).sqrt()
        return self.gamma.data * self.x_hat + self.beta.data
        
    def bwd(self, out, inp):
        # learned from: https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html
        dL = out.g
        dLdg = (dL * self.x_hat).sum((0,2,3), keepdim=True)
        dLdb = dL.sum((0,2,3), keepdim=True)
        self.gamma.update(dLdg)
        self.beta.update(dLdb)
        
        n = dL.shape[0]
        dLdx = dL * self.gamma.data
        
        denom = (self.var + self.epsilon).sqrt()
        numer = inp - self.mean
        
        dv = (-1/2)*dLdx*numer / (self.var+self.epsilon)**1.5
        dm = ( 2/n)*dLdx/denom + numer*dv
        
        inp.g = (2*dv*numer + dm)/n + dLdx/denom
    
    def __repr__(self): return f'BatchNorm()'

In [5]:
#export
def get_conv_final_model(data_bunch):
    return Sequential(Reshape((1, 28, 28)),
                      Conv2d(c_in=1, c_out=4, k_s=5, stride=2, pad=1), # 4, 13, 13
                      AvgPool2d(k_s=2, pad=0), # 4, 12, 12
                      BatchNorm(4),
                      Conv2d(c_in=4, c_out=16, stride=2, leak=1.), # 16, 5, 5
                      BatchNorm(16),
                      Flatten(),
                      Linear(400, 64),
                      ReLU(),
                      Linear(64, 10, True))

In [6]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, y_train, x_valid, y_valid = x_train[:8000], y_train[:8000], x_valid[:2000], y_valid[:2000]

data_bunch = get_data_bunch(x_train, y_train, x_valid, y_valid, batch_size=64)
model = get_conv_final_model(data_bunch)
optimizer = DynamicOpt(list(model.parameters()), learning_rate=0.1)
loss_fn = CrossEntropy()
callbacks = [StatsLogging()]

In [7]:
learner = Learner(data_bunch, model, loss_fn, optimizer, callbacks)
print(learner)

(DataBunch) 
	(DataLoader) 
		(Dataset) x: (8000, 784), y: (8000,)
		(Sampler) total: 8000, batch_size: 64, shuffle: True
	(DataLoader) 
		(Dataset) x: (2000, 784), y: (2000,)
		(Sampler) total: 2000, batch_size: 128, shuffle: False
(Sequential)
	(Layer1) Reshape(1, 28, 28)
	(Layer2) Conv2D(in: 1, out: 4, kernel: 5, stride: 2, pad: 1)
	(Layer3) AvgPool2d(kernel: 2, stride: 1, pad: 0)
	(Layer4) BatchNorm()
	(Layer5) Conv2D(in: 4, out: 16, kernel: 3, stride: 2, pad: 0)
	(Layer6) BatchNorm()
	(Layer7) Flatten()
	(Layer8) Linear(400, 64)
	(Layer9) ReLU()
	(Layer10) Linear(64, 10)
(CrossEntropy)
(DynamicOpt) num_params: 12, hyper_params: ['learning_rate']
(Callbacks) ['TrainEval', 'StatsLogging']


In [8]:
learner.fit(5)

Epoch - 1
train metrics - [0.0011055092811584472, 0.85675]
valid metrics - [0.018804485321044923, 0.8905]

Epoch - 2
train metrics - [0.0009419311881065369, 0.932375]
valid metrics - [0.013131357192993165, 0.926]

Epoch - 3
train metrics - [0.0014515787363052369, 0.943]
valid metrics - [0.0122059907913208, 0.9335]

Epoch - 4
train metrics - [0.0015083916187286378, 0.954875]
valid metrics - [0.008289137840270995, 0.935]

Epoch - 5
train metrics - [0.0008319444060325622, 0.96]
valid metrics - [0.007107770442962647, 0.933]

