In this notebook, we continue our investigation of the **Si**nusoidal **Re**presentation **N**etworks (SIREN)  presented in __*Sitzmann, V., Martel, J., Bergman, A., Lindell, D., & Wetzstein, G. (2020). Implicit neural representations with periodic activation functions. Advances in Neural Information Processing Systems, 33, 7462-7473*.__

Here we implement a frequency encoding of the input. We will eventually work towards the multiresolution strategy discussed in __Müller, T., Evans, A., Schied, C., & Keller, A. (2022). Instant neural graphics primitives with a multiresolution hash encoding. ACM Transactions on Graphics (ToG), 41(4), 1-15.__


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

First, we will get an image from the web and construct a dataset from it.

In [None]:
!wget https://cdn.theatlantic.com/thumbor/ISgRyw-VeYqYCE38o7HkVAiz90c=/900x626/media/img/photo/2020/02/photos-superb-owl-sunday-iv/s01_1103328920/original.jpg

In [None]:
# image = Image.open('original.jpg')
# transform = transforms.Compose([transforms.PILToTensor()])
# image = transform(image)
# plt.imshow()

from matplotlib import image
im = image.imread('original.jpg')
plt.imshow(im)
plt.show()

In [None]:
# normalization of the image
im = torch.tensor(im)
nrows, ncols, nchannels = im.shape

im = im/255 # \im in [0,1]
rows = torch.arange(0, nrows)
cols = torch.arange(0, ncols)

rows = 2/(rows[-1] - rows[0]) * ( rows - (rows[-1] + rows[0])/2 )
cols = 2/(cols[-1] - cols[0]) * ( cols - (cols[-1] + cols[0])/2 )

grid_rows, grid_cols = torch.meshgrid(rows, cols, indexing='ij')

X = torch.stack((grid_rows.reshape(-1), grid_cols.reshape(-1)), dim=1)
Y = im.view(-1,3)

# hardwired 80%-20% training and validation split
n1 = int(0.8*nrows*ncols)
idx = torch.randperm(nrows*ncols)
Xtr, Ytr = X[idx[:n1]], Y[idx[:n1]]
Xval, Yval = X[idx[:n1:]], Y[idx[n1:]]

train_data = (Xtr, Ytr)
val_data = (Xval, Yval)

In [None]:
pred_im = Y.view((nrows,ncols,nchannels))
plt.imshow(pred_im.detach().numpy())
plt.show()

# Frequency Encoding
Contains no trainable parameters

In [None]:
class InputEncoding(nn.Module):

  def __init__(self,L):
    super().__init__()
    self.L = L
    self.scale = torch.tensor([float(2)**l for l in range(L)])

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # input: x is (B,D)
    # output: (B, D*L)
    z = torch.kron(self.scale, x)
    out = torch.hstack((torch.sin(z),torch.cos(z)))
    return out

## Building the Neural Network

In [None]:
class SinusoidalBlock(nn.Module):

    def __init__(self, fan_in, fan_out, w0):
      super().__init__()
      self.linear = nn.Linear(fan_in, fan_out, bias=True)
      self.sin = torch.sin
      self.w0 = w0

      #adjust inits
      c = torch.sqrt(torch.tensor(6))
      with torch.no_grad():
        if fan_in == 2: #initial block
          self.lin.weight *=  self.w0 * c
        else:
          self.linear.weight *= c
        self.linear.bias *= c

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.linear(x)
      x = self.sin(x)
      return x


class Siren(nn.Module):
    """
    SIREN, i.e., a MLP with sinusoidal activations
    """
    
    def __init__(self, arch, w0, L) -> None:
      super().__init__()
      self.L = L
      self.encoding = InputEncoding(L)
      self.body = nn.Sequential( *(SinusoidalBlock(a,b, w0) for a,b in zip(arch[0:-2], arch[1:-1])) )
      self.head = nn.Linear(arch[-2], arch[-1], bias=True)

      #adjust inits
      with torch.no_grad():
        self.head.weight *= torch.sqrt( torch.tensor(6) )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (B,D)
        x = self.encoding(x)
        x = self.body(x)
        x = self.head(x)
        return x

In [None]:
L = 20
D = 2
arch = [2*D*L, 128, 128, 128, 128, 3]
w0 = 30 #note the influence of this value on the prior predictions
model = Siren(arch, w0, L)

