In [1]:
# %pip install torch numpy scikit-learn pandas matplotlib seaborn

In [2]:
import torch
import math
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

# from imports import *

# Self-Attention: Bits, use it inspect and visualize line by line

Let's first setup a small dummy input for testing and visualisation

In [3]:
B = 2
S = 8
E = 16

z = torch.rand(B,S,E)

Now we construct query, key and value from the input

In [4]:
query_proj = nn.Linear(E,E)
key_proj = nn.Linear(E,E)
value_proj = nn.Linear(E,E)

In [5]:
query = query_proj(z)
key = key_proj(z)
value = value_proj(z)

In [6]:
query.shape, key.shape, value.shape

(torch.Size([2, 8, 16]), torch.Size([2, 8, 16]), torch.Size([2, 8, 16]))

Calculate attention logits

In [7]:
key_t = key.transpose(1,2)
key_t.shape # Q.K' : [B,S,E] @ [B,E,S]

torch.Size([2, 16, 8])

In [8]:
unscaled_attention_logits = (query@key_t) # Q.K' : [B,S,E] @ [B,E,S] => [B,S,S]
unscaled_attention_logits.shape

torch.Size([2, 8, 8])

Scale our attention logits

In [9]:
B,S,E = query.shape
scale = math.sqrt(E)
attention_logits = unscaled_attention_logits/scale

In [10]:
attention_logits.shape

torch.Size([2, 8, 8])

Causal masking

In [11]:
mask_value = -torch.inf
causal_mask = torch.tril(torch.ones(B,S,S)) # lower trianglular matrix with 1s
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) # replace upper triangular with -inf or a very large negative number

In [12]:
masked_attention_logits[0]

tensor([[0.0864,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1881, 0.1825,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1470, 0.0837, 0.0767,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.0975, 0.0718, 0.0193, 0.0382,   -inf,   -inf,   -inf,   -inf],
        [0.1623, 0.1190, 0.0545, 0.0852, 0.1564,   -inf,   -inf,   -inf],
        [0.1410, 0.1174, 0.1031, 0.0749, 0.1418, 0.0471,   -inf,   -inf],
        [0.2412, 0.2278, 0.1846, 0.1734, 0.1910, 0.1089, 0.1918,   -inf],
        [0.1133, 0.0457, 0.0184, 0.0343, 0.0964, 0.0371, 0.0608, 0.0761]],
       grad_fn=<SelectBackward0>)

In [13]:
causal_attention_weights = torch.softmax(masked_attention_logits, dim=-1)

In [14]:
causal_attention_weights # Observe lower triangular! Causal attention weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5014, 0.4986, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3483, 0.3270, 0.3247, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2603, 0.2537, 0.2407, 0.2453, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2094, 0.2005, 0.1880, 0.1939, 0.2082, 0.0000, 0.0000, 0.0000],
         [0.1728, 0.1688, 0.1664, 0.1618, 0.1729, 0.1573, 0.0000, 0.0000],
         [0.1505, 0.1485, 0.1422, 0.1406, 0.1431, 0.1318, 0.1432, 0.0000],
         [0.1318, 0.1231, 0.1198, 0.1217, 0.1295, 0.1221, 0.1250, 0.1269]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5246, 0.4754, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3351, 0.3284, 0.3365, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2513, 0.2301, 0.2582, 0.2604, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1985, 0.1871, 0.2092, 0.1945, 0.2107, 0.0000, 0.0000, 0.0000],
         [0.1678, 0.159

In [15]:
attention_output = causal_attention_weights@value

# What about padding tokens?
* In the base implementation we used self-attention on every-tokens, including padding tokens if present!
* Now we extend the base implementation so that attention is ignored for tokens like pad token

## Quick look at a sample tokenizer

In [16]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Important!,usually decoder only models require left padding
texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
encoded_input.keys()

dict_keys(['input_ids', 'attention_mask'])

In [18]:
tokenizer.pad_token_id # pad token id for reference

2

In [19]:
encoded_input['input_ids'][0] # Observe left padding!, 2 is the pad token
# Refer: https://huggingface.co/docs/transformers/llm_tutorial#wrong-padding-side

tensor([    2,     2,     2,     2,     2,     2,     1, 13264,  5509])

In [20]:
encoded_input['attention_mask'][0] # Observe attention mask is zero for pad tokens!

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1])

In [21]:
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

## Handle attention mask in the attention logits
Also look at https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L800

In [22]:
attention_logits = torch.rand(B, S, S) # dummy for demo
attention_logits = attention_logits.masked_fill_(attention_mask[:,None,:] == 0, -torch.inf)
attention_logits

