### Measure Execution Time

In [5]:
from flax import linen as nn
from flax.training import train_state, checkpoints

import jax
import jax.numpy as jnp
from jax import random
# Seeding for random operations
main_rng = random.PRNGKey(42)

import time

from typing import Optional, List
from einops import rearrange
import optax

In [6]:
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "D:/Malky/research/saved_models/mla_hybridNorm_jax"

In [7]:
def rotary_embedding(x, max_seq_len, dim):
    """Apply RoPE to input tensor x."""
    # Ensure max_seq_len is a concrete value, not symbolic
    positions = jnp.arange(max_seq_len, dtype=jnp.float32)
    
    # Ensure freqs is calculated with concrete values for dim
    freqs = 1.0 / (10000 ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim))
    
    # Angle calculation with fixed values
    angles = positions[:, None] * freqs[None, :]
    
    sin, cos = jnp.sin(angles), jnp.cos(angles)
    
    # Apply rotary embeddings to the input tensor x
    x1, x2 = x[..., ::2], x[..., 1::2]
    x_rotated = jnp.concatenate([x1 * cos - x2 * sin, x1 * sin + x2 * cos], axis=-1)
    
    return x_rotated

# Scaled Dot-Product Attention

@jax.jit
def scaled_dot_product(q, k, v, mask=None):
    """Computes the scaled dot-product attention."""
    # q -> [B, n_h, L, d_h + d_h^R] 
    # k -> [B, n_h, L, d_h + d_h^R] 
    scale = jnp.sqrt(q.shape[-1])  # Scaling factor for attention scores  d_h + d_h_r
    scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) / scale  # Efficient batch matmul
    
    if mask is not None:
        scores += mask  # Apply mask in-place

    attention = nn.softmax(scores, axis=-1)  # Compute attention weights
    values = jnp.einsum("bhqk,bhkd->bhqd", attention, v)  # Apply attention to values

    return values, attention

# Helper function to support different mask shapes.
# Output shape supports (B, number of heads, seq length, seq length)
# If 2D: broadcasted over batch size and number of heads
# If 3D: broadcasted over number of heads
# If 4D: leave as is
@jax.jit
def expand_mask(mask):
    """Expands a mask tensor to shape (B, num_heads, L, L)."""
    ndim = mask.ndim
    assert ndim >= 2, "Mask must be at least 2D (L x L)"

    # Efficient broadcasting using jnp.reshape and jnp.expand_dims
    if ndim == 2:  # (L, L) → (1, 1, L, L)
        return mask[None, None, :, :]
    elif ndim == 3:  # (B, L, L) → (B, 1, L, L)
        return mask[:, None, :, :]
    
    return mask  # If already (B, num_heads, L, L), return as is


class MultiHeadLatentAttention(nn.Module):
    d_model: int  # Output dimension
    n_h: int  # Number of heads
    d_c: int  # Latent compression dimension
    d_c_: int #Latent caompression dimension fro Q

    def setup(self):
        #self.d_h = self.d_model // self.n_h  # Head dimension

        #RmsNorm after latent variable 
        self.rmsNormV = nn.RMSNorm()
        self.rmsNormK = nn.RMSNorm()
        self.rmsNormQ = nn.RMSNorm()


       # Fused projection layers for efficiency
        self.kv_proj = nn.Dense(2 * self.d_c, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros)
        self.q_proj = nn.Dense(self.d_c_, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros)

        # Up-projection for keys, values, and queries
        self.ukv_proj = nn.Dense(2 * self.d_model, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros)
        self.uq_proj = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros)

        # Output projection
        self.o_proj = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros) 

    def __call__(self, x, mask=None):
        # B, L, _ = x.shape  # Batch size, Sequence length, Embedding dim

        if mask is not None:
            mask = expand_mask(mask)

        # Compute compressed KV projection
        c_kv = self.kv_proj(x)  # [B, L, 2 * d_c]
        k, v = jnp.split(self.ukv_proj(c_kv), 2, axis=-1)  # [B, L, d_model] each

        # Compute query projection
        q = self.uq_proj(self.q_proj(x))  # [B, L, d_model]

        # Reshape using einops for efficiency
        q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.n_h), [q, k, v])

        q = self.rmsNormQ(q)
        k = self.rmsNormK(k)
        v = self.rmsNormV(v)
        # Compute attention (optimized version)
        values, attention = scaled_dot_product(q, k, v, mask)

        # Reshape and output projection
        values = rearrange(values, "b h l d -> b l (h d)")
        return self.o_proj(values), attention
    

