In [1]:
# %pip install torch numpy scikit-learn pandas matplotlib seaborn

In [2]:
import torch
import math
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

# from imports import *

# Self-Attention: Bits, use it inspect and visualize line by line

Let's first setup a small dummy input for testing and visualisation

In [3]:
B = 2
S = 8
E = 16

z = torch.rand(B,S,E)

Now we construct query, key and value from the input

In [4]:
query_proj = nn.Linear(E,E)
key_proj = nn.Linear(E,E)
value_proj = nn.Linear(E,E)

In [5]:
query = query_proj(z)
key = key_proj(z)
value = value_proj(z)

In [6]:
query.shape, key.shape, value.shape

(torch.Size([2, 8, 16]), torch.Size([2, 8, 16]), torch.Size([2, 8, 16]))

Calculate attention logits

In [7]:
key_t = key.transpose(1,2)
key_t.shape # Q.K' : [B,S,E] @ [B,E,S]

torch.Size([2, 16, 8])

In [8]:
unscaled_attention_logits = (query@key_t) # Q.K' : [B,S,E] @ [B,E,S] => [B,S,S]
unscaled_attention_logits.shape

torch.Size([2, 8, 8])

Scale our attention logits

In [9]:
B,S,E = query.shape
scale = math.sqrt(E)
attention_logits = unscaled_attention_logits/scale

In [10]:
attention_logits.shape

torch.Size([2, 8, 8])

Causal masking

In [11]:
mask_value = -torch.inf
causal_mask = torch.tril(torch.ones(B,S,S)) # lower trianglular matrix with 1s
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) # replace upper triangular with -inf or a very large negative number

In [12]:
masked_attention_logits[0]

tensor([[-0.1849,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1462, -0.1943,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.2337, -0.2623, -0.1436,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1217, -0.2063, -0.1253, -0.1600,    -inf,    -inf,    -inf,    -inf],
        [-0.2344, -0.2627, -0.3572, -0.2270, -0.1346,    -inf,    -inf,    -inf],
        [-0.1876, -0.1997, -0.1351, -0.1944, -0.1529, -0.1729,    -inf,    -inf],
        [-0.2344, -0.2535, -0.2790, -0.2311, -0.1611, -0.2323, -0.1757,    -inf],
        [-0.2373, -0.2316, -0.2593, -0.1755, -0.1306, -0.1911, -0.1500, -0.1635]],
       grad_fn=<SelectBackward0>)

In [13]:
causal_attention_weights = torch.softmax(masked_attention_logits, dim=-1)

In [14]:
causal_attention_weights # Observe lower triangular! Causal attention weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5120, 0.4880, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3261, 0.3170, 0.3569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2579, 0.2370, 0.2570, 0.2482, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2013, 0.1956, 0.1780, 0.2027, 0.2224, 0.0000, 0.0000, 0.0000],
         [0.1643, 0.1624, 0.1732, 0.1632, 0.1701, 0.1668, 0.0000, 0.0000],
         [0.1413, 0.1386, 0.1351, 0.1417, 0.1520, 0.1416, 0.1498, 0.0000],
         [0.1194, 0.1201, 0.1168, 0.1270, 0.1328, 0.1250, 0.1303, 0.1285]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5205, 0.4795, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3454, 0.3293, 0.3253, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2427, 0.2459, 0.2474, 0.2640, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2016, 0.2007, 0.2048, 0.2024, 0.1905, 0.0000, 0.0000, 0.0000],
         [0.1580, 0.161

In [15]:
attention_output = causal_attention_weights@value

# What about padding tokens?
* In the base implementation we used self-attention on every-tokens, including padding tokens if present!
* Now we extend the base implementation so that attention is ignored for tokens like pad token

## Quick look at a sample tokenizer

In [16]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Important!,usually decoder only models require left padding
texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)

In [17]:
encoded_input.keys()

dict_keys(['input_ids', 'attention_mask'])

In [18]:
tokenizer.pad_token_id # pad token id for reference

2

In [19]:
encoded_input['input_ids'][0] # Observe left padding!, 2 is the pad token
# Refer: https://huggingface.co/docs/transformers/llm_tutorial#wrong-padding-side