tensor([[[  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.8878, 0.6595,
          0.8409],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.2274, 0.0165,
          0.2792],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.7158, 0.7744,
          0.5028],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.6947, 0.7066,
          0.5426],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.6911, 0.4306,
          0.5193],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.2244, 0.3229,
          0.6242],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.8498, 0.7114,
          0.7502],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.3790, 0.9751,
          0.5675],
         [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 0.6106, 0.8722,
          0.1585]],

        [[  -inf,   -inf,   -inf,   -inf, 0.1535, 0.9526, 0.8391, 0.5240,
          0.9733],
         [  -inf,   -inf,   -inf,   -inf, 0.5833, 0.0900, 0.5204, 0.

# Putting the code bits together

In [23]:
class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)

In [24]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention

    Not Implemented:
    * KV cacheing
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,S,S = attention_logits.shape

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,S,S)).to(device)

        if mask_value is None:
            mask_value = -torch.inf

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits):
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, -torch.inf)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output


In [25]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

# K-V caching

In [26]:
"""\

With KV caching:
Input: (B,1,E), kv_cache
Output: (B,1,E)

"""

'\nWith KV caching:\nInput: (B,1,E), kv_cache\nOutput: (B,1,E)\n\n'

In [27]:
# Dummy input for testing
B = 2
S = 4
E = 16

z_old = torch.rand(B,S,E)
z = torch.rand(B,1,E)
z_new = torch.concat([z_old,z],dim=1)

self_atten = BaseSelfAttention(E)
output_old = self_atten(z_old,attention_mask=None)
output_new = self_atten(z_new,attention_mask=None)

print(z.shape)
print(z_old.shape)
print(z_new.shape)

torch.Size([2, 1, 16])
torch.Size([2, 4, 16])
torch.Size([2, 5, 16])


In [28]:
class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)

In [29]:
# store previous K,V in the cache
cache = KVCache(output_old.key,
                output_old.value)

In [30]:
# Construct K,V from cache

new_key = self_atten.key_proj(z)
new_value = self_atten.value_proj(z)
query = self_atten.query_proj(z)

In [31]:
key = torch.concat([cache.key,
                    new_key],dim=1)
value = torch.concat([cache.value,
                      new_value],dim=1)

In [32]:
query.shape

torch.Size([2, 1, 16])

In [33]:
key.shape

torch.Size([2, 5, 16])

In [34]:
# remaining operations remains the same

# Calculate logits
unmasked_attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

# Apply causal masking
masked_attention_logits = self_atten.apply_causal_mask(unmasked_attention_logits)

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

RuntimeError: output with shape [2, 1, 5] doesn't match the broadcast shape [2, 5, 5]

Oh no, causal mask!

In [35]:
# let's replicate it

attention_logits = self_atten.calculate_unmasked_attention_logits(query,key)

B,Sq,Sk = attention_logits.shape # <---------Change 1

device = attention_logits.device
causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device) # <------------Change 2

print(causal_mask.shape)
causal_mask = causal_mask[:,-Sq:,-Sk:] # <------------Change 3
print(causal_mask.shape)

if mask_value is None:
    mask_value = -torch.inf

# replace upper triangular with -inf or a very large negative number, causal masking!
masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 

torch.Size([2, 5, 5])
torch.Size([2, 1, 5])


In [36]:
# continue from that

# Calculate attention weights
attention_weights = self_atten.calculate_attention_weights(masked_attention_logits)

# And finally, Calculate the final output
attention_output = self_atten.calculate_final_output(attention_weights,value)

output_with_caching = BaseOutput(attention_output=attention_output,
                    attention_weights=attention_weights,
                    key=key,
                    value=value)

In [37]:
output_new.attention_output[:,-1:,:]

tensor([[[-0.0490, -0.1678,  0.1818, -0.1398,  0.0972,  0.2274, -0.1216,
          -0.0604, -0.0800,  0.1481,  0.2079, -0.1148,  0.0112, -0.0232,
           0.2457, -0.1065]],

        [[-0.0963, -0.2137,  0.1641, -0.1136,  0.0437,  0.2227, -0.0582,
          -0.0935, -0.0542,  0.2538,  0.1606, -0.0999,  0.0909,  0.0347,
           0.1008, -0.1780]]], grad_fn=<SliceBackward0>)

In [38]:
output_with_caching.attention_output

tensor([[[-0.0490, -0.1678,  0.1818, -0.1398,  0.0972,  0.2274, -0.1216,
          -0.0604, -0.0800,  0.1481,  0.2079, -0.1148,  0.0112, -0.0232,
           0.2457, -0.1065]],

        [[-0.0963, -0.2137,  0.1641, -0.1136,  0.0437,  0.2227, -0.0582,
          -0.0935, -0.0542,  0.2538,  0.1606, -0.0999,  0.0909,  0.0347,
           0.1008, -0.1780]]], grad_fn=<ViewBackward0>)

Yay! matches!

# Merge it with the baseline

