## Attempting to create a custom convolutional layer

In [36]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import v2


class sparseConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1):
        super(sparseConv2D, self).__init__()
        
        self.kernel_size = (kernel_size, kernel_size)
        self.kernel_size_number = kernel_size * kernel_size
        self.offset_row = self.kernel_size[0] // 2
        self.offset_col = self.kernel_size[1] // 2
        self.out_channels = out_channels
        self.dilation = (dilation, dilation)
        self.padding = (padding, padding)
        self.stride = (stride, stride)
        self.in_channels = in_channels
        self.weights = nn.Parameter(torch.Tensor(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])).data.uniform_(0, 1)

    def forward(self, x):
        nz = torch.nonzero(x)
        width = self.calculateNewWidth(x)
        height = self.calculateNewHeight(x)
        x = torch.nn.functional.pad(x, (self.padding[0], self.padding[0], self.padding[1], self.padding[1]), 'constant', 0)
        out = torch.zeros((x.shape[0], self.out_channels, height, width))
        # for channel in output imaage
            # for channel in input image
                # for coord in coords
                    # calculate values at coords
                    # add to respective location in output
        for batch in range(x.shape[0]):
            for out_channel in range(self.out_channels):
                for in_channel in range(self.in_channels):
                    for n in range(nz.shape[0]):
                        # assuming odd kernel size
                        i = nz[n,0]
                        j = nz[n,1]
                        out[batch, out_channel, i, j] = \
                            torch.sum(self.weights[out_channel, in_channel, :, :] * \
                                      x[batch, in_channel,
                                        i - self.offset_row + self.padding[0]:i + self.offset_row+1 + self.padding[0],
                                        j - self.offset_col + self.padding[1]:j + self.offset_col+1 + self.padding[1]])
        
        return out

    def calculateNewWidth(self, x):
        return (
            (x.shape[2] + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1)
            // self.stride[0]
        ) + 1

    def calculateNewHeight(self, x):
        return (
            (x.shape[3] + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1)
            // self.stride[1]
        ) + 1

In [45]:
class sparseCNN(nn.Module):
    def __init__(self):
        super(sparseCNN, self).__init__()
        self.conv1 = sparseConv2D(in_channels=1, 
                                  out_channels=10, 
                                  kernel_size=5, 
                                  dilation=1, 
                                  padding=2, 
                                  stride=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(10*14*14, 10)
        self.out = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.linear(x)
        x = self.out(x)

        return x

In [46]:
test = sparseCNN()
val = torch.zeros((1,1,14,14))
val[0,0,7:11,7:11] = 1
y = test.forward(val)
print(y)

tensor([[0.0976, 0.0981, 0.1009, 0.0989, 0.1013, 0.1014, 0.0996, 0.1013, 0.0993,
         0.1016]], grad_fn=<SoftmaxBackward0>)


In [25]:
t = torch.arange(3*4*5*6).view(3,4,5,6)
print(t.sum(dim=(2,3), keepdim=True).shape)
print(t[:,0,:,:].shape)
idx = torch.tensor([[1,1],[2,2]])
print(t[0,0,1,1])

from torch.nn.functional import pad
o = torch.ones((4,4))
print(pad(o, (1,1,1,1), "constant", 0))

torch.Size([3, 4, 1, 1])
torch.Size([3, 5, 6])
tensor(7)
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.]])
