<a href="https://colab.research.google.com/github/Shakib-IO/Python_Practice/blob/main/Multi_Head_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[Medium](https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a)

In [48]:
import math
import torch
import torch.nn as nn

#### Model Input
- X has a size of **(batch_size, seq_length, d_model)**. For example, a batch of 32 sequences of length 10 with an embedding of 512, which would have a shape of **(32, 10, 512)**.
- Wq, Wk, and Wv have a size of **(d_model, d_model)**. Following the example above, they would have a shape of **(512, 512)**.
- **Q = XWq** | (batch_size, seq_length, d_model) x (d_model, d_model) = (**batch_size, seq_length, d_model)**
- **K = XWk** | (batch_size, seq_length, d_model) x (d_model, d_model) = **(batch_size, seq_length, d_model)**
- **V = XWv** | (batch_size, seq_length, d_model) x (d_model, d_model) = **(batch_size, seq_length, d_model)**

#### Creating Sample Input

In [18]:
# Main String
sequences = ["I wonder what will come next!",
             "This is a basic example paragraph.",
             "Hello, what is a basic split?"]

In [3]:
def tokenize(sequence):
  # remove punctuation
  for punc in ["!", ".", "?"]:
    sequence = sequence.replace(punc, "")

  # split the sequence on spaces and lowercase each token
  return [token.lower() for token in sequence.split(" ")]

In [30]:
def build_vocab(data):
  # Join the sequences into a single string before tokenization
  data_str = " ".join(data)

  # tokenize the data and remove duplicates
  vocab = list(set(tokenize(data_str))) # Pass the combined string to tokenize

  # sort the vocabulary
  vocab.sort()

  # assign an integer to each word
  stoi = {word:i for i, word in enumerate(vocab)}

  return stoi

# Pass the list of sequences directly to build_vocab
stoi = build_vocab(sequences)

In [49]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model: int, dropout: float = 0.1, max_length: int = 5000):
    """
    Args:
      d_model:      dimension of embeddings
      dropout:      randomly zeroes-out some of the input
      max_length:   max sequence length
    """
    # inherit from Module
    super().__init__()

    # initialize dropout
    self.dropout = nn.Dropout(p=dropout)

    # create tensor of 0s
    pe = torch.zeros(max_length, d_model)

    # create position column
    k = torch.arange(0, max_length).unsqueeze(1)

    # calc divisor for positional encoding
    div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
    )

    # calc sine on even indices
    pe[:, 0::2] = torch.sin(k * div_term)

    # calc cosine on odd indices
    pe[:, 1::2] = torch.cos(k * div_term)

    # add dimension
    pe = pe.unsqueeze(0)

    # buffers are saved in state_dict but not trained by the optimizer
    self.register_buffer("pe", pe)

  def forward(self, x: torch.Tensor):
    """
    Args:
      x:        embeddings (batch_size, seq_length, d_model)

    Returns:
                embeddings + positional encodings (batch_size, seq_length, d_model)
    """
    # add positional encoding to the embeddings
    x = x + self.pe[:, : x.size(1)].requires_grad_(False)

    # perform dropout
    return self.dropout(x)

In [50]:
tokenized_sequences = [tokenize(seq) for seq in sequences]

# convert the sequences to integers
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]

# conver thr sequence to a tensor
tensor_seq = torch.tensor(indexed_sequences).long() # 3x6

# Vocab_size
vocab_size = len(stoi) # 14

# D_model (Embedding Dimention)
d_model = 8

# Create the Embeddings
emb = nn.Embedding(vocab_size, d_model) # Create an embedding size of (14, 8)

# Positional Embedding
pos_embd = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)

# Embedded the Sequebce
lut = emb(tensor_seq) # torch.Size([3, 6, 8])

# Add the positional encodings
X = pos_embd(lut)
X # torch.Size([3, 6, 8])

