In [22]:
import math

import torch
from torch import nn
import torch.nn.functional as F

# Self-attention
Based on [Peter Bloem's blog](https://peterbloem.nl/blog/transformers).
>Self-attention is a sequence-to-sequence operation taking in $x_1,...,x_t$ and returning $y_1,...,y_t$ where each $y_i$ is the weighted sum of all $x_{ij}$'s: $y_i = \sum_j w_{ij}x_j$
>**Intuition:** 
> - The dot product expresses how related two vectors of the input sequence are. 
> - Related is defined by the learning task.
> - The output vectors are weighted sums of the input vectors, where the weights are calculated by the dot products.
- $w_{ij}$ is not a parameter but derived from a function of $x_i$ and $x_j$.
- Simplest function is dot product: $w_{ij} = x_i^Tx_j$. As this yields results in $[-\infty, \infty]$, the softmax function is applied to get a probability distribution: $w_{ij} = \frac{exp(x_i^Tx_j)}{\sum_{j'}exp(x_i^Tx_{j'})}$

Input: Sequence of $t$ vectors of $k$ dimensions: $X^{(t,k)}$
With minibatch dimension $b$: $X^{(b,t,k)}$

In [21]:
x = torch.randn(2, 3, 4)
print("x: \n", x)
# Matrix multiplication over the last two dimensions
print("x.transpose(1, 2): \n", x.transpose(1, 2))
raw_weights = torch.bmm(x, x.transpose(1, 2))  # (b, t, k) * (b, k, t) -> (b, t, t)
print("Raw weights x * x^T: \n", raw_weights)
# Softmax over the last dimension (row)
weights = F.softmax(raw_weights, dim=2)  # (b, t, t)
print("Weights: \n", weights)
# Weighted sum over the last dimension (row)
y = torch.bmm(weights, x)  # (b, t, t) * (b, t, k) -> (b, t, k)
print("y: \n", y)

x: 
 tensor([[[-0.6576, -0.0910,  0.6779,  1.7254],
         [ 0.7237, -0.8033,  0.9599, -1.4178],
         [-0.3415, -0.3925, -0.8440,  0.2096]],

        [[-0.7420, -1.5567, -2.0906, -0.9844],
         [ 1.1749,  0.9946, -0.6373,  0.4512],
         [ 0.5579,  0.8278,  1.4489, -0.2451]]])
x.transpose(1, 2): 
 tensor([[[-0.6576,  0.7237, -0.3415],
         [-0.0910, -0.8033, -0.3925],
         [ 0.6779,  0.9599, -0.8440],
         [ 1.7254, -1.4178,  0.2096]],

        [[-0.7420,  1.1749,  0.5579],
         [-1.5567,  0.9946,  0.8278],
         [-2.0906, -0.6373,  1.4489],
         [-0.9844,  0.4512, -0.2451]]])
Raw weights x * x^T: 
 tensor([[[ 3.8774, -2.1985,  0.0499],
         [-2.1985,  4.1004, -1.0393],
         [ 0.0499, -1.0393,  1.0270]],

        [[ 8.3136, -1.5319, -4.4904],
         [-1.5319,  2.9795,  0.4448],
         [-4.4904,  0.4448,  3.1559]]])
