In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [39]:
class MultiHeadAttention(nn.Module) : 
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,bias=False):
        super().__init__()
        assert(d_out % num_heads ==0 ) ,"d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        
        self.head_dim = d_out // num_heads
        
        self.w_q = nn.Linear(d_in,d_out,bias=bias)
        self.w_k = nn.Linear(d_in,d_out,bias=bias)
        self.w_v = nn.Linear(d_in,d_out,bias=bias)
        self.out_proj = nn.Linear(d_out,d_out)
        self.dropout  = nn.Dropout(dropout) 
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length,context_length),diagonal=1)
        )
    
    def forward(self,x):
        b,num_tokens,d_embed= x.shape
        
        queries = self.w_q(x) ## (b , num_tokens , d_out)
        print(f"queries :\n {queries}")
        keys    = self.w_k(x) ## (b , num_tokens , d_out)
        print(f"keys    : \n {keys}")
        values  = self.w_v(x) ## (b , num_tokens , d_out)
        print(f"values  : \n {values}")
        ## (b , num_tokens , d_out) -------->  (b , num_tokens , num_heads , head_dim)
        queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)
        keys    = keys.view(b,num_tokens,self.num_heads,self.head_dim)
        values  = values.view(b,num_tokens,self.num_heads,self.head_dim)
        print(f"queries after view :\n {queries}")
        print(f"keys after view    :\n {keys}")
        print(f"values after view  :\n {values}")
        # (b , num_tokens , num_heads , head_dim) ------> (b , num_heads , num_tokens , head_dim)
        ## in this case each head should be able to access all the tokens but with different embeddings (keys values splitted over the different heads)
        keys    = keys.transpose(1,2) 
        queries = queries.transpose(1,2) 
        values  = values.transpose(1,2) 
        print(f"queries after transpose :\n {queries}")
        print(f"keys after transpose    :\n {keys}")
        print(f"values after transpose  :\n {values}")
        attn_scores = queries @ keys.transpose(-1,-2)
        raw_scores = attn_scores
        print(f"Attention Scores : \n {attn_scores}")
        mask_bool = self.mask.bool()[:num_tokens,:num_tokens]
        attn_scores_masked = attn_scores.masked_fill_(mask_bool,-torch.inf)
        print(f"Attention Scores masked : \n {attn_scores_masked}")
        attn_weights = F.softmax(attn_scores_masked / keys.shape[-1]**0.5,dim=-1)
        print(f"Attention Weights : \n {attn_weights}")
        attn_weights = self.dropout(attn_weights)
        
        context = ( attn_weights @ values ).transpose(1,2)# (b , num_heads , num_tokens , head_dim) ------> (b , num_tokens , num_heads , head_dim) 
        
        context = context.contiguous().view(b,num_tokens,self.d_out)
        
        context = self.out_proj(context) # optional linear projection layer 
        
        return context , raw_scores

In [55]:
sentence = "Attention weights represent how much each input token attends to others after softmax normalization."
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2") ## BPE
tokens = tokenizer.encode(sentence)
print(tokens)

[8086, 1463, 19590, 2380, 703, 881, 1123, 5128, 11241, 32743, 284, 1854, 706, 2705, 9806, 3487, 1634, 13]


In [62]:
emb = nn.Embedding(50257,64)
sentence_emb = emb(torch.tensor(tokens))
sentence_emb.shape

torch.Size([18, 64])

In [65]:
torch.manual_seed(123)
batch_size,context_length,d_in = sentence_emb.unsqueeze(0).shape
d_out=18
mha = MultiHeadAttention(d_in,d_out,context_length,0.2,num_heads=6)

context , attn_scores = mha(sentence_emb.unsqueeze(0)) ## unsqueeze(0) to 

print(f"context.shape : ",context.shape)

print(f"Context after projection layer : ",context)

queries :
 tensor([[[ 0.4309,  0.1749, -0.6359,  0.2238, -0.2205,  0.9382,  0.7738,
          -0.9971, -0.3518, -0.1009,  0.2387, -0.3538,  0.2927, -0.6940,
           1.1900,  0.0804,  0.5503,  1.1136],
         [ 0.7838,  0.8262,  0.6356, -0.4975,  0.0027,  0.9358, -0.6377,
          -0.0601, -0.3072, -0.4414, -0.9806, -0.1052,  0.3955,  0.1093,
          -0.9104,  0.2250,  0.9217, -0.2534],
         [-0.2281,  0.5944, -0.2324,  0.5608, -0.5354, -0.1085,  0.3186,
           0.0559,  0.1822,  0.5233, -0.6498, -0.4942, -0.7795, -0.3642,
           0.8058, -0.1098,  0.8982,  0.5762],
         [-0.2947,  0.2223,  0.3071,  0.0801,  0.9876, -0.8423,  0.4898,
           0.3534, -0.6473, -0.7112,  0.1846,  0.3498,  0.2004,  0.2259,
           0.0214, -0.3832,  0.2067, -0.6545],
         [-0.8139, -1.0427, -0.6891, -0.5236,  0.1778,  0.5880, -0.2531,
           0.1877,  0.4703,  0.0518, -0.1646, -0.5282, -0.3228, -0.0336,
           0.5984,  0.4412,  0.5519,  0.4385],
         [ 0.8522, -0.17

In [70]:
import pandas as pd
import altair as alt
tokens = [8086, 1463, 19590, 2380, 703, 881, 1123, 5128, 11241, 32743, 284, 1854, 706, 2705, 9806, 3487, 1634, 13]
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

In [74]:
all_charts = []
for head in range(attn_scores.shape[1]):
    df = mtx2df(attn_scores[0, head].detach(), 18, 18, tokens, tokens)
    chart = alt.Chart(data=df).mark_rect().encode(
        x=alt.X("col_token", axis=alt.Axis(title="Key")),
        y=alt.Y("row_token", axis=alt.Axis(title="Query")),
        color="value",
        tooltip=["row", "column", "value", "row_token", "col_token"],
    ).properties(
        title=f"Head {head}", width=250, height=250
    ).interactive()
    all_charts.append(chart)

# Split the charts into rows of 3 charts each
rows = [alt.hconcat(*all_charts[i:i+3]) for i in range(0, len(all_charts), 3)]

# Vertically concatenate the rows
final_chart = alt.vconcat(*rows)

final_chart