In [None]:
# prior to training
with torch.no_grad():
  z = model(X)
pred_im = z.view((nrows,ncols,nchannels))
#plt.contourf(grid_x, grid_y, pred_im)
plt.imshow(pred_im.detach().numpy())
plt.show()

In [None]:
#Some default hyperparameters
batch_size = 1024
max_iters = 20000
eval_iters = 100
learning_rate = 1e-3
weight_decay = 0.0
out_freq = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

def loss_func(ypred: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:
    # ypred and yb are (B,3)
    loss = (ypred-yb)**2
    loss.sum(dim=1)
    return loss.mean()

 # Helper function

def get_batch(opt):
  data = train_data if opt == 'train' else val_data
  ix = torch.randint(low = 0, high = data[0].shape[0], size = (batch_size,)) 
  x = data[0][ix] #create each block at each starting location in ix
  y = data[1][ix] #create targets for each block in the batch
  x,y = x.to(device), y.to(device)
  return x, y 

@torch.no_grad()
def estimate_loss(model):
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      preds = model(X)
      losses[k] = loss(preds,Y).item()
    out[split] = losses.mean()
  model.train()
  return out

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for p in model.parameters():
  p.grad = None
  p.requires_grad = True

lossi = []

In [None]:
model.to(device)

model.train()
for iter in range(max_iters):
  
    xb, yb = get_batch('train')
    
    pred = model(xb)
    loss = loss_func(pred, yb)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # tracking
    lossi.append(loss.log10().item())

    # outputs
    if iter % out_freq == 0:
        print(f'iter {iter:7d} | loss  {loss.item():.12f}') 


In [None]:
#so far, the best model seems to have a validation accuracy of around 0.9 (on a chosen batch)
plt.plot(torch.tensor(lossi).view(-1,100).mean(dim=1)) 

In [None]:
model.eval()
with torch.no_grad():
  z = model(X)

total_loss = loss_func(z,Y)
print(f'total loss: {total_loss}')
pred_im = z.view((nrows,ncols,nchannels))
plt.imshow(pred_im.detach().numpy())
plt.show()

# the predictions don't quite capture the blur in the background of the original image

# ReLu network
For comparison

In [None]:
class ReluBlock(nn.Module):

    def __init__(self, fan_in, fan_out):
      super().__init__()
      self.lin = nn.Linear(fan_in, fan_out, bias=True)
      self.relu = torch.relu

      with torch.no_grad():
        self.lin.weight *= 0.01
        self.lin.bias += 0.5


    def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.lin(x)
      x = self.relu(x)
      return x


class ReluNet(nn.Module):
    """
    An MLP with relu activations
    """
    
    def __init__(self, arch, L) -> None:
      super().__init__()
      self.L = L
      self.encoding = InputEncoding(L)
      self.body = nn.Sequential( *(ReluBlock(a,b) for a,b in zip(arch[0:-2], arch[1:-1])) )
      self.head = nn.Linear(arch[-2], arch[-1], bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoding(x)
        x = self.body(x)
        x = self.head(x)
        return x

In [None]:
model2 = ReluNet(arch,L)

In [None]:
# prior to training
with torch.no_grad():
  z = model2(X)
pred_im = z.view((nrows,ncols,nchannels))
#plt.contourf(grid_x, grid_y, pred_im)
plt.imshow(pred_im.detach().numpy())
plt.show()

In [None]:
optimizer = torch.optim.Adam(model2.parameters(), lr=learning_rate, weight_decay=weight_decay)

for p in model2.parameters():
  p.grad = None
  p.requires_grad = True

loss2i = []

In [None]:
model2.to(device)

model2.train()
for iter in range(max_iters):
  
    xb, yb = get_batch('train')
    
    pred = model2(xb)
    loss = loss_func(pred, yb)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # tracking
    loss2i.append(loss.log10().item())

    # outputs
    if iter % out_freq == 0:
        print(f'iter {iter:7d} | loss  {loss.item():.12f}') 


In [None]:
model2.eval()
with torch.no_grad():
  z = model2(X)

total_loss = loss_func(z,Y)
print(f'total loss: {total_loss}')
pred_im = z.view((nrows,ncols,nchannels))
plt.imshow(pred_im.detach().numpy())
plt.show()