In [1]:
import torch
import torch.nn as nn

from superonn_final import SuperONN2d

In [38]:
for in_c in [3, 8, 16, 64]:
    for out_c in [3, 8, 16, 64]:
        for q in [2, 3, 4, 5, 6, 7]:
            for full_groups in range(out_c):
                for groups in range(min(out_c, in_c * q * full_groups)):
                    for shift_groups in range(in_c):
                        x = torch.randn(1, in_c, 5, 5)

                        try:
                            sonn = SuperONN2d(
                                in_c, 
                                out_c,
                                kernel_size=3,
                                padding=1,
                                q=q,
                                full_groups=full_groups,
                                groups=groups,
                                shift_groups=shift_groups,
                                max_shift=10,
                                learnable=True
                            )
                            conv = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, groups=groups)
                        except Exception as e:
                            continue

                        try:
                            outconv = conv(x)
                            outsonn = sonn(x)
                        except:
                            print(f"in_c={in_c}, out_c={out_c}, q={q}, full_groups={full_groups}, groups={groups}, shift_groups={shift_groups}")
                            raise Exception("Passing through conv/sonn failed")
                        
                        if not torch.equal(torch.tensor(outconv.shape), torch.tensor(outsonn.shape)):
                            raise Exception("Shapes don't match")
                        
                        sonnw = torch.tensor(sonn.weight.shape)
                        convw = torch.tensor(conv.weight.shape)

                        sonnw[1] = sonnw[1] * (sonn.groups // sonn.full_groups) // sonn.q // groups

                        if not torch.equal(sonnw, convw):
                            print(f"sonnw={sonnw}, convw={convw}")
                            print(f"in_c={in_c}, out_c={out_c}, q={q}, full_groups={full_groups}, groups={groups}, shift_groups={shift_groups}")
                            raise Exception("Weights don't match")

KeyboardInterrupt: 

In [53]:
sonn = SuperONN2d(3, 18*3, q=3, kernel_size=3, padding=1, groups=9, full_groups=1, max_shift=10, learnable=True)

In [54]:
%%timeit
sonn(torch.randn(1, 3, 25, 25))

453 µs ± 13.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [52]:
%%timeit
sonn(torch.randn(1, 3, 25, 25))

733 µs ± 285 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [50]:
%%timeit
sonn(torch.randn(1, 3, 25, 25))

452 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [47]:
%%timeit
sonn(torch.randn(1, 3, 25, 25))

2.25 ms ± 721 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [45]:
%%timeit
sonn(torch.randn(1, 3, 25, 25))

1.74 ms ± 31.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [27]:
sonn.weight.shape, sonn.shifts.shape

(torch.Size([16, 6, 3, 3]), torch.Size([16, 3, 2]))

In [58]:
torch.repeat_interleave(x, 3, 0).shape

torch.Size([18, 4, 4])

In [21]:
# Shift repeat test
# shift_groups = 3, in_channels = 12
shifts = torch.tensor([[0, 1, 2], [0, 1, 2]]).T
new_shifts = torch.repeat_interleave(shifts, 12 // 3, 0)

assert shifts.shape == (3, 2)
assert new_shifts.shape == (12, 2)
assert torch.allclose(shifts[0, :], new_shifts[0:4, :])
print(shifts)
print(new_shifts)

tensor([[0, 0],
        [1, 1],
        [2, 2]])
tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2]])


In [65]:
# Reshape test
n = 1
in_c = 2
out_c = 3
h = 2
w = 2

# A 2x2 image with 2 channels, one of them is all 1s and the other is all 2s.
x = torch.tensor([1,1,1,1,2,2,2,2]).reshape(n, in_c, h, w)
print(x, x.shape)

# Concat the channels out_c times.
y = torch.cat([x for _ in range(out_c)], 1)
print(y, y.shape)

# Now, should we reshape it to (n, in_c, out_c, h, w) or (n, out_c, in_c, h, w)?
# (n, out_c, in_c, h, w) is the correct shape, as we have (n x out_c) images with in_c channels, not (n x in_c) images with out_c channels.
z = y.reshape(n, out_c, in_c, h, w)
print(z)

tensor([[[[1, 1],
          [1, 1]],

         [[2, 2],
          [2, 2]]]]) torch.Size([1, 2, 2, 2])
tensor([[[[1, 1],
          [1, 1]],

         [[2, 2],
          [2, 2]],

         [[1, 1],
          [1, 1]],

         [[2, 2],
          [2, 2]],

         [[1, 1],
          [1, 1]],

         [[2, 2],
          [2, 2]]]]) torch.Size([1, 6, 2, 2])
tensor([[[[[1, 1],
           [1, 1]],

          [[2, 2],
           [2, 2]]],


         [[[1, 1],
           [1, 1]],

          [[2, 2],
           [2, 2]]],


         [[[1, 1],
           [1, 1]],

          [[2, 2],
           [2, 2]]]]])


In [76]:
# Shift reshape test
in_channels = 4
full_groups = 3
shift_groups = 2
shifts = torch.tensor([[[0, 1, 2], [0, 1, 2]], [[0, 1, 2], [0, 1, 2]]]).T
print(shifts.shape)

new_shifts = torch.repeat_interleave(shifts, in_channels // shift_groups, 1)
print(new_shifts.shape)

print(shifts.reshape(full_groups * shift_groups, 2))
print(new_shifts.reshape(full_groups * in_channels, 2))

torch.Size([3, 2, 2])
torch.Size([3, 4, 2])
tensor([[0, 0],
        [0, 0],
        [1, 1],
        [1, 1],
        [2, 2],
        [2, 2]])
tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [2, 2],
        [2, 2],
        [2, 2],
        [2, 2]])


In [32]:
torch.cat(x.tile(4).chunk(4, -1), 1)

tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]],

         [[16, 17, 18, 19],
          [20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31]],

         [[32, 33, 34, 35],
          [36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]],

         [[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]],

         [[16, 17, 18, 19],
          [20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31]],

         [[32, 33, 34, 35],
          [36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]],

         [[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]],

         [[16, 17, 18, 19],
          [20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31]],

         [[32, 33, 34, 35],
          [36, 37, 38, 39],
          [40, 41, 42, 43],
    