tensor([    2,     2,     2,     2,     2,     2,     1, 13264,  5509])

In [20]:
encoded_input['attention_mask'][0] # Observe attention mask is zero for pad tokens!

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1])

In [21]:
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

## Handle attention mask in the attention logits
Also look at https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L800

In [22]:
attention_logits = torch.rand(B, S, S) # dummy for demo
attention_logits = attention_logits.masked_fill_(attention_mask[:,None,:] == 0, torch.finfo(attention_logits.dtype).min)
attention_logits

tensor([[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  3.8209e-01,  4.0718e-01,  1.3385e-02],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  1.8255e-01,  3.2008e-01,  3.2200e-02],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  8.1127e-01,  4.1283e-02,  8.8291e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  7.6152e-01,  9.2479e-01,  8.3445e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  5.4773e-01,  2.6998e-01,  5.0050e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  4.7296e-01,  4.2571e-01,  8.8768e-01],
         [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
          -3.4028e+38,  7.3142e-01,  3.2486e-01,  4.0470e-01],
         [-3.4028e+38, -3.4028e+38

# Putting the code bits together

In [23]:
class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)

In [24]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention

    Not Implemented:
    * KV cacheing
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,S,S = attention_logits.shape

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,S,S)).to(device)

        if mask_value is None:
            mask_value = torch.finfo(attention_logits.dtype).min

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits,mask_value=None):
        if mask_value is None:
            mask_value = torch.finfo(masked_attention_logits.dtype).min
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, mask_value)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output


In [25]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

# K-V caching

In [26]:
"""\

With KV caching:
Input: (B,1,E), kv_cache
Output: (B,1,E)

"""

'\nWith KV caching:\nInput: (B,1,E), kv_cache\nOutput: (B,1,E)\n\n'

In [27]:
# Dummy input for testing
B = 2
S = 4
E = 16

z_old = torch.rand(B,S,E)
z = torch.rand(B,1,E)
z_new = torch.concat([z_old,z],dim=1)

self_atten = BaseSelfAttention(E)
output_old = self_atten(z_old,attention_mask=None)
output_new = self_atten(z_new,attention_mask=None)

print(z.shape)
print(z_old.shape)
print(z_new.shape)

torch.Size([2, 1, 16])
torch.Size([2, 4, 16])
torch.Size([2, 5, 16])


In [28]:
class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)

In [29]:
# store previous K,V in the cache
cache = KVCache(output_old.key,
                output_old.value)

In [30]:
# Construct K,V from cache

new_key = self_atten.key_proj(z)
new_value = self_atten.value_proj(z)
query = self_atten.query_proj(z)

In [31]:
key = torch.concat([cache.key,
                    new_key],dim=1)
value = torch.concat([cache.value,
                      new_value],dim=1)

In [32]:
query.shape

torch.Size([2, 1, 16])

In [33]:
key.shape

torch.Size([2, 5, 16])

In [34]:
# remaining operations remains the same

# Calculate logits
unmasked_attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

# Apply causal masking
masked_attention_logits = self_atten.apply_causal_mask(unmasked_attention_logits)

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

RuntimeError: output with shape [2, 1, 5] doesn't match the broadcast shape [2, 5, 5]

Oh no, causal mask!

In [35]:
# let's replicate it

attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

B,Sq,Sk = attention_logits.shape # <---------Change 1

device = attention_logits.device
causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device) # <------------Change 2

print(causal_mask.shape)
causal_mask = causal_mask[:,-Sq:,-Sk:] # <------------Change 3
print(causal_mask.shape)

if mask_value is None:
    mask_value = -torch.inf

# replace upper triangular with -inf or a very large negative number, causal masking!
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 

torch.Size([2, 5, 5])
torch.Size([2, 1, 5])


In [36]:
# continue from that

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

In [37]:
output_new.attention_output[:,-1:,:]

tensor([[[-0.1127,  0.2380,  0.1639,  0.5565,  0.2931, -0.0349, -0.0777,
          -0.0873,  0.1724, -0.1370,  0.4725,  0.0513,  0.0438,  0.1115,
           0.0340, -0.1261]],

        [[-0.0409,  0.2507,  0.0951,  0.5170,  0.3235, -0.0853, -0.1103,
          -0.0243,  0.0784, -0.1308,  0.4732, -0.0428,  0.1069,  0.0872,
          -0.0407, -0.2352]]], grad_fn=<SliceBackward0>)

