Reference 
* [how-do-vits-work](https://colab.research.google.com/github/xxxnell/how-do-vits-work/blob/transformer/fourier_analysis.ipynb#scrollTo=a7350a20)
* [Fourier Convolutions with Kernel Sizes of 1024x1024 and Larger](https://towardsdatascience.com/fourier-cnns-with-kernel-sizes-of-1024x1024-and-larger-29f513fd6120)
* [Fast-CNN](https://github.com/pushkar-khetrapal/Fast-CNN/blob/master/Convolution%20vs%20FFT%20explaination.ipynb)

In [1]:
import numpy as np 
from PIL import Image 

import torch 
import torch.nn as nn 
import torchvision.transforms as T 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def fourier(x):  # 2D Fourier transform
    f = torch.fft.fft2(x)
    f = f.abs() + 1e-6
    f = f.log()
    return f

def shift(x):  # shift Fourier transformed feature map
    b, c, h, w = x.shape
    return torch.roll(x, shifts=(int(h/2), int(w/2)), dims=(2, 3))    

In [3]:
img = Image.open("Lenna.png").convert('L')
img.show()

# Spatial Convolution 

In [4]:
def convolution(img_slice, kernel):
    conv_img = 0
    conv_img = np.multiply(img_slice, kernel)
    conv_img = np.sum(conv_img)
    return conv_img

In [5]:
img_np = np.array(img)

H, W = img_np.shape

In [6]:
filter = np.array([[-1,  -1, -1],
                   [-1,  8, -1],
                   [-1, -1, -1]])  # edge detection filter 
                                   # https://medium.com/@kgerding/image-kernels-2f8a36087b75

In [7]:
convoled_img = np.zeros(shape = (H-2, W-2))

for i in  range(H-2): 
    for j in range(W-2):
        img_stem = img_np[i:i+3,j:j+3]
        convoled_img[i,j] = convolution(img_stem, filter)

In [8]:
convoled_img_pil = Image.fromarray(convoled_img)
convoled_img_pil.show()

# FFT Convolution 

In [9]:
def fourier2d(img_np):  # 2D Fourier transform
    f = np.fft.fft2(img_np)
    fshift = np.fft.fftshift(f) # low_frequency components are at the center
    magnitude_spectrum = 20*np.log(np.abs(fshift))

    return magnitude_spectrum

In [10]:
out = fourier2d(img_np)

Image.fromarray(out).show()

In [11]:
f2 = np.pad(filter,(111,111))
f21 = np.fft.fftn(f2)
print(f21.shape)

(225, 225)
