In [1]:
# %pip install torch numpy scikit-learn pandas matplotlib seaborn

In [1]:
import torch
import math
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

# from imports import *

# Self-Attention: Bits, use it inspect and visualize line by line

Let's first setup a small dummy input for testing and visualisation

In [2]:
B = 2
S = 8
E = 16

z = torch.rand(B,S,E)

Now we construct query, key and value from the input

In [3]:
query_proj = nn.Linear(E,E)
key_proj = nn.Linear(E,E)
value_proj = nn.Linear(E,E)

In [4]:
query = query_proj(z)
key = key_proj(z)
value = value_proj(z)

In [5]:
query.shape, key.shape, value.shape

(torch.Size([2, 8, 16]), torch.Size([2, 8, 16]), torch.Size([2, 8, 16]))

Calculate attention logits

In [6]:
key_t = key.transpose(1,2)
key_t.shape # Q.K' : [B,S,E] @ [B,E,S]

torch.Size([2, 16, 8])

In [7]:
unscaled_attention_logits = (query@key_t) # Q.K' : [B,S,E] @ [B,E,S] => [B,S,S]
unscaled_attention_logits.shape

torch.Size([2, 8, 8])

Scale our attention logits

In [8]:
B,S,E = query.shape
scale = math.sqrt(E)
attention_logits = unscaled_attention_logits/scale

In [9]:
attention_logits.shape

torch.Size([2, 8, 8])

Causal masking

In [10]:
mask_value = -torch.inf
causal_mask = torch.tril(torch.ones(B,S,S)) # lower trianglular matrix with 1s
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) # replace upper triangular with -inf or a very large negative number

In [11]:
masked_attention_logits[0]

tensor([[ 0.1104,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0592,  0.1330,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0450,  0.1087,  0.0230,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1133,  0.1776, -0.0067,  0.0827,    -inf,    -inf,    -inf,    -inf],
        [ 0.1373,  0.2221,  0.0218,  0.0863,  0.0739,    -inf,    -inf,    -inf],
        [ 0.1535,  0.2603,  0.0694,  0.1739,  0.1543, -0.0314,    -inf,    -inf],
        [ 0.0565,  0.1109, -0.0508,  0.0451,  0.0285, -0.1157, -0.0383,    -inf],
        [ 0.1122,  0.2090,  0.0232,  0.0874,  0.1077, -0.0731,  0.0905,  0.0062]],
       grad_fn=<SelectBackward0>)

In [12]:
causal_attention_weights = torch.softmax(masked_attention_logits, dim=-1)

In [13]:
causal_attention_weights # Observe lower triangular! Causal attention weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4816, 0.5184, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3285, 0.3501, 0.3214, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2549, 0.2718, 0.2261, 0.2472, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2054, 0.2236, 0.1830, 0.1952, 0.1928, 0.0000, 0.0000, 0.0000],
         [0.1699, 0.1891, 0.1562, 0.1734, 0.1701, 0.1412, 0.0000, 0.0000],
         [0.1500, 0.1584, 0.1347, 0.1483, 0.1459, 0.1263, 0.1364, 0.0000],
         [0.1299, 0.1431, 0.1189, 0.1267, 0.1293, 0.1080, 0.1271, 0.1169]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4380, 0.5620, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2925, 0.3659, 0.3416, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2197, 0.2767, 0.2443, 0.2593, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1777, 0.2216, 0.2190, 0.1891, 0.1927, 0.0000, 0.0000, 0.0000],
         [0.1471, 0.180

In [14]:
attention_output = causal_attention_weights@value

# What about padding tokens?
* In the base implementation we used self-attention on every-tokens, including padding tokens if present!
* Now we extend the base implementation so that attention is ignored for tokens like pad token

## Quick look at a sample tokenizer

In [15]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Important!,usually decoder only models require left padding
texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
encoded_input.keys()

dict_keys(['input_ids', 'attention_mask'])

In [17]:
tokenizer.pad_token_id # pad token id for reference

2

In [18]:
encoded_input['input_ids'][0] # Observe left padding!, 2 is the pad token
# Refer: https://huggingface.co/docs/transformers/llm_tutorial#wrong-padding-side

tensor([    2,     2,     2,     2,     2,     2,     1, 13264,  5509])

In [19]:
encoded_input['attention_mask'][0] # Observe attention mask is zero for pad tokens!

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1])

In [20]:
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

## Handle attention mask in the attention logits
Also look at https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L800

In [27]:
attention_logits = torch.rand(B, S, S) # dummy for demo
attention_logits = attention_logits.masked_fill_(attention_mask[:,None,:] == 0, torch.finfo(attention_logits.dtype).min)
attention_logits

