# Octave Convolution Tests

We can use this notebook to test our implementation of the OctConv module.

The OctConv module itself is defined under `modules.py`

## Setup

In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [75]:
%load_ext autoreload
%autoreload 2

from modules import OctConv2d

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
USE_GPU = True

dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cpu


## Testing OctConv Behavior

Maybe we can move this to a `test.py` file eventually

In [85]:
# Test output shapes for 1x1 convolution

oc = OctConv2d(16, 16, (1, 1), 0.5, 0.5)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 8, 32, 32), "Incorrect high-frequency output shape for OctConv2d"
assert output_l.shape == (128, 8, 16, 16), "Incorrect low-frequency output shape for OctConv2d"

In [86]:
# Test output shapes with alpha_in != alpha_out

oc = OctConv2d(16, 16, (1, 1), 0.5, 0.25)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 12, 32, 32), "Incorrect high-frequency output shape for OctConv2d"
assert output_l.shape == (128, 4, 16, 16), "Incorrect low-frequency output shape for OctConv2d"

In [87]:
# Test output shapes with alpha_in != alpha_out and in_channels != out_channels

oc = OctConv2d(16, 32, (1, 1), 0.5, 0.25)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 24, 32, 32), "Incorrect high-frequency output shape for OctConv2d"
assert output_l.shape == (128, 8, 16, 16), "Incorrect low-frequency output shape for OctConv2d"

In [88]:
# Test output shapes with alpha_in = alpha_out = 0

oc = OctConv2d(16, 32, (1, 1), 0, 0)
input_h = torch.randn(128, 16, 32, 32)
input_l = torch.randn(128, 0, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 32, 32, 32), "Incorrect high-frequency output shape for OctConv2d"
assert output_l is None, "Incorrect low-frequency output shape for OctConv2d"

In [89]:
# Test output shapes with alpha_in = 0, alpha_out > 0 (imitates first layer)

oc = OctConv2d(16, 32, (1, 1), 0, 0.25)
input_h = torch.randn(128, 16, 32, 32)
input_l = None

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 24, 32, 32), "Incorrect high-frequency output shape for OctConv2d"
assert output_l.shape == (128, 8, 16, 16), "Incorrect low-frequency output shape for OctConv2d"

In [90]:
# Test output shapes with padding and stride

oc = OctConv2d(16, 32, (3, 3), 0.5, 0.5, stride=1, padding=1)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 16, 32, 32), "Shape mismatch for stride=1, padding=1"
assert output_l.shape == (128, 16, 16, 16), "Shape mismatch for stride=1, padding=1"

In [91]:
# Test output shapes with stride to downsample

oc = OctConv2d(16, 32, (2, 2), 0.5, 0.5, stride=2, padding=0)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 16, 16, 16), "Shape mismatch for stride=2, padding=0"
assert output_l.shape == (128, 16, 8, 8), "Shape mismatch for stride=2, padding=0"

In [92]:
# Test that OctConv2d behaves like Conv2d when alpha_in = alpha_out = 0

oc = OctConv2d(3, 32, (3, 3), 0, 0, stride=1, padding=1)
conv = nn.Conv2d(3, 32, (3, 3), stride=1, padding=1)
input_h = torch.randn(128, 3, 32, 32)
input_l = None

conv.weight = oc.conv_hh.weight
conv.bias = oc.conv_hh.bias

output_h, output_l = oc(input_h, input_l)
output_conv = conv(input_h)
assert output_h.shape == output_conv.shape, "OctConv2d and Conv2d have different output shapes"
assert torch.all(torch.eq(output_h, output_conv)), "OctConv2d and Conv2d have different outputs"

## Building an Octconv Network

Lol right now this code is really specific and not very flexible - we can write code to initialize more general Octconv networks

In [93]:
# We define a flatten method here for convenience (taken from Pytorch notebook assignment 2)
def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

In [94]:
# We can't use nn.sequential because sequential only takes one input at each stage
# We have two outputs/inputs at each layer, for the low-frequency and high-frequency channels

# We could store the layers more neatly using ModuleDict or ModuleList

class FourLayerOctConvNet(nn.Module):
    """
    Four layer octconv net for testing. Assumes inputs are of size 3 x 32 x 32.
    
    Architecture: [Octconv -> ReLU -> OctConv -> ReLU -> Pool]*2 -> FC
    """
    
    def __init__(self, alpha, F, D_out):
        """
        Initialize a four-layer Octconv network.
        
        alpha: float between 0 and 1 representing our alpha parameter
        F: integer representing the number of filters in each hidden layer
        D_out: the length of the output vector
        """
        super().__init__()
        self.alpha = alpha
        self.oc1 = OctConv2d(3, F, (3, 3), 0, self.alpha, stride=1, padding=1)
        self.oc2 = OctConv2d(F, F, (3, 3), self.alpha, self.alpha, stride=1, padding=1)
        self.oc3 = OctConv2d(F, F, (3, 3), self.alpha, self.alpha, stride=1, padding=1)
        self.oc4 = OctConv2d(F, F, (3, 3), self.alpha, 0, stride=1, padding=1)
        self.fc1 = torch.nn.Linear(F * 8 * 8, D_out)
        # TODO: Do we need to initialize the weights of oc layers under the OctConv module?
    
    def forward(self, x):
        x_h, x_l = self.oc1(x, None) # alpha_in = 0
        x_h, x_l = F.relu(x_h), F.relu(x_l)
        x_h, x_l = self.oc2(x_h, x_l)
        x_h, x_l = F.relu(x_h), F.relu(x_l)
        x_h, x_l = F.max_pool2d(x_h, (2, 2), stride=2), F.max_pool2d(x_l, (2, 2), stride=2)
        
        x_h, x_l = self.oc3(x_h, x_l)
        x_h, x_l = F.relu(x_h), F.relu(x_l)
        x_h, _ = self.oc4(x_h, x_l) # alpha_out = 0
        x_h = F.relu(x_h)
        x_h = F.max_pool2d(x_h, (2, 2), stride=2)
        x_h = flatten(x_h)
        
        out = self.fc1(x_h)
        return out

