In [1]:
# first the question is to see what the output should be
# the idea is we tile it properly so that we can simply do max reduction or whatever for pooling. so you could do a loop with a kernel, but that's not optimal
# let's test pytorch unfold function

import torch

I = torch.arange(16).reshape(1, 1, 4, 4).float()
I

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])

In [2]:
# so the easy way to do this is just directly run a kernel over it, so do a convolution with a kernel of size 2x2 and stride 2
m = torch.nn.Conv2d(1, 1, 2, 2, 0, bias=False)
m

Conv2d(1, 1, kernel_size=(2, 2), stride=(2, 2), bias=False)

In [5]:
# let's define a function that takes in a tensor and a kernel size and manually does this
# I will just be 4x4 or whatever
def tile(I, kernel_size):
    out = torch.zeros(
        I.shape[0] // kernel_size[0],
        I.shape[1] // kernel_size[1],
        kernel_size[0] * kernel_size[1],
    )
    i_steps = I.shape[0] // kernel_size[0]
    j_steps = I.shape[0] // kernel_size[1]
    for i in range(i_steps):
        for j in range(j_steps):
            out[i, j] = I[i : i + kernel_size[0], j : j + kernel_size[1]].flatten()
    return out


I = torch.arange(16).reshape(4, 4).float()
print(I)
out = tile(I, (2, 2))
print(out)

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
tensor([[[ 0.,  1.,  4.,  5.],
         [ 1.,  2.,  5.,  6.]],

        [[ 4.,  5.,  8.,  9.],
         [ 5.,  6.,  9., 10.]]])


In [7]:
for i in range(4):
    print(
        out[:, :, i]
    )  # want to shape it like this, basically like ran the 2d kernel over it... now how to reshape it

tensor([[0., 1.],
        [4., 5.]])
tensor([[1., 2.],
        [5., 6.]])
tensor([[4., 5.],
        [8., 9.]])
tensor([[ 5.,  6.],
        [ 9., 10.]])


In [9]:
# ok so the way it should stack it is like this, let's make sure my function works
out = tile(I, (1, 1))
print(out, out.shape)

tensor([[[ 0.],
         [ 1.],
         [ 2.],
         [ 3.]],

        [[ 4.],
         [ 5.],
         [ 6.],
         [ 7.]],

        [[ 8.],
         [ 9.],
         [10.],
         [11.]],

        [[12.],
         [13.],
         [14.],
         [15.]]]) torch.Size([4, 4, 1])


In [18]:
# if any pool with kernel size of 1, it won't do shit obviously!, so in this case makes it 4x4x1. So this tiling is correct
# now the question is hwo to reshape with this

# so this is what chatgpt says
I2 = I.reshape(2, 2, 2, 2)  # new height, kh, new width, kw
print(I2.shape)
I3 = I2.permute(0, 2, 1, 3).reshape(2, 2, 4)
I3  # new height, new width, kw*kh
# I3

torch.Size([2, 2, 2, 2])


tensor([[[ 0.,  1.,  4.,  5.],
         [ 2.,  3.,  6.,  7.]],

        [[ 8.,  9., 12., 13.],
         [10., 11., 14., 15.]]])

In [19]:
I3.shape

torch.Size([2, 2, 4])

In [21]:
out = tile(I, (2, 2))
out.shape

torch.Size([2, 2, 4])

In [22]:
out - I3  # yeah definitely not the same

tensor([[[ 0.,  0.,  0.,  0.],
         [-1., -1., -1., -1.]],

        [[-4., -4., -4., -4.],
         [-5., -5., -5., -5.]]])

In [25]:
for i in range(4):
    print(I3[:, :, i])

tensor([[ 0.,  2.],
        [ 8., 10.]])
tensor([[ 1.,  3.],
        [ 9., 11.]])
tensor([[ 4.,  6.],
        [12., 14.]])
tensor([[ 5.,  7.],
        [13., 15.]])


In [26]:
I3.shape

torch.Size([2, 2, 4])

In [29]:
# let's see the thing for a more complex tensor, 6x6 tensor but tile 3x3
I = torch.arange(36).reshape(6, 6)
print(I)

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])


In [30]:
kh, kw = 3, 3
new_height = I.shape[0] // kh
new_width = I.shape[1] // kw

I2 = I.reshape(new_height, kh, new_width, kw)
print(I2.shape, I2)
I3 = I2.permute(0, 2, 1, 3).reshape(new_height, new_width, kh * kw)

print(I3.shape, I3)

torch.Size([2, 3, 2, 3]) tensor([[[[ 0,  1,  2],
          [ 3,  4,  5]],

         [[ 6,  7,  8],
          [ 9, 10, 11]],

         [[12, 13, 14],
          [15, 16, 17]]],


        [[[18, 19, 20],
          [21, 22, 23]],

         [[24, 25, 26],
          [27, 28, 29]],

         [[30, 31, 32],
          [33, 34, 35]]]])
torch.Size([2, 2, 9]) tensor([[[ 0,  1,  2,  6,  7,  8, 12, 13, 14],
         [ 3,  4,  5,  9, 10, 11, 15, 16, 17]],

        [[18, 19, 20, 24, 25, 26, 30, 31, 32],
         [21, 22, 23, 27, 28, 29, 33, 34, 35]]])


In [32]:
# so the shape of this is indeed correct, let's test to see if it makes sense
out = tile(I, (3, 3))
print(out.shape, out)

torch.Size([2, 2, 9]) tensor([[[ 0.,  1.,  2.,  6.,  7.,  8., 12., 13., 14.],
         [ 1.,  2.,  3.,  7.,  8.,  9., 13., 14., 15.]],

        [[ 6.,  7.,  8., 12., 13., 14., 18., 19., 20.],
         [ 7.,  8.,  9., 13., 14., 15., 19., 20., 21.]]])


In [33]:
# wait how to interpret this 2,2,9. Basically 9 2x2 things, so after reduce each of these left with 9 elements...
# wtf is this tho? Should have 4 values actually after the reduction, so reduce along this last dimension
# the 9 is how many elements are in each of the kernels! So what matters is that we sum along that third dimension is identical, or that the values in each of the 4 are

tensor([[[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
         [ -2.,  -2.,  -2.,  -2.,  -2.,  -2.,  -2.,  -2.,  -2.]],

        [[-12., -12., -12., -12., -12., -12., -12., -12., -12.],
         [-14., -14., -14., -14., -14., -14., -14., -14., -14.]]])

In [35]:
print(
    out.sum(2), I3.sum(2), sep="\n"
)  # the values don't even add up, this is kinda weird?

tensor([[ 63.,  72.],
        [117., 126.]])
tensor([[ 63,  90],
        [225, 252]])


In [36]:
# let's manually verify
I

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])

In [37]:
a1 = 0 + 1 + 2 + 6 + 7 + 8 + 12 + 13 + 14
a2 = 3 + 4 + 5 + 9 + 10 + 11 + 15 + 16 + 17
print(a1, a2)

63 90


In [None]:
# yeah it seems my out is wrong, let's just go with what they have lol