In [1]:
import torch

In [2]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponential.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()

    cos, sin = freqs.cos(), freqs.sin()

    return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)


In [3]:
freq_tensor = precompute_freqs_cis(2, 10)

In [4]:
freq_tensor

tensor([[[[ 1.0000, -0.0000],
          [ 0.0000,  1.0000]]],


        [[[ 0.5403, -0.8415],
          [ 0.8415,  0.5403]]],


        [[[-0.4161, -0.9093],
          [ 0.9093, -0.4161]]],


        [[[-0.9900, -0.1411],
          [ 0.1411, -0.9900]]],


        [[[-0.6536,  0.7568],
          [-0.7568, -0.6536]]],


        [[[ 0.2837,  0.9589],
          [-0.9589,  0.2837]]],


        [[[ 0.9602,  0.2794],
          [-0.2794,  0.9602]]],


        [[[ 0.7539, -0.6570],
          [ 0.6570,  0.7539]]],


        [[[-0.1455, -0.9894],
          [ 0.9894, -0.1455]]],


        [[[-0.9111, -0.4121],
          [ 0.4121, -0.9111]]]])