## Attempting to create a custom convolutional layer

In [36]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import v2
from torch.autograd import Function

In [47]:
class SparseConv(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, weights, width, height, in_channels, out_channels, padding, offset_row, offset_col):
        nz = torch.nonzero(input)
        input = torch.nn.functional.pad(input, (padding[0], padding[0], padding[1], padding[1]), 'constant', 0)
        out = torch.zeros((input.shape[0], out_channels, height, width), dtype=torch.float32)
        dx = torch.zeros_like(weights, dtype=torch.float32)
        for batch in range(input.shape[0]):
            for out_channel in range(out_channels):
                for in_channel in range(in_channels):
                    for n in range(nz.shape[0]):
                        # assuming odd kernel size
                        i = nz[n,0]
                        j = nz[n,1]
                        nz_window = torch.nonzero(input[batch, in_channel,
                                        i - offset_row + padding[0]:i + offset_row+1 + padding[0],
                                        j - offset_col + padding[1]:j + offset_col+1 + padding[1]], as_tuple=True)
                        out[batch, out_channel, i, j] = \
                            torch.sum(weights[out_channel, in_channel, :, :] * \
                                      input[batch, in_channel,
                                        i - offset_row + padding[0]:i + offset_row+1 + padding[0],
                                        j - offset_col + padding[1]:j + offset_col+1 + padding[1]])
                        
                        dx[out_channel, in_channel, nz_window[0], nz_window[1]] += weights[out_channel, in_channel, nz_window[0], nz_window[1]]
        
        return out, dx


    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, out):
        input, weights, width, height, in_channels, out_channels, padding, offset_row, offset_col = inputs
        result, dx = out
        ctx.in1 = width
        ctx.in2 = height
        ctx.in3 = in_channels
        ctx.in4 = out_channels
        ctx.in5 = padding
        ctx.in6 = offset_row
        ctx.in7 = offset_col
        ctx.save_for_backward(input, weights, dx)

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        input, weights, dx = ctx.saved_tensors
        
        grad_input = dx * grad_dx

        return grad_input

# Wrap MyCube in a function so that it is clearer what the output is
def my_sparse_conv(x, weights, width, height, in_channels, out_channels, padding, offset_row, offset_col):
    result, dx = SparseConv.apply(x, weights, width, height, in_channels, out_channels, padding, offset_row, offset_col)
    return result

In [48]:
from torch.autograd import gradcheck

def calculateNewWidth(self, x):
    return (
        (x.shape[2] + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1)
        // self.stride[0]
    ) + 1

def calculateNewHeight(self, x):
    return (
        (x.shape[3] + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1)
        // self.stride[1]
    ) + 1
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
x = torch.eye(10, requires_grad=True, dtype=torch.double).unsqueeze(0).unsqueeze(0)
weights = torch.randn((1,1,5,5), requires_grad=True, dtype=torch.double)
padding = (2,2)
dilation = (1,1)
kernel_size = (5,5)
stride = (1,1)
width = (
        (x.shape[2] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1)
        // stride[0]
    ) + 1
height = (
        (x.shape[3] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1)
        // stride[1]
    ) + 1
in_channels = 1
out_channels = 1
offset_row = kernel_size[0] // 2
offset_col = kernel_size[1] // 2
input = (x, weights, width, height, in_channels, out_channels, padding, offset_row, offset_col)
test = gradcheck(my_sparse_conv, input, eps=1e-6, atol=1e-4)
print(test)

RuntimeError: function SparseConvBackward returned an incorrect number of gradients (expected 9, got 1)

In [35]:
temp = torch.eye(9).unsqueeze(0)
print(torch.nonzero(temp[0, 3:6,3:6]))
nz = torch.nonzero(temp[0, 3:6,3:6], as_tuple=True)
print(nz[0], nz[1])
temp[0, nz[0], nz[1]] += temp[0, nz[0], nz[1]]
print(temp)

tensor([[0, 0],
        [1, 1],
        [2, 2]])
tensor([0, 1, 2]) tensor([0, 1, 2])
tensor([[[2., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 2., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 2., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1.]]])


