In [1]:
import torch
from torch.nn.functional import fold , unfold
import matplotlib.pyplot as plt


In [2]:
class convolution(object):
    def __init__(self, in_ch, out_ch, kernel_size = (3,3), padding = 0, stride = 1, use_bias = False):
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.k_1 = self.kernel_size[0]
        self.k_2 = self.kernel_size[1]
        self.use_bias = use_bias
        self.stride = stride
        self.padding = padding
        self.kernel = torch.empty(out_ch, in_ch, self.k_1, self.k_2).normal_()
        self.bias = torch.empty(out_ch).normal_() if use_bias else torch.zeros(out_ch)
        
    def forward(self, x):   
        self.x = x
        self.batch_size = x.size(0)
        X_unf = unfold(x, kernel_size=(self.k_1, self.k_2), padding = self.padding, stride = self.stride)
        K_expand = self.kernel.view(self.out_ch, -1)
        O_expand = K_expand @ X_unf
        s1 = torch.tensor(x.size(-2)-self.k_1+1+self.padding*2).div(self.stride).ceil().int()
        s2 = torch.tensor(x.size(-1)-self.k_2+1+self.padding*2).div(self.stride).ceil().int()
        
        print('s1 & s2',s1,s2)

        O = O_expand.view(self.batch_size, self.out_ch, s1, s2)
        return O + self.bias.view(1, -1, 1, 1) if self.use_bias else O
    
    def backward(self, gradwrtoutput):
        kernel_back = self.kernel.flip(-2, -1).transpose(0,1)
        s1 = self.x.size(-2)
        s2 = self.x.size(-1)
        
        # backward wrt input
        M = self.get_M(s1-k_1 + 1 + self.padding*2)
        dL_dO = (M.transpose(0,1) @ gradwrtoutput) @ M

        dL_dO_unf = unfold(dL_dO, kernel_size=(k_1, k_2), padding = (k_1 - 1 - self.padding, k_2-1- self.padding), stride = 1)
        dO_dX_exp = kernel_back.reshape(self.in_ch, -1)
        dL_dX_exp = dO_dX_exp @ dL_dO_unf
        dL_dX = dL_dX_exp.view(self.batch_size, self.in_ch, s1, s2)
        
        self.dL_dO = dL_dO.transpose(0,1) # K
        self.dO_dF = self.x.view(self.in_ch, self.batch_size, s1, s2).transpose(0,1) # X
        
        # backward wrt weights
        dL_dO_unf_F = self.dL_dO.reshape(self.out_ch, -1)
        dO_dF_exp = unfold(self.dO_dF, kernel_size = (s1 - self.k_1 +1 + self.padding*2, s2 - self.k_2 +1 + self.padding*2), padding = self.padding, stride = 1)
        dL_dF_exp = dL_dO_unf_F @ dO_dF_exp
        dL_dF = dL_dF_exp.transpose(0,1).view(self.kernel.size())
        
        # backward wrt bias
        if self.use_bias:
            dL_dO_exp = self.dL_dO.reshape(self.out_ch, -1)
            dO_dB_exp = torch.ones(self.batch_size * (s1 - self.k_1 +1 + self.padding*2) * (s2 - self.k_2 +1 + self.padding*2))
            dL_dB = dL_dO_exp @ dO_dB_exp
        else:
            dL_dB = None
        return dL_dX, dL_dF, dL_dB
    
    def get_M(self, N):
        return torch.eye(N)[range(0,N,self.stride)]
        
    def param(self) :
        return [self.kernel, self.bias]

        

In [3]:
# Initial parameters
s_1, s_2 = 7,7
k_1, k_2 = 3,3
bs = 2
ch_in, ch_out = 2, 4
stride = 2
padding = 1
# input tensor 
X = torch.empty(bs, ch_in, s_1, s_2).normal_().requires_grad_()
X_copy = X.clone().detach().requires_grad_()

# initialize convolution moduls
conv = convolution(ch_in, ch_out, kernel_size = (k_1, k_2), padding = padding, use_bias=True, stride = stride)

# get weigts and bias
F = conv.kernel
B = conv.bias
F.requires_grad_()
B.requires_grad_()

# forward
out = conv.forward(X)
out_compare = torch.nn.functional.conv2d(X_copy, F, bias = B, stride = stride, padding=padding)

# backward
dL_dX, dL_dF, dL_dB = conv.backward(out/out)
out_compare.backward(out_compare/out_compare)



print('same output of conv: ', (out_compare - out).abs().sum()) 
print('same input gradient: ', (X_copy.grad - dL_dX).abs().sum())
print('same weigth gradient: ',(F.grad-dL_dF).abs().sum() )
print('same bias gradient: ',(B.grad-dL_dB).abs().sum() )


same output of conv:  tensor(2.8983e-05, grad_fn=<SumBackward0>)
same input gradient:  tensor(4.3720e-05, grad_fn=<SumBackward0>)
same weigth gradient:  tensor(0., grad_fn=<SumBackward0>)
same bias gradient:  tensor(0., grad_fn=<SumBackward0>)
