<a href="https://colab.research.google.com/github/MichalRyszardWojcik/transformer-language-model/blob/main/TransformerLM_mathematical_definition_with_code_and_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The transformer language model technology is used in the following projects:
<ul>
<li><a href="https://en.wikipedia.org/wiki/GPT-3">GPT-3</a>: an AI system than can create articles, poetry, stories, news reports and dialogue using just a small amount of input text
<li><a href='https://openai.com/blog/dall-e/'>DALL·E</a>: a neural network that creates images from text captions
<li><a href='https://openai.com/blog/openai-codex/'>OpenAI Codex</a>: an AI system that translates natural language to code
</ul>

It was introduced in the seminal article <a href='https://arxiv.org/abs/1706.03762'>Attention Is All You Need</a> coauthored by Lukasz Kaiser.

<p>We aim to explain the computer science theory behind this technology
and to present it from basic principles without any use of AI jargon
so that any mathematician can understand how it works and any programmer can implement it on their own.
I studied the original article and the Trax implementation and produced this mathematical documentation.
</p>

<div>
This article was written in 2021 by Michal Ryszard Wojcik (PhD in mathematics) under the guidance of Lukasz Kaiser (PhD in computer science) based on two sources:
<ul><li>the article
<a href='https://arxiv.org/abs/1706.03762'>Attention Is All You Need</a> coauthored by Lukasz Kaiser
<li style='margin-top:0.5em;'>the source code
<a href='https://github.com/google/trax/blob/master/trax/models/transformer.py#L194'>TransformerLM in Trax</a>
supervised by Lukasz Kaiser
</ul>
</div>

<p>Later I wrote a Python implementation of this mathematical specification, with some fully working demo examples:
<a href='https://github.com/MichalRyszardWojcik/transformer-language-model/blob/main/TransformerLM_mathematical_definition_with_code_and_examples.ipynb'>github.com/MichalRyszardWojcik/transformer-language-model</a>.
It can be executed in Google Colab directly in the browser without any installation.</p>

<p>The official web page for this project is <a href='https://www.apronus.com/math/transformer-language-model-definition'>https://www.apronus.com/math/transformer-language-model-definition</a></p>

In [1]:
import numpy as np
!pip install -q -U trax
from trax import fastmath
from trax.fastmath import numpy as jnp

