<img src="https://www.inovex.de/wp-content/uploads/inovex-logo-dunkelblau-quadrat.png" width="100px" align="left"/>


<table align="right">

  <td >
    <a target="_blank" href="https://colab.research.google.com/github/inovex/notebooks/blob/main/Neural_Fields_Tutorial.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/inovex/notebooks/blob/main/Neural_Fields_Tutorial.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>


# Neural Fields Tutorial

Hi there, I am glad you found this notebook and are eager to learn something about Neural Fields! You most certainly found this notebook while reading my introductory blog article about the topic. If not, here is the link to the post. I recommend reading the blog article first, because it discusses the fundamentals of the topics you will see in this notebook. There is aslo the possiblity to work through both in parallel, as I refer back and forth between the documents. This notebook is not meant to be a standalone tutorial, so if questions arise while you read trough this, there will be most certainly an explanation of it in the blog article. If something is unclear, feel free to comment under the blog article. I am happy to answer your questions!

This notebook then covers:

  1. Training of a simple Neural Field for image representation
  2. Fourier Feature Mapping
  3. Conditional Neural Field (represented as an Autodecoder)  
  4. Global vs Local Conditioning

### Imports

In [6]:
import os
from dataclasses import dataclass
from pathlib import Path
from collections.abc import Iterable
from typing import List

import imageio
from skimage.transform import resize
from skimage.color import rgba2rgb
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from sklearn.model_selection import train_test_split

!pip install livelossplot --quiet
from livelossplot import PlotLosses

Make sure to connect to a GPU runtime. You can do so by clicking on `Runtime` > `Change Runtime type` in the menu above. Simply select `GPU` in the hardware accelerator dropdown menu. The following cell prints you some
information about the GPU colab assigned to you. It will fail if you are not connected to a GPU Runtime.

In [7]:
!nvidia-smi

Wed Nov 27 09:03:22 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## 1. Neural Field for a single image


In the following, you will train a simple neural field for image representation from scratch. I will recommend reading section 1 of the blog article before you continue with this notebook.

**A short TL;DR**: A field is a function that maps coordinates in space and time to a physical quantity. A neural field is a field parameterized by a Neural Network. For images, we have spatial coordinates (pixels), that map to color intensities. If we want to represent an image as a neural field, we want to learn a function $f: \mathbb{R}^2 \to \mathbb{R}^3$, which maps continuous pixel coordinates $(u, v)$ to the three base colors red, green, and blue. This is also visualized in the image below.


<img src="https://www.inovex.de/wp-content/uploads/coordinate-based-mlp.png" width="500px" align="center"/>


## Data
The following cell downloads an image of our inovex logo. We will use this image to train an implicit neural representation of it.

In [11]:
image = imageio.imread('https://www.inovex.de/wp-content/uploads/inovex-logo-dunkelblau-quadrat.png')[...]
print(image.shape)
image = rgba2rgb(resize(image, (512, 512)))
plt.imshow(image);

  image = imageio.imread('https://www.inovex.de/wp-content/uploads/inovex-logo-dunkelblau-quadrat.png')[...]


(3001, 3001, 3)


ValueError: the input array must have size 4 along `channel_axis`, got (512, 512, 3)

You have now loaded the image. The `rgba2rgb()` function removes the alpha channel and normalizes the color channels to the range `[0, 1]`. Since we want to represent the image as a continuous function, parameterized by a MLP, it is a good idea to also normalize the spatial coordinates to the unit square.

We then split the image in training and test data. Therefore, we use every other pixel as training data sample and the rest as test data. When we want to visualize intermediary results, we take the whole grid and pass it through the network.



In [None]:
image = torch.tensor(image, dtype=torch.float32)

x = np.linspace(0, 1, image.shape[1], endpoint=False)

# sample the grid, which will be the input to the model
grid = torch.tensor(np.stack(np.meshgrid(x, x), -1), dtype=torch.float32)
X, Y = [grid.view(-1, 2), image.view(-1, 3)]
test_X, test_y = [X[1::2], Y[1::2]]
train_X, train_y = [X[::2], Y[::2]]

