In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super(AttentionHead, self).__init__()
        self.embed_dim = embed_dim  # Input embedding dimension
        self.head_dim = head_dim    # Dimension of each attention head

        # Linear layers for query, key, and value
        self.query = nn.Linear(embed_dim, head_dim, bias=False)
        self.key = nn.Linear(embed_dim, head_dim, bias=False)
        self.value = nn.Linear(embed_dim, head_dim, bias=False)

        # Output linear layer to combine the attended result
        self.out_proj = nn.Linear(head_dim, embed_dim, bias=False)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate QK^T / sqrt(d_k)
        d_k = K.size(-1)  # Head dimension for scaling
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

        # Apply optional mask (useful for causal/self-attention)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)

        # Multiply with V to get the final attention output
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

    def forward(self, x, mask=None):
        # Generate query, key, and value matrices
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Apply scaled dot-product attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # Project the attention output back to the embedding dimension
        output = self.out_proj(attention_output)

        return output, attention_weights

# Example usage:
if __name__ == "__main__":
    # Define input parameters
    batch_size = 2  # Number of sequences in a batch
    seq_len = 5     # Length of each sequence
    embed_dim = 16  # Input embedding dimension
    head_dim = 8    # Head dimension for multi-head attention

    # Create a random input tensor (batch_size, seq_len, embed_dim)
    x = torch.randn(batch_size, seq_len, embed_dim)

    # Initialize the attention head
    attention_head = AttentionHead(embed_dim, head_dim)

    # Run the input through the attention head
    output, attention_weights = attention_head(x)

    print("Attention Output:")
    print(output)
    print("\nAttention Weights:")
    print(attention_weights)


Attention Output:
tensor([[[ 0.1446, -0.0710, -0.0724,  0.0838, -0.0739,  0.0579,  0.0374,
          -0.0342,  0.1297,  0.0644,  0.0385,  0.0452,  0.0625,  0.0768,
          -0.1009, -0.1196],
         [ 0.1863, -0.0836, -0.0405,  0.0899, -0.0179,  0.0057,  0.0133,
          -0.0182,  0.1213,  0.1319,  0.0022,  0.0786,  0.0830,  0.1149,
          -0.1352, -0.0776],
         [ 0.2918, -0.1325, -0.0168,  0.0887,  0.0348, -0.0910,  0.0087,
          -0.0142,  0.1455,  0.2207, -0.1024,  0.1624,  0.1736,  0.1900,
          -0.2143, -0.0513],
         [ 0.2090, -0.0865, -0.0646,  0.0843, -0.0450,  0.0272,  0.0458,
          -0.0128,  0.1527,  0.1152,  0.0213,  0.1090,  0.1222,  0.1302,
          -0.1688, -0.1232],
         [ 0.1120, -0.0535, -0.0451,  0.0841, -0.0424,  0.0626,  0.0174,
          -0.0249,  0.0942,  0.0770,  0.0624,  0.0242,  0.0177,  0.0758,
          -0.0847, -0.0824]],

        [[ 0.0141, -0.0130,  0.0492, -0.0967,  0.0377, -0.0373,  0.0504,
          -0.1189, -0.0421, -0.0

In [3]:
attention_head.query

Linear(in_features=16, out_features=8, bias=False)