## Introduction

This notebook is going to be an attempt to replicate the implementation of the transformer architecture : [Attention is all you need](https://arxiv.org/abs/1706.03762), as shown in the [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/) notebook but using flax and jax transforms.

## Preliminary Imports

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import altair as alt

In [2]:
# jax master prng key
master_key = jax.random.key(2025)

## Transformer

Transformer is a neural network architecture to model sequences using only self attention, which is a departure from prior sequence modelling techniques, which either relied on convolutions or recurrent methods. 

Both convolutions and recurrent methods have their drawbacks. To keep it short, convs can only attend a specific number of sequence elements at a time (window limit) and recurrent networks work sequentially, which make them slower. Self Attention removes both these limitations by parallely modelling the sequence against itself (gets rid of the window limit as well!).

Initially introduced (and still used) for language modelling, Transformer is an Encoder-Decoder architecture at the core, with the following flow for a sequence of tokens (from a text sequence):

![Flow Diagram for Transformer](./images/flow.png)

We'll start with the embedding part first.

## Embedding

In [None]:
import flax.linen as nn

class SequenceEmbedding(nn.Module):
    d_model: int
    vocab_size: int
    
    @nn.compact
    def __call__(self, x):
        embeddings = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)
        return embeddings / jnp.sqrt(self.d_model)


The embeddings are scaled by a factor of $\frac{1}{\sqrt{D_{model}}}$, where $D_{model}$ is the dimension of the embeddings.

In [8]:
embed_key, master_key = jax.random.split(master_key, 2)

x = jnp.arange(0, 100)
embed_layer = SequenceEmbedding(20, 100)
vars = embed_layer.init(embed_key, x)
embeddings = embed_layer.apply(vars, x)

print(embeddings)

[[-0.0322848   0.05868233  0.10417529 ...  0.06455109  0.04813394
   0.07360873]
 [ 0.00426957 -0.01262947  0.01007479 ... -0.09050234  0.00743644
   0.01120917]
 [ 0.00060597  0.04499757 -0.07457425 ...  0.05564927  0.05339404
   0.09623987]
 ...
 [ 0.03571852  0.04928762  0.04179859 ...  0.03035531  0.00036091
  -0.00393461]
 [-0.00223833 -0.07784591  0.06716082 ...  0.01606899  0.0080359
   0.05167094]
 [-0.01754706  0.01652121  0.01451415 ... -0.06777494  0.00777013
   0.05804292]]


## Positional Encoding

The original Transformer paper implemented an "Absolute Position Encoding". Transformer doesn't have recurrence. As a result, it lacks the sense of position in a sequence, (also self attention compares a sequence against itself and no positional awareness will make it learn random correlations between the elements or tokens).

In [9]:
from einops import rearrange


class AbsolutePositionalEncoder(nn.Module):
    d_model: int
    max_len: int
    dropout: float = 0.1

    @nn.compact
    def __call__(self, x):
        # encoding
        encoding = np.zeros((self.max_len, self.d_model))
        position = np.arange(0, self.max_len)
        # must be in the shape max_len, 1
        position = rearrange(position, "max_len -> max_len 1")

        factor = np.exp(
            np.arange(0, self.d_model, 2) *
            (-np.log(np.array([1.0e4])) / self.d_model)
        )

        # encoding for odd and even positions
        # even, 0::2
        encoding[:, 0::2] = np.sin(position * factor)
        # odd, 1::2
        encoding[:, 1::2] = np.cos(position * factor)

        # reshape
        encoding = rearrange(encoding, "s dmodel -> 1 s dmodel")

        encoded_x = x + encoding[:, : x.shape[1]]
        # apply dropout
        encoded_x = nn.Dropout(self.dropout, deterministic=True)(encoded_x)

        return encoded_x

In [12]:
pos_key, master_key = jax.random.split(master_key, 2)


# input in the shape: 1, max_len, d_model
x = jnp.zeros((1, 100, 20))

encoder = AbsolutePositionalEncoder(20, 100, 0.0)
vars = encoder.init(pos_key, x)
encodings = encoder.apply(vars, x)

In [13]:
# using the same plotting function from the annotated transformer notebook

def plot_encoding(y, max_len, dim_range):
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "embedding": np.array(y)[0, :, dim],
                    "dimension": dim,
                    "position": list(range(max_len)),
                }
            )
            for dim in dim_range
        ]
    )

    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800)
        .encode(x="position", y="embedding", color="dimension:N")
        .interactive()
    )

In [None]:
chart = plot_encoding(encodings, 100, [4, 5, 6, 7])
# static for github viewing
chart.save("enc.png")

![Plot of positional encoding](enc.png)

In [None]:
# for an interactive version
chart.interactive()