<a href="https://colab.research.google.com/github/ShawonAshraf/annotated-transformer-flax/blob/main/nb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction

This notebook is going to be an attempt to replicate the implementation of the transformer architecture : [Attention is all you need](https://arxiv.org/abs/1706.03762), as shown in the [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/) notebook but using flax and jax transforms.

In [1]:
!pip install flax altair einops optax chex brax

Collecting brax
  Downloading brax-0.12.1-py3-none-any.whl.metadata (7.7 kB)
Collecting dm_env (from brax)
  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Collecting flask_cors (from brax)
  Downloading flask_cors-5.0.1-py3-none-any.whl.metadata (961 bytes)
Collecting jaxopt (from brax)
  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)
Collecting ml_collections (from brax)
  Downloading ml_collections-1.0.0-py3-none-any.whl.metadata (22 kB)
Collecting mujoco (from brax)
  Downloading mujoco-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m570.8 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting mujoco-mjx (from brax)
  Downloading mujoco_mjx-3.3.0-py3-none-any.whl.metadata (3.3 kB)
Collecting pytinyrenderer (from brax)
  Downloading pytinyrenderer-0.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting tenso

## Preliminary Imports

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import altair as alt

In [None]:
# jax master prng key
master_key = jax.random.key(2025)

## Transformer

Transformer is a neural network architecture to model sequences using only self attention, which is a departure from prior sequence modelling techniques, which either relied on convolutions or recurrent methods.

Both convolutions and recurrent methods have their drawbacks. To keep it short, convs can only attend a specific number of sequence elements at a time (window limit) and recurrent networks work sequentially, which make them slower. Self Attention removes both these limitations by parallely modelling the sequence against itself (gets rid of the window limit as well!).

Initially introduced (and still used) for language modelling, Transformer is an Encoder-Decoder architecture at the core, with the following flow for a sequence of tokens (from a text sequence):

![Flow Diagram for Transformer](https://github.com/ShawonAshraf/annotated-transformer-flax/blob/main/images/flow.png?raw=1)

We'll start with the embedding part first.

## Embedding

In [None]:
import flax.linen as nn

class SequenceEmbedding(nn.Module):
    d_model: int
    vocab_size: int

    @nn.compact
    def __call__(self, x):
        embeddings = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)
        return embeddings / jnp.sqrt(self.d_model)


The embeddings are scaled by a factor of $\frac{1}{\sqrt{D_{model}}}$, where $D_{model}$ is the dimension of the embeddings.

In [None]:
embed_key, master_key = jax.random.split(master_key, 2)

x = jnp.arange(0, 100)
embed_layer = SequenceEmbedding(20, 100)
vars = embed_layer.init(embed_key, x)
embeddings = embed_layer.apply(vars, x)

print(embeddings)

[[ 0.00881531  0.01145858 -0.0442136  ... -0.03573872  0.03622499
   0.06185519]
 [-0.02964733  0.01290533 -0.09276884 ...  0.06421088 -0.0477102
   0.04761793]
 [ 0.02071862 -0.00026872  0.04888517 ...  0.05003741  0.06048984
  -0.03646948]
 ...
 [-0.020246    0.0312751  -0.01027099 ...  0.0198579  -0.01917169
   0.00367736]
 [ 0.11928734  0.03426114  0.05348063 ...  0.00294668 -0.00699746
   0.04327941]
 [-0.00780651 -0.00029554  0.02707211 ... -0.0541172  -0.05628999
  -0.01496975]]


## Positional Encoding

The original Transformer paper implemented an "Absolute Position Encoding". Transformer doesn't have recurrence. As a result, it lacks the sense of position in a sequence, (also self attention compares a sequence against itself and no positional awareness will make it learn random correlations between the elements or tokens).

In [None]:
from einops import rearrange


class AbsolutePositionalEncoder(nn.Module):
    d_model: int
    max_len: int
    dropout: float = 0.1

    @nn.compact
    def __call__(self, x):
        # encoding
        encoding = np.zeros((self.max_len, self.d_model))
        position = np.arange(0, self.max_len)
        # must be in the shape max_len, 1
        position = rearrange(position, "max_len -> max_len 1")

        factor = np.exp(
            np.arange(0, self.d_model, 2) *
            (-np.log(np.array([1.0e4])) / self.d_model)
        )

        # encoding for odd and even positions
        # even, 0::2
        encoding[:, 0::2] = np.sin(position * factor)
        # odd, 1::2
        encoding[:, 1::2] = np.cos(position * factor)

        # reshape
        encoding = rearrange(encoding, "s dmodel -> 1 s dmodel")

        encoded_x = x + encoding[:, : x.shape[1]]
        # apply dropout
        encoded_x = nn.Dropout(self.dropout, deterministic=True)(encoded_x)

        return encoded_x