tensor([[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  9.5617e-01,  3.2460e-02,  1.0315e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  6.8204e-01,  9.1030e-01,  9.7699e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  8.2957e-01,  7.8932e-01,  9.3458e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  1.9415e-01,  1.4572e-02,  1.9848e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  2.4337e-01,  3.5080e-01,  2.0121e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  1.0501e-01,  9.4756e-01,  4.5054e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  3.7606e-01,  4.1200e-01,  5.6446e-01],
         [-3.4028e+38, -3.4028e+38

# Putting the code bits together

In [22]:
class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)

In [28]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention

    Not Implemented:
    * KV cacheing
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,S,S = attention_logits.shape

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,S,S)).to(device)

        if mask_value is None:
            mask_value = torch.finfo(attention_logits.dtype).min

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits,mask_value=None):
        if mask_value is None:
            mask_value = torch.finfo(masked_attention_logits.dtype).min
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, mask_value)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output


In [29]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

# K-V caching

In [30]:
"""\

With KV caching:
Input: (B,1,E), kv_cache
Output: (B,1,E)

"""

'\nWith KV caching:\nInput: (B,1,E), kv_cache\nOutput: (B,1,E)\n\n'

In [31]:
# Dummy input for testing
B = 2
S = 4
E = 16

z_old = torch.rand(B,S,E)
z = torch.rand(B,1,E)
z_new = torch.concat([z_old,z],dim=1)

self_atten = BaseSelfAttention(E)
output_old = self_atten(z_old,attention_mask=None)
output_new = self_atten(z_new,attention_mask=None)

print(z.shape)
print(z_old.shape)
print(z_new.shape)

torch.Size([2, 1, 16])
torch.Size([2, 4, 16])
torch.Size([2, 5, 16])


In [32]:
class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)

In [33]:
# store previous K,V in the cache
cache = KVCache(output_old.key,
                output_old.value)

In [34]:
# Construct K,V from cache

new_key = self_atten.key_proj(z)
new_value = self_atten.value_proj(z)
query = self_atten.query_proj(z)

In [35]:
key = torch.concat([cache.key,
                    new_key],dim=1)
value = torch.concat([cache.value,
                      new_value],dim=1)

In [36]:
query.shape

torch.Size([2, 1, 16])

In [37]:
key.shape

torch.Size([2, 5, 16])

In [38]:
# remaining operations remains the same

# Calculate logits
unmasked_attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

# Apply causal masking
masked_attention_logits = self_atten.apply_causal_mask(unmasked_attention_logits)

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

RuntimeError: output with shape [2, 1, 5] doesn't match the broadcast shape [2, 5, 5]

Oh no, causal mask!

In [39]:
# let's replicate it

attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

B,Sq,Sk = attention_logits.shape # <---------Change 1

device = attention_logits.device
causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device) # <------------Change 2

print(causal_mask.shape)
causal_mask = causal_mask[:,-Sq:,-Sk:] # <------------Change 3
print(causal_mask.shape)

if mask_value is None:
    mask_value = -torch.inf

# replace upper triangular with -inf or a very large negative number, causal masking!
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 

torch.Size([2, 5, 5])
torch.Size([2, 1, 5])


In [40]:
# continue from that

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

In [41]:
output_new.attention_output[:,-1:,:]

tensor([[[ 0.0332,  0.1659, -0.0441,  0.4275, -0.4797, -0.0975, -0.0763,
          -0.1674, -0.3293, -0.1751, -0.1188, -0.1235, -0.2550,  0.3202,
           0.0678,  0.0646]],

        [[ 0.0223,  0.1878, -0.0288,  0.4413, -0.5024, -0.1070, -0.0152,
          -0.1167, -0.3277, -0.1006, -0.1549, -0.1229, -0.3062,  0.2441,
           0.0094,  0.0572]]], grad_fn=<SliceBackward0>)

In [42]:
output_with_caching.attention_output

tensor([[[ 0.0332,  0.1659, -0.0441,  0.4275, -0.4797, -0.0975, -0.0763,
          -0.1674, -0.3293, -0.1751, -0.1188, -0.1235, -0.2550,  0.3202,
           0.0678,  0.0646]],

        [[ 0.0223,  0.1878, -0.0288,  0.4413, -0.5024, -0.1070, -0.0152,
          -0.1167, -0.3277, -0.1006, -0.1549, -0.1229, -0.3062,  0.2441,
           0.0094,  0.0572]]], grad_fn=<ViewBackward0>)

Yay! matches!

# Merge it with the baseline

In [43]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x,kv_cache):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        if kv_cache is not None:
            key = torch.concat([kv_cache.key,
                    key],dim=1)
            value = torch.concat([kv_cache.value,
                      value],dim=1)

        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,Sq,Sk = attention_logits.shape 

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device)
        causal_mask = causal_mask[:,-Sq:,-Sk:] # Trim off, for kv_cache, no-op if kv_cache is None

        if mask_value is None:
            mask_value = torch.finfo(attention_logits.dtype).min

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits,mask_value=None):
        if mask_value is None:
            mask_value = torch.finfo(masked_attention_logits.dtype).min
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, mask_value)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None,kv_cache=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x,kv_cache=kv_cache)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output
    


