In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import jax
import jax.numpy as jnp
from flax import struct
import numpy as np
from functools import partial
from typing import Optional

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        # hidden dim gymnastics that Meta simplified only later
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        gate = self.w1(x)
        up = self.w3(x)
        return self.w2(F.silu(gate) * up)


In [None]:
bsz = 2
seqlen = 64
dim = 192
multiple_of = 32
dtype = np.float32

# Mimic hidden_dim calculation from PyTorch implementation
hidden_dim = int(2 * (4 * dim) / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

# 2. Create shared weights and inputs
np.random.seed(0)
x_np = np.random.randn(bsz, seqlen, dim).astype(dtype)

# PyTorch names weights w1, w2, w3. JAX uses w_gate, w_up, w_down.
# Mapping: torch.w1 -> jax.w_gate, torch.w3 -> jax.w_up, torch.w2 -> jax.w_down
w1_np = np.random.normal(0, 0.02, (dim, hidden_dim)).astype(dtype)  # gate_proj
w3_np = np.random.normal(0, 0.02, (dim, hidden_dim)).astype(dtype)  # up_proj
w2_np = np.random.normal(0, 0.02, (hidden_dim, dim)).astype(dtype)  # down_proj

In [None]:
# Example: Initialize attention weights like Flax nn.initializers.normal(stddev=0.02)
# Define attention parameters (like in your LLaMa model)
n_heads = 8
n_kv_heads = 4  # For grouped query attention
head_dim = dim // n_heads

# Method 1: Using np.random.normal directly (equivalent to nn.initializers.normal(stddev=0.02))
wq_np = np.random.normal(0, 0.02, (dim, n_heads, head_dim)).astype(dtype)
wk_np = np.random.normal(0, 0.02, (dim, n_kv_heads, head_dim)).astype(dtype)
wv_np = np.random.normal(0, 0.02, (dim, n_kv_heads, head_dim)).astype(dtype)
wo_np = np.random.normal(0, 0.02, (n_heads * head_dim, dim)).astype(dtype)

# Method 2: Using np.random.randn and scaling (equivalent approach)
# wv_np = (np.random.randn(dim, n_kv_heads, head_dim) * 0.02).astype(dtype)

print(f"wv shape: {wv_np.shape}")
print(f"wv std: {wv_np.std():.6f} (should be ~0.02)")
print(f"wv mean: {wv_np.mean():.6f} (should be ~0.0)")


In [5]:
# 4. PyTorch setup
x_torch = torch.tensor(x_np, device=device)
torch_ff = FeedForward(
    dim=dim, hidden_dim=hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=None
)
torch_ff.w1.weight = torch.nn.Parameter(torch.tensor(w1_np.T, device=device))
torch_ff.w3.weight = torch.nn.Parameter(torch.tensor(w3_np.T, device=device))
torch_ff.w2.weight = torch.nn.Parameter(torch.tensor(w2_np.T, device=device))
output_torch = torch_ff(x_torch)

In [6]:
output_torch[0,:5,0]

tensor([ 3824.5667,  3250.8960, -1562.5957,  4669.1479, -2415.9500],
       grad_fn=<SelectBackward0>)

In [7]:

@struct.dataclass
class FeedForwardParams:
    w_gate: jax.Array  # Corresponds to gate_proj
    w_up: jax.Array  # Corresponds to up_proj
    w_down: jax.Array  # Corresponds to down_proj

@partial(jax.jit, static_argnames=["activation_fn"],donate_argnums=[0])
def feed_forward_jax(
    x: jax.Array,
    params: FeedForwardParams,
    activation_fn: str,  # Added activation function name
) -> jax.Array:
    """
    Compute FeedForward network (MLP) using a configurable activation function (like SwiGLU).

    Args:
        x: Input tensor of shape [batch_size, seqlen, dim].
        params: Dataclass containing weight matrices (w_gate, w_up, w_down).
        activation_fn: Name of the activation function ('silu', 'relu', 'gelu').

    Returns:
        Output tensor after MLP computation.
    """

    # Project input: x -> gate, up
    # x: [bs, seqlen, dim], w_gate: [dim, hidden_dim], w_up: [dim, hidden_dim]
    gate = jnp.einsum("bsd,dh->bsh", x, params.w_gate)
    up = jnp.einsum("bsd,dh->bsh", x, params.w_up)

    # Apply the specified activation function (SwiGLU style)
    if activation_fn == "silu":
        activated_gate = jax.nn.silu(gate)
    elif activation_fn == "relu":
        activated_gate = jax.nn.relu(gate)
    elif activation_fn == "gelu":
        # Use approximate=False for exact GELU, True for faster approximation
        activated_gate = jax.nn.gelu(gate, approximate=False)
    else:
        raise ValueError(f"Unsupported activation function: {activation_fn}")
        # replace error handling with chex

    fused_activation = activated_gate * up

    # Project down
    # fused_swiglu: [bs, seqlen, hidden_dim], w_down: [hidden_dim, dim]
    output = jnp.einsum("bsh,hd->bsd", fused_activation, params.w_down)

    return output

In [8]:

# 3. JAX setup
x_jax = jnp.array(x_np)
jax_params = FeedForwardParams(
    w_gate=jnp.array(w1_np), w_up=jnp.array(w3_np), w_down=jnp.array(w2_np)
)
output_jax = feed_forward_jax(x_jax, jax_params,"silu")


In [9]:
print(output_jax[0,:5,0])
print(output_torch[0,:5,0])

[ 3821.081   3259.9338 -1562.5793  4675.722  -2414.2063]
tensor([ 3824.5667,  3250.8960, -1562.5957,  4669.1479, -2415.9500],
       grad_fn=<SelectBackward0>)
