In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import dataclasses
from typing import Optional, List

In [None]:
@dataclasses.dataclass
class MoeArgs:  # MoE in mistral
    n_experts: int
    n_experts_per_tok: int


@dataclasses.dataclass
class ModelArgs:
    dim: int = 128      # embedding dimension for each input token and general dim in model layers
    n_layers: int = 4   # number of transformer layers in the model
    hidden_dim: int = 256   # hidden dimension used in ffn
    head_dim: int = 32  # head dimension used in attention (conventionally set to hidden_dim / n_heads)
    n_heads: int = 8  # number of heads for the Q
    n_kv_heads: Optional[int] = None  # number of heads for the K and V (can be different from Q)
    vocab_size: int = 1000  # vocab size (number of possible tokens) usually from tokenizer.vocab_size
    norm_eps: float = 1e-5   # for numerical stability
    max_batch_size: int = 8     # maximum batch size
    max_seq_len: int = 64   # maximum sequence length (not directly used in Mistral)
    attn_window: Optional[int] = None  # attention window and rolling buffer size, if None, it is set to max_seq_len
    rope_theta: float = 10000.0  # theta for rotary embeddings
    moe: Optional[MoeArgs] = None   # if set then use MoE otherwise normal FFN
    debug: Optional[bool] = False   # if verbose
    device: str = "cpu"  # device to use


# RMSNorm

In [None]:

class RMSNorm(nn.Module):
    # https://arxiv.org/pdf/1910.07467
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # The gamma (g) parameter that is trainable to perform the rescaling on the norm
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # RMSNorm statistics, (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
        rms_reciprocal = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)    # rsqrt = 1 / sqrt
        return x * rms_reciprocal

    def forward(self, x: torch.Tensor):
        # This completes the equation 4 in the paper
        # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
        # auto-broadcasting expands (Dim) to (1, 1, Dim) to multiplied to the last dimension of (B, Seq_Len, Dim)
        # recall: Automatic broadcasting in PyTorch occurs when dimensions match or are broadcastable starting from the trailing dimensions (i.e., from right to left)
        return self.weight * self._norm(x.float()).type_as(x)

# Self Attention with KV Cache

**TODO**:
 * Create RollingBuffer class support
 * Add RollingBuffer in the self attention class