tensor([[[-0.9965,  0.0000, -0.1276, -0.1866, -0.0046,  0.1521, -1.7051,
           1.3467],
         [ 0.7666,  3.9468, -0.3194,  2.2877,  0.0000,  0.5118, -0.4217,
           1.7311],
         [ 0.6534, -0.3180,  0.0492, -1.0483, -0.0000,  1.9002, -1.1002,
           1.9462],
         [-0.8726, -1.5124, -0.1749, -0.2964,  1.4201,  0.1977, -0.4234,
           2.9098],
         [-0.8739, -1.9751, -1.6096,  1.1916, -0.2003,  0.0000,  0.9831,
          -0.4254],
         [-0.2143,  0.6313,  1.1066, -0.6754, -1.7632,  1.9809,  1.3412,
           1.7196]],

        [[-0.0000,  2.1747, -2.0839,  2.8800, -0.7512,  1.8371, -0.2812,
           0.0000],
         [ 2.9189,  0.5060, -1.7257,  0.5813,  0.0000,  2.0329,  0.7927,
           0.6572],
         [ 1.2726, -0.1309, -0.7897,  1.2685,  0.6416,  0.0000, -0.1389,
           1.3736],
         [ 1.1722, -0.2704, -0.2394,  1.1643,  0.0000, -0.0357,  0.4561,
           1.3206],
         [-1.1534, -0.2791,  0.0000, -1.3569, -0.3477, -1.4234,  1.1

At this point, the embedded sequences, **X**, have a shape of **(3, 6, 8)**. There are **3 sequences** of **6 tokens** with an **8 dimensional embedding**.

In [51]:
# Now lets create Wq, Wk, Wv
Wq = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wq = 8x8
Wk = nn.Linear(in_features=d_model, out_features=d_model, bias=False)  # Wk = 8x8
Wv = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # Wk = 8x8

Wq.state_dict()['weight']

tensor([[ 0.3407,  0.1125,  0.0422,  0.0085, -0.2184,  0.1037, -0.2194, -0.0895],
        [ 0.1059, -0.1269,  0.0787, -0.0451, -0.3095,  0.2400, -0.1946,  0.3119],
        [-0.2710, -0.0675, -0.2163, -0.0774,  0.3440,  0.0519,  0.0586, -0.0065],
        [ 0.0851,  0.0432, -0.2403,  0.0841, -0.2135, -0.1283, -0.1478,  0.1103],
        [ 0.3269,  0.2589,  0.2757,  0.0331, -0.2540,  0.0685,  0.2441,  0.1574],
        [ 0.2386,  0.0038, -0.1266, -0.0983, -0.1232,  0.1283, -0.0568,  0.3321],
        [-0.2552, -0.3156,  0.3358,  0.0693, -0.2996, -0.3025,  0.3484,  0.0366],
        [ 0.0500, -0.1178, -0.1273,  0.3405,  0.0551,  0.2097,  0.0049,  0.0223]])

Now transform Q, K, V

In [54]:
Q = Wq(X) # (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)
K = Wk(X) # (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)
V = Wv(X) # (batch_size, seq_length, d_model) x (d_model, d_model) = (batch_size, seq_length, d_model)

Q # (3, 6, 8)

tensor([[[-7.6132e-02,  6.8265e-01,  2.0964e-01,  3.1220e-01, -5.5975e-01,
           3.6098e-01, -3.9090e-01, -4.3866e-02],
         [ 7.0210e-01,  1.9695e-01, -5.9135e-01,  6.9222e-01,  1.4647e+00,
           6.7790e-01, -1.6283e+00,  5.3686e-01],
         [ 4.4426e-01,  1.4378e+00, -6.3673e-02,  7.5431e-02,  2.7797e-01,
           1.2043e+00, -1.0094e+00,  1.4339e-01],
         [-9.3462e-01,  6.9696e-01,  8.5428e-01, -6.7560e-02, -7.2747e-01,
           6.7827e-01,  9.4413e-02,  2.3817e-01],
         [-7.1157e-01, -2.8423e-01,  6.1754e-01,  1.7791e-01, -9.7733e-01,
          -3.0176e-01,  7.7525e-01,  7.8374e-01],
         [ 1.8155e-01,  1.3111e+00, -6.0809e-01, -1.9994e-01,  1.5574e+00,
           8.4386e-01,  6.3969e-01, -9.2711e-02]],

        [[ 5.9778e-01,  1.5845e-01, -9.8266e-02,  8.0302e-01,  3.3197e-01,
           3.3321e-01, -1.6152e+00,  1.3321e+00],
         [ 9.6169e-01,  6.2157e-01, -3.4899e-01,  4.2831e-01,  1.0649e+00,
           1.2939e+00, -1.7586e+00,  9.4875e-01]

Splitting Q, K, and V Into Their Heads <br>
**d_key = (d_model / n_heads)** <br>
Q contains (batch_size, seq, d_model) now, the d_model will split and become the row and column. Now Q will have (batch_size, seq, row, col) or (batch_size, seq, n_heads, d_key). <br>

The shape of each tensor becomes:

**(batch_size, seq_length, d_model) → (batch_size, seq_length, n_heads, d_key)**


In [56]:
batch_size = Q.size(0)
n_heads = 4
d_key = d_model//n_heads # 8/4 = 2

# query tensor | -1 = query_length | (3, 6, 8) -> (3, 6, 4, 2)
Q = Q.view(batch_size, -1, n_heads, d_key)

# value tensor | -1 = key_length | (3, 6, 8) -> (3, 6, 4, 2)
K = K.view(batch_size, -1, n_heads, d_key)

# value tensor | -1 = value_length | (3, 6, 8) -> (3, 6, 4, 2)
V = V.view(batch_size, -1, n_heads, d_key)

Q.shape #torch.Size([3, 6, 4, 2])

torch.Size([3, 6, 4, 2])

To proceed, it would be best to **transpose seq_length and n_heads**, the second and third dimensions, to have the following shape:

**(batch_size, seq_length, n_heads, d_key) → (batch_size, n_heads, seq_length, d_key)**

In [61]:
# query tensor | (3, 6, 4, 2) -> (3, 4, 6, 2)
Q = Q.permute(0, 2, 1, 3)
# key tensor | (3, 6, 4, 2) -> (3, 4, 6, 2)
K = K.permute(0, 2, 1, 3)
# value tensor | (3, 6, 4, 2) -> (3, 4, 6, 2)
V = V.permute(0, 2, 1, 3)

Q.shape # torch.Size([3, 4, 6, 2])

torch.Size([3, 4, 6, 2])

Now Scale-dot multiplication. <br>
Moving forward, the seq_length shape of each tensor will be known by its respective tensor for clarity, **Q_length, K_length, or V_length**:

Q has a shape of **(batch_size, n_heads, Q_length, d_key)**
K has a shape of (batch_size, n_heads, K_length, d_key)
V has a shape of (batch_size, n_heads, V_length, d_key)
The two rightmost dimensions of K must be transposed to change the shape to **(batch_size, n_heads, d_key, K_length)**.

**QK^T** will be:

**(batch_size, n_heads, Q_length, d_key) x (batch_size, n_heads, d_key, K_length) = (batch_size, n_heads, Q_length, K_length)**

In [62]:
# calculate scaled dot product
scaled_dot_prod = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(d_key) # (batch_size, n_heads, Q_length, K_length) | (3, 4, 6, 2) x (3, 4, 2, 6) = (3, 4, 6, 6).
scaled_dot_prod.shape

torch.Size([3, 4, 6, 6])

In [65]:
# apply softmax to get context for each token and others
attn_probs = torch.softmax(scaled_dot_prod, dim=-1) # (batch_size, n_heads, Q_length, K_length)
attn_probs #torch.Size([3, 4, 6, 6])

tensor([[[[0.1835, 0.1388, 0.1919, 0.2339, 0.1192, 0.1328],
          [0.1772, 0.3405, 0.1523, 0.1667, 0.0790, 0.0843],
          [0.2000, 0.2049, 0.1916, 0.2990, 0.0455, 0.0589],
          [0.1412, 0.0446, 0.1795, 0.2113, 0.2042, 0.2192],
          [0.1195, 0.0632, 0.1388, 0.1235, 0.2886, 0.2664],
          [0.1980, 0.1614, 0.2003, 0.2975, 0.0634, 0.0793]],

         [[0.1984, 0.1242, 0.1818, 0.1829, 0.1701, 0.1427],
          [0.1583, 0.0692, 0.1294, 0.2191, 0.3058, 0.1182],
          [0.1675, 0.1530, 0.1639, 0.1735, 0.1798, 0.1622],
          [0.2086, 0.1929, 0.2141, 0.1381, 0.0921, 0.1542],
          [0.2152, 0.1492, 0.2055, 0.1620, 0.1230, 0.1450],
          [0.1236, 0.1834, 0.1303, 0.1633, 0.2136, 0.1858]],

         [[0.1835, 0.1996, 0.1714, 0.1440, 0.1493, 0.1522],
          [0.1122, 0.2190, 0.1456, 0.2217, 0.1875, 0.1140],
          [0.1399, 0.2864, 0.1552, 0.1644, 0.1515, 0.1026],
          [0.1848, 0.2312, 0.1709, 0.1358, 0.1408, 0.1364],
          [0.2069, 0.1438, 0.1752, 0

In [68]:
# multiply attention and values to get reweighted values
A = torch.matmul(attn_probs, V) # (batch_size, n_heads, Q_length, d_key)
A.shape # ([3, 4, 6, 2])

torch.Size([3, 4, 6, 2])

The concatenation reverses the split that was performed originally. The first step is to **transpose n_heads and Q_length**. The second step is to concatenate **n_heads and d_key back together to get d_model**.

Once this is complete, **A will have a shape of (batch_size, Q_length, d_model).**

In [69]:
# transpose from (3, 4, 6, 2) -> (3, 6, 4, 2)
A = A.permute(0, 2, 1, 3).contiguous()

# reshape from (3, 6, 4, 2) -> (3, 6, 8) = (batch_size, Q_length, d_model)
A = A.view(batch_size, -1, n_heads*d_key)

A.shape # torch.Size([3, 6, 8]) == Input: torch.Size([3, 6, 8])

torch.Size([3, 6, 8])

#### MultiHeadAttention

In [71]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
    """
    Args:
        d_model:      dimension of embeddings
        n_heads:      number of self attention heads
        dropout:      probability of dropout occurring
    """
    super().__init__()
    assert d_model % n_heads == 0            # ensure an even num of heads
    self.d_model = d_model                   # 512 dim
    self.n_heads = n_heads                   # 8 heads
    self.d_key = d_model // n_heads          # assume d_value equals d_key | 512/8=64

    self.Wq = nn.Linear(d_model, d_model)    # query weights
    self.Wk = nn.Linear(d_model, d_model)    # key weights
    self.Wv = nn.Linear(d_model, d_model)    # value weights
    self.Wo = nn.Linear(d_model, d_model)    # output weights

    self.dropout = nn.Dropout(p=dropout)     # initialize dropout layer

  def forward(self, query: torch,Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
    """
    Args:
       query:         query vector         (batch_size, q_length, d_model)
       key:           key vector           (batch_size, k_length, d_model)
       value:         value vector         (batch_size, s_length, d_model)
       mask:          mask for decoder

    Returns:
       output:        attention values     (batch_size, q_length, d_model)
       attn_probs:    softmax scores       (batch_size, n_heads, q_length, k_length)
    """
    batch_size = key.size(0)

    # calculate query, key, and value tensors
    Q = self.Wq(query)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)
    K = self.Wk(key)                         # (32, 10, 512) x (512, 512) = (32, 10, 512)
    V = self.Wv(value)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)

    # split each tensor into n-heads to compute attention

    # query tensor
    Q = Q.view(batch_size,                   # (32, 10, 512) -> (32, 10, 8, 64)
               -1,                           # -1 = q_length
               self.n_heads,
               self.d_key
               ).permute(0, 2, 1, 3)         # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, q_length, d_key)
    # key tensor
    K = K.view(batch_size,                   # (32, 10, 512) -> (32, 10, 8, 64)
               -1,                           # -1 = k_length
               self.n_heads,
               self.d_key
               ).permute(0, 2, 1, 3)         # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, k_length, d_key)
    # value tensor
    V = V.view(batch_size,                   # (32, 10, 512) -> (32, 10, 8, 64)
               -1,                           # -1 = v_length
               self.n_heads,
               self.d_key
               ).permute(0, 2, 1, 3)         # (32, 10, 8, 64) -> (32, 8, 10, 64) = (batch_size, n_heads, v_length, d_key)

    # computes attention
    # scaled dot product -> QK^{T}
    scaled_dot_prod = torch.matmul(Q,        # (32, 8, 10, 64) x (32, 8, 64, 10) -> (32, 8, 10, 10) = (batch_size, n_heads, q_length, k_length)
                                   K.permute(0, 1, 3, 2)
                                   ) / math.sqrt(self.d_key)      # sqrt(64)

    # fill those positions of product as (-1e10) where mask positions are 0
    if mask is not None:
      scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e10)

    # apply softmax
    attn_probs = torch.softmax(scaled_dot_prod, dim=-1)

    # multiply by values to get attention
    A = torch.matmul(self.dropout(attn_probs), V)       # (32, 8, 10, 10) x (32, 8, 10, 64) -> (32, 8, 10, 64)
                                                        # (batch_size, n_heads, q_length, k_length) x (batch_size, n_heads, v_length, d_key) -> (batch_size, n_heads, q_length, d_key)

    # reshape attention back to (32, 10, 512)
    A = A.permute(0, 2, 1, 3).contiguous()              # (32, 8, 10, 64) -> (32, 10, 8, 64)
    A = A.view(batch_size, -1, self.n_heads*self.d_key) # (32, 10, 8, 64) -> (32, 10, 8*64) -> (32, 10, 512) = (batch_size, q_length, d_model)

    # push through the final weight layer
    output = self.Wo(A)                                 # (32, 10, 512) x (512, 512) = (32, 10, 512)

    return output, attn_probs                           # return attn_probs for visualization of the scores
