In [8]:
# VS Code Requirement: You will almost certainly have this already installed in your VS Code or Cursor, so no need to do anything
# - The Microsoft Python extension (ms-python.python)
# This extension provides the ability to run '%%' cells interactively.
# Below you should see "Run Cell" and "Run Below" that is written just above '# %% [markdown]' line

# %% [markdown]
# # Building a Byte Pair Encoding (BPE) Tokenizer from Scratch
#
# This tutorial walks through the process of creating a basic BPE tokenizer, a common type of tokenizer used in Large Language Models (LLMs).
#
# ## Step 1: Prepare Training Data
#
# The first step in building any tokenizer is to have a corpus of text to train it on. The tokenizer learns merge rules based on the frequency of character pairs in this data.
#
# i: 1
#
# s: 2
#
# is: 3
#
# Even though "i" and "s" are separate tokens, we create a new token "is" by merging them as they frequently appear together (is, this, his, miss, dismiss, list, fist, twist, mist, whisk, visible, vision, revise, crisis), reducing computation needs by 2x at any place where we merge those 2 tokens. This is how we will itteratively merge most frequent pairs. The new tokens can also be further merged.
#
# Let's start with a small example corpus.

# %%
# Our sample training data
corpus = [
    "This is the first document.",
    "This document is the second document.",
    "And this is the third one.",
    "Is this the first document?",
]

print("Training Corpus:")
for doc in corpus:
    print(doc)

# %% [markdown]
# ## Step 2: Initialize Vocabulary and Pre-tokenize
#
# The BPE algorithm starts with a base vocabulary consisting of all unique characters present in the training data.
#
# We also need to pre-tokenize the corpus. This usually involves splitting the text into words (or word-like units) and then representing each word as a sequence of its individual characters. We often add a special end-of-word token (like `</w>`) to mark word boundaries, which helps the tokenizer learn subword units that align better with whole words.

# %%
# Initialize vocabulary with unique characters
unique_chars = set()
for doc in corpus:
    for char in doc:
        unique_chars.add(char)

vocab = list(unique_chars)
vocab.sort() # For consistent order of characters, making the vocabulary list predictable

# Add a special end-of-word token
end_of_word = "</w>"
vocab.append(end_of_word)

print("Initial Vocabulary:")
print(vocab)
print(f"Vocabulary Size: {len(vocab)}")

# Pre-tokenize the corpus: Split into words and then characters
# We'll split by space for simplicity and add the end-of-word token
word_splits = {}
for doc in corpus:
    words = doc.split(' ')
    for word in words:
        if word:
            char_list = list(word) + [end_of_word]
            # Use tuple for immutability if storing counts later - you can't change tuple once it's created (values, order, adding, removing elements, etc.), so they can be used as dictionary keys because of that.
            word_tuple = tuple(char_list)
            if word_tuple not in word_splits:
                 word_splits[word_tuple] = 0
            word_splits[word_tuple] += 1 # Count frequency of each initial word split

print("\nPre-tokenized Word Frequencies:")
print(word_splits)

# %% [markdown]
# ### Helper Function: `get_pair_stats`
#
# This function takes the current word splits (represented as a dictionary where keys are tuples of symbols/characters forming a word and values are their frequencies) and calculates the frequency of each adjacent pair of symbols across the entire corpus.
#
# **Input Example (`splits`):**
# ```
# {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 2, ...}
# ```
# **Output Example (`pair_counts`):**
# ```
# {('i', 's'): 4, ('s', '</w>'): 4, ('T', 'h'): 2, ...}
# ```

# %%
import collections

def get_pair_stats(splits):
    """Counts the frequency of adjacent pairs in the word splits."""
    # Initialize a dictionary with default values of 0 to count pairs of symbols.
    # defaultdict: It's like a regular dictionary (dict), but with a key difference.
    # If you try to access or modify a key that doesn't exist, instead of raising a KeyError,
    # it automatically creates that key and assigns it a default value.
    # int: This is the "default factory" you provide when creating the defaultdict. When a new key is created, it needs a default value, defaultdict calls this factory function. int() called with no arguments returns 0.
    pair_counts = collections.defaultdict(int)
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i+1])
            pair_counts[pair] += freq # Add the frequency of the word to the pair count
    return pair_counts

