# Loading demo data

In [None]:
# Image taken from Fourier Feature Network authors' colab demo: https://colab.research.google.com/github/tancik/fourier-feature-networks/blob/master/Demo.ipynb
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm
import os, imageio

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Download image, take a square crop from the center
image_url = 'https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg'
img = imageio.imread(image_url)[..., :3] / 255.
c = [img.shape[0]//2, img.shape[1]//2]
r = 256
img = img[c[0]-r:c[0]+r, c[1]-r:c[1]+r]

plt.imshow(img)
plt.show()

# Create input pixel coordinates in the unit square
coords = np.linspace(0, 1, img.shape[0], endpoint=False)
x_test = np.stack(np.meshgrid(coords, coords), -1)
test_data = [x_test, img]
train_data = [x_test[::2,::2], img[::2,::2]]

In [None]:
def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)
    return m

def sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
    return m

class Sine(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
        return torch.sin(30 * x)

class Siren(nn.Module):
    def __init__(self, depth=5, mapping_size=2, hidden_size=256):
        super().__init__()
        layers = []
        layers.append(first_layer_sine_init(nn.Linear(mapping_size,hidden_size)))
        layers.append(Sine())
        for i in range(depth-2):
            layers.append(sine_init(nn.Linear(hidden_size,hidden_size)))
            layers.append(Sine())
        layers.append(nn.Linear(hidden_size,3))
        self.layers = nn.Sequential(*layers)
    def forward(self,x,mode=None):
        if mode is None:
            return torch.sin(30*self.layers(x))
        if mode == 'sigmoid':
            return torch.sigmoid(self.layers(x))
        if mode == 'linear':
            return self.layers(x)
model = Siren()
model(torch.randn(100,2)).shape

In [None]:
xb,yb = torch.tensor(train_data[0]).reshape(-1,2),torch.tensor(train_data[1]).reshape(-1,3)
x_test,y_test = torch.tensor(test_data[0]).reshape(-1,2),torch.tensor(test_data[1]).reshape(-1,3)
xb,yb,x_test,y_test = xb.float().cuda(),yb.float().cuda(),x_test.float(),y_test.float()
pass

# Default siren model

In [None]:
model = Siren().cuda()
opt = torch.optim.Adam(model.parameters(),lr=1e-4)
loss = nn.MSELoss()
for i in tqdm(range(10000)):
    ypred = model(xb,'sigmoid')
    l = loss(ypred,yb)
    opt.zero_grad()
    l.backward()
    opt.step()

In [None]:
model.cpu().eval()
with torch.no_grad():
    ypreds = model(x_test,'sigmoid')
    ypreds = ypreds.reshape(512,512,3)

In [None]:
plt.imshow(ypreds)