In [1]:
import torch
from fft_conv_pytorch import fft_conv, FFTConv1d, FFTConv2d
import torch.nn.functional as F
import time
import numpy as np

#### Faster than direct convolution for large kernels.
#### Much slower than direct convolution for small kernels

### 0. Play with 1D cases

In [2]:
# Create dummy data.  
#     Data shape: (batch, channels, length)
#     Kernel shape: (out_channels, in_channels, kernel_size)
#     Bias shape: (out channels, )
# For ordinary 1D convolution, simply set batch=1.
kernal_size = 3
signal = torch.randn(3, 3, 1000)
kernel = torch.randn(2, 3, kernal_size)
bias = torch.randn(2)

In [3]:
# Functional execution.  (Easiest for generic use cases.)
# out = fft_conv(signal, kernel, bias=bias)
start_time1 = time.time()
# Object-oriented execution.  (Requires some extra work, since the 
# defined classes were designed for use in neural networks.)
my_fft_conv = FFTConv1d(3, 2, kernal_size, bias=True)
my_fft_conv.weight = torch.nn.Parameter(kernel)
my_fft_conv.bias = torch.nn.Parameter(bias)
out = my_fft_conv(signal)
print(out.shape)
end_time1 = time.time()
print(end_time1-start_time1)

torch.Size([3, 2, 998])
0.046999454498291016


In [4]:
start_time2 = time.time()
conv = torch.nn.Conv1d(3, 2, kernal_size, bias=True)
conv.weight = torch.nn.Parameter(kernel)
conv.bias = torch.nn.Parameter(bias)
out = conv(signal)
end_time2 = time.time()
print(end_time2-start_time2)

0.024333715438842773


### 1. Brute Forced FFT-OVA

In [5]:
n = signal.shape[2] // kernal_size
L = signal.shape[2] // n
chunks = []
for i in range(n-1):
    chunks.append(signal[:, :, i*L : (i+1)*L])
chunks.append(signal[:, :, (n-1)*L:])

start_time1 = time.time()
# Object-oriented execution.  (Requires some extra work, since the 
# defined classes were designed for use in neural networks.)
outs = [None] * n
fft_conv = [None] * n
for i in range(n):
    fft_conv[i] = FFTConv1d(3, 2, kernal_size, bias=True)
    fft_conv[i].weight = torch.nn.Parameter(kernel)
    fft_conv[i].bias = torch.nn.Parameter(bias)
    outs[i] = fft_conv[i](chunks[i])


end_time1 = time.time()
out = torch.concat(outs, dim=2)
print(out.shape)

print(end_time1-start_time1)

torch.Size([3, 2, 334])
0.2180337905883789


### 2. 2D Blocked FFT-CONV

In [6]:
# Create 2d dummy data.  
#     Data shape: (batch, channels, height, width)
#     Kernel shape: (out_channels, in_channels, kernel_size)
#     Bias shape: (out channels, )
# For ordinary 1D convolution, simply set batch=1.
kernal_size = 3
signal2d = torch.randn(3, 3, 32, 64)
kernel2d = torch.randn(2, 3, kernal_size, kernal_size)
bias = torch.randn(2)

#### 2.1. original 2d fft-conv

In [7]:
# original fft-conv
start_time1 = time.time()
# Object-oriented execution.  (Requires some extra work, since the 
# defined classes were designed for use in neural networks.)
my_fft_conv2d = FFTConv2d(3, 2, kernal_size, bias=True)
my_fft_conv2d.weight = torch.nn.Parameter(kernel2d)
my_fft_conv2d.bias = torch.nn.Parameter(bias)
out = my_fft_conv2d(signal2d)
print(out.shape)
end_time1 = time.time()
print(end_time1-start_time1)

torch.Size([3, 2, 30, 62])
0.007521152496337891


#### fft-ova-conv is the most popular fast convolution
#### however, fft-split-conv, outperforms fft-ova-conv in computation complexity

#### 2.2. fft-split-conv

