NLP Training 6: Attention from scratch
--- 


In [2]:
import torch.nn as nn
import torch
import os

os.chdir('..')
print(f'Setting working dir to: {os.getcwd()}')

Setting working dir to: /Users/ingomarquart/Documents/GitHub


## Writing an attention layer from scratch with PyTorch

### Overview

We will write one attention layer with a single attention head from scratch, to illustrate how a fairly complex model can be written and composed in PyTorch.


For the attention head, we need the following components

1.   A scaled dot product
2.   Query, Key and Value operations

and to write a single-head attention layer, we also need

3.   Layer Norm
4.   Feedforward Network
5.   An attention module that implements the above, plus skip connections


We will implement all these at PyTorch modules to get the hang of it.
In practice, we can instead define functions and use only one or two modules.


First, let's write our first PyTorch Module



### Exercise 1: Implementing a PyTorch Module

PyTorch modules need to implement a couple of methods:    
- First, `__init()__`, which also needs to call the superclass's initialization.    
- Second, a `forward` method that tells PyTorch what takes the module from input to output.

Create a new class called `TestModule` that inherits from `nn.Module` and implements a `__init__` and `__forward__` method. The later should just print `Hello World`.

In [4]:
class TestModule(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self):
    print("Hello World")
    pass

Different from other ML frameworks, PyTorch modules act like simple Python classes. Test it by instantiate your class an call it just like any other Python module.

In [6]:
# Instantiate the class
test_module = TestModule()

# We call the forward pass by calling the module directly!
test_module()

Hello World


### Exercise 2:  Query, Key, and Value Operation

Let's start with our first real module. We will code the operation that duplicates our input embedding vector `X` three times and transforms them separately into key, query and value vectors.

The Query, Key and Value operation is a simple matrix multiplication and doesn't need its own module. 

But since it is simple, we will write this as a module (`torch.nn.Module)`). On top of that, we will even write two of the three multiplications (for Query and Key) by hand. This goes to show that PyTorch can keep arbitrary variables in the computational graph (not just modules or prefixed functions) and optimize them - provided that we register them correctly. For the Values we will use `nn.Linear` to do the multiplication.

We already created a model skeleton for you, try to fill it!

In [8]:
class QueryKeyValue(nn.Module):
  def __init__(self, dim_h, dim_k):
    """
    dim_n: Hidden embedding dimensions of the inputs
    dim_k: Embedding dimension of outputs going into the attention mechanism
    """
    super().__init__()

    # Here we write the parameter matrices ourselves for Query and Key
    # Since we use torch.ones it will automatically filled with 1s
    # For real use cases we would use torch.empty and a specific initialization function
    self.W_q = nn.Parameter(torch.ones([dim_h, dim_k], requires_grad = True))
    self.W_k = nn.Parameter(torch.ones([dim_h, dim_k], requires_grad = True))
    # Now we need to register the parameters with the module
    # Otherwise, PyTorch would not optimize them in the backwards pass
    self.register_parameter('W_q', self.W_q)
    self.register_parameter('W_k', self.W_k)

    # For value, we do it using a PyTorch Module (a linear layer)
    # This is equivalent for the solution above, but mich shorter
    # We do not need to register the parameters, since they are already registered
    # Same for the initialization of weights
    self.value = nn.Linear(dim_h, dim_k)

  def forward(self, X):
    # Our custom linear functions
    queries = torch.matmul(X, self.W_q)
    keys = torch.matmul(X, self.W_k)
    # Torch Module
    values = self.value(X)
    # And we return
    return queries, keys, values


### Exercise 3 - Test your Module

After the implementation, we need to the the Module:

- Let's create a test tensor with two observations and the embedding dimension 5. Since our parameters will be `torch.floats`, we need to specify the same datatype, otherwise `matmul` will throw an error.
- Let us initialize a `QueryKeyValue` class. `dim_h` needs to be 5, which is the dimension of our tensor. By contrast, we can choose the dimension for the attention embeddings. In BERT, it will be `dim_h` divided by the number of heads. Here, let's do 3.
- If we plug in our test tensor, we should get three outputs, all with the same size, can you confirm?