test_X.requires_grad = False
train_X.requires_grad = False


### Model definition

The following cell defines our model using PyTorch. We will use a very simple MLP Architecture with 4 Layers and 256 Neurons each.

In [None]:
class NeuralField(nn.Module):
  def __init__(self, hidden_layers=2, neurons_per_layer=256, input_dimension=2):
    super().__init__()
    self.input_layer = nn.Linear(input_dimension, neurons_per_layer)
    self.hidden_layers = nn.ModuleList([nn.Linear(neurons_per_layer, neurons_per_layer) for i in range(hidden_layers)])
    self.output_layer = nn.Linear(neurons_per_layer, 3)

  def forward(self, input):
    x = F.relu(self.input_layer(input))
    for layer in self.hidden_layers:
      x = F.relu(layer(x))
    return torch.sigmoid(self.output_layer(x))

### Model training

Alrighty, it's time to train our model! The following cell defines the loss function and a metric that we will use to assess the reconstruction quality. The cell after contsins the model training loop. You can monitor the training by looking at the loss curve and the peak signal to noise ratio (PSNR). We will also save an temporary image every 25 iterations to visualize the training progress afterwards.


For those of you who are not familiar with the PSNR: It quantifies the ratio between the maximum power of a signal and the influence of corrupting noise affecting the preciseness of the signal. It is widely used to quantify the reconstruction quality for images and video subject to lossy compression. It's computation involves the Mean Squared Error between the noise-free image and the lossy approximation. It is expressed as a logarithmic quantity on the decibel scale.

In [None]:
def mse(gt, pred):
  return 0.5 * torch.mean((gt - pred) ** 2., (-1, -2)).sum(-1).mean()

def psnr(gt, pred):
  return -10 * torch.log10(2. * torch.mean((gt - pred) ** 2.))

In [None]:
model = nn.DataParallel(NeuralField().cuda())
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
images = []

liveloss = PlotLosses()
for i in range(2000):
  model.train()
  optimizer.zero_grad(set_to_none=True)
  prediction = model(train_X)
  loss = mse(train_y.to('cuda'), prediction)
  loss.backward()
  optimizer.step()

  if i % 25 == 0:
    with torch.no_grad():
      model.eval()
      reconstruction = model(X).detach().cpu()

    liveloss.update({'PSNR train': psnr(train_y, prediction.detach().cpu()),
                     'Loss train': mse(train_y, prediction.detach().cpu()),
                     'PSNR test': psnr(test_y, reconstruction[::2]),
                     'Loss test': mse(test_y, reconstruction[::2])},
                    current_step=i)
    liveloss.send()
    images.append(reconstruction.numpy().reshape(512, 512, 3))

Now that the training has converged, you are probably eager to see the results. The following cell concatenates all the temporary images into a video and displays it to you.

In [None]:
all_images = np.stack(images)
data8 = (255*np.clip(all_images,0,1)).astype(np.uint8)
f = os.path.join('training_convergence_no_ff.mp4')
imageio.mimwrite(f, data8, fps=20)

