In [None]:
# Multi-Head Attention:
# "Non-Efficient-Way"
import torch
from torch import nn

In [2]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])

batch = torch.stack([inputs, inputs])
batch.shape

torch.Size([2, 6, 3])

In [3]:
class CausalAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, dropout, context_length, bias = False):
        super().__init__()
        self.W_q = torch.nn.Linear(d_in, d_out, bias)
        self.W_k = torch.nn.Linear(d_in, d_out, bias)
        self.W_v = torch.nn.Linear(d_in, d_out, bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('mask', torch.ones(context_length, context_length).triu(diagonal=1).bool())
    
    def forward(self, x):
        b, num_token, token_emed = x.shape
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)
        attention_scores = queries @ keys.transpose(-1, -2)
        masked_attention_scores = attention_scores.masked_fill_(self.mask[:num_token, :num_token], -torch.inf)
        masked_attention_weights = torch.softmax(masked_attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vectors = masked_attention_weights @ values
        return context_vectors

d_in = batch.shape[-1]
d_out = 2
context_length = 10 # max-no. of tokens that the model can handle
causalAttention_01 = CausalAttention(d_in, d_out, 0.5, context_length)

In [4]:
causalAttention_01(batch).shape

torch.Size([2, 6, 2])

In [None]:
# MultiHead-Attention:
class MultiheadAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, context_length, num_heads=2):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, dropout, context_length) 
            for _ in range(num_heads)
        ])
         
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [60]:
torch.manual_seed(0)

# batch.shape # (2 x 6 x 3)
d_in = batch.shape[-1]
d_out = 2
context_length = 20
mha = MultiheadAttention(d_in, d_out, 0.5, context_length, num_heads=2)
mha(batch)

tensor([[[-0.5063,  0.3518, -0.3550, -0.6560],
         [-0.6503,  0.3955, -0.1536, -0.7514],
         [-0.6976,  0.4064, -0.0853, -0.7803],
         [-0.6289,  0.3677, -0.0297, -0.7015],
         [-0.6131,  0.3179, -0.0417, -0.6247],
         [-0.5870,  0.3259, -0.0040, -0.6322]],

        [[-0.5063,  0.3518, -0.3550, -0.6560],
         [-0.6503,  0.3955, -0.1536, -0.7514],
         [-0.6976,  0.4064, -0.0853, -0.7803],
         [-0.6289,  0.3677, -0.0297, -0.7015],
         [-0.6131,  0.3179, -0.0417, -0.6247],
         [-0.5870,  0.3259, -0.0040, -0.6322]]], grad_fn=<CatBackward0>)