In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
import sys
from os.path import join

sys.path.insert(0, '/'.join(sys.path[0].split('/')[:-1] + ['scripts']))
from convolution import *

In [3]:
#export
class MaxPool2d(Module):
    def __init__(self, k_s=3, stride=1, pad=0):
        super().__init__()
        self.k_s = k_s
        self.stride = stride
        self.pad = pad
    
    def fwd(self, inp):
        batch_size, in_c, in_h, in_w = inp.shape
        out_dim = lambda d: (d + 2 * self.pad - self.k_s) // self.stride + 1
        out_c, out_h, out_w = in_c, out_dim(in_h), out_dim(in_w)

        padded = pad_tensor(inp, self.pad, inp.min()) if self.pad > 0 else inp
        out = torch.zeros(batch_size, out_c, out_h, out_w)
        for i in range(out_h):
            for j in range(out_w):
                i_s, j_s = i * self.stride, j * self.stride
                in_window = padded[:, :, i_s: i_s+self.k_s, j_s: j_s+self.k_s]
                out[:,:,i,j] = in_window.max(-1)[0].max(-1)[0]
        return out
            
    def bwd(self, out, inp):
        dL = out.g
        batch_size, out_c, out_h, out_w = dL.shape
        
        padded = pad_tensor(inp, self.pad, inp.min()) if self.pad > 0 else inp
        dX = torch.zeros_like(padded)
        for i in range(out_h):
            for j in range(out_w):
                i_s, j_s = i*stride, j*stride
                inp_w = padded[:, :, i_s:i_s+k_s, j_s:j_s+k_s]                
                mask = torch.zeros_like(inp_w)
                for i in range(mask.shape[0]):
                    for j in range(mask.shape[1]):
                        mask[i,j,:,:] = inp_w[i,j,:,:] == inp_w[i,j,:,:].max()
                dX[:, :, i_s:i_s+k_s, j_s:j_s+k_s] += dL[:,:,i,j][...,None,None] * mask
        padded.g = dX
        
    def __repr__(self):
        return f'MaxPool2d(kernel_size: {self.k_s}, stride: {self.stride}, pad: {self.pad})'

In [4]:
#export
class AvgPool2d(Module):
    def __init__(self, k_s=3, stride=1, pad=0):
        super().__init__()
        self.k_s = k_s
        self.stride = stride
        self.pad = pad
    
    def fwd(self, inp):
        batch_size, in_c, in_h, in_w = inp.shape
        out_dim = lambda d: (d + 2 * self.pad - self.k_s) // self.stride + 1
        out_c, out_h, out_w = in_c, out_dim(in_h), out_dim(in_w)

        padded = pad_tensor(inp, self.pad) if self.pad > 0 else inp
        out = torch.zeros(batch_size, out_c, out_h, out_w)
        for i in range(out_h):
            for j in range(out_w):
                i_s, j_s = i * self.stride, j * self.stride
                in_window = padded[:, :, i_s: i_s+self.k_s, j_s: j_s+self.k_s]
                out[:,:,i,j] = torch.mean(in_window, (-1, -2))
        return out
    
    def bwd(self, out, inp):
        dL = out.g
        batch_size, out_c, out_h, out_w = dL.shape
        
        padded = pad_tensor(inp, self.pad) if self.pad > 0 else inp
        dX = torch.zeros_like(padded)
        for i in range(out_h):
            for j in range(out_w):
                i_s, j_s = i*self.stride, j*self.stride
                dX[:, :, i_s:i_s+self.k_s, j_s:j_s+self.k_s] += dL[:,:,i,j][...,None,None] / (self.k_s ** 2)
        padded.g = dX
            
    def __repr__(self):
        return f'AvgPool2d(kernel: {self.k_s}, stride: {self.stride}, pad: {self.pad})'

In [5]:
x_train, _, _, _ = get_mnist_data()
x_train = x_train[:100].view(-1, 1, 28, 28)
k_s = 5
stride = 2
pad = 0
torch_max = nn.MaxPool2d(k_s, stride, pad)(x_train)
torch_avg = nn.AvgPool2d(k_s, stride, pad)(x_train)
my_max = MaxPool2d(k_s, stride, pad).fwd(x_train)
my_avg = AvgPool2d(k_s, stride, pad).fwd(x_train)

assert(torch_max.shape == my_max.shape)
assert(torch_avg.shape == my_avg.shape)
test_near(torch_max, my_max)
test_near(torch_avg, my_avg)

In [6]:
def get_conv_pool_model(data_bunch):
    return Sequential(Reshape((1, 28, 28)),
                      Conv2d(c_in=1, c_out=4, k_s=5, stride=2, pad=1), # 4, 13, 13
                      AvgPool2d(k_s=2, pad=0), # 4, 12, 12
                      Conv2d(c_in=4, c_out=16, stride=2, leak=1.), # 16, 5, 5
                      Flatten(),
                      Linear(400, 64),
                      ReLU(),
                      Linear(64, 10, True))

