In [1]:
import jax
import flax

print("JAX version:", jax.__version__)
print("Flax version:", flax.__version__)


JAX version: 0.5.3
Flax version: 0.10.4


In [2]:
from flax import linen as nn
import jax 
import jax.numpy as jnp

In [6]:
class Einsum(nn.Module):
    """Einsum is used for customized matrix/tensor operations"""

    shape: tuple[int,...]
    weight_name: str = "w"
    initializer: nn.initializers.Initializer = nn.initializers.normal()
    dtype: jnp.dtype | None = None

    @nn.compact
    def __call__(self, eqn: str,x: jax.Array) -> jax.Array:
        w = self.param(
            self.weight_name,
            self.initializer,
            self.shape,
            self.dtype if self.dtype is not None else None,

        )
        return jnp.einsum(eqn, x, w)
    

In [7]:
class RMSNorm(nn.Module):
    """ Root Mean Normalization """

    @nn.compact
    def __call__(self,x):
        scale = self.param("scale", nn.initializers.zeros_init(),x.shape[-1]) # learnable paramter
        var = jnp.mean(jnp.square(x), axis=-1,keepdims=True) #taking mean of input from the nn layer 

        normed_inputs = x * jax.lax.rsqrt(var + 1e-06) # elemnt wise reciprocal multiplication

        scale = jnp.expand_dims(scale,axis=range(len(x.shape) - 1))
        normed_inputs = normed_inputs * (1 + scale)

        return normed_inputs

In [10]:
ROPE_BASE_FREQUENCY  = 10_000
def add_positional_embeddings(
        inputs: jax.Array,
        positions: jax.Array,
        max_wavelength : int = ROPE_BASE_FREQUENCY
) -> jax.Array:
    """First we will add positonal embeddings to the inputs then apply the rotaional mattrix when 
    applying multi head attention
    
    B,L,H,N is the input shape
    B->batch,L->Lenght of the sequence,H->No of heads, N -> head dim

    we add positions which has shape [B,L]

    
    """

    head_dim = inputs.shape[-1]
    num_timescales  = head_dim // 2 # as we are using rope embeddings we have to divde it by 2 as we apply rotation matrix

    log_timescale_increment = jnp.log(
        float(max_wavelength) / jnp.maximum(
            jnp.asarray(num_timescales,dtype=jnp.float32) -1  , 1
        )
    )

    inv_timescales = jnp.exp(
        jnp.arange(num_timescales,dtype=jnp.float32) * -log_timescale_increment
    )


    scaled_time = (
        positions[...,jnp.newaxis] * inv_timescales * inv_timescales[jnp.newaxis,jnp.newaxis,:]
    ) 
    """ here we multiply the thetas with the postions """

    scaled_time = scaled_time[..., jnp.newaxis, :]

    signal = jnp.concatenate(
        [jnp.sin(scaled_time),jnp.cos(scaled_time)] , axis = -1
    )

    position_embedding = signal.astype(jnp.float32)
    return inputs + position_embedding



def apply_rope(
        inputs: jax.Array,
        positions: jax.Array,
        base_frequencies: int,
        scale_factor: float = 1.0,
) -> jax.Array:
    """
    Applying rope
    the input here is the the input with shape : [B,L,H,N]

    """

    head_dim  = inputs.shape[-1] #extracting dim from inputs
    fraction = 2 * jnp.arange(0,head_dim//2) / head_dim # this is for diff dimensions diff frequencies the formula is base^(2i/head_dim)

    time_scale = base_frequencies ** fraction # the shape will be [head_dim//2]
    #positions have shape[B,L]

    sinusoid_positions = (
        positions[...,jnp.newaxis] / time_scale[jnp.newaxis,jnp.newaxis,:]
    ) # okay here broadcating happens the shape of positions become [B,L,1] and the shape of time scale is [1,1,head_dim//2]
  # so the final sinusoid shape is [B,L,head_dim/2]

    sinusoid_positions = sinusoid_positions[...,jnp.newaxis,:]

    if scale_factor < 1.0 :
        raise ValueError(f'scale factor must be >= 1.0 , the given scale factor:{scale_factor}')
        # the scale factor is used for the seq_len whichh is more than the pretrained length 
        # what happens is if we have more seq_len the roation angles misaligns as hihger frequencise have higher thetas whihch leads to overlapping of angles leading to degradation of postional embedding



    sinusoid_positions = sinusoid_positions / scale_factor

    sin = jnp.sin(sinusoid_positions)
    cos = jnp.cos(sinusoid_positions)

    first_half,second_half = jnp.split(inputs,2,axis = -1)
        