# %% [markdown]
# ### Helper Function: `merge_pair`
#
# This function takes a specific pair (`pair_to_merge`) that we want to combine and the current `splits`. It iterates through all the word representations in `splits`, replaces occurrences of the `pair_to_merge` with a new single token (concatenation of the pair), and returns the updated `splits`.
#
# **Input Example:**
# - `pair_to_merge`: `('i', 's')`
# - `splits`: `{('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 2, ...}`
#
# **Output Example (`new_splits`):**
# - `{('T', 'h', 'is', '</w>'): 2, ('is', '</w>'): 2, ...}` (assuming 'is' is the merged token)

# %%
def merge_pair(pair_to_merge, splits):
    """Merges the specified pair in the word splits."""
    new_splits = {}
    (first, second) = pair_to_merge
    merged_token = first + second
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        new_symbols = []
        i = 0
        while i < len(symbols):
            # If the current and next symbol match the pair to merge
            if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:
                new_symbols.append(merged_token)
                i += 2 # Skip the next symbol
            else:
                new_symbols.append(symbols[i])
                i += 1
        new_splits[tuple(new_symbols)] = freq # Use the updated symbol list as the key
    return new_splits

# %% [markdown]
# ### Step 3: Iterative BPE Merging Loop
#
# Now we perform the core BPE training. We'll loop for a fixed number of merges (`num_merges`). In each iteration:
# 1. Calculate the frequencies of all adjacent pairs in the current word representations using `get_pair_stats`.
# 2. Find the pair with the highest frequency (`best_pair`).
# 3. Merge this `best_pair` across all word representations using `merge_pair`.
# 4. Add the newly formed token (concatenation of `best_pair`) to our vocabulary (`vocab`).
# 5. Store the merge rule (mapping the pair to the new token) in the `merges` dictionary.
#
# We'll add print statements to observe the state at each step of the loop.

# %%
# --- BPE Training Loop Initialization ---
num_merges = 15
# Stores merge rules, e.g., {('a', 'b'): 'ab'}
# Example: {('T', 'h'): 'Th'}
merges = {}
# Initial word splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 2, ...}
current_splits = word_splits.copy() # Start with initial word splits

print("\n--- Starting BPE Merges ---")
print(f"Initial Splits: {current_splits}")
print("-" * 30)

for i in range(num_merges):
    print(f"\nMerge Iteration {i+1}/{num_merges}")

    # 1. Calculate Pair Frequencies
    pair_stats = get_pair_stats(current_splits)
    if not pair_stats:
        print("No more pairs to merge.")
        break
    # Optional: Print top 5 pairs for inspection
    sorted_pairs = sorted(pair_stats.items(), key=lambda item: item[1], reverse=True)
    print(f"Top 5 Pair Frequencies: {sorted_pairs[:5]}")

    # 2. Find Best Pair
    # The 'max' function iterates over all key-value pairs in the 'pair_stats' dictionary
    # The 'key=pair_stats.get' tells 'max' to use the frequency (value) for comparison, not the pair (key) itself
    # This way, 'max' selects the pair with the highest frequency
    best_pair = max(pair_stats, key=pair_stats.get)
    best_freq = pair_stats[best_pair]
    print(f"Found Best Pair: {best_pair} with Frequency: {best_freq}")

    # 3. Merge the Best Pair
    current_splits = merge_pair(best_pair, current_splits)
    new_token = best_pair[0] + best_pair[1]
    print(f"Merging {best_pair} into '{new_token}'")
    print(f"Splits after merge: {current_splits}")

    # 4. Update Vocabulary
    vocab.append(new_token)
    print(f"Updated Vocabulary: {vocab}")

    # 5. Store Merge Rule
    merges[best_pair] = new_token
    print(f"Updated Merges: {merges}")

    print("-" * 30)