In [None]:
class RollingBufferKVCache:
    def __init__(self, max_batch_size, attn_window, n_kv_heads, head_dim):
        # implemented based on idea from original Mistral paper https://arxiv.org/abs/2310.06825
        self.max_batch_size = max_batch_size
        self.attn_window = attn_window
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim
        # initialize the KV cache with zeros with shape (B, attn_window, n_kv_heads, Head_Dim)
        self.cache_k = torch.zeros((self.max_batch_size, self.attn_window, self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((self.max_batch_size, self.attn_window, self.n_kv_heads, self.head_dim))

    def update_cache(self, xk, xv, batch_size, start_pos):
        # get the position of rolling window cache using the modulo operation
        # ensures that the position wraps around within the attn_window size
        cache_position = start_pos % self.attn_window
        # update the entry in the KV cache's respective calculated position with the new KV values
        # fill (:B, idx) part of the (max_B, max_seq_len, n_kv_heads, Head_Dim) cache with (B, 1, n_kv_heads, Head_Dim)
        # shape of xk and xv: (batch_size, 1, n_kv_heads, head_dim)
        self.cache_k[:batch_size, cache_position:cache_position + 1] = xk
        self.cache_v[:batch_size, cache_position:cache_position + 1] = xv

    def update_cache_multiple(self, xk, xv, batch_size, start_pos, seq_len):
        # used when seq_len > 1, yet in inference we only care about the seq_len = 1 case
        # can be optimized in the future to support Mistral's pre-fill and chunking (to handle prompts)
        for i in range(seq_len):
            self.update_cache(xk[:, i:i+1, :, :], xv[:, i:i+1, :, :], batch_size, start_pos + i)

    def retrieve_cache(self, batch_size, start_pos):
        # calculate the effective start position considering the rolling buffer's nature
        # NOTE: start_pos should be updated to be start_pos + seq_len when called after update_cache
        effective_start_pos = start_pos % self.attn_window
        # retrieve KV from the cache, split into 2 parts to handle the wrap-around
        keys = torch.cat([
            self.cache_k[:batch_size, effective_start_pos:, :, :],
            self.cache_k[:batch_size, :effective_start_pos, :, :]
        ], dim=1)
        values = torch.cat([
            self.cache_v[:batch_size, effective_start_pos:, :, :],
            self.cache_v[:batch_size, :effective_start_pos, :, :]
        ], dim=1)
        # select the last seq_len tokens from the concatenated keys and values (to handle when < attn_window)
        keys = keys[:, -start_pos:, :, :]
        values = values[:, -start_pos:, :, :]
        return keys, values

In [None]:
class SelfAttention(nn.Module):
    # Decoder only with causal attention (only work for inference)
    # only care about current token and its corresponding attention (with support from the KV Cache)
    # Extended support for GQA (grouped query attention)
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.debug = args.debug
        # set the number of KV heads for GQA (see the paper), default to Q heads (then just MHA)
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # set the number of Q heads, should always be args.n_heads
        self.n_heads_q = args.n_heads
        # get num times the KV should be repeated in GQA
        self.n_rep = self.n_heads_q // self.n_kv_heads
        # dim of each head = dim / n_heads (the part of the embedding that each head will be responsible for)
        self.head_dim = args.head_dim
        # cache size or attention window size, if not specified, default to full attention
        self.attn_window = args.attn_window if args.attn_window is not None else args.max_seq_len

        # q k v o weights in transformer attention
        self.wq = nn.Linear(args.dim, self.n_heads_q * self.head_dim)   # for Q
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim)  # for K
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim)  # for V
        self.wo = nn.Linear(self.n_heads_q * self.head_dim, args.dim)   # for O (here n_heads_q * head_dim == dim)

        # KV Cache with support of Sliding Window Attention & Rolling Buffer Cache
        self.kv_cache = RollingBufferKVCache(
            max_batch_size=args.max_batch_size,
            attn_window=self.attn_window,
            n_kv_heads=self.n_kv_heads,
            head_dim=self.head_dim
        )

    def repeat_kv(self, kv: torch.Tensor) -> torch.Tensor:  # just copy, but can be optimized...
        # in GQA, each Q group shares the same KV heads, thus just repeat KV heads for the Q in the same group
        # goal shape: (B, prefix_seq_len, n_kv_heads, Head_Dim) => (B, prefix_seq_len, n_heads_q, Head_Dim)
        batch_size, seq_len, n_kv_heads, head_dim = kv.shape
        if self.n_rep == 1:  # Q and KV are 1-to-1 (just a normal MHA)
            return kv
        else:  # GQA
            return (
                # (B, prefix_seq_len, n_kv_heads, 1, Head_Dim)
                kv[:, :, :, None, :]
                # (B, prefix_seq_len, n_kv_heads, n_rep, Head_Dim) just copy n_rep times
                .expand(batch_size, seq_len, n_kv_heads, self.n_rep, head_dim)
                # (B, prefix_seq_len, n_kv_heads * n_rep, Head_Dim) = (B, prefix_seq_len, n_heads_q, Head_Dim)
                .reshape(batch_size, seq_len, n_kv_heads * self.n_rep, head_dim)
            )

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor) -> torch.Tensor:
        # recall start_pos = the position of the token in the sequence we are dealing with
        # this is the standard self-attention mechanism computation with slight modifications (Llama / Mistral)
        # goal shape: (B, 1, Dim) => (B, 1, Dim)
        if self.debug: print("SelfAttention input shape", x.shape)
        batch_size, seq_len, _ = x.shape    # (B, 1, Dim)
        assert seq_len == 1, "only support 1D input for now for debugging"  # TODO test support when seq_len > 1
        # compute Q K V from the weights wq wk wv
        # (B, 1, Dim) => (B, 1, n_heads_q * Head_Dim)
        xq = self.wq(x)
        # (B, 1, Dim) => (B, 1, n_kv_heads * Head_Dim)
        xk = self.wk(x)
        # (B, 1, Dim) => (B, 1, n_kv_heads * Head_Dim)
        xv = self.wv(x)

        # reshape Q K V to get individual single heads (Qi, Ki, Vi) from the tensors
        # (B, 1, n_heads_q * Head_Dim) => (B, 1, n_heads_q, Head_Dim)
        xq = xq.reshape(batch_size, seq_len, self.n_heads_q, self.head_dim)
        # (B, 1, n_kv_heads * Head_Dim) => (B, 1, n_kv_heads, Head_Dim)
        xk = xk.reshape(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        # (B, 1, n_kv_heads * Head_Dim) => (B, 1, n_kv_heads, Head_Dim)
        xv = xv.reshape(batch_size, seq_len, self.n_kv_heads, self.head_dim)

        # apply RoPE on Q and K, both should have the same shape before and after RoPE
        xq = self.apply_rotary_embeddings(xq, freqs_complex, x.device)    # (B, 1, n_heads_q, Head_Dim)
        xk = self.apply_rotary_embeddings(xk, freqs_complex, x.device)    # (B, 1, n_kv_heads, Head_Dim)

        # replace the entry in the KV cache's respective position (aka update KV Cache)
        # fill (:B, idx) part of the (max_B, max_seq_len, n_kv_heads, Head_Dim) cache with (B, 1, n_kv_heads, Head_Dim)
        self.kv_cache.update_cache(xk, xv, batch_size, start_pos)

        # retrieve complete K and V from KV Cache for Attention Computation
        # (B, prefix_seq_len, n_kv_heads, Head_Dim)
        keys, values = self.kv_cache.retrieve_cache(batch_size, start_pos + seq_len)

        # in GQA, each Q group shares the same KV heads, thus just repeat KV heads for the Q in the same group
        # (B, prefix_seq_len, n_kv_heads, Head_Dim) => (B, prefix_seq_len, n_heads_q, Head_Dim)
        keys = self.repeat_kv(keys)
        values = self.repeat_kv(values)

        # reshape: equivalent to X.reshape(B, n_heads_q, 1 or prefix_seq_len, Head_Dim)
        # (B, 1, n_heads_q, Head_Dim) => (B, n_heads_q, 1, Head_Dim)
        xq = xq.transpose(1, 2)
        # (B, prefix_seq_len, n_heads_q, Head_Dim) => (B, n_heads_q, prefix_seq_len, Head_Dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # attention weight and score computation
        # NOTE about MATMUL: for tensors with more than 2 dimensions, torch.matmul treats the last two dimensions as matrices and performs batch matrix multiplication on the other dimensions. The result is a tensor where each batch element is the result of matrix multiplication on the corresponding batch elements of the input tensors
        # (B, n_heads_q, 1, Head_Dim) @ (B, n_heads_q, Head_Dim, prefix_seq_len) => (B, n_heads_q, 1, prefix_seq_len)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        # softmax(QK/sqrt(dk)): (B, n_heads_q, 1, prefix_seq_len) => (B, n_heads_q, 1, prefix_seq_len)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)  # dim=-1 means softmax along last dimension (sum=1)

        # attention computation with the values
        # (B, n_heads_q, 1, prefix_seq_len) @ (B, n_heads_q, prefix_seq_len, Head_Dim) => (B, n_heads_q, 1, Head_Dim)
        output = torch.matmul(scores, values)
        # (B, n_heads_q, 1, Head_Dim) => (B, 1, n_heads_q, Head_Dim) and make sure contiguous in memory
        output = output.transpose(1, 2).contiguous()
        # (B, 1, n_heads_q, Head_Dim) => (B, 1, n_heads_q * Head_Dim) = (B, 1, Dim)
        output = output.reshape(batch_size, seq_len, self.n_heads_q * self.head_dim)

        # final linear layer
        # (B, 1, Dim) => (B, 1, Dim)
        output = self.wo(output)

        if self.debug:
            print("SelfAttention output shape", output.shape)
        return output

    @staticmethod
    def apply_rotary_embeddings(x: torch.Tensor, freqs_cis: torch.Tensor, device: torch.device) -> torch.Tensor:
        # reshape x: (B, 1, n_heads_q, Head_Dim) => (B, 1, n_heads_q, 2, Head_Dim/2)
        x = x.float().reshape(*x.shape[:-1], -1, 2)
        x = torch.view_as_complex(x)
        # apply the RoPE rotation
        x = x * freqs_cis.to(device)
        # reshape back x: (B, 1, n_heads_q, 2, Head_Dim/2) => (B, 1, n_heads_q, Head_Dim)
        x = torch.view_as_real(x)
        return x.flatten(3)


# Feedforward

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    # Initialize the FeedForward class
    def __init__(self, args: ModelArgs):
        super().__init__()  # Call the constructor of the parent class (nn.Module)
        self.debug = args.debug  # Store the debug flag from the arguments
        hidden_dim = args.hidden_dim  # Get the hidden dimension size from the arguments

        # Initialize the linear layers (weights) of the FFN
        # First linear layer: transforms input from 'dim' to 'hidden_dim'
        # Second linear layer: transforms back from 'hidden_dim' to 'dim'
        # Third linear layer: another transformation from 'dim' to 'hidden_dim'
        self.w1 = nn.Linear(args.dim, hidden_dim)  # Linear layer for w1
        self.w2 = nn.Linear(hidden_dim, args.dim)  # Linear layer for w2
        self.w3 = nn.Linear(args.dim, hidden_dim)  # Linear layer for w3

    # Define the forward pass of the network
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.debug:  # If debugging is enabled, print the shape of the input tensor
            print("FeedForward input shape", x.shape)

        # Apply the first linear transformation (w1) to the input
        # This maps the input from the original dimension (Dim) to the hidden dimension (Hidden_Dim)
        x_w1 = self.w1(x)  # (B, Seq, Hidden_Dim)

        # Apply the SiLU activation function to the output of the first linear transformation
        # This activation function adds non-linearity to the network
        x_silu = F.silu(x_w1)  # (B, Seq, Hidden_Dim)

        # Apply the third linear transformation (w3) to the input
        # This again maps the input from the original dimension (Dim) to the hidden dimension (Hidden_Dim)
        x_w3 = self.w3(x)  # (B, Seq, Hidden_Dim)

        # Perform element-wise multiplication between the activated output (x_silu) and the result of the third linear transformation (x_w3)
        # This is part of the gating mechanism in SwiGLU, where one part controls the other
        x_swiglu = x_silu * x_w3  # (B, Seq, Hidden_Dim)

        # Apply the second linear transformation (w2) to the result of the element-wise multiplication
        # This maps the output back from the hidden dimension (Hidden_Dim) to the original input dimension (Dim)
        output = self.w2(x_swiglu)  # (B, Seq, Dim)

        if self.debug:  # If debugging is enabled, print the shape of the output tensor
            print("FeedForward output shape", output.shape)

        return output


# MOE

In [None]:
import torch
import torch.nn as nn
from typing import List

class MoE(nn.Module):
    # Initialize the MoE layer with a list of experts, a gate module, and MoE-specific arguments
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
        super().__init__()
        # Ensure the number of provided experts matches the expected number from moe_args
        assert len(experts) == moe_args.n_experts, "Number of experts must be equal to n_experts"

        # Store the list of experts (each expert is a separate feed-forward network or another nn.Module)
        self.experts = nn.ModuleList(experts)

        # The gating mechanism: a linear layer that determines which experts to use for each input
        self.gate = gate

        # Store additional MoE-specific arguments
        self.moe_args = moe_args

    # Forward pass through the MoE layer
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # NOTE: In the Mistral paper, the input/output sizes are reshaped to (B * seq_len, Dim) instead of (B, seq_len, Dim)
        # Goal shape transformation: (B, seq_len, Dim) => (B, seq_len, Dim)

        # Get the batch size (B), sequence length (seq_len), and dimension (dim) from the input tensor
        B, seq_len, dim = x.shape

        # Flatten the input to match the expected shape for the gating mechanism: (B * seq_len, Dim)
        x_flat = x.view(-1, dim)  # (B * seq_len, Dim)

        # Pass the flattened input through the gate to get logits for each expert
        # (B * seq_len, Dim) gate=> (B * seq_len, n_experts)
        # The gate produces a score (logit) for each expert
        logits = self.gate(x_flat)  # (B * seq_len, n_experts)

        # Select the top-k experts based on the gate logits for each input token
        # `torch.topk` returns the highest `n_experts_per_tok` logits and their corresponding expert indices
        # weights=logits, selected_experts=indices
        # (B * seq_len, n_experts) => (B * seq_len, n_experts_per_tok)
        topk_logits, topk_indices = torch.topk(logits, self.moe_args.n_experts_per_tok, dim=1)  # top-k experts

        # Normalize the selected expert weights using softmax, so that they sum to 1 for each input token
        topk_weights = torch.softmax(topk_logits, dim=1)  # (B * seq_len, n_experts_per_tok)

        # Initialize the output tensor to store the final results: (B * seq_len, Dim)
        results = torch.zeros_like(x_flat)  # (B * seq_len, Dim)

        # Iterate over each expert to compute the weighted sum of the outputs from each selected top-k expert
        for i, expert in enumerate(self.experts):
            # Determine which tokens are assigned to the current expert
            # `torch.where` returns indices where the selected_experts match the current expert index `i`
            expert_mask = topk_indices == i  # (B * seq_len, n_experts_per_tok)

            if expert_mask.any():  # Check if any tokens are assigned to this expert
                # Extract the input tokens assigned to this expert: (K, Dim), where K is the number of tokens assigned
                expert_input = x_flat[expert_mask.any(dim=1)]  # (K, Dim)

                # Pass the selected tokens through the expert
                expert_output = expert(expert_input)  # (K, Dim)

                # Get the corresponding weights for the selected tokens: (K, 1)
                expert_weights = topk_weights[expert_mask]  # (K,)

                # Multiply the expert output by its weight and add the result to the final output tensor
                # expert_w * expert_out: (K,) * (K, Dim) => (K, Dim)
                weighted_output = expert_output * expert_weights.unsqueeze(1)  # (K, Dim)
                results[expert_mask.any(dim=1)] += weighted_output  # Update the results

        # Reshape the results back to the original input shape: (B * seq_len, Dim) => (B, seq_len, Dim)
        results = results.view(B, seq_len, dim)  # (B, seq_len, Dim)

        # Return the final output of the MoE layer
        return results


# Transformer Block

In [None]:
import torch
import torch.nn as nn
from typing import Optional
# The TransformerBlock class
class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.debug = args.debug
        self.n_head = args.n_heads
        self.dim = args.dim

        # RMSNorm layers for attention and feed-forward
        self.rms_norm_attn = nn.LayerNorm(self.dim, eps=args.norm_eps)
        self.rms_norm_ffn = nn.LayerNorm(self.dim, eps=args.norm_eps)

        # Self-attention layer
        self.self_attention = nn.MultiheadAttention(self.dim, self.n_head, batch_first=True)

        # Define the experts (example)
        self.experts = [nn.Sequential(nn.Linear(args.dim, args.hidden_dim), nn.ReLU(), nn.Linear(args.hidden_dim, args.dim)) for _ in range(args.moe.n_experts)]

        # Define the gate (example)
        self.gate = nn.Linear(args.dim, args.moe.n_experts)

        # MoE or standard feed-forward layer
        if args.moe:
            # Assuming args.moe contains experts, gate, and moe_args
            self.feed_forward = MoE(self.experts, self.gate, args.moe)
        else:
            self.feed_forward = FeedForward(args)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor) -> torch.Tensor:
        if self.debug:
            print("TransformerBlock input shape:", x.shape)

        # Apply RMSNorm and self-attention, then add the result to the original input (residual connection)
        x = self.rms_norm_attn(x)
        attn_output, _ = self.self_attention(x, x, x)
        x = x + attn_output

        # Apply RMSNorm and the feed-forward network, then add the result to the previous output (residual connection)
        x = self.rms_norm_ffn(x)
        ff_output = self.feed_forward(x)
        out = x + ff_output

        return out

