Import required libraries

In [47]:
import torch
import torch.nn as nn

Consider below vector embeddings for simplicity

In [48]:
inputs = torch.tensor([
 [0.43, 0.15, 0.89],
 [0.55, 0.87, 0.66],
 [0.57, 0.85, 0.64],
 [0.22, 0.58, 0.33],
 [0.77, 0.25, 0.10],
 [0.05, 0.80, 0.55]
]
)

Initiate varibales

In [49]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
dropout = 0.5

Implementing causual attention class including handling of batches

Create a dummy batch

In [50]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [51]:
class CausalAttention(nn.Module):
    
    def __init__(self, d_in,d_out,context_length,dropout, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length),diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        atten_scores = queries @ keys.transpose(1,2)
        atten_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens],-torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        atten_weights = torch.softmax(atten_scores / keys.shape[-1]**0.5, dim=-1)
        atten_weights = self.dropout(atten_weights)
        context_vector = atten_weights @ values
        return context_vector


Extending single head attention to multi-head attention

In [52]:
class MultiheadAttentionWrapper(nn.Module):

    def __init__(self,d_in,d_out,context_length,dropout,num_heads = 2,qkvbias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in,d_out,context_length,dropout,qkvbias) 
             for _ in range(num_heads)]
        )

    def forward(self,x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [53]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in,d_out = 3,2
mha = MultiheadAttentionWrapper(d_in, d_out, context_length, dropout, num_heads=2, qkvbias=False)
context_vecs = mha(batch)
print(context_vecs.shape)  # Should be (batch_size, num_tokens, d_out


torch.Size([2, 6, 4])
