In [3]:
# %pip install torch numpy scikit-learn pandas matplotlib seaborn

In [4]:
import torch
import math
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

# from imports import *

# Self-Attention: Bits, use it inspect and visualize line by line

Let's first setup a small dummy input for testing and visualisation

In [5]:
B = 2
S = 8
E = 16

z = torch.rand(B,S,E)

Now we construct query, key and value from the input

In [6]:
query_proj = nn.Linear(E,E)
key_proj = nn.Linear(E,E)
value_proj = nn.Linear(E,E)

In [7]:
query = query_proj(z)
key = key_proj(z)
value = value_proj(z)

In [8]:
query.shape, key.shape, value.shape

(torch.Size([2, 8, 16]), torch.Size([2, 8, 16]), torch.Size([2, 8, 16]))

Calculate attention logits

In [9]:
key_t = key.transpose(1,2)
key_t.shape # Q.K' : [B,S,E] @ [B,E,S]

torch.Size([2, 16, 8])

In [10]:
unscaled_attention_logits = (query@key_t) # Q.K' : [B,S,E] @ [B,E,S] => [B,S,S]
unscaled_attention_logits.shape

torch.Size([2, 8, 8])

Scale our attention logits

In [11]:
B,S,E = query.shape
scale = math.sqrt(E)
attention_logits = unscaled_attention_logits/scale

In [12]:
attention_logits.shape

torch.Size([2, 8, 8])

Causal masking

In [13]:
mask_value = -torch.inf
causal_mask = torch.tril(torch.ones(B,S,S)) # lower trianglular matrix with 1s
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) # replace upper triangular with -inf or a very large negative number

In [14]:
masked_attention_logits[0]

tensor([[ 0.1272,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0857,  0.0085,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0985,  0.0315,  0.1365,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2192,  0.1274,  0.1895, -0.0568,    -inf,    -inf,    -inf,    -inf],
        [ 0.2630,  0.2194,  0.3101,  0.0008,  0.0192,    -inf,    -inf,    -inf],
        [ 0.1405,  0.0805,  0.1534, -0.0623, -0.0633,  0.0253,    -inf,    -inf],
        [ 0.0920,  0.0191,  0.1043, -0.1186, -0.1003, -0.0371,  0.0185,    -inf],
        [ 0.1840,  0.1545,  0.2070, -0.0107,  0.0201,  0.0729,  0.1638,  0.1078]],
       grad_fn=<SelectBackward0>)

In [15]:
causal_attention_weights = torch.softmax(masked_attention_logits, dim=-1)

In [16]:
causal_attention_weights # Observe lower triangular! Causal attention weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5193, 0.4807, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3363, 0.3145, 0.3493, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2746, 0.2505, 0.2666, 0.2084, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2194, 0.2100, 0.2299, 0.1688, 0.1719, 0.0000, 0.0000, 0.0000],
         [0.1826, 0.1719, 0.1849, 0.1490, 0.1489, 0.1627, 0.0000, 0.0000],
         [0.1566, 0.1456, 0.1586, 0.1269, 0.1292, 0.1376, 0.1455, 0.0000],
         [0.1339, 0.1300, 0.1370, 0.1102, 0.1137, 0.1198, 0.1312, 0.1241]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4531, 0.5469, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2794, 0.3415, 0.3791, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2106, 0.2464, 0.2905, 0.2524, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1897, 0.2116, 0.2106, 0.1865, 0.2015, 0.0000, 0.0000, 0.0000],
         [0.1405, 0.172

In [17]:
attention_output = causal_attention_weights@value

# What about padding tokens?
* In the base implementation we used self-attention on every-tokens, including padding tokens if present!
* Now we extend the base implementation so that attention is ignored for tokens like pad token

## Quick look at a sample tokenizer

In [18]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Important!,usually decoder only models require left padding
texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
encoded_input.keys()

dict_keys(['input_ids', 'attention_mask'])

In [20]:
tokenizer.pad_token_id # pad token id for reference

2

In [21]:
encoded_input['input_ids'][0] # Observe left padding!, 2 is the pad token
# Refer: https://huggingface.co/docs/transformers/llm_tutorial#wrong-padding-side

tensor([    2,     2,     2,     2,     2,     2,     1, 13264,  5509])

In [22]:
encoded_input['attention_mask'][0] # Observe attention mask is zero for pad tokens!

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1])