# Transformer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        assert args.vocab_size > 0, "vocab size should be set"  # Ensure vocab size is provided
        self.args = args
        self.debug = args.debug  # Debugging flag
        self.vocab_size = args.vocab_size  # Vocabulary size for token embeddings
        self.n_layers = args.n_layers  # Number of transformer layers

        # Token embedding layer: maps input tokens to their vector representations
        self.embedding = nn.Embedding(self.vocab_size, args.dim)

        # Transformer layers: Stacking multiple TransformerBlocks to form the deep model
        self.transformer_layers = nn.ModuleList([
            TransformerBlock(args) for _ in range(self.n_layers)
        ])

        # RMS Normalization after all layers
        self.rms_norm = nn.LayerNorm(args.dim)  # RMSNorm can be used instead of LayerNorm

        # Output layer: Projects the final hidden states to the vocabulary size for prediction
        self.output = nn.Linear(args.dim, self.vocab_size)

        # Precompute frequencies for ROPE positional encoding as described in the paper (for more efficient processing)
        self.freqs_complex = self.precompute_theta_pos_frequencies(
            head_dim=args.dim * 2,  # Assuming head_dim is dim
            seq_len=args.max_seq_len,  # Set max_seq_len or adjust based on your use case
            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),  # Assuming device is set to 'cuda' or 'cpu'
            theta=args.rope_theta  # Preset theta value
        )

    @property
    def dtype(self) -> torch.dtype:
        # Returns the data type of the model parameters
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        # Returns the device (CPU/GPU) on which the model parameters are stored
        return next(self.parameters()).device

    # Forward pass through the entire transformer model
    def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor:

      if self.debug:
          print("Transformer input shape", tokens.shape)

      batch_size, seq_len = tokens.shape
      assert seq_len == 1, "One token at a time at inference time"

      embeddings = self.embedding(tokens)  # Shape: (batch_size, seq_len, dim)

      print("Start Position ", start_pos, "Seq Length", seq_len, "batch Size", batch_size)

      # Get the positional encodings
      pos_encodings = self.freqs_complex[start_pos:start_pos + seq_len]  # (seq_len, dim)
      pos_encodings = pos_encodings.expand(batch_size, seq_len, -1)  # (batch_size, seq_len, dim)
      pos_encodings = pos_encodings.to(torch.float32)  # Ensure pos_encodings are float32

      print("Embeddings shape:", embeddings.shape)
      print("Positional encodings shape:", pos_encodings.shape)

      # Add positional encodings to token embeddings
      embeddings = embeddings + pos_encodings

      # Pass the embeddings through each transformer layer sequentially
      h = embeddings
      for layer in self.transformer_layers:
          h = layer(h, start_pos, self.freqs_complex)

      h = self.rms_norm(h)
      output = self.output(h)
      return output



    # Precompute the positional encoding frequencies for ROPE
    def precompute_theta_pos_frequencies(
         self, head_dim: int, seq_len: int, device: torch.device, theta: float
         ) -> torch.Tensor:

      """
      Precomputes the frequencies used in ROPE positional encoding.
      """
      assert head_dim % 2 == 0, "head_dim must be even"

      # Calculate the sequence of theta values
      theta_numerator = torch.arange(0, head_dim, 2).float()
      theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)

      # Generate the positions 'm' as a sequence from 0 to seq_len-1
      m = torch.arange(seq_len, device=device)

      # Compute the outer product of positions 'm' and theta values
      freqs = torch.outer(m, theta).float()

      # Compute the polar form (complex number) used in ROPE
      freqs_complex = torch.polar(torch.ones_like(freqs, dtype=torch.float32), freqs)
      print("freqs_complex shape:", freqs_complex.shape)
      return freqs_complex