In [10]:
# A simple tensor for test purposes
test_tensor = torch.tensor([[1, 2, 3, 4, 5], 
                            [5, 6, 7, 8, 10]], dtype=torch.float32)

# Init our Module
qkv_mechanism = QueryKeyValue(5, 3)

# Run a pass through the model
queries,keys,values = qkv_mechanism(test_tensor)
print(queries.shape)
print(keys.shape)
print(values.shape)

torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 3])


Since we have initialized our own parameters (for queries and keys) as one, we can also reason about what the output should be.
First, they should be equal for queries and keys (both weights are one).

Second, since weights are one, the first row should be
$$
[1,2,3,4,5] * [1,1,1,1,1]^T$$
that is
$$
1*1+1*2+1*3+1*4+1*5 = 15
$$

In [11]:
print(queries)
print(keys)

tensor([[15., 15., 15.],
        [36., 36., 36.]], grad_fn=<MmBackward0>)
tensor([[15., 15., 15.],
        [36., 36., 36.]], grad_fn=<MmBackward0>)


By contrast, the values module is initialized by PyTorch - it has random weights

In [12]:
print(values)

tensor([[ 0.9126,  0.5544, -3.0775],
        [ 2.5894,  2.3251, -8.2189]], grad_fn=<AddmmBackward0>)


### Exercise 4: Scaled Dot Product

For our scaled dot product, we take three matrices and compute attention scores and output embeddings. This time, we will use PyTorch build in modules - specifically, the softmax module.

A second option is to use PyTorch functional API: softmax is a simple function, and PyTorch can keep track of its impact on the gradients. In this case, we include `torch.functional` as `F` and employ `F.softmax()`.

As before, try to fill in the missing parts in the Module skeleton.

In [14]:
class ScaledDotProduct(nn.Module):
  def __init__(self):
    super().__init__()
    # Specify dim for clarity - torch will compute over the last dimension by default.
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, Q, K, V):
    """
    Q,K and V we have computed earlier
    """
    # Dot Product - allowing for a batch dimension
    attention = torch.matmul(Q, K.T)
    # Get the normalization constant. We need the embedding size of the attention layer
    # aka the last dimension of each of the matrices
    dim_k = Q.shape[-1]
    # attention scores correspond to the softmax of the normed matrix
    attention = self.softmax(attention / torch.sqrt(torch.tensor(2.0 * dim_k)))
    return torch.matmul(attention, V)


### Exercise 4:  Attention Head

That's all we need for a single attention head.

We will also encode this as a module, to show how nicely PyTorch composes modules.

In [15]:
class AttentionHead(nn.Module):
  def __init__(self, dim_h, dim_k):
    super().__init__()

    # Here we simply instantiate our two classes
    self.scaled_dot_product = ScaledDotProduct()
    self.qkv_mechanism = QueryKeyValue(dim_h, dim_k)

  def forward(self,X):
    queries, keys, values = self.qkv_mechanism(X)
    X_next_layer = self.scaled_dot_product(queries, keys, values)

    return X_next_layer
  

Let's test it!

Since we only have one head, and no concat operation, we would want to set `dim_h=dim_k`. After all, the input and output dimension of a transformer layer should be the same!

In [16]:
test_tensor = torch.tensor([[1, 2, 3, 4, 5], 
                            [5, 6, 7, 8, 10]], dtype=torch.float32)

# Add your solution here:
# ...

In [17]:
attention_head = AttentionHead(5, 5)
attention_head(test_tensor)

tensor([[ 0.8646, -1.3807, -0.7563, -0.9157, -2.6003],
        [ 0.8646, -1.3807, -0.7563, -0.9157, -2.6003]], grad_fn=<MmBackward0>)

### Exercise 5:  Recap - and Batching

We have so far supplied a test tensor with two dimensions. The first dimension is the token, the second the embedding dimension.
Thus, we can train one sentence at a time

This would take a long time... the real benefit of a transformer is that we can train many sequences (batches) in parallel.

To test this, we just add another "batch" to our test tensor - the same sequence duplicated.

In [18]:
test_tensor = torch.tensor([[[1, 2, 3, 4, 5], [5, 6, 7, 8, 10]], 
                            [[1, 2, 3, 4, 5], [5, 6, 7, 8, 10]]], dtype=torch.float32)