[K     |████████████████████████████████| 637 kB 31.6 MB/s 
[K     |████████████████████████████████| 4.6 MB 20.8 MB/s 
[K     |████████████████████████████████| 511.7 MB 6.4 kB/s 
[K     |████████████████████████████████| 5.8 MB 50.8 MB/s 
[K     |████████████████████████████████| 438 kB 71.8 MB/s 
[K     |████████████████████████████████| 1.6 MB 53.5 MB/s 
[?25h

https://github.com/google/trax/blob/master/trax/models/transformer.py#L182

# Rowwise and autoregressive matrix operations

Let $T(X)=Y$ represent an operation which takes a matrix $X$ and returns a matrix $Y$ with the same number of rows as $X$, the number of colums being fixed for both input and output, say $m$ for input and $m'$ for output.

We say that $T$ is **rowwise** iff for each matrix $X$ with $n$ rows and $m$ columns we have $T(X_i)=T(X)_i$ whenever $i\leq n$, where $X_i$ is the i-th row of $X$ and $T(X)_i$ is the i-th row of $T(X)$.

We say that $T$ is **autoregressive** iff for each matrix $X$ with $n$ rows and $m$ columns we have $T(X|i)=Y|i$ whenever $i\leq n$, where the symbol $A|i$ denotes the matrix A reduced to the first $i$ rows, (e.g. $A|n=A$ and $A|1$ is the first row).

Note that rowwise operations are autoregressive and compositions of autoregressive operations are autoregressive.

Let $\unicode{0x25FA}A$ denote the matrix $A$ modified so that it has $-\infty$ above the main diagonal and is unchanged otherwise. Note that for a square matrix $A$ the rowwise operation softmax($\unicode{0x25FA}A$) always returns a lower triangular square matrix.

**Claim.** If $A,B$ are autoregressive operations and $A$ always returns lower triangular square matrices then the matrix multiplication operation $X\mapsto A(X)\times B(X)$ is autoregressive.

Below we are going to define the Transformer Language Model and argue that it is autoregressive by appropriately applying the arguments collected above.

In [2]:
def softmax_L(A):
  X = jnp.tril(jnp.exp(A))
  return X / jnp.sum(X,axis=1,keepdims=True)

# Embedding

The Embedding layer is the first layer acting directly on the input composed of a sequence of tokens.

Each token is an integer from $\{0,1,\ldots,\mathrm{vocab\_size}-1\}$, where $\mathrm{vocab\_size}$ is one of the transformer language model's parameters.

This layer translates tokens into "vectorwords" which are vectors in $\mathbb R^{d_{model}}$, where $d_{model}$ is a parameter.

There is an embedded array of weights of shape $\mathrm{vocab\_size}\times d_{model}$. The layer simply assigns a row to each token.

The output is a matrix of dimension $n\times d_{model}$, where $n$ is the number of tokens in the input.



In [3]:
def embedding(weights, input):
  n = input.shape[0]
  rows = []
  for i in range(n):
    row = weights[input[i]]
    rows.append( jnp.array([row]) )
  return jnp.concatenate(tuple(rows), axis=0)

# Positional Encoding

Let us define the Positional Encoding layer in a transformer language model with max input length $M$ and $d_{model}$.

Let $X$ be an input matrix with rows containing vectorwords of length $d_{model}$ with the number of rows between $1$ and $M$.

Then $PE(X)=X+PE|n$ where $n$ is the number of rows in $X$ and $PE|n$ is the layer's embedded $PE$ weight matrix reduced to the first $n$ rows.

Note that this layer simply adds its weight matrix to its input.

Note that it is a rowwise operation.

In [4]:
def positionalEncoding(weights, X):
  n = X.shape[0]
  return X + weights[:n,:]

The weight matrix can be anything that results from the training process. But it is initialized in the following way.

$M$ is the maximum number of tokens in the input

$$PE\colon\{1,\ldots,M\}\times\{1,\ldots,d_{model}\}\to[-1,1]$$

$$PE(pos,2i)=\sin\Big(\cfrac{pos}{10000^{2i/d_{model}}}\Big)$$

$$PE(pos,2i+1)=\cos\Big(\cfrac{pos}{10000^{2i/d_{model}}}\Big)$$

$PE$ is defined here as a function of two variables where the first variable is interpreted as the position in the input sequence (e.g. first token, second token, third token) and the second variable is interpreted as one of the $d_{model}$ axes of the vectorwords created from the input tokens.

But we will also treat $PE$ as a $M\times d_{model}$ matrix of weights embedded in the Positional Encoding layer.



In [5]:
def PE(pos,k):
  if k % 2 == 0:
    trig = '\mathrm{sin}'
    i2 = str(k)
  else:
    trig = '\mathrm{cos}'
    i2 = str(k-1)
  re = trig + "\Big(\cfrac{" + str(pos) + "}{10000^{" + i2 + "/d_{model}}}\Big)"
  return '$\small{' + re + '}$ '

def PEpos(pos,d_model):
  ret = ''
  for k in range(d_model):
    ret += (PE(pos,k))
  return ret + '\n\n'

d_model = 6
max_len = 3
ret = ''
for p in range(max_len):
  ret += PEpos(p+1,d_model)
print(ret)

$\small{\mathrm{sin}\Big(\cfrac{1}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{1}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{1}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{1}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{1}{10000^{4/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{1}{10000^{4/d_{model}}}\Big)}$ 

$\small{\mathrm{sin}\Big(\cfrac{2}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{2}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{2}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{2}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{2}{10000^{4/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{2}{10000^{4/d_{model}}}\Big)}$ 

$\small{\mathrm{sin}\Big(\cfrac{3}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{3}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{3}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{3}{10000^{2/d_{mode

The image represents a vector of $d_{model}=6$ real numbers (vectorword).

$\Large[$$\small{\mathrm{sin}\Big(\cfrac{pos}{10000^{0/d_{model}}}\Big)}$,
$\small{\mathrm{cos}\Big(\cfrac{pos}{10000^{0/d_{model}}}\Big)}$,
$\small{\mathrm{sin}\Big(\cfrac{pos}{10000^{2/d_{model}}}\Big)}$,
$\small{\mathrm{cos}\Big(\cfrac{pos}{10000^{2/d_{model}}}\Big)}$,
$\small{\mathrm{sin}\Big(\cfrac{pos}{10000^{4/d_{model}}}\Big)}$,
$\small{\mathrm{cos}\Big(\cfrac{pos}{10000^{4/d_{model}}}\Big)}$$\Large]$

Let us assume maximum input length $M=3$ and $d_{model}=6$ for simplicity of presentation.

The following $M\times d_{model}$ matrix is populated with numbers from the $PE$ function.

$\small{\mathrm{sin}\Big(\cfrac{1}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{1}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{1}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{1}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{1}{10000^{4/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{1}{10000^{4/d_{model}}}\Big)}$ 

$\small{\mathrm{sin}\Big(\cfrac{2}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{2}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{2}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{2}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{2}{10000^{4/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{2}{10000^{4/d_{model}}}\Big)}$ 

$\small{\mathrm{sin}\Big(\cfrac{3}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{3}{10000^{0/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{3}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{3}{10000^{2/d_{model}}}\Big)}$ $\small{\mathrm{sin}\Big(\cfrac{3}{10000^{4/d_{model}}}\Big)}$ $\small{\mathrm{cos}\Big(\cfrac{3}{10000^{4/d_{model}}}\Big)}$ 


From the paper *Attention Is All You Need*:
We chose this function because we hypothesized it would allow the model to easily attend by relative position, since for any fixed offset $k$, $PE_{pos+k}$ can be represented as a linear function of $PE_{pos}$.

The following articles help to understand the linear representation claim:

https://kazemnejad.com/blog/transformer_architecture_positional_encoding/

https://timodenk.com/blog/linear-relationships-in-the-transformers-positional-encoding/

For any fixed offset $k$, there is a linear map $T_k\colon\mathbb R^{d_{model}}\to\mathbb R^{d_{model}}$ such that
$$T_k\circ PE_{pos}=PE_{pos+k}$$
for any $pos\in\{1,2,\ldots,M-k\}$.

For any fixed offset $k$, there is a linear map $T_k\colon\mathbb R^{d_{model}}\to\mathbb R^{d_{model}}$ such that
$$T_k(PE(pos,i))=PE(pos+k,i)$$
for any $pos\in\{1,2,\ldots,M-k\}$ and any $i\in\{1,\ldots,d_{model}\}$.


# Normalization layers

Let us say that to standardize a vectorword $x\in\mathbb R^{d_{model}}$ means to subtract its arithmetic mean $\overline{x}$ from each coordinate and divide each coordinate by $\sigma(x)+\epsilon$, where $\sigma(x)$ is its standard deviation and $\epsilon=10^{-6}$ is a constant embedded in the code.

Let us say that to normalize a vectorword $x\in\mathbb R^{d_{model}}$ means to standardize it and then multiply it coordinatewise by a weight vector of the same length $a\in\mathbb R^{d_{model}}$ and then add a weight vector $b\in\mathbb R^{d_{model}}$.

Each normalization layer has two weight vectors $a,b\in\mathbb R^{d_{model}}$.

The input is as usual a matrix $X$ with $n$ rows, where $n$ is the number of words in the input sentence (the number of tokens in the input). Each row of $X$ is a vectorword from $\mathbb R^{d_{model}}$.

A normalization layer is a rowwise operation that normalizes each row using the same weights $a,b$.


In [6]:
def normalize(weights,input,eps = 0.000001):
  (a,b) = weights
  mean = jnp.mean(input, axis=1, keepdims=True)
  output = input - mean
  std = jnp.std(output, axis=1, keepdims=True)
  output = output / (std+eps)
  return (output * a) + b

# DecoderBlock

**input**

X is a matrix with n rows, where n is the number of words in the input sentence (the number of tokens in the input)

Each row of X is a vectorword from $R^{d_{model}}$ -- a vector of $d_{model}$ real numbers.

Let us call this the usual shape for a matrix because such are the tensors accepted and returned by decoder blocks and its sublayers -- simplifying the exposition by assuming one sentence per batch and disregarding the batch size axis.


## Causal Attention with one head

This section is not necessary because the Multi-Head Causal Attention section is self-contained and the definition of the Decoder Block refers to the Multi-Head Causal Attention section. The point is to be a gentle introduction to the multi-head version.

**weights**

$W^Q,W^K,W^V$ are square matrices of size ${d_{model}}\times{d_{model}}$ (e.g. 512)

They are the weights embedded in the Causal Attention layer CA.

**definition**

Let

$Q(X)=X\times W^Q$,

$K(X)=X\times W^K$,

$V(X)=X\times W^V$.

Then the Causal Attention layer is defined as

$$CA(X)=\mathrm{softmax}\Big(\cfrac{\unicode{0x25FA}\big(Q(X)\times K(X)^T\big)}{\sqrt{d_{model}}}\Big)\times V(X)$$

Note that $CA$ is an autoregressive operation.

## Multi-Head Causal Attention

**parameters**

$d_{model}$ is the length of the vectorwords

$h$ is the number of attention heads

$d_{head}=\cfrac{d_{model}}{h}$ is the width of the weight matrices associated with individual heads

**weights**

$W_i^Q,W_i^K,W_i^V$ are matrices of dimension ${d_{model}}\times{d_{head}}$ for $i=1,\ldots,h$

$W^O$ is a matrix of dimension ${d_{model}}\times{d_{model}}$

$B$ is a matrix of dimension $n\times{d_{model}}$ whose rows are identical

Fixing $i$, let

$Q_i(X)=X\times W_i^Q$,

$K_i(X)=X\times W_i^K$,

$V_i(X)=X\times W_i^V$.

The dimension of these matrices is $n\times d_{head}$.

The calculation for an individual head is given by

$$H_i(X)=\mathrm{softmax}\Big(\unicode{0x25FA}\Big(\cfrac{Q_i(X)\times K_i(X)^T}{\sqrt{d_{head}}}\Big)\Big)\times V_i(X)$$

Note that each $H_i$ is an autoregressive operation.

Note that the output matrices $H_i(X)$ have dimension $n\times d_{head}$.

Let us concatenate them to obtain a matrix of dimension $n\times d_{model}$

$$H(X)=\mathrm{Concat}(H_1,\ldots,H_h)$$

Note that $H$ is autoregressive.

Finally, the multi-head causal attention layer is defined by
$$CA(X)=H(X)\times W^O+B.$$
Note that $CA$ is autoregressive.

In [7]:
def causalAttention(h,weights,X):
  WQ, WK, WV, W0, B = weights
  d_model = W0.shape[0]
  d_head = d_model // h
  denominator = jnp.sqrt(d_head)
  def W_i(W,i):
    c1 = d_head * i
    c2 = d_head * (i+1)
    return W[:,c1:c2]
  def Q_i(i): return jnp.matmul(X,W_i(WQ,i))
  def K_i(i): return jnp.matmul(X,W_i(WK,i))
  def V_i(i): return jnp.matmul(X,W_i(WV,i))
  def H_i(i):  
    A = jnp.matmul(Q_i(i),jnp.transpose(K_i(i))) / denominator
    return jnp.matmul(softmax_L(A),V_i(i))
  H = jnp.concatenate(tuple( H_i(i) for i in range(h) ), axis=1)
  return jnp.matmul(H,W0) + B

## The Feed Forward layer

The feed forward layer is defined by

$$FF(X) = \mathrm{ReLu}\big(N_{ff}(X)\times A+K\big)\times B+L$$

where $N_{ff}$ is a normalization layer with its two weight vectors $a,b\in\mathbb R^{d_{model}}$,

$A$ is a weight matrx of $d_{model}$ rows and $d_{ff}$ columns (e.g. $d_{ff} = 2048$),

$K$ is a weight matrix of dimension $n\times d_{ff}$ with identical rows,

$B$ is a weight matrix of $d_{ff}$ rows and $d_{model}$ columns,

$L$ is a weight matrix of dimension $n\times d_{model}$ with identical rows.

Note that $a,b,A,K,B,L$ are weights embedded in $FF$.

Note that both the input and output are matrices of the usual shape of $n$ rows and $d_{model}$ columns.

Note that the input matrix $X$ and the output matrix $FF(X)$ have the same  usual shape $n\times d_{model}$.

Moreover, note that $FF$ is a rowwise operation.


In [8]:
def relu(X): return jnp.maximum(0,X)

def feedforward(weights, X):
  a, b, A, K, B, L = weights
  q = jnp.matmul(normalize((a,b),X),A) + K
  return jnp.matmul(relu(q),B) + L

## DecoderBlock Definition

In order to match the Trax implementation we will write the definition using the residual layer notation $$Res(f)(x) = x+f(x),$$
where $f$ is a layer and $x$ is an input tensor.

The DecoderBlock is a composition of a normalization layer $N_{ca}(X)$ and the layers previously defined:


$$DB(X)=Res(FF)\big(Res(CA\circ N_{ca})(X)\big)$$

Without the residual layer notation it is simply

$$DB(X)=X+CA(N_{ca}(X))+FF(X+CA(N_{ca}(X)))$$

Note that the input matrix $X$ and the output matrix $DB(X)$ have the same  usual shape $n\times d_{model}$.

Moreover, note that $DB$ is an autoregressive operation.

In [9]:
def res(f,x): return f'{x}+{f}({x})'
ca = res('CAoN','x')
ff = res('FF',ca)
print(ff)

x+CAoN(x)+FF(x+CAoN(x))


**inspirational idea**

Perhaps the decoder block would be a little more efficient if the formula was simplified to DB(X) = x+FF(x+CAoN(x))

In [10]:
def decoderBlock(n_heads, weights, X):
  Nweights, CAweights, FFweights = weights
  CA = causalAttention(n_heads, CAweights, normalize(Nweights,X))
  FF = feedforward(FFweights, X + CA)
  return X + FF # according to the inspirational idea
  return X + CA + FF # according to the definition

# Final Layer

The input to the final layer is the output of a decoder block, which is a matrix $X$ of the usual shape $n\times d_{model}$, where $n$ is the number of tokens in the original input to the transformer language model.

The final layer's weight matrix $Y$ has dimension $d_{model}\times\mathrm{vocab\_size}$.

The final layer is a matrix multiplication defined as $\Phi(X)=X\times Y+B$,

where $B$ is a weight matrix of dimension $n\times\mathrm{vocab\_size}$ with identical rows.

The output is a matrix of dimension $n\times\mathrm{vocab\_size}$.

Note that it is a rowwise operation.

In [11]:
def finalLayer(weights, X):
  Y, B = weights
  return jnp.matmul(X,Y) + B

# Definition of TransformerLM

**parameters**

$\mathrm{vocab\_size}$ is the size of the set of input tokens

$d_{model}=512$ is the length of the vectorwords passed from sublayer to sublayer

$d_{ff}=2048$ is the width of certain weight matrices inside the feed forward layer in each decoder block

$n\_layers=6$ is the number of decoder blocks

$n\_heads=8$ is the number of attention heads in each decoder block

$\mathrm{max\_len}=2048$ is the maximum number of tokens in the input sequence



$$T_{LM}=\Phi\circ N\circ DB_{n\_layers}\circ\ldots\circ DB_1\circ PE\circ Em$$

where $N$ is a normalization layer

The input is a sequence of tokens.

Each token is an integer from $\{0,1,\ldots,\mathrm{vocab\_size}-1\}$.

The output is a matrix of dimension $n\times\mathrm{vocab\_size}$ containing real numbers, where $n$ is the number of tokens in the input sequence.

Note that it is an autoregressive operation if we interpret the input as a matrix of dimension $n\times 1$.

In [12]:
def transformerLM(n_heads, weights, input):
  EMweights, PEweights, DBweights, Nweights, Fweights = weights
  EM = embedding(EMweights, input)
  PE = positionalEncoding(PEweights, EM)
  n_layers = len(DBweights)
  X = PE
  for i in range(n_layers):
    X = decoderBlock(n_heads, DBweights[i], X)
  X = normalize(Nweights, X)
  return finalLayer(Fweights, X)

In [13]:
def generateRandomWeights(vocab_size, d_model = 512, d_ff = 2048, n_layers = 6, n_heads = 8, max_len = 2048, real = float):
  rng = np.random.default_rng()
  def randomweights(shape):
    return (rng.random(shape) - 0.5).astype(real)
  def generateRandomDBweights():
    Nweights = (randomweights((d_model,)), randomweights((d_model,)))
    shape = (d_model, d_model)
    WQ = randomweights(shape)
    WK = randomweights(shape)
    WV = randomweights(shape)
    W0 = randomweights(shape)
    B = randomweights((1,d_model))
    CAweights = (WQ, WK, WV, W0, B)
    a, b = randomweights((d_model,)), randomweights((d_model,))
    A = randomweights((d_model, d_ff))
    K = randomweights((1,d_ff))
    B = randomweights((d_ff,d_model))
    L = randomweights((1,d_model))
    FFweights = (a, b, A, K, B, L)
    return (Nweights, CAweights, FFweights)
  EMshape = (vocab_size, d_model)
  EMweights = randomweights(EMshape)
  PEshape = (max_len, d_model)
  PEweights = randomweights(PEshape)
  DBweights = []
  for i in range(n_layers):
    DBweights.append(generateRandomDBweights())
  DBweights = tuple(DBweights)
  Nshape = (d_model,)
  Nweights = (randomweights(Nshape),randomweights(Nshape))
  Yshape = (d_model, vocab_size)
  Bshape = (1, vocab_size)
  Fweights = (randomweights(Yshape), randomweights(Bshape))

  weights = (EMweights, PEweights, DBweights, Nweights, Fweights)
  return weights

In [14]:
def generateRandomInput(vocab_size, n):
  rng = np.random.default_rng()
  return rng.integers(0,vocab_size,n)

# Evaluation (the loss function)

## The essence of the loss function &mdash; single sequence input with no loss weights

The input sequences are elements of $\{0,1,\ldots,\rm{vocab\_size-1}\}^n$, where $n$ is the sequence length and $\rm{vocab\_size}$ is the number of possible tokens from which the sequences are built.

We are given an autoregressive operation $T$, which takes an input sequence (interpreted as a matrix of dimension $n\times 1$) and outputs a real-valued matrix of dimension $n\times\rm{vocab\_size}$.

In this context, we are going to define a loss function which returns a positive real number for each input sequence.


Let $(x_1,\ldots,x_n)$ be the input sequence.

In the first step of the definition,
let $X=(0,x_1,\ldots,x_n)$ be viewed as the $(n+1)\times 1$ dimensional input matrix for the autoregressive operation $T$. Then $T(X)$ is a real-valued matrix of dimension $(n+1)\times\rm{vocab\_size}$, but we will be interested only in the first $n$ rows $T(X)|n$.





Note that each row of $T(X)|n$ is a sequence of $\rm{vocab\_size}$ many arbitrary real numbers as outputted by the working of the operation $T$. In the second step, we apply the LogSoftmax function to each row separately to obtain $Y=\rm{LogSoftmax}(T(X)|n)$, which is now a matrix of $n$ rows with each row serving as a log probability distribution over the set of tokens.

At this point we have created a function $k\mapsto Y[k]$, which assigns a log probability distribution over the set of tokens for each position in the original input sequence $(x_1,\ldots,x_n)$.

Note that --- due to the autoregressive nature of $T$ and the insertion of the $0$ token in the first step --- we have guaranteed that:
* the log distribution $Y[1]$ does not depend on the input sequence at all
* the log distribution $Y[2]$ depends only on $x_1$
* the log distribution $Y[3]$ is a function of $(x_1,x_2)$
* the log distribution $Y[k]$ is a function of $(x_1,x_2,\ldots,x_{k-1})$ for each $k=4,5,\ldots,n$.

Therefore it makes sense to evaluate the outputted log distribution $Y[k]$ against the actual token $x_k$ in the input.

Let $y_k = Y[k]_i$, where $i=x_k+1$, for each $k=1,\ldots,n$.

Note that in our setup $y_k\lt 0$ is interpreted as the log probability of the token $x_k$ appearing on the $k$th position in the input string after $(x_1,\ldots,x_{k-1})$. The closer to zero the better the evaluation.

Finally, the loss function is defined as
$$-\cfrac{1}{n}\sum_{i=1}^n y_k>0.$$
The closer to zero the better.

## Batch evaluation with loss weights

The input sequences are elements of $\{0,1,\ldots,\rm{vocab\_size-1}\}^n$, where $n$ is the sequence length and $\rm{vocab\_size}$ is the number of possible tokens from which the sequences are built.

We will be dealing with a batch of input sequences &mdash; formally a matrix of dimension $\rm{batch\_size}\times n$, whose rows are input sequences. Additionally, there's a matrix of the same dimension whose elements are real numbers from the closed unit interval $[0,1]$, which is called the matrix of loss weights for the input batch.

We are given an autoregressive operation $T$, which takes an input sequence (interpreted as a matrix of dimension $n\times 1$) and outputs a real-valued matrix of dimension $n\times\rm{vocab\_size}$. This time the operation $T$ is going to act simultaneously and independently on each input sequence from the batch so that formally it is a function which takes an array of dimension $\rm{batch\_size}\times n$ and returns an array of dimension $\rm{batch\_size}\times n\times\rm{vocab\_size}$.

In this context, we are going to define a loss function which returns a positive real number for each batch of input sequences with loss weights.


Let us fix notation so that the $j$th row in the input batch is denoted as $(x_{j1},\ldots,x_{jn})$ for $j=1,\ldots,\rm{batch\_size}$.

Similarly, let the $j$th row in the loss weights matrix be denoted as $(\omega_{j1},\ldots,\omega_{jn})$ for $j=1,\ldots,\rm{batch\_size}$.

Let $(y_{j1},\ldots,y_{jn})$ be the log probabilities for the $j$th input sequence as computed according to the formula in the previous section, for $j=1,\ldots,\rm{batch\_size}$.

The loss function is defined by
$$-\cfrac{\sum_{(j,i)}\omega_{ji}y_{ji}}{\sum_{(j,i)}\omega_{ji}}$$
where the summation index $(j,i)$ ranges over
$\{1,\ldots,\rm{batch\_size}\}\times\{1,\ldots,n\}$.

Note that this is a generalization of the loss function from the previous section, which can be seen by setting the weights to zeros on all rows except on a single row where they should be all ones.

### not the same as average over rows

Note that this is <b>not the same</b> thing as the average loss function of a single input sequence with loss weights taken over all the rows:
$$\cfrac{1}{s}\sum_{j=1}^s\Big(-\cfrac{\sum_{i=1}^n\omega_{ji}y_{ji}}{\sum_{i=1}^n\omega_{ji}}\Big)$$
The two formulas give the same result if the sums $\sum_{i=1}^n\omega_{ji}$ are the same for each row, but they are slightly different otherwise. The following code illustrates these points.

In [15]:
import numpy as np
s = np.random.randint(1,89)
n = np.random.randint(1,86)
y = np.random.random_sample(size=(s,n))
w = np.random.random_sample(size=(s,n))
def formula1(y,w):
  a = 0.0
  b = 0.0
  for j in range(s):
    for i in range(n):
      a += w[j][i]*y[j][i]
      b += w[j][i]
  return a/b
def formula2(y,w):
  c = 0.0
  for j in range(s):
    a = 0.0
    b = 0.0
    for i in range(n):
      a += w[j][i]*y[j][i]
      b += w[j][i]
    c += a/b
  return c/s
print(formula1(y,w))
print(formula2(y,w))
# expected to be slightly different

0.5006191316067764
0.5015580101356863


In [16]:
import numpy as np
s = np.random.randint(1,89)
n = np.random.randint(1,86)
_y = np.random.random_sample(size=(s,n))
_w = np.random.random_sample(size=(1,n))
def w(j,i):
  return _w[0][(i+j)%n]
  return _w[0][i]
def y(j,i):
  return _y[j][i]
def formula1():
  a = 0.0
  b = 0.0
  for j in range(s):
    for i in range(n):
      a += w(j,i)*y(j,i)
      b += w(j,i)
  return a/b
def formula2():
  c = 0.0
  for j in range(s):
    a = 0.0
    b = 0.0
    for i in range(n):
      a += w(j,i)*y(j,i)
      b += w(j,i)
    c += a/b
  return c/s
print(formula1())
print(formula2())
# expected to be equal

0.48544311762896336
0.48544311762896436


## The code for the loss function

In [17]:
def LogSoftmax(x):
  return x - fastmath.logsumexp(x, -1, keepdims=True)

In [18]:
def weighted_loss(T, input, mask):
  '''
  T is a function
         with input shape (batch_size, n)
          and output shape (batch_size, n, vocab_size)
  input, mask both have shape (batch_size, n)
  '''
  batch_size = input.shape[0]
  n = input.shape[1] # the length of input sequences
  zero = jnp.zeros((batch_size,1), np.int32)
  X = jnp.concatenate((zero,input), axis=1)
  # X.shape == (batch_size, n+1)
  Y = LogSoftmax( T(X)[:n,:] )
  # Y.shape[0] == batch_size
  # Y.shape[1] == n+1
  # Y.shape[2] == vocab_size
  y = []
  for row_number in range(batch_size):
    row = []
    for token_position in range(n):
      logprob = Y[row_number, token_position, input[row_number, token_position]]
      row.append(logprob)
    y.append(row)
  y = jnp.array(y)
  # y.shape == mask.shape
  return - jnp.sum(mask*y) / jnp.sum(mask)

## Gradient-Descent-based minimization of the loss function

Our loss function is a positive real-valued function but we need to take a closer look at its domain and its differentiability in order to use it within a gradient-descent-based minimization algorithm.

The loss function depends on the batch of input sequences on the one hand and on the collection of weights from the $T$ operation (which is a TransformerLM). We may think of a batch space and a weight space so that the loss function takes a point from the batch space and a point from the weight space to return the loss value.

Note that the loss function is suitable for a mini-batch (stochastic) gradient descent setup because it works for any point from the batch space.

<b>Claim.</b>
Fixing a batch space point, this function is differentiable almost everywhere as a multivariate real-valued function from the weight space.
<br><i>Proof.</i>
Apart from ReLu it is a composition of the basic arithmetic operations (addition, subtraction, multiplication, division) and the exp and log elementary functions from softmax and logSoftmax.

In order to calculate the gradient by automatic differentiation it remains to arbitrarily define the derivative of ReLu at 0.

(There is also the square root in the calculation of the std.)


Useful links for this subsection:
* [Gradient Descent For Machine Learning
by Jason Brownlee](https://machinelearningmastery.com/gradient-descent-for-machine-learning/)
* [Batch, Mini Batch & Stochastic Gradient Descent by Sushant Patrikar](https://towardsdatascience.com/batch-mini-batch-stochastic-gradient-descent-7a62ecba642a)
* [An overview of gradient descent optimization algorithms by Sebastian Ruder](https://ruder.io/optimizing-gradient-descent/)

# Fully working demo

## The dataset for demo purposes

In [19]:
import random

def generate_batch(batch_size, n_tokens, n):
  shape = (batch_size, 2*n+2)
  a = np.zeros(shape,np.int32)
  for y in range(batch_size):
    for i in range(n):
      x = random.randint(1,n_tokens)
      a[y][i] = x
      a[y][2*n-i] = x
  return a

def train_stream(batch_size,
                 n_tokens,
                 input_length_min,
                 input_length_max):
  n = random.randint(input_length_min, input_length_max)
  batch = generate_batch(batch_size, n_tokens, n)
  inputs = batch
  shape = (batch_size, 2*n+2)
  loss_weights = np.zeros(shape,np.float32)
  for y in range(batch_size):
    for i in range(n+1,2*n+2):
      loss_weights[y][i] = 1.0
  return (inputs,loss_weights)

def data_batches_yielder(batch_size, input_params):
  n_tokens, input_length_min, input_length_max = input_params
  while True:
    input, loss_weights = train_stream(batch_size, n_tokens, input_length_min, input_length_max)
    yield (input, loss_weights)

## The training process

In [20]:
def initWeights(model_params):
  (vocab_size, d_model, d_ff, n_layers, n_heads, max_len) = model_params
  real = float
  return generateRandomWeights(vocab_size, d_model, d_ff, n_layers, n_heads, max_len, real)

In [21]:
import time

def train(batch_size, n_steps, input_params, model_params, weights):
  start_time = time.time()
  data_batches = data_batches_yielder(batch_size, input_params)
  n_heads = model_params[4]

  def network(data_batch, weights):
    def T(X):
      rows = []
      for j in range(batch_size):
        row = transformerLM(n_heads, weights, X[j])
        rows.append(row)
      return jnp.array(rows)
    input, mask = data_batch
    return weighted_loss(T, input, mask)

  value_and_grad = fastmath.value_and_grad(network, argnums=1)
  value_and_grad = fastmath.jit(value_and_grad)

  print(f'Starting the {n_steps}-steps training process...')
  acc_loss = 0.0
  for i in range(n_steps):
    data = next(data_batches)
    loss, grad = value_and_grad(data, weights)
    acc_loss += loss
    weights = fastmath.nested_map_multiarg(lambda w, g: w - 0.001 * g, weights, grad)
    if i % 100 == 99:
      print(f'avg loss at {i+1} steps: {acc_loss / 100:.3f}')
      acc_loss = 0.0
  end_time = time.time()
  seconds = int(end_time - start_time)
  minutes = int(seconds / 60)
  print(f'The training process took {minutes}:{seconds % 60}.')
  return weights

## Testing the trained model

In [22]:
def transformerLM_output(model_params, weights, input, length=0):
  (vocab_size, d_model, d_ff, n_layers, n_heads, max_len) = model_params
  if length == 0: length = max_len
  X = [0] + list(input)
  while len(X) < length:
    Y = transformerLM(n_heads, weights, np.array(X))[-1,:]
    token = int(np.argmax(Y))
    X.append(token)
    if token == 0: break
  return X[1:]

'''
def onesequencetest(input_params, model_params, weights):
  n_tokens, input_length_min, input_length_max = input_params
  n = random.randint(input_length_min, input_length_max)
  perfect = generate_batch(1, n_tokens, n)[0]
  input = perfect[:n+1]
  n_heads = model_params[4]
  output = transformerLM_output(model_params, weights, input)
  if (output == list(perfect)):
    print('correct')
  else:
    print('wrong')
  print(list(perfect))
  print(output)
  return 1 if (output == list(perfect)) else 0
'''

def onesequencetest(input_params, model_params, weights):
  n_tokens, input_length_min, input_length_max = input_params
  n = random.randint(input_length_min, input_length_max)
  perfect = generate_batch(1, n_tokens, n)[0]
  input = perfect[:n+1]
  n_heads = model_params[4]
  output = transformerLM_output(model_params, weights, input)
  if (output != list(perfect)):
    print('wrong')
    print('should be:', list(perfect))
    print('instead:  ', output)
    return 0
  return 1

def runtest(input_params, model_params, weights, n):
  print(f'Running a test on {n} random sequences...')
  r = []
  for i in range(n):
    r.append(onesequencetest(input_params, model_params, weights))
  print(100*sum(r)/n, 'success rate')

## Demo examples

In these examples the task is to reverse the order of tokens in the input sequence.

### Example 0

In [23]:
def example0():
  n_tokens = 10
  min_length = 2
  max_length = 2
  # 100 elements in the training dataset
  input_params = (n_tokens, min_length, max_length)

  vocab_size = n_tokens + 1
  d_model = 128
  d_ff = 256
  n_layers = 2
  n_heads = 2
  max_len = 2 * max_length + 2 + 1
  model_params = (vocab_size, d_model, d_ff, n_layers, n_heads, max_len)

  weights = initWeights(model_params)
  batch_size = 4
  n_steps = 6000
  weights = train(batch_size,n_steps,input_params, model_params, weights)

  runtest(input_params, model_params, weights, 100)

In [24]:
# example0() # less than two minutes

### Example 1

In [25]:
def example1():
  n_tokens = 3
  min_length = 2
  max_length = 3
  # 36 elements in the training dataset
  input_params = (n_tokens, min_length, max_length)

  vocab_size = n_tokens + 1
  d_model = 128
  d_ff = 256
  n_layers = 2
  n_heads = 2
  max_len = 2 * max_length + 2 + 1
  model_params = (vocab_size, d_model, d_ff, n_layers, n_heads, max_len)

  weights = initWeights(model_params)
  batch_size = 4
  n_steps = 6000
  weights = train(batch_size, n_steps, input_params, model_params, weights)

  runtest(input_params, model_params, weights, 100)

In [26]:
# example1() # less than 2 minutes

### Example 2

In [27]:
def example2():
  n_tokens = 4
  min_length = 2
  max_length = 4
  # 336 elements in the training dataset
  input_params = (n_tokens, min_length, max_length)

  vocab_size = n_tokens + 1
  d_model = 128
  d_ff = 256
  n_layers = 2
  n_heads = 2
  max_len = 2 * max_length + 2 + 1
  model_params = (vocab_size, d_model, d_ff, n_layers, n_heads, max_len)

  weights = initWeights(model_params)

  batch_size = 4
  n_steps = 16_000
  weights = train(batch_size, n_steps, input_params, model_params, weights)

  runtest(input_params, model_params, weights, 100)

In [None]:
# example2() # less than 5 minutes