In [14]:
# fft-split-conv
# data preparation
blocksize = 6
m, n = signal2d.shape[2:]
blockdims = [m // blocksize + 1, n // blocksize + 1]
mpad, npad = blockdims[0] * blocksize, blockdims[1] * blocksize

# pad to proper size
signal2d_pad = F.pad(signal2d, pad=(0, npad - n, 0, mpad - m))
print(signal2d_pad.shape)
# cut padded data into sub matrices
# parallel to optimize
chunks = [[None] * blockdims[1] for _ in range(blockdims[0])]
for i in range(blockdims[0]):
    x1, x2 = i * blocksize, (i+1) * blocksize
    for j in range(blockdims[1]):
        y1, y2 = j * blocksize, (j+1) * blocksize
        chunks[i][j] = (signal2d_pad[:, :, x1 : x2, y1 : y2])
        
# compute fft conv per block
# parallel to optimize
outs = None
fft_conv_2d = [[None] * blockdims[1] for _ in range(blockdims[0])]
for i in range(blockdims[0]):
    for j in range(blockdims[1]):
        fft_conv_2d[i][j] = FFTConv2d(3, 2, kernal_size, bias=True)
        fft_conv_2d[i][j].weight = torch.nn.Parameter(kernel2d)
        fft_conv_2d[i][j].bias = torch.nn.Parameter(bias)
        res = fft_conv_2d[i][j](chunks[i][j])
        if outs == None:
            outs = torch.empty((blockdims[0], blockdims[1], 3, 2, res.shape[-2], res.shape[-1]))
        outs[i, j, :, :, :, :] = res

out.shape

torch.Size([3, 3, 36, 66])


torch.Size([3, 2, 30, 62])

#### 2.3 fft-ova-conv toy model

In [None]:
# fft-ova-conv
# data preparation
m, n = signal2d.shape[2:]
blockdims = [m // kernal_size + 1, n // kernal_size + 1]
mpad, npad = blockdims[0] * kernal_size, blockdims[1] * kernal_size

# pad to proper size
signal2d_pad = F.pad(signal2d, pad=(0, npad - n, 0, mpad - m))

# cut padded data into sub matrices
# parallel to optimize
chunks = [[None] * blockdims[1] for _ in range(blockdims[0])]
for i in range(blockdims[0]):
    x1, x2 = i * kernal_size, (i+1) * kernal_size
    for j in range(blockdims[1]):
        y1, y2 = j * kernal_size, (j+1) * kernal_size
        chunks[i][j] = (signal2d_pad[:, :, x1 : x2, y1 : y2])


In [16]:
# toy example: 4x4 signal
sig = torch.randn(4, 4)
filter = torch.randn(2, 2)

blocks = torch.empty(2, 2, 2, 2)
blocks[0, 0, :, :] = sig[:2, :2]
blocks[1, 0, :, :] = sig[2:, :2]
blocks[0, 1, :, :] = sig[:2, 2:]
blocks[1, 1, :, :] = sig[2:, 2:]

outs = torch.empty(2, 2, 2, 2)
outs[0, 0, :, :] = blocks[0, 0, :, :] * filter
outs[1, 0, :, :] = blocks[1, 0, :, :] * filter
outs[0, 1, :, :] = blocks[0, 1, :, :] * filter
outs[1, 1, :, :] = blocks[1, 1, :, :] * filter

print(outs)

tensor([[[[ 0.2886, -0.3243],
          [-0.3126, -0.3285]],

         [[-0.9751, -0.2347],
          [ 0.8087, -3.8556]]],


        [[[-0.7202, -0.1184],
          [ 0.5144,  0.7840]],

         [[ 0.6198,  0.1922],
          [-0.0352,  2.5967]]]])


In [17]:
import copy
b00 = copy.deepcopy(outs[0, 0, :, :])
b10 = copy.deepcopy(outs[1, 0, :, :])
b01 = copy.deepcopy(outs[0, 1, :, :])
b11 = copy.deepcopy(outs[1, 1, :, :])

In [24]:
uf_row1 = torch.cat((b00.view(1, 4, 1), b01.view(1, 4, 1)), dim=2)
row1 = F.fold(uf_row1, (2,3), kernel_size=(2,2), stride=1)
uf_row2 = torch.cat((b10.view(1, 4, 1), b11.view(1, 4, 1)), dim=2)
row2 = F.fold(uf_row2, (2,3), kernel_size=(2,2), stride=1)
print(row1)
print(row2)

tensor([[[[ 0.2886, -1.2995, -0.2347],
          [-0.3126,  0.4802, -3.8556]]]])
tensor([[[[-0.7202,  0.5014,  0.1922],
          [ 0.5144,  0.7488,  2.5967]]]])


In [21]:
uf_res = torch.cat((row1.view(1, 6, 1), row2.view(1, 6, 1)), dim=2)
res = F.fold(uf_res, (3,3), kernel_size=(2,3), stride=1)
res

tensor([[[[ 0.2886, -1.2995, -0.2347],
          [-1.0328,  0.9816, -3.6635],
          [ 0.5144,  0.7488,  2.5967]]]])