# Exercise 11 Part 1: Self-Attention
**Summer Semester 2024**

**Author**: Stefan Baumann (stefan.baumann@lmu.de)

### Task: Implement Self-Attention
In this exercise, you will implement multi-head self-attention for a 2D sequence of tokens (shape `B D H W`) yourself using **only basic functions (no pre-made attention implementations!)**. You're allowed to use simple functions such as, e.g., `torch.bmm()`, `torch.nn.functional.softmax()`, ... and simple modules such as `torch.nn.Linear`.

Usage of functions provided by the `einops` library (such as `einops.rearrange()`) is also allowed and encouraged (but completely optional!), as it allows writing the code in a nice and concise way by specifying operations across axes of tensors as strings instead of relying on dimension indices.<br>
A short introduction into einops is available at https://nbviewer.org/github/arogozhnikov/einops/blob/master/docs/1-einops-basics.ipynb.

In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

# Optional
import einops

device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device "{device}".')

Using device "cuda".


In [9]:
class SelfAttention2d(nn.Module):
    def __init__(
        self,
        embed_dim: int = 256,
        head_dim: int = 32,
        value_dim: int = 32,
        num_heads: int = 8,
    ):
        """Multi-Head Self-Attention Module with 2d token input & output

        Args:
            embed_dim (int, optional): Dimension of the tokens at the input & output. Defaults to 256.
            head_dim (int, optional): Per-head dimension of query & key. Defaults to 32.
            value_dim (int, optional): Per-head dimension of values. Defaults to 32.
            num_heads (int, optional): Number of attention heads. Defaults to 6.
        """
        super().__init__()

        # TODO: Student.
        # Hint: use a single linear layer for q/k/v/out each, and name the respective layers q, k, v, out for the unit tests below to work.
        self.embed_dim = embed_dim
        self.head_dim = head_dim
        self.value_dim = value_dim
        self.num_heads = num_heads

        #assert(self.head_dim * num_heads == embed_dim), "Embed size needs to be divided by heads"
        
        
        self.q = nn.Linear(embed_dim, head_dim * num_heads, bias=False)
        self.k = nn.Linear(embed_dim, head_dim * num_heads, bias=False)
        self.v = nn.Linear(embed_dim, value_dim * num_heads, bias=False)
        self.out = nn.Linear(value_dim * num_heads, embed_dim, bias=False)



    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward of multi-head self-attention

        Args:
            x (torch.Tensor): Input tensor of shape (B, D, H, W) (batch, embedding dimension, height, width)

        Returns:
            torch.Tensor: Output tensor of shape (B, D, H, W) (batch, embedding dimension, height, width)
        """
        B, D, H, W = x.shape

        # TODO: Student. Don't forget to implement scaling of the attention logits by 1/sqrt(head_dim).
        # Implement a standard multi-head self-attention mechanism in a fully batched manner (no explicit for loops etc, pure PyTorch/einops code)
        # The expected behavior of this method is that described in Eq. 2 of Attention Is All You Need, Vaswani et al., 2017, NeurIPS.
        # In the case of single-head attention, the expected behavior is described by Eq. 1 of the same paper
        # Hint when you run into problems:
        # For consistency with the multi-head reference implementation the unit test compares against, make sure that the individual heads are arranged correctly in q, k, v, and out.
        # The convention is that each head's part in q/k/v is contiguous, i.e., 
        # if you want to get the query for head 0, it's at q[..., :head_dim], head 1 is at q[..., head_dim:2*head_dim], etc.
        
        # Flatten height and width to make x suitable for input layers
        x = einops.rearrange(x, 'b d h w -> b (h w) d')

        # Perform linear transformations
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        # Split q, k, v into multiple heads
        q = einops.rearrange(q, 'b q_len (num_heads head_dim)  -> b q_len num_heads head_dim ', num_heads=self.num_heads)
        k = einops.rearrange(k, 'b k_len (num_heads head_dim)  -> b k_len num_heads head_dim ', num_heads=self.num_heads)
        v = einops.rearrange(v, 'b v_len (num_heads value_dim)  -> b v_len num_heads value_dim ', num_heads=self.num_heads)

        # Calculate the attention scores

        # q shape: (b, q_len, num_heads, head_dim)
        # k shape: (b, k_len, num_heads, head_dim)
        energy = torch.einsum("bqhd,bkhd -> bhqk", [q, k])
        # energy shape: (b, num_heads, q_len, k_len)

        attention = torch.softmax(energy / (self.head_dim ** (0.5)) , dim=3) # apply to key dimension => normalize the scores to sum up 1 one across the source dimension key

        # attention shape: (b, num_heads, q_len, k_len)
        # v shape: (b, v_len, num_heads, heads_dim)
        out = torch.einsum('bhql,blhd -> bqhd', [attention, v]) # multiply k_len and v_len dimensions together
        # out shape: (b, q_len, num_heads, head_dim)
        
        # concatenate num_heads head_dim into 1 dimension
        out = einops.rearrange(out, 'b q_len num_heads head_dim -> b q_len (num_heads head_dim)')
        # out shape: (b, q_len, (num_heads head_dim) )
        
        out = self.out(out)
        # out shape: (b, q_len, (num_heads head_dim) )

        out = einops.rearrange(out, 'b (h w) d -> b d h w', h=H, w=W)

        return out

# Unit Test (single-head) DO NOT CHANGE!
with torch.no_grad():
    layer = SelfAttention2d(embed_dim=256, head_dim=256, value_dim=256, num_heads=1).to(device)
    x = torch.randn((4, 256, 24, 24), device=device)
    res_layer = layer(x)

    layer_ref = nn.MultiheadAttention(layer.embed_dim, layer.num_heads).to(device)
    layer_ref.load_state_dict({ 'in_proj_weight': torch.cat([layer.q.weight, layer.k.weight, layer.v.weight]), 'out_proj.weight': layer.out.weight }, strict=False)
    res_ref = layer_ref(*[x.view(*x.shape[:2], -1).permute(2, 0, 1)] * 3)[0].permute(1, 2, 0).view(*x.shape)
    assert torch.allclose(res_layer, res_ref, rtol=1e-2, atol=1e-5), 'Single-head attention result incorrect.'

# Unit Test (multi-head) DO NOT CHANGE!
with torch.no_grad():
    layer = SelfAttention2d().to(device)
    x = torch.randn((4, 256, 24, 24), device=device)
    res_layer = layer(x)

    layer_ref = nn.MultiheadAttention(layer.embed_dim, layer.num_heads).to(device)
    layer_ref.load_state_dict({ 'in_proj_weight': torch.cat([layer.q.weight, layer.k.weight, layer.v.weight]), 'out_proj.weight': layer.out.weight }, strict=False)
    res_ref = layer_ref(*[x.view(*x.shape[:2], -1).permute(2, 0, 1)] * 3)[0].permute(1, 2, 0).view(*x.shape)
    assert torch.allclose(res_layer, res_ref, rtol=1e-2, atol=1e-5), 'Multi-head attention result incorrect.'

print('All tests passed.')

All tests passed.