# Mistral

In [None]:
from typing import List, Optional, Tuple
import torch
from tqdm import tqdm

class Mistral:
    def __init__(self, model_args: ModelArgs):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.args = model_args
        self.model = Transformer(args=model_args).to(device)

    def generate(
            self, prompts: List[List[int]], temperature: float = 0.6,
            top_p: float = 0.9, max_gen_len: Optional[int] = None
    ):

        if max_gen_len is None:
            max_gen_len = self.args.max_seq_len - 1

        prompt_tokens = prompts
        batch_size = len(prompt_tokens)
        assert batch_size <= self.args.max_batch_size, \
            f"batch size must be less than or equal to {self.args.max_batch_size}"

        max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
        assert max_prompt_len <= self.args.max_seq_len, \
            f"prompt length must be less than or equal to {self.args.max_seq_len}"
        total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len)

        pad_id = 0
        tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=self.model.device)
        for k, t in enumerate(prompt_tokens):
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=self.model.device)

        eos_reached = torch.tensor([False] * batch_size, device=self.model.device)
        prompt_tokens_mask = tokens != pad_id

        for cur_pos in tqdm(range(1, total_len), desc="Generating tokens"):
            with torch.no_grad():
                logits = self.model.forward(tokens[:, cur_pos - 1:cur_pos], cur_pos)

            next_token = self.sample_next_token(logits, temperature, top_p)
            next_token = next_token.reshape(-1)
            next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token)
            tokens[:, cur_pos] = next_token
            eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == -1)
            if all(eos_reached): break

        out_tokens = []
        out_text = []
        for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
            if -1 in current_prompt_tokens:
                eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id)
                current_prompt_tokens = current_prompt_tokens[:eos_idx]
            out_tokens.append(current_prompt_tokens)
            # out_text.append(self.tokenizer.decode(current_prompt_tokens))
        return (out_tokens, out_text)

    def sample_next_token(self, probs, temperature, top_p):
        if temperature == 0:
            next_token = torch.argmax(probs[:, -1], dim=-1)
            return next_token

        probs = torch.softmax(probs[:, -1] / temperature, dim=-1)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        mask = probs_sum - probs_sort > top_p
        probs_sort[mask] = 0.0
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        next_token = torch.multinomial(probs_sort, num_samples=1)
        next_token = torch.gather(probs_idx, -1, next_token)
        return next_token


