# Installs

## Open notebook in:
| Colab                                 |  Gradient                                                                                                                                         |
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/https://github.com/Nicolepcx/transformers-the-definitive-guide/blob/main/CH01/ch01_attention_mechanism_variations.ipynb)                                              | [![Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://console.paperspace.com//github.com/Nicolepcx/transformers-the-definitive-guide/blob/main/CH01/ch01_attention_mechanism_variations.ipynb)|             

# About this notebook


In this notebook you can run and try out the different attention mechanism variations as introduced about in the book.


In [1]:
!pip install einops -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25h

# Imports

In [2]:
import torch
import torch.nn.functional as F
from einops import rearrange

# Multi-head attention

In [3]:
def MultiheadAttention(x, M, W_query, W_key, W_value, P_o):
    """
    Multi-head Attention on one query.
    Args:
    x: a vector with shape [d]
    M: a matrix with shape [m, d]
    W_query: a tensor with shape [h, d, k]
    W_key: a tensor with shape [h, d, k]
    W_value: a tensor with shape [h, d, v]
    P_o: a tensor with shape [h, d, v]
    Returns:
    y: a vector with shape [d]
    """
    scaling_factor = W_key.shape[1]**0.5

    Q = torch.einsum('d,hdk->hk', x, W_query)
    print("Q shape: ",Q.shape)
    K = torch.einsum('md,hdk->hmk', M, W_key)
    print("K shape: ",K.shape)
    V = torch.einsum('md,hdv->hmv', M, W_value)
    print("V shape: ",V.shape)
    attn_scores = torch.einsum('hk,hmk->hm', Q, K) / scaling_factor
    attn_weights = F.softmax(attn_scores, dim=-1)

    o = torch.einsum('hm,hmv->hv', attn_weights, V)
    y = torch.einsum('hv,hdv->d', o, P_o)
    return y


## Running multi-head attention with parameters

In [4]:
# Dimensions for tensors
d, m, h, k, v = 4, 3, 2, 5, 5

# Input tensors
x = torch.randn(d)
M = torch.randn(m, d)
W_query = torch.randn(h, d, k)
W_key = torch.randn(h, d, k)
W_value = torch.randn(h, d, v)
P_o = torch.randn(h, d, v)

# Call the function
y = MultiheadAttention(x, M, W_query, W_key, W_value, P_o)
y.shape

Q shape:  torch.Size([2, 5])
K shape:  torch.Size([2, 3, 5])
V shape:  torch.Size([2, 3, 5])


torch.Size([4])

# Multi-query attention

In [5]:
def MultiqueryAttention(X, M, mask, W_query, W_key, W_value, P_o):
    """
    Multi-Query Attention.
    Args:
    X: a tensor with shape [b, n, d]
    M: a tensor with shape [b, m, d]
    mask: a tensor with shape [b, h, n, m]
    W_query: a tensor with shape [h, d, k]
    W_key: a tensor with shape [d, k]
    P_v: a tensor with shape [d, v]
    P_o: a tensor with shape [h, d, v]
    Returns:
    Y: a tensor with shape [b, n, d]
    """
    scaling_factor = W_key.shape[1]**0.5

    # Apply projections using einsum
    Q = torch.einsum('bnd,hdk->bhnk', X, W_query)
    print("Q shape: ",Q.shape)
    K = torch.einsum('bmd,dk->bmk', M, W_key)
    print("K shape: ",K.shape)
    V = torch.einsum('bmd,dv->bmv', M, W_value)
    print("V shape: ",V.shape)
    attn_scores = torch.einsum('bhnk,bmk->bhnm', Q, K) / scaling_factor
    attn_weights = F.softmax(attn_scores + mask, dim=-1)

    O = torch.einsum('bhnm,bmv->bhnv', attn_weights, V)
    Y = torch.einsum('bhnv,hdv->bnd', O, P_o)
    return Y

## Running multi-query attention with parameters

In [6]:
# Setup parameters
d, m, h, k, v = 4, 3, 2, 5, 5  # dimensions for d: vector dimension, m: number of keys/values, h: number of heads, k and v: key/value dimensions
b, n = 2, 3  # Batch size, number of queries

