# Convolution

In [15]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
#export
import sys
sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))

from functools import reduce
import torch.nn as nn # imported for testing my convolution layer
from torch.nn.functional import pad as torch_pad
from linear import *

In [17]:
#export
def pad_tensor(inp, pad, value=0):
    '''Util function for padding inp tensor'''
    return torch_pad(inp, [pad]*4, 'constant', value)

In [18]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train = x_train[:100].view(-1, 1, 28, 28)
torch_conv = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3)

In [19]:
torch_conv_out_mean = 0.
torch_conv_out_std = 0.
for _ in range(100):
    torch_conv.weight = nn.Parameter(init_4d_weight((4, 1, 3, 3), 0.))
    torch_conv_out = torch_conv(x_train).clamp_min(0.)
    torch_conv_out_mean += torch_conv_out.mean()
    torch_conv_out_std += torch_conv_out.std()
torch_conv_out_mean /= 100.0
torch_conv_out_std /= 100.0
print(f'mean: {torch_conv_out_mean}, std: {torch_conv_out_std}')

mean: 0.500054657459259, std: 0.8764498829841614


In [20]:
#export
class Reshape(Module):
    '''Reshape layer'''
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
    
    def fwd(self, inp): 
        return inp.view(-1, *self.shape)
    
    def bwd(self, out, inp): 
        # simply reverse the fwd
        inp.g = out.g.reshape(-1, reduce(lambda x,y: x*y, self.shape))
    
    def __repr__(self, t=''): 
        return f"{t+'    '}Reshape{self.shape}"

In [21]:
#export
class Flatten(Module):
    '''Flatten layer'''
    def __init__(self):
        super().__init__()
        
    def fwd(self, inp):
        self.batch_size, *self.shape = inp.shape
        return inp.view(self.batch_size, -1)
    
    def bwd(self, out, inp):
        inp.g = out.g.view(-1, *self.shape)
        
    def __repr__(self, t=''):
        return f"{t+'    '}Flatten()"

