# Embedding Images into Neural Networks Using Positional Encoding

Encoding data into a small network is useful for a variety of applications. 

One reason might be to compress data. It's possible to encode the entire data in a model that is much smaller, with minimal quality loss.

Another might be to interpolate data. It's possible to train the network with only a few examples you have in hand, and make it generate all the data interpolated, like NeRF and SIREN.

# Model inputs and outputs

In this notebook, we will train a model that decodes x and y coordinates of pixels to predict the color value of that coordinate.

The input for our neural network will be pair of x and y coordinates, (batch_size, x_encoding_length) and (batch_size, y_encoding_length).
Encoding length will be the log base 2 of the width and height shape + 2. We add 2 to have more redudant information for the network to use.

The output will be the RGB value, (batch_size, 3).

# Library imports

We will use the PyTorch library for our model, and define everything else we need.

In [1]:
import torch
from torch import nn
import torch.optim as optim
import numpy as np
import math

from PIL import Image

# Data Setup

Drop an image in the images folder that we want to embed in a network.
We will first convert the image into a numpy array, and divide by 255 so the color values range from 0 to 1.

We need the image width and height, and define our encoding length as the log base 2 of the length, and add 2 to it.

We will define the our positional encoding for x and y coordinates.

# Positional Encoding

In machine learning, sometimes you need to represent coordinates, or just numbers. As an example, let's see how we can encode a number from 0 to N.

You can encode every number you want as one hot encoding, meaning you have a matrix with dimension (batch_size, N), and encode the number N as the Nth dimension being 1 and the rest is 0.
But this is very wasteful use of a matrix, since the dimension grows linearly with N. We can do better.

We can instead encode the number as binary, meaning now we can use a matrix with dimension (batch_size, log2(N)). This is much better in terms of the size of the matrix, but in practice, networks trained this way does not perform very well. This is because the encoding is very not smooth. We want an encoding where if a and b are close, then encode(a) and encode(b) is also close. With binary, binary(7) is 111, while binary(8) is 1000. Every bit has changed, so it is hard for the network to see those numbers as close.

Instead, we can use positional encoding. Positional encoding is like a smooth version of a binary expansion of numbers. 

The formula for positional encoding, where pos is the number representing the position, i is the i-th row of the encoding, and a is an arbitrary number, is

$PE_i(pos)=\sin(a^i pos)$



In [2]:
imname="images/Darga117.jpeg"
image=np.array(Image.open(imname))/255

imgshape=list(image.shape[:2])

datashape=(int(math.log(imgshape[0],2)+2),int(math.log(imgshape[1],2))+2)

encode_x = np.vectorize(lambda i:np.array([math.sin(i*2**-(j-1)) for j in range(datashape[0])]), signature=f"()->(datashape)")
encode_y = np.vectorize(lambda i:np.array([math.sin(i*2**-(j-1)) for j in range(datashape[1])]), signature=f"()->(datashape)")

In [3]:
# Generate all positional encodings for every coordinates first, so that it doesn't need to generate every sample

def alldata():
    x = np.arange(0,imgshape[0])
    y = np.arange(0,imgshape[1])

    x=encode_x(x)
    y=encode_y(y)

    return x,y

alldata=alldata()

# Sample from positional encodings, and corresponding RGB values of the image

def data(num):
    s1=np.random.randint(0,imgshape[0],num)
    s2=np.random.randint(0,imgshape[1],num)    
    
    x1=torch.from_numpy(alldata[0][s1]).type(torch.float)
    x2=torch.from_numpy(alldata[1][s2]).type(torch.float)
    
    y=torch.from_numpy(image[s1,s2]).type(torch.float)
    
    return (x1,x2),y

# Reconstruct the image from the model by runniing all positional encodings and reshaping them into 

def test():
    x1 = torch.from_numpy(alldata[0].repeat(imgshape[1],0)).type(torch.float)
    x2 = torch.from_numpy(np.concatenate([alldata[1]]*imgshape[0],0)).type(torch.float)
    
    x = model(x1,x2)
    x = np.uint8(x.reshape(imgshape+[-1]).detach().numpy()*255)

    Image.fromarray(x).save('images/out.png')

# Defining Our Model

The input to our model are pairs of x and y coordinates, (batch_size, x_encoding_length) and (batch_size, y_encoding_length).

The output will be the RGB value, (batch_size, 3).

How do we get use the input to get the output? we will need to define a model that does that. 

# ReZero