In [None]:
!pip install fire

Collecting fire
  Downloading fire-0.6.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.4/88.4 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.6.0-py2.py3-none-any.whl size=117030 sha256=8259b683a3c65c1084e9de755c4d0232376abff8b65f4dca167c878afa3e2696
  Stored in directory: /root/.cache/pip/wheels/d6/6d/5d/5b73fa0f46d01a793713f8859201361e9e581ced8c75e5c6a3
Successfully built fire
Installing collected packages: fire
Successfully installed fire-0.6.0


In [None]:
pip install torchviz

Collecting torchviz
  Downloading torchviz-0.0.2.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->torchviz)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->torchviz)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->torchviz)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->torchviz)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->torchviz)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->torchviz)
  Using cached nvi

In [None]:
import logging
import os
from pathlib import Path
from typing import List, Optional

import torch
from torchviz import make_dot

# Ensure all necessary classes and functions are imported
# from mistral_model import Transformer, TransformerBlock, ModelArgs, MoeArgs, FeedForward
# from mistral_inference import Mistral

# Define necessary model arguments
args = ModelArgs(
    dim=128,
    n_layers=1,
    hidden_dim=256,
    head_dim=16,
    n_heads=8,
    n_kv_heads=2,
    vocab_size=1000,
    norm_eps=1e-5,
    max_batch_size=8,
    max_seq_len=64,
    attn_window=4,
    rope_theta=10000.0,
    moe=MoeArgs(n_experts=4, n_experts_per_tok=2),
    debug=False
)