# Display video inline
from IPython.display import HTML
from base64 import b64encode
mp4 = open(f,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML(f'''
<video width=500 controls autoplay loop>
      <source src="{data_url}" type="video/mp4">
</video>
''')

You are probably very disappointed by now, and ask yourself where the problem lies. It turns out that Neural Networks have a hard time learning high-frequency details in low dimensional problem domains. For our task at hand, this means that our network struggles learning the sharp edges of the inovex logo (and any other image, for that matter).

There is actually a theoretical explanation from [Tancik et al. (2020)](https://bmild.github.io/fourfeat/) for that, which I explain in the blog article. So this is a good time to switch tabs and read section 2 of the article ;)

In [None]:
model.eval()
pred = model(X.to('cuda')).cpu().detach().numpy().reshape(512, 512, 3)
f, (ax0, ax1) = plt.subplots(1, 2)

ax0.imshow(pred)
ax1.imshow(image);

## 2. Fourier Features for Neural Fields

In the following cell, we inplement the Random fourier feature mapping as proposed by [Rahimi & Recht (2007)](https://proceedings.neurips.cc/paper/2007/file/013a006f03dbc5392effeb8f18fda755-Paper.pdf), which is defined as:

$$\gamma(\boldsymbol{v}) = [cos(2\pi\boldsymbol{B}\boldsymbol{v}, sin(2\pi\boldsymbol{B}\boldsymbol{v}))]^T$$

where $\boldsymbol{v} \in \mathbb{R}^2$ are the raw input pixel coordinates, and $\boldsymbol{B} \in \mathbb{R}^{2 \times D_{FF}}$ is a wide matrix that projects the low-dimensional pixel coordinates into a higher-dimensional space. The dimensionality of the space is a hyperparameter that you can choose freely. Tancik et al. recommend to set it as high as it is possible with your available memory.

The entries of $\boldsymbol{B}$ follow a normal distribution. It turns out that the mean of the normal distribution used to sample the entries of $\boldsymbol{B}$ does not have large influence on the results. The standard deviation, however, does and is the single hyperparameter we need to tune. I set it to 10 for you in the cell below, because this value works quite well for our model and image. If you are curious, you can try out different values for the standard deviation.

In [None]:
FOURIER_DIM = 256
FOURIER_SCALE = 10.
INPUT_DIMS = 2 * FOURIER_DIM

B = FOURIER_SCALE * torch.randn(size=(2, FOURIER_DIM), requires_grad=False)

In [None]:
def apply_fourier_features(x, B):
  projection = (2 * np.pi * x) @ B
  transformed = torch.cat([torch.sin(projection), torch.cos(projection)], dim=-1)
  return transformed

Let us now use the same model architecture as above, but apply the Fourier featuers to the input coordinates before we pass them to the model. Again, we will save a snapshot of the training progress every 25 iterations. Notice how we now reach a much higher PSNR, and a much lower loss value than before already after a few iterations!

In [None]:
model = nn.DataParallel(NeuralField(input_dimension=INPUT_DIMS).cuda())
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ff_images = []

liveloss = PlotLosses()
for i in range(750):
  optimizer.zero_grad(set_to_none=True)
  prediction = model(apply_fourier_features(train_X,  B))
  loss = mse(train_y.to('cuda'), prediction)
  loss.backward()
  optimizer.step()

  if i % 25 == 0:
    with torch.no_grad():
      optimizer.zero_grad(set_to_none=True)
      reconstruction = model(apply_fourier_features(X, B)).detach().cpu()

    liveloss.update({'PSNR train': psnr(train_y, prediction.detach().cpu()),
                     'Loss train': mse(train_y, prediction.detach().cpu()),
                     'PSNR test': psnr(test_y, reconstruction[::2]),
                     'Loss test': mse(test_y, reconstruction[::2])},
                    current_step=i)
    liveloss.send()
    ff_images.append(reconstruction.cpu().detach().numpy().reshape(512, 512, 3))

Let's look again at the result that our model now produces. It looks almost identical to the original image! We only trained for 750 iterations instead of 2000 and the results look much better, all just because of the Fourier feature mapping!

In [None]:
model.eval()
predicted_image = model(apply_fourier_features(X, B).to('cuda')).cpu().detach().numpy()

predicted_image = predicted_image.reshape(image.shape)
plt.imshow(predicted_image);

Also, when you look at the video below, you can see that the model picks up the high frequency components of the inovex logo early. This is exactly the impact of the Fourier feature mapping on the Neural Tangent Kernel that I talk about in the blog article. The scale of the normal distribution that we use to sample the matrix $\boldsymbol{B}$ directly influences the width of the kernel. Hence, if you play around with this hyperparameter you get different results. If the scale is too low, the kernel gets too wide and the result will look blurry. On the other hand, if the kernel gets too narrow (the scale is too large), the result will look grainy.

In [None]:
all_images = np.stack(ff_images)
data8 = (255*np.clip(all_images,0,1)).astype(np.uint8)
f = os.path.join('training_convergence_no_ff.mp4')
imageio.mimwrite(f, data8, fps=10)

# Display video inline
from IPython.display import HTML
from base64 import b64encode
mp4 = open(f,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML(f'''
<video width=500 controls autoplay loop>
      <source src="{data_url}" type="video/mp4">
</video>
''')

## 3. Conditional Neural Fields

Now that we have a Neural Field that works on single images, it is time to extend the model towards learning a whole family of images. Towards this, we will build a conditional neural field as an autodecoder and train it to reconstruct grayscale images of faces from the Yale Face Database [[Belhumeur et al.,  1997](http://vision.ucsd.edu/content/yale-face-database)].

In order to speed things up and see results in orders of minutes, let us just use one subject.

## Data

The following few cells preprocess the data and create a PyTorch Dataset that we can use to train our model. Here a few facts about the dataset:

  - The dataset comprises 165 grayscale images of 15 individuals. Hence we have 11 images per subject in different configurations.
  - Each image is of resolution $243 \times 320$. We will take a $240 \times 240$ crop to train our model.

In [None]:
!wget -qN http://vision.ucsd.edu/datasets/yale_face_dataset_original/yalefaces.zip
!unzip -q yalefaces.zip -d yalefaces


In [None]:
example_face = imageio.imread('yalefaces/yalefaces/subject01.normal')

print(example_face.shape)
plt.imshow(example_face[2:-1, 40:-40], cmap='gray');

In [None]:
# sample the grid, which will be the input to the model
x = np.linspace(0, 1, 240)
grid = torch.tensor(np.stack(np.meshgrid(x,x), -1), dtype=torch.float32, requires_grad=False)

In [None]:
YF_PATH = Path('./yalefaces/yalefaces')

all_files = np.array(list(YF_PATH.rglob('subject01*'))) # only use subject1 for this demo
np.random.shuffle(all_files)

The following cell creates a `torch.utils.data.Dataset` for our data. For simplicity, we will not use a train/test split here. In practice however, you would also make a split on pixel level here in order to spot overfitting or other problems that may occur during training!

In [None]:
@dataclass(eq=False)
class YaleFacesDataset(torch.utils.data.Dataset):
  directory: Path
  split: List[Path]
  grid: np.ndarray

  def __len__(self):
    return len(self.split)

  def __getitem__(self, idx):
    path = self.split[idx]
    img = imageio.imread(path) / 255.
    img = img[2:-1, 40:-40]

    return self.grid, torch.tensor(img, dtype=torch.float32).view(1, *img.shape), idx

In [None]:
ds = YaleFacesDataset(YF_PATH, split=all_files, grid=grid)
loader = torch.utils.data.DataLoader(ds, shuffle=True, batch_size=4, num_workers=0)

Let's look at a few examples and see if everything looks like we would expect.

In [None]:
# functions to show an image
def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(loader)
grids, images, indices = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))


### Model definition

The following few cells contain different implementations of the conditioning
techniques discussed in the blog article.

We use a slightly deeper architecture than in the examples above. Namely, we will create a 4 Layer MLP.


### Concatenation-based conditioning

The following model concatenates the coordinate inputs (which we will of course pass through the Fourier feature mapping prior to passing them into the model), with the latent code for this sample along the `-1` axis. Hence, our coordinates are expected to be of shape `(N0, N1, ..., FOURIER_FEATURE_DIM)`, where `N1, N2, ...` are arbitrarily many batch dimensions, and the latent tensor should have shape `(N0, N1, ..., LATENT_DIM)`, where the batch dimensions match the ones of the coordinates tensor.

The image below depitcs this conditioning method.

<img src="https://www.inovex.de/wp-content/uploads/concat-field.drawio.png" width="500px" align="center"/>

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_channels, latent_size, image_shape=(240, 240)):
        super().__init__()

        # conv_theta is input convolution
        self.conv_theta = nn.Conv2d(num_channels, 128, 3, 1, 1)
        self.relu = nn.ReLU(inplace=True)

        self.cnn = nn.Sequential(
            nn.Conv2d(128, 256, 5, 2, 0),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 2, 0),
            nn.ReLU(),
            nn.Conv2d(256, latent_size, 1, 1, 0)
        )

        self.latent_size = latent_size
        self.fc = nn.Linear(58*58, 1)


    def forward(self, I):
        o = self.relu(self.conv_theta(I))
        o = self.cnn(o)
        o = self.fc(torch.relu(o).view(o.shape[0], self.latent_size, -1))
        return o.squeeze(-1)


In [None]:
class ConditionalFieldConcat(nn.Module):
  def __init__(self, num_layers=4, layer_size=128, input_dimensions=(2, 32)):
    super().__init__()

    self.num_layers = num_layers
    self.hidden_layers = nn.ModuleList()
    self.input_layer = nn.Linear(sum(list(input_dimensions)), layer_size)
    for i in range(num_layers-1):
        self.hidden_layers.append(nn.Linear(layer_size, layer_size))

    self.output_layer = nn.Linear(layer_size, 1)

  def forward(self, coordinate, latent):
    cat = torch.cat([coordinate, latent], -1)
    input = F.relu(self.input_layer(cat))
    x = input.clone()
    for i in range(self.num_layers -1):
      if i == self.num_layers // 2:
        x = x + input
      x = F.relu(self.hidden_layers[i](x))

    return torch.sigmoid(self.output_layer(x))

Another way to incorporate the conditioning variable is by projecting the coordinate inputs and the latent code input to the same dimsion using two separate MLP layers and then simply add the projections. See the image below for a reference.
<img src="https://www.inovex.de/wp-content/uploads/Conditional-Field.drawio.png" width="500px" align="center"/>

In [None]:
class ConditionalField(nn.Module):
  def __init__(self, num_layers=4, layer_size=128, input_dimensions=(2, 32)):
    super().__init__()

    self.num_layers = num_layers
    self.coordinate_input = nn.Linear(input_dimensions[0], layer_size)
    self.latent_input = nn.Linear(input_dimensions[1], layer_size)
    self.hidden_layers = nn.ModuleList()
    for i in range(num_layers-1):
        self.hidden_layers.append(nn.Linear(layer_size, layer_size))

    self.output_layer = nn.Linear(layer_size, 1)

  def forward(self, coordinate, latent):
    x_c = F.relu(self.coordinate_input(coordinate))
    x_l = F.relu(self.latent_input(latent))
    x = x_c + x_l
    for i in range(self.num_layers -1):
      if i == self.num_layers // 2:
        x = x + x_c
      x = F.relu(self.hidden_layers[i](x))

    return torch.sigmoid(self.output_layer(x))

Last, but not least, we can use the Feature-wise Linear Modulation mapping, which I implemented for you in a basic version below. Note that this mechanism is slower than the other two.

To explain it again, have a look at the image below. The latent vector is processed by two linear layers $\boldsymbol{\beta}$ and $\boldsymbol{\gamma}$. The two resulting feature tensors are used in an affine projection with the positional information $x$, which itself is processed by a linear layer. The output of one layer can thus be described by the formula:

$\hat{x} = L(x) \odot \gamma(z) + \beta(z) $

<img src="https://www.inovex.de/wp-content/uploads/FiLM_CBN-768x597.png" width="500px" align="center"/>

In [None]:
class FiLM(nn.Module):
  def __init__(self, num_features, latent_size):
    super().__init__()
    self.num_features = num_features
    self.latent_size = latent_size

    self.beta = nn.Linear(latent_size, num_features)
    self.gamma = nn.Linear(latent_size, num_features)
    self.affine = nn.Linear(num_features, num_features)


  def forward(self, x, z):
    beta = self.beta(z)
    gamma = self.gamma(z)

    out = self.affine(x)
    return gamma * out + beta

class FiLMConditionedField(nn.Module):
  def __init__(self, num_layers=4, layer_size=128, input_dimensions=(2, 32)):
    super().__init__()

    self.num_layers = num_layers
    self.hidden_layers = nn.ModuleList()
    self.input_layer = nn.Linear(input_dimensions[0], layer_size)
    for i in range(num_layers-1):
        self.hidden_layers.append(FiLM(layer_size, input_dimensions[1]))

    self.output_layer = nn.Linear(layer_size, 1)

  def forward(self, coordinate, latent):
    coord = self.input_layer(coordinate)
    x = coord.clone()
    for i in range(self.num_layers -1):
      if i == self.num_layers // 2:
        x = x + coord
      x = F.relu(self.hidden_layers[i](x, latent))

    return torch.sigmoid(self.output_layer(x))

Now, lets put everything together by combining the encoder with the decoder of your choice. Feel free to ajust the code. You may need to adjust some of the hyper parameters too, if you chose another conditioning method.

In [None]:
class ConditionalNeuralField(nn.Module):

  def __init__(self, encoder: nn.Module, decoder: nn.Module):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def encode(self, image):
    return self.encoder(image)

  def decode(self, coordinate, code):
    return self.decoder(coordinate, code)

  def forward(self, coordinate, image):
    encoding = self.encode(image)
    batch_dims = coordinate.shape[:-1]
    batched_encoding = encoding.view(batch_dims[0], 1, 1, -1)
    batched_encoding = batched_encoding.repeat(1, *batch_dims[1:], 1)

    return self.decode(coordinate, batched_encoding), encoding

### Training preparation

Next, we have to instantiate our model, create our set of latent vectors, and create an optimizer that keeps track of both the model parameters and the latent vectors. We will also define the loss function in here. We use the per-pixel mean squared error like in the example above and add a L2 regularization term over the latent vectors.

In [None]:
# Define some variables that we will use in the following cells
NUM_EPOCHS = 250
LATENT_VECTOR_SIZE =128
LAYER_SIZE = 128
MODEL_LR = 1e-3
REG_WEIGHT = 1e-3
FOURIER_DIM = 256
FOURIER_SCALE = 10.
INPUT_DIMS = (2 * FOURIER_DIM, LATENT_VECTOR_SIZE)

In [None]:
def loss_fn(gt, pred, vectors):
  data_loss = mse(gt, pred)
  regularization = torch.mean(vectors ** 2.)

  return data_loss, regularization

In [None]:
# feel free to change the model type here.
encoder = Encoder(1, LATENT_VECTOR_SIZE)
decoder = ConditionalField(input_dimensions=INPUT_DIMS, layer_size=LAYER_SIZE)

model = nn.DataParallel(ConditionalNeuralField(encoder, decoder).cuda())

optimizer = torch.optim.Adam(
    [
     {
      'params': model.parameters(),
      'lr': MODEL_LR,
     },
])

# I found the learning rate schedule to be helpful with the FiLM Mapping.
schedule = torch.optim.lr_scheduler.StepLR(optimizer, 25)

B = FOURIER_SCALE * torch.randn(size=(2, FOURIER_DIM), requires_grad=False)

### Train Loop

In [None]:
del encoder, decoder
torch.cuda.empty_cache()

In [None]:
liveloss = PlotLosses()

for epoch in range(NUM_EPOCHS):
  losses = []
  regs = []
  psnrs = []
  for X, y, idx in loader:
    X.requires_grad = False
    batch_size = len(idx)

    optimizer.zero_grad(set_to_none=True)

    ff = apply_fourier_features(X, B)
    ff.requires_grad = False

    prediction, encoding = model(ff, y.reshape(-1, 1, 240, 240))

    dloss, reg = loss_fn(y.squeeze(), prediction.squeeze().cpu(), encoding)
    loss = dloss + REG_WEIGHT * reg
    psnrs.append(psnr(y.squeeze(), prediction.squeeze().detach().cpu()))
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    regs.append(reg.item())

  schedule.step()


  liveloss.update({'Loss': np.mean(losses), 'L2': np.mean(regs)}, current_step=epoch)
  liveloss.update({'PSNR': np.mean(psnrs)}, current_step=epoch)
  liveloss.send()

In [None]:
import torchvision as tv
model.eval()


X,y,_ = next(iter(loader))

ff = apply_fourier_features(X, B)
prediction, encoding = model(ff, y.reshape(-1, 1, 240, 240))

pred = prediction.detach().cpu().reshape(-1, 1, 240, 240)

gt = tv.utils.make_grid(y.reshape(-1, 1, 240, 240)).permute(1, 2, 0).numpy()
pred = tv.utils.make_grid(pred).permute(1, 2, 0).numpy()

stacked = np.vstack([gt, pred])

fig = plt.figure(figsize=(12, 6))
plt.imshow(stacked);

The reusults do not look as clear as the reconstruction of the single image. There are multiple reasons for that. Generally, we have very little data, so it may be that the architecture struggles a bit to encode common features in the decoder and specific features in the latent codes that are generated by the encoder. Let me show you one last cool idea to condition the neural field.

### Local Conditioning

We can also use multiple latent codes to encode information. In this scenario, each code is tied to a spatial area on the input domain.

In the following, we create a $4 \times 4$ embedding grid with a soft assignment. We will also use the autodecoder architecture here. Therefore, we create an embedding, which holds the latent vectors for each spatial subregion and define a function that maps the vectors to the subregions for which they are responsible. For the interpolation, we build an `emb_weights` matrix, which will be of shape `(num_pixels, 16)` and contain weigths that are used later to soft-assign local latent codes to pixels. The rows of the matrix sum to one and a single pixel is assigned to a maximum of two latent codes.

Just in case you wonder while reading the code below: We are using a slightly different architecture. Below, we will train the Neural Field as an autodecoder. Meaning that we do not use an encoder to create the latent encoding for us, but rather store multiple vectors in a tensor and access them by their indices. If you want a refresher on this, head over to the blog article accompanying this notebook.

In [None]:
NUM_SPATIAL_CODES_PER_SIDE = 4 # 4 by 4 embedding grid
NUM_CODES = NUM_SPATIAL_CODES_PER_SIDE ** 2
TILED_LATENT_VECTOR_SIZE = LATENT_VECTOR_SIZE // NUM_CODES
local_embeddings = torch.nn.Embedding(len(ds) * NUM_CODES, TILED_LATENT_VECTOR_SIZE, max_norm=1.)

emb_vertices, step = np.linspace(0, 1, NUM_SPATIAL_CODES_PER_SIDE, retstep=True, endpoint=False)
emb_vertices += step/2

embedding_grid = torch.tensor(
    np.stack(np.meshgrid(emb_vertices, emb_vertices), -1),
    dtype=torch.float32, requires_grad=False
    )


grid_flat = grid.reshape(-1, 2)
emb_grid_flat = embedding_grid.reshape(-1, 2)

emb_weights = torch.zeros(len(grid_flat), NUM_SPATIAL_CODES_PER_SIDE **2)

for i, x in enumerate(grid_flat):
  dists, _ = torch.max(torch.abs(x -  emb_grid_flat), -1)
  dists[dists > step] = 0


  emb_weights[i, :] = dists / torch.sum(dists)

Again, we create the model and optimizer. As we are in the autodecoder setting now, we need to also assign the `local_embedding` tensor to our optimizer, so that the values will get updated during training.

In [None]:
INPUT_DIMS = (2 * FOURIER_DIM, TILED_LATENT_VECTOR_SIZE)
LATENT_LR = 1e-4

model = nn.DataParallel(ConditionalFieldConcat(input_dimensions=INPUT_DIMS, layer_size=LAYER_SIZE).cuda())

optimizer = torch.optim.Adam(
    [
     {
      'params': model.parameters(),
      'lr': MODEL_LR,
     },
     {
      'params': local_embeddings.parameters(),
      'lr': LATENT_LR
     }
])

# I found the learning rate schedule to be helpful with the FiLM Mapping.
schedule = torch.optim.lr_scheduler.StepLR(optimizer, 40)

B = FOURIER_SCALE * torch.randn(size=(2, FOURIER_DIM), requires_grad=False)

The following function builds the vector tensor that we pass to our model. Instead of using a single, global vector for each pixel coordinate, we will use the `emb_weight` matrix to perform the soft assignment.

In [None]:
def build_vectors(indices):
  vectors = []
  raw = []
  for idx in indices:
    embs = local_embeddings(torch.arange(idx * NUM_CODES, idx * NUM_CODES + NUM_CODES, dtype=torch.int32))
    raw.append(embs)
    vectors.append(emb_weights @ embs)

  vectors = torch.stack(vectors)
  raw = torch.stack(raw)
  return vectors.reshape(len(indices), *grid.shape[:2], -1), raw



### Model Training

It's time to train this architecture. It takes a few more iterations than the other methods and thus, a few minutes more. Just that you aware ;)

In [None]:

liveloss = PlotLosses()

for epoch in range(500):
  losses = []
  psnrs = []
  for X, y, idx in loader:
    X.requires_grad = False
    batch_size = len(idx)

    optimizer.zero_grad(set_to_none=True)

    ff = apply_fourier_features(X, B)
    ff.requires_grad = False

    vectors, vraw = build_vectors(idx)
    prediction = model(ff, vectors)

    data_loss, reg = loss_fn(y.squeeze(), prediction.squeeze().cpu(), vraw)
    loss = data_loss + REG_WEIGHT * reg
    psnrs.append(psnr(y.squeeze(), prediction.squeeze().detach().cpu()))
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

  #schedule.step()

  liveloss.update({'Loss': np.mean(losses)}, current_step=epoch)
  liveloss.update({'PSNR': np.mean(psnrs)}, current_step=epoch)
  liveloss.send()




Uff, that took a while! But you should see from the loss curve that the model indeed did learn something and is able to create images of faces from the local conditioning. Let us have a look at some of the results.

In [None]:
model.eval()
random_idx = torch.tensor(np.random.choice(len(ds), 4))
images = []
for idx in random_idx:
  embs = local_embeddings(torch.arange(idx * NUM_CODES, idx * NUM_CODES + NUM_CODES, dtype=torch.int32))
  emb = (emb_weights @ embs).reshape(1, *grid.shape[:2], -1)

  ff = apply_fourier_features(grid, B)

  pred = model(ff.unsqueeze(0), emb).squeeze().detach().cpu().numpy()
  images.append(pred)

concat = np.concatenate(images, axis=1)
fig = plt.figure(figsize=(12,3))
plt.imshow(concat, cmap='gray');

Even after 500 epochs of training, you can see some of the boundaries between the patches. This indicates that our interpolation between the patches is maybe not ideal. If you want to, you can use the code above and play with it a little bit to see whether you can create a better result.

## Final remark

So you worked your way through this quite lentghy notebook, congratulations! I hope you learned something and enjoyed this content. At least I did while creating this notebook and I definitely will proceed working on and with Neural Fields, as I truly believe they have the potential to do amazing things at the interesection of Deep Learning and Physics, or Deep Learning and Arts. If you enjoyed the notebook, feel free to share it with your peers.

---

<img src="https://www.inovex.de/wp-content/uploads/inovex-logo-dunkelblau-quadrat.png" width="100px" align="left">
<br /><br /><br /><br /><br />


Find more notebooks like this on our <a href="https://github.com/inovex/notebooks">Github</a>.

For in-depth articles have a look at our [blog](https://www.inovex.de/blog).

Shared with 💙 by [inovex.de](https://www.inovex.de).