In [119]:
# Initialize random training data
N, C, H, W, D_out = 10, 3, 32, 32, 10
x = torch.randn(N, C, H, W, dtype=dtype, device=device)
y = torch.randint(0, D_out, (D_out, ), dtype=dtype, device=device) # Random correct indices

In [115]:
# Create our model
model = FourLayerOctConvNet(0.25, 32, 10)
list(model.modules())[:1]

[FourLayerOctConvNet(
   (oc1): OctConv2d(
     (conv_hh): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv_hl): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (pool): AvgPool2d(kernel_size=(2, 2), stride=2, padding=0)
   )
   (oc2): OctConv2d(
     (conv_hh): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv_ll): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv_lh): Conv2d(8, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (upsample): Upsample(scale_factor=2, mode=nearest)
     (conv_hl): Conv2d(24, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (pool): AvgPool2d(kernel_size=(2, 2), stride=2, padding=0)
   )
   (oc3): OctConv2d(
     (conv_hh): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv_ll): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv_lh): Conv2d(8, 24, kernel_size=(3, 3), stride=(1, 1), padding=(

In [116]:
# Try one forward pass
y_pred = model(x)
y_pred[:1]

tensor([[-0.0476,  0.0253,  0.0508, -0.0035,  0.1010,  0.0477, -0.0843, -0.0289,
         -0.0328,  0.0324]], grad_fn=<SliceBackward>)

In [118]:
# Overfit on our fake dataset
# This training code shamelessy adapted from Justin Johnson's Pytorch examples
model = model.to(device=device)
x = x.to(device=device, dtype=dtype)
y = y.to(device=device, dtype=torch.long)

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    y_pred = model(x)
    
    loss = F.cross_entropy(y_pred, y)
    print(t, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 0.03386716917157173
1 0.030994605273008347
2 0.028363799676299095
3 0.02621612511575222
4 0.024231623858213425
5 0.022223185747861862
6 0.020702075213193893
7 0.019034957513213158
8 0.017602253705263138
9 0.016336917877197266
10 0.015036201104521751
11 0.01393346767872572
12 0.01289587002247572
13 0.011908340267837048
14 0.011047649197280407
15 0.01022186316549778
16 0.009447669610381126
17 0.008777523413300514
18 0.008153152652084827
19 0.007551670074462891
20 0.00701484689489007
21 0.0065367696806788445
22 0.006086921785026789
23 0.005668449215590954
24 0.005291747860610485
25 0.0049461363814771175
26 0.0046257018111646175
27 0.0043357848189771175
28 0.004071235656738281
29 0.0038236617110669613
30 0.0035923004616051912
31 0.00338325509801507
32 0.00319499964825809
33 0.0030204772483557463
34 0.002855491591617465
35 0.00270423898473382
36 0.00256519322283566
37 0.00243721017614007
38 0.0023181915748864412
39 0.00220661167986691
40 0.0021038055419921875
41 0.0020080567337572575
42 0

321 5.073547436040826e-05
322 5.035400317865424e-05
323 4.978180004400201e-05
324 4.9591064453125e-05
325 4.882812572759576e-05
326 4.863739013671875e-05
327 4.863739013671875e-05
328 4.844665454584174e-05
329 4.825591895496473e-05
330 4.787445141118951e-05
331 4.76837158203125e-05
332 4.730224463855848e-05
333 4.653930591302924e-05
334 4.653930591302924e-05
335 4.653930591302924e-05
336 4.615783836925402e-05
337 4.596710277837701e-05
338 4.539489600574598e-05
339 4.539489600574598e-05
340 4.520416405284777e-05
341 4.520416405284777e-05
342 4.444122168933973e-05
343 4.405975414556451e-05
344 4.38690185546875e-05
345 4.38690185546875e-05
346 4.367828296381049e-05
347 4.329681542003527e-05
348 4.291534423828125e-05
349 4.291534423828125e-05
350 4.272460864740424e-05
351 4.253387305652723e-05
352 4.215240551275201e-05
353 4.1961669921875e-05
354 4.1961669921875e-05
355 4.158019874012098e-05
356 4.119873119634576e-05
357 4.100799560546875e-05
358 4.100799560546875e-05
359 4.062652442371473