<a href="https://colab.research.google.com/github/Shiveshrane/Research_paper_implementations/blob/main/Llama_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn


In [8]:
import tensorflow as tf
import numpy as np

In [3]:
import math

## Precompute RoPE

### Tensorflow version

In [9]:
class RotaryEmbeddingsTF(tf.keras.layers.Layer):
  def __init__(self, dim, max_seq_len, **kwargs):
    super().__init__(**kwargs)
    self.dim=dim
    self.max_seq_len=max_seq_len
    self.theta=tf.pow(10000, -2 * tf.range(0, dim, 2, dtype=tf.float32) / tf.cast(dim, tf.float32))
    positions=tf.range(max_seq_len, dtype=tf.float32)
    positions=tf.reshape(positions, (-1,1))
    self.cos=tf.cos(positions * self.theta)
    self.sin=tf.sin(positions * self.theta)

  def call(self, x, start_pos):
    batch, seq_len, n_heads, head_dim=x.shape
    x_reshaped=tf.reshape(x, (batch, seq_len, n_heads, head_dim//2, 2)) # Reshape to (batch, seq_len, n_heads, head_dim//2, 2)
    start_pos=tf.convert_to_tensor(start_pos, dtype=tf.int32)
    cos=tf.slice(self.cos, [start_pos, 0], [seq_len, head_dim//2]) # Slice cos based on head_dim//2
    sin=tf.slice(self.sin, [start_pos, 0], [seq_len, head_dim//2]) # Slice sin based on head_dim//2
    cos=tf.expand_dims(tf.expand_dims(cos, 0), 2) # Expand dimensions to match x_reshaped
    sin=tf.expand_dims(tf.expand_dims(sin, 0), 2) # Expand dimensions to match x_reshaped
    #print(self.cos, self.sin)
    x0=x_reshaped[..., 0]
    x1=x_reshaped[..., 1]
    x_rot=tf.stack([x0*cos - x1*sin, x0*sin + x1*cos], axis=-1)
    return tf.reshape(x_rot, (batch, seq_len, n_heads, head_dim)) # Reshape back to original shape

In [39]:
# Initialize layer
rotary_emb = RotaryEmbeddingsTF(dim=128, max_seq_len=2048)
tf.random.set_seed(42)

# Create dummy input (batch_size=2, seq_len=10, n_heads=8, head_dim=128)
x = tf.random.normal((2, 10, 8, 128))

# Apply rotary embeddings
output = rotary_emb(x, start_pos=0)
print(output.shape)  # (2, 10, 8, 128)

(2, 10, 8, 128)


### Pytorch version

In [None]:
class RotaryEmbeddings(nn.Module):
    def __init__(self, dim, max_seq_len):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len

        # Precompute theta for dimension pairs
        theta = 10000 ** (-2 * torch.arange(0, dim//2, dtype=torch.float32) / dim)
        self.register_buffer("theta", theta)  # (dim//2,)

        # Precompute sin/cos for all positions
        positions = torch.arange(max_seq_len, dtype=torch.float32).unsqueeze(1)  # (max_seq_len, 1)
        angles = positions * self.theta.unsqueeze(0)  # (max_seq_len, dim//2)
        self.register_buffer("cos", torch.cos(angles))  # (max_seq_len, dim//2)
        self.register_buffer("sin", torch.sin(angles))  # (max_seq_len, dim//2)

    def forward(self, x, start_pos=0):
        batch, seq_len, n_heads, head_dim = x.shape
        x = x.view(batch, seq_len, n_heads, head_dim//2, 2)  # (B, T, H, D//2, 2)

        # Slice cos/sin and add dimensions for broadcasting
        cos = self.cos[start_pos:start_pos+seq_len]  # (T_slice, D//2)
        sin = self.sin[start_pos:start_pos+seq_len]  # (T_slice, D//2)
        cos = cos.unsqueeze(0).unsqueeze(2)  # (1, T_slice, 1, D//2)
        sin = sin.unsqueeze(0).unsqueeze(2)  # (1, T_slice, 1, D//2)

        # Apply rotation: [x0, x1] â†’ [x0*cos - x1*sin, x0*sin + x1*cos]
        x_rot = torch.stack([
            x[..., 0] * cos - x[..., 1] * sin,
            x[..., 0] * sin + x[..., 1] * cos
        ], dim=-1)  # (B, T, H, D//2, 2)

        return x_rot.view(batch, seq_len, n_heads, head_dim)  # (B, T, H, D)

In [None]:
torch.manual_seed(42)
x=torch.randn(2, 10, 8, 128)
rotary_emb=RotaryEmbeddings(dim=128, max_seq_len=2048)
output=rotary_emb(x, start_pos=0)
output.shape


torch.Size([2, 10, 8, 128])

import torch
def apply_rotary_embedding_torch(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
    # Two consecutive values will become a single complex number
    # (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension
    # (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # Which results in the rotation of the complex number as shown in the Figure 1 of the paper
    # (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2)
    x_rotated = x_complex * freqs_complex
    # Convert the complex number back to the real number
    # (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2)
    x_out = torch.view_as_real(x_rotated)
    # (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

## RMSNorm

### RMSNorm in pytorch

In [None]:
class RMS_Norm_Torch(nn.Module):
  def __init__(self, dim, eps=1e-6 ):
    super().__init__()
    self.eps=eps
    ## Gamma parameter
    self.weight=nn.Parameter(torch.ones(dim))

  def norm(self, x):
    # x.shape= (B,Seq_len, Dim)
    rms_val=torch.sqrt(x.pow(2).mean(-1, keepdims=True)+self.eps)
    return x/rms_val

  def forward(self, x):
    return self.weight*self.norm(x)

In [None]:
## Test RMS
x=torch.ones(1,10,128)
rms=RMS_Norm_Torch(128)
rms(x)

tensor([[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]],
       grad_fn=<MulBackward0>)

### RMSNorm in Tensorflow

In [10]:
class RMS_Norm_TF(tf.keras.layers.Layer):
  def __init__(self, dim, eps=1e-6, **kwargs):
    super().__init__(**kwargs)
    self.eps=eps
    self.weight=self.add_weight(
        name="rms_weight",
        shape=(dim,),
        initializer="ones",
        trainable=True,
    )

  def rms_calc(self, x):
    rms_val=tf.sqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)+self.eps)
    return x/rms_val

  def call(self, x):
    return self.weight*self.rms_calc(x)

In [9]:
## Test
x=tf.ones((1,10,128))
rms=RMS_Norm_TF(128)
rms(x)

<tf.Tensor: shape=(1, 10, 128), dtype=float32, numpy=
array([[[0.9999995, 0.9999995, 0.9999995, ..., 0.9999995, 0.9999995,
         0.9999995],
        [0.9999995, 0.9999995, 0.9999995, ..., 0.9999995, 0.9999995,
         0.9999995],
        [0.9999995, 0.9999995, 0.9999995, ..., 0.9999995, 0.9999995,
         0.9999995],
        ...,
        [0.9999995, 0.9999995, 0.9999995, ..., 0.9999995, 0.9999995,
         0.9999995],
        [0.9999995, 0.9999995, 0.9999995, ..., 0.9999995, 0.9999995,
         0.9999995],
        [0.9999995, 0.9999995, 0.9999995, ..., 0.9999995, 0.9999995,
         0.9999995]]], dtype=float32)>

## Grouped Query Attention

### GQA in Pytorch

In [10]:
def repeat_kv(x, n_rep):
  batch_size, seq_len, n_kv_heads, head_dim=x.shape
  if n_rep==1:
    return x
  else:
    return(
        # (B, Seq_len, KV_heads, Head_dim)
        x[:, :, :, None, :]
        .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(batch_size, seq_len, n_kv_heads*n_rep, head_dim)
    )

In [None]:
import math

In [None]:
class SelfAttentionTorch(nn.Module):
  def __init__(self, n_heads, dim, max_seq_len, max_batch_size, use_cache=True, n_kv_heads=None):
    super().__init__()

    ## No of heads for Query
    self.n_heads_q=n_heads

    ## No of heads for Key and Value (same as query if kv heads=None)
    self.n_kv_heads= n_heads if n_kv_heads is None else n_kv_heads

    ## No of times key and value needs to be repeated to match the head of the queries
    self.n_rep=self.n_heads_q//self.n_kv_heads

    # indicates the dimension of each head
    self.head_dim=dim//self.n_heads_q

    self.wq=nn.Linear(dim, self.n_heads_q*self.head_dim, bias=False)
    self.wk=nn.Linear(dim, self.n_kv_heads*self.head_dim, bias=False)
    self.wv=nn.Linear(dim, self.n_kv_heads*self.head_dim, bias=False)
    self.wo=nn.Linear(self.n_heads_q*self.head_dim, dim, bias=False)


    self.inference=use_cache

    if self.inference:
      self.cache_k=torch.zeros((max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim))
      self.cache_v=torch.zeros((max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim))

    self.rotary=RotaryEmbeddings(self.head_dim, max_seq_len)


  def forward(self, x, start_pos):
    batch_size, seq_len, _=x.shape

    # (B,1, Dim)--> (B,1, H_Q*Head_dim)
    xq=self.wq(x)

    # (B,1, Dim)--> (B,1, H_kv*Head_dim)
    xk=self.wk(x)

    ## (B,1, Dim)--> (B,1, H_KV*Head_dim)
    xv=self.wv(x)

    # (B,1, HQ*Head_dim)--> (B,1, HQ, head_dim)
    xq=xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)

    # (B,1, HQ*Head_dim)--> (B,1, HKV, head_dim)
    xk=xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

    # (B,1, HQ*Head_dim)--> (B,1, HkV, head_dim)
    xv=xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)


      # Does not change shape
   # rotary_matrix = self.rotary.init_matrix(seq_len)
    #xq = torch.einsum("bihd,ijd->bihj", xq, rotary_matrix)
    #k = torch.einsum("bihd,ijd->bihj", xk, rotary_matrix)

    xq=self.rotary(xq, start_pos)
    xk=self.rotary(xk, start_pos)
    #print(xq.shape
    ## Einsum is same as matrix mult
    """
    xq= xq @ rotatory_matrix.transpose(perm=(0,2,1)) # Gives shape (B, seq_len, heads, head_dim) (since 1 and 2 of matrix is head_dim itself)
    xk= xk @ rotatory_matrix.transpose(perm=(0,2,1))


    """


      # Replace the entry in the cache for this token
    if self.inference:
      self.cache_k[:batch_size, start_pos:start_pos+seq_len]=xk
      self.cache_v[:batch_size, start_pos:start_pos+seq_len]=xv

     #(B, Seq_len_KV,h_kv, head_dim)
     # Retrieve all values.
      keys=self.cache_k[:batch_size, 0:start_pos+seq_len]
      values=self.cache_v[:batch_size, 0:start_pos+seq_len]


    # Repeat the heads of K and V to reach the number of heads of the queries

    else:
      keys=xk
      values=xv

    keys=repeat_kv(keys, self.n_rep)
    values=repeat_kv(values, self.n_rep)

    # (B, 1, H_Q, Head_dim)--> (B, H_Q, 1, head_dim)
    xq=xq.transpose(1,2)
    keys=keys.transpose(1,2)
    values=values.transpose(1,2)


    # (B, H_Q, 1, Head_Dim) @(B, H_kv, Head_dim, seq_len_kv)==> (B, HQ, 1, seq_len_kv)
    scores=torch.matmul(xq, keys.transpose(2,3))/math.sqrt(self.head_dim)

    if self.inference==False:
      mask=torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1)
      scores.masked_fill_(mask, -np.inf)


    scores=torch.softmax(scores, dim=-1).type_as(xq)
    #print(scores)

    # (B, hq,1,seq_len_kv)@ (B,H_kv, seq_len_kv, head_dim)--> (B,Hq, 1, head_dim)
    output=torch.matmul(scores, values)



    #(B, HQ, 1, Head_dim)--> (B, 1,HQ, Head_dim)
    output=(output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))
    return self.wo(output) #(B, 1, Dim)





In [None]:
x=torch.ones(1,10,128)
attn=SelfAttentionTorch(n_heads=16, dim=128, max_seq_len=10, max_batch_size=1, use_cache=False)
val=attn(x, 0)
val, val.shape

(tensor([[[ 0.0036, -0.0640,  0.0164,  ..., -0.4840, -0.5122, -0.0622],
          [ 0.0036, -0.0640,  0.0164,  ..., -0.4840, -0.5122, -0.0622],
          [ 0.0036, -0.0640,  0.0164,  ..., -0.4840, -0.5122, -0.0622],
          ...,
          [ 0.0036, -0.0640,  0.0164,  ..., -0.4840, -0.5122, -0.0622],
          [ 0.0036, -0.0640,  0.0164,  ..., -0.4840, -0.5122, -0.0622],
          [ 0.0036, -0.0640,  0.0164,  ..., -0.4840, -0.5122, -0.0622]]],
        grad_fn=<UnsafeViewBackward0>),
 torch.Size([1, 10, 128]))

In [None]:
x=torch.ones(1,10,64)
attn=SelfAttentionTorch(n_heads=16, dim=64, max_seq_len=10, max_batch_size=1, use_cache=False)
attn(x, 0).shape

torch.Size([1, 10, 64])

### GQA in Tensorflow

In [20]:
class SelfAttentionTF(tf.keras.layers.Layer):
  def __init__(self, n_heads, dim, max_seq_len, max_batch_size, use_cache=False, n_kv_heads=None):
    super().__init__()
    self.n_heads_q=n_heads
    self.n_kv_heads= n_heads if n_kv_heads is None else n_kv_heads
    self.n_rep=self.n_heads_q//self.n_kv_heads
    self.head_dim=dim//self.n_heads_q
    self.wq=tf.keras.layers.Dense(self.n_heads_q*self.head_dim, use_bias=False)
    self.wk=tf.keras.layers.Dense(self.n_kv_heads*self.head_dim, use_bias=False)
    self.wv=tf.keras.layers.Dense(self.n_kv_heads*self.head_dim, use_bias=False)
    self.wo=tf.keras.layers.Dense(dim, use_bias=False)
    #self.freqs=precompute_theta_pos_frequencies(self.head_dim, max_seq_len)//For using the function method

    self.inference=use_cache
    #if self.inference:
     # self.cache_k=tf.Variable(tf.zeros((max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim)))
      #self.cache_v=tf.Variable(tf.zeros((max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim)))
    if self.inference:
            self.cache_k = self.add_weight(
                name="cache_k",
                shape=(max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim),
                initializer="zeros",
                trainable=False,
            )
            self.cache_v = self.add_weight(
               name= "cache_v",
                shape=(max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim),
                initializer="zeros",
                trainable=False,
            )
    self.rotary=RotaryEmbeddingsTF(dim, max_seq_len)

  def call(self, x, start_pos):
      batch_size, seq_len, _=x.shape
      xq=self.wq(x)
      xk=self.wk(x)
      xv=self.wv(x)

      xq=tf.reshape(xq, shape=(batch_size, seq_len, self.n_heads_q, self.head_dim))
      xk=tf.reshape(xk, shape=(batch_size, seq_len, self.n_kv_heads, self.head_dim))
      xv=tf.reshape(xv, shape=(batch_size, seq_len, self.n_kv_heads, self.head_dim))

      xq=self.rotary(xq, start_pos=start_pos)
      xk=self.rotary(xk, start_pos=start_pos)
      #xq=apply_rotary_embeddings(xq, self.freqs)
      #xk=apply_rotary_embeddings(xk, self.freqs)
      if self.inference:
        self.cache_k[:batch_size, start_pos:start_pos+seq_len].assign(xk)
        self.cache_v[:batch_size, start_pos:start_pos+seq_len].assign(xv)
        keys=self.cache_k[:batch_size, :start_pos+seq_len]
        values=self.cache_v[:batch_size, :start_pos+seq_len]
      else:
        keys=xk
        values=xv

      keys=tf.repeat(keys, repeats=self.n_rep, axis=2)
      values=tf.repeat(values, repeats=self.n_rep, axis=2)
      #print(self.n_rep,keys.shape, values.shape)

      xq=tf.transpose(xq, perm=(0,2,1,3))
      keys=tf.transpose(keys, perm=(0,2,1,3))
      values=tf.transpose(values, perm=(0,2,1,3))
      scores=tf.matmul(xq, keys, transpose_b=True)/math.sqrt(self.head_dim)
      if self.inference==False:
        mask=tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=tf.bool), -1, 0)
        mask=tf.reshape(mask, (1,1,seq_len, seq_len))
        #print(mask)
        scores= tf.where(mask, scores, -np.inf)

      attn=tf.nn.softmax(scores, axis=-1)
      output=tf.matmul(attn, values)
      output=tf.transpose(output, perm=(0,2,1,3))
      output=tf.reshape(output, shape=(batch_size, seq_len, self.n_heads_q*self.head_dim))
      return self.wo(output)









In [23]:
def test_self_attention_layer():
    batch_size = 1
    seq_len = 10
    model_dim = 64
    n_heads = 4
    max_seq_len = 16
    max_batch_size = 4
    n_kv_heads = 2

    tf.random.set_seed(42)

    x = tf.random.uniform((batch_size, seq_len, model_dim))
    print("Input shape:", x.shape)

    print("\n--- Testing SelfAttentionTF in Training Mode (no caching) ---")
    sa_layer_train = SelfAttentionTF(n_heads, model_dim, max_seq_len, max_batch_size, n_kv_heads=n_kv_heads, use_cache=False)
    output_train = sa_layer_train(x, start_pos=0)  # Pass start_pos as a keyword argument
    print("Training mode output shape:", output_train.shape)
    output=RMS_Norm_TF(model_dim)(output_train)
    print(output)
   # print("Training mode output:", output_train.numpy())

if __name__ == '__main__':
    test_self_attention_layer()


Input shape: (1, 10, 64)

--- Testing SelfAttentionTF in Training Mode (no caching) ---
Training mode output shape: (1, 10, 64)
tf.Tensor(
[[[-1.32251644e+00  3.17307174e-01 -5.68085276e-02  2.32912445e+00
    1.05909801e+00 -2.28188252e+00 -1.40880480e-01 -7.88160622e-01
   -8.72630417e-01  1.38920736e+00 -1.93686873e-01  7.01617301e-01
    1.12299621e+00  7.74149776e-01 -7.75494814e-01  9.81954515e-01
    3.59548837e-01 -1.19232142e+00  2.78841078e-01 -3.78031582e-01
    9.25838947e-02  7.48536170e-01 -4.94152397e-01 -1.48785448e+00
   -7.93696165e-01  2.70074129e-01  4.26028550e-01  1.06547391e+00
    1.61096418e+00 -1.34831417e+00 -3.11176538e-01 -1.47922993e-01
   -7.99997866e-01 -1.42394245e+00 -3.36356193e-01  2.28861070e+00
   -1.55109298e+00 -1.37776339e+00 -1.17986917e+00  2.78266221e-01
   -8.96787703e-01 -1.68970430e+00 -7.40677476e-01 -1.57817125e+00
    1.00614083e+00  3.25292677e-01  2.98547834e-01  8.03892612e-01
   -1.54277131e-01 -3.02864194e-01 -1.03238754e-01 -1.141

## FeedForward (with Swiglu)

### FF in Pytorch

In [16]:
class FeedForward_Torch(nn.Module):
  def __init__(self, dim, multiple_of=256, custom_mult=None, dropout=0.0):
    super().__init__()
    hidden_dim=dim*4
    hidden_dim=int(2*hidden_dim/3)
    if custom_mult is not None:
      hidden_dim=int(dim*custom_mult)
    hidden_dim=multiple_of*((hidden_dim+multiple_of-1)//multiple_of)
    self.w1=nn.Linear(dim, hidden_dim)
    self.w2=nn.Linear(hidden_dim, dim)
    self.w3=nn.Linear(dim, hidden_dim)
    self.dropout=nn.Dropout(dropout)

  def forward(self, x):
    x=self.w1(x)
    swish=F.silu(x)
    x_V=self.w3(x)
    x=swish*x_V # element wise multiplication
    x=self.w2(x)
    return self.dropout(x)

In [18]:
class EncoderLayer(nn.Module):
  def __init__(self, dim, n_heads, dropout=0.0):
    super().__init__()
    self.attention=SelfAttentionTorch(dim=dim, n_heads=n_heads)
    self.feed_forward=FeedForward_Torch(dim=dim)
    self.norm1=RMS_Norm_Torch(dim)
    self.norm2=RMS_Norm_Torch(dim)
    self.dropout=nn.Dropout(dropout)
    def forward(self, x, start_pos):
      x=x+self.dropout(self.attention(self.norm1(x), start_pos=start_pos))
      x=x+self.dropout(self.feed_forward(self.norm2(x)))
      return x

### FF in Tensorflow

In [12]:
class FeedForward_TF(tf.keras.layers.Layer):
  def __init__(self, dim, multiple_of=256, custom_mult=None, dropout=0.0):
    super().__init__()
    hidden_dim=dim*4
    hidden_dim=int(2*hidden_dim/3)
    if custom_mult is not None:
      hidden_dim=int(dim*custom_mult)
    hidden_dim=multiple_of*((hidden_dim+multiple_of-1)//multiple_of)

    self.w1=tf.keras.layers.Dense(hidden_dim, use_bias=False)
    self.w2=tf.keras.layers.Dense(dim, use_bias=False) # Final layer
    self.w3=tf.keras.layers.Dense(hidden_dim, use_bias=False)
    self.dropout=tf.keras.layers.Dropout(dropout)


  def call(self,x):
    x=self.w1(x)
    swish=tf.nn.silu(x)
    x_V=self.w3(x)
    x=swish*x_V # element wise multiplication
    x=self.w2(x)
    return self.dropout(x)


In [26]:
class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, dim, n_heads, max_seq_len, batch_size, dropout=0.0, use_cache=False):
    super().__init__()
    self.attention=SelfAttentionTF(dim=dim, n_heads=n_heads, max_seq_len=max_seq_len, max_batch_size=batch_size, use_cache=use_cache)
    self.feed_forward=FeedForward_TF(dim=dim)
    self.norm1=RMS_Norm_TF(dim)
    self.norm2=RMS_Norm_TF(dim)
    self.dropout=tf.keras.layers.Dropout(dropout)
  def call(self, x, start_pos):
    x=x+self.dropout(self.attention(self.norm1(x), start_pos=start_pos))
    x=x+self.dropout(self.feed_forward(self.norm2(x)))
    return x

In [25]:
## Test Encoder
x=tf.ones((1,10,128))
enc=EncoderLayer(128, 16, 10, 32)
enc(x, start_pos=0)

<tf.Tensor: shape=(1, 10, 128), dtype=float32, numpy=
array([[[ 0.1210289 ,  0.4260625 ,  0.534716  , ...,  0.2922683 ,
          0.7803144 , -0.55545443],
        [ 0.12102893,  0.4260623 ,  0.534716  , ...,  0.2922682 ,
          0.78031427, -0.55545443],
        [ 0.1210289 ,  0.4260621 ,  0.53471595, ...,  0.29226857,
          0.780314  , -0.55545473],
        ...,
        [ 0.12102884,  0.42606243,  0.5347161 , ...,  0.29226837,
          0.78031456, -0.5554542 ],
        [ 0.12102899,  0.42606243,  0.5347157 , ...,  0.29226813,
          0.7803142 , -0.5554545 ],
        [ 0.12102884,  0.4260621 ,  0.5347161 , ...,  0.29226822,
          0.78031427, -0.55545455]]], dtype=float32)>

## Final showdown

### Encoder block

In [41]:
class Transformer(tf.keras.Model):
  def __init__(self, dim, n_heads, n_layers, max_seq_len, batch_size, vocab_size, eps=1e-6, dropout=0.0, use_cache=False):
    super().__init__()
    assert vocab_size!=-1, "Vocab size must be set"
    self.token_embeddings=tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=dim)
    self.encoder_layers=[EncoderLayer(dim=dim, n_heads=n_heads, max_seq_len=max_seq_len, batch_size=batch_size, dropout=dropout, use_cache=use_cache) for _ in range(n_layers)]
    self.norm=RMS_Norm_TF(dim, eps=eps)
    self.final_layer=tf.keras.layers.Dense(vocab_size, use_bias=False)

  def call(self, x):
    batch_size, seq_len=x.shape
    x=self.token_embeddings(x)
    for i in range(len(self.encoder_layers)):
      x=self.encoder_layers[i](x, start_pos=0)
    x=self.norm(x)
    logits=self.final_layer(x)
    return logits

In [42]:
## Test transformer
x=tf.random.uniform((1,10), minval=0, maxval=100, dtype=tf.int32)
transformer=Transformer(dim=64, n_heads=16, n_layers=1, max_seq_len=10, batch_size=1, vocab_size=100, use_cache=True)
transformer(x)


<tf.Tensor: shape=(1, 10, 100), dtype=float32, numpy=
array([[[-1.11166120e+00,  6.41244769e-01, -4.09252465e-01,
         -3.30242842e-01,  5.90650141e-01, -3.13680589e-01,
         -1.76802315e-02,  5.72568297e-01,  2.21795782e-01,
          1.16955495e+00, -1.19740689e+00,  1.27444410e+00,
         -9.74639654e-01, -4.57591921e-01,  6.84647739e-01,
          5.06904006e-01, -1.72054362e+00,  1.19893670e+00,
          8.80591035e-01,  6.67242527e-01, -3.37746799e-01,
         -5.29994726e-01,  1.05022812e+00,  1.85814071e+00,
          2.47316048e-01, -1.85167694e+00,  7.07224727e-01,
          1.16831601e+00, -2.14823937e+00, -1.51316512e+00,
          1.35644639e+00, -2.56015867e-01,  9.17008758e-01,
          1.12232864e-01,  5.42326927e-01, -9.00227189e-01,
          3.83093096e-02,  1.05395103e+00,  1.39789328e-01,
         -1.26404274e+00, -1.70159554e+00,  5.64920962e-01,
         -1.44117093e-02,  5.58448970e-01, -1.42117047e+00,
         -1.72615349e-02, -1.07714760e+00,  1.