In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
from flax import struct
from typing import Optional

# Try to import torch_xla for TPU support
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    TORCH_XLA_AVAILABLE = True
except ImportError:
    TORCH_XLA_AVAILABLE = False

# Device selection: TPU > CUDA > CPU
if TORCH_XLA_AVAILABLE:
    device = torch_xla.device()  # TPU device
    print(f"Using PyTorch XLA device: {device}")
elif torch.cuda.is_available():
    device = "cuda"
    print("Using CUDA device")
else:
    device = "cpu"
    print("Using CPU device")

jax_dtype = jnp.bfloat16
torch_dtype = torch.bfloat16


In [None]:
# JAX FeedForward implementation
@struct.dataclass
class FeedForwardParams:
    w_gate: jax.Array  # Corresponds to gate_proj
    w_up: jax.Array  # Corresponds to up_proj
    w_down: jax.Array  # Corresponds to down_proj


@partial(jit, static_argnames=["activation_fn"], donate_argnums=[0])
def feed_forward(
    x: jax.Array,
    params: FeedForwardParams,
    activation_fn: str,  # Added activation function name
) -> jax.Array:
    """
    Compute FeedForward network (MLP) using a configurable activation function (like SwiGLU).

    Args:
        x: Input tensor of shape [batch_size, seqlen, dim].
        params: Dataclass containing weight matrices (w_gate, w_up, w_down).
        activation_fn: Name of the activation function ('silu', 'relu', 'gelu').

    Returns:
        Output tensor after MLP computation.
    """

    # Project input: x -> gate, up
    # x: [bs, seqlen, dim], w_gate: [dim, hidden_dim], w_up: [dim, hidden_dim]
    gate = jnp.einsum("bsd,dh->bsh", x, params.w_gate)
    up = jnp.einsum("bsd,dh->bsh", x, params.w_up)

    # Apply the specified activation function (SwiGLU style)
    if activation_fn == "silu":
        activated_gate = jax.nn.silu(gate)
    elif activation_fn == "relu":
        activated_gate = jax.nn.relu(gate)
    elif activation_fn == "gelu":
        # Use approximate=False for exact GELU, True for faster approximation
        activated_gate = jax.nn.gelu(gate, approximate=False)
    else:
        raise ValueError(f"Unsupported activation function: {activation_fn}")

    fused_activation = activated_gate * up

    # Project down
    # fused_swiglu: [bs, seqlen, hidden_dim], w_down: [hidden_dim, dim]
    output = jnp.einsum("bsh,hd->bsd", fused_activation, params.w_down)

    return output


In [None]:
# PyTorch FeedForward implementation
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        # hidden dim gymnastics that Meta simplified only later
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


In [None]:
# 1. Setup Parameters
bsz = 4
seqlen = 64
dim = 1536
multiple_of = 32
dtype = np.float32

# Mimic hidden_dim calculation from PyTorch implementation
hidden_dim = (dim//3)*8

print(f"Batch size: {bsz}")
print(f"Sequence length: {seqlen}")
print(f"Model dimension: {dim}")
print(f"Hidden dimension: {hidden_dim}")


In [None]:
# 2. Create shared weights and inputs
np.random.seed(0)
x_np = np.random.randn(bsz, seqlen, dim).astype(dtype)

# PyTorch names weights w1, w2, w3. JAX uses w_gate, w_up, w_down.
# Mapping: torch.w1 -> jax.w_gate, torch.w3 -> jax.w_up, torch.w2 -> jax.w_down
w1_np = np.random.normal(0, 0.02, (dim, hidden_dim)).astype(dtype)  # gate_proj
w3_np = np.random.normal(0, 0.02, (dim, hidden_dim)).astype(dtype)  # up_proj
w2_np = np.random.normal(0, 0.02, (hidden_dim, dim)).astype(dtype)  # down_proj

print(f"Input shape: {x_np.shape}")
print(f"Weight shapes: w1={w1_np.shape}, w2={w2_np.shape}, w3={w3_np.shape}")


In [None]:
# 3. JAX setup
x_jax = jnp.array(x_np, dtype=jax_dtype)
jax_params = FeedForwardParams(
    w_gate=jnp.array(w1_np, dtype=jax_dtype), 
    w_up=jnp.array(w3_np, dtype=jax_dtype), 
    w_down=jnp.array(w2_np, dtype=jax_dtype)
)
output_jax = feed_forward(x_jax, jax_params, "silu")

print(f"JAX output shape: {output_jax.shape}")
print(f"JAX output dtype: {output_jax.dtype}")


In [None]:
# 4. PyTorch setup
x_torch = torch.tensor(x_np, device=device, dtype=torch_dtype)
torch_ff = FeedForward(
    dim=dim, hidden_dim=hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=None
)
torch_ff.w1.weight = nn.Parameter(torch.tensor(w1_np.T, device=device, dtype=torch_dtype))
torch_ff.w3.weight = nn.Parameter(torch.tensor(w3_np.T, device=device, dtype=torch_dtype))
torch_ff.w2.weight = nn.Parameter(torch.tensor(w2_np.T, device=device, dtype=torch_dtype))
output_torch = torch_ff(x_torch)

print(f"PyTorch output shape: {output_torch.shape}")
print(f"PyTorch output dtype: {output_torch.dtype}")


In [None]:
# 5. Compare outputs
output_jax_np = np.array(output_jax)
output_torch_np = output_torch.float().detach().cpu().numpy()

# Check shapes match
assert output_jax_np.shape == output_torch_np.shape, f"Shape mismatch: JAX {output_jax_np.shape} vs PyTorch {output_torch_np.shape}"

# Compare with same tolerances as test_ops.py
np.testing.assert_allclose(
    output_jax_np, output_torch_np, rtol=1e-2, atol=1e-3
)

print("âœ“ Feedforward test passed!")
print(f"Max absolute difference: {np.max(np.abs(output_jax_np - output_torch_np))}")
print(f"Mean absolute difference: {np.mean(np.abs(output_jax_np - output_torch_np))}")