print(test_tensor.shape)
print(test_tensor)

torch.Size([2, 2, 5])
tensor([[[ 1.,  2.,  3.,  4.,  5.],
         [ 5.,  6.,  7.,  8., 10.]],

        [[ 1.,  2.,  3.,  4.,  5.],
         [ 5.,  6.,  7.,  8., 10.]]])


If we were to run this on our above attention layer, it would fail, but we only need to change one thing:

In the QKV mechanism we should unsqueeze our weight matrices to give them a batch dimension (which is empty). matmul will then broadcast the matrices and apply them to each batch. For the Linear Layer (`nn.Linear`), we of course don't need to worry about it.

In addition, our Scaled Dot Product now works across batches. We have the same number of batches for all matrices - hence no need to broadcast.
In that case, `torch.bmm` (batched matrix multiplication) is faster!

Previously, we have transposed the Key vector. Of course, we can not transpose the batch dimension. Instead of K.T, we will specify that we only want to transpose the last two dimensions, using `.transpose(-1,-2)` method.

The two places are marked by #!!

In [19]:
class QueryKeyValue(nn.Module):
  def __init__(self, dim_h, dim_k):
    """
    dim_n: Hidden embedding dimensions of the inputs
    dim_k: Embedding dimension of outputs going into the attention mechanism
    """
    super().__init__()

    # Here we write the parameter matrices ourselves for Query And Key
    self.W_q = nn.Parameter(torch.ones([dim_h, dim_k], requires_grad = True))
    self.W_k = nn.Parameter(torch.ones([dim_h, dim_k], requires_grad = True))
    # Now we need to register the parameters with the module
    # Otherwise, PyTorch would not optimize them in the backwards pass
    self.register_parameter('W_q',self.W_q)
    self.register_parameter('W_q',self.W_q)
    # Finally, we also need to initialize the parameters

    # For value, we do it using a PyTorch Module
    self.value = nn.Linear(dim_h,dim_k)
    
    # 

  def forward(self, X):
    # Our custom linear functions
    # !!
    queries = torch.matmul(X,self.W_q.unsqueeze(0))
    keys = torch.matmul(X,self.W_k.unsqueeze(0))
    # Torch Module
    values = self.value(X)
    # And we return
    return queries, keys, values

class ScaledDotProduct(nn.Module):
  def __init__(self):
    super().__init__()
    # Specify dim for clarity - torch will compute over the last dimension by default.
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, Q, K, V):
    """
    Q,K and V we have computed earlier
    """
    # Dot Product - allowing for a batch dimension - here we used batched matrix multiply with transpose
    # !!
    attention = torch.bmm(Q,K.transpose(-1,-2))
    # Get the normalization constant. We need the embedding size of the attention layer
    # aka the last dimension of each of the matrices
    dim_k = Q.shape[-1]
    # attention scores correspond to the softmax of the normed matrix
    attention = self.softmax(attention / torch.sqrt(torch.tensor(2.0*dim_k)))
    return torch.matmul(attention, V)

class AttentionHead(nn.Module):
  def __init__(self, dim_h, dim_k):
    super().__init__()

    # Here we simply instantiate our two classes
    self.scaled_dot_product = ScaledDotProduct()
    self.qkv_mechanism = QueryKeyValue(dim_h, dim_k)

  def forward(self,X):
    queries, keys, values = self.qkv_mechanism(X)
    X_next_layer = self.scaled_dot_product(queries, keys, values)

    return X_next_layer
  

Let's test it

In [20]:
test_tensor = torch.tensor([[[1, 2, 3, 4, 5], [5, 6, 7, 8, 10]],
                            [[1, 2, 3, 4, 5], [5, 6, 7, 8, 10]]], dtype=torch.float32)

# Add your solution here:
# ...

In [21]:
attention_head = AttentionHead(5,5)
attention_head(test_tensor)

tensor([[[-3.2694, -1.4893, -1.7046, -1.1541,  0.5490],
         [-3.2694, -1.4893, -1.7046, -1.1541,  0.5490]],

        [[-3.2694, -1.4893, -1.7046, -1.1541,  0.5490],
         [-3.2694, -1.4893, -1.7046, -1.1541,  0.5490]]],
       grad_fn=<UnsafeViewBackward0>)