In [38]:
output_with_caching.attention_output

tensor([[[-0.1127,  0.2380,  0.1639,  0.5565,  0.2931, -0.0349, -0.0777,
          -0.0873,  0.1724, -0.1370,  0.4725,  0.0513,  0.0438,  0.1115,
           0.0340, -0.1261]],

        [[-0.0409,  0.2507,  0.0951,  0.5170,  0.3235, -0.0853, -0.1103,
          -0.0243,  0.0784, -0.1308,  0.4732, -0.0428,  0.1069,  0.0872,
          -0.0407, -0.2352]]], grad_fn=<ViewBackward0>)

Yay! matches!

# Merge it with the baseline

In [39]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x,kv_cache):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        if kv_cache is not None:
            key = torch.concat([kv_cache.key,
                    key],dim=1)
            value = torch.concat([kv_cache.value,
                      value],dim=1)

        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,Sq,Sk = attention_logits.shape 

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device)
        causal_mask = causal_mask[:,-Sq:,-Sk:] # Trim off, for kv_cache, no-op if kv_cache is None

        if mask_value is None:
            mask_value = torch.finfo(attention_logits.dtype).min

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits,mask_value=None):
        if mask_value is None:
            mask_value = torch.finfo(masked_attention_logits.dtype).min
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, mask_value)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None,kv_cache=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x,kv_cache=kv_cache)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output
    


class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)


class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)


In [40]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

In [41]:
# Dummy input for kv-cache testing
B = 2
S = 4
E = 16

z = torch.rand(B,S,E)
self_atten = BaseSelfAttention(E)

output = self_atten(z,attention_mask=None)
cache = KVCache(output.key,output.value) # Construct kv cache

# new token
_z = torch.rand(B,1,E)

# Without KV Caching
z_new = torch.concat([z,_z],dim=1)
output_new = self_atten(z_new,attention_mask=None)

# With KV Caching
output_with_caching = self_atten(_z,attention_mask=None,kv_cache=cache)

In [42]:
output_new.attention_output[:,-1,:]

tensor([[ 0.3471,  0.4393, -0.5275,  0.5012, -0.3518,  0.0159, -0.0625, -0.2379,
          0.0965, -0.1188, -0.2035,  0.1183, -0.1810,  0.1572,  0.1148, -0.1510],
        [ 0.4791,  0.4694, -0.5979,  0.5361, -0.4261,  0.0571, -0.0666, -0.2384,
          0.0965, -0.0262, -0.1194,  0.1316, -0.3341,  0.1024,  0.1217, -0.3082]],
       grad_fn=<SliceBackward0>)

In [43]:
output_with_caching.attention_output

tensor([[[ 0.3471,  0.4393, -0.5275,  0.5012, -0.3518,  0.0159, -0.0625,
          -0.2379,  0.0965, -0.1188, -0.2035,  0.1183, -0.1810,  0.1572,
           0.1148, -0.1510]],

        [[ 0.4791,  0.4694, -0.5979,  0.5361, -0.4261,  0.0571, -0.0666,
          -0.2384,  0.0965, -0.0262, -0.1194,  0.1316, -0.3341,  0.1024,
           0.1217, -0.3082]]], grad_fn=<ViewBackward0>)

In [44]:
output_new.attention_output[:,-1:,:] - output_with_caching.attention_output
# Just to verify, some numerical errors are expected only

tensor([[[ 2.9802e-08,  0.0000e+00,  5.9605e-08,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00, -1.4901e-08,  1.4901e-08,  7.4506e-09,
           0.0000e+00,  1.4901e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          -1.4901e-08]],

        [[ 0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00,  2.9802e-08,
           1.4901e-08,  0.0000e+00,  1.4901e-08,  2.9802e-08, -1.4901e-08,
          -7.4506e-09,  0.0000e+00,  0.0000e+00, -2.9802e-08,  0.0000e+00,
           0.0000e+00]]], grad_fn=<SubBackward0>)

