In [90]:
import tensorflow as tf
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
import torch
from rotary_embedding_torch import RotaryEmbedding



In [91]:

X1 = tf.random.uniform((100,100))
X1 = tf.round(X1*10000)/10000

np_array = X1.numpy()

X1_torch = torch.from_numpy(np_array).unsqueeze(0).unsqueeze(0)



# Orcanzoi RoPE Implementation



In [None]:



def get_rope_rotation_matrix_list(d, m_vec, base=10000.0, dtype=tf.float32):
    js = np.arange(0, d, 2)  
    theta = 1.0 / (base ** (js / d))  
    rotation_matrices = []
    for m in m_vec:
        m_theta = m * theta
        cos_m_theta = tf.constant(np.cos(m_theta), dtype=dtype)
        sin_m_theta = tf.constant(np.sin(m_theta), dtype=dtype)
        I = tf.eye(d, dtype=dtype)
        #R = tf.identity(I)
        for idx, j in enumerate(js):
            e_j = tf.eye(d, dtype=dtype)[:, j:j+1]
            e_j1 = tf.eye(d, dtype=dtype)[:, j+1:j+2]
            ej_ejT = tf.matmul(e_j, e_j, transpose_b=True)
            ej1_ej1T = tf.matmul(e_j1, e_j1, transpose_b=True)
            ej_ej1T = tf.matmul(e_j, e_j1, transpose_b=True)
            ej1_ejT = tf.matmul(e_j1, e_j, transpose_b=True)
            Rj = (
                I
                + (cos_m_theta[idx] - 1) * (ej_ejT + ej1_ej1T)
                + sin_m_theta[idx] * (ej_ej1T - ej1_ejT)
            )
            #R = tf.matmul(R, Rj)
        rotation_matrices.append(Rj)
    return rotation_matrices




def apply_rope_per_row(X, base=10000.0):

    m, d = X.shape
    m_vec = np.arange(m)
    rotations = get_rope_rotation_matrix_list(d, m_vec, base=base, dtype=X.dtype)  
    rotations = tf.stack(rotations)  
    X_exp = tf.expand_dims(X, axis=1)  
    X_rotated = tf.matmul(X_exp, rotations)  
    X_rotated = tf.squeeze(X_rotated, axis=1)  
    return X_rotated




In [93]:
%%time
X_rotated = apply_rope_per_row(X1)
print(X_rotated)
X_rotated[70][99]


tf.Tensor(
[[0.9627     0.4698     0.0137     ... 0.0895     0.7386     0.1422    ]
 [0.9864     0.9629     0.3087     ... 0.3442     0.7981824  0.14639597]
 [0.318      0.67       0.7297     ... 0.0364     0.1957812  0.07824708]
 ...
 [0.5777     0.2154     0.3627     ... 0.1299     0.57167816 0.5883067 ]
 [0.7184     0.809      0.8707     ... 0.6045     0.23151873 0.16103892]
 [0.9035     0.9099     0.5241     ... 0.7485     0.8659355  0.88436913]], shape=(100, 100), dtype=float32)
CPU times: total: 8.7 s
Wall time: 8.71 s


<tf.Tensor: shape=(), dtype=float32, numpy=0.5905372500419617>

In [94]:
%%time
d = 100
rotary_emb = RotaryEmbedding(dim = d)
q = rotary_emb.rotate_queries_or_keys(X1_torch)
print(q)
q[0][0][70][99]


tensor([[[[ 0.9627,  0.4698,  0.0137,  ...,  0.0895,  0.7386,  0.1422],
          [-0.2773,  1.3503, -0.2137,  ...,  0.3443,  0.7982,  0.1464],
          [-0.7416,  0.0103, -0.5426,  ...,  0.0367,  0.1958,  0.0782],
          ...,
          [-0.6162,  0.0200,  0.3972,  ...,  0.1433,  0.5717,  0.5883],
          [-0.1247, -1.0747,  0.9443,  ...,  0.6064,  0.2315,  0.1610],
          [ 0.9452, -0.8666,  0.1652,  ...,  0.7582,  0.8659,  0.8844]]]])
CPU times: total: 125 ms
Wall time: 8.88 ms


tensor(0.5905)

In [95]:
# import tensorflow as tf
# import sonnet as snt

# class RoPEMultiheadAttention(snt.Module):
#     def __init__(self, num_heads, key_size, value_size, rope_dim=128, name=None):
#         super().__init__(name=name)
#         self.num_heads = num_heads
#         self.key_size = key_size
#         self.value_size = value_size
#         self.rope_dim = rope_dim
#         self.q_proj = snt.Linear(num_heads * key_size, name='q_proj')
#         self.k_proj = snt.Linear(num_heads * key_size, name='k_proj')
#         self.v_proj = snt.Linear(num_heads * value_size, name='v_proj')
#         self.out_proj = snt.Linear(num_heads * value_size, name='out_proj')

