<a href="https://colab.research.google.com/github/Shiveshrane/Research_paper_implementations/blob/main/Rotary_Embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf

### Theory
#### The Class Method.
The input is dim (Dimensionality of the embeddings), seq_len

theta=10000**(-2*(i)/dim) for i in range(0,1,2,.....(dim/2))

Theta is a non zero constant.
So i is part of a range from 0 to dim/2

We create a ```Position Tensor``` which contains all positional indices from 0 to max_len-1

Then angles are calculated using sin formula, and stored in a buffer (ie, they arent kept trainable)


``` IN forward method```
1. We take x and the start pos, which is 0
2. We take the size of x from, ie batch, seq_len,n_heads, head_dim=x.shape
3. We reshape the x in the shape (batch, seq_len, n_heads, dim//2, 2)

This allows to divide the embedding dimension into 2 components for rotation

4. We slice the sine and cos value from that specified start position to the start_pos+seq_len

We then add extra dims at 0th and 2th positon

5. We now apply rotation to the tensor using the matrix way, ie
 ```
    x[..., 0]*cos-x[...,1]sin
    x[..., 0]*cos-x[...,1]sin
    ```
  This is than stacked along the last dimension, ie dim=-1

6. We then reshape it again back in original shap, ie:
    ``` (batch, seq_len, n_heads, dim) ```
7. return the rotated tensor




In [2]:
class RotaryEmbedding(tf.keras.layers.Layer):
  def __init__(self, dim, seq_len, **kwargs):
    super(RotaryEmbedding,self).__init__(**kwargs)
    self.dim = dim
    self.seq_len = seq_len
    self.theta=tf.pow(10000, -2*tf.range(0, dim, 2, dtype=tf.float32)/tf.cast(dim, tf.float32))
    self.pos_tensor = tf.range(seq_len, dtype=tf.float32)
    self.pos_tensor=tf.reshape(self.pos_tensor, (-1,1))
    self.sin=tf.sin(self.pos_tensor*self.theta)
    self.cos=tf.cos(self.pos_tensor*self.theta)

  def call(self, x, start_pos=0):
    batch, len, n_heads, head_dim=x.shape
    x_reshaped=tf.reshape(x, (batch,len, n_heads, head_dim//2, 2))
    x_cos=tf.slice(self.cos, [start_pos,0], [self.seq_len, head_dim//2])
    x_sin=tf.slice(self.sin, [start_pos,0], [self.seq_len, head_dim//2])

    x_cos=tf.expand_dims(tf.expand_dims(x_cos, 0), 2)
    x_sin=tf.expand_dims(tf.expand_dims(x_sin, 0), 2)

    x0=x_reshaped[...,0]
    x1=x_reshaped[...,1]
    x_op=tf.stack([x0*x_cos-x1*x_sin, x0*x_sin+x1*x_cos], axis=-1)
    print(x_op.shape)
    x_op=tf.reshape(x_op, (batch, len, n_heads, head_dim))
    return x_op

In [None]:
x=tf.random.normal((2,10,4,8))
rotary=RotaryEmbedding(8,10)
rotary(x)

(2, 10, 4, 4, 2)


<tf.Tensor: shape=(2, 10, 4, 8), dtype=float32, numpy=
array([[[[-5.07640362e-01, -5.94620705e-01, -3.56113791e-01,
           5.67960024e-01,  2.86763519e-01, -2.37046286e-01,
           1.14617884e-01, -6.11950219e-01],
         [-1.53277564e+00, -9.75816309e-01,  1.16067670e-01,
           3.27773404e+00,  5.33493042e-01,  1.20748878e+00,
          -1.43000603e+00, -1.31717753e+00],
         [ 1.15877926e+00, -7.65947759e-01, -4.45443392e-02,
          -2.73956895e-01, -2.88327206e-02,  2.75142312e-01,
          -2.15718770e+00, -6.00409269e-01],
         [ 6.57899141e-01,  1.83286350e-02, -2.15667820e+00,
           3.86329561e-01,  4.96853113e-01, -1.48428118e+00,
          -3.28305602e-01, -1.47322416e+00]],

        [[-7.16666579e-02,  7.17665434e-01, -1.59775019e-01,
           2.33245477e-01, -7.95885473e-02,  6.81121767e-01,
          -5.24379075e-01, -3.54484916e-01],
         [ 5.10714114e-01, -3.11789393e-01, -8.69418859e-01,
           4.58465785e-01, -2.73109060e-02,  1.

## Theory: Functional method

#### Here, we use 2 functions, and use complex numbers, like in the papers to make this come true.
This one is much easier to impplement in pytorch, but I'll try to implement in tensorflow.

We use rotation matrix here (Polar form) (Euler form)
In tensorflow, we dont have a seperate polar function, so we will directly go for tf.complex, after calculating sine and cos

So the function is divided into 2 parts.
1. Precompute freq
- Here we first check if the head_dim%2==0.
- Create a position array from 0 to head_dim//2, make it float
- Calculate the theta using that formula ```theta=10000*(-2(i)/dim) for i in range(0,1,2,.....(dim/2))```
- We then create a range sequence of 0 to seq_len
- We perform outer op using theta and the range seq
- We calculate sine and cos of the above op
- we add it in tf.complex

2. We apply rotary embeddings
- Get the batch, seq_len, heads, head_dim from x.shape
- Reshape the array using ```tf.concat([tf.shape(x)[:-1], [tf.shape(x)[-1]//2,2 ]], axis=0)```
- We convert the first term of this array to real, while the 2nd term of each cell as complex
- We than expand the frequency on 0th, and 2nd dimension. And then multiply with the complex version of X
- Then we stack the result, in the form (real, imag) along the 1st dim
- We then resize the shape to the original shape and return.


In [21]:
def precompute_theta_freqs(seq_len, dim):
  pos=tf.range(0, dim,2, dtype=tf.float32)
  theta=tf.pow(10000, -2*pos/tf.cast(dim, tf.float32))
  print(theta.shape)
  theta=tf.reshape(theta, (1,-1))
  m=tf.range(0, seq_len, dtype=tf.float32)
  m=tf.reshape(m, (-1,1))
  print(m.shape)
  freqs=tf.keras.ops.outer(m, theta)
  return tf.complex(tf.sin(freqs), tf.cos(freqs))


In [49]:
tf.concat([tf.shape(x)[:-1], [tf.shape(x)[-1]//2,2 ]], axis=0)

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 2, 10,  4,  4,  2], dtype=int32)>

In [44]:
def apply_positionalEmbeddings(x, freq):
  batch, len, n_heads, head_dim=x.shape
  x_reshaped=tf.reshape(x, shape=tf.concat([tf.shape(x)[:-1], [tf.shape(x)[-1]//2,2 ]], axis=0)) # (b,seq_len, heads, head_dim//2, 2)
  x_cplx=tf.complex(x_reshaped[...,0], x_reshaped[...,1])
  val=tf.expand_dims(tf.expand_dims(freq, axis=0), axis=2)*x_cplx
  x_real=tf.stack([tf.math.real(val), tf.math.imag(val)], axis=1)
  return tf.reshape(x_real, shape=(batch, len,n_heads, head_dim))

In [50]:
x=tf.random.normal((2,10,4,8))
freq=precompute_theta_freqs(10,8)
apply_positionalEmbeddings(x, freq)

(4,)
(10, 1)


<tf.Tensor: shape=(2, 10, 4, 8), dtype=float32, numpy=
array([[[[-5.84691346e-01,  8.09939876e-02,  4.96796161e-01,
           1.03473210e+00,  3.40799272e-01, -7.30813563e-01,
           5.40869758e-02,  6.42201245e-01],
         [-2.62250185e-01,  5.40926754e-02,  4.63205546e-01,
           7.39895165e-01,  3.81825864e-01,  2.42858618e-01,
          -7.00922549e-01,  2.68965781e-01],
         [-2.07107854e+00, -9.17688727e-01, -1.74631476e-02,
           4.94003534e-01, -1.10503685e+00, -3.95400524e-01,
          -1.71925277e-01,  9.64548111e-01],
         [ 1.98673546e-01, -4.22701031e-01, -8.24392401e-03,
          -4.58068460e-01,  1.39493322e+00, -7.96380818e-01,
           4.85844254e-01, -9.00271177e-01]],

        [[ 1.19277954e+00,  3.49827200e-01,  1.90044200e+00,
           1.23597634e+00,  1.80475080e+00, -7.57934690e-01,
           8.29859197e-01,  1.96897829e+00],
         [ 4.57934201e-01,  8.13293606e-02,  6.06765747e-01,
           6.58186615e-01,  9.58878756e-01,  1.