In [39]:
class BaseSelfAttention(nn.Module):
    """
    Minimal Base Self Attention
    """
    def __init__(self,embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.init_qkvo_proj()

    def init_qkvo_proj(self):
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.key_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,self.embed_dim)
        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)

    def construct_query_key_value(self,x,kv_cache):
        # Construct Q, K, V
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        if kv_cache is not None:
            key = torch.concat([kv_cache.key,
                    key],dim=1)
            value = torch.concat([kv_cache.value,
                      value],dim=1)

        return query, key, value
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,S,E] @ [B,E,S]
        key_t = key.transpose(1,2) # Transpose to [B,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,Sq,Sk = attention_logits.shape 

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,Sk,Sk)).to(device)
        causal_mask = causal_mask[:,-Sq:,-Sk:] # Trim off, for kv_cache, no-op if kv_cache is None

        if mask_value is None:
            mask_value = -torch.inf

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits):
        masked_attention_logits = masked_attention_logits.masked_fill_(attention_mask[:,None,:] == 0, -torch.inf)
        return masked_attention_logits
    
    def calculate_attention_weights(self,masked_attention_logits):
        attention_weights = torch.softmax(masked_attention_logits, dim=-1)
        return attention_weights
    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value
        final_output = self.output_proj(attention_output)
        return final_output
    
    def forward(self, x,attention_mask=None,kv_cache=None):
        # Construct Q, K, V
        query, key, value = self.construct_query_key_value(x,kv_cache=kv_cache)

        # Calculate logits
        unmasked_attention_logits = self.calculate_unmasked_attention_logits(query,key)

        # Apply causal masking
        masked_attention_logits = self.apply_causal_mask(unmasked_attention_logits)

        if attention_mask is not None:
            # Apply attention mask
            masked_attention_logits = self.apply_attention_mask(attention_mask,masked_attention_logits)

        # Calculate attention weights
        attention_weights = self.calculate_attention_weights(masked_attention_logits)

        # And finally, Calculate the final output
        attention_output = self.calculate_final_output(attention_weights,value)

        output = BaseOutput(attention_output=attention_output,
                            attention_weights=attention_weights,
                            key=key,
                            value=value)
        return output
    


class KVCache:
    def __init__(self,key,value,**kwargs):
        self.key = key
        self.value = value
        for k,v in kwargs.items():
            setattr(self,k,v)


class BaseOutput:
    def __init__(self,**kwargs):
        for k,v in kwargs.items():
            setattr(self,k,v)


In [40]:
E = 16

texts = ["Something random","A bit longer text","something even longer than the two before!"]
encoded_input = tokenizer(texts, return_tensors='pt',padding=True)
attention_mask = encoded_input['attention_mask']
B,S = attention_mask.shape

z = torch.rand(B,S,E)



self_atten = BaseSelfAttention(E)
output = self_atten(z,attention_mask)
assert output.attention_output.shape == (B,S,E)

In [41]:
# Dummy input for kv-cache testing
B = 2
S = 4
E = 16

z = torch.rand(B,S,E)
self_atten = BaseSelfAttention(E)

output = self_atten(z,attention_mask=None)
cache = KVCache(output.key,output.value) # Construct kv cache

# new token
_z = torch.rand(B,1,E)

# Without KV Caching
z_new = torch.concat([z,_z],dim=1)
output_new = self_atten(z_new,attention_mask=None)

# With KV Caching
output_with_caching = self_atten(_z,attention_mask=None,kv_cache=cache)

In [42]:
output_new.attention_output[:,-1,:]

tensor([[-0.1764,  0.1235,  0.1896,  0.1131, -0.1038, -0.2366, -0.5676, -0.0113,
          0.2021,  0.4325, -0.1776, -0.4810,  0.0634, -0.0687, -0.3003, -0.3440],
        [-0.1879,  0.1474,  0.1990,  0.1931, -0.0349, -0.1754, -0.4873, -0.0399,
          0.2542,  0.4419, -0.2270, -0.4308, -0.0149, -0.1404, -0.3198, -0.3030]],
       grad_fn=<SliceBackward0>)

In [43]:
output_with_caching.attention_output

tensor([[[-0.1764,  0.1235,  0.1896,  0.1131, -0.1038, -0.2366, -0.5676,
          -0.0113,  0.2021,  0.4325, -0.1776, -0.4810,  0.0634, -0.0687,
          -0.3003, -0.3440]],

        [[-0.1879,  0.1474,  0.1990,  0.1931, -0.0349, -0.1754, -0.4873,
          -0.0399,  0.2542,  0.4419, -0.2270, -0.4308, -0.0149, -0.1404,
          -0.3198, -0.3030]]], grad_fn=<ViewBackward0>)

In [44]:
output_new.attention_output[:,-1:,:] - output_with_caching.attention_output
# Just to verify, some numerical errors are expected only