class TransformerBlock(nn.Module):
    """A Transformer block with RMS Norm and residual connections."""
    d_model: int  # Dimension of the model (hidden size)
    n_h: int  # Number of attention heads
    dim_feedforward: int    # Dimension of the feed-forward network
    dropout_rate: float = 0.1  # Dropout rate
    d_c : int = 64 # Latent dimension (d_c)
    d_c_: int = 64 # Latent dimension (d_c)
    d_h_R: int = 32  # Rotated dimension (d_h^R)
    position: bool = False


    def setup(self):
        # Attention layer
        if self.position:
            self.self_attn = MultiHeadLatentAttentionRope(
                d_model = self.d_model,
                n_h = self.n_h,
                d_c = self.d_c,
                d_c_ = self.d_c_,
                d_h_R = self.d_h_R)
        else:
            self.self_attn = MultiHeadLatentAttention(
                d_model = self.d_model,
                n_h = self.n_h,
                d_c = self.d_c,
                d_c_ = self.d_c_)
            

        # Feed-Forward Network
        self.ffn = [
            nn.Dense(features=self.dim_feedforward,
                     kernel_init=nn.initializers.xavier_uniform(),
                     bias_init=nn.initializers.zeros),
            nn.gelu,
            nn.Dropout(self.dropout_rate),
            nn.Dense(features=self.d_model,
                     kernel_init=nn.initializers.xavier_uniform(),
                     bias_init=nn.initializers.zeros),
            nn.Dropout(self.dropout_rate)
        ]
            
         
        # Layers to apply in between the main layers
        # self.rmsNorm1 = nn.RMSNorm()
        self.rmsNorm2 = nn.RMSNorm()
        self.dropout = nn.Dropout(self.dropout_rate)


    
    def __call__(self, x: jnp.ndarray, mask: Optional[jnp.ndarray] = None, train: bool = False) -> jnp.ndarray:
        """
        Args:
            x: Input tensor of shape (batch, length, d_model)
            mask: Attention mask of shape (batch, heads, length, length) or None
            train: Whether in training mode (for dropout)
        Returns:
            Output tensor of shape (batch, length, d_model)
        """
        # First RMS Norm + Residual Connection
        residual = x
        # x = self.rmsNorm1(x)
        x, _= self.self_attn(x, mask=mask)
        x = self.dropout(x, deterministic=not train)
        x = x + residual  # Residual connection

        # Second RMS Norm + Residual Connection
        
        x = self.rmsNorm2(x)
        x = self.dropout(x, deterministic=not train)

        residual = x
        
        for layer in self.ffn:
            x = layer(x) if not isinstance(layer, nn.Dropout) else layer(x, deterministic=not train)

        x = self.dropout(x, deterministic=not train)
        x = x + residual  # Residual connection

        return x
    

class TransformerEncoder(nn.Module):
    num_layers : int
    d_model: int  # Dimension of the model (hidden size)
    n_h: int  # Number of attention heads
    dim_feedforward: int    # Dimension of the feed-forward network
    dropout_rate: float = 0.1  # Dropout rate
    d_c : int = 64 # Latent dimension (d_c)
    d_c_: int = 64 # Latent dimension (d_c)
    d_h_R: int = 32  # Rotated dimension (d_h^R)
    position: bool = False

    def setup(self):

        # Initialize a list of Transformer blocks
        self.layers = [TransformerBlock(
            d_model=self.d_model,
            n_h=self.n_h,
            dim_feedforward=self.dim_feedforward,
            dropout_rate=self.dropout_rate,
            d_c=self.d_c,
            d_c_=self.d_c_,
            d_h_R=self.d_h_R,
            position=self.position
            ) for _ in range(self.num_layers)]


    def __call__(self, x:jnp.ndarray, mask:Optional[jnp.ndarray] = None, train:bool=True)-> jnp.ndarray:
        """
        Args:
            x: Input tensor of shape (batch, length, d_model)
            mask: Attention mask of shape (batch, heads, length, length) or None
            train: Whether in training mode (for dropout)
        Returns:
            Output tensor of shape (batch, length, d_model)
        """
        # Apply each Transformer block in sequence
        for layer in self.layers:
            x = layer(x, mask=mask, train=train)
        return x

    def get_attention_maps(self, x:jnp.ndarray, mask:Optional[jnp.ndarray] = None, train:bool=True)-> List[jnp.ndarray]:
        # A function to return the attention maps within the model for a single application
        # Used for visualization purpose later
        attention_maps = []
        for layer in self.layers:
    
            _, attn_map = layer.self_attn(x, mask=mask)
            attention_maps.append(attn_map)
            x = layer(x, mask=mask, train=train)
        return attention_maps
    


