In [1]:
import torch

In [2]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponential.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()

    cos, sin = freqs.cos(), freqs.sin()

    return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)


In [3]:
freq_tensor = precompute_freqs_cis(2, 10)

In [4]:
freq_tensor

tensor([[[[ 1.0000, -0.0000],
          [ 0.0000,  1.0000]]],


        [[[ 0.5403, -0.8415],
          [ 0.8415,  0.5403]]],


        [[[-0.4161, -0.9093],
          [ 0.9093, -0.4161]]],


        [[[-0.9900, -0.1411],
          [ 0.1411, -0.9900]]],


        [[[-0.6536,  0.7568],
          [-0.7568, -0.6536]]],


        [[[ 0.2837,  0.9589],
          [-0.9589,  0.2837]]],


        [[[ 0.9602,  0.2794],
          [-0.2794,  0.9602]]],


        [[[ 0.7539, -0.6570],
          [ 0.6570,  0.7539]]],


        [[[-0.1455, -0.9894],
          [ 0.9894, -0.1455]]],


        [[[-0.9111, -0.4121],
          [ 0.4121, -0.9111]]]])

In [4]:
import torch
import torch.nn as nn
from typing import Optional

In [8]:
def precompute_freq_cis(dim, end, theta):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    
    cos, sin = freqs.cos(), freqs.sin()
    return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size() , 2, 2)

def reshape_freq_tensor(freq_cis, x, seq_dim):
    ndim=x.ndim
    shape = [
        d if i ==  seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
    ]
    return freq_cis.view(*shape)

def apply_rotary_embedding(
  xq: torch.Tensor,
  xk: torch.Tensor,
  seq_dim: int,
  freq_cis: torch.Tensor  
):
    xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
    freq_cis = reshape_freq_tensor(
        freq_cis, xq_, seq_dim
    ).float()
    xq_out = (xq_ * freq_cis).sum(5).flatten(3)
    xk_out = (xk_ * freq_cis).sum(5).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class RotaryEmbedding(nn.Module):
    def __init__(self, theta, head_dim, max_seqlen):
        super().__init__()
        self.theta = theta
        self.head_dim = head_dim
        self.max_seqlen = max_seqlen
        
        self.register_buffer(
            "freq_cis",
            precompute_freq_cis(
                dim=self.head_dim,
                end=self.max_seqlen,
                theta=self.theta
            ),
            persistent=False
        )

    def reset_parameters(self):
        self.freq_cis[...] = precompute_freq_cis(
            dim=self.head_dim,
            end=self.max_seqlen,
            theta=self.theta
        )
    
    def forward(self, seqlen: Optional[int] = None, token_id: Optional[torch.Tensor] = None):
        
        check = seqlen is None or token_id is None
        assert check, "Either seqlen or token_id must be provided."
        if token_id is not None:
            return self.freq_cis[token_id]
        elif seqlen is not None:
            return self.freq_cis[:seqlen]

In [9]:
rotary = RotaryEmbedding(
    theta=10000,
    head_dim=64,
    max_seqlen=1024
)

In [10]:
rotary(
    seqlen=4096
)

tensor([[[[ 1.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  1.0000e+00]],

         [[ 1.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  1.0000e+00]],

         [[ 1.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  1.0000e+00]],

         ...,

         [[ 1.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  1.0000e+00]],

         [[ 1.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  1.0000e+00]],

         [[ 1.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  1.0000e+00]]],


        [[[ 5.4030e-01, -8.4147e-01],
          [ 8.4147e-01,  5.4030e-01]],

         [[ 7.3176e-01, -6.8156e-01],
          [ 6.8156e-01,  7.3176e-01]],

         [[ 8.4601e-01, -5.3317e-01],
          [ 5.3317e-01,  8.4601e-01]],

         ...,

         [[ 1.0000e+00, -2.3714e-04],
          [ 2.3714e-04,  1.0000e+00]],

         [[ 1.0000e+00, -1.7783e-04],
          [ 1.7783e-04,  1.0000e+00]],

         [[ 1.0000e+00, -1.3335e-04],
          [ 1.3335e-04,  1.0000e+00]]],


        [[[-4.1615e-01, -9.093

In [22]:
def precompute_freq_cis(dim, end, theta):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    
    cos, sin = freqs.cos(), freqs.sin()
    return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size() , 2, 2)

def reshape_freq_tensor(freq_cis, x, seq_dim):
    ndim=x.ndim
    shape = [
        d if i ==  seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
    ]+[2, 2]
    return freq_cis.view(*shape)


In [29]:
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.
        seq_dim (int): Sequence dimension index.

    Returns:
        torch.Tensor: Reshaped frequency tensor.
    """
    ndim = x.ndim
    assert 0 <= seq_dim < ndim
    assert freqs_cis.shape == (
        x.shape[seq_dim],
        x.shape[-3],
        2,
        2,
    ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
    shape = [
        d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
    ] + [2, 2]
    return freqs_cis.view(*shape)

In [36]:
def apply_rotary_embedding(
  xq: torch.Tensor,
  xk: torch.Tensor,
  seq_dim: int,
  freq_cis: torch.Tensor  
):
    xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
    freq_cis = reshape_freq_tensor(
        freq_cis, xq_, seq_dim
    ).float()
    xq_out = (xq_ * freq_cis).sum(5).flatten(3)
    xk_out = (xk_ * freq_cis).sum(5).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [56]:
frq = precompute_freq_cis(64, 4096, 10000)

In [57]:
x = torch.randn(4, 4096, 12, 64)
x.ndim

4

In [58]:
new_x = x.reshape(
    *x.shape[:-1], -1, 1, 2
)
new_y = x.reshape(
    *x.shape[:-1], -1, 1, 2
)

In [59]:
freq_cis = reshape_for_broadcast(
    frq, new_x, 1
)

In [64]:
newx_out = (new_x * freq_cis).sum(5).flatten(3)

In [61]:
new_x.shape

torch.Size([4, 4096, 12, 32, 1, 2])

In [62]:
newx_out.shape

torch.Size([4, 4096, 12, 32, 2, 2])

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [36]:
p = torch.randn(1, 10)

In [37]:
p

tensor([[ 2.4172, -0.2354,  0.3734, -1.0634, -2.0294,  0.5437, -1.0728, -0.8370,
          0.1716,  0.6862]])

In [38]:
p

tensor([[ 2.4172, -0.2354,  0.3734, -1.0634, -2.0294,  0.5437, -1.0728, -0.8370,
          0.1716,  0.6862]])

In [43]:
def entropy(scores):
    """
    scores: [bs, seq_len, vocab]
    returns [bs, seq_len]

    Computes the entropy for each token in the batch.
    Note: uses natural log.
    """
    log_probs = F.log_softmax(scores, dim=-1)
    probs = torch.exp(log_probs)
    p_log_p = log_probs * probs
    entropy = -p_log_p.sum(dim=-1)
    return entropy

In [44]:
entropy(p)

tensor([1.5168])

In [45]:
from torch.distributions import Categorical

entropy = Categorical(logits=p).entropy()
print(entropy)

tensor([1.5168])


In [46]:
bs, trunc_seqlen = 4, 384
patchstart = torch.full(
    size=(bs, trunc_seqlen),
    fill_value=trunc_seqlen,
    dtype=torch.long,
    
)

In [47]:
patchstart

tensor([[384, 384, 384,  ..., 384, 384, 384],
        [384, 384, 384,  ..., 384, 384, 384],
        [384, 384, 384,  ..., 384, 384, 384],
        [384, 384, 384,  ..., 384, 384, 384]])