In [2]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


# Model Components

### RMS Normalization

### Feed Forward Network Layer

### Rotary Positional Encoding (RoPE)

### Grouped Query Attention Layer

### Self Attention Layer

### Transformer Block

In [41]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    # As written in the paragraph 3.2.2 of the paper
    # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"
    # Build the theta parameter
    # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
    # Shape: (Head_Dim / 2)
    theta_numerator = torch.arange(0, head_dim, 2).float()
    # Shape: (Head_Dim / 2)
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
    # Construct the positions (the "m" parameter)
    # Shape: (Seq_Len)
    m = torch.arange(seq_len, device=device)
    # Multiply each theta by each position using the outer product.
    # Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
    freqs = torch.outer(m, theta).float()
    # We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
    # (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2)
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
    # Two consecutive values will become a single complex number
    # (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    print(x.float().reshape(*x.shape[:-1], -1, 2).shape)
    # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
    # (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # Which results in the rotation of the complex number as shown in the Figure 1 of the paper
    # (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
    x_rotated = x_complex * freqs_complex
    # Convert the complex number back to the real number
    # (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
    x_out = torch.view_as_real(x_rotated)
    # (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)


In [None]:
freq = precompute_theta_pos_frequencies(128, 128, "cpu")
print(freq.shape)
x = torch.randn(1, 128, 128, 128)
r = apply_rotary_embeddings(x, freq, "cpu")

In [None]:
x = torch.randn(1, 128, 128, 128)
c = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
freq = freq.unsqueeze(0).unsqueeze(2)

print(c)
print(freq)

In [None]:
torch.arange(0, 127, 2)[: (127 // 2)].shape

In [None]:
m = torch.arange(128).float().view(,1)
m.shape

In [81]:
def pre_compute_rotation_matrix(head_dim: int, seq_len: int, rope_theta: float=10000):
        # head_dim: Dimension of the each head
        # rope_theta: rotation angle

        assert head_dim % 2 == 0, "Dimension must be even"
        freqs = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
        m = torch.arange(seq_len).float()
        print(torch.outer(m, freqs).shape)
        complex_freqs = torch.view_as_complex()
        return complex_freqs

In [None]:
d = torch.randn(1, 128, 128, 128)

xq_ = torch.view_as_complex(d.float().reshape(*d.shape[:-1], -1, 2))
print(xq_.shape)


In [1]:
import tiktoken