class TransformerPredictor(nn.Module):
    
    num_classes : int                 # Number of classes to predict per sequence element
    num_layers : int
    d_model: int  # Dimension of the model (hidden size)
    n_h: int  # Number of attention heads
    dim_feedforward: int    # Dimension of the feed-forward network
    dropout_rate: float = 0.1  # Dropout rate
    input_dropout_prob : float = 0.0  # Dropout to apply on the input features
    d_c : int = 64 # Latent dimension (d_c)
    d_c_: int = 64 # Latent dimension (d_c)
    d_h_R: int = 32  # Rotated dimension (d_h^R)
    position: bool= False

    def setup(self):
        # Input dim -> Model dim
        self.input_dropout = nn.Dropout(self.input_dropout_prob)
        self.input_layer = nn.Dense(self.d_model)

        # Transformer encoder
        self.transformer = TransformerEncoder(
            num_layers=self.num_layers,
            d_model=self.d_model,
            n_h=self.n_h,
            dim_feedforward=self.dim_feedforward,
            dropout_rate=self.dropout_rate,
            d_c=self.d_c,
            d_c_=self.d_c_,
            d_h_R=self.d_h_R,
            position=self.position
        )

        # Output classifier per sequence element
        self.output_net = [
            nn.Dense(self.d_model),
            nn.LayerNorm(),
            nn.relu,
            nn.Dropout(self.dropout_rate),
            nn.Dense(self.num_classes)
        ]


    def __call__(self, x:jnp.ndarray, mask:Optional[jnp.ndarray] = None, train:bool=True):
        """
        Inputs:
            x - Input features of shape [Batch, SeqLen, input_dim]
            mask - Mask to apply on the attention outputs (optional)
            add_positional_encoding - If True, we add the positional encoding to the input.
                                      Might not be desired for some tasks.
            train - If True, dropout is stochastic
        """

        # Apply input dropout and linear transformation
        x = self.input_dropout(x, deterministic=not train)
        x = self.input_layer(x)

        # Apply Transformer encoder
        x = self.transformer(x, mask=mask, train=train)


 
        for layer in self.output_net:
            x = layer(x) if not isinstance(layer, nn.Dropout) else layer(x, deterministic=not train)
        return x
        

    def get_attention_maps(self, x:jnp.ndarray, mask:Optional[jnp.ndarray] = None, train:bool=True):
        """
        Function for extracting the attention matrices of the whole Transformer for a single batch.
        Input arguments same as the forward pass.
        """
        # Apply input dropout and linear transformation
        x = self.input_dropout(x, deterministic=not train)
        x = self.input_layer(x)

        # Get attention maps from the Transformer encoder
        attention_maps = self.transformer.get_attention_maps(x, mask=mask, train=train)
        return attention_maps



In [8]:

def load_model(checkpoint_path: str, state):
    # Load model. We use different checkpoint for the pretrained model
    params = checkpoints.restore_checkpoint(ckpt_dir= checkpoint_path, target=state.params)
    return params
        

In [85]:
def benchmark_model(model, params, x, dropout_init_rng, mask=None, n_runs=10 ):
    """Benchmark the execution time of a Flax model."""


    # JIT compile the function
    model_fn = jax.jit(lambda params, x, mask: model.apply({'params': params}, x, mask, train = False, rngs={'dropout': dropout_init_rng}))
    # Warmup
    model_fn(params, x, mask)

    # Timing execution
    start = time.time()
    for _ in range(n_runs):
        _ = model_fn(params, x, mask)
    end = time.time()

    avg_time = (end - start) / n_runs
    return avg_time

def compute_flops(model, params, x, dropout_init_rng, mask=None):
    """Compute an approximate count of FLOPs for a Flax model."""
  
    # JAX tracing to get Jaxpr (computational graph)
    jaxpr = jax.make_jaxpr(lambda x, mask: model.apply(
        {'params': params}, x, mask, train=False, rngs={'dropout': dropout_init_rng}))(x, mask)

    # Count the number of floating-point operations
    flops = sum(eq.primitive.name in ["dot_general", "conv_general_dilated"] for eq in jaxpr.jaxpr.eqns)

    return flops


In [86]:
d_model = 128
n_h = 4
d_h = d_model//n_h
d_c = d_h // 2
d_c_ = d_h // 4

model_kwargs = {
    "d_model":d_model,
    "n_h":n_h,
    "num_classes":1,
    "num_layers":4,
    "dim_feedforward" : 256,
    "dropout_rate":0.1,
    "input_dropout_prob":0.1,
    "d_c" : d_c,
    "d_c_": d_c_,
    "position": False
}
model = TransformerPredictor(**model_kwargs)

main_rng, x_rng = random.split(main_rng)
x = random.normal(x_rng, (64,10,512)) 


In [87]:
rng = jax.random.PRNGKey(423)
rng, init_rng, dropout_init_rng = jax.random.split(rng, 3)
params = model.init({'params': init_rng, 'dropout': dropout_init_rng}, x, train=True)['params']
optimizer = optax.adam(learning_rate=1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

params = load_model(CHECKPOINT_PATH, state=state )

In [88]:
execution_time_mla = benchmark_model(model=model,params=params, x=x, dropout_init_rng=dropout_init_rng)
flops_mla = compute_flops(model=model,params=params, x=x, dropout_init_rng=dropout_init_rng)

In [89]:
print(f"MultiHeadLatentAttention Time: {execution_time_mla:.6f} sec")
print(f"MultiHeadLatentAttention FLOPs: {flops_mla}")

MultiHeadLatentAttention Time: 0.000100 sec
MultiHeadLatentAttention FLOPs: 31
