In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def precompute_theta_pos_frequencies(head_dim: int, seq_len:int, device:str, theta: float=10000.0):
    assert head_dim % 2 == 0, "dimension must be divisible by 2"
    theta_numerator = torch.arange(0,head_dim,2).float() # [0,2,4,6, ....]
    theta = 1.0 / (theta ** (theta_numerator /head_dim)).to(device) # 각도 정해주기
    
    m = torch.arange(seq_len, device=device) # [0,1,2,3,4, ....]
    
    freqs = torch.outer(m, theta).float() # 외적 -> m이랑 세타를 외적함
    
    freqs_complex = torch.polar(torch.ones_like(freqs),freqs)
    return freqs_complex

head_dim = 8
seq_len = 10
device = 'cpu'

freqs_complex = precompute_theta_pos_frequencies(head_dim, seq_len, device)
print(freqs_complex)


tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],
        [-0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j],
        [-0.9900+0.1411j,  0.9553+0.2955j,  0.9996+0.0300j,  1.0000+0.0030j],
        [-0.6536-0.7568j,  0.9211+0.3894j,  0.9992+0.0400j,  1.0000+0.0040j],
        [ 0.2837-0.9589j,  0.8776+0.4794j,  0.9988+0.0500j,  1.0000+0.0050j],
        [ 0.9602-0.2794j,  0.8253+0.5646j,  0.9982+0.0600j,  1.0000+0.0060j],
        [ 0.7539+0.6570j,  0.7648+0.6442j,  0.9976+0.0699j,  1.0000+0.0070j],
        [-0.1455+0.9894j,  0.6967+0.7174j,  0.9968+0.0799j,  1.0000+0.0080j],
        [-0.9111+0.4121j,  0.6216+0.7833j,  0.9960+0.0899j,  1.0000+0.0090j]])


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "dimension must be divisible by 2"
    theta_numerator = torch.arange(0, head_dim, 2).float()  # [0,2,4,6, ...]
    theta_vals = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)  # 각도
    m = torch.arange(seq_len, device=device)  # [0,1,2,...,seq_len-1]
    freqs = torch.outer(m, theta_vals).float()  # m과 theta의 외적
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)  # 복소수 표현 (1, angle)
    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # 마지막 차원(head_dim)을 2로 나누어 복소수 텐서로 변환합니다.
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    
    # freqs_complex의 차원 맞추기: (1, seq_len, 1, head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    
    # 복소수 곱 연산
    x_rotated = x_complex * freqs_complex
    
    # 다시 실수 텐서로 변환 (마지막 차원이 2가 됨)
    x_out = torch.view_as_real(x_rotated)
    
    # 원래의 shape로 복원합니다.
    x_out = x_out.reshape(*x.shape)
    
    return x_out.type_as(x).to(device)

device = "cpu"  # 또는 "cuda" 사용 가능
head_dim = 8  # 반드시 2의 배수여야 합니다.
seq_len = 10
batch_size = 2  # 예시 배치 사이즈

# 임의의 데이터를 생성합니다. 
# shape: (batch_size, seq_len, head_dim)
x = torch.randn(batch_size, seq_len, head_dim)
print("원본 x:")
print(x)

# Rotary embeddings을 위한 각도 주파수 복소수 값 계산
freqs_complex = precompute_theta_pos_frequencies(head_dim, seq_len, device)
print("\n계산된 freqs_complex:")
print(freqs_complex)

# rotary embedding 적용
x_rotated = apply_rotary_embeddings(x, freqs_complex, device)
print("\nrotary embedding 적용 후 x_rotated:")
print(x_rotated)


원본 x:
tensor([[[ 1.0283,  1.3571,  0.7097,  1.4176, -0.7947,  1.0628,  0.1579,
          -1.1376],
         [ 0.2505, -0.5515, -1.4658,  0.2132, -0.8378,  0.0130,  0.0600,
           0.2982],
         [ 1.0666, -0.9124,  1.0255,  0.0243, -0.7689,  0.7426,  1.1425,
           0.9992],
         [ 1.9832,  0.9898,  2.4005,  0.4955,  0.6851,  2.2017,  1.3194,
          -0.4922],
         [-1.3652,  0.1174,  0.0676,  1.2757, -0.7534, -0.8079, -0.3021,
          -0.5468],
         [-0.1576,  0.6182,  1.5327,  1.0034,  1.0967,  1.1009,  3.2416,
           0.2178],
         [ 0.9898,  0.0973,  0.2729,  0.5641, -1.3300,  0.3752,  1.8387,
          -1.3646],
         [ 1.6059,  0.1259, -0.6587,  0.4747, -0.6078,  1.1926, -0.7239,
           0.5323],
         [ 0.1784, -0.6692, -0.7088, -0.8850, -3.2152, -0.6891, -0.8253,
           0.9278],
         [ 1.7650,  0.2146, -0.8469,  0.1316, -1.8610,  0.8300, -0.9973,
           0.2113]],

        [[ 0.9223,  1.8291,  2.8139, -0.3953,  0.2322, -0.3519

RuntimeError: The size of tensor a (2) must match the size of tensor b (10) at non-singleton dimension 1

In [10]:
import torch
import math

# 각 복소수의 크기를 모두 1로 설정
magnitudes = torch.ones(3)
print(magnitudes)

# 각 복소수의 각도를 0, π/2, π 라디안으로 설정
angles = torch.tensor([0.0, math.pi/2, math.pi])
print(angles)

# polar를 사용하여 복소수 텐서를 생성합니다.
complex_tensor = torch.polar(magnitudes, angles)
print(complex_tensor)


tensor([1., 1., 1.])
tensor([0.0000, 1.5708, 3.1416])
tensor([ 1.0000e+00+0.0000e+00j, -4.3711e-08+1.0000e+00j,
        -1.0000e+00-8.7423e-08j])
