In [2]:
from torch.nn.functional import fold, unfold
from torch import Tensor

class Module(object):
    
     def forward(self,*input):
        raise NotImplementedError
    
     def backward(self,*gradwrtoutput):
        raise NotImplementedError
    
     def param(self):
        return []

class Conv2d(Module):
    
    def __init__(self,in_channel, out_channel, kernel_size = (2,2),stride=1, padding=0, dilation=1, groups=1,weight = None, bias = None):
        if isinstance(stride, int):
            self.stride = (stride, stride)
        else:
            self.stride = stride
        if isinstance(padding, int):
            self.padding = (padding, padding)
        else: 
            self.padding = padding
        if isinstance(dilation, int):
            self.dilation = (dilation, dilation)
        else:
            self.dilation = dilation
        if isinstance(kernel_size, int):
            self.kernel = (kernel_size, kernel_size)
        else:
            self.kernel = kernel_size
        
        self.groups = groups
        self.in_channel = in_channel
        self.out_channel = out_channel
        
        k = self.groups/(self.in_channel*self.kernel.prod())
        if weight == None:
            self.weight = torch.empty((self.out_channel,self.in_channel// self.groups,
                                       self.kernel[0],self.kernel[1])).uniform_(-k,k)
        else:
            self.weight = weight
        if bias == None:
            self.bias = torch.empty(self.out_channel).uniform_(-k,k)
        else:
            self.bias = bias
        
        self.weight_grad = torch.empty((self.kernel[0],self.kernel[1]))
        self.bias_grad = torch.empty(self.out_channel)
        
    @staticmethod
    def conv2d(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor:
        # input is 4d tensor
        

        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        
        kernel_size = (weight.size(-2), weight.size(-1))
        C_out = weight.size(0)
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
        
        inp_unf = unfold(input,kernel_size,dilation,padding,stride)
        if bias != None:
            out_unf = inp_unf.transpose(1, 2).matmul(weight.view(C_out, -1).t()).transpose(1, 2).add(bias.view(1, -1, 1).repeat(N,1,H_out*W_out))
        else:
            out_unf = inp_unf.transpose(1, 2).matmul(weight.view(C_out, -1).t()).transpose(1, 2)
        return out_unf.view(N,C_out,H_out,W_out)


    def forward(self, input):
        self.input = input
        
        if (len(input.size()) == 4) :
            return self.conv2d(self.input,weight = self.weight, bias = self.bias,stride =self.stride,
                                  padding = self.padding, dilation = self.dilation,groups = self.groups)
        
    
    def grad_conv2d_weight(self,input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):

        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        N = input.size(0)
        C_in = input.size(1)
        H_in = input.size(-2)
        W_in = input.size(-1)
        
        kernel_size = (weight_size[2], weight_size[3])
        C_out = weight_size[0]
        H_out = int((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
        W_out = int((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)

        grad_w = torch.empty(weight_size)

        for i in range(grad_w.size(0)): 
            for j in range(input.size(1)): 
                grad_w[i,j,:,:] = self.conv2d(self.input[:,j,:,:].view(N,1,H_in,W_in),
                                              weight = grad_output[i,:,:].view(1,1,H_out,W_out), 
                                              bias = None,stride =self.dilation,padding = self.padding, 
                                              dilation = self.stride,groups = self.groups).narrow(2, 0, 
                                              weight_size[-2]).narrow(3, 0, weight_size[-1]).sum(dim=0)
        
        return grad_w
        
    def backward(self,gradwrtouput):
        
        a = self.grad_conv2d_weight(self.input, self.weight.shape, gradwrtouput, self.stride, self.padding, self.dilation, self.groups)
        # the true bacward return the gradient with respect to the input
        return a
    