In [None]:
import math
from typing import Optional, List

import torch
from torch import nn

from labml import tracker

In [None]:
class PrepareForMultiHeadAttention(nn.Module):
    '''
    Prepares the input for multi-head attention by reshaping it.
    This module reshapes the input tensor to have an additional dimension for the number of heads.
    
    Args:
    - d_model (int): The dimension of the model.
    - heads (int): The number of attention heads.
    - d_k (int): The dimension of each head.
    - bias (bool): Whether to include a bias term in the linear transformation.

    Example:
    >>> prepare = PrepareForMultiHeadAttention(d_model=512, heads=8, d_k=64, bias=True)
    >>> input_tensor = torch.randn(10, 20, 512)  # (batch_size, seq_len, d_model)
    >>> output_tensor = prepare(input_tensor)
    output_tensor.shape  # Should be (10, 8, 20, 64)
    '''
    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        # Linear layer for linear transformation
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        # Number of heads
        self.heads = heads
        # Dimension of each head
        self.d_k = d_k
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass to reshape the input tensor for multi-head attention.
        Args:
        - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
        Returns:
        - torch.Tensor: Reshaped tensor of shape (batch_size, heads, seq_len, d_k).
        '''

        head_shape = x.shape[:-1]

        # Apply linear transformation
        x = self.linear(x)

        # Reshape to (batch_size, heads, seq_len, d_k)
        x = x.view(*head_shape, self.heads, self.d_k)

        # Output has shape (batch_size, heads, seq_len, d_k) or (batch_size, seq_len, heads, d_k)
        return x 