- Demonstrates differences between PyTorch and Flax for neural network models in JAX
- Shows how to define a simple multilayer perceptron (MLP) in both frameworks
- Highlights stateful (PyTorch) vs stateless/lazy (Flax) parameter handling
- Includes code for model instantiation and discussion of where parameters "live"
- Useful for users transitioning from PyTorch to JAX + Flax


In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn  # Flax
from flax.core import freeze
import torch                  # PyTorch
import torch.nn as tnn
import numpy as np
import copy

#### Model definition

In [None]:
# --- PYTORCH (Stateful) ---
class TorchMLP(tnn.Module):
    def __init__(self, output_dim):
        super().__init__()
        # Layers are created AND weights are initialized here immediately.
        # The weights live inside 'self.fc1' and 'self.fc2'.
        self.fc1 = tnn.Linear(10, 128)
        self.fc2 = tnn.Linear(128, output_dim)

    def forward(self, x):
        x = tnn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# --- FLAX (Stateless) ---
class FlaxMLP(nn.Module):
    output_dim: int  # Type annotation (dataclass style)

    @nn.compact
    def __call__(self, x):
        # Layers are defined lazily.
        # No weights exist here yet. This is just a graph description.
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        return x

#### Instantiation of the model

In [None]:
# --- PYTORCH ---
# Instantiation = Initialization.
# Random weights are generated immediately upon creation.
torch_model = TorchMLP(output_dim=1)
print(f"PyTorch: Model created. Weights are inside.")

# --- FLAX ---
# Instantiation = Configuration.
# No weights are generated yet.
flax_model = FlaxMLP(output_dim=1)
print(f"Flax:    Blueprint created. No weights yet.")

#### Weights generation in Flax

In [None]:
# Flax needs a PRNGKey and Dummy Input
key = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 10))

# The 'init' function returns the State (Pytree)
flax_params = flax_model.init(key, dummy_input)

print("Flax:    Weights generated explicitly via 'init'.")

#### Why dummy_input? (Lazy Initialization)

Notice that in our Flax model, we defined `nn.Dense(features=128)` but we never said what 
the input size was.

Flax infers the input shape automatically the first time data passes through. 


To trigger this inference and allocate the correct memory for weights, 
we must pass a single piece of "dummy" data (just zeros or ones) through the model 
using `init`.

If `dummy_input` has shape (1, 50), Flax creates a 50x128 matrix.

If `dummy_input` has shape (1, 10), Flax creates a 10x128 matrix.

#### Forward pass

The `.apply` method in Flax is used to run a forward pass of your model using 
explicitly provided parameters and any other state. 

When you define a Flax module like `FlaxMLP`, you typically implement a `__call__` method, 
which describes how your data flows through the model (the forward computation). 

However, to actually evaluate the model, Flax separates the *blueprint* 
(the module and its `__call__` definition) from the *parameters/state* 
(which are returned by `.init`). 

The `.apply` method takes the parameters (e.g., `flax_params`) and the input data, 
and runs the logic you wrote in `__call__`. So, it's effectively a way to say: 
“use *these* parameters in the `__call__` forward pass, with *this* input”.

**Summary:**
- `__call__`: You write this to define the forward pass computation.
- `.apply(params, x)`: This *invokes* your `__call__`, using the specified `params` and data `x`.



In [None]:
# Create input data
x_numpy = np.random.randn(5, 10).astype(np.float32)
x_torch = torch.from_numpy(x_numpy)
x_jax = jnp.array(x_numpy)

# --- PYTORCH ---
# Syntax: model(x)
# Implicitly: It finds 'self.fc1.weight' internally to do the math.
y_torch = torch_model(x_torch)

# --- FLAX ---
# Syntax: model.apply(params, x)
# Explicitly: We must hand it the weights we want to use.
y_flax = flax_model.apply(flax_params, x_jax)

print("Forward pass complete for both.")

#### Mutation

In [None]:
# --- PYTORCH (Mutation) ---
# We reach into the object and change memory in-place.
# Side effect: The 'torch_model' object is permanently changed.
#
# We use torch.no_grad() here because we're manually modifying the model's parameters 
# (specifically, setting the bias to zero) outside of a typical training step. 
# Without torch.no_grad(), PyTorch would try to track this operation in its computation graph, 
# which is unnecessary and could potentially interfere with gradient calculations later.
# Wrapping the assignment in torch.no_grad() temporarily disables autograd, 
# ensuring that this mutation is not recorded for backpropagation.
with torch.no_grad():
    torch_model.fc1.bias.fill_(0.0)

# --- FLAX (Functional Update) ---
# We cannot change 'flax_params' (it is frozen/immutable).
# We must create a NEW set of parameters.

# 1. Deepcopy the original parameters

mutable_params = copy.deepcopy(flax_params)

# 2. Modify
mutable_params['params']['Dense_0']['bias'] = jnp.zeros((128,))

# 3. Refreeze (Pack it back up)
new_flax_params = freeze(mutable_params)

# Proof: The old params still exist!
# We have branched the universe. We have 'flax_params' AND 'new_flax_params'.

#### Do you know deepcopy?

In [None]:
# 1. Create a nested list (a list containing another list)
original = [1, [2, 3], 4]

# 2. Make a Shallow Copy and a Deep Copy
shallow = copy.copy(original)
deep = copy.deepcopy(original)

# 3. Modify the nested list in the original
original[1][0] = 'CHANGED'

# 4. Results
print(f"Original: {original}")
print(f"Shallow:  {shallow}  <-- Affected (shares the nested list)")
print(f"Deep:     {deep}     <-- Unaffected (has its own nested list)")

#### Why this happens

**Shallow Copy (copy.copy):** Creates a new list, but references the same inner objects. 
If you change a mutable object inside (like the inner list), the change shows up in both.

**Deep Copy (copy.deepcopy):** Recursively creates copies of the list and everything inside it. 
The new object is completely independent of the original.