In [1]:
import torch
import torch.nn as nn
from config import SimpleDecoderOnlyTransformerConfig
from einops import rearrange

config = SimpleDecoderOnlyTransformerConfig()
config

SimpleDecoderOnlyTransformerConfig {
  "device": "cuda",
  "dropout": 0.1,
  "eps": 1e-06,
  "flash_attn": false,
  "hidden_size": 768,
  "intermediate_size": 3072,
  "max_seq_len": 64,
  "model_type": "simple_decoder_only_transformer",
  "n_layers": 12,
  "num_attention_heads": 12,
  "transformers_version": "4.48.1",
  "vocab_size": 999999
}

In [2]:
q = torch.rand([64, config.max_seq_len, config.num_attention_heads, config.hidden_size//config.num_attention_heads]).to(config.device) # B, seq_len, dim
k = torch.rand([64, config.max_seq_len, config.num_attention_heads, config.hidden_size//config.num_attention_heads]).to(config.device) # B, seq_len, dim

In [10]:
class RotaryEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.hidden_size % 2 == 0
        head_size = config.hidden_size // config.num_attention_heads
        freqs = 1 / (10000 ** (torch.arange(0, head_size, 2) / head_size))
        t = torch.arange(0, config.max_seq_len).float()
        freqs = torch.outer(t, freqs) # Equivalent to t.unsqueeze(1) @ freqs.unsqueeze(0), but maybe faster
        pos_cis = torch.polar(torch.ones_like(freqs), freqs).to(config.device)
        pos_cis = rearrange(pos_cis, f'l d -> 1 l 1 d') # [max_seq_len, hidden_size//2] -> [1 max_seq_len 1 hidden_size//2]
        self.register_buffer('pos_cis', pos_cis)
    
    def forward(self, q, k):
        seq_len = q.shape[1]
        q_ = torch.view_as_complex(q.reshape(*q.shape[:-1], -1, 2).float()) # [b l h k] -> [b l h k//2 2] -> [b l h k//2]
        k_ = torch.view_as_complex(k.reshape(*k.shape[:-1], -1, 2).float()) # [b l h k] -> [b l h k//2 2] -> [b l h k//2]
        q_out = torch.view_as_real(q_ * self.pos_cis[:seq_len]).flatten(3) # [b l h k//2] -> [b l h k//2 2] -> [b l h k]
        k_out = torch.view_as_real(k_ * self.pos_cis[:seq_len]).flatten(3) # [b l h k//2] -> [b l h k//2 2] -> [b l h k]
        return q_out, k_out

rope = RotaryEmbedding(config)
q, k = rope(q, k)
q.shape, k.shape

(torch.Size([64, 64, 12, 64]), torch.Size([64, 64, 12, 64]))