In [23]:
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

## Handle attention mask in the attention logits
Also look at https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L800

In [24]:
attention_logits = torch.rand(B, S, S) # dummy for demo
attention_logits = attention_logits.masked_fill_(attention_mask[:,None,:] == 0, -torch.inf)
attention_logits

tensor([[[  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.1245, 0.6143,
          0.4404],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.7113, 0.8817,
          0.9789],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.0140, 0.7529,
          0.9597],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.2815, 0.8038,
          0.3988],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.0323, 0.8905,
          0.1319],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.2209, 0.4885,
          0.5274],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.0901, 0.3648,
          0.5164],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.3190, 0.0750,
          0.5282],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.5381, 0.0880,
          0.0667]],

        [[  -inf,   -inf,   -inf,   -inf, 0.9039, 0.7427, 0.1361, 0.9384,
          0.6286],
         [  -inf,   -inf,   -inf,   -inf, 0.1962, 0.6371, 0.3871, 0.

# Putting the code bits together

In [25]:
class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)

In [26]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention

    Not Implemented:
    * KV cacheing
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,S,S = attention_logits.shape

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,S,S)).to(device)

        if mask_value is None:
            mask_value = -torch.inf

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits):
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, -torch.inf)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output


In [27]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

# K-V caching

In [28]:
"""\

With KV caching:
Input: (B,1,E), kv_cache
Output: (B,1,E)

"""

'\nWith KV caching:\nInput: (B,1,E), kv_cache\nOutput: (B,1,E)\n\n'

In [29]:
# Dummy input for testing
B = 2
S = 4
E = 16

z_old = torch.rand(B,S,E)
z = torch.rand(B,1,E)
z_new = torch.concat([z_old,z],dim=1)

self_atten = BaseSelfAttention(E)
output_old = self_atten(z_old,attention_mask=None)
output_new = self_atten(z_new,attention_mask=None)

print(z.shape)
print(z_old.shape)
print(z_new.shape)

torch.Size([2, 1, 16])
torch.Size([2, 4, 16])
torch.Size([2, 5, 16])


In [30]:
class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)

In [31]:
# store previous K,V in the cache
cache = KVCache(output_old.key,
                output_old.value)

In [32]:
# Construct K,V from cache

new_key = self_atten.key_proj(z)
new_value = self_atten.value_proj(z)
query = self_atten.query_proj(z)

In [33]:
key = torch.concat([cache.key,
                    new_key],dim=1)
value = torch.concat([cache.value,
                      new_value],dim=1)

In [34]:
query.shape

torch.Size([2, 1, 16])

In [35]:
key.shape

torch.Size([2, 5, 16])

In [36]:
# remaining operations remains the same

# Calculate logits
unmasked_attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

# Apply causal masking
masked_attention_logits = self_atten.apply_causal_mask(unmasked_attention_logits)

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

RuntimeError: output with shape [2, 1, 5] doesn't match the broadcast shape [2, 5, 5]

Oh no, causal mask!

In [37]:
# let's replicate it

attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

B,Sq,Sk = attention_logits.shape # <---------Change 1

device = attention_logits.device
causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device) # <------------Change 2

print(causal_mask.shape)
causal_mask = causal_mask[:,-Sq:,-Sk:] # <------------Change 3
print(causal_mask.shape)

if mask_value is None:
    mask_value = -torch.inf

# replace upper triangular with -inf or a very large negative number, causal masking!
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 

torch.Size([2, 5, 5])
torch.Size([2, 1, 5])


In [38]:
# continue from that

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

In [39]:
output_new.attention_output[:,-1:,:]