# Input tensors
X = torch.randn(b, n, d)
M = torch.randn(b, m, d)
mask = torch.zeros(b, h, n, m)  # Attention mask
W_query = torch.randn(h, d, k)
W_key = torch.randn(d, k)
W_value = torch.randn(d, v)
P_o = torch.randn(h, d, v)

# Call the corrected function
y = MultiqueryAttention(X, M, mask, W_query, W_key, W_value, P_o)

y.shape


Q shape:  torch.Size([2, 2, 3, 5])
K shape:  torch.Size([2, 3, 5])
V shape:  torch.Size([2, 3, 5])


torch.Size([2, 3, 4])

# Cross attention

In [7]:
def CrossAttention(x_1, x_2, W_query, W_key, W_value):
    """
    Cross-Attention mechanism.
    Args:
    x_1: a tensor with shape [b1, d_in] where b1 is the batch size for x_1
    x_2: a tensor with shape [b2, d_in] where b2 is the batch size for x_2
    W_query: a tensor with shape [d_in, d_out_kq]
    W_key: a tensor with shape [d_in, d_out_kq]
    W_value: a tensor with shape [d_in, d_out_v]
    Returns:
    Y: a tensor with shape [b1, d_out_v] representing the weighted sum of values based on the attention.
    """
    scaling_factor = W_query.shape[1]**0.5

    # Obtain queries, keys, and values from x_1 and x_2
    Q = torch.einsum('bd,dk->bk', x_1, W_query) # Shape [b1, d_out_kq]
    K = torch.einsum('bd,dk->bk', x_2, W_key)  # Shape[b2, d_out_kq]
    V = torch.einsum('bd,dv->bv', x_2, W_value)  # Shape [b2, d_out_v]

    # Compute attention scores and weights
    attn_scores = torch.einsum('bk,mk->bm', Q, K) # Shape [b1, b2]
    attn_weights = F.softmax(attn_scores / scaling_factor, dim=-1)  # Softmax over n2 dimension

    # Apply attention weights to the values
    Y = torch.einsum('bm,mv->bv', attn_weights, V)  # [n1, d_out_v]

    return Y


## Running cross attention with parameters

In [8]:
d_in, d_out_kq, d_out_v, b1, b2 = 4, 5, 6, 2, 3
x_1 = torch.randn(b1, d_in)
x_2 = torch.randn(b2, d_in)
W_query = torch.randn(d_in, d_out_kq)
W_key = torch.randn(d_in, d_out_kq)
W_value = torch.randn(d_in, d_out_v)

# Call CrossAttention function
Y = CrossAttention(x_1, x_2, W_query, W_key, W_value)
print("Output shape: ", Y.shape)

Output shape:  torch.Size([2, 6])


# Grouped-query attention

In [9]:
def GroupedQueryAttention(Q, K, V, num_heads, group_size):
    batch_size, seq_len, embed_dim = Q.shape
    scaling_factor = (embed_dim // num_heads) ** 0.5

    # Reshape Q, K, V for grouped multi-head attention computation
    Q = rearrange(Q, 'b s (h d) -> (b h) s d', h=num_heads)
    K = rearrange(K, 'b s (h d) -> (b h) s d', h=num_heads)
    V = rearrange(V, 'b s (h d) -> (b h) s d', h=num_heads)

    # Compute attention scores using scaled dot-product attention
    attn_scores = torch.einsum('bid,bjd->bij', Q, K) / scaling_factor
    attn_weights = F.softmax(attn_scores, dim=-1)

    # Apply attention weights to the values
    attn_output = torch.einsum('bij,bjd->bid', attn_weights, V)

    # Reshape back to the original dimensions
    y = rearrange(attn_output, '(b h) s d -> b s (h d)', b=batch_size, h=num_heads)

    return Y



## Running multi-head attention with parameters

In [10]:
# Example dimensions and data
batch_size = 2
seq_len = 256
embed_dim = 64
num_heads = 8
group_size = 2  # This value is now not directly used but is implied in the num_heads.

Q = torch.randn(batch_size, seq_len, embed_dim)
K = torch.randn(batch_size, seq_len, embed_dim)
V = torch.randn(batch_size, seq_len, embed_dim)

# Call function
output = GroupedQueryAttention(Q, K, V, num_heads, group_size)
print("Output shape: ", output.shape)


Output shape:  torch.Size([2, 6])