# Extend to multiple-heads | Multi-Headed Attention
Functionality wise most of the code remains same. Hacky way to reuse self-attention code is by reshaping the q,k,v to shape `(batch*num_heads, S, E//num_heads)`. 

We will implement from scratch with seperate dimension for heads for better flexibility later on!

In [45]:
# Split Q,K,V into heads

In [46]:
B = 4
S = 8
E = 32
z = torch.rand(B,S,E)

self_attn = BaseSelfAttention(E)
output = self_attn(z)

In [47]:
# Reuse the self-attention implementation

query,key,value = self_attn.construct_query_key_value(z,kv_cache=None)

In [48]:
num_heads = 4


B,S,E = query.shape
assert E%num_heads==0, "embed_dim must be divisible by num_heads"

In [49]:
query = query.view(B,num_heads,S,E//num_heads)
key = key.view(B,num_heads,S,E//num_heads)
value = value.view(B,num_heads,S,E//num_heads)

In [50]:
query.shape # now (B,num_heads,S,E//num_heads)

torch.Size([4, 4, 8, 8])

In [51]:
# After this, the operations are same as the original self-attention.
# however,Some changes are required to incoperate an additional head dimension
# After all operations, we merge back the heads before final output projection

In [52]:
attention_output = torch.rand(B,num_heads,S,E//num_heads) # dummy for testing

attention_output = attention_output.view(B,S,E) # go back to (B,S,E)

# Extend self-attention to multi-headed attention

In [53]:
z = torch.arange(6*32).view(2,3,32)

In [54]:
B,S,E = z.shape

In [55]:
num_heads = 8


In [56]:
class BaseMultiHeadedAttention(nn.Module):
    """
    Minimal Base Self Attention
    """
    def __init__(self,embed_dim,num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def _split_head(self,x,num_heads):
        B,S,E = x.shape
        assert E%num_heads==0, "embed_dim must be divisible by num_heads"
        x = x.view(B,S,E//num_heads,num_heads) # B,S,E,H
        x = x.permute([0,3,1,2]) # B,H, S, E
        return x
    
    def _merge_head(self,x):
        B,H,S,E = x.shape
        x = x.permute([0,2,1,3]) # B,S,H,E
        x = x.reshape(B,S,H*E) # B,S,H*E
        return x


    def create_heads(self,query,key,value):
        B,S,E = query.shape
        assert E%self.num_heads==0, "embed_dim must be divisible by num_heads"
        # query = query.view(B,self.num_heads,S,E//self.num_heads)
        # key = key.view(B,self.num_heads,S,E//self.num_heads)
        # value = value.view(B,self.num_heads,S,E//self.num_heads)
        query = self._split_head(query,self.num_heads)
        key = self._split_head(key,self.num_heads)
        value = self._split_head(value,self.num_heads)
        return query,key,value

    def construct_query_key_value(self,x,kv_cache=None):
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)

        query,key,value = self.create_heads(query,key,value)
        
        if kv_cache is not None:
            key = torch.concat([kv_cache.key,
                    key],dim=2) # B,H,S,E. dim becomes 2
            value = torch.concat([kv_cache.value,
                      value],dim=2)
        
        return query,key,value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,H,S,E] @ [B,H,E,S]
        key_t = key.transpose(2,3) # Transpose to [B,H,E,S] by exchanging dim 2 and 3
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale
        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,H,Sq,Sk = attention_logits.shape 

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,H,Sk,Sk)).to(device)
        causal_mask = causal_mask[:,:,-Sq:,-Sk:] # Trim off, for kv_cache, no-op if kv_cache is None

        if mask_value is None:
            mask_value = torch.finfo(attention_logits.dtype).min

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits,mask_value=None):
        if mask_value is None:
            mask_value = torch.finfo(masked_attention_logits.dtype).min
        masked_attention_logits = masked_attention_logits.masked_fill_(
            attention_mask[:,None,None,:] == 0, # Additional dimension for heads
            mask_value)
        return masked_attention_logits
    
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        B,H,_,E = value.shape
        B,H,Sq,Sk = attention_weights.shape

        attention_output = attention_weights@value
        
        attention_output = self._merge_head(attention_output) # flatten back to (B,S,E)

        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None,kv_cache=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x,kv_cache)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output