Weights: 
 tensor([[[9.7650e-01, 2.2437e-03, 2.1252e-02],
         [1.8242e-03, 9.9236e-01, 5.8146e-03],
         [2.5041e-01

## Multi-head attention
>Multi-head self-attention allows a model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this. ~[Attention is all you need](https://arxiv.org/abs/1706.03762)
 
Based on [Peter Bloem's blog](https://peterbloem.nl/blog/transformers).
- Multi-head self-attention allows the model to account for different relationships between different words in the input sequence.
- A attention head is a self-attention layer with its own parameters. 
    - A self-attention layer consists of a linear projection of the input vectors, followed by the dot product attention and a linear projection of the output vectors.
    - The projection matrices are denoted by $W_q^r, W_k^r, W_v^r$ for the $r$-th head.


In [23]:
class SelfAttention(nn.Module):
    def __init__(self, k_dimensions: int, heads: int = 4, mask: bool = False):
        super().__init__()
        assert k_dimensions % heads == 0, "Embedding dimension must be divisible by number of heads."
        self.k_dimensions = k_dimensions
        self.heads = heads
        
        # Linear transformations for Q, K, V
        self.transform_to_key = nn.Linear(k_dimensions, k_dimensions * heads, bias=False)
        self.transform_to_query = nn.Linear(k_dimensions, k_dimensions * heads, bias=False)
        self.transform_to_value = nn.Linear(k_dimensions, k_dimensions * heads, bias=False)
        
        # Final linear transformation
        self.unify_layer = nn.Linear(k, k)
        
    def forward(self, x: torch.Tensor):
        b_batch_size, t_sequence_length, k_dimensions = x.size()
        h_heads = self.heads
        
        # Linear transformation to obtain key, query and value vectors
        key_vectors = self.transform_to_key(x)  # (b, t, h * k)
        query_vectors = self.transform_to_query(x)  # (b, t, h * k)
        value_vectors = self.transform_to_value(x)  # (b, t, h * k)
        
        # We need one query, key and value vector per head, so we reshape in order to split the embedding dimension k into h heads
        # A view is a reshape that doesn't change the underlying data representation
        s = k_dimensions // h_heads
        key_vectors = key_vectors.view(b_batch_size, t_sequence_length, h_heads, s)  # (b, t, h, s)
        query_vectors = query_vectors.view(b_batch_size, t_sequence_length, h_heads, s)  # (b, t, h, s)
        value_vectors = value_vectors.view(b_batch_size, t_sequence_length, h_heads, s)  # (b, t, h, s)
        
        # Dot product computation is the same for all heads, so we can fold the heads into the batch dimension
        # (b, t, h, k) -> (b * h, t, k)
        key_vectors = key_vectors.transpose(1, 2).contiguous().view(b_batch_size * h_heads, t_sequence_length, s)
        query_vectors = query_vectors.transpose(1, 2).contiguous().view(b_batch_size * h_heads, t_sequence_length, s)
        value_vectors = value_vectors.transpose(1, 2).contiguous().view(b_batch_size * h_heads, t_sequence_length, s)
        
        # Dot product of query and key vectors to obtain raw weights
        dot_products = torch.bmm(query_vectors, key_vectors.transpose(1, 2))  # (b * h, t, t)
        # Scale dot products by dimensionality of key
        dot_products = dot_products / (k_dimensions**(1/2))
        # Normalize weights
        normalized = F.softmax(dot_products, dim=2)  # (b * h, t, t)
        
        # Apply the self-attention to the values
        # (b * h, t, t) * (b * h, t, s) -> (b * h, t, s)
        weighted_sum = torch.bmm(normalized, value_vectors).view(b_batch_size, h_heads, t_sequence_length, s)  # (b * h, t, s)
        
        # Put the heads back together by concatenating them
        weighted_sum = weighted_sum.transpose(1, 2).contiguous().view(b_batch_size, t_sequence_length, k_dimensions)
        return self.unify_layer(weighted_sum)

![Reshape of tensors to iterate over heads](https://peterbloem.nl/files/transformers/reshape.svg)

### Transformer block
> The Transformer block is the basic building block of the Transformer architecture. It consists of a multi-head self-attention layer, followed by a feed-forward layer. Each of these layers has a residual connection around it, and is followed by a layer normalization. ~[Attention is all you need](https://arxiv.org/abs/1706.03762)

In [23]:
class TransformerBlock(nn.Module):
    attention: SelfAttention
    mask: bool
    
    def __init__(self, k_dimensions: int, heads: int, mask: bool = False):
        super().__init__()
        self.attention = SelfAttention(k_dimensions, heads=heads, mask=mask)
        self.mask = mask
        
        # Normalizatiion layers to normalize the input before the residual connection is added, and to normalize the output before it is passed on to the next layer.
        self.norm1 = nn.LayerNorm(k_dimensions)
        self.norm2 = nn.LayerNorm(k_dimensions)
        
        # Choice & form of feed-forward layer is arbitrary
        self.feed_forward = nn.Sequential(
            nn.Linear(k_dimensions, 4 * k_dimensions),
            nn.ReLU(),
            nn.Linear(4 * k_dimensions, k_dimensions)
        )
        
    def forward(self, x: torch.Tensor):
        attended = self.attention(x)
        x = self.norm1(attended + x)
        fed_forward = self.feed_forward(x)
        return self.norm2(fed_forward + x)

### Text classification transformer
From [Peter Bloem's blog](https://peterbloem.nl/blog/transformers)

In [None]:
class Transformer(nn.Module):
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()

        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)

        # The sequence of transformer blocks that does all the
        # heavy lifting
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(k=k, heads=heads))
        self.tblocks = nn.Sequential(*tblocks)

        # Maps the final output sequence to class logits
        self.toprobs = nn.Linear(k, num_classes)

    def forward(self, x):
        """
        :param x: A (b, t) tensor of integer values representing
                  words (in some predetermined vocabulary).
        :return: A (b, c) tensor of log-probabilities over the
                 classes (where c is the nr. of classes).
        """
        # generate token embeddings
        tokens = self.token_emb(x)
        b, t, k = tokens.size()

        # generate position embeddings
        positions = torch.arange(t)
    positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)

    x = tokens + positions
    x = self.tblocks(x)

    # Average-pool over the t dimension and project to class
    # probabilities
    x = self.toprobs(x.mean(dim=1))
    return F.log_softmax(x, dim=1)