ReZero, or Residual with zero initialization, is a residual connection, except with a trainable parameter multiplying the shortcut initialized as 0. It trains better than a vanilla residual connection. The formula for ReZero is just $x + a y $ where $a$ is a parameter initialized as 0 before training, and $y = F(x)$, where $F$ is any network you want to do a shortcut.

# Gating

Gating information is a good way of only propagating information that is needed to the next layer, and discarding everything else. LSTM layers heavily use gating so the network can learn long distance information.

Here, we define a network that gets its inspiration from Attention layers. It runs the input through 2 linear layer and then do an elementwise multiplication. The result is passed through a Tanh layer, gating the information that gets multiplied by V.

You don't need to have ReZero or Gating just to make a neural network, but it's better than only using linear layers.

In [5]:
class ReZero(nn.Module):
    def __init__(self, alpha=0.01):
        super(ReZeroShortcut, self).__init__()
        self.alpha = nn.parameter.Parameter(torch.ones(1) * alpha)

    def forward(self, shortcut, x):
        return shortcut + self.alpha * x

class Gating(nn.Module):
    def __init__(self,i,m,o=-1):
        super(Gating,self).__init__()
        self.Q = nn.Linear(i,m)
        self.K = nn.Linear(i,m)
        self.V = nn.Linear(i,m)

        self.tanh = nn.Tanh()
        self.norm = nn.BatchNorm1d(m)
        self.gelu = nn.GELU()

    def forward(self,x):

        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        
        gate = self.tanh(q*k)
        x = gate*v        
        x = self.norm(x)
        x = self.gelu(x)
        
        return x    

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        size=100

        self.fcx = Gating(datashape[0],size)
        self.fcy = Gating(datashape[1],size)
        
        self.fc1 = Gating(2*size, size)
        self.fc2 = Gating(size, size)
        self.fc3 = Gating(size, size)
        self.fc4 = Gating(size, size)
        self.fc5 = Gating(size, size)
        self.fc6 = nn.Linear(size, 3)
        

        self.sc1 = ReZero()
        self.sc2 = ReZero()

    def forward(self,x,y):
        x = self.fcx(x)
        y = self.fcy(y)

        x = torch.cat((x,y),1)
        
        x = self.fc1(x)
        
        y = self.fc2(x)
        y = self.fc3(y)
        
        x = self.sc1(x,y)
        
        y = self.fc4(x)
        y = self.fc5(y)
        
        x = self.sc2(x,y)
        
        x = self.fc6(x)
        x = torch.sigmoid(x)
        
        return x
    
model=Model()
optimizer = optim.RMSprop(model.parameters(), lr=0.001)
mse_loss = nn.MSELoss()

# Train loop
 Now we get to setup a train loop for our network. In pytorch, there are 4 basic step to train a model. 
 
 First, you need to prepare the data. You can either generate data, or load from a dataset for the input and output.
 
 Next, you need to get the output prediction of the data from the model, and compute the loss. The loss is a single number that measures how far off from the target, given the data. In this case, we're using an L2 loss, or a mean squared error. It is calculated by taking the average of the squared difference between model prediction and the target. 
 
 We then take the gradient of the loss that we just computed, which will tell the network how to best tweak all its parameters so that the loss would be minimized, so that the predicted output would be closer to the target. 
 
 Finally, we update the parameters of the network using the oprimizer. Different optimizer keeps track of past gradients it already saw, and tries to be smart about how to update the network. For example, by having a moving average so that it can keep moving when it gets to a flat gradient area, or ignore small noise, or update quicker when all the past updates strongly suggest the direction is where it needs to go, etc.

In [None]:
i = 0
losses = []
while True:
    i+=1
    x,y=data(10000)    
            
    optimizer.zero_grad()

    loss = mse_loss(model(*x),y)

    loss.backward()

    optimizer.step()
    
    losses.append(loss.item())    
    
    if i%10==0:
        print(i, loss.item(), sum(losses)/len(losses))
    if i%100==0:
        test()

10 0.05127786099910736 0.051755746454000474
20 0.049994029104709625 0.051350790448486804
30 0.0490322969853878 0.05093179817001025
40 0.04837393760681152 0.05045821899548173
50 0.04791867733001709 0.05001378245651722
60 0.04659176990389824 0.04958857757349809
70 0.04690074175596237 0.049172211119106836
80 0.04644962400197983 0.048823737911880015
90 0.045071497559547424 0.04845580392413669
100 0.043894167989492416 0.04810748875141144
