A reimplementation of "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains" (https://github.com/ndahlquist/pytorch-fourier-feature-networks/tree/master). We are going to represent a 2D image with an MLP and verify the effectiveness of Fourier Features. 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.io import read_image, ImageReadMode
from torchvision.transforms.functional import to_pil_image

from tqdm import tqdm
from einops import rearrange

def viz_image(pt_img: torch.Tensor):
    pil_img = to_pil_image(pt_img)
    display(pil_img)

    
input_image = read_image('misuzu.png', ImageReadMode.RGB)
input_image = input_image.to(torch.float32) / 255
input_image = input_image.unsqueeze(0)
input_image = F.interpolate(input_image, (256, 256), mode='bilinear')
viz_image(input_image[0])


In [None]:
class MLP(nn.Module):
    def __init__(self, in_c, out_c=3, hiden_states=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Conv2d(in_c, hiden_states, 1), nn.ReLU(), nn.BatchNorm2d(hiden_states),
            nn.Conv2d(hiden_states, hiden_states, 1), nn.ReLU(), nn.BatchNorm2d(hiden_states),
            nn.Conv2d(hiden_states, hiden_states, 1), nn.ReLU(), nn.BatchNorm2d(hiden_states),
            nn.Conv2d(hiden_states, out_c, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.mlp(x)

In [None]:
H, W = input_image.shape[2:]

h_coord = torch.linspace(0, 1, H)
w_coord = torch.linspace(0, 1, W)
grid = torch.stack(torch.meshgrid([h_coord, w_coord]), -1).permute(2, 0, 1).unsqueeze(0)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
n_loops = 400
input_image = input_image.to(device)
grid = grid.to(device)
for epoch in tqdm(range(n_loops)):
    output = model(grid)
    loss = F.l1_loss(output, input_image)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 100 == 0 or epoch == n_loops - 1:
        viz_image(output[0])
        print(loss.item())

In [None]:
class FourierFeature(nn.Module):
    def __init__(self, in_c, out_c, scale):
        super().__init__()
        fourier_basis = torch.randn(in_c, out_c // 2) * scale
        self.register_buffer('_fourier_basis', fourier_basis)
        
    def forward(self, x):
        N, C, H, W = x.shape
        x = rearrange(x, 'n c h w -> (n h w) c')
        x = x @ self._fourier_basis
        x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W)
            
        x = 2 * torch.pi * x
        x = torch.cat([torch.sin(x), torch.cos(x)], dim=1) 
        return x
        
feature_length = 256
model = MLP(feature_length).to(device)
fourier_feature = FourierFeature(2, feature_length, 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
n_loops = 400
for epoch in tqdm(range(n_loops)):
    x = fourier_feature(grid)
    output = model(x)
    loss = F.l1_loss(output, input_image)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 100 == 0 or epoch == n_loops - 1:
        viz_image(output[0])
        print(loss.item())
prev_output = output

In [None]:
N, C, H, W = grid.shape
tx = 50 / H
ty = 0
theta = torch.tensor(torch.pi * 1 / 8)
affine_matrix = torch.tensor([
    [torch.cos(theta), -torch.sin(theta), tx],
    [torch.sin(theta), torch.cos(theta), ty],
    [0, 0, 1]
]
).to(device)
grid_ext = torch.ones(N, 3, H, W).to(device)
grid_ext[:, :2] = grid.clone()
grid_ext = grid_ext.permute(0, 2, 3, 1)
grid_ext = (grid_ext @ affine_matrix.T)
grid_ext = grid_ext.permute(0, 3, 1, 2)[:, :2]

x = fourier_feature(grid_ext)
output = model(x)
viz_image(output[0])

In [None]:
def aff_transform(model, x, aff):
    N, C, H, W = x.shape
    x = rearrange(x, 'n c h w -> (n h w) c')
    
    if aff is not None:
        phases = torch.zeros(1, model._fourier_basis.shape[1]).to(x.device)
        phases = phases + (model._fourier_basis.T @ aff[:2, 2:]).T
        freq = model._fourier_basis.T @ aff[:2, :2]
        x = x @ freq.T
        x = x + phases
    else:
        x = x @ model._fourier_basis
    x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W)
    
    
    x = 2 * torch.pi * x
    x = torch.cat([torch.sin(x), torch.cos(x)], dim=1) 
    return x

x = aff_transform(fourier_feature, grid, affine_matrix)
output2 = model(x)
viz_image(output2[0])
print(F.l1_loss(output2, output).item())

In [None]:
class FourierFeature(nn.Module):
    def __init__(self, in_c, out_c, scale):
        super().__init__()
        fourier_basis = torch.randn(in_c, out_c) * scale
        self.register_buffer('_fourier_basis', fourier_basis)
        
    def forward(self, x):
        N, C, H, W = x.shape
        x = rearrange(x, 'n c h w -> (n h w) c')
        x = x @ self._fourier_basis
        x = rearrange(x, '(n h w) c -> n c h w', h = H, w = W)
            
        x = 2 * torch.pi * x
        x = torch.sin(x)
        return x

feature_length = 256
model2 = MLP(feature_length).to(device)
fourier_feature = FourierFeature(2, feature_length, 10).to(device)
optimizer = torch.optim.Adam(model2.parameters(), lr=1e-4)
n_loops = 500
for epoch in tqdm(range(n_loops)):
    x = fourier_feature(grid + 10)
    output = model2(x)
    loss = F.l1_loss(output, input_image)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 100 == 0 or epoch == n_loops - 1:
        viz_image(output[0])
        print(loss.item())
        print(F.l1_loss(output, prev_output).item())