In [None]:
pos_key, master_key = jax.random.split(master_key, 2)


# input in the shape: 1, max_len, d_model
x = jnp.zeros((1, 100, 20))

encoder = AbsolutePositionalEncoder(20, 100, 0.0)
vars = encoder.init(pos_key, x)
encodings = encoder.apply(vars, x)

In [None]:
# using the same plotting function from the annotated transformer notebook

def plot_encoding(y, max_len, dim_range):
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "embedding": np.array(y)[0, :, dim],
                    "dimension": dim,
                    "position": list(range(max_len)),
                }
            )
            for dim in dim_range
        ]
    )

    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800)
        .encode(x="position", y="embedding", color="dimension:N")
        .interactive()
    )

In [None]:
chart = plot_encoding(encodings, 100, [4, 5, 6, 7])
chart.save("enc.png")

![Positional Encoding non interactive](https://github.com/ShawonAshraf/annotated-transformer-flax/blob/main/enc.png?raw=1)

In [None]:
chart.interactive()

## Encoder

The original transformer consisted of 6 encoder and decoder layers. Each encoder layer has 2 parts: the Multi Head Attention layer and the feed forward layer. The outputs from the layer are fed through an additional Residual (additive) and LayerNorm layer.

![Encoder flow diagram](https://github.com/ShawonAshraf/annotated-transformer-flax/blob/main/images/encoder.png?raw=1)


## Multi Head Attention

A multi head attention layer consists of $n$ self-attention layers (also known as heads). Each head is expected to model a different pattern or aspect from a sequence.

### Self-Attention

Self-attention compares a sequence against itself and tries to figure out how each element correlates to the rest of the components (learning semantic similarities for example). As a result, for a sequence of length $N$, this operation becomes an $N \times N$ comparison. Also, the self-attention has quadriatic memory complexity (hence all the context length limit and high gpu requirements in today's large language models).

Self-attention uses query, key and values to describe its operation, which can get a bit confusing without proper semantics. To keep things simpler, consider that you're searching for information in a sequence. The query is your question or search input, the key is the available information in the sequence (consider the sequence as your search space of context of information) and values are the probability of each element in the sequence being the desired search result.

(Latent spaces in the neural nets are just compressed search spaces, if you see it like that!)

So for query $Q$, key $K$ and values $V$, attention $A$ is

$$
A(Q, K, V) = softmax(QK^{T})V
$$

Why softmax? V stores the probabilities. And softmax gives you a probability distribution.

There's one small problem though, remember scaling the embeddings? It was done to avoid numerical overflow (quite common if your gradients suddenly explode during training). So, self-attention in transformer is also scaled by the a factor, which here is the dimension of the keys or the information space (consider this as narrowing your search space for speed).

$$
A(Q, K, V) = softmax(\frac{QK^{T}}{\sqrt{D_{key}}})V
$$

In a nutshell:

![self attention block diagram](https://github.com/ShawonAshraf/annotated-transformer-flax/blob/main/images/attn.png?raw=1)


For GPT like causal (or generative) models, which consist only of decoders (no encoders there!) there's a concept called masking. (Will get back to this later duding the decoder implementation!)

![self attention block diagram with masking](https://github.com/ShawonAshraf/annotated-transformer-flax/blob/main/images/attn_mask.png?raw=1)

In [None]:
def self_attention(q, k, v, mask=None, dropout=None):
    d_key = q.shape[-1]

    q_k_t = jnp.matmul(q, rearrange(k, "batch h seq k -> batch h k seq"))

    scaled = q_k_t / jnp.sqrt(d_key)

    if mask:
        # since jax numpy doesn't have a masked fill like np.ma.array
        # https://github.com/jax-ml/jax/discussions/9363#discussioncomment-2066105
        scaled = jnp.where(mask == 0, -1e9, x)


    proba = nn.softmax(scaled)
    if dropout:
        proba = dropout(proba)

    attn = jnp.matmul(proba, v)
    return attn


Also, why mat-mul? Faster than scaled products and can be acclelerated on GPUs and TPUs (via XLA and CUDA).

Okay, now that there's one attention head defined, we can have a look at Multi-Head-Attention.

### Attention Heads

![multi head attention block](https://github.com/ShawonAshraf/annotated-transformer-flax/blob/main/images/mha.png?raw=1)

In [None]:
class MultiHeadAttention(nn.Module):
    nheads: int
    d_model: int
    dropout_value:float = 0.1


    def setup(self):
        assert self.d_model % self.nheads == 0, "Number of heads must be a multiple of d_model"

        self.d_k = self.d_model // self.nheads

        # 3 dense linear layers, for q, k, v
        self.linear_layers = [nn.Dense(self.d_model)] * 3
        self.dropout = nn.Dropout(self.dropout_value, deterministic=True)

        # one final dense layer
        self.final_linear = nn.Dense(self.d_model)

    def __call__(self, q, k, v, mask=None):
        if mask:
            mask = rearrange(mask, "s -> s 1")

        # apply the dense layers
        q, k, v = [
            rearrange(lin(i), "batch seq (h k) -> batch h seq k", h=self.nheads, k=self.d_k)
            for lin, i in zip(self.linear_layers, (q, k, v))
        ]


        # attention
        score = self_attention(q, k, v, dropout=self.dropout, mask=mask)

        # concat the scores
        concatenated = jnp.concat(score)

        # pass through a linear layer
        return self.final_linear(concatenated)

In [None]:
mha_key, master_key = jax.random.split(master_key, 2)

mha = MultiHeadAttention(4, 20)
vars = mha.init(
    mha_key,
    encodings,
    encodings,
    encodings
)

mha_out = mha.apply(
    vars,
    encodings,
    encodings,
    encodings
)

mha_out.shape

(4, 100, 20)

(A note to the people who wrote the updated version of the annotated transformer notebook: were you writing for an audience or just revising for an exam? Things which can be said can always be said in easier words - same goes for code, don't overly complicate variable naming and function signatures.)

### Layer Normalisation

When you're dealing with sequences, they can be of variable lengths. And for this, the outputs (or activations) from a layer can have different sizes. So, you can't normalise them as batch or as a group. Layer norm can handle variable length outputs from a neural network layer. (P.S. layer norm doesn't perform well for fixed sized inputs, you're better off with batch norm or group norm).

In [None]:
# you know, the pytorch docs were more helpful to figure this one out
# https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html

# TODO: fix layer norm

class LayerNorm(nn.Module):
    normalised_shape: tuple
    eps: float = 1e-6

    def setup(self):
        # keeping bias 0
        self.gamma = jnp.ones(shape=self.normalised_shape)

    def __call__(self, x):
        mean = x.mean(axis=-1, keepdims=True)
        std = x.std(axis=-1, keepdims=True)
        norm = self.gamma * (x - mean) / (std + self.eps)

        return norm

### Residual + Norm Layer

In [None]:
class AddAndNormLayer(nn.Module):
    normalised_shape: tuple
    dropout_rate: float

    @nn.compact
    def __call__(self, x):
        return x + LayerNorm(self.normalised_shape)(x)

In [None]:
norm_res_key, master_key = jax.random.split(master_key, 2)

add_norm_layer = AddAndNormLayer(mha_out.shape, 0.1)
vars = add_norm_layer.init(norm_res_key, mha_out)
add_norm_out = add_norm_layer.apply(vars, mha_out)
add_norm_out.shape

(4, 100, 20)

### Feed Forward Layer

In [None]:
class FeedForwardLayer(nn.Module):
    d_model: int
    hidden: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.d_model, self.hidden)(x)
        x = nn.Dense(self.hidden, self.d_model)(x)
        x = nn.Dropout(self.dropout_rate, deterministic=True)(x)
        # leaky_relu can be used as well!
        x = nn.relu(x)

        return x

In [None]:
ff_key, master_key = jax.random.split(master_key, 2)

ff_layer = FeedForwardLayer(20, 10)
vars = ff_layer.init(ff_key, add_norm_out)
ff_out = ff_layer.apply(vars, add_norm_out)

ff_out.shape

(4, 100, 10)

### Gathering the encoder layers

In [None]:
class TransformerEncoder(nn.Module):
    d_model: int
    ff_hidden: int
    dropout_rate: float = 0.1
    nheads: int = 4

    @nn.compact
    def __call__(self, x):
        x = MultiHeadAttention(self.nheads, self.d_model, self.dropout_rate)(x, x, x)

        x = AddAndNormLayer(x.shape, self.dropout_rate)(x)

        x = FeedForwardLayer(self.d_model, self.ff_hidden, self.dropout_rate)(x)

        x = AddAndNormLayer(x.shape, self.dropout_rate)(x)

        return x

In [None]:
transformer_encoder = TransformerEncoder(d_model=20, ff_hidden=80)

enc_key, master_key = jax.random.split(master_key, 2)
vars = transformer_encoder.init(enc_key, encodings)
enc_out = transformer_encoder.apply(vars, encodings)
enc_out.shape

(4, 100, 80)