In [57]:
B = 4
S = 3
E = 32
H = 8
z = torch.rand(B,S,E)

multi_headed_attn = BaseMultiHeadedAttention(E,H)
output = multi_headed_attn(z)

In [58]:
output.attention_output.shape

torch.Size([4, 3, 32])

In [59]:
# Dummy input for kv-cache testing
B = 2
S = 4
E = 16
H = 8

z = torch.rand(B,S,E)
multi_headed_attn = BaseMultiHeadedAttention(E,H)

output = multi_headed_attn(z,attention_mask=None)
cache = KVCache(output.key,output.value) # Construct kv cache

# new token
_z = torch.rand(B,1,E)

# Without KV Caching
z_new = torch.concat([z,_z],dim=1)
output_new = multi_headed_attn(z_new,attention_mask=None)

# With KV Caching
output_with_caching = multi_headed_attn(_z,attention_mask=None,kv_cache=cache)

In [60]:
output_new.attention_output[:,-1:,:]

tensor([[[ 0.0735,  0.2351, -0.2685,  0.0621,  0.4552, -0.2058, -0.1648,
          -0.0123,  0.1202,  0.1113, -0.2080, -0.2728, -0.0924,  0.0647,
          -0.0148, -0.0551]],

        [[ 0.2660,  0.3878, -0.1069,  0.0234,  0.4454, -0.0942, -0.1740,
          -0.1417,  0.1913,  0.1965, -0.2677, -0.1547, -0.0307,  0.0504,
          -0.0111,  0.0865]]], grad_fn=<SliceBackward0>)

In [61]:
output_with_caching.attention_output

tensor([[[ 0.0735,  0.2351, -0.2685,  0.0621,  0.4552, -0.2058, -0.1648,
          -0.0123,  0.1202,  0.1113, -0.2080, -0.2728, -0.0924,  0.0647,
          -0.0148, -0.0551]],

        [[ 0.2660,  0.3878, -0.1069,  0.0234,  0.4454, -0.0942, -0.1740,
          -0.1417,  0.1913,  0.1965, -0.2677, -0.1547, -0.0307,  0.0504,
          -0.0111,  0.0865]]], grad_fn=<ViewBackward0>)

In [62]:
# Test attention mask
E = 16
H = 4

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



