# Why use positional encoding

If you're anything like me, while trying to implement transformers, you've read the original [attention is all you need](https://arxiv.org/abs/1706.03762) paper, [the annotated transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html), [the updated version](http://nlp.seas.harvard.edu/annotated-transformer/), [d2l.ai](https://d2l.ai/chapter_attention-mechanisms-and-transformers/transformer.html), and had to cobble them all together to get something going. This post is an attempt to make that proccess easier for people like me in a short and to-the-point style. You can think of this as a bare-bones implementation with a whole lot of documentation.

Notes:
1. this post is about *how* transformers are implemented, not *why* they're implemented the way they are.
2. This post pretty much takes you from just the basics of ML through understanding how transformers work. Feel free to skip whatever sections you're familiar with.

# Overview

The following is the transformer architecture diagram taken from the original paper. We'll be referring back to it often in this post.

![image.png](./assets/tformer.png)
Transformer architecture diagram

At a high level, the transformer is an encoder-decoder model; it takes a **sequence** of **tokens** from a source (e.g. English words) and learns to translate that into a destination sequence (e.g. French words). The encoder is the left side and the decoder is the right side of the architecture diagram.

For the rest of the post, when we refer to "tokens" or "sequences", feel free to replace them with "words" and "sentences" (or "paragraphs", "pages", "books", etc.) respectively in your head.

# Preamble

We'll be using [jax](https://jax.readthedocs.io/en/latest/index.html) and [flax](https://flax.readthedocs.io/en/latest/index.html) for this implementation.

[embeddings](http://wikipedia2vec.s3.amazonaws.com/models/en/2018-04-20/enwiki_20180420_100d.pkl.bz2)

In [9]:
import numpy as np
import pandas as pd

In [13]:
corrs = (
    pd.DataFrame(np.corrcoef(np.random.uniform(size=(5, 30))),
                 index=pd.RangeIndex(stop=5, name='a'),
                 columns=pd.RangeIndex(stop=5, name='b'))
    .stack()
    .to_frame('$\rho$')
    .reset_index()
)
corrs

Unnamed: 0,a,b,$\rho$
0,0,0,1.0
1,0,1,-0.046218
2,0,2,-0.006698
3,0,3,0.249767
4,0,4,-0.240984
5,1,0,-0.046218
6,1,1,1.0
7,1,2,0.368624
8,1,3,0.041882
9,1,4,-0.074076


In [18]:
alt.Chart(corrs, title="Daily Max Temperatures (C) in Seattle, WA").mark_rect().encode(
    alt.X("a:O", title="Day"),
    alt.Y("b:O", title="Month"),
    alt.Color("$\rho$", title='rho'),
    tooltip=[
        alt.Tooltip("a", title="Date"),
        alt.Tooltip("b", title="Max Temp"),
    ],
).configure_view(
    step=13 * 2,
    strokeWidth=0
).configure_axis(
    domain=False
)

In [8]:
import altair as alt
from vega_datasets import data

source = data.seattle_weather()

alt.Chart(source, title="Daily Max Temperatures (C) in Seattle, WA").mark_rect().encode(
    alt.X("date(date):O", title="Day"),
    alt.Y("month(date):O", title="Month"),
    alt.Color("max(temp_max)", title='temp'),
    tooltip=[
        alt.Tooltip("monthdate(date)", title="Date"),
        alt.Tooltip("max(temp_max)", title="Max Temp"),
    ],
).configure_view(
    step=13,
    strokeWidth=0
).configure_axis(
    domain=False
)

In [3]:
import pandas as pd

In [10]:
tmp = pd.read_csv('/home/amniskin/code/daggerml/transformers/data/glove.6B.50d.txt', delim_whitespace=True, header=None)
tmp.head()

ParserError: Error tokenizing data. C error: EOF inside string starting at row 8

In [4]:
source.dtypes

date             datetime64[ns]
precipitation           float64
temp_max                float64
temp_min                float64
wind                    float64
weather                  object
dtype: object

In [1]:
import chex
from flax import linen as nn, core
import jax.numpy as jnp
import jax.random as jran
import jax.tree_util
from math import prod
from typing import Callable, Sequence
import matplotlib.pyplot as plt

key = jran.PRNGKey(0)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


# Input (output) embeddings
![image.png](../_static/post/another-annotated-transformer/embed-circled.png)

Let's start with the beginning of the data flow: the Input layer.

## Tokenization
The vocabulary we'll be using is digits 0-9. We then add a **pad** token (`"p"`) and a **start** token (`"s"`), both of which we'll explain later. So our tokens are `["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "p", "s"]`. Our first task is to turn this into something numeric.

In [2]:
vocab = {str(x): i for i, x in enumerate(list(range(10)) + ['p', 's'])}
vocab

{'0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 'p': 10,
 's': 11}

Now we can turn a number into a sequence of token indices like so:

In [3]:
def num2tokens(num):
    return jnp.array([vocab.get(x) for x in str(num)])

num2tokens(232)

Array([2, 3, 2], dtype=int32)

## Input embedding

An embedding is a learned mapping from tokens to (hopefully) more densly populated vectors. Under the hood it's stored as a matrix but that's an implementation detail. Below is a model that embeds a vocabulary with 5 unique tokens to 2 dimensions.

In [4]:
model = nn.Embed(len(vocab), 2)

Note that we can easily inspect the parameters of a model, and that the **embedding** parameter below is a $5\times 2$ matrix. It's $5\times2$ because it's storing 5 independent 2-dimensional embeddings.

In [5]:
params = model.init(key, jnp.array([1]))
params

FrozenDict({
    params: {
        embedding: Array([[-0.07138442,  0.31376272],
               [ 0.4053369 ,  0.9015168 ],
               [-0.5195726 ,  0.2808921 ],
               [-0.9512113 ,  0.35392332],
               [-1.8304102 , -0.332351  ],
               [ 0.40043908,  0.08431281],
               [-0.4319639 , -0.7836565 ],
               [ 0.54686207,  0.35685003],
               [ 0.07191442, -0.1766023 ],
               [-0.48776352,  0.33678803],
               [ 1.6915971 ,  0.6755828 ],
               [ 0.07079946,  0.05526797]], dtype=float32),
    },
})

To get embeddings from tokens, we can either do it ourselves:

In [6]:
nn.one_hot(num2tokens(2234), len(vocab)) @ params['params']['embedding']

Array([[-0.5195726 ,  0.2808921 ],
       [-0.5195726 ,  0.2808921 ],
       [-0.9512113 ,  0.35392332],
       [-1.8304102 , -0.332351  ]], dtype=float32)

Or use the flax api (which is going to be necessary once our model gets more complicated):

In [7]:
model.apply(params, num2tokens(2234))

Array([[-0.5195726 ,  0.2808921 ],
       [-0.5195726 ,  0.2808921 ],
       [-0.9512113 ,  0.35392332],
       [-1.8304102 , -0.332351  ]], dtype=float32)

# Sequence masking
Transformers have a fixed context window. This means there's a maximum sequence length and every sequence passed to the model must be of that length. We handle this issue in two ways:
1. padding
2. masking

## Padding
Let's say our max sequence length is 5 and we're constructing a training batch from the following 4 numbers: `[123, 42451, 0, 12]`. Our sequences look like this:

In [8]:
[num2tokens(x) for x in [123, 42451, 0, 12]]

[Array([1, 2, 3], dtype=int32),
 Array([4, 2, 4, 5, 1], dtype=int32),
 Array([0], dtype=int32),
 Array([1, 2], dtype=int32)]

They're all of different lengths so we can't put them into a matrix, and we can't feed them to our model. We need to pad.

In [9]:
tokens = [num2tokens(x) for x in [123, 42451, 0, 12]]
tokens = jnp.array([jnp.pad(x, (0, 5 - len(x)), 'constant', constant_values=(0, vocab['p']))
                    for x in tokens])
tokens

Array([[ 1,  2,  3, 10, 10],
       [ 4,  2,  4,  5,  1],
       [ 0, 10, 10, 10, 10],
       [ 1,  2, 10, 10, 10]], dtype=int32)

In [10]:
model.apply(params, tokens).shape

(4, 5, 2)

## Masking
Now we can fit these sequences into a batch, but we don't want the model to learn from the padded characters. So for that, we [mask](https://www.ml-science.com/masking). The idea being that we identify which tokens we don't want feeding into gradients so that we can later ensure they don't. There are several ways to do it. Here's one:

In [11]:
mask = (tokens == vocab['p']).astype(jnp.float32)
mask

Array([[0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0.],
       [0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.]], dtype=float32)

More on this in a bit

# Positional encoding

![image.png](assets/posenc-circled.png)

The positional encoding is the same size as a single observation fed to the model and added to each observation in the batch. We use the same function as they used in the original paper. Let $X\in\mathbb{R}^{s\times d}$ where $s$ is the max sequence length, and $d$ is the embedding dimension.

$$f(X_{i,j}) = \begin{cases}
\sin\left(i/\left(10000^{j/d}\right)\right) & \text{if } j\equiv 0\pmod{2} \\
\cos\left(i/\left(10000^{(j-1)/d}\right)\right) & \text{if } j\equiv 1\pmod{2}
\end{cases}$$

In [12]:
def sin_pos_enc(sequence_length, embed_dim):
    """create sin/cos positional encodings

    Paramters
    =========
    sequence_length : int
        The max length of the input sequences for this model
    embed_dim : int
        the embedding dimension

    Returns
    =======
    a matrix of shape: (sequence_length, embed_dim)
    """
    chex.assert_is_divisible(embed_dim, 2)
    X = jnp.expand_dims(jnp.arange(sequence_length), 1) / \
        jnp.power(10000, jnp.arange(embed_dim, step=2) / embed_dim)
    out = jnp.empty((sequence_length, embed_dim))
    out = out.at[:, 0::2].set(jnp.sin(X))
    out = out.at[:, 1::2].set(jnp.cos(X))
    return out

sin_pos_enc(5, 2)

Array([[ 0.        ,  1.        ],
       [ 0.84147096,  0.5403023 ],
       [ 0.9092974 , -0.41614684],
       [ 0.14112   , -0.9899925 ],
       [-0.7568025 , -0.6536436 ]], dtype=float32)

So when we implement the `Encoder` and the `Decoder`, we'll sum the embeddings and the positional encodings (in practice, we scale the embeddings beforehand).

In [13]:
params['params']['embedding'] * sin_pos_enc(*params['params']['embedding'].shape)

Array([[-0.0000000e+00,  3.1376272e-01],
       [ 3.4107921e-01,  4.8709157e-01],
       [-4.7244602e-01, -1.1689236e-01],
       [-1.3423494e-01, -3.5038143e-01],
       [ 1.3852590e+00,  2.1723911e-01],
       [-3.8399076e-01,  2.3916358e-02],
       [ 1.2069740e-01, -7.5244367e-01],
       [ 3.5928103e-01,  2.6903003e-01],
       [ 7.1149126e-02,  2.5695641e-02],
       [-2.0101637e-01, -3.0685776e-01],
       [-9.2026454e-01, -5.6686229e-01],
       [-7.0798770e-02,  2.4459933e-04]], dtype=float32)

# Attention
![image.png](assets/attn-circled.png)

Transformers are built around this **attention** mechanism, so this warrants its own section. Multi-head attention involves stacking a collection of attention "heads" and adding some learned weights in the mix. As such, we'll start with attention heads and progress to multi-head attention.

Attention is just a function that takes 3 arguments (key, value, and query) and aggregates them to a vector. There are a few forms of attention but we'll focus on the one used in the seminal paper: **scaled dot product** attention.

## Scaled dot product attention

Let $Q\in\mathbb{R}^{n\times d},K\in\mathbb{R}^{m\times d},V\in\mathbb{R}^{m\times v}$ be the **query**, **key**, and **value**. Basically we just need the shapes to be fit for the matrix multiplication below. A good reference for this is [d2l.ai](https://d2l.ai/chapter_attention-mechanisms-and-transformers/attention-scoring-functions.html).

$$\text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\in\mathbb{R}^{n\times v}$$

The $\text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)$ part is called the **attention weights**.

It's worthwhile to note that there are no learnable weights in this formula.

This formula is deceptive in 2 ways:
1. The softmax is actually a masked softmax
2. There's generally some dropout on the attention weights

Let's take this piece by piece.

### Masked softmax

Let $X\in\mathbb{R}^k$ a vector, then $\text{softmax}(X)\in\mathbb{R}^k$.

$$\text{softmax}(X)_i = \frac{e^{X_i}}{\sum_{j=0}^{k-1}e^{X_j}}$$

It's just normalization with a monotonic function applied, meaning the relative ranking of the elements of $X$ aren't changed. For more on this, see [this](https://charlielehman.github.io/post/visualizing-tempscaling/) post.

For masked softmax, we'll be taking the approximate approach. Because of the sum in the denominator and the exponentiation, it's unwise to mask with 0 ($e^0 = 1$). Instead we'll mask with a very large negative number before we exponentiate so that the result is close to 0 ($e^{-\infty} \approx 0$).

Let's recall our tokens and mask from earlier:

In [14]:
tokens

Array([[ 1,  2,  3, 10, 10],
       [ 4,  2,  4,  5,  1],
       [ 0, 10, 10, 10, 10],
       [ 1,  2, 10, 10, 10]], dtype=int32)

In [15]:
mask

Array([[0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0.],
       [0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.]], dtype=float32)

In [16]:
def masked_softmax(args, mask):
    if mask is not None:
        args = args + (mask.astype(args.dtype) * -10_000.0)
    return nn.softmax(args)

masked_softmax(tokens, mask)

Array([[0.09003057, 0.24472846, 0.6652409 , 0.        , 0.        ],
       [0.20393994, 0.02760027, 0.20393994, 0.55436623, 0.01015357],
       [1.        , 0.        , 0.        , 0.        , 0.        ],
       [0.26894143, 0.7310586 , 0.        , 0.        , 0.        ]],      dtype=float32)

### The meat

Let's start with a few simplifications and then add pieces in.

1. assume that $K=V$,
2. assume the rows of $V$ all have unit norm
3. ignore the softmax and the $1/\sqrt{d}$ factor.
    
Under these assumptions, the attention weights become $(QV^T)V$, which is the [projection](https://en.wikipedia.org/wiki/Vector_projection) of $Q$ onto the rows of $V$. Things are pretty interpretable under those simplifying assumptions, so let's start justifying them.

1. Although it's not true that $K=V$ in general, it is true when it comes to transformers. You can see in the diagram that the key and value (although not labeled) are always the same in all attention heads.
2. We have normalization layers after every step, so this shouldn't be too off.
3. Although this does break the interpretation, it's not so bad still because it's at least a monotonically increasing.

In [17]:
Q = jran.normal(key, (3, 7))
K = jran.normal(key, (5, 7))
V = jran.normal(key, (5, 11))
(masked_softmax(Q @ K.T / jnp.sqrt(Q.shape[-1]), None) @ V).shape

(3, 11)

When we add a batch dimension (and later a `num_heads` dimension) we'll want to broadcast the multiplication over those dimensions, so we won't be able to use `K.T`. Instead, we'll use `K.swapaxes`, which will allow us to swap only the last two dimensions.

In [18]:
Q = jran.normal(key, (11, 3, 7))
K = jran.normal(key, (11, 5, 7))
V = jran.normal(key, (11, 5, 11))
(masked_softmax(Q @ K.swapaxes(-1, -2) / jnp.sqrt(Q.shape[-1]), None) @ V).shape

(11, 3, 11)

Bringing it all together and adding some dropout to the attention weights, we end up with:

In [19]:
def dot_prod_attn(q, k, v, dropout=lambda x: x, mask=None):
    # NxD @ DxM => NxM
    # (B[, H], N, M)
    attn_scores = q @ k.swapaxes(-2, -1) / jnp.sqrt(q.shape[-1])
    attn_weights = masked_softmax(attn_scores, mask)
    # (B[, H], N, D)
    out = dropout(attn_weights) @ v
    return out, attn_weights

## Multihead Attention

![image.png](./assets/multi-head-attention.png)

At a high level, mutli-head attention is a bunch of stacked attention layers. But given that there are no learnable weights in the attention heads (they query, key, and values are all arguments), each would yield the same result -- not so useful. So instead, we train a linear layer per attention head, and then combine the results.

Our interpretation of attention being the projection needs some adjustment now that we're applying these linear layers (it's no longer true that $K=V$). But we'll leave this for the reader (or maybe another post).

### One linear vs stacked linears

In practice, most implementations use one linear layer and reshape the output rather than storing a collection of linear models. At first that might not seem kosher, but it is. Here's a visual interpretation.

![stacked-linears](../_static/post/another-annotated-transformer/nn.graph.png)

To show this in action, we'll approach this both ways. We'll first instantiate a dense layer much like the full diagram above, then we'll split out the red and blue arrows, and compare the two outputs.

In [20]:
batch_size = 2
sequence_length = 5
embed_dim = 3
n_heads = 2
size_per_head = 2

In [21]:
X = jnp.arange(batch_size * sequence_length * embed_dim)
X = X.reshape((batch_size, sequence_length, embed_dim))
X.shape

(2, 5, 3)

In [22]:
params = nn.Dense(n_heads * size_per_head).init(key, X)
params

FrozenDict({
    params: {
        kernel: Array([[ 0.4087802 ,  0.43891278, -0.23872387, -0.8494273 ],
               [ 0.41122693, -0.5888459 , -0.55229884,  0.49776074],
               [ 0.3480036 , -0.7046275 , -0.30813402, -1.21659   ]],      dtype=float32),
        bias: Array([0., 0., 0., 0.], dtype=float32),
    },
})

In [23]:
param_list = []
for i in range(n_heads):
    p = params.unfreeze()
    p['params'] = {
        'kernel': p['params']['kernel'][:, i * size_per_head:(i + 1) * size_per_head],
        'bias': p['params']['bias'][i * size_per_head:(i + 1) * size_per_head]
    }
    param_list.append(core.freeze(p))

In [24]:
display(jnp.stack([nn.Dense(size_per_head).apply(p, X) for p in param_list]).swapaxes(0, 1)[0])
print('=' * 50)
display(nn.Dense(n_heads * size_per_head).apply(params, X)\
        .reshape((batch_size, sequence_length, n_heads, size_per_head))\
        .swapaxes(1, 2)[0])

Array([[[  1.1072341,  -1.998101 ],
        [  4.611266 ,  -4.561783 ],
        [  8.115298 ,  -7.1254644],
        [ 11.61933  ,  -9.689147 ],
        [ 15.123363 , -12.252829 ]],

       [[ -1.168567 ,  -1.9354193],
        [ -4.466037 ,  -6.640189 ],
        [ -7.7635074, -11.344959 ],
        [-11.060977 , -16.049728 ],
        [-14.358448 , -20.7545   ]]], dtype=float32)



Array([[[  1.1072341,  -1.998101 ],
        [  4.611266 ,  -4.561783 ],
        [  8.115298 ,  -7.1254644],
        [ 11.61933  ,  -9.689147 ],
        [ 15.123363 , -12.252829 ]],

       [[ -1.168567 ,  -1.9354193],
        [ -4.466037 ,  -6.640189 ],
        [ -7.7635074, -11.344959 ],
        [-11.060977 , -16.049728 ],
        [-14.358448 , -20.7545   ]]], dtype=float32)

We're ready to implement `Multi-Head Attention` layer. You'll notice there's some weird stuff going on with the mask. This is intentional and it'll make more sense once we implement the `DecoderLayer`[^mha_mask].

[^mha_mask]: You have OCD, huh? Yeah, me too... If check out the diagram, notice the `Masked Multi-Head Attention` layer. We'll be masking the attention weights for the self attention layer. That mask is of a different shape from the sequence masking, so these counterintuitive lines dealing with the mask are to facilitate that later. It seemed easier than having to edit the classes and whatnot just for that one issue.

In [25]:
class MultiHeadAttention(nn.Module):
    n_heads: int
    size_per_head: int
    attn_dropout: float
    fc_dropout: float
    attn_fn: Callable = dot_prod_attn

    @nn.compact
    def __call__(self, q, k, v, mask=None, training=False):
        "expected shape: Batch, [N|M], Dim"
        B, N, D = q.shape
        _, M, _ = k.shape

        def qkv_layer(x, name):
            x = nn.Dense(self.n_heads * self.size_per_head, name=name)(x)
            x = x.reshape((B, -1, self.n_heads, self.size_per_head)).swapaxes(1, 2)
            return x
        # BxNxD => BxHxNxP
        q = qkv_layer(q, 'query_linear')
        # BxMxD => BxHxMxP
        k = qkv_layer(k, 'key_linear')
        # BxMxD => BxHxMxP
        v = qkv_layer(v, 'value_linear')
        if mask is not None:
            # accounting for reshape in qkv_layer
            # B[xN]xN   => Bx1[xN]xN
            mask = jnp.expand_dims(mask, 1)
            if mask.ndim < q.ndim:
                # softmax is applied to dim -1
                # Bx1xN => Bx1x1xN
                mask = jnp.expand_dims(mask, -2)
        attn_do = nn.Dropout(self.attn_dropout, deterministic=not training, name='attn_dropout')
        out, attn_weights = self.attn_fn(q, k, v, attn_do, mask=mask)
        # uncomment to keep attention weights in state
        # self.sow('intermediates', 'weights', attn_weights)
        out = out.swapaxes(1, 2).reshape((B, N, -1))
        out = nn.Dense(D, name='output_linear')(out)
        out = nn.Dropout(self.fc_dropout, deterministic=not training, name='fc_dropout')(out)
        return out

To better understand this model, we'll calculate the number of parameters and then see if we're right.

Each of the `query`, `key`, and `value` linear layers has `embed_dim * n_heads * size_per_head + size_per_head` many parameters (the kernel and the bias terms). The final linear layer brings us back to `embed_dim` size, so we have `n_heads * size_per_head * embed_dim + embed_dim`.

All together, our formula is:

In [26]:
3 * (n_heads * (size_per_head * embed_dim + size_per_head)) + (n_heads * size_per_head * embed_dim + embed_dim)

63

Let's check our work.

In [27]:
n_heads, size_per_head

(2, 2)

In [28]:
X = jran.uniform(key, (batch_size, sequence_length, embed_dim))
X.shape

(2, 5, 3)

In [29]:
mdl = MultiHeadAttention(n_heads, size_per_head, attn_dropout=0.2, fc_dropout=0.3)
params = mdl.init(key, X, X, X, mask=(jnp.max(X, axis=-1) < 0.8).astype(jnp.float32))

jax.tree_map(jnp.shape, params)

FrozenDict({
    params: {
        key_linear: {
            bias: (4,),
            kernel: (3, 4),
        },
        output_linear: {
            bias: (3,),
            kernel: (4, 3),
        },
        query_linear: {
            bias: (4,),
            kernel: (3, 4),
        },
        value_linear: {
            bias: (4,),
            kernel: (3, 4),
        },
    },
})

In [30]:
nn.tabulate(mdl, key, console_kwargs=dict(force_jupyter=True))(X, X, X);

Since we'll want to see this a few times, let's write a function for it.

In [31]:
def num_params(params):
    return jnp.sum(jnp.array(jax.tree_util.tree_flatten(jax.tree_map(lambda x: jnp.prod(jnp.array(jnp.shape(x))), params))[0])).item()

num_params(params)

63

# Add & Norm

We'll implement an `AddAndNorm` layer just so our code looks like the diagram. The layer is so simple that you're likely to see implementations that don't implement this and just do it in the `EncoderLayer` or `DecoderLayer`.

In [32]:
class AddAndNorm(nn.Module):
    """The add and norm."""

    @nn.compact
    def __call__(self, X, X_out):
        return nn.LayerNorm()(X + X_out)

# Feed forward

Same deal as `AddAndNorm`.

In [33]:
class FeedForward(nn.Module):
    """a 2-layer feed-forward network."""
    hidden_dim: int

    @nn.compact
    def __call__(self, X):
        D = X.shape[-1]
        X = nn.Dense(self.hidden_dim)(X)
        X = nn.relu(X)
        X = nn.Dense(D)(X)
        return X

# Encoder

## EncoderLayer

![image.png](../_static/post/another-annotated-transformer/encoder-layer-circled.png)

The `Encoder` is a combination of the various layers we've already built up along with several `EncoderLayer`s (which are themselves just combinations of previously defined layers). This section is going to be short.

Note the `EncoderLayer` takes one argument (neglecting the mask) and feeds that one argument as the `query`, `key`, and `value` in the `Multi-Head Attention` layer. This can be seen by following the arrows in the diagram.

In [34]:
class EncoderLayer(nn.Module):
    hidden_dim: int
    n_heads: int
    size_per_head: int
    attn_dropout: float
    fc_dropout: float

    def setup(self):
        self.attn = MultiHeadAttention(n_heads=self.n_heads,
                                       size_per_head=self.size_per_head,
                                       attn_dropout=self.attn_dropout,
                                       fc_dropout=self.fc_dropout)
        self.aan_0 = AddAndNorm()
        self.ff = FeedForward(hidden_dim=self.hidden_dim)
        self.aan_1 = AddAndNorm()

    def __call__(self, X, mask=None, training=False):
        X1 = self.attn(X, X, X, mask=mask, training=training)
        X = self.aan_0(X, X1)
        X1 = self.ff(X)
        X = self.aan_1(X, X1)
        return X

![image.png](../_static/post/another-annotated-transformer/encoder-circled.png)

In [35]:
class Encoder(nn.Module):
    pos_encoding: Callable[[int, int], jnp.array]
    vocab_size: int
    embed_dim: int
    layers: Sequence[EncoderLayer]

    @nn.compact
    def __call__(self, X, mask=None, training=False):
        B, N = X.shape
        if mask is not None:
            chex.assert_shape(mask, (B, N))
        X = nn.Embed(self.vocab_size, self.embed_dim, name='embed')(X)
        X = X * jnp.sqrt(self.embed_dim)
        # X.shape[-2] is the sequence length
        X = X + self.pos_encoding(X.shape[-2], self.embed_dim)
        for layer in self.layers:
            X = layer(X, mask=mask, training=training)
        return X

Just for fun, let's check out 

In [36]:
def layer_fn():
    return EncoderLayer(hidden_dim=13,
                        attn_dropout=0.1,
                        fc_dropout=0.1,
                        n_heads=7,
                        size_per_head=17)
mdl = Encoder(pos_encoding=sin_pos_enc, vocab_size=len(vocab),
              embed_dim=2 * 3 * 5,
              layers=[layer_fn() for _ in range(3)])
batch = [num2tokens(x) for x in jran.randint(key, (3,), 0, 1e5)]
batch = jnp.stack([jnp.pad(x, (0, 6 - len(x)), 'constant', constant_values=vocab['p']) for x in batch])
mask = (batch == vocab['p'])
params = mdl.init(key, batch)
resp = mdl.apply(params, batch, mask=mask, training=True, rngs={'dropout': key})
resp.shape

(3, 6, 30)

In [37]:
num_params(params['params'])

47190

## Interpretation

There are a subset of transformers called `encoder-only` transformers. Now you know what that means. They take in a sequence of tokens, train the non-contextual embeddings (the `Embedding` layer) and output contextual embeddings. The output embeddings are a function of the word itself, but also the context that the word appeared in.

The quintessential example of an encoder-only transformer is [Bert](https://arxiv.org/pdf/1810.04805.pdf) from Google research.

# Decoder

## DecoderLayer

![image.png](../_static/post/another-annotated-transformer/decoder-layer-circled.png)

There's one last piece to implement: The `Masked Multi-Head Attention`. We've implemented regular `Multi-Head Attention`, but not the masked part. The idea with the masked attention is to feed the whole sequence in at once, but still train the model as if we hadn't. To that end, we restrict the one place in the `Decoder` where information is shared across elements of the output sequence (the self-attention layer) so that a given position can only use information from previous positions. This is commonly described as restricting the positions that a given output can attend to.

In [38]:
def causal_mask(shape):
    return jnp.triu(jnp.ones(shape, dtype=jnp.bool_), k=1)

causal_mask((1, 5, 5))

Array([[[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]]], dtype=bool)

In [39]:
class DecoderLayer(nn.Module):
    hidden_dim: int
    n_heads: int
    size_per_head: int
    attn_dropout: float
    fc_dropout: float

    @nn.compact
    def __call__(self, X_enc, X_dec, enc_mask, dec_mask, training=False):

        def attn(q, kv, mask, training, name):
            mdl = MultiHeadAttention(n_heads=self.n_heads,
                                     size_per_head=self.size_per_head,
                                     attn_dropout=self.attn_dropout,
                                     fc_dropout=self.fc_dropout,
                                     name=f'{name}_attn')
            out = mdl(q, kv, kv, mask=mask, training=training)
            aan = AddAndNorm(name=f'{name}_addnorm')
            return aan(q, out)
        X_dec = attn(X_dec, X_dec, dec_mask, training, 'self')
        X_dec = attn(X_dec, X_enc, enc_mask, training, 'src')
        X1 = FeedForward(hidden_dim=self.hidden_dim)(X_dec)
        X_dec = AddAndNorm()(X_dec, X1)
        return X_dec

![image.png](../_static/post/another-annotated-transformer/decoder-circled.png)

In [40]:
class Decoder(nn.Module):
    pos_encoding: Callable[[int, int], jnp.array]
    vocab_size: int
    embed_dim: int
    layers: Sequence[DecoderLayer]

    @nn.compact
    def __call__(self, X_enc, X_dec, enc_mask, training=False):
        B, N = X_dec.shape[:2]
        dec_mask = causal_mask((1, N, N))
        X_dec = nn.Embed(self.vocab_size, self.embed_dim, name='embed')(X_dec)
        X_dec = X_dec * jnp.sqrt(self.embed_dim)
        # X.shape[-2] is the sequence length
        X_dec = X_dec + self.pos_encoding(X_dec.shape[-2], self.embed_dim)
        for layer in self.layers:
            X_dec = layer(X_enc, X_dec, enc_mask, dec_mask, training=training)
        X_dec = nn.Dense(self.vocab_size, name='final')(X_dec)
        return X_dec

In [41]:
def layer_fn():
    return DecoderLayer(hidden_dim=13,
                        attn_dropout=0.1,
                        fc_dropout=0.1,
                        n_heads=7,
                        size_per_head=17)
mdl = Decoder(pos_encoding=sin_pos_enc,
              vocab_size=len(vocab),
              embed_dim=2 * 3 * 5,
              layers=[layer_fn() for _ in range(3)])
batch = [num2tokens(x) for x in jran.randint(key, (3,), 0, 1e5)]
batch = jnp.stack([jnp.pad(x, (0, 6 - len(x)), 'constant', constant_values=vocab['p']) for x in batch])
kv = [num2tokens(x) for x in jran.randint(key + 1, (3,), 0, 1e6)]
kv = jnp.stack([jnp.pad(x, (0, 6 - len(x)), 'constant', constant_values=vocab['p']) for x in kv])
enc_mask = (kv == vocab['p'])
kv = nn.one_hot(kv, len(vocab))
params = mdl.init(key, kv, batch, enc_mask)
resp = mdl.apply(params, kv, batch, enc_mask, training=True, rngs={'dropout': key})
resp.shape

(3, 6, 12)

In [42]:
num_params(params['params'])

78891

# Transformers

Transformers come in three main flavors.

## Flavors

### Encoder-decoder

![image.png](../_static/post/another-annotated-transformer/tformer.png)

* These models are officially just this diagram.
* They're of a class of models called [seq2seq](https://en.wikipedia.org/wiki/Seq2seq) models.
* They take sequence inputs, generate some state features (via the encoder), and generate a sequence output (via the decoder).
* As such, they're typically used as translation models.

In [43]:
class EncoderDecoderTransformer(nn.Module):
    pos_encoding: Callable[[int, int], jnp.array]
    in_vocab_size: int
    out_vocab_size: int
    embed_dim: int
    n_layers: int
    hidden_dim: int
    attn_dropout: float
    fc_dropout: float
    n_heads: int
    size_per_head: int

    def setup(self):
        self.encoder = Encoder(
            pos_encoding=self.pos_encoding,
            vocab_size=self.in_vocab_size,
            embed_dim=self.embed_dim,
            layers=[EncoderLayer(hidden_dim=self.hidden_dim,
                                 attn_dropout=self.attn_dropout,
                                 fc_dropout=self.fc_dropout,
                                 n_heads=self.n_heads,
                                 size_per_head=self.size_per_head,
                                 name=f'encoder_{i}')
                    for i in range(self.n_layers)])
        self.decoder = Decoder(
            pos_encoding=self.pos_encoding,
            vocab_size=self.out_vocab_size,
            embed_dim=self.embed_dim,
            layers=[DecoderLayer(hidden_dim=self.hidden_dim,
                                 attn_dropout=self.attn_dropout,
                                 fc_dropout=self.fc_dropout,
                                 n_heads=self.n_heads,
                                 size_per_head=self.size_per_head,
                                 name=f'decoder_{i}')
                    for i in range(self.n_layers)])

    def __call__(self, X, Y, source_mask, training=False):
        # required for dot product attention
        chex.assert_equal(self.encoder.embed_dim, self.decoder.embed_dim)
        encodings = self.encoder(X, source_mask, training=training)
        self.sow('intermediates', 'encodings', encodings)
        return self.decoder(encodings, Y, source_mask, training=training)

### Encoder-only

![image.png](../_static/post/another-annotated-transformer/encoder-circled.png)

* These take in a sequence and output state features.
* It's mostly useful for tasks like text classification, sentiment analysis, stuff like that.
* One notable example is Google's [bert](https://en.wikipedia.org/wiki/BERT_(language_model)).

In [44]:
class EncoderOnlyTransformer(nn.Module):
    pos_encoding: Callable[[int, int], jnp.array]
    vocab_size: int
    embed_dim: int
    n_layers: int
    hidden_dim: int
    attn_dropout: float
    fc_dropout: float
    n_heads: int
    size_per_head: int

    def setup(self):
        self.encoder = Encoder(
            pos_encoding=self.pos_encoding,
            vocab_size=self.vocab_size,
            embed_dim=self.embed_dim,
            layers=[EncoderLayer(hidden_dim=self.hidden_dim,
                                 attn_dropout=self.attn_dropout,
                                 fc_dropout=self.fc_dropout,
                                 n_heads=self.n_heads,
                                 size_per_head=self.size_per_head,
                                 name=f'encoder_{i}')
                    for i in range(self.n_layers)])

    def __call__(self, X, mask, training=False):
        return self.encoder(X, mask, training=training)

### Decoder-only

![image.png](../_static/post/another-annotated-transformer/decoder-circled.png)

* These are called [generative model](https://en.wikipedia.org/wiki/Generative_model)s.
* They take a static state and generate a sequence iteratively.
* Mostly useful for text (media) generation[^transformer_uses].
* One notable example: [GPT](https://en.wikipedia.org/wiki/Generative_pre-trained_transformer).

[^transformer_uses]: This fact is very quickly becoming outdated.

In [45]:
class DecoderOnlyTransformer(nn.Module):
    pos_encoding: Callable[[int, int], jnp.array]
    vocab_size: int
    embed_dim: int
    n_layers: int
    hidden_dim: int
    attn_dropout: float
    fc_dropout: float
    n_heads: int
    size_per_head: int

    def setup(self):
        self.embed = nn.Embed(self.vocab_size, self.embed_dim)
        self.decoder = Decoder(
            pos_encoding=self.pos_encoding,
            vocab_size=self.out_vocab_size,
            embed_dim=self.embed_dim,
            layers=[DecoderLayer(hidden_dim=self.hidden_dim,
                                 attn_dropout=self.attn_dropout,
                                 fc_dropout=self.fc_dropout,
                                 n_heads=self.n_heads,
                                 size_per_head=self.size_per_head,
                                 name=f'decoder_{i}')
                    for i in range(self.n_layers)])

    def __call__(self, static, X, source_mask, training=False):
        encodings = self.embed(static)
        return self.decoder(encodings, X, source_mask, training=training)

# Example
As an example, we'll train a model to perform rot13. This isn't intended to be an example how how these models can be useful, but rather just an example of how this model in training. There are a few reasons why this task is not appropriate, but the biggest one is probably that from the perspective of the model, this is not much different from a copy task[^identity_fn].

[^identity_fn]: The copy task is the identity function. You train a model to copy the input ($f(X) = X$).

## Setup
Jax requires a bit of setup, so we'll do that.

In [46]:
import optax
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [47]:
vocab = {chr(97 + i): i for i in range(26)}
vocab['<start>'] = len(vocab)
vocab['<pad>'] = len(vocab)
vocab

{'a': 0,
 'b': 1,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'j': 9,
 'k': 10,
 'l': 11,
 'm': 12,
 'n': 13,
 'o': 14,
 'p': 15,
 'q': 16,
 'r': 17,
 's': 18,
 't': 19,
 'u': 20,
 'v': 21,
 'w': 22,
 'x': 23,
 'y': 24,
 'z': 25,
 '<start>': 26,
 '<pad>': 27}

Throughout this example, we'll be using the fact that the token IDs correspond to the sorted alphabet with two tokens added at the end[^alphabet_index]. So we do things like generate random strings and compute the target via modulo arithmetic.

[^alphabet_index]: Meaning `vocab['a'] == 0`, and `vocab['b'] == 1`, etc.

In [48]:
def get_data(key):
    k0, k1 = jran.split(key, 2)
    max_len = 15
    X = jran.randint(k0, (10, max_len), 0, len(vocab) - 2)
    mask = jnp.stack([jnp.arange(max_len) >= i for i in jran.randint(k1, (10,), 1, max_len)])
    X = X * (1 - mask) + (mask * vocab['<pad>'])
    Y = ((X + 13) % (len(vocab) - 2)) * (1 - mask) + mask * vocab['<pad>']
    Ys = (
        jnp.ones_like(Y, dtype=jnp.int32)
        .at[:, 1:].set(Y[:, :-1])
        .at[:, 0].set(vocab['<start>'])
    )
    return (X, Ys, mask.astype(jnp.float32)), Y

In [49]:
mdl = EncoderDecoderTransformer(pos_encoding=sin_pos_enc,
                                in_vocab_size=len(vocab),
                                out_vocab_size=len(vocab),
                                embed_dim=8,
                                n_layers=1,
                                hidden_dim=5,
                                attn_dropout=0.0,
                                fc_dropout=0.0,
                                n_heads=7,
                                size_per_head=5)

opt = optax.chain(
    optax.clip_by_global_norm(1),
    optax.sgd(
        learning_rate=optax.warmup_exponential_decay_schedule(
            init_value=0.5, peak_value=0.8, warmup_steps=100,
            transition_steps=200, decay_rate=0.5,
            transition_begin=100, staircase=False, end_value=1e-3
        )
    )
)

params = mdl.init(key, *get_data(key)[0])
print('num_params: ', num_params(params))
opt_state = opt.init(params)

num_params:  4665


Notice that we have almost 5 thousand parameters and this is just about the minimal example I could come up with. This architecture gets big real quick. Our `train_step` function is our main training code.

In [50]:
@jax.jit
def train_step(params, opt_state, step, key):
    """Train for a single step."""
    k0, k1 = jran.split(jran.fold_in(key, step))
    args, y = get_data(k0)

    @jax.grad
    def grad_fn(params):
        logits = mdl.apply(params, *args,
                           training=True, rngs={'dropout': k1})
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, y
        ).mean()
        return loss
    grads = grad_fn(params)
    updates, opt_state = opt.update(
        grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state

In [51]:
for step in tqdm(range(10_000)):
    params, opt_params = train_step(params, opt_state, step, key)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:26<00:00, 371.81it/s]


In [52]:
args, y = get_data(key + 500)
yh = mdl.apply(params, *args)

In [53]:
(y[:, :-1] == jnp.argmax(yh[:, :-1], axis=-1)).all()

Array(True, dtype=bool)

Now let's check it out on our own words. To do that we'll write our own `rot13` and `rot13_inv` functions[^bad_idea].

[^bad_idea]: Yet another reason why this is a bad idea (we have easier to reason about alternatives).

In [54]:
def rot13(input_string):
    return ''.join([chr(((vocab[x] + 13) % 26) + 97) for x in input_string])
def rot13_inv(input_string):
    return ''.join([chr(((vocab[x] - 13) % 26) + 97) for x in input_string])

a = 'asdfqwerz'
b = rot13(a)
c = rot13_inv(b)
print(a, '=>', b, '=>', c)

asdfqwerz => nfqsdjrem => asdfqwerz


In [55]:
def str2ids(txt):
    return [vocab[x] for x in txt]


def strs2ids(*txts):
    ids = [str2ids(x) for x in txts]
    maxlen = max([len(x) for x in ids])
    return jnp.stack([jnp.pad(jnp.array(x), (0, maxlen - len(x)), 'constant', constant_values=vocab['<pad>'])
                      for x in ids])

strs2ids('asdf', 'qwer', 'zxcvbu')

Array([[ 0, 18,  3,  5, 27, 27],
       [16, 22,  4, 17, 27, 27],
       [25, 23,  2, 21,  1, 20]], dtype=int32)

In [56]:
def ids2str(ids):
    x = [list(vocab)[x] for x in ids]
    x = [y if y != '<pad>' else '~' for y in x]
    return ''.join(x).rstrip('~')

def ids2strs(ids):
    return [ids2str(x) for x in ids]

Now let's run the test.

In [57]:
X = jnp.array(strs2ids('hey', 'there', 'ma', 'dood'))
start = jnp.array([[vocab['<start>']]] * X.shape[0], dtype=jnp.int32)
Y = start
while (Y[:, -1] != vocab['<pad>']).any():
    Y = jnp.argmax(mdl.apply(params, X, jnp.concatenate([start, Y], axis=-1), X == vocab['<pad>']), axis=-1)

In [58]:
ids2strs(list(Y))

['url', 'gurer', 'zn', 'qbbq']

In [59]:
[rot13_inv(x) for x in ids2strs(list(Y))]

['hey', 'there', 'ma', 'dood']

Yay! It works! We can see that the encoder transforms the input sequence into input embeddings, whereas the decoder transforms a state vector (attending to its own history) to the target.

Stay tuned for more on transformers. Future topics include:

1. Ablation study to see what each part contributes
2. Model explainability
3. How to implement many tasks as either encoders, decoders, encoder-decoders