class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)


class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)


In [44]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

In [46]:
# Dummy input for kv-cache testing
B = 2
S = 4
E = 16

z = torch.rand(B,S,E)
self_atten = BaseSelfAttention(E)

output = self_atten(z,attention_mask=None)
cache = KVCache(output.key,output.value) # Construct kv cache

# new token
_z = torch.rand(B,1,E)

# Without KV Caching
z_new = torch.concat([z,_z],dim=1)
output_new = self_atten(z_new,attention_mask=None)

# With KV Caching
output_with_caching = self_atten(_z,attention_mask=None,kv_cache=cache)

In [47]:
output_new.attention_output[:,-1,:]

tensor([[ 0.2702, -0.0455, -0.0794,  0.0350, -0.1609,  0.1403,  0.1773, -0.4252,
          0.0762, -0.2927,  0.3254, -0.1199,  0.0784, -0.3484, -0.4058,  0.0879],
        [ 0.2043, -0.0176, -0.0551, -0.0361, -0.1820,  0.0297,  0.1552, -0.3377,
          0.1246, -0.3020,  0.2412, -0.1095,  0.1516, -0.2699, -0.3367,  0.1868]],
       grad_fn=<SliceBackward0>)

In [48]:
output_with_caching.attention_output

tensor([[[ 0.2702, -0.0455, -0.0794,  0.0350, -0.1609,  0.1403,  0.1773,
          -0.4252,  0.0762, -0.2927,  0.3254, -0.1199,  0.0784, -0.3484,
          -0.4058,  0.0879]],

        [[ 0.2043, -0.0176, -0.0551, -0.0361, -0.1820,  0.0297,  0.1552,
          -0.3377,  0.1246, -0.3020,  0.2412, -0.1095,  0.1516, -0.2699,
          -0.3367,  0.1868]]], grad_fn=<ViewBackward0>)

In [49]:
output_new.attention_output[:,-1:,:] - output_with_caching.attention_output
# Just to verify, some numerical errors are expected only

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.4901e-08,
           2.9802e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00, -1.4901e-08,  2.9802e-08,  0.0000e+00,
          -7.4506e-09]],

        [[-1.4901e-08, -2.2352e-08,  0.0000e+00, -7.4506e-09,  1.4901e-08,
          -3.7253e-09,  0.0000e+00,  0.0000e+00,  2.2352e-08, -2.9802e-08,
           2.9802e-08, -7.4506e-09,  0.0000e+00,  0.0000e+00,  2.9802e-08,
           0.0000e+00]]], grad_fn=<SubBackward0>)

# Extend to multiple-heads | Multi-Headed Attention
Functionality wise most of the code remains same. Hacky way to reuse self-attention code is by reshaping the q,k,v to shape `(batch*num_heads, S, E//num_heads)`. 

We will implement from scratch with seperate dimension for heads for better flexibility later on!

In [50]:
# Split Q,K,V into heads

In [51]:
B = 4
S = 8
E = 32
z = torch.rand(B,S,E)

self_attn = BaseSelfAttention(E)
output = self_attn(z)

In [52]:
# Reuse the self-attention implementation

query,key,value = self_attn.construct_query_key_value(z,kv_cache=None)

In [53]:
num_heads = 4


B,S,E = query.shape
assert E%num_heads==0, "embed_dim must be divisible by num_heads"

