In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:

class SelfAttention(nn.Module):
    """
    Single head of self-attention.
    This class implements a single attention head as part of a larger attention mechanism.
    """
    def __init__(self, n_embd, head_size, dropout, block_size):
        """
        Initializes the SelfAttention layer.
        :param n_embd: Size of each embedding vector.
        :param head_size: Size of each attention head.
        :param dropout: Dropout rate for regularization.
        :param block_size: Size of the block for the attention mask.
        """
        super().__init__()  # Initialize the superclass (nn.Module)

        # Linear transformations for the keys, queries, and values.
        self.key = nn.Linear(n_embd, head_size, bias=False)  # Linear layer for keys.
        self.query = nn.Linear(n_embd, head_size, bias=False)  # Linear layer for queries.
        self.value = nn.Linear(n_embd, head_size, bias=False)  # Linear layer for values.

        # Register a lower triangular matrix as a buffer, used for creating masks in some attention mechanisms.
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # Dropout layer for regularization.
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Defines the forward pass of the self-attention layer.
        :param x: Input tensor of shape (batch size, sequence length, embedding dimension).
        """
        B, T, C = x.shape  # B: Batch size, T: Sequence length, C: Embedding dimension

        # Applying linear transformations to compute keys, queries, and values.
        k = self.key(x)  # Transform input to keys.
        q = self.query(x)  # Transform input to queries.

        # Calculate the attention weights.
        # q @ k: Perform batch matrix multiplication between queries and keys.
        # Transpose the last two dimensions of k for proper matrix multiplication.
        # Scale the result by the inverse square root of the dimension of the keys.
        weight = q @ k.transpose(-2, -1) * k.shape[-1] ** (-0.5)

        # Apply softmax to get probabilities, ensuring the sum of weights for each query is 1.
        weight = F.softmax(weight, dim=-1)

        # Transform input to values.
        v = self.value(x)

        # Calculate the output by weighted sum of values.
        out = weight @ v  # Batch matrix multiplication between weights and values.

        return out  # Return the result of the self-attention mechanism.