### Exercise 6:  LayerNorm (and BatchNorm)

We will not code this ourselves, but we can use PyTorch to understand LayerNorm and BatchNorm better

LayerNorm computes
$$
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
$$

the important point here over what dimensions the expected value and variance are computed.

In our application, the first dimension of our tensor is the batch dimension, the second the dimension of observations, and the final the hidden dimension.

LayerNorm as applied in BERT norms across the hidden dimension.

let's see this on our test tensor

In [22]:
ln = nn.LayerNorm(5)
test_tensor = torch.tensor([[[1, 2, 3, 4, 5], [5, 6, 7, 8222, 10]], 
                            [[1, 2, 22223, 4, 5], [5, 6, 7, 8, 10]]], dtype=torch.float32)
# Consider the first batch
ln(test_tensor[0,...])

tensor([[-1.4142, -0.7071,  0.0000,  0.7071,  1.4142],
        [-0.5006, -0.5003, -0.5000,  2.0000, -0.4991]],
       grad_fn=<NativeLayerNormBackward0>)

It is a bit hard to see, but if you sum all hidden embeddings for each dimension, you get roughly 0


In [23]:
ln(test_tensor[0,...]).sum(-1, keepdims=True)

tensor([[-3.5763e-07],
        [-2.3842e-07]], grad_fn=<SumBackward1>)

If we instead sum up across observations (or even batches), this is not true

In [24]:
ln(test_tensor[0,...]).sum(0, keepdims=True)

tensor([[-1.9148, -1.2074, -0.5000,  2.7071,  0.9151]], grad_fn=<SumBackward1>)

We can compare this to the BatchNorm. Of course, we have a sequence model, where a "batch" corresponds to one sequence.

So, for the first sequence, the batch norm gives


In [25]:
bn = nn.BatchNorm1d(5)
bn(test_tensor[0,...])