multi_headed_attn = BaseMultiHeadedAttention(E,H)
output = multi_headed_attn(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

# Making a unit test, since the test cases are being repeated

In [63]:

class AttentionTestCase:
    def __init__(self,Module,B,S,E,model_kwargs):
        self.Module = Module
        self.B = B
        self.S = S
        self.E = E
        self.model_kwargs = model_kwargs

    def test_forward(self):
        z = torch.rand(self.B,self.S,self.E)

        module = self.Module(**self.model_kwargs)
        output = module(z)
        assert output.attention_output.shape == (self.B,self.S,self.E)
        print('Forward test passed!')

    def test_attention_mask(self):
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left" # Important!,usually decoder only models require left padding

        
        texts = ["Something random","A bit longer text","something even longer than the two before!"]
        encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
        attention_mask = encoded_input['attention_mask']
        B,S = attention_mask.shape

        z = torch.rand(B,S,self.E)

        module = self.Module(**self.model_kwargs)
        output = module(z,attention_mask)
        assert output.attention_output.shape == (B,S,self.E)

        print('Attention mask test passed!')

    def test_kv_cache(self):
        z = torch.rand(self.B,self.S,self.E)
        module = self.Module(**self.model_kwargs)

        output = module(z,attention_mask=None)
        cache = KVCache(output.key,output.value) # Construct kv cache

        # new token
        _z = torch.rand(self.B,1,self.E)

        # Without KV Caching
        z_new = torch.concat([z,_z],dim=1)
        output_new = module(z_new,attention_mask=None)

        # With KV Caching
        output_with_caching = module(_z,attention_mask=None,kv_cache=cache)

        assert (output_new.attention_output[:,-1:,:] - output_with_caching.attention_output).abs().max().item() < 1e-5,\
              "Attention output does not match: kv-caching"
        
        print('KV cache test passed!')
        
    def run(self):
        self.test_forward()
        self.test_attention_mask()
        self.test_kv_cache()

In [64]:
testing = AttentionTestCase(BaseMultiHeadedAttention,B=2,S=3,E=32,model_kwargs={'embed_dim':32,'num_heads':8})
testing.run()

Forward test passed!
Attention mask test passed!
KV cache test passed!


# Grouped Query Attention

In [65]:
class BasePreTrainedGroupedQueryAttention(BaseMultiHeadedAttention):
    """
    Modifies *trained* Multi-Headed Attention to Grouped Query Attention as done in https://arxiv.org/pdf/2305.13245v3.pdf
    """
    def __init__(self,embed_dim,num_heads,num_groups):
        super().__init__(embed_dim,num_heads)
        self.num_groups = num_groups
        assert num_heads%num_groups==0, "num_heads must be divisible by num_groups"

    def _group(self,x):
        """
        Grouping with mean pooling as suggested by https://arxiv.org/pdf/2305.13245v3.pdf

        key,value shape: B,H,S,E => B,G,H//G,S,E ===mean pooling===> B,G,S,E
        To ensure order is correct, we permute to B,S,E,H then group to B,S,E,G,H//G
        Then mean pool B,S,E,G,H//G => B,S,E,G ===permute===> B,G,S,E
        Then Interleave repeat to B,H,S,E
        """

        B,H,S,E = x.shape
        G = self.num_groups

        x = x.permute([0,2,3,1]) # B,S,E,H
        x = x.reshape(B,S,E,G,H//G) # B,S,E,G,H//G
        x = x.mean(dim=-1) # B,S,E,G
        x = x.permute([0,3,1,2]) # B,G,S,E
        x = torch.repeat_interleave(x,H//G,dim=1)
        return x

    def construct_query_key_value(self, x,kv_cache=None):
        query,key,value =  super().construct_query_key_value(x)

        B,H,S,E = key.shape
        G = self.num_groups

        key = self._group(key)
        value = self._group(value)

        return query,key,value

In [66]:
testing = AttentionTestCase(BasePreTrainedGroupedQueryAttention,B=2,S=3,E=32
                            ,model_kwargs={'embed_dim':32,'num_groups':2,'num_heads':8})

testing.test_forward()
testing.test_attention_mask()

Forward test passed!
Attention mask test passed!


In [67]:
class BaseGroupedQueryAttention(BaseMultiHeadedAttention):
    """
    Implementation for training form scratch, 
    directly project to grouped query instead of mean pooling like done in Mistral
    """
    def __init__(self,embed_dim,num_heads,num_groups):
        self.num_groups = num_groups
        super().__init__(embed_dim,num_heads)
        assert num_heads%num_groups==0, "num_heads must be divisible by num_groups"

    def init_qkvo_proj(self):
        kv_head_embed_dim = self.num_groups * (self.embed_dim//self.num_heads)
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)

        self.key_proj = nn.Linear(self.embed_dim,kv_head_embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,kv_head_embed_dim)

        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)


    def construct_query_key_value(self, x,kv_cache=None):
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
     
        query = self._split_head(query,self.num_heads)
        key = self._split_head(key,self.num_groups)
        value = self._split_head(value,self.num_groups)

        key = torch.repeat_interleave(key,self.num_heads//self.num_groups,dim=1)
        value = torch.repeat_interleave(value,self.num_heads//self.num_groups,dim=1)

        return query,key,value

In [68]:
# Same as multi-query
testing = AttentionTestCase(BaseGroupedQueryAttention,B=2,S=3,E=32
                            ,model_kwargs={'embed_dim':32,'num_groups':1,'num_heads':8})

testing.test_attention_mask()
testing.test_forward()
print()
# Same as grouped query
testing = AttentionTestCase(BaseGroupedQueryAttention,B=2,S=3,E=32
                            ,model_kwargs={'embed_dim':32,'num_groups':4,'num_heads':8})

testing.test_attention_mask()
testing.test_forward()

Attention mask test passed!
Forward test passed!

Attention mask test passed!
Forward test passed!