# %% [markdown]
# ### Step 4: Review Final Results
#
# After the loop finishes, we can examine the final state:
# - The learned merge rules (`merges`).
# - The final representation of words after merges (`current_splits`).
# - The complete vocabulary (`vocab`) containing initial characters and learned subword tokens.

# %%
# --- BPE Merges Complete ---
print("\n--- BPE Merges Complete ---")
print(f"Final Vocabulary Size: {len(vocab)}")
print("\nLearned Merges (Pair -> New Token):")
# Pretty print merges
for pair, token in merges.items():
    print(f"{pair} -> '{token}'")

print("\nFinal Word Splits after all merges:")
print(current_splits)

print("\nFinal Vocabulary (sorted):")
# Sort for consistent viewing
final_vocab_sorted = sorted(list(set(vocab))) # Use set to remove potential duplicates if any step introduced them
print(final_vocab_sorted)

# %%

Training Corpus:
This is the first document.
This document is the second document.
And this is the third one.
Is this the first document?
Initial Vocabulary:
[' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '</w>']
Vocabulary Size: 20

Pre-tokenized Word Frequencies:
{('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'): 1, ('s', 'e', 'c', 'o', 'n', 'd', '</w>'): 1, ('A', 'n', 'd', '</w>'): 1, ('t', 'h', 'i', 's', '</w>'): 2, ('t', 'h', 'i', 'r', 'd', '</w>'): 1, ('o', 'n', 'e', '.', '</w>'): 1, ('I', 's', '</w>'): 1, ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '</w>'): 1}

