In [75]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
    
    Args:
        dim: Dimension of the embedding.
        end: Maximum sequence length.
        theta: Base value for the frequency calculation.
    
    Returns:
        Complex tensor with shape [end, dim // 2] for efficient computation.
    """
    # Ensure dim is even
    if dim % 2 != 0:
        raise ValueError(f"Dimension {dim} must be even")
    
    # Create frequencies for each dimension
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # Create position indices
    t = torch.arange(end, device=freqs.device)
    
    # Outer product of position indices and frequencies
    freqs = torch.outer(t, freqs)
    
    # Compute complex exponentials: cos(x) + i*sin(x)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    
    return freqs_cis


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.
    
    Args:
        xq: Query states tensor of shape [batch_size, seq_len, n_heads, head_dim]
        xk: Key states tensor of shape [batch_size, seq_len, n_heads, head_dim]
        freqs_cis: Complex tensor of shape [seq_len, head_dim/2]
        
    Returns:
        Tuple of (xq_out, xk_out) with the same shape as the input tensors.
    """
    # Extract shapes
    batch, seq_len, n_heads, head_dim = xq.shape
    
    # Ensure head_dim is even
    if head_dim % 2 != 0:
        raise ValueError(f"Head dimension {head_dim} must be even")
    
    # Reshape inputs to complex-valued tensors
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # Extend frequency tensor to match the batch and heads dimensions
    freqs_cis = freqs_cis[:seq_len]
    
    # Apply rotation using complex multiplication
    xq_out = torch.view_as_real(xq_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).flatten(-2)
    
    # Return the rotated tensors with original dtype
    return xq_out.type_as(xq), xk_out.type_as(xk)


class RotaryEmbedding(nn.Module):
    """
    Rotary positional embedding implementation as a PyTorch module.
    """
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        self.register_buffer(
            "freqs_cis", precompute_freqs_cis(self.dim, self.max_seq_len, self.base)
        )
        
    def forward(self, q, k):
        """
        Apply rotary embeddings to query and key tensors.
        
        Args:
            q: Query tensor
            k: Key tensor
            
        Returns:
            Tuple of (q, k) with rotary embeddings applied
        """
        return apply_rotary_emb(q, k, self.freqs_cis)


# Example usage in a self-attention layer
class SelfAttentionWithRoPE(nn.Module):
    def __init__(self, hidden_size, num_heads, max_seq_len=2048):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        # Query, Key, Value projections
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)
        
        # Rotary embeddings
        self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len)
        
    def forward(self, hidden_states: torch.Tensor, attention_mask=None) -> torch.Tensor:
        batch_size, seq_length = hidden_states.shape[:2]
        
        # Project to query, key, value
        q:torch.Tensor = self.q_proj(hidden_states)
        k:torch.Tensor = self.k_proj(hidden_states)
        v:torch.Tensor = self.v_proj(hidden_states)
        
        # Reshape for multi-head attention
        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim)
        k = k.view(batch_size, seq_length, self.num_heads, self.head_dim)
        v = v.view(batch_size, seq_length, self.num_heads, self.head_dim)
        
        print(q.shape)
        
        # Apply rotary embeddings
        q, k = self.rotary_emb(q, k)
        
        print(q.shape)
        
        # Transpose for attention calculation
        q = q.transpose(1, 2)  # (batch, num_heads, seq_len, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Calculate attention scores
        scale = 1.0 / math.sqrt(self.head_dim)
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # Apply attention mask if provided
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask
        
        # Softmax and dropout
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        # Get weighted sum
        attn_output = torch.matmul(attn_weights, v)
        
        # Transpose and reshape
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_length, self.hidden_size
        )
        
        # Output projection
        attn_output = self.out_proj(attn_output)
        
        return attn_output