tensor([[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
        [ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000]],
       grad_fn=<NativeBatchNormBackward0>)

Now, the hidden dimension for each observations do not sum up to 0

In [26]:
bn(test_tensor[0,...]).sum(-1, keepdims=True)

tensor([[-5.0000],
        [ 5.0000]], grad_fn=<SumBackward1>)

But, across observations ("batches"), each dimension comes out to zero

In [27]:
bn(test_tensor[0,...]).sum(0, keepdims=True)

tensor([[ 1.1921e-07, -5.9605e-08,  0.0000e+00,  0.0000e+00, -2.3842e-07]],
       grad_fn=<SumBackward1>)

But recall that we do not wish to add skewed relationships between observations (so, tokens) in the transformer just due to normalizations. Hence, we choose LayerNorm for our Attention Layer

### Exercise 7: Feedforward Layer


The attention mechanism is responsible for encoding complex relationships between tokens (in ML lingo we say it learns a dyadic approximation to any permutation invariant function).

The Feedforward Layer is a far less exciting module of the Transformer - it is simply a set of linear regressions (`nn.Linear)`) followed by a non-linear activation function - also called a multilayer perceptron (MLP). 

Nevertheless, this part of the model has been identified as crucial component, as it provides expressivity to process the information from the attention layer. To do this well, the Feedfoward layer should first increase the dimensionality of the embeddings, before squashing them back down to the original size.

And to have a useful composition of the two layers, instead of a plain ol' matrix multiplication, we use a non-linear activation (in this case, a GeLU `nn.GELU`).

For this part, we will use standard PyTorch functions and modules, having done the extra legwork above.   

See if you can fill the missing parts in the module skeleton.

In [None]:
class FeedForwardLayer(torch.nn.Module):
  def __init__(self, dim_h, dim_expanded):
    """
    dim_h: Hidden embedding dimensions of the inputs
    dim_expanded: 
    """
    super().__init__()

    # from dim_h -> dim_expanded
    self.firstlayer = nn.Linear(in_features=dim_h, out_features=dim_expanded)
    # and back
    self.secondlayer = nn.Linear(in_features = dim_expanded, out_features = dim_h)
    self.activation = nn.GELU()

  def forward(self, x):
    x = self.firstlayer(x)
    x = self.activation(x)
    # Dropout here
    x = self.secondlayer(x)
    return x

To confirm it works with our batched tensor.

In [None]:
test_tensor = torch.tensor([[[1, 2, 3, 4, 5], [5, 6, 7, 8222, 10]],
                            [[1, 2, 22223, 4, 5], [5, 6, 7, 8, 10]]], dtype=torch.float32)
ffl = FeedForwardLayer(5, 15)
ffl(test_tensor)

### Exercise 8: Building the (Encoder) Attention Layer

With all our components in place, we can build a single-head attention layer.

We need two more things:


*   First, our attention head's embedding dimension is `dim_k`, whereas the rest of our transformer uses `dim_h`. This is because we might want to have several parallel attention heads. We are thus missing the aggregation step, with a final linear layer ensuring the correct dimension.

*   Second, recall that we want to add a skip connection and - depending on how we do it - one or two layernorms.
Here, we will implement the Pre-Normalization skip connections, which have been shown to make training easier.

See, if you can fill the missing parts in the module skeleton.

In [None]:
class SimpleAttentionLayer(nn.Module):
  def __init__(self, dim_h, dim_k, dim_expanded):
    super().__init__()

    # Attention head
    self.qkv = QueryKeyValue(dim_h, dim_k)
    self.dot_product_attn = ScaledDotProduct()
    self.aggregation_step = nn.Linear(dim_k, dim_h)

    # Layer Norm(s)
    self.layernorm1 = nn.LayerNorm(dim_h)
    self.layernorm2 = nn.LayerNorm(dim_h)
    
    # Feedforward Layer
    self.ffl = FeedForwardLayer(dim_h,dim_expanded)

  def forward(self, x):
    # Branch off from the skip connection before the initial layernorm
    # - that is: we keep x as is
    h = self.layernorm1(x)
    
    # Triplicate our input matrix
    query, key, value = self.qkv(h)
    
    # Run the scaled dot product attention operation
    h = self.dot_product_attn(query, key, value)
    h = self.aggregation_step(h)
    
    # Our first skip connection returns to the main path
    x = h + x
    
    # Branch off again with the second skip leading into the feedforward layer
    h = self.layernorm2(x)
    h = self.ffl(h)
    
    # Finally, return the skip connection
    x = h + x
    return x

In [None]:
# Test the attention layer
attn_layer = SimpleAttentionLayer(5, 5, 15)
attn_layer(test_tensor)

### What is missing?

Actually - not much. With the addition of a head - say a softmax classification head, and token+positional embeddings, this is a working transformer.

We have not implemented a dropout layer, which would go at the end - after the feedforward layer . If you use it, you'll want to differentiate between evaluation and training and disable the dropout for the former.

Second, we are missing multi-headed attention. But: Having specified a single attention head, that's easy to do. Try it yourself!

Hint: Use a PyTorch Module List


## Pro Version

If you code NN in Python, you should be keenly aware of torch.einsum and einops

In [1]:
%pip install einops

Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
from einops import rearrange
class MultiHeadAttention(torch.nn.Module):
    def __init__(
        self,
        dim_input: int,
        nr_heads: int = 8,
        dim_head: int = 16,
        dropout_p: float = 0.0,
        scale_factor: float = 0.5,
    ):
        """Ye olde Multihead Attention, implmented with Einstein Notation.
        Note: There' ain't no masking here

        Args:
            dim_input (int): The input dimension
            nr_heads (int, optional): Number of heads. Defaults to 8.
            dim_head (int, optional): Dimension of heads. Defaults to 16.
            dropout_p (float, optional): Dropout. Defaults to 0.0.
            scale_factor (float, optional): Exponent of the scaling division - default is square root. Defaults to 0.5.
        """
        super().__init__()
        self.nr_heads = nr_heads
        self.scale = dim_head**-scale_factor

        self.to_qkv = torch.nn.Linear(dim_input, dim_head * nr_heads * 3, bias=False)
        self.to_out = torch.nn.Linear(dim_head * nr_heads, dim_input)
        self.dropout = torch.nn.Dropout(dropout_p)

    def forward(self, x):
        h = self.nr_heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
        sim = torch.einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        attn = sim.softmax(dim=-1)
        attn = self.dropout(attn)

        out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)", h=h)
        return self.to_out(out)