--- Starting BPE Merges ---
Initial Splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 3, ('t', 'h', 'e', '</w>'): 4, ('f', 'i', 'r', 's', 't', '</w>'): 2, ('d', 'o', 'c', 'u', 'm', '

In [9]:

# lesson_4_llama4_feedforward_code.py

# %% [markdown]
# # Understanding the Llama 4 Feed-Forward Network (FFN)
#
# This tutorial explores the Feed-Forward Network (FFN) used in the Llama 4 architecture, specifically the MLP (Multi-Layer Perceptron) variant used in dense layers. The FFN is applied independently to each token position after the attention mechanism and residual connection. Its role is to further process the information aggregated by the attention layer, adding non-linearity and increasing the model's representational capacity.
#
# Llama models typically use a specific FFN structure involving gated linear units (like SwiGLU), which has shown strong performance. We will break down the `Llama4TextMLP` module and its surrounding components (like Layer Normalization) step by step.

# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

# %% [markdown]
# ## Step 1: Setup and Configuration
#
# First, let's define configuration parameters relevant to the FFN and create sample input data. This input data represents the hidden state *after* the attention block and its residual connection, but *before* the post-attention layer normalization.

# %%
# Configuration (Simplified for clarity)
hidden_size = 128  # Dimensionality of the model's hidden states
# Intermediate size for the FFN. Often calculated based on hidden_size.
# A common pattern is around 2.67 * hidden_size, rounded up to a multiple (e.g., 256).
ffn_intermediate_ratio = 8 / 3
multiple_of = 32  # Common multiple for FFN intermediate size
intermediate_size = int(hidden_size * ffn_intermediate_ratio)
# This line of code adjusts the intermediate_size to be a multiple of 'multiple_of'.
# It does this by first adding 'multiple_of - 1' to 'intermediate_size', then performing integer division by 'multiple_of',
# and finally multiplying the result by 'multiple_of'. This effectively rounds up 'intermediate_size' to the nearest multiple of 'multiple_of'.
intermediate_size = ((intermediate_size + multiple_of -
                     1) // multiple_of) * multiple_of

hidden_act = "silu"  # Activation function (SiLU/Swish)
rms_norm_eps = 1e-5  # Epsilon for RMSNorm
ffn_bias = False  # Whether to use bias in FFN linear layers

# Sample Input (Represents output of Attention + Residual)
batch_size = 2
sequence_length = 10
# This is the state before the post-attention LayerNorm
input_to_ffn_block = torch.randn(batch_size, sequence_length, hidden_size)

print("Configuration:")
print(f"  hidden_size: {hidden_size}")
print(
    f"  intermediate_size: {intermediate_size} (Calculated from ratio {ffn_intermediate_ratio:.2f}, multiple of {multiple_of})")
print(f"  hidden_act: {hidden_act}")
print(f"  rms_norm_eps: {rms_norm_eps}")

print("\nSample Input Shape (Before FFN Block Norm):")
print(f"  input_to_ffn_block: {input_to_ffn_block.shape}")

# %% [markdown]
# ## Step 2: Pre-Normalization (Post-Attention LayerNorm)
#
# Before passing the hidden state through the FFN, Llama applies a Layer Normalization step. Unlike standard Transformers that often use LayerNorm *after* the FFN and residual connection, Llama uses a pre-normalization approach. Here, it's the "post-attention" normalization (`post_attention_layernorm` in the original `Llama4TextDecoderLayer`). Llama typically uses RMSNorm.

# %%
# Simplified RMSNorm Implementation


class SimplifiedRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        # Learnable gain parameter
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        # Calculate in float32 for stability
        hidden_states = hidden_states.to(torch.float32)
        # Calculate variance (mean of squares) across the hidden dimension
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # Normalize: input / sqrt(variance + epsilon)
        hidden_states = hidden_states * \
            torch.rsqrt(variance + self.variance_epsilon)
        # Apply learnable weight and cast back to original dtype
        return (self.weight * hidden_states).to(input_dtype)


# Instantiate and apply the normalization
post_attention_norm = SimplifiedRMSNorm(hidden_size, eps=rms_norm_eps)
normalized_hidden_states = post_attention_norm(input_to_ffn_block)

print("Shape after Post-Attention RMSNorm:")
print(f"  normalized_hidden_states: {normalized_hidden_states.shape}")

# %% [markdown]
# ## Step 3: The Feed-Forward Network (MLP with Gated Linear Unit)
#
# The core of the FFN in Llama's dense layers is an MLP using a gated mechanism, often referred to as SwiGLU (SiLU Gated Linear Unit). It consists of three linear projections:
#
# 1.  **`gate_proj`:** Projects the input to the `intermediate_size`.
# 2.  **`up_proj`:** Also projects the input to the `intermediate_size`.
# 3.  **`down_proj`:** Projects the result back down to the `hidden_size`.
#
# The calculation is: `down_proj( F.silu(gate_proj(x)) * up_proj(x) )`
#
# - The `gate_proj` output is passed through an activation function (SiLU/Swish).
# - This activated gate is element-wise multiplied by the `up_proj` output.
# - The result is then projected back to the original hidden dimension by `down_proj`.

# %%
# Define FFN layers
gate_proj = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias)
up_proj = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias)
down_proj = nn.Linear(intermediate_size, hidden_size, bias=ffn_bias)

# Define the activation function (SiLU/Swish)
# ACT2FN could be used here, but for simplicity, we directly use nn.SiLU
if hidden_act == "silu":
    activation_fn = nn.SiLU()
else:
    # Add other activations if needed, otherwise raise error
    raise NotImplementedError(
        f"Activation {hidden_act} not implemented in this example.")

# Apply the FFN layers to the *normalized* hidden states
gate_output = gate_proj(normalized_hidden_states)
up_output = up_proj(normalized_hidden_states)

# Apply activation to the gate and perform element-wise multiplication
activated_gate = activation_fn(gate_output)
gated_result = activated_gate * up_output

# Apply the final down projection
ffn_output = down_proj(gated_result)

