In [None]:
import torch
import torch.nn as nn

# 模拟 LLaMA 的配置
class DummyConfig:
    hidden_size = 16
    max_position_embeddings = 128
    rope_scaling = None

# 简化版 RoPE 初始化函数
def simple_rope_init_fn(config, device=None):
    dim = config.hidden_size // 2
    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, device=device).float() / dim))
    return inv_freq, 1.0

ROPE_INIT_FUNCTIONS = {"default": simple_rope_init_fn}


class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, config, device=None):
        super().__init__()
        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS["default"]
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

# ------------------------
# 测试
# ------------------------
config = DummyConfig()
rope = LlamaRotaryEmbedding(config)

batch, seq_len, hidden = 2, 10, config.hidden_size
x = torch.randn(batch, seq_len, hidden)  # 模拟输入
position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch, 1)  # 位置 [0,1,2,...]

cos, sin = rope(x, position_ids)

print("cos shape:", cos.shape)  # [batch, seq_len, hidden]
print("sin shape:", sin.shape)
print("cos[0,0]:", cos[0,1:4])  # 打印第一个 token 的 cos 编码


cos shape: torch.Size([2, 10, 16])
sin shape: torch.Size([2, 10, 16])
cos[0,0]: tensor([[ 0.5403,  0.9504,  0.9950,  0.9995,  0.9999,  1.0000,  1.0000,  1.0000,
          0.5403,  0.9504,  0.9950,  0.9995,  0.9999,  1.0000,  1.0000,  1.0000],
        [-0.4161,  0.8066,  0.9801,  0.9980,  0.9998,  1.0000,  1.0000,  1.0000,
         -0.4161,  0.8066,  0.9801,  0.9980,  0.9998,  1.0000,  1.0000,  1.0000],
        [-0.9900,  0.5828,  0.9553,  0.9955,  0.9996,  1.0000,  1.0000,  1.0000,
         -0.9900,  0.5828,  0.9553,  0.9955,  0.9996,  1.0000,  1.0000,  1.0000]])
