<a href="https://colab.research.google.com/github/ArjunSohur/Self-Attention/blob/master/Self_Attention_Spelled_Out.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Self Attention Spelled Out

In this code, we will review the structure of a multi head attention module using self attention.

The goal in this notebook is not to create a training mechanism that trains the multihead attention weights (that will be abother notebook).  

Moreover, the goal is to be able to see how self attention - and multi head attention by extention -  look and interact in code.

This version of sef attention is not the most rubust, clean, nor efficient implementation; rather, it aims to be a gentle introduction to enable understanding more complicated architecture.

We recommend that you understand self attention theoretically before trying to understand this code.  If you're unfamiliar with how attention works (particually the significance of queries, keys, and values), check out our guide to self attention:

https://github.com/ArjunSohur/transformergallery/blob/main/README.md

The code is heavily based off of the article:

https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51

This article seemed to have the most consie and down-to-first-principles based self attention implementation we've seen.  Since it helped us in our journey to understanding multi head attention, we think it can greatly benefit you as well!

Enjoy!

### Imports

In [1]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as f

### Self Attention Head

Let's start with implementing a single self attention head.  The most common formula for self attention is the scaled dot product attention: $$\text{attention(Q, K, V)} = \text{softmax}(\frac{QK^T}{\sqrt{d_{QK}}})V$$
Where $Q, K \in \mathbb{R}^{n \text{ } \times \text{ } d_{QK}} \text{ and } V \in \mathbb{R}^{n \text{ } \times \text{ } d_V } $.  If you are confused about what these $Q, K, \text{ and } V$ come from, check out our explanation:
https://github.com/ArjunSohur/transformergallery/blob/main/README.md

Our first task will be creating this formula.  The biggest challenge - and one that we will encounter quite often - will be making sure that the dimensions of our matrices match up for matric multiplication.

In [2]:
# PARAMS
# queries - tensor of shape (batch_size, sequence_length, number_of features_for_queries)
# keys - tensor of shape (batch_size, sequence_length, number_of features_for_keys)
# Note: number_of features_for_keys = number_of features_for_queries
# values - tensor of shape (batch_size, sequence_length, number_of features_for_values)
def scaled_dot_product_attention(queries: Tensor, keys: Tensor, values: Tensor) -> Tensor:
    # First, we batch matrix multiply the queries vector with the transpose of the keys vector
    # This step essentially determines how important each position is relative to each other
    # Results in a sequence_length x sequence_length matrix
    matrix_mult_of_queries_and_keys = queries.bmm(keys.transpose(1, 2))

    # We need to ensure that back propagation runs smoothly
    # We keep the values of our above matrix multiplication in check by dividing by the square root
    # of the number of hidden dimension of the queries (which equals the hidden dimension of the keys)
    scalar_for_gradient_stability = queries.size(-1) ** (1/2)

    # Dividing the values of the matrix multiplication by the number of hidden dimension of the queries
    matrix_multiplication_adjusted_by_scalar = matrix_mult_of_queries_and_keys / scalar_for_gradient_stability

    # To further help gradient descent and to standardize weights, we apply softmax row-wise on out result
    softmax_scaled_matrix_multiplication = f.softmax(matrix_multiplication_adjusted_by_scalar, dim=-1)

    # Lastly, we perform batch matrix multiplication with the value matrices
    scaled_dot_product_attention_result = softmax_scaled_matrix_multiplication.bmm(values)

    # Results in a sequence_length x number_of features_for_values matrix
    return scaled_dot_product_attention_result

Now that we have the formula down, we can actually focus on the self attention head itself.  We represent it as a class so that we can instantiate as many separate self attention mechanisms as possible.

There are two methods in our self attention class: the instantiation method and the the forward method.

The initialization method defines linear, single-layer neural networks that will act as the weigths for our ** queries, keys, and values **.  We can view a neural network as weights if we think about the weigths of the edges of a fully connected neural network between each node as an element of a matrix.  Each row corresponds to an input nodes edge weight with the output nodes.  The neural network is important for traning the weights so that the attention mechanism leans what to pay attention to.

The forward method passes the query, key, and value inputs through the weighting neural networks then sends the result to the sclaed dot product attention formula.