In [22]:
#export
class Conv(Module):
    '''Convolutional layer'''
    def __init__(self, c_in, c_out, k_s=3, stride=1, pad=0, leak=1.):
        super().__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.k_s = k_s
        self.stride = stride
        self.pad = pad   
        
        self.w = Parameter(init_4d_weight((c_out, c_in, k_s, k_s), leak))
        self.b = Parameter(torch.zeros(c_out))
    
    def fwd(self, inp):
        batch_size, _, in_h, in_w = inp.shape
        inp = pad_tensor(inp, self.pad)
        _, _, p_h, p_w = inp.shape

        # init output
        out_dim = lambda d: (d + 2 * self.pad - self.k_s) // self.stride + 1
        out = torch.zeros(batch_size, self.c_out, out_dim(in_h), out_dim(in_w))

        # compute output cell by cell
        for i in range(0, p_h - self.k_s + 1, self.stride):
            for j in range(0, p_w - self.k_s + 1, self.stride):
                receptive_field = inp[:, :, i:i+self.k_s, j:j+self.k_s].unsqueeze(1)
                out[:, :, i//self.stride, j//self.stride] = (receptive_field * self.w.data).sum((-1,-2,-3)) + self.b.data

        return out
    
    def bwd(self, out, inp):
        # source of var names and math calcs: https://medium.com/@pavisj/convolutions-and-backpropagations-46026a8f5d2c
        dL = out.g
        X, F, B = pad_tensor(inp, self.pad), self.w.data, self.b.data
        dX, dF, dB = torch.zeros_like(X), torch.zeros_like(F), torch.zeros_like(B)
        k_s = F.shape[2]
        _, _, out_h, out_w = dL.shape

        # each cell in output are computed from a receptive field in input
        for i in range(out_h):
            for j in range(out_w):
                i_s, j_s = i * self.stride, j * self.stride
                receptive_field = X[:, :, j_s: j_s+k_s, i_s: i_s+k_s].unsqueeze(1)
                dL_section = dL[:, :, j, i][..., None, None, None]

                dX[:, :, j_s: j_s+k_s, i_s: i_s+k_s] += (F * dL_section).sum(1)
                dF += (receptive_field * dL_section).sum(0)
                dB += dL[:, :, j, i].sum(0)

        self.w.update(dF)
        self.b.update(dB)
        inp.g = dX if self.pad == 0 else dX[:, :, self.pad: -self.pad, self.pad: -self.pad]
    
    def __repr__(self, t=''): 
        return f"{t+'    '}Conv({self.c_in}, {self.c_out}, {self.k_s}, {self.stride})"

In [23]:
#export
def get_conv_model(data_bunch):
    '''Util function to get convolutional model based on data bunch shape'''
    in_dim = data_bunch.train_ds.x_data.shape[1]
    out_dim = int(max(data_bunch.train_ds.y_data) + 1)
    assert in_dim == 1 * 28 * 28
    return Sequential(Reshape((1, 28, 28)),
                      Conv(1, 8, 5, stride=4, pad=2, leak=0.), # 8, 7, 7 
                      ReLU(), 
                      Conv(8, 16, 3, stride=2, pad=1, leak=1.), # 16, 4, 4
                      Flatten(),
                      Linear(256, out_dim, True))

# Tests

In [24]:
inp = torch.randn(12, 18, 37, 37)
batch_size = inp.shape[0]
c_in, c_out = inp.shape[1], 7
k_s = 5
pad = 3
stride = 3

print(f'input shape: {tuple(inp.shape)}')
print(f'padding: {pad}')
print(f'stride: {stride}')

input shape: (12, 18, 37, 37)
padding: 3
stride: 3


In [25]:
torch_conv = nn.Conv2d(c_in, c_out, k_s, stride, pad)
torch_res = torch_conv(inp)

conv_layer = Conv(c_in, c_out, k_s, stride, pad)
conv_layer.w = Parameter(torch_conv.weight)
conv_layer.b = Parameter(torch_conv.bias)
my_res = conv_layer.fwd(inp)

print(f'weight shape {tuple(conv_layer.w.data.shape)}')
print(f'bias shape: {tuple(conv_layer.b.data.shape)}')
print(f'output shape: {tuple(my_res.shape)}')

assert(my_res.shape == torch_res.shape)
test_near(my_res, torch_res)

weight shape (7, 18, 5, 5)
bias shape: (7,)
output shape: (12, 7, 13, 13)


In [26]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, y_train, x_valid, y_valid = x_train[:8], y_train[:8], x_valid[:2], y_valid[:2]

data_bunch = get_data_bunch(x_train, y_train, x_valid, y_valid, batch_size=64)
model = get_conv_model(data_bunch)
loss_fn = CrossEntropy()

In [27]:
model

(Model)
    Reshape(1, 28, 28)
    Conv(1, 8, 5, 4)
    ReLU()
    Conv(8, 16, 3, 2)
    Flatten()
    Linear(256, 10)

In [28]:
loss = loss_fn(model(x_train), y_train)
loss_fn.backward()
model.backward()

xtg = x_train.g.clone()
w1g = model.layers[1].w.grad.clone()
b1g = model.layers[1].b.grad.clone()
w2g = model.layers[3].w.grad.clone()
b2g = model.layers[3].b.grad.clone()
w3g = model.layers[5].w.grad.clone()
b3g = model.layers[5].b.grad.clone()

x_train2 = x_train.clone().requires_grad_(True)
model.layers[1].w.data.requires_grad_(True)
model.layers[1].b.data.requires_grad_(True)
model.layers[3].w.data.requires_grad_(True)
model.layers[3].b.data.requires_grad_(True)
model.layers[5].w.data.requires_grad_(True)
model.layers[5].b.data.requires_grad_(True)

loss = loss_fn(model(x_train2), y_train)
loss.backward()

test_near(w1g, model.layers[1].w.data.grad)
test_near(b1g, model.layers[1].b.data.grad)
test_near(w2g, model.layers[3].w.data.grad)
test_near(b2g, model.layers[3].b.data.grad)
test_near(w3g, model.layers[5].w.data.grad)
test_near(b3g, model.layers[5].b.data.grad)