In [1]:
import torch
from fft_conv_pytorch import fft_conv, FFTConv1d, FFTConv2d
import time
import matplotlib.pyplot as plt
from torchinfo import summary

In [2]:
myNet = torch.nn.Sequential(
    torch.nn.Conv2d(3,2,5),
    torch.nn.ReLU(),
    torch.nn.Conv2d(2,1,3),
    torch.nn.ReLU()
)

In [3]:
print(myNet)

Sequential(
  (0): Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
)


In [6]:
mySummary = summary(myNet, (3, 50, 50))
print(mySummary)

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 44, 44]               --
├─Conv2d: 1-1                            [2, 46, 46]               152
├─ReLU: 1-2                              [2, 46, 46]               --
├─Conv2d: 1-3                            [1, 44, 44]               19
├─ReLU: 1-4                              [1, 44, 44]               --
Total params: 171
Trainable params: 171
Non-trainable params: 0
Total mult-adds (M): 0.01
Input size (MB): 0.03
Forward/backward pass size (MB): 0.05
Params size (MB): 0.00
Estimated Total Size (MB): 0.08


In [10]:
conv = myNet[0]
print(conv)

Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1))


In [8]:
mySummary.summary_list[0].output_size[-2:]

[44, 44]

In [None]:
torch.nn.SyncBatchNorm.convert_sync_batchnorm()

In [26]:
def hard_convert_opt_conv2d(module, threshold):
    module_output = module
    for i, layer in enumerate(list(myNet.children())):
        if isinstance(layer, torch.nn.Conv2d):
            if layer.kernel_size[0] > threshold:
                module_output[i] = FFTConv2d(in_channels=layer.in_channels, 
                                             out_channels=layer.out_channels,
                                             kernel_size=layer.kernel_size, 
                                             stride=layer.stride,
                                             padding=layer.padding,
                                             padding_mode=layer.padding_mode,
                                             dilation=layer.dilation,
                                             groups=layer.groups,
                                             bias=True if layer.bias is not None else False)
    del module
    return module_output

In [34]:
myNet_opt = hard_convert_opt_conv2d(myNet, 4)
print(myNet_opt)

Sequential(
  (0): _FFTConv()
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
)


### More candidates for convolution and matmul