#     def rope_partial(self, x):
#         # x: [batch, seq_len, num_heads, head_dim]
#         shape = tf.shape(x)
#         batch, seq_len, num_heads, head_dim = shape[0], shape[1], shape[2], shape[3]
#         rope_dim = self.rope_dim
#         rot_part = x[..., :rope_dim]  # (..., rope_dim)
#         pass_part = x[..., rope_dim:] # (..., head_dim - rope_dim)
#         inv_freq = 1.0 / (10000 ** (tf.cast(tf.range(0, rope_dim, 2), tf.float32) / tf.cast(rope_dim, tf.float32)))
#         pos = tf.cast(tf.range(seq_len), tf.float32)
#         freqs = tf.einsum('i,j->ij', pos, inv_freq)  # (seq_len, rope_dim//2)
#         cos = tf.cos(freqs)  # (seq_len, rope_dim//2)
#         sin = tf.sin(freqs)  # (seq_len, rope_dim//2)
#         cos = tf.reshape(cos, [1, seq_len, 1, -1])  # (1, seq_len, 1, rope_dim//2)
#         sin = tf.reshape(sin, [1, seq_len, 1, -1])  # (1, seq_len, 1, rope_dim//2)
#         rot_part_2d = tf.reshape(rot_part, [batch, seq_len, num_heads, -1, 2])  # (..., rope_dim//2, 2)
#         rot_0 = rot_part_2d[..., 0] * cos - rot_part_2d[..., 1] * sin
#         rot_1 = rot_part_2d[..., 0] * sin + rot_part_2d[..., 1] * cos
#         rot_part_rotated = tf.reshape(tf.stack([rot_0, rot_1], axis=-1), [batch, seq_len, num_heads, rope_dim])
#         out = tf.concat([rot_part_rotated, pass_part], axis=-1)
#         return out

#     def __call__(self, x, is_training=False):
#         # x: [batch, seq_len, d_model]
#         batch = tf.shape(x)[0]
#         seq_len = tf.shape(x)[1]
#         d_model = x.shape[-1]
#         Q = self.q_proj(x)  # (batch, seq_len, num_heads * key_size)
#         K = self.k_proj(x)
#         V = self.v_proj(x)
#         Q = tf.reshape(Q, [batch, seq_len, self.num_heads, self.key_size])
#         K = tf.reshape(K, [batch, seq_len, self.num_heads, self.key_size])
#         V = tf.reshape(V, [batch, seq_len, self.num_heads, self.value_size])
#         Q = self.rope_partial(Q)
#         K = self.rope_partial(K)
#         Q = tf.transpose(Q, [0, 2, 1, 3])  # (batch, num_heads, seq_len, key_size)
#         K = tf.transpose(K, [0, 2, 1, 3])
#         V = tf.transpose(V, [0, 2, 1, 3])
#         attn_scores = tf.matmul(Q, K, transpose_b=True) / tf.math.sqrt(tf.cast(self.key_size, tf.float32))  # (batch, num_heads, seq_len, seq_len)
#         attn_weights = tf.nn.softmax(attn_scores, axis=-1)
#         attn_output = tf.matmul(attn_weights, V)  # (batch, num_heads, seq_len, value_size)
#         attn_output = tf.transpose(attn_output, [0, 2, 1, 3])  # (batch, seq_len, num_heads, value_size)
#         attn_output = tf.reshape(attn_output, [batch, seq_len, self.num_heads * self.value_size])
#         return self.out_proj(attn_output)

In [96]:
X1 = tf.expand_dims(X1,axis=0)
X1 = tf.expand_dims(X1,axis=0)


In [97]:
def rope_partial(x):
    # x: [batch, seq_len, num_heads, head_dim]
    shape = tf.shape(x)
    batch, seq_len, num_heads, head_dim = shape[0], shape[1], shape[2], shape[3]
    rope_dim = 100
    rot_part = x[..., :rope_dim]  # (..., rope_dim)
    pass_part = x[..., rope_dim:] # (..., head_dim - rope_dim)
    inv_freq = 1.0 / (10000 ** (tf.cast(tf.range(0, rope_dim, 2), tf.float32) / tf.cast(rope_dim, tf.float32)))
    pos = tf.cast(tf.range(seq_len), tf.float32)
    freqs = tf.einsum('i,j->ij', pos, inv_freq)  # (seq_len, rope_dim//2)
    cos = tf.cos(freqs)  # (seq_len, rope_dim//2)
    sin = tf.sin(freqs)  # (seq_len, rope_dim//2)
    cos = tf.reshape(cos, [1, seq_len, 1, -1])  # (1, seq_len, 1, rope_dim//2)
    sin = tf.reshape(sin, [1, seq_len, 1, -1])  # (1, seq_len, 1, rope_dim//2)
    rot_part_2d = tf.reshape(rot_part, [batch, seq_len, num_heads, -1, 2])  # (..., rope_dim//2, 2)
    rot_0 = rot_part_2d[..., 0] * cos - rot_part_2d[..., 1] * sin
    rot_1 = rot_part_2d[..., 0] * sin + rot_part_2d[..., 1] * cos
    rot_part_rotated = tf.reshape(tf.stack([rot_0, rot_1], axis=-1), [batch, seq_len, num_heads, rope_dim])
    out = tf.concat([rot_part_rotated, pass_part], axis=-1)
    return out

In [98]:
c = rope_partial(X1)
print(c)
c[0][0][70][99]

tf.Tensor(
[[[[0.9627 0.4698 0.0137 ... 0.0895 0.7386 0.1422]
   [0.9864 0.9629 0.3087 ... 0.3442 0.7982 0.1463]
   [0.318  0.67   0.7297 ... 0.0364 0.1958 0.0782]
   ...
   [0.5777 0.2154 0.3627 ... 0.1299 0.5785 0.5816]
   [0.7184 0.809  0.8707 ... 0.6045 0.2334 0.1583]
   [0.9035 0.9099 0.5241 ... 0.7485 0.8764 0.874 ]]]], shape=(1, 1, 100, 100), dtype=float32)


<tf.Tensor: shape=(), dtype=float32, numpy=0.5853999853134155>