# Training an Equinox MLP on Spiral Data

This notebook demonstrates how to train a three-layer MLP (using Equinox) to classify the spiral dataset generated by `spiral_data.py`. The training loop uses JAX and Equinox, and the model is compiled with `eqx.filter_jit` for efficiency.

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import matplotlib.pyplot as plt
from spiral_data import generate_spiral_data
from mlp import MLP

import optax

In [None]:
# Generate spiral dataset
X, y = generate_spiral_data(points_per_class=100, num_classes=3, noise=0.2)
num_classes = 3
# One-hot encode labels
Y = jax.nn.one_hot(y, num_classes)
print(f"X shape: {X.shape}, Y shape: {Y.shape}")

In [None]:
# Loss function and accuracy metric

def compute_loss(model, x, y):
    """Complete this function to return the cross-entropy loss.""" 

def compute_accuracy(model, x, y):
    """Complete this function to return the accuracy."""


In [None]:
# Training step (JIT compiled)
@eqx.filter_jit
def train_step(model, x, y, opt_state, optimizer):
    def loss_fn(model):
        return compute_loss(model, x, y)
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state

In [None]:
# Initialize model and optimizer
key = jax.random.PRNGKey(0)
model = MLP() # Instantiate your MLP model here
optimizer = optax.adam(learning_rate=) # Choose your optimizer and learning_rate.
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

In [None]:
# Training loop
num_epochs = # Set the number of epochs
for epoch in range(num_epochs):
    model, opt_state = train_step(model, X, Y, opt_state, optimizer)
    if (epoch + 1) % 10 == 0:
        loss = compute_loss(model, X, Y)
        acc = compute_accuracy(model, X, Y)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, Accuracy={acc:.4f}")

In [None]:
# Visualize decision boundaries and predictions
h = 0.01
x_min, x_max = float(jnp.min(X[:, 0])) - 0.5, float(jnp.max(X[:, 0])) + 0.5
y_min, y_max = float(jnp.min(X[:, 1])) - 0.5, float(jnp.max(X[:, 1])) + 0.5
xx, yy = jnp.meshgrid(jnp.arange(x_min, x_max, h), jnp.arange(y_min, y_max, h))
grid = jnp.c_[xx.ravel(), yy.ravel()]
logits = jax.vmap(model)(grid)
preds = jnp.argmax(logits, axis=1)
preds = preds.reshape(xx.shape)
plt.figure(figsize=(7, 7))
plt.contourf(xx, yy, preds, alpha=0.3, cmap=plt.cm.rainbow)
for class_number in range(num_classes):
    plt.scatter(X[y == class_number, 0], X[y == class_number, 1], label=f"Class {class_number}", edgecolor='k')
plt.legend()
plt.title("MLP Decision Boundaries on Spiral Data")
plt.xlabel("x1")
plt.ylabel("x2")
plt.axis("equal")
plt.show()