device =  "cpu"
mistral = Mistral(args)
print(mistral.model)

# Visualize the model
model = Transformer(args).to(device)
x = torch.ones(3, 1, dtype=torch.long, device=device)
y = model(x, 0)

params = {name: param for name, param in model.named_parameters() if 'bias' not in name}
dot = make_dot(y, params=params)

# Simplify node names
simple_names = {}
for node in dot.body:
    if 'Conv' in node or 'BatchNorm' in node or 'ReLU' in node:
        name = node.split('[')[0]
        simple_names[name] = name.replace('\"', '').split('/')[-1]

for i in range(len(dot.body)):
    for key, value in simple_names.items():
        if key in dot.body[i]:
            dot.body[i] = dot.body[i].replace(key, value)

dot.render('model_visualization', format='png')

# Generate tokens
encoded_prompts = [[10, 2, 4, 4, 3, 7, 8], [4, 5, 6, 2, 3, 8, 9], [4, 5, 6, 2, 3, 4, 5, 6, 2, 3]]
tokens, text = mistral.generate(
    prompts=encoded_prompts,
    temperature=0.6,
    top_p=0.9,
    max_gen_len=10
)
print("Generated tokens:", tokens)

# Instantiate transformer and check parameters
transformer_block = TransformerBlock(args=args).to(device)
transformer = Transformer(args=args).to(device)
tok = transformer.embedding
vocab_size = 1000
batch_size = 3  # Example batch size
encoded_prompts = [[100, 2, 4, 4, 3, 7, 8], [4, 5, 6, 2, 3, 8, 9], [4, 5, 6, 2, 3, 4, 5, 6, 2, 3]]