tensor([[[ 0.0684, -0.0154, -0.0904, -0.2118,  0.2850,  0.2744,  0.3194,
           0.4253,  0.1098,  0.1277, -0.3167,  0.3301, -0.3709, -0.0795,
           0.1013,  0.1084]],

        [[ 0.1241, -0.0053, -0.1390, -0.1733,  0.2012,  0.3066,  0.2656,
           0.5026,  0.1187,  0.1604, -0.3240,  0.3251, -0.3033, -0.1132,
           0.0834,  0.1254]]], grad_fn=<SliceBackward0>)

In [40]:
output_with_caching.attention_output

tensor([[[ 0.0684, -0.0154, -0.0904, -0.2118,  0.2850,  0.2744,  0.3194,
           0.4253,  0.1098,  0.1277, -0.3167,  0.3301, -0.3709, -0.0795,
           0.1013,  0.1084]],

        [[ 0.1241, -0.0053, -0.1390, -0.1733,  0.2012,  0.3066,  0.2656,
           0.5026,  0.1187,  0.1604, -0.3240,  0.3251, -0.3033, -0.1132,
           0.0834,  0.1254]]], grad_fn=<ViewBackward0>)

Yay! matches!

# Merge it with the baseline

In [41]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x,kv_cache):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        if kv_cache is not None:
            key = torch.concat([kv_cache.key,
                    key],dim=1)
            value = torch.concat([kv_cache.value,
                      value],dim=1)

        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,Sq,Sk = attention_logits.shape 

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device)
        causal_mask = causal_mask[:,-Sq:,-Sk:] # Trim off, for kv_cache, no-op if kv_cache is None

        if mask_value is None:
            mask_value = -torch.inf

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits):
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, -torch.inf)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None,kv_cache=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x,kv_cache=kv_cache)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output
    


class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)


class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)


In [42]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

In [43]:
# Dummy input for kv-cache testing
B = 2
S = 4
E = 16

z = torch.rand(B,S,E)
self_atten = BaseSelfAttention(E)

output = self_atten(z,attention_mask=None)
cache = KVCache(output.key,output.value) # Construct kv cache

# new token
_z = torch.rand(B,1,E)

# Without KV Caching
z_new = torch.concat([z,_z],dim=1)
output_new = self_atten(z_new,attention_mask=None)

# With KV Caching
output_with_caching = self_atten(_z,attention_mask=None,kv_cache=cache)

In [44]:
output_new.attention_output[:,-1,:]

tensor([[-0.0030,  0.2412,  0.3421,  0.2261,  0.1717, -0.2394,  0.0092,  0.1346,
          0.4956, -0.2014, -0.0524, -0.2519, -0.1964, -0.0980,  0.4925, -0.0917],
        [ 0.0101,  0.2290,  0.3033,  0.2238,  0.1581, -0.2247,  0.0239,  0.1448,
          0.4721, -0.1742, -0.0281, -0.2448, -0.1844, -0.0562,  0.4522, -0.0282]],
       grad_fn=<SliceBackward0>)

In [45]:
output_with_caching.attention_output

tensor([[[-0.0030,  0.2412,  0.3421,  0.2261,  0.1717, -0.2394,  0.0092,
           0.1346,  0.4956, -0.2014, -0.0524, -0.2519, -0.1964, -0.0980,
           0.4925, -0.0917]],

        [[ 0.0101,  0.2290,  0.3033,  0.2238,  0.1581, -0.2247,  0.0239,
           0.1448,  0.4721, -0.1742, -0.0281, -0.2448, -0.1844, -0.0562,
           0.4522, -0.0282]]], grad_fn=<ViewBackward0>)

In [46]:
output_new.attention_output[:,-1:,:] - output_with_caching.attention_output
# Just to verify, some numerical errors are expected only

tensor([[[ 1.4901e-08,  1.4901e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          -1.4901e-08,  1.1176e-08,  2.9802e-08,  2.9802e-08, -2.9802e-08,
          -7.4506e-09,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]],

        [[-7.4506e-09,  0.0000e+00,  0.0000e+00, -1.4901e-08,  0.0000e+00,
           0.0000e+00, -3.7253e-09,  0.0000e+00,  2.9802e-08, -2.9802e-08,
           0.0000e+00,  1.4901e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00]]], grad_fn=<SubBackward0>)