# Octave Convolution Tests

We can use this notebook to test our implementation of the OctConv module.

## Setup

In [4]:
import torch
import torch.nn as nn

In [6]:
%load_ext autoreload
%autoreload 2

from modules import OctConv2d

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Testing OctConv Behavior

In [7]:
# Test output shapes

oc = OctConv2d(16, 16, (2, 2), 0.5, 0.5)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 8, 31, 31)
assert output_l.shape == (128, 8, 15, 15)

In [8]:
# Test output shapes with alpha_in != alpha_out

oc = OctConv2d(16, 16, (2, 2), 0.5, 0.25)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 12, 31, 31)
assert output_l.shape == (128, 4, 15, 15)

In [9]:
# Test output shapes with alpha_in != alpha_out and in_channels != out_channels

oc = OctConv2d(16, 32, (2, 2), 0.5, 0.25)
input_h = torch.randn(128, 8, 32, 32)
input_l = torch.randn(128, 8, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 24, 31, 31)
assert output_l.shape == (128, 8, 15, 15)

In [13]:
# Test output shapes with alpha_in = alpha_out = 0

oc = OctConv2d(16, 32, (2, 2), 0, 0)
input_h = torch.randn(128, 16, 32, 32)
input_l = torch.randn(128, 0, 16, 16)

output_h, output_l = oc(input_h, input_l)
assert output_h.shape == (128, 32, 31, 31)
assert output_l is None

In [24]:
# Test that OctConv2d behaves like Conv2d when alpha_in = alpha_out = 0

oc = OctConv2d(3, 32, (2, 2), 0, 0)
conv = nn.Conv2d(3, 32, (2, 2))
input_h = torch.randn(128, 3, 32, 32)
input_l = None

conv.weight = oc.conv_hh.weight
conv.bias = oc.conv_hh.bias

output_h, output_l = oc(input_h, input_l)
output_conv = conv(input_h)
assert output_h.shape == output_conv.shape, "OctConv and Conv have different output shapes"
assert torch.all(torch.eq(output_h, output_conv)), "OctConv and Conv have different outputs"