In [1]:
# "Multihead Attention in a Efficient Way"
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])

batch = torch.stack([inputs, inputs])
batch.shape

torch.Size([2, 6, 3])

In [None]:
torch.manual_seed(42)

context_length = 10 # max of tokens the model can process:

# batch.shape # (b, num_tokens, token_dim)
b, num_tokens, token_dim= batch.shape
d_in = 3
d_out = 8
num_heads = 4
d_head = d_out // num_heads

W_q = nn.Linear(in_features= d_in, out_features= d_out)
W_k = nn.Linear(in_features= d_in, out_features= d_out)
W_v = nn.Linear(in_features= d_in, out_features= d_out)

queries = W_q(batch) # (2, 6, 8) = (b, num_tokens, d_out)
keys = W_k(batch) 
values = W_v(batch) 

# Reshaping :
queries = queries.view(b, num_tokens, num_heads, d_head)
keys = keys.view(b, num_tokens, num_heads, d_head)
values = values.view(b, num_tokens, num_heads, d_head)

queries.transpose_(1, 2) # (b, num_heads, num_tokens, d_head)
keys.transpose_(1, 2)
values.transpose_(1, 2)

attn_scores = queries @ keys.transpose(-1, -2)
mask = torch.ones((context_length, context_length)).triu(diagonal=1).bool()
mask_attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
# mask_attn_scores[0][0].softmax(dim=-1).sum(dim=-1, keepdims=True)
mask_attn_weights = torch.softmax(mask_attn_scores / d_head**0.5, dim=-1)
# mask_attn_weights.shape # (2, 4, 6, 6)
# values.shape # (2, 4, 6, 2)
# mask_attn_weights = dropout(mask_attn_weights) # ****
context_vectors = mask_attn_weights @ values
context_vectors.shape # (b, num_heads, num_tokens, d_head)
context_vectors.transpose_(1, 2).shape # (b, num_tokens, num_heads,  d_head)
context_vectors = context_vectors.contiguous().view(b, num_tokens, d_out)
# context_vectors.shape # (2, 6, 8)

out_proj = nn.Linear(d_out, d_out)

out_proj(context_vectors)

torch.Size([2, 6, 8])

In [15]:
# context_length: "Maximum number of tokens that the model can handle."
class MultiheadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, qkv_bias=False, dropout_p=0.5):
        super().__init__()
        assert (d_out % num_heads == 0), "`d_out` must be divisible by `num_heads`"
        self.W_q = nn.Linear(d_in, d_out)
        self.W_k = nn.Linear(d_in, d_out)
        self.W_v = nn.Linear(d_in, d_out)
        self.d_out = d_out
        self.num_heads = num_heads
        self.d_head = (d_out // num_heads)
        self.dropout = nn.Dropout(dropout_p)
        self.register_buffer('mask', torch.ones(context_length, context_length).triu(diagonal=1).bool())

        # out-projection:
        self.out_proj = nn.Linear(d_out, d_out)
    
    def forward(self, x):
        '''x: (batch_size, num_tokens, token_dim)'''
        b, num_tokens, token_dim = x.shape

        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x) # (b, num_tokens, [d_out])
        # Reshaping:
        queries = queries.view((b, num_tokens, self.num_heads, self.d_head))
        keys = keys.view((b, num_tokens, self.num_heads, self.d_head))
        values = values.view((b, num_tokens, self.num_heads, self.d_head))

        # Transposing, for making total sense of the Matrix
        queries = queries.transpose(1, 2) # (b, num_heads, num_tokens, d_head)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(-1, -2)
        masked_attn_scores = attn_scores.masked_fill(self.mask[:num_tokens, :num_tokens], -torch.inf)
        masked_attn_weights = torch.softmax(masked_attn_scores / self.d_head**0.5, dim=-1)
        masked_attn_weights = self.dropout(masked_attn_weights)
        context_vectors = (masked_attn_weights @ values).transpose(1, 2) # (b, num_heads, num_tokens, d_head)
        # context_vectors = (b,  num_tokens, num_heads, d_head)
        context_vectors = context_vectors.contiguous().view(b, num_tokens, self.d_out)
        context_vectors = self.out_proj(context_vectors)
        return context_vectors #.shape 

mha = MultiheadAttention(d_in=batch.shape[-1], d_out=8, num_heads=4, context_length=10)
mha(batch)

tensor([[[-0.4778,  0.6357,  0.9619,  0.0515,  0.5443,  0.1522,  0.1753,
          -0.9143],
         [-0.4642,  0.4034,  0.5549, -0.1379,  0.4325,  0.2282,  0.2022,
          -0.4690],
         [-0.4756,  0.4119,  0.5055, -0.4186,  0.5354,  0.3641,  0.3838,
          -0.2982],
         [-0.1830,  0.2517,  0.5895,  0.0918,  0.0767,  0.1906, -0.1340,
          -0.3514],
         [-0.3252,  0.3469,  0.4782, -0.0962,  0.2970,  0.3436,  0.0971,
          -0.3144],
         [-0.2796,  0.2638,  0.4168, -0.0473,  0.2336,  0.3241,  0.0175,
          -0.2542]],

        [[-0.3452,  0.3062,  0.5174,  0.3615,  0.0507, -0.0598, -0.3246,
          -0.6409],
         [-0.2852,  0.4073,  0.7233, -0.1198,  0.2182,  0.3254, -0.0067,
          -0.4223],
         [-0.5523,  0.3772,  0.3209, -0.3505,  0.4984,  0.2411,  0.3291,
          -0.2988],
         [-0.3549,  0.4129,  0.6398, -0.2399,  0.5585,  0.4019,  0.4169,
          -0.3652],
         [-0.3187,  0.3414,  0.5716, -0.0411,  0.4280,  0.3113,  0.2

# **Revising:**

In [1]:
import torch
import torch.nn as nn

inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])

batch = torch.stack([inputs, inputs, inputs])
batch.shape

torch.Size([3, 6, 3])

In [48]:
# Multihead-Attention:
torch.manual_seed(42)

b, num_tokens, token_embed = batch.shape
d_in  = batch.shape[-1] #3
d_out = 4
num_heads = 2
d_head = d_out // num_heads

W_q = nn.Linear(in_features= d_in, out_features= d_out)
W_k = nn.Linear(in_features= d_in, out_features= d_out)
W_v = nn.Linear(in_features= d_in, out_features= d_out)

queries = W_q(batch) # (b, num_tokens, d_out)
keys = W_k(batch)
values = W_v(batch)

queries = queries.view(b, num_tokens, num_heads, d_head)
keys = keys.view(b, num_tokens, num_heads, d_head)
values = values.view(b, num_tokens, num_heads, d_head)

queries = queries.transpose(1, 2) # (b, num_heads, num_tokens, d_head)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

attn_scores = queries @ keys.transpose(-1, -2)
context_length =  10 # Maximum number of tokens that the model can handle
mask = torch.ones(context_length, context_length).triu(diagonal=1).bool()
mask_attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
mask_attn_weights = torch.softmax(mask_attn_scores / d_head**0.5, dim=-1)

dropout = nn.Dropout()
mask_attn_weights = dropout(mask_attn_weights)
context_vectors = mask_attn_weights @ values
context_vectors = context_vectors.transpose(1, 2) # (b, num_tokens, num_heads, d_head)
context_vectors = context_vectors.contiguous().view(b, num_tokens, num_heads * d_head)

out_proj = nn.Linear(in_features= d_out, out_features=d_out)
context_vectors = out_proj(context_vectors).shape

context_vectors

torch.Size([3, 6, 4])