<a href="https://colab.research.google.com/github/afaale/ML/blob/ML/Transformers_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Source:
https://peterbloem.nl/blog/transformers

# Transformers

# Self Attention

$$y_i = \sum_{j} w_{ij}x_j$$

$$w_{ij}^{'} = x_i^Tx_j$$

$$w_{ij} = \frac{exp(w_{ij}^{'})}{\sum_{j} exp(w_{ij}^{'})}$$

self-attention.svg

i.e. $$y_2 = \sum_j w_{2j}{x_j} = w_{21}x_1+w_{22}x_2+w_{33}x_3+w_{24}x_4$$


$$w_{2j}^{'} = \sum_j x_2^Tx_j = x_2^Tx_1+x_2^Tx_2+x_2^Tx_3+x_2^Tx_4$$


$$w_{2j} = \frac{exp(w_{2j}^{'})}{\sum_j exp(w^{'}_{2j})} = \frac{exp(w_{2j}^{'})}{exp(w_{21}^{'})+exp(w_{22}^{'})+exp(w_{23}^{'})+exp(w_{24}^{'})}$$

In [None]:
import torch
import torch.nn.functional as F

# assume we have some tensor x with size (b, t, k)
x = ...

raw_weights = torch.bmm(x, x.transpose(1, 2))
# - torch.bmm is a batched matrix multiplication. It
#   applies matrix multiplication over batches of
#   matrices.
weights = F.softmax(raw_weights, dim=2)
y = torch.bmm(weights, x)

# Query, key,value; scaling; multihead attention

$q_i=W_qx_i \quad$ $k_i=W_kx_i \quad$ $v_i=W_vx_i$
$$w^{'}_{ij}=q^{T}_ik_j$$
$$w_{ij}=softmax(w^{'}_{ij})$$
$$y_i=\sum_j w_{ij}v_j$$

qkv_diag.svg

**Scaling the dot product**

$$w^{'}_{ij}=\frac{q_i^{T}k_j}{\sqrt{k}}$$

**Multihead attention**
multi-head.svg

Requires $3h$ matrices of size $k/h$. In total, gives $3hk\frac{k}{h}=3k^2$ parameters to compute the inputs to the multi-head self attention.

# In Pytorch: complete self-attention

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, k, heads=4, mask=False):
        #initialize attributes of the parent classes
        super().init()

        assert k % heads == 0

        self.k, self.heads = k, heads

        # These compute the queries, keys and values for all heads
        self.tokeys    = nn.Linear(k, k, bias=False)
        self.toqueries = nn.Linear(k, k, bias=False)
        self.tovalues  = nn.Linear(k, k, bias=False)

	    # This will be applied after the multi-head self-attention operation.
        self.unifyheads = nn.Linear(k, k)

    def forward(self, x):

        b,t,k = x.size()
        h = self.heads

        queries = self.toqueries(x)
        keys = self.tokeys(x)
        values = self.tovalues(x)

        #see image below
        s = k // h

	    keys    = keys.view(b, t, h, s)
	    queries = queries.view(b, t, h, s)
	    values  = values.view(b, t, h, s)

        # - fold heads into the batch dimension
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        # Get dot product of queries and keys, and scale
        dot = torch.bmm(queries, keys.transpose(1, 2))
        # -- dot has size (b*h, t, t) containing raw weights

        # scale the dot product
        dot = dot / (k ** (1/2))

        # normalize
        dot = F.softmax(dot, dim=2)
        # - dot now contains row-wise normalized weights

        # apply the self attention to the values
        out = torch.bmm(dot, values).view(b, h, t, s)

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)

        return self.unifyheads(out)

reshape.svg

# The transformer block

transformer-block.svg

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, k, heads):
    super().__init__()

    self.attention = SelfAttention(k, heads=heads)

    self.norm1 = nn.LayerNorm(k)
    self.norm2 = nn.LayerNorm(k)

    self.ff = nn.Sequential(
      nn.Linear(k, 4 * k),
      nn.ReLU(),
      nn.Linear(4 * k, k))

  def forward(self, x):
    attended = self.attention(x)
    x = self.norm1(attended + x)

    fedforward = self.ff(x)
    return self.norm2(fedforward + x)

In [None]:
class Transformer(nn.Module):
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()

        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)

		# The sequence of transformer blocks that does all the
		# heavy lifting
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(k=k, heads=heads))
        self.tblocks = nn.Sequential(*tblocks)

		# Maps the final output sequence to class logits
        self.toprobs = nn.Linear(k, num_classes)

    def forward(self, x):
        """
        :param x: A (b, t) tensor of integer values representing
                  words (in some predetermined vocabulary).
        :return: A (b, c) tensor of log-probabilities over the
                 classes (where c is the nr. of classes).
        """
		# generate token embeddings
        tokens = self.token_emb(x)
        b, t, k = tokens.size()

		# generate position embeddings
		positions = torch.arange(t)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)

        x = tokens + positions
        x = self.tblocks(x)

        # Average-pool over the t dimension and project to class
        # probabilities
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x, dim=1)

# Text generation transformer

generator.svg

Masked attention
masked-attention.svg

In [None]:
dot = torch.bmm(queries, keys.transpose(1, 2))

indices = torch.triu_indices(t, t, offset=1)
dot[:, indices[0], indices[1]] = float('-inf')

dot = F.softmax(dot, dim=2)