print("\nShapes within FFN:")
# (batch, seq_len, intermediate_size)
print(f"  gate_output: {gate_output.shape}")
# (batch, seq_len, intermediate_size)
print(f"  up_output: {up_output.shape}")
# (batch, seq_len, intermediate_size)
print(f"  gated_result: {gated_result.shape}")
print(f"  ffn_output: {ffn_output.shape}")   # (batch, seq_len, hidden_size)


# %% [markdown]
# ## Step 4: Residual Connection
#
# Similar to the attention block, a residual connection is used around the FFN block. The output of the FFN (`ffn_output`) is added to the input that went *into* the FFN block (i.e., the output of the attention block + its residual, stored here as `input_to_ffn_block`).

# %%
# Add the FFN output to the input of the FFN block (before normalization)
final_output = input_to_ffn_block + ffn_output

print("\nShape after FFN Residual Connection:")
# Should be (batch, seq_len, hidden_size)
print(f"  final_output: {final_output.shape}")

# %% [markdown]
# ## Step 5: Putting it Together (Simplified Llama4 FFN Block)
#
# Let's combine the normalization and MLP steps into a simplified module. Note that the residual connection is typically handled *outside* this specific module in the main `DecoderLayer`.

# %%


class SimplifiedLlama4FFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config['hidden_size']
        self.intermediate_size = config['intermediate_size']
        self.hidden_act = config['hidden_act']
        self.ffn_bias = config['ffn_bias']
        self.rms_norm_eps = config['rms_norm_eps']

        # Normalization Layer (applied before MLP)
        self.norm = SimplifiedRMSNorm(self.hidden_size, eps=self.rms_norm_eps)

        # MLP Layers
        self.gate_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=self.ffn_bias)
        self.up_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=self.ffn_bias)
        self.down_proj = nn.Linear(
            self.intermediate_size, self.hidden_size, bias=self.ffn_bias)

        # Activation
        if self.hidden_act == "silu":
            self.activation_fn = nn.SiLU()
        else:
            raise NotImplementedError(
                f"Activation {self.hidden_act} not implemented.")

    def forward(self, hidden_states):
        # 1. Apply pre-FFN normalization
        normalized_states = self.norm(hidden_states)

        # 2. Apply MLP (SwiGLU)
        gate = self.gate_proj(normalized_states)
        up = self.up_proj(normalized_states)
        down = self.down_proj(self.activation_fn(gate) * up)

        # This module returns *only* the MLP output.
        # The residual connection is applied outside.
        return down


# Instantiate and run the simplified module
ffn_config_dict = {
    'hidden_size': hidden_size,
    'intermediate_size': intermediate_size,
    'hidden_act': hidden_act,
    'ffn_bias': ffn_bias,
    'rms_norm_eps': rms_norm_eps,
}

simplified_ffn_module = SimplifiedLlama4FFN(ffn_config_dict)

# Run forward pass using the module
# Input is the state *before* the norm
mlp_output_from_module = simplified_ffn_module(input_to_ffn_block)

# Apply the residual connection externally
final_output_from_module = input_to_ffn_block + mlp_output_from_module

print("\nOutput shape from simplified FFN module (before residual):",
      mlp_output_from_module.shape)
print("Output shape after external residual connection:",
      final_output_from_module.shape)
# Verify that the manual calculation matches the module output (should be very close)
print("Outputs are close:", torch.allclose(
    final_output, final_output_from_module, atol=1e-6))


# %% [markdown]
# ## Conclusion
#
# The Llama 4 Feed-Forward Network block typically consists of:
# 1.  **Pre-Normalization:** An RMSNorm layer applied to the output of the previous (attention + residual) block.
# 2.  **Gated MLP (SwiGLU):** Two linear layers projecting to an intermediate dimension, combined using an activation (SiLU) and element-wise multiplication, followed by a projection back to the hidden dimension.
# 3.  **Residual Connection:** The output of the MLP is added back to the input of the normalization layer.
#
# This structure provides the necessary non-linearity and processing power for each token position within the transformer layer.

