In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [28]:
def Attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, d_k: int):
    QK_T = torch.matmul(Q, torch.transpose(K, -1, -2))
    QK_T_d_k = torch.div(QK_T, torch.sqrt(d_k))
    softmax = F.softmax(QK_T_d_k, dim = -1)
    return torch.matmul(softmax, V)

In [43]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, h: int):
        super(MultiHeadAttention, self).__init__()

        self.h = h
        self.d_k_value = torch.Tensor([d_k])
        self.linear = nn.ModuleList()
        self.W_O = nn.Parameter(torch.Tensor(h*d_v, d_model))
        self.attention = Attention

        for _ in range(self.h):
            linear = nn.ModuleList([nn.Linear(d_k, d_model), nn.Linear(d_k, d_model), nn.Linear(d_v, d_model)])
            self.linear.append(linear)

    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
        assert len(Q.shape) == len(K.shape) == len(V.shape), f"invalid dimensions, got Q:{Q.shape}, K: {K.shape}, V:{V.shape}"

        heads = [self.attention(layer[0](Q), layer[1](K), layer[2](V), self.d_k_value) for layer in self.linear]
        concat_heads = torch.cat(heads, dim = -1)
        return torch. matmul(concat_heads, self.W_O)

# Tests

In [44]:
multi_head_attention = MultiHeadAttention(512, 512, 512, 8)
multi_head_attention.linear

ModuleList(
  (0-7): 8 x ModuleList(
    (0-2): 3 x Linear(in_features=512, out_features=512, bias=True)
  )
)

In [45]:
batch_size = 2    # Example batch size
seq_length = 5    # Length of the sequence (number of tokens)
model_dim = 64    # Dimension of the model
d_k = 16          # Dimension of the keys (and queries)

# Generate random tensors for Q, K, V
Q = torch.randn(batch_size, seq_length, d_k)  # Queries
K = torch.randn(batch_size, seq_length, d_k)  # Keys
V = torch.randn(batch_size, seq_length, model_dim)  # Values

# Print the shapes for confirmation
print("Shape of Q:", Q.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of K:", K.shape)  # Expected: (batch_size, seq_length, d_k)
print("Shape of V:", V.shape)  # Expected: (batch_size, seq_length, model_dim)

multi_head_attention = MultiHeadAttention(model_dim, d_k, model_dim, 4)
multi_head_attention(Q, K, V)

Shape of Q: torch.Size([2, 5, 16])
Shape of K: torch.Size([2, 5, 16])
Shape of V: torch.Size([2, 5, 64])


tensor([[[ 2.4445e+34, -7.4946e+34,  3.4983e+33, -3.0524e+34,  1.9479e+34,
          -4.8222e+34,         nan,         nan, -6.0450e+32,  2.0500e+33,
           1.3955e+34, -5.6239e+34,  2.1966e+33,  1.8136e+34,  9.8687e+33,
          -5.0540e+33,  1.5466e+35,  3.4836e+33,  1.5569e+34,  1.4061e+34,
           2.2541e+34,  1.1064e+34, -4.2650e+34, -2.4469e+34,  9.3073e+33,
           2.9667e+34, -1.8876e+35, -3.0091e+34,  2.9198e+34,  4.1485e+34,
          -8.6919e+33,  5.4249e+33,  1.7572e+34, -2.3438e+34,  2.3730e+34,
           1.1159e+34,  7.0735e+34,  6.0478e+34,  1.1153e+34, -5.6251e+33,
          -5.7578e+33,  2.9731e+34, -9.1963e+33, -5.9002e+33, -2.2500e+34,
           5.6736e+33, -5.2378e+34,  2.0040e+34,  1.3603e+32,  5.9965e+34,
          -1.0792e+34,  2.2398e+33,  2.3582e+34,  1.7505e+34,  4.0903e+34,
           2.6814e+32, -7.3109e+33, -1.4660e+34,  3.3339e+34, -1.3406e+34,
           1.6280e+34,  1.5884e+34, -2.6358e+33,  2.9108e+34],
         [ 4.0159e+34, -7.0553e+34, -