In [54]:
query = query.view(B,num_heads,S,E//num_heads)
key = key.view(B,num_heads,S,E//num_heads)
value = value.view(B,num_heads,S,E//num_heads)

In [55]:
query.shape # now (B,num_heads,S,E//num_heads)

torch.Size([4, 4, 8, 8])

In [56]:
# After this, the operations are same as the original self-attention.
# however,Some changes are required to incoperate an additional head dimension
# After all operations, we merge back the heads before final output projection

In [57]:
attention_output = torch.rand(B,num_heads,S,E//num_heads) # dummy for testing

attention_output = attention_output.view(B,S,E) # go back to (B,S,E)

# Extend self-attention to multi-headed attention

In [58]:
z = torch.arange(6*32).view(2,3,32)

In [59]:
B,S,E = z.shape

In [60]:
num_heads = 8


In [69]:
class BaseMultiHeadedAttention(nn.Module):
    """
    Minimal Base Self Attention
    """
    def __init__(self,embed_dim,num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def _split_head(self,x,num_heads):
        B,S,E = x.shape
        assert E%num_heads==0, "embed_dim must be divisible by num_heads"
        x = x.view(B,S,E//num_heads,num_heads) # B,S,E,H
        x = x.permute([0,3,1,2]) # B,H, S, E
        return x
    
    def _merge_head(self,x):
        B,H,S,E = x.shape
        x = x.permute([0,2,1,3]) # B,S,H,E
        x = x.reshape(B,S,H*E) # B,S,H*E
        return x


    def create_heads(self,query,key,value):
        B,S,E = query.shape
        assert E%self.num_heads==0, "embed_dim must be divisible by num_heads"
        # query = query.view(B,self.num_heads,S,E//self.num_heads)
        # key = key.view(B,self.num_heads,S,E//self.num_heads)
        # value = value.view(B,self.num_heads,S,E//self.num_heads)
        query = self._split_head(query,self.num_heads)
        key = self._split_head(key,self.num_heads)
        value = self._split_head(value,self.num_heads)
        return query,key,value

    def construct_query_key_value(self,x,kv_cache):
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)

        query,key,value = self.create_heads(query,key,value)
        
        if kv_cache is not None:
            key = torch.concat([kv_cache.key,
                    key],dim=2) # B,H,S,E. dim becomes 2
            value = torch.concat([kv_cache.value,
                      value],dim=2)
        
        return query,key,value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,H,S,E] @ [B,H,E,S]
        key_t = key.transpose(2,3) # Transpose to [B,H,E,S] by exchanging dim 2 and 3
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale
        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,H,Sq,Sk = attention_logits.shape 

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,H,Sk,Sk)).to(device)
        causal_mask = causal_mask[:,:,-Sq:,-Sk:] # Trim off, for kv_cache, no-op if kv_cache is None

        if mask_value is None:
            mask_value = torch.finfo(attention_logits.dtype).min

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits,mask_value=None):
        if mask_value is None:
            mask_value = torch.finfo(masked_attention_logits.dtype).min
        masked_attention_logits = masked_attention_logits.masked_fill_(
            attention_mask[:,None,None,:] == 0, # Additional dimension for heads
            mask_value)
        return masked_attention_logits
    
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        B,H,_,E = value.shape
        B,H,Sq,Sk = attention_weights.shape

        attention_output = attention_weights@value
        
        attention_output = self._merge_head(attention_output) # flatten back to (B,S,E)

        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None,kv_cache=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x,kv_cache=kv_cache)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output

In [70]:
B = 4
S = 3
E = 32
H = 8
z = torch.rand(B,S,E)

multi_headed_attn = BaseMultiHeadedAttention(E,H)
output = multi_headed_attn(z)

In [71]:
output.attention_output.shape

torch.Size([4, 3, 32])

In [72]:
# Dummy input for kv-cache testing
B = 2
S = 4
E = 16
H = 8

z = torch.rand(B,S,E)
multi_headed_attn = BaseMultiHeadedAttention(E,H)

output = multi_headed_attn(z,attention_mask=None)
cache = KVCache(output.key,output.value) # Construct kv cache

# new token
_z = torch.rand(B,1,E)

# Without KV Caching
z_new = torch.concat([z,_z],dim=1)
output_new = multi_headed_attn(z_new,attention_mask=None)

# With KV Caching
output_with_caching = multi_headed_attn(_z,attention_mask=None,kv_cache=cache)

In [73]:
output_new.attention_output[:,-1:,:]

tensor([[[ 0.0607, -0.0614, -0.2348, -0.4526, -0.0983,  0.3194,  0.0798,
           0.0944,  0.2969, -0.3214,  0.0311,  0.3004, -0.1775, -0.5601,
           0.1528,  0.1161]],

        [[ 0.0357, -0.0586, -0.2431, -0.4274, -0.1170,  0.3940,  0.0592,
           0.1295,  0.4053, -0.3705,  0.0928,  0.3760, -0.1673, -0.5725,
           0.0668,  0.0803]]], grad_fn=<SliceBackward0>)

In [74]:
output_with_caching.attention_output

tensor([[[ 0.0607, -0.0614, -0.2348, -0.4526, -0.0983,  0.3194,  0.0798,
           0.0944,  0.2969, -0.3214,  0.0311,  0.3004, -0.1775, -0.5601,
           0.1528,  0.1161]],

        [[ 0.0357, -0.0586, -0.2431, -0.4274, -0.1170,  0.3940,  0.0592,
           0.1295,  0.4053, -0.3705,  0.0928,  0.3760, -0.1673, -0.5725,
           0.0668,  0.0803]]], grad_fn=<ViewBackward0>)

In [75]:
# Test attention mask
E = 16
H = 4

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



multi_headed_attn = BaseMultiHeadedAttention(E,H)
output = multi_headed_attn(z,attention_mask)
assert output.attention_output.shape == (B,S,E)