In [1]:
import torch
from fft_conv_pytorch import fft_conv, FFTConv1d, FFTConv2d
import time
import matplotlib.pyplot as plt
from torchinfo import summary

In [2]:
myNet = torch.nn.Sequential(
    torch.nn.Conv2d(3,2,5),
    torch.nn.ReLU(),
    torch.nn.Conv2d(2,1,3),
    torch.nn.ReLU()
)

In [3]:
print(myNet)

Sequential(
  (0): Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
)


In [4]:
mySummary = summary(myNet, (3, 50, 50))
print(mySummary)

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 44, 44]               --
├─Conv2d: 1-1                            [2, 46, 46]               152
├─ReLU: 1-2                              [2, 46, 46]               --
├─Conv2d: 1-3                            [1, 44, 44]               19
├─ReLU: 1-4                              [1, 44, 44]               --
Total params: 171
Trainable params: 171
Non-trainable params: 0
Total mult-adds (M): 0.01
Input size (MB): 0.03
Forward/backward pass size (MB): 0.05
Params size (MB): 0.00
Estimated Total Size (MB): 0.08


In [5]:
conv = myNet[0]
print(conv)

Conv2d(3, 2, kernel_size=(5, 5), stride=(1, 1))


In [6]:
mySummary.summary_list[0].output_size[-2:]

[44, 44]

In [7]:
torch.nn.SyncBatchNorm.convert_sync_batchnorm()

TypeError: convert_sync_batchnorm() missing 1 required positional argument: 'module'

In [None]:
def hard_convert_opt_conv2d(module, threshold):
    module_output = module
    for i, layer in enumerate(list(myNet.children())):
        if isinstance(layer, torch.nn.Conv2d):
            if layer.kernel_size[0] > threshold:
                module_output[i] = FFTConv2d(in_channels=layer.in_channels, 
                                             out_channels=layer.out_channels,
                                             kernel_size=layer.kernel_size, 
                                             stride=layer.stride,
                                             padding=layer.padding,
                                             padding_mode=layer.padding_mode,
                                             dilation=layer.dilation,
                                             groups=layer.groups,
                                             bias=True if layer.bias is not None else False)
    del module
    return module_output

In [None]:
myNet_opt = hard_convert_opt_conv2d(myNet, 4)
print(myNet_opt)

Sequential(
  (0): _FFTConv()
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
)


### More candidates for convolution and matmul

In [8]:
import torch
from fft_conv_pytorch import fft_conv, FFTConv1d, FFTConv2d
import time
import matplotlib.pyplot as plt
from torchinfo import summary

In [9]:
### dynamically decide the better operator for convolution, given parameters
def dynamic_fit_conv(kernel_size, input_width, input_height, input_channel=3, output_channel=2):
    signal = torch.randn(3, input_channel, input_width, input_height)
    kernel = torch.randn(output_channel, input_channel, kernel_size, kernel_size)
    bias = torch.randn(output_channel)
    
    test_torch_conv = torch.nn.Conv2d(input_channel, output_channel, kernel_size, bias=True)
    test_torch_conv.weight = torch.nn.Parameter(kernel)
    test_torch_conv.bias = torch.nn.Parameter(bias)
    
    test_fft_conv = FFTConv2d(input_channel, output_channel, kernel_size, bias=True)
    test_fft_conv.weight = torch.nn.Parameter(kernel)
    test_fft_conv.bias = torch.nn.Parameter(bias)
    
    iters = 16
    time0 = time.time()
    for _ in range(iters):
        out = test_torch_conv(signal)
    time1 = time.time()
    
    for _ in range(iters):
        out = test_fft_conv(signal)
    time2 = time.time()
    
    torch_time = (time1 - time0) / iters * 1000
    fft_time = (time2 - time1) / iters * 1000
    
    # 0: torch conv is better; 1: fft conv is better
    return 0 if torch_time < fft_time else 1

### dynamically convert current model to optimized model, accelerating convolution
def dynamic_convert_opt_conv2d(module, org_in_channels, org_in_width, org_in_height):
    # compute input width, height per layer
    ####
    mySummary = summary(module, (org_in_channels, org_in_width, org_in_height)).summary_list
    module_output = module
    for i, layer in enumerate(list(module.children())):
        if isinstance(layer, torch.nn.Conv2d):
            ### fetch internal data sizes
            ### we must feed correct kernel_size, input_size to the 'dynamic_fit_conv' function
            if i == 0:
                input_width, input_height = org_in_width, org_in_height
            else:
                input_width, input_height = mySummary[i-1].output_size[-2:]
            print(layer.kernel_size[0], input_width, input_height, layer.in_channels, layer.out_channels)
            best_candidate = dynamic_fit_conv(kernel_size=layer.kernel_size[0], 
                                         input_width=input_width, input_height=input_height, 
                                         input_channel=layer.in_channels, output_channel=layer.out_channels)
            if best_candidate == 1:
                # org_weight, org_bias = module_output[i].weight, module_output[i].bias
                module_output[i] = FFTConv2d(in_channels=layer.in_channels, 
                                             out_channels=layer.out_channels,
                                             kernel_size=layer.kernel_size, 
                                             stride=layer.stride,
                                             padding=layer.padding,
                                             padding_mode=layer.padding_mode,
                                             dilation=layer.dilation,
                                             groups=layer.groups,
                                             bias=True if layer.bias is not None else False)
                # copy weights and bias
                # module_output[i].weight, module_output[i].bias = org_weight, org_bias
    del module
    return module_output

In [10]:
myNet = torch.nn.Sequential(
    torch.nn.Conv2d(3, 2, 30),
    torch.nn.Conv2d(2, 2, 10)
)
print(myNet)

Sequential(
  (0): Conv2d(3, 2, kernel_size=(30, 30), stride=(1, 1))
  (1): Conv2d(2, 2, kernel_size=(10, 10), stride=(1, 1))
)


In [11]:
image = torch.randn(3, 3, 512, 512)
out = myNet(image)
out.shape

torch.Size([3, 2, 474, 474])

In [12]:
opt_myNet = dynamic_convert_opt_conv2d(myNet, image.shape[1], image.shape[2], image.shape[3])
opt_myNet

30 512 512 3 2
10 474 474 2 2


Sequential(
  (0): _FFTConv()
  (1): Conv2d(2, 2, kernel_size=(10, 10), stride=(1, 1))
)