diff --git a/lectures/jax_nn.md b/lectures/jax_nn.md index 4b4d605..389092a 100644 --- a/lectures/jax_nn.md +++ b/lectures/jax_nn.md @@ -444,8 +444,6 @@ def train_jax_model( θ: list, # Initial parameters (pytree) x: jnp.ndarray, # Training input data y: jnp.ndarray, # Training target data - x_validate: jnp.ndarray, # Validation input data - y_validate: jnp.ndarray, # Validation target data config: Config # contains configuration data ): """ @@ -473,12 +471,12 @@ param_key = jax.random.PRNGKey(1234) θ = initialize_network(param_key, config) # Warmup run to trigger JIT compilation -train_jax_model(θ, x_train, y_train, x_validate, y_validate, config) +train_jax_model(θ, x_train, y_train, config) # Reset and time the actual run θ = initialize_network(param_key, config) start_time = time() -θ = train_jax_model(θ, x_train, y_train, x_validate, y_validate, config) +θ = train_jax_model(θ, x_train, y_train, config) θ[0].W.block_until_ready() # Ensure computation completes jax_runtime = time() - start_time