In [3]:
class SelfAttentionHead(nn.Module):
    # PARAMS
    # Weight matrices are crucial for the attention model to learn what is important and what isn't
    # All the parameters make it so that matrix multiplication of with the queries, keys, and values goes smoothly
    # Considering the queries, keys, and values matrices will be sequence_length x embedding_length
    # we use the embedding length as the number of rows and chose a number for the amount of hidden dimensions
    def __init__(self, embedding_dimension: int, queries_keys_hidden_dimension: int, values_hidden_dimension: int):
        # Since SelfAttentionHead is a subclass of nn.module, we need to make a super call
        super(SelfAttentionHead, self).__init__()

        # Weights for thr keys, queries, and values
        self.query_weights = nn.Linear(embedding_dimension, queries_keys_hidden_dimension)
        self.key_weights = nn.Linear(embedding_dimension, queries_keys_hidden_dimension)
        self.value_weights = nn.Linear(embedding_dimension, values_hidden_dimension)

    # Overriding nn.Module's forward method
    # PARAMS
    # query, key, and values are all have size input_sequence_length x embedding size
    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        # performing matrix multiplication with the weights
        weighted_query = self.query_weights(query)
        weighted_key = self.key_weights(key)
        weighted_value = self.value_weights(value)

        # using scaled dot product attention to find the attention weights
        attention = scaled_dot_product_attention(weighted_query, weighted_key, weighted_value)

        return attention

### Multi Head Attention

Notice how we called the above class "SelfAttentionHead".  Then, the name "multi head attention" would imply that there are multiple self attention heads - which is exactly the case!

Our multi head attention class is just multiple "heads" of self attention.

We send our inputs through multiple heads of attention and concatenate the results of each one.  To normalize the jumble of all the concatenated attention head outputs, we send it though a simple one layer neural network that outputs a normal sized input vector that encapsulates the outputs of the attention head.


In [4]:
class MultiHeadAttention(nn.Module):
    # Same params as a regular self attention head, but we have to specify how many heads we want
    def __init__(self, number_of_heads: int, embedding_dimension: int, queries_keys_hidden_dimension: int,
                 values_hidden_dimension: int):
        # Since MultiHeadAttention is a subclass of nn.module, we perform a super call to begin with
        super(MultiHeadAttention, self).__init__()

        # Creates a list of heads
        self.heads = nn.ModuleList([SelfAttentionHead(embedding_dimension, queries_keys_hidden_dimension,
                                                      values_hidden_dimension)
                                    for _ in range(number_of_heads)])

        # feed forward layer to deal with the huge concatenation matrix
        self.feed_forward_layer = nn.Linear(number_of_heads * values_hidden_dimension, embedding_dimension)

    # forward call
    def forward(self, query: Tensor, key: Tensor, value: Tensor):
        # We basically just concatenate all the results of each head ...
        multi_head_result = torch.cat([head(query, key, value) for head in self.heads], dim=-1)

        # ... then pass it through a feed forward neural network to clean it up
        processed_multi_head_result = self.feed_forward_layer(multi_head_result)

        return processed_multi_head_result

### Running an Input Through a Multi Head Attention Layer

Ok, let's test out this bad Larrold out to see if it gives us the correct size tensor based on the input.

In [5]:
    # Here I just made up numbers to see if the multi-head attention mechanism woks as intended
    # The goal of the code is not to use or train multi-head attention, but just to create it
    # therefore, the determinant of success will be if the output is valid
    number_of_batches = 3
    number_of_inputs = 24
    embedding_dimension = 512
    queries_and_keys_hidden_dimension = 1024
    values_hidden_dimension = 512
    number_of_heads = 8

    # Creating random values for keys, queries, and values
    query = torch.rand([number_of_batches, number_of_inputs, embedding_dimension])
    key = torch.rand([number_of_batches, number_of_inputs, embedding_dimension])
    value = torch.rand([number_of_batches, number_of_inputs, embedding_dimension])

    # Initializing a multi-head attention instance
    multi_head = MultiHeadAttention(number_of_heads=number_of_heads,
                                    queries_keys_hidden_dimension=queries_and_keys_hidden_dimension,
                                    embedding_dimension=embedding_dimension,
                                    values_hidden_dimension=values_hidden_dimension)

    # Seeing the result of a forward pass
    print(multi_head.forward(query, key, value).size())

torch.Size([3, 24, 512])


And it looks like, though all that jumble, we got a vector that is the same size as our input, which is exactly what we wanted!

If we had trained this attention layer, our output would now have important relationships, like dependancies, between different parts of the input encoded into it.