print(encoded_prompts)
prompt_chunks = [p[0:1] for p in encoded_prompts]
input_tensor = torch.tensor(sum(prompt_chunks, []), device=device, dtype=torch.long)
print(input_tensor.shape)
print(tok(input_tensor).shape)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(transformer)
print(f"Total trainable parameters in the model: {total_params}")

# Inference without cache
input_tensor = torch.tensor([[vocab_size - 1]] * batch_size, dtype=torch.long, device=device)
start_pos = 0
output = transformer(input_tensor, start_pos)
print("Inference output:", output)




freqs_complex shape: torch.Size([64, 128])
Transformer(
  (embedding): Embedding(1000, 128)
  (transformer_layers): ModuleList(
    (0): TransformerBlock(
      (rms_norm_attn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (rms_norm_ffn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (gate): Linear(in_features=128, out_features=4, bias=True)
      (feed_forward): MoE(
        (experts): ModuleList(
          (0-3): 4 x Sequential(
            (0): Linear(in_features=128, out_features=256, bias=True)
            (1): ReLU()
            (2): Linear(in_features=256, out_features=128, bias=True)
          )
        )
        (gate): Linear(in_features=128, out_features=4, bias=True)
      )
    )
  )
  (rms_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (output): Linear(in_features=128, out_feature

Generating tokens: 100%|██████████| 19/19 [00:00<00:00, 116.71it/s]


Start Position  1 Seq Length 1 batch Size 3
Embeddings shape: torch.Size([3, 1, 128])
Positional encodings shape: torch.Size([3, 1, 128])
Start Position  2 Seq Length 1 batch Size 3
Embeddings shape: torch.Size([3, 1, 128])
Positional encodings shape: torch.Size([3, 1, 128])
Start Position  3 Seq Length 1 batch Size 3
Embeddings shape: torch.Size([3, 1, 128])
Positional encodings shape: torch.Size([3, 1, 128])
Start Position  4 Seq Length 1 batch Size 3
Embeddings shape: torch.Size([3, 1, 128])
Positional encodings shape: torch.Size([3, 1, 128])
Start Position  5 Seq Length 1 batch Size 3
Embeddings shape: torch.Size([3, 1, 128])
Positional encodings shape: torch.Size([3, 1, 128])
Start Position  6 Seq Length 1 batch Size 3
Embeddings shape: torch.Size([3, 1, 128])
Positional encodings shape: torch.Size([3, 1, 128])
Start Position  7 Seq Length 1 batch Size 3
Embeddings shape: torch.Size([3, 1, 128])
Positional encodings shape: torch.Size([3, 1, 128])
Start Position  8 Seq Length 1 bat