In [7]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, y_train, x_valid, y_valid = x_train[:8], y_train[:8], x_valid[:2], y_valid[:2]

data_bunch = get_data_bunch(x_train, y_train, x_valid, y_valid, batch_size=64)
model = get_conv_pool_model(data_bunch)
loss_fn = CrossEntropy()

In [8]:
model

(Sequential)
	(Layer1) Reshape(1, 28, 28)
	(Layer2) Conv2D(in: 1, out: 4, kernel: 5, stride: 2, pad: 1)
	(Layer3) AvgPool2d(kernel: 2, stride: 1, pad: 0)
	(Layer4) Conv2D(in: 4, out: 16, kernel: 3, stride: 2, pad: 0)
	(Layer5) Flatten()
	(Layer6) Linear(400, 64)
	(Layer7) ReLU()
	(Layer8) Linear(64, 10)

In [9]:
loss = loss_fn(model(x_train), y_train)
loss_fn.backward()
model.backward()

xtg = x_train.g.clone()
w1g = model.layers[1].w.grad.clone()
b1g = model.layers[1].b.grad.clone()
w2g = model.layers[3].w.grad.clone()
b2g = model.layers[3].b.grad.clone()
w3g = model.layers[5].w.grad.clone()
b3g = model.layers[5].b.grad.clone()
w4g = model.layers[7].w.grad.clone()
b4g = model.layers[7].b.grad.clone()

x_train2 = x_train.clone().requires_grad_(True)
model.layers[1].w.data.requires_grad_(True)
model.layers[1].b.data.requires_grad_(True)
model.layers[3].w.data.requires_grad_(True)
model.layers[3].b.data.requires_grad_(True)
model.layers[5].w.data.requires_grad_(True)
model.layers[5].b.data.requires_grad_(True)
model.layers[7].w.data.requires_grad_(True)
model.layers[7].b.data.requires_grad_(True)

loss = loss_fn(model(x_train2), y_train)
loss.backward()

test_near(w1g, model.layers[1].w.data.grad)
test_near(b1g, model.layers[1].b.data.grad)
test_near(w2g, model.layers[3].w.data.grad)
test_near(b2g, model.layers[3].b.data.grad)
test_near(w3g, model.layers[5].w.data.grad)
test_near(b3g, model.layers[5].b.data.grad)
test_near(w4g, model.layers[7].w.data.grad)
test_near(b4g, model.layers[7].b.data.grad)

In [10]:
x_train, y_train, x_valid, y_valid = get_mnist_data()
x_train, y_train, x_valid, y_valid = x_train[:8], y_train[:8], x_valid[:2], y_valid[:2]

data_bunch = get_data_bunch(x_train, y_train, x_valid, y_valid, batch_size=64)
model = get_conv_pool_model(data_bunch)
loss_fn = CrossEntropy()

In [11]:
model

(Sequential)
	(Layer1) Reshape(1, 28, 28)
	(Layer2) Conv2D(in: 1, out: 4, kernel: 5, stride: 2, pad: 1)
	(Layer3) AvgPool2d(kernel: 2, stride: 1, pad: 0)
	(Layer4) Conv2D(in: 4, out: 16, kernel: 3, stride: 2, pad: 0)
	(Layer5) Flatten()
	(Layer6) Linear(400, 64)
	(Layer7) ReLU()
	(Layer8) Linear(64, 10)

In [12]:
loss = loss_fn(model(x_train), y_train)
loss_fn.backward()
model.backward()

xtg = x_train.g.clone()
w1g = model.layers[1].w.grad.clone()
b1g = model.layers[1].b.grad.clone()
w2g = model.layers[3].w.grad.clone()
b2g = model.layers[3].b.grad.clone()
w3g = model.layers[5].w.grad.clone()
b3g = model.layers[5].b.grad.clone()
w4g = model.layers[7].w.grad.clone()
b4g = model.layers[7].b.grad.clone()

x_train2 = x_train.clone().requires_grad_(True)
model.layers[1].w.data.requires_grad_(True)
model.layers[1].b.data.requires_grad_(True)
model.layers[3].w.data.requires_grad_(True)
model.layers[3].b.data.requires_grad_(True)
model.layers[5].w.data.requires_grad_(True)
model.layers[5].b.data.requires_grad_(True)
model.layers[7].w.data.requires_grad_(True)
model.layers[7].b.data.requires_grad_(True)

loss = loss_fn(model(x_train2), y_train)
loss.backward()

test_near(w1g, model.layers[1].w.data.grad)
test_near(b1g, model.layers[1].b.data.grad)
test_near(w2g, model.layers[3].w.data.grad)
test_near(b2g, model.layers[3].b.data.grad)
test_near(w3g, model.layers[5].w.data.grad)
test_near(b3g, model.layers[5].b.data.grad)
test_near(w4g, model.layers[7].w.data.grad)
test_near(b4g, model.layers[7].b.data.grad)