tensor([[[ 2.9802e-08,  7.4506e-09,  1.4901e-08,  1.4901e-08, -1.4901e-08,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  2.9802e-08,  2.9802e-08,
          -1.4901e-08,  0.0000e+00,  0.0000e+00, -1.4901e-08,  0.0000e+00,
          -5.9605e-08]],

        [[ 0.0000e+00, -1.4901e-08,  1.4901e-08,  0.0000e+00,  1.4901e-08,
          -1.4901e-08,  2.9802e-08,  0.0000e+00, -2.9802e-08,  0.0000e+00,
           2.9802e-08, -2.9802e-08,  2.9802e-08,  2.9802e-08,  0.0000e+00,
           2.9802e-08]]], grad_fn=<SubBackward0>)

# Extend to multiple-heads | Multi-Headed Attention
Functionality wise most of the code remains same. Hacky way to reuse self-attention code is by reshaping the q,k,v to shape `(batch*num_heads, S, E//num_heads)`. 

We will implement from scratch with seperate dimension for heads for better flexibility later on!

In [45]:
# Split Q,K,V into heads

In [46]:
B = 4
S = 8
E = 32
z = torch.rand(B,S,E)

self_attn = BaseSelfAttention(E)
output = self_attn(z)

In [47]:
# Reuse the self-attention implementation

query,key,value = self_attn.construct_query_key_value(z,kv_cache=None)

In [48]:
num_heads = 4


B,S,E = query.shape
assert E%num_heads==0, "embed_dim must be divisible by num_heads"

In [49]:
query = query.view(B,num_heads,S,E//num_heads)
key = key.view(B,num_heads,S,E//num_heads)
value = value.view(B,num_heads,S,E//num_heads)

In [50]:
query.shape # now (B,num_heads,S,E//num_heads)

torch.Size([4, 4, 8, 8])

In [51]:
# After this, the operations are same as the original self-attention.
# however,Some changes are required to incoperate an additional head dimension
# After all operations, we merge back the heads before final output projection

In [52]:
attention_output = torch.rand(B,num_heads,S,E//num_heads) # dummy for testing

attention_output = attention_output.view(B,S,E) # go back to (B,S,E)

# Extend self-attention to multi-headed attention

In [55]:
class BaseMultiHeadedAttention(BaseSelfAttention):
    """
    Extends self attention to multi-headed attention
    """
    def __init__(self,embed_dim,num_heads):
        super().__init__(embed_dim)
        self.num_heads = num_heads
    
    def create_heads(self,query,key,value):
        B,S,E = query.shape
        query = query.view(B,self.num_heads,S,E//self.num_heads)
        key = key.view(B,self.num_heads,S,E//self.num_heads)
        value = value.view(B,self.num_heads,S,E//self.num_heads)
        return query,key,value
    
        
    def construct_query_key_value(self,x,kv_cache):
        query,key,value = super().construct_query_key_value(x,kv_cache)
        return self.create_heads(query,key,value)
    
    def calculate_unmasked_attention_logits(self,query,key):
        # Q.K' : [B,H, S,E] @ [B,H,E,S]
        key_t = key.transpose(2,3) # Transpose to [B,H,E,S] by exchanging dim 1 and 2
        
        # scaling factor
        scale = math.sqrt(self.embed_dim)

        # Calculate logits
        unmasked_attention_logits = (query@key_t)/scale

        return unmasked_attention_logits
    
    def apply_causal_mask(self,attention_logits,mask_value=None):
        # lower trianglular matrix with 1s
        B,H,Sq,Sk = attention_logits.shape 

        device = attention_logits.device
        causal_mask = torch.tril(torch.ones(B,H,Sk,Sk)).to(device)
        causal_mask = causal_mask[:,:,-Sq:,-Sk:] # Trim off, for kv_cache, no-op if kv_cache is None

        if mask_value is None:
            mask_value = -torch.inf

        # replace upper triangular with -inf or a very large negative number, causal masking!
        masked_attention_logits = attention_logits.masked_fill_(causal_mask == 0, mask_value) 
        return masked_attention_logits
    
    def apply_attention_mask(self,attention_mask,masked_attention_logits):
        masked_attention_logits = masked_attention_logits.masked_fill_(
            attention_mask[:,None,None,:] == 0, # Additional dimension for heads
            -torch.inf)
        return masked_attention_logits
    

    
    def calculate_final_output(self,attention_weights,value):
        attention_output = attention_weights@value

        attention_output = attention_output.view(B,S,E) # flatten back to (B,S,E)

        final_output = self.output_proj(attention_output)
        return final_output

In [56]:
B = 4
S = 8
E = 32
H = 4
z = torch.rand(B,S,E)

self_attn = BaseMultiHeadedAttention(E,H)
output = self_attn(z)

In [57]:
output.attention_output.shape

torch.Size([4, 8, 32])