# %%


Configuration:
  hidden_size: 128
  intermediate_size: 352 (Calculated from ratio 2.67, multiple of 32)
  hidden_act: silu
  rms_norm_eps: 1e-05

Sample Input Shape (Before FFN Block Norm):
  input_to_ffn_block: torch.Size([2, 10, 128])
Shape after Post-Attention RMSNorm:
  normalized_hidden_states: torch.Size([2, 10, 128])

Shapes within FFN:
  gate_output: torch.Size([2, 10, 352])
  up_output: torch.Size([2, 10, 352])
  gated_result: torch.Size([2, 10, 352])
  ffn_output: torch.Size([2, 10, 128])

Shape after FFN Residual Connection:
  final_output: torch.Size([2, 10, 128])

Output shape from simplified FFN module (before residual): torch.Size([2, 10, 128])
Output shape after external residual connection: torch.Size([2, 10, 128])
Outputs are close: False


In [10]:
# lesson_4_llama4_feedforward_code.py

# %% [markdown]
# # Understanding the Llama 4 Feed-Forward Network (FFN)
#
# This tutorial explores the Feed-Forward Network (FFN) used in the Llama 4 architecture, specifically the MLP (Multi-Layer Perceptron) variant used in dense layers. The FFN is applied independently to each token position after the attention mechanism and residual connection. Its role is to further process the information aggregated by the attention layer, adding non-linearity and increasing the model's representational capacity.
#
# Llama models typically use a specific FFN structure involving gated linear units (like SwiGLU), which has shown strong performance. We will break down the `Llama4TextMLP` module and its surrounding components (like Layer Normalization) step by step.

# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

# %% [markdown]
# ## Step 1: Setup and Configuration
#
# First, let's define configuration parameters relevant to the FFN and create sample input data. This input data represents the hidden state *after* the attention block and its residual connection, but *before* the post-attention layer normalization.

