In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DotProductAttention(nn.Module):
    def __init__(self):
        super(DotProductAttention, self).__init__()

    def forward(self, query, key, value):

        scores = torch.matmul(query, key.transpose(-2, -1))  # Dot product of query and key
        scores = scores / torch.sqrt(torch.tensor(query.size(-1), dtype=torch.float32)) # Scale the scores

        attention_weights = F.softmax(scores, dim=-1) # Apply softmax to get attention weights

        output = torch.matmul(attention_weights, value) # Compute the weighted sum using attention weights

        return output, attention_weights

# Let's test our implementation
batch_size = 2
query_len = 3
key_len = 4
d_model = 5

# Random input tensors
query = torch.randn(batch_size, query_len, d_model)
key = torch.randn(batch_size, key_len, d_model)
value = torch.randn(batch_size, key_len, d_model)

# Initialize and apply DotProductAttention
attention = DotProductAttention()
output, attention_weights = attention(query, key, value)

print("Output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)


Output shape: torch.Size([2, 3, 5])
Attention weights shape: torch.Size([2, 3, 4])


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads."

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)

        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections for each head
        query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Dot-product attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)

        # Apply dropout if needed

        # Compute the weighted sum using attention weights
        output = torch.matmul(attention_weights, value)

        # Concatenate heads and apply output linear transformation
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.output_linear(output)

        return output, attention_weights


In [None]:
batch_size = 2
query_len = 3
key_len = 4
d_model = 16
num_heads = 4

# Random input tensors
query = torch.randn(batch_size, query_len, d_model)
key = torch.randn(batch_size, key_len, d_model)
value = torch.randn(batch_size, key_len, d_model)

# Initialize and apply MultiHeadAttention
attention = MultiHeadAttention(d_model, num_heads)
output, attention_weights = attention(query, key, value)

print("Output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)

Output shape: torch.Size([2, 3, 16])
Attention weights shape: torch.Size([2, 4, 3, 4])
