### Simply Implementation of Rotary Positional Embedding

In [49]:
import torch

In [50]:
def init_sinusoidal_positional_embedding(seq_len: int,
                                         hidden_size: int,
                                         base: int=10000,
                                         ):
    position = torch.arange(seq_len).unsqueeze(1)  # Shape: [seq_len, 1]
    div_term = torch.exp(torch.arange(0, hidden_size, 2) * -(torch.log(torch.tensor(base)) / hidden_size)) # Shape: hidden_size//2
    # Apply sin to even indices, cos to odd indices
    pos_embedding = torch.zeros(seq_len, hidden_size)
    pos_embedding[:, 0::2] = torch.sin(position * div_term)  # even indices
    pos_embedding[:, 1::2] = torch.cos(position * div_term)  # odd indices
    return pos_embedding.unsqueeze(0)  # Shape: [1,seq_len, hidden_size]

In [51]:
def rotary_embed(
    input: torch.Tensor 
    
):
    batch_size, seq_len, num_heads, hidden_size = input.shape
    
    sinusoidal_pos = init_sinusoidal_positional_embedding(seq_len, hidden_size) # Shape: [1,seq_len, hidden_size]
    
        
    cos_pos = torch.repeat_interleave(sinusoidal_pos[..., 1::2],repeats=2, dim=-1).unsqueeze(2)# Shape: [1, seq_len,  hidden_size] cos(m theta)
    sin_pos = torch.repeat_interleave(sinusoidal_pos[..., ::2],repeats=2, dim=-1).unsqueeze(2) # Shape: [1, seq_len,  hidden_size] sin(m theta)
    
    input_rerange = torch.stack((-input[..., 1::2], input[..., ::2]), dim=-1) #[ [-x2, x1],[-x4,x3], ... ]
    input_rerange = input_rerange.reshape(input.shape) #[b,s,n,h]
    
    output = input*cos_pos+input_rerange*sin_pos
    
    return output #[batchsize,seq_len,n_head,hidden_size]
    
 

In [52]:
q = torch.randn([4,256,8,512])
k = torch.randn([4,128,8,512])
q_rot = rotary_embed(q)
k_rot = rotary_embed(k)

attention = torch.einsum("bjnh,bknh->bnjk",q_rot,k_rot)
print(attention.shape) #[batchsize, n_head, query_seq_length, key_seq_length]

torch.Size([4, 8, 256, 128])
