- Demonstrates the differences between PyTorch and Flax machine learning frameworks
- Shows how to:
  - Set up a simple linear regression dataset
  - Define an MLP model in both PyTorch (stateful) and Flax (stateless, functional)
  - Prepare data for each framework (NumPy, Torch, JAX arrays)
- Highlights model definition, instantiation, and initialization approaches in both libraries


In [None]:
import jax
import jax.numpy as jnp
import torch
import torch.optim as optim
import flax.linen as nn
from flax.training import train_state
import optax
import numpy as np
import torch.nn as tnn

# Data: y = 2x + 1
x_numpy = np.random.randn(100, 1).astype(np.float32)
y_numpy = 2 * x_numpy + 1 + np.random.randn(100, 1) * 0.1

# Prepare for PyTorch
x_torch = torch.from_numpy(x_numpy)
y_torch = torch.from_numpy(y_numpy)

# Prepare for JAX
x_jax = jnp.array(x_numpy)
y_jax = jnp.array(y_numpy)

#### Define the model

In [None]:
# --- PYTORCH (Stateful Object) ---
class TorchMLP(tnn.Module):
    def __init__(self):
        super().__init__()
        # Weights are initialized immediately and stored in self.layer
        self.layer = tnn.Linear(1, 1)

    def forward(self, x):
        return self.layer(x)

# --- FLAX (Stateless Blueprint) ---
class FlaxMLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        # No weights here. Just a description.
        return nn.Dense(1)(x)

#### Instantiate the Pytorch model

In [None]:
# 1. Instantiate (Weights created)
torch_model = TorchMLP()

# 2. Optimizer attaches to the model's parameters (Reference)
torch_opt = optim.Adam(torch_model.parameters(), lr=0.1)

print("PyTorch: Model and Optimizer ready.")

#### Instantiate the Flax model

In [None]:
# 1. Instantiate Blueprint
flax_model = FlaxMLP()

# 2. Initialize Parameters (Abstract Evaluation)
key = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 1))
params = flax_model.init(key, dummy_input)['params']

# 3. Define Optimizer (Pure Config)
tx = optax.adam(learning_rate=0.1)

# 4. Bundle into TrainState
# This object holds the CURRENT snapshot of training
state = train_state.TrainState.create(
    apply_fn=flax_model.apply,
    params=params,
    tx=tx,
)

print("Flax: TrainState ready.")

#### Training step in Pytorch


In [None]:
def torch_train_step(x, y):
    # 1. Zero Gradients (Clear history)
    torch_opt.zero_grad()
    
    # 2. Forward Pass
    pred = torch_model(x)
    
    # 3. Loss
    loss = torch.mean((pred - y) ** 2)
    
    # 4. Backward (Calculates grads and stores them IN the model)
    loss.backward()
    
    # 5. Step (Updates weights IN the model)
    torch_opt.step()
    
    return loss.item()

#### Training step in Flax

In [None]:
# We JIT compile this function for speed
@jax.jit
def flax_train_step(state, x, y):
    
    # 1. Define Loss Function (Must be a function of params)
    def loss_fn(params):
        # Apply the model using the passed parameters
        pred = state.apply_fn({'params': params}, x)
        loss = jnp.mean((pred - y) ** 2)
        return loss

    # 2. Calculate Gradients (Explicit)
    # value_and_grad returns (loss_value, gradients)
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    
    # 3. Update State
    # 'apply_gradients' uses the optimizer logic (Adam) to create
    # a NEW set of parameters and returns a NEW TrainState.
    new_state = state.apply_gradients(grads=grads)
    
    return new_state, loss

#### Training loop in Pytorch

In [None]:
print("--- PyTorch Training ---")
for i in range(20):
    loss = torch_train_step(x_torch, y_torch)
    if i % 2 == 0:
        print(f"Step {i}, Loss: {loss:.4f}")

# Access final weights
print(f"Final Weights: {torch_model.layer.weight.data}")

#### Training loop in Flax

In [None]:
print("\n--- Flax Training ---")
for i in range(20):
    # CRITICAL: We must assign the NEW state back to the variable 'state'
    # If we didn't, the model would never learn (we'd keep using the old state).
    state, loss = flax_train_step(state, x_jax, y_jax)
    
    if i % 2 == 0:
        print(f"Step {i}, Loss: {loss:.4f}")

# Access final weights
print(f"Final Weights: {state.params['Dense_0']['kernel']}")