# %%
# Configuration (Simplified for clarity)
hidden_size = 128  # Dimensionality of the model's hidden states
# Intermediate size for the FFN. Often calculated based on hidden_size.
# A common pattern is around 2.67 * hidden_size, rounded up to a multiple (e.g., 256).
ffn_intermediate_ratio = 8 / 3
multiple_of = 32  # Common multiple for FFN intermediate size
intermediate_size = int(hidden_size * ffn_intermediate_ratio)
# This line of code adjusts the intermediate_size to be a multiple of 'multiple_of'.
# It does this by first adding 'multiple_of - 1' to 'intermediate_size', then performing integer division by 'multiple_of',
# and finally multiplying the result by 'multiple_of'. This effectively rounds up 'intermediate_size' to the nearest multiple of 'multiple_of'.
intermediate_size = ((intermediate_size + multiple_of -
                     1) // multiple_of) * multiple_of

hidden_act = "silu"  # Activation function (SiLU/Swish)
rms_norm_eps = 1e-5  # Epsilon for RMSNorm
ffn_bias = False  # Whether to use bias in FFN linear layers

# Sample Input (Represents output of Attention + Residual)
batch_size = 2
sequence_length = 10
# This is the state before the post-attention LayerNorm
input_to_ffn_block = torch.randn(batch_size, sequence_length, hidden_size)

print("Configuration:")
print(f"  hidden_size: {hidden_size}")
print(
    f"  intermediate_size: {intermediate_size} (Calculated from ratio {ffn_intermediate_ratio:.2f}, multiple of {multiple_of})")
print(f"  hidden_act: {hidden_act}")
print(f"  rms_norm_eps: {rms_norm_eps}")

print("\nSample Input Shape (Before FFN Block Norm):")
print(f"  input_to_ffn_block: {input_to_ffn_block.shape}")

# %% [markdown]
# ## Step 2: Pre-Normalization (Post-Attention LayerNorm)
#
# Before passing the hidden state through the FFN, Llama applies a Layer Normalization step. Unlike standard Transformers that often use LayerNorm *after* the FFN and residual connection, Llama uses a pre-normalization approach. Here, it's the "post-attention" normalization (`post_attention_layernorm` in the original `Llama4TextDecoderLayer`). Llama typically uses RMSNorm.

# %%
# Simplified RMSNorm Implementation


class SimplifiedRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        # Learnable gain parameter
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        # Calculate in float32 for stability
        hidden_states = hidden_states.to(torch.float32)
        # Calculate variance (mean of squares) across the hidden dimension
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # Normalize: input / sqrt(variance + epsilon)
        hidden_states = hidden_states * \
            torch.rsqrt(variance + self.variance_epsilon)
        # Apply learnable weight and cast back to original dtype
        return (self.weight * hidden_states).to(input_dtype)


# Instantiate and apply the normalization
post_attention_norm = SimplifiedRMSNorm(hidden_size, eps=rms_norm_eps)
normalized_hidden_states = post_attention_norm(input_to_ffn_block)

print("Shape after Post-Attention RMSNorm:")
print(f"  normalized_hidden_states: {normalized_hidden_states.shape}")

# %% [markdown]
# ## Step 3: The Feed-Forward Network (MLP with Gated Linear Unit)
#
# The core of the FFN in Llama's dense layers is an MLP using a gated mechanism, often referred to as SwiGLU (SiLU Gated Linear Unit). It consists of three linear projections:
#
# 1.  **`gate_proj`:** Projects the input to the `intermediate_size`.
# 2.  **`up_proj`:** Also projects the input to the `intermediate_size`.
# 3.  **`down_proj`:** Projects the result back down to the `hidden_size`.
#
# The calculation is: `down_proj( F.silu(gate_proj(x)) * up_proj(x) )`
#
# - The `gate_proj` output is passed through an activation function (SiLU/Swish).
# - This activated gate is element-wise multiplied by the `up_proj` output.
# - The result is then projected back to the original hidden dimension by `down_proj`.

# %%
# Define FFN layers
gate_proj = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias)
up_proj = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias)
down_proj = nn.Linear(intermediate_size, hidden_size, bias=ffn_bias)

# Define the activation function (SiLU/Swish)
# ACT2FN could be used here, but for simplicity, we directly use nn.SiLU
if hidden_act == "silu":
    activation_fn = nn.SiLU()
else:
    # Add other activations if needed, otherwise raise error
    raise NotImplementedError(
        f"Activation {hidden_act} not implemented in this example.")

# Apply the FFN layers to the *normalized* hidden states
gate_output = gate_proj(normalized_hidden_states)
up_output = up_proj(normalized_hidden_states)

# Apply activation to the gate and perform element-wise multiplication
activated_gate = activation_fn(gate_output)
gated_result = activated_gate * up_output

# Apply the final down projection
ffn_output = down_proj(gated_result)

print("\nShapes within FFN:")
# (batch, seq_len, intermediate_size)
print(f"  gate_output: {gate_output.shape}")
# (batch, seq_len, intermediate_size)
print(f"  up_output: {up_output.shape}")
# (batch, seq_len, intermediate_size)
print(f"  gated_result: {gated_result.shape}")
print(f"  ffn_output: {ffn_output.shape}")   # (batch, seq_len, hidden_size)


# %% [markdown]
# ## Step 4: Residual Connection
#
# Similar to the attention block, a residual connection is used around the FFN block. The output of the FFN (`ffn_output`) is added to the input that went *into* the FFN block (i.e., the output of the attention block + its residual, stored here as `input_to_ffn_block`).

# %%
# Add the FFN output to the input of the FFN block (before normalization)
final_output = input_to_ffn_block + ffn_output

print("\nShape after FFN Residual Connection:")
# Should be (batch, seq_len, hidden_size)
print(f"  final_output: {final_output.shape}")

# %% [markdown]
# ## Step 5: Putting it Together (Simplified Llama4 FFN Block)
#
# Let's combine the normalization and MLP steps into a simplified module. Note that the residual connection is typically handled *outside* this specific module in the main `DecoderLayer`.

# %%


class SimplifiedLlama4FFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config['hidden_size']
        self.intermediate_size = config['intermediate_size']
        self.hidden_act = config['hidden_act']
        self.ffn_bias = config['ffn_bias']
        self.rms_norm_eps = config['rms_norm_eps']

        # Normalization Layer (applied before MLP)
        self.norm = SimplifiedRMSNorm(self.hidden_size, eps=self.rms_norm_eps)

        # MLP Layers
        self.gate_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=self.ffn_bias)
        self.up_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=self.ffn_bias)
        self.down_proj = nn.Linear(
            self.intermediate_size, self.hidden_size, bias=self.ffn_bias)

        # Activation
        if self.hidden_act == "silu":
            self.activation_fn = nn.SiLU()
        else:
            raise NotImplementedError(
                f"Activation {self.hidden_act} not implemented.")

    def forward(self, hidden_states):
        # 1. Apply pre-FFN normalization
        normalized_states = self.norm(hidden_states)

        # 2. Apply MLP (SwiGLU)
        gate = self.gate_proj(normalized_states)
        up = self.up_proj(normalized_states)
        down = self.down_proj(self.activation_fn(gate) * up)

        # This module returns *only* the MLP output.
        # The residual connection is applied outside.
        return down


# Instantiate and run the simplified module
ffn_config_dict = {
    'hidden_size': hidden_size,
    'intermediate_size': intermediate_size,
    'hidden_act': hidden_act,
    'ffn_bias': ffn_bias,
    'rms_norm_eps': rms_norm_eps,
}

simplified_ffn_module = SimplifiedLlama4FFN(ffn_config_dict)

# Run forward pass using the module
# Input is the state *before* the norm
mlp_output_from_module = simplified_ffn_module(input_to_ffn_block)

# Apply the residual connection externally
final_output_from_module = input_to_ffn_block + mlp_output_from_module

print("\nOutput shape from simplified FFN module (before residual):",
      mlp_output_from_module.shape)
print("Output shape after external residual connection:",
      final_output_from_module.shape)
# Verify that the manual calculation matches the module output (should be very close)
print("Outputs are close:", torch.allclose(
    final_output, final_output_from_module, atol=1e-6))


# %% [markdown]
# ## Conclusion
#
# The Llama 4 Feed-Forward Network block typically consists of:
# 1.  **Pre-Normalization:** An RMSNorm layer applied to the output of the previous (attention + residual) block.
# 2.  **Gated MLP (SwiGLU):** Two linear layers projecting to an intermediate dimension, combined using an activation (SiLU) and element-wise multiplication, followed by a projection back to the hidden dimension.
# 3.  **Residual Connection:** The output of the MLP is added back to the input of the normalization layer.
#
# This structure provides the necessary non-linearity and processing power for each token position within the transformer layer.

# %%


Configuration:
  hidden_size: 128
  intermediate_size: 352 (Calculated from ratio 2.67, multiple of 32)
  hidden_act: silu
  rms_norm_eps: 1e-05

Sample Input Shape (Before FFN Block Norm):
  input_to_ffn_block: torch.Size([2, 10, 128])
Shape after Post-Attention RMSNorm:
  normalized_hidden_states: torch.Size([2, 10, 128])

Shapes within FFN:
  gate_output: torch.Size([2, 10, 352])
  up_output: torch.Size([2, 10, 352])
  gated_result: torch.Size([2, 10, 352])
  ffn_output: torch.Size([2, 10, 128])

Shape after FFN Residual Connection:
  final_output: torch.Size([2, 10, 128])

Output shape from simplified FFN module (before residual): torch.Size([2, 10, 128])
Output shape after external residual connection: torch.Size([2, 10, 128])
Outputs are close: False