In [8]:
class sparseConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1):
        super(sparseConv2D, self).__init__()
        
        self.kernel_size = (kernel_size, kernel_size)
        self.kernel_size_number = kernel_size * kernel_size
        self.offset_row = self.kernel_size[0] // 2
        self.offset_col = self.kernel_size[1] // 2
        self.out_channels = out_channels
        self.dilation = (dilation, dilation)
        self.padding = (padding, padding)
        self.stride = (stride, stride)
        self.in_channels = in_channels
        self.weights = nn.Parameter(torch.randn((self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])))
        self.register_parameter("kernels", self.weights)
        # self.weights.requires_grad = True

    def forward(self, x):
        # nz = torch.nonzero(x)
        # width = self.calculateNewWidth(x)
        # height = self.calculateNewHeight(x)
        # x = torch.nn.functional.pad(x, (self.padding[0], self.padding[0], self.padding[1], self.padding[1]), 'constant', 0)
        # out = torch.zeros((x.shape[0], self.out_channels, height, width), dtype=torch.float32)
        # for batch in range(x.shape[0]):
        #     for out_channel in range(self.out_channels):
        #         for in_channel in range(self.in_channels):
        #             for n in range(nz.shape[0]):
        #                 # assuming odd kernel size
        #                 i = nz[n,0]
        #                 j = nz[n,1]
        #                 out[batch, out_channel, i, j] = \
        #                     torch.sum(self.weights[out_channel, in_channel, :, :] * \
        #                               x[batch, in_channel,
        #                                 i - self.offset_row + self.padding[0]:i + self.offset_row+1 + self.padding[0],
        #                                 j - self.offset_col + self.padding[1]:j + self.offset_col+1 + self.padding[1]])
        
        # x = out
        # return x

    def calculateNewWidth(self, x):
        return (
            (x.shape[2] + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1)
            // self.stride[0]
        ) + 1

    def calculateNewHeight(self, x):
        return (
            (x.shape[3] + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1)
            // self.stride[1]
        ) + 1

In [9]:
class sparseCNN(nn.Module):
    def __init__(self):
        super(sparseCNN, self).__init__()
        self.conv1 = sparseConv2D(in_channels=1, 
                                  out_channels=1, 
                                  kernel_size=5, 
                                  dilation=1, 
                                  padding=2, 
                                  stride=1)
        # self.relu = nn.ReLU()
        # self.flatten = nn.Flatten()
        # self.linear = nn.Linear(10*14*14, 10)
        # self.out = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv1(x)
        # x = self.relu(x)
        # x = self.flatten(x)
        # x = self.linear(x)
        # x = self.out(x)

        return x

In [24]:
net = sparseCNN()
net.train()
optimizer = torch.optim.SGD(params=net.parameters(), lr=0.1)
loss = nn.MSELoss()
out = torch.zeros((1,1,14,14))
out[0,0,7:11,7:11] = 10
inp = torch.zeros((1,1,14,14))
inp[0,:,7,7] = 1

print(net.conv1.weights)
for i in range(1):
    pred = net.forward(inp)
    l = loss(pred, out)
    print("Loss: ", l.item())
    # l.requires_grad = True
    l.backward()
    optimizer.step()
    print(net.conv1.weights.grad)
print(net.conv1.weights)


Parameter containing:
tensor([[[[-0.5391,  1.0129, -0.5719,  0.0184, -1.2213],
          [-1.5627, -0.2636, -1.5103, -0.2849, -0.2953],
          [-0.7283, -1.0078,  1.6991, -0.3756, -0.4482],
          [ 0.8908, -2.1303, -0.5394,  1.1742,  0.5946],
          [-0.5012, -0.2617, -0.0962,  0.3924,  0.4818]]]], requires_grad=True)
Loss:  8.163265228271484
tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]])
Parameter containing:
tensor([[[[-0.5391,  1.0129, -0.5719,  0.0184, -1.2213],
          [-1.5627, -0.2636, -1.5103, -0.2849, -0.2953],
          [-0.7283, -1.0078,  1.6991, -0.3756, -0.4482],
          [ 0.8908, -2.1303, -0.5394,  1.1742,  0.5946],
          [-0.5012, -0.2617, -0.0962,  0.3924,  0.4818]]]], requires_grad=True)


In [5]:
print(list(net.parameters()))
print(net.conv1.weights.grad)

[Parameter containing:
tensor([[[[-2.8262,  0.9700,  0.7905,  1.0629,  2.0799],
          [ 1.0917,  0.6226, -1.4012, -0.6333, -1.1871],
          [ 0.7901, -0.7102,  1.1423,  0.4486, -0.8829],
          [-0.0178,  0.1736,  0.5451, -0.0836, -0.9699],
          [-0.2646, -0.5052, -0.4009,  0.8752,  0.0299]]]], requires_grad=True)]
tensor([[[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]])


In [6]:
t = torch.arange(3*4*5*6).view(3,4,5,6)
print(t.sum(dim=(2,3), keepdim=True).shape)
print(t[:,0,:,:].shape)
idx = torch.tensor([[1,1],[2,2]])
print(t[0,0,1,1])

from torch.nn.functional import pad
o = torch.ones((4,4))
print(pad(o, (1,1,1,1), "constant", 0))

torch.Size([3, 4, 1, 1])
torch.Size([3, 5, 6])
tensor(7)
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.]])
