In [5]:
import torch
import random
import torch.nn.functional as F
from torch import nn

In [8]:
class Conv2DFunc(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward
    passes which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input_batch, kernel, stride=1, padding=1):

        # store objects for the backward
        ctx.save_for_backward(input_batch)
        ctx.save_for_backward(kernel)

        # get shape of input tensor
        b, c, H, W = input_batch.size()

        # calculate output shape
        oH = (H - kernel.size(2)) // stride + 1
        oW = (W - kernel.size(3)) // stride + 1

        # create output tensor
        output_batch = torch.zeros(b, kernel.size(0), oH, oW)

        # Unfold Input
        unfolded = F.unfold(input_batch, (2,2))

        # Reshape Filter
        filter_reshaped = kernel.view(kernel.size(0), -1).T

        # Perform Convolution
        conv_output = unfolded.transpose(1,2).matmul(filter_reshaped).transpose(1,2)

        # Fold Back
        output_batch = F.fold(conv_output, output_batch.shape[2:4], (1,1))

        return output_batch


    @staticmethod
    def backward(ctx, grad_output):

        """
        In the backward pass we receive a Tensor containing the
        gradient of the loss with respect to the output, and we need
        to compute the gradient of the loss with respect to the
        input
        """

        # retrieve stored objects
        input_batch, kernel = ctx.saved_tensors

        # your code here
        input_batch_grad = F.fold(F.unfold(grad_output, kernel.size()[2:]).matmul(kernel.flatten()), input_batch.shape[2:],(1,1))
        kernel_grad = F.fold(F.unfold(input_batch, grad_output.size()[2:]).matmul(grad_output.flatten()),kernel.shape[2:],(1,1))

        # The gradients of the inputs. For anything that doesn't have # a gradient (the stride and padding) you can
        # return None.

        return input_batch_grad, kernel_grad, None, None


input_batch = torch.randn(16,3,32,32)
kernel = torch.randn(2, 3, 2, 2)
Conv2DFunc.apply(input_batch, kernel)

tensor([[[[  2.5842,  -0.1989,  -4.9084,  ...,   0.4085,  -1.1752,  -2.4602],
          [ -3.7182,   3.7164,   6.1937,  ...,   1.2530,   0.4965,  -1.0982],
          [ -0.3931,  -1.2821,   0.7987,  ...,   0.0476,   4.4236,   1.9649],
          ...,
          [ -5.7531,  -1.4558,   3.8175,  ...,  -2.3677,   1.2201,   0.9530],
          [  1.0636,  -0.9391,   4.5892,  ...,  -1.3669,  -0.1630,   0.5820],
          [ -1.8402,  -2.0190,  -0.8690,  ...,  -2.2354,  -2.5158,   4.1031]],

         [[ -3.4498,  -3.7207,  -5.3942,  ...,   4.5733,   4.6192,  -3.0451],
          [ -1.7342,   4.7661,  -2.8703,  ...,  -3.2667,  -5.5867,   4.5160],
          [ -0.1363,   5.7277,   7.4695,  ...,  -4.5361, -11.3935, -10.6704],
          ...,
          [  1.9613,   0.8342,  -5.4449,  ...,  -1.5805,  -5.8310,   0.1576],
          [  0.9721,  -0.6311,   4.0244,  ...,  -2.9763,   3.2695,  -2.6310],
          [  5.3097,  -1.3595,   5.0795,  ...,   9.1282,   7.6029,   5.1334]]],


        [[[ -3.6233,   0.264