# Positional Encoding

This notebook demonstrates a method to encode token positions in a sequence into tensors as a way of injecting positional information into model flows.

![alt text](embedding_flow.png "Embedding Flow")

To include positional information, we need a way of encoding the position of a token in the same dimension as the token embedding in order to combine the two tensors.

For given inputs

- <b>pos</b>: position of the token in the sequence
- <b>d</b>: dimension of the encoded output
- and 0 <= <b>i</b> <= d/2

the positional encoding compute as follows:

![alt text](formula.png "Positional Encoding Formula")

Source: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)


In [None]:
_ = !pip install torch
_ = !pip install matplotlib

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, dim_out: int):
        super().__init__()
        self.dim_out = dim_out
        self.n = 10000

    def forward(self, pos):
        even = torch.arange(0, 2 * self.dim_out / 2, 2).view(1, -1).repeat(pos.size(0), 1)
        odd = torch.arange(1, 2 * self.dim_out / 2 + 1, 2).view(1, -1).repeat(pos.size(0), 1)
        pos = pos.view(-1, 1)
        even = torch.sin(pos / torch.pow(self.n, (even / self.dim_out)))
        odd = torch.cos(pos / torch.pow(self.n, (odd / self.dim_out)))
        out = torch.zeros(pos.shape[0], self.dim_out)
        out[:, 0::2] = even
        out[:, 1::2] = odd
        return out        

### Example

In [None]:
d = 100
sequence_length = 200

pe = PositionalEncoding(d)
x = torch.tensor(range(0, sequence_length))
y = pe(x)
print(y.shape)

### Visualization

In [None]:
plt.figure()
plt.plot(y[5].numpy(), label="5")
plt.plot(y[6].numpy(), label="6")
plt.plot(y[50].numpy(), label="50")
plt.xlabel('encoding dimension')
plt.ylabel('encoding value')
plt.legend(title="position")

In [None]:
cax = plt.matshow(y.numpy().transpose(), cmap='Purples')
plt.gcf().colorbar(cax)
plt.xlabel('token position')
plt.ylabel('encoding dimension')