# Why use positional encoding

If you're anything like me, while trying to implement transformers, you've read the original [attention is all you need](https://arxiv.org/abs/1706.03762) paper, [the annotated transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html), [the updated version](http://nlp.seas.harvard.edu/annotated-transformer/), [d2l.ai](https://d2l.ai/chapter_attention-mechanisms-and-transformers/transformer.html), and had to cobble them all together to get something going. This post is an attempt to make that proccess easier for people like me in a short and to-the-point style. You can think of this as a bare-bones implementation with a whole lot of documentation.

Notes:
1. this post is about *how* transformers are implemented, not *why* they're implemented the way they are.
2. This post pretty much takes you from just the basics of ML through understanding how transformers work. Feel free to skip whatever sections you're familiar with.

# Overview

The following is the transformer architecture diagram taken from the original paper. We'll be referring back to it often in this post.

![image.png](./assets/tformer.png)
Transformer architecture diagram

At a high level, the transformer is an encoder-decoder model; it takes a **sequence** of **tokens** from a source (e.g. English words) and learns to translate that into a destination sequence (e.g. French words). The encoder is the left side and the decoder is the right side of the architecture diagram.

For the rest of the post, when we refer to "tokens" or "sequences", feel free to replace them with "words" and "sentences" (or "paragraphs", "pages", "books", etc.) respectively in your head.

# Preamble

We'll be using [jax](https://jax.readthedocs.io/en/latest/index.html) and [flax](https://flax.readthedocs.io/en/latest/index.html) for this implementation.

In [1]:
import altair as alt
import numpy as np
import pandas as pd

In [2]:
vocab = []
embeddings = []

with open('/home/amniskin/code/daggerml/transformers/data/glove.6B.50d.txt', 'r') as f:
    for l in f:
        word, *vals = l.split(' ')
        vocab.append(word)
        embeddings.append(np.array([float(x) for x in vals]))
embeddings = np.stack(embeddings)
tok2id = {k: i for i, k in enumerate(vocab)}
embeddings.shape

(400000, 50)

In [3]:
sentence = 'the dog is is playing with the cat'

In [4]:
sentence_tokens = sentence.split()
token_ids = np.stack([tok2id[x] for x in sentence_tokens])
token_ids.shape

(8,)

In [5]:
embed = np.stack([embeddings[i] for i in token_ids])
embed.shape

(8, 50)

In [6]:
def softmax(x):
    return np.exp(x) / np.exp(x).sum(axis=-1, keepdims=True)

In [7]:
def dot_prod_attn(q, k, v, dropout=lambda x: x, mask=None):
    # NxD @ DxM => NxM
    # (B[, H], N, M)
    attn_scores = q @ k.swapaxes(-2, -1) / np.sqrt(q.shape[-1])
    attn_weights = softmax(attn_scores)
    # (B[, H], N, D)
    out = dropout(attn_weights) @ v
    return out, attn_weights

In [8]:
def self_attn(x, **kw):
    return dot_prod_attn(x, x, x, **kw)

In [9]:
def heatmap(matrix, query_ids, key_ids, title='attention weights'):
    idy = [str(a) + ':' + vocab[j] for a, j in enumerate(key_ids)]
    idx = [str(a) + ':' + vocab[j] for a, j in enumerate(query_ids)]
    df = (
        pd.DataFrame(
            matrix,
            index=pd.Index(idy, name='key_word'),
            columns=pd.Index(idx, name='query_word')
        )
        .stack()
        .to_frame('value')
        .reset_index()
    )
    df['key'] = df['key_word'].apply(lambda x: x.split(':')[1])
    df['query'] = df['query_word'].apply(lambda x: x.split(':')[1])
    return (
        alt.Chart(df, title=title)
        .mark_rect()
        .encode(
            x=alt.X('key_word:N'),
            y=alt.Y('query_word:N'),
            color=alt.Color('value', title='softmax').scale(scheme='greenblue'),
            tooltip=[
                'key:N',
                'query:N',
                alt.Tooltip('value', title='val', format='.02'),
            ],
        )
    )

heatmap(self_attn(embed)[-1], token_ids, token_ids)

In [10]:
def sin_pos_enc(sequence_length, embed_dim):
    X = np.expand_dims(np.arange(sequence_length), 1) / \
        np.power(10000, np.arange(embed_dim, step=2) / embed_dim)
    out = np.empty((sequence_length, embed_dim))
    out[:, 0::2] = np.sin(X)
    out[:, 1::2] = np.cos(X)
    return out

pos_enc = sin_pos_enc(*embed.shape)
pos_enc.shape

(8, 50)

In [11]:
heatmap(self_attn(pos_enc)[-1], token_ids, token_ids, title='difference in attention weights')

In [12]:
_, w1 = self_attn(embed)
_, w2 = self_attn(embed + pos_enc)

In [13]:
alt.hconcat(
    heatmap(w1, token_ids, token_ids, title='without positional encoding'),
    heatmap(w2, token_ids, token_ids, title='with positional encoding'),
)

In [14]:
heatmap(w2 - w1, token_ids, token_ids, title='difference in attention weights')

In [15]:
heatmap(w2 - w1 > 0, token_ids, token_ids, title='Difference is positive?')