# Example usage
if __name__ == "__main__":
    # Model parameters
    batch_size = 2
    seq_length = 10
    hidden_size = 512
    num_heads = 8
    
    # Create random input
    hidden_states = torch.rand(batch_size, seq_length, hidden_size)
    
    # Initialize the self-attention layer with RoPE
    self_attn = SelfAttentionWithRoPE(hidden_size, num_heads)
    
    # Forward pass
    output = self_attn(hidden_states)
    print(f"Input shape: {hidden_states.shape}")
    print(f"Output shape: {output.shape}")
    
    # We can verify the output shape matches the input shape
    assert output.shape == hidden_states.shape

torch.Size([2, 10, 8, 64])
torch.Size([2, 10, 8, 64])
Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])


In [None]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
    
    Args:
        dim: Dimension of the embedding.
        end: Maximum sequence length.
        theta: Base value for the frequency calculation.
    
    Returns:
        Complex tensor with shape [end, dim // 2] for efficient computation.
    """
    # Ensure dim is even
    if dim % 2 != 0:
        raise ValueError(f"Dimension {dim} must be even")
    
    # Create frequencies for each dimension
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # Create position indices
    t = torch.arange(end, device=freqs.device)
    
    # Outer product of position indices and frequencies
    freqs = torch.outer(t, freqs)
    
    # Compute complex exponentials: cos(x) + i*sin(x)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    
    return freqs_cis

In [None]:
import torch

batch_size = 2
seq_length = 10
hidden_size = 512

num_heads = 8

# Create random input
hidden_states = torch.rand(batch_size, seq_length, hidden_size)

In [None]:
import torch

magnitude = torch.tensor([1.0])  # abs = 1
angle = torch.tensor([torch.pi / 4])  # angle = π/4

complex_number = torch.polar(magnitude, angle)
print(complex_number)  # Output: tensor([0.7071+0.7071j])


In [None]:
import torch

# Define frequency matrix
freqs = torch.tensor([[0.0, torch.pi / 2], [torch.pi, 3*torch.pi / 2]])  # Angles
magnitudes = torch.ones_like(freqs)  # Unit circle (magnitude = 1)

# Compute complex exponentials
freqs_cis = torch.polar(magnitudes, freqs)

print(freqs_cis)


In [None]:
def custom_polar(magnitude, angle):
    real = magnitude * torch.cos(angle)
    imag = magnitude * torch.sin(angle)
    
    return torch.stack([real, imag], dim=-1)

magnitude = torch.tensor([1.0])  # abs = 1
angle = torch.tensor([torch.pi / 4])  # angle = π/4

complex_number = custom_polar(magnitude, angle)


print(complex_number)  # Output: tensor([0.7071, 0.7071])

print(torch.view_as_complex(complex_number))

In [11]:
dim = 512
end = 10

torch.arange(0, dim, 2)[: (dim // 2)].shape

torch.Size([256])

In [12]:
torch.arange(0, dim, 2) == torch.arange(0, dim, 2)[: (dim // 2)]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr

In [13]:
torch.arange(0, dim, 2)[: (dim // 2)].float()/ dim

tensor([0.0000, 0.0039, 0.0078, 0.0117, 0.0156, 0.0195, 0.0234, 0.0273, 0.0312,
        0.0352, 0.0391, 0.0430, 0.0469, 0.0508, 0.0547, 0.0586, 0.0625, 0.0664,
        0.0703, 0.0742, 0.0781, 0.0820, 0.0859, 0.0898, 0.0938, 0.0977, 0.1016,
        0.1055, 0.1094, 0.1133, 0.1172, 0.1211, 0.1250, 0.1289, 0.1328, 0.1367,
        0.1406, 0.1445, 0.1484, 0.1523, 0.1562, 0.1602, 0.1641, 0.1680, 0.1719,
        0.1758, 0.1797, 0.1836, 0.1875, 0.1914, 0.1953, 0.1992, 0.2031, 0.2070,
        0.2109, 0.2148, 0.2188, 0.2227, 0.2266, 0.2305, 0.2344, 0.2383, 0.2422,
        0.2461, 0.2500, 0.2539, 0.2578, 0.2617, 0.2656, 0.2695, 0.2734, 0.2773,
        0.2812, 0.2852, 0.2891, 0.2930, 0.2969, 0.3008, 0.3047, 0.3086, 0.3125,
        0.3164, 0.3203, 0.3242, 0.3281, 0.3320, 0.3359, 0.3398, 0.3438, 0.3477,
        0.3516, 0.3555, 0.3594, 0.3633, 0.3672, 0.3711, 0.3750, 0.3789, 0.3828,
        0.3867, 0.3906, 0.3945, 0.3984, 0.4023, 0.4062, 0.4102, 0.4141, 0.4180,
        0.4219, 0.4258, 0.4297, 0.4336, 

In [14]:
dim = 512
theta = 10000.0

if dim % 2 != 0:
    raise ValueError(f"Dimension {dim} must be even")

# Create frequencies for each dimension
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

freqs.shape

torch.Size([256])

In [15]:
1.0 / (10000 ** (torch.arange(0, dim, 2)[: (dim // 2)].float()/ dim))

tensor([1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
        8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
        6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
        5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
        4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
        3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-01, 2.8387e-01,
        2.7384e-01, 2.6416e-01, 2.5483e-01, 2.4582e-01, 2.3714e-01, 2.2876e-01,
        2.2067e-01, 2.1288e-01, 2.0535e-01, 1.9810e-01, 1.9110e-01, 1.8434e-01,
        1.7783e-01, 1.7154e-01, 1.6548e-01, 1.5963e-01, 1.5399e-01, 1.4855e-01,
        1.4330e-01, 1.3824e-01, 1.3335e-01, 1.2864e-01, 1.2409e-01, 1.1971e-01,
        1.1548e-01, 1.1140e-01, 1.0746e-01, 1.0366e-01, 1.0000e-01, 9.6466e-02,
        9.3057e-02, 8.9769e-02, 8.6596e-02, 8.3536e-02, 8.0584e-02, 7.7737e-02,
        7.4989e-02, 7.2339e-02, 6.9783e-

In [16]:
freqs

tensor([1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
        8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
        6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
        5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
        4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
        3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-01, 2.8387e-01,
        2.7384e-01, 2.6416e-01, 2.5483e-01, 2.4582e-01, 2.3714e-01, 2.2876e-01,
        2.2067e-01, 2.1288e-01, 2.0535e-01, 1.9810e-01, 1.9110e-01, 1.8434e-01,
        1.7783e-01, 1.7154e-01, 1.6548e-01, 1.5963e-01, 1.5399e-01, 1.4855e-01,
        1.4330e-01, 1.3824e-01, 1.3335e-01, 1.2864e-01, 1.2409e-01, 1.1971e-01,
        1.1548e-01, 1.1140e-01, 1.0746e-01, 1.0366e-01, 1.0000e-01, 9.6466e-02,
        9.3057e-02, 8.9769e-02, 8.6596e-02, 8.3536e-02, 8.0584e-02, 7.7737e-02,
        7.4989e-02, 7.2339e-02, 6.9783e-

In [17]:
seq_length = 10

t = torch.arange(seq_length)

t

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [18]:
torch.outer(t, freqs).shape

torch.Size([10, 256])

In [19]:
v1 = torch.arange(1., 5.)
v2 = torch.arange(1., 4.)
torch.outer(v1, v2)

tensor([[ 1.,  2.,  3.],
        [ 2.,  4.,  6.],
        [ 3.,  6.,  9.],
        [ 4.,  8., 12.]])

In [20]:
torch.outer(v1, v2).T

tensor([[ 1.,  2.,  3.,  4.],
        [ 2.,  4.,  6.,  8.],
        [ 3.,  6.,  9., 12.]])

In [21]:
v1, v2

(tensor([1., 2., 3., 4.]), tensor([1., 2., 3.]))

In [22]:
torch.matmul(v1.unsqueeze(-1), v2.unsqueeze(0))

tensor([[ 1.,  2.,  3.],
        [ 2.,  4.,  6.],
        [ 3.,  6.,  9.],
        [ 4.,  8., 12.]])

In [23]:
v1 , v1.T

  v1 , v1.T


(tensor([1., 2., 3., 4.]), tensor([1., 2., 3., 4.]))

In [24]:
v1.unsqueeze(-1), v2.unsqueeze(0)

(tensor([[1.],
         [2.],
         [3.],
         [4.]]),
 tensor([[1., 2., 3.]]))

In [25]:
# torch.outer is works only on 1D tensors

In [26]:
torch.ones_like(torch.outer(t, freqs))

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

In [27]:
t.unsqueeze(1).shape, freqs.unsqueeze(0).shape

(torch.Size([10, 1]), torch.Size([1, 256]))

In [28]:
freqs.unsqueeze(0)

tensor([[1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
         8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
         6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
         5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
         4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
         3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-01, 2.8387e-01,
         2.7384e-01, 2.6416e-01, 2.5483e-01, 2.4582e-01, 2.3714e-01, 2.2876e-01,
         2.2067e-01, 2.1288e-01, 2.0535e-01, 1.9810e-01, 1.9110e-01, 1.8434e-01,
         1.7783e-01, 1.7154e-01, 1.6548e-01, 1.5963e-01, 1.5399e-01, 1.4855e-01,
         1.4330e-01, 1.3824e-01, 1.3335e-01, 1.2864e-01, 1.2409e-01, 1.1971e-01,
         1.1548e-01, 1.1140e-01, 1.0746e-01, 1.0366e-01, 1.0000e-01, 9.6466e-02,
         9.3057e-02, 8.9769e-02, 8.6596e-02, 8.3536e-02, 8.0584e-02, 7.7737e-02,
         7.4989e-02, 7.2339e

In [None]:
t.unsqueeze(1)

In [None]:
torch.outer(t, freqs)

## Day 3


In [52]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
    
    Args:
        dim: Dimension of the embedding.
        end: Maximum sequence length.
        theta: Base value for the frequency calculation.
    
    Returns:
        Complex tensor with shape [end, dim // 2] for efficient computation.
    """
    # Ensure dim is even
    if dim % 2 != 0:
        raise ValueError(f"Dimension {dim} must be even")
    
    # Create frequencies for each dimension
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # Create position indices
    t = torch.arange(end, device=freqs.device)
    
    # Outer product of position indices and frequencies
    freqs = torch.outer(t, freqs)
    
    # Compute complex exponentials: cos(x) + i*sin(x)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    
    return freqs_cis

In [53]:
freqs_cis = precompute_freqs_cis(512, 10)
freqs_cis , freqs_cis.shape

(tensor([[ 1.0000+0.0000e+00j,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
           ...,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
           1.0000+0.0000e+00j],
         [ 0.5403+8.4147e-01j,  0.5697+8.2186e-01j,  0.5974+8.0196e-01j,
           ...,  1.0000+1.1140e-04j,  1.0000+1.0746e-04j,
           1.0000+1.0366e-04j],
         [-0.4161+9.0930e-01j, -0.3509+9.3641e-01j, -0.2863+9.5814e-01j,
           ...,  1.0000+2.2279e-04j,  1.0000+2.1492e-04j,
           1.0000+2.0733e-04j],
         ...,
         [ 0.7539+6.5699e-01j,  0.8918+4.5239e-01j,  0.9735+2.2877e-01j,
           ...,  1.0000+7.7978e-04j,  1.0000+7.5223e-04j,
           1.0000+7.2564e-04j],
         [-0.1455+9.8936e-01j,  0.1363+9.9067e-01j,  0.3981+9.1736e-01j,
           ...,  1.0000+8.9118e-04j,  1.0000+8.5969e-04j,
           1.0000+8.2931e-04j],
         [-0.9111+4.1212e-01j, -0.7366+6.7637e-01j, -0.4979+8.6724e-01j,
           ...,  1.0000+1.0026e-03j,  1.0000+9.6715e-04j,
           1.0000+9.3297e-04j]]),
 torch

In [29]:
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.
    
    Args:
        xq: Query states tensor of shape [batch_size, seq_len, n_heads, head_dim]
        xk: Key states tensor of shape [batch_size, seq_len, n_heads, head_dim]
        freqs_cis: Complex tensor of shape [seq_len, head_dim/2]
        
    Returns:
        Tuple of (xq_out, xk_out) with the same shape as the input tensors.
    """
    # Extract shapes
    batch, seq_len, n_heads, head_dim = xq.shape
    
    # Ensure head_dim is even
    if head_dim % 2 != 0:
        raise ValueError(f"Head dimension {head_dim} must be even")
    
    # Reshape inputs to complex-valued tensors
    xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # Extend frequency tensor to match the batch and heads dimensions
    freqs_cis = freqs_cis[:seq_len]
    
    # Apply rotation using complex multiplication
    xq_out = torch.view_as_real(xq_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).flatten(-2)
    xk_out = torch.view_as_real(xk_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).flatten(-2)
    
    # Return the rotated tensors with original dtype
    return xq_out.type_as(xq), xk_out.type_as(xk)


In [77]:
# random states

torch.manual_seed(42)
torch.cuda.manual_seed(42)

q = torch.rand(2, 10, 8, 64)
k = torch.rand(2, 10, 8, 64)

freqs_cis = precompute_freqs_cis(64, 10)

q_ , k_ = apply_rotary_emb(q, k, freqs_cis)

q_.shape, k_.shape

(torch.Size([2, 10, 8, 64]), torch.Size([2, 10, 8, 64]))

In [36]:
q , q_

(tensor([[[[0.8823, 0.9150, 0.3829,  ..., 0.1587, 0.6542, 0.3278],
           [0.6532, 0.3958, 0.9147,  ..., 0.2083, 0.3289, 0.1054],
           [0.9192, 0.4008, 0.9302,  ..., 0.5535, 0.4117, 0.3510],
           ...,
           [0.1525, 0.3970, 0.8703,  ..., 0.1474, 0.6872, 0.9231],
           [0.5070, 0.9549, 0.0740,  ..., 0.2564, 0.1352, 0.9012],
           [0.8918, 0.1182, 0.4613,  ..., 0.4078, 0.5411, 0.0410]],
 
          [[0.6556, 0.1186, 0.1836,  ..., 0.6907, 0.9170, 0.3513],
           [0.3546, 0.7670, 0.2533,  ..., 0.2422, 0.0622, 0.3856],
           [0.6020, 0.0316, 0.9366,  ..., 0.2662, 0.2614, 0.0806],
           ...,
           [0.0620, 0.2249, 0.1381,  ..., 0.5235, 0.8648, 0.6559],
           [0.3225, 0.2944, 0.3762,  ..., 0.3112, 0.9130, 0.5512],
           [0.1261, 0.5031, 0.1117,  ..., 0.3092, 0.0702, 0.1836]],
 
          [[0.7785, 0.4253, 0.7124,  ..., 0.8245, 0.9554, 0.7918],
           [0.2408, 0.0055, 0.6897,  ..., 0.5963, 0.0773, 0.8968],
           [0.6508, 0.59

In [64]:
seq_len , head_dim = 10, 64

In [78]:

x_q_complex = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))

x_q_complex.shape, q.shape

(torch.Size([2, 10, 8, 32]), torch.Size([2, 10, 8, 64]))

In [45]:
q.shape[:-1]

torch.Size([2, 10, 8])

In [44]:
q.float().reshape(*q.shape[:-1], -1, 2).shape

torch.Size([2, 10, 8, 32, 2])

In [47]:
q.float().reshape(*q.shape[:-1], -1, 2)

tensor([[[[[0.8823, 0.9150],
           [0.3829, 0.9593],
           [0.3904, 0.6009],
           ...,
           [0.8913, 0.1447],
           [0.5315, 0.1587],
           [0.6542, 0.3278]],

          [[0.6532, 0.3958],
           [0.9147, 0.2036],
           [0.2018, 0.2018],
           ...,
           [0.4654, 0.1612],
           [0.1568, 0.2083],
           [0.3289, 0.1054]],

          [[0.9192, 0.4008],
           [0.9302, 0.6558],
           [0.0766, 0.8460],
           ...,
           [0.6870, 0.4121],
           [0.3676, 0.5535],
           [0.4117, 0.3510]],

          ...,

          [[0.1525, 0.3970],
           [0.8703, 0.7563],
           [0.1836, 0.0991],
           ...,
           [0.9142, 0.0409],
           [0.8343, 0.1474],
           [0.6872, 0.9231]],

          [[0.5070, 0.9549],
           [0.0740, 0.3090],
           [0.7916, 0.3911],
           ...,
           [0.4870, 0.8903],
           [0.9807, 0.2564],
           [0.1352, 0.9012]],

          [[0.8918, 0.11

In [46]:
torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))

tensor([[[[0.8823+0.9150j, 0.3829+0.9593j, 0.3904+0.6009j,  ...,
           0.8913+0.1447j, 0.5315+0.1587j, 0.6542+0.3278j],
          [0.6532+0.3958j, 0.9147+0.2036j, 0.2018+0.2018j,  ...,
           0.4654+0.1612j, 0.1568+0.2083j, 0.3289+0.1054j],
          [0.9192+0.4008j, 0.9302+0.6558j, 0.0766+0.8460j,  ...,
           0.6870+0.4121j, 0.3676+0.5535j, 0.4117+0.3510j],
          ...,
          [0.1525+0.3970j, 0.8703+0.7563j, 0.1836+0.0991j,  ...,
           0.9142+0.0409j, 0.8343+0.1474j, 0.6872+0.9231j],
          [0.5070+0.9549j, 0.0740+0.3090j, 0.7916+0.3911j,  ...,
           0.4870+0.8903j, 0.9807+0.2564j, 0.1352+0.9012j],
          [0.8918+0.1182j, 0.4613+0.0069j, 0.0907+0.5966j,  ...,
           0.1320+0.2316j, 0.3901+0.4078j, 0.5411+0.0410j]],

         [[0.6556+0.1186j, 0.1836+0.0843j, 0.9357+0.0265j,  ...,
           0.2012+0.0071j, 0.1931+0.6907j, 0.9170+0.3513j],
          [0.3546+0.7670j, 0.2533+0.2636j, 0.8081+0.0643j,  ...,
           0.3633+0.2947j, 0.0479+0.2422j, 

In [81]:
freqs_cis.shape

torch.Size([10, 32])

In [79]:
freqs_cis[:seq_len].shape

torch.Size([10, 32])

In [67]:
freqs_cis.shape

torch.Size([10, 32])

In [68]:
freqs_cis.unsqueeze(0).unsqueeze(2).shape

torch.Size([1, 10, 1, 32])

In [69]:
freqs_cis.unsqueeze(0).shape , freqs_cis.unsqueeze(0).unsqueeze(2).shape

(torch.Size([1, 10, 32]), torch.Size([1, 10, 1, 32]))

In [70]:
(x_q_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).shape

torch.Size([2, 10, 8, 32])

In [73]:
x_q_complex.shape , freqs_cis.unsqueeze(0).unsqueeze(2).shape

(torch.Size([2, 10, 8, 32]), torch.Size([1, 10, 1, 32]))

In [74]:
(x_q_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).flatten(-2).shape

torch.Size([2, 10, 256])

In [84]:
torch.view_as_real(x_q_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).shape

torch.Size([2, 10, 8, 32, 2])

In [86]:
(x_q_complex * freqs_cis.unsqueeze(0).unsqueeze(2)).shape

torch.Size([2, 10, 8, 32])