In [1]:
# importing required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

  cpu = _conversion_method_template(device=torch.device("cpu"))


# Config

*   intermediate size is the embedding space where the model learns more complex relationships and it is often a integer rounded of to `2.68 * hidden_size`.  

*   We must make sure it is a multiple of 32 as modern GPUs are very efficieint in performing calculations on matrices of the sizes 2, 4, 8, 16, 32, 64, etc

##### Layer Config

In [2]:
hidden_size = 128  # dimesionality of the model's hidden state
ffn_intermediate_ratio = 8/3 # aprox 2.68
multiple_of = 32  # to make sure the intermediate size is a multiple of this value
intermediate_size = int(hidden_size * ffn_intermediate_ratio)

# making sure "hidden_size" is a multiple of "multiple_of"
intermediate_size = ((intermediate_size + multiple_of - 1) // multiple_of) * multiple_of

In [3]:
hidden_act = 'silu'  # activation function
rms_norm_eps = 1e-5
ffn_bias = False  # whether or not to use bias in FFN linear layers

##### Sample Input
*(a.k.a output of attention mechanism)*

In [4]:
batch_size = 2
sequence_length = 10
input_to_ffn_block = torch.randn(batch_size, sequence_length, hidden_size)

In [5]:
print('Configuration:')
print(f'    hidden_size: {hidden_size}')
print(f"    intermediate_size: {intermediate_size} (Calculated from ratio {ffn_intermediate_ratio:.2f}, multiple of {multiple_of})")
print(f"    hidden_act: {hidden_act}")
print(f"    rms_norm_eps: {rms_norm_eps}")

print("\nSample Input Shape (Before FFN Block Norm):")
print(f"    input_to_ffn_block: {input_to_ffn_block.shape}")

Configuration:
    hidden_size: 128
    intermediate_size: 352 (Calculated from ratio 2.67, multiple of 32)
    hidden_act: silu
    rms_norm_eps: 1e-05

Sample Input Shape (Before FFN Block Norm):
    input_to_ffn_block: torch.Size([2, 10, 128])


# Pre-Normalization

Unlike transformers that apply LayerNorm *after* the FFN and residual connection, Llama uses a pre-normalization aproach (`post-attention normalization` in the original `Llama4TextDecoderLayer`).

##### EXPLANATION
Before the FFN, we apply Root Mean Square Normalization (RMSNorm) to stabilize the training process.

The problem with deep networks is that the numbers flowing through them can get too big or small, which makes learning difficult. Normalization fixes this, but it also affects the vector's magnitude (its length or signal strength). The original magnitude might have been important, so we need a way to let the model recover it if needed.

RMSNorm solves this with a two-step process:

Normalize: First, the input vector is scaled to a standard size. This stabilizes the numbers but changes the original magnitude.

Rescale: Second, the normalized vector is multiplied by a learnable weight. This acts like a "volume knob" that allows the model to learn how to scale the signal back up or down, effectively learning to restore the magnitude if the original strength was meaningful.

`self.weight`: The RMSNorm calculation first normalizes the hidden state to have a unit standard deviation. This line then multiplies that normalized state by the weight tensor.  
( *During learning, the model learns to scale them up or downas needed to improve performance.* )  


`variance = hidden_states.pow(2).mean(-1, keepdim=True)`  
example:  
[ [1, 2, 3],  
&nbsp;&nbsp; [4, 5, 6] ]  

*   `-1` tells to calc mean along the cols:  
    *   Mean of the 1st row [1, 2, 3] is 2  
    *   Mean of the 2nd row [4, 5, 6] is 5.  

*   `keepdim`:
    *   `If keepdim=False (the default)`:
The dimension you averaged over is removed. The output shape would be a 1D tensor: [2, 5]

    *   `If keepdim=True`:
The dimension is kept, but its size becomes 1 (*we collapsed 3 number into 1 number*). The output shape would be a 2D tensor: [[2], [5]]

In [6]:
class SimplifiedRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))  # learnable gain parameter
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)  # float32 for stability

        # calculate variance (mean of square)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)

        # normalise: input / sqrt(variance + epsilon)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        # apply the learnable weights andcast back to the original dtype
        return (self.weight * hidden_states).to(input_dtype)

In [7]:
post_attention_norm = SimplifiedRMSNorm(hidden_size, eps = rms_norm_eps)
normalized_hidden_states = post_attention_norm(input_to_ffn_block)

In [8]:
print("Shape after Post-Attention RMSNorm:")
print(f'    normalized_hidden-states: {normalized_hidden_states.shape}')

Shape after Post-Attention RMSNorm:
    normalized_hidden-states: torch.Size([2, 10, 128])


# Feed-Forward Network 
(MLP with Gated linear Unit)

##### Defining

The core of the LLM's dense layers is an MLPusing a gated mechanism, SiLU Gated Linear Unit (SwiGLU).  
It consists of three layers:
1.  **`gate_proj`:** Projects the input to the `intermediate_size`.
2.  **`up_proj`:** Also projects the input to the `intermediate_size`.
3.  **`down_proj`:** Projects the result back down to the `hidden_size`.
 
Calculation: `down_proj( F.silu(gate_proj(x)) * up_proj(x))`
- The `gate_proj` output is passed through an activation function (SiLU/Swish).
- This activated gate is element-wise multiplied by the `up_proj` output.
- The result is then projected back to the original hidden dimension by `down_proj`.

In [9]:
gate_proj = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias)
up_proj = nn.Linear(hidden_size, intermediate_size, bias=ffn_bias)
down_proj = nn.Linear(intermediate_size, hidden_size, bias=ffn_bias)

Normally, `ACT2FN` is used here, it's like a lookup table that we can use instead of hardcoding the function `nn.SiLU`  

*Ecample:*  
*In a config file we can write:*  
` "hidden_act": "silu"`  

*In the model code, we would:*  
`activation_function = ACT2FN[config.hidden_act]`


In [10]:
if hidden_act == "silu":
    activation_fn = nn.SiLU()
else:
    # we can add any other activation function or just raise error
    raise NotImplementedError(f"Activation {hidden_act} not implemented as of now.")

##### Applying

In [11]:
gate_output = gate_proj(normalized_hidden_states)
up_output = up_proj(normalized_hidden_states)

# Applying the function acc to equation previously explained
activated_gate = activation_fn(gate_output)
gated_result = activated_gate * up_output
ffn_output = down_proj(gated_result)

In [12]:
print('Shapes wihtin FFN:')
print(f'    gate_output: {gate_output.shape}')
print(f'    up_output: {up_output.shape}')
print()
print(f'    gated_result: {gated_result.shape}')
print(f'    ffn_output: {ffn_output.shape}')

Shapes wihtin FFN:
    gate_output: torch.Size([2, 10, 352])
    up_output: torch.Size([2, 10, 352])

    gated_result: torch.Size([2, 10, 352])
    ffn_output: torch.Size([2, 10, 128])


# Residual Connection

When you take an iunput to a layer/block and add it to the output of the layer.
`final_output = input + f(input)`

In [13]:
final_output = input_to_ffn_block + ffn_output

print('Shape after FFN Residual Connection: ')
print(f'    final_output: {final_output.shape}')

Shape after FFN Residual Connection: 
    final_output: torch.Size([2, 10, 128])


# Putting It All Together

##### Initialising class SimplifiedLlama4FFN

In [14]:
class SimplifiedLlama4FFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config['hidden_size']
        self.intermediate_size = config['intermediate_size']
        self.hidden_act = config['hidden_act']
        self.ffn_bias = config['ffn_bias']
        self.rms_norm_eps = config['rms_norm_eps']

        # normlization before MLP
        self.norm = SimplifiedRMSNorm(self.hidden_size, eps=self.rms_norm_eps)

        # MLP layers
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=self.ffn_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=self.ffn_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.ffn_bias)

        # activation function
        if self.hidden_act == "silu":
            self.activation_fn = nn.SiLU()
        else:
            raise NotImplementedError(f'Activation {self.hidden_act} not implemented.')
        
    def forward(self, hidden_states):
        # aplpying normalization
        normalized_states = self.norm(hidden_states)

        # applying mlp
        gate = self.gate_proj(normalized_states)
        up = self.up_proj(normalized_states)
        down = self.down_proj(self.activation_fn(gate) * up)

        return down

##### Instantiate

In [15]:
ffn_config_dict = {
    'hidden_size': hidden_size,
    'intermediate_size': intermediate_size,
    'hidden_act': hidden_act,
    'ffn_bias': ffn_bias,
    'rms_norm_eps': rms_norm_eps
}

simplified_ffn_module = SimplifiedLlama4FFN(ffn_config_dict)

# forward pass using the module
mlp_output_from_module = simplified_ffn_module(input_to_ffn_block)

# applying residual connection
final_output_from_module = input_to_ffn_block + mlp_output_from_module

In [16]:
print(f'Output shape from simplified FFN module: {mlp_output_from_module.shape}')
print(f'Output shape after external residual connection: ', final_output_from_module.shape)



Output shape from simplified FFN module: torch.Size([2, 10, 128])
Output shape after external residual connection:  torch.Size([2, 10, 128])


##### Simple element-wise verification

In [17]:
print(f'Outputs closely match: {torch.allclose(final_output, final_output_from_module, atol=1e-6)}')

Outputs closely match: False


It's `False` because the manual layers such as `post_attention_norm`, `gate_proj`, etc. and the second set of weights inside `SimplifiedLlama4FFN` are randomly loaded.

*we'd have to manually load the same weights*