From a9b83d60c96487ab423efa48f8ac6d32ca400742 Mon Sep 17 00:00:00 2001 From: Chase Coleman Date: Sun, 7 Dec 2025 02:32:24 +0000 Subject: [PATCH] BUG: Remove x_validate and y_validate as they're unnecessary --- lectures/jax_nn.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lectures/jax_nn.md b/lectures/jax_nn.md index 4b4d605..389092a 100644 --- a/lectures/jax_nn.md +++ b/lectures/jax_nn.md @@ -444,8 +444,6 @@ def train_jax_model( θ: list, # Initial parameters (pytree) x: jnp.ndarray, # Training input data y: jnp.ndarray, # Training target data - x_validate: jnp.ndarray, # Validation input data - y_validate: jnp.ndarray, # Validation target data config: Config # contains configuration data ): """ @@ -473,12 +471,12 @@ param_key = jax.random.PRNGKey(1234) θ = initialize_network(param_key, config) # Warmup run to trigger JIT compilation -train_jax_model(θ, x_train, y_train, x_validate, y_validate, config) +train_jax_model(θ, x_train, y_train, config) # Reset and time the actual run θ = initialize_network(param_key, config) start_time = time() -θ = train_jax_model(θ, x_train, y_train, x_validate, y_validate, config) +θ = train_jax_model(θ, x_train, y_train, config) θ[0].W.block_until_ready() # Ensure computation completes jax_runtime = time() - start_time