Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lectures/ifp_dl.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ From here the approach is

1. Replace $\Sigma$ with $\{\sigma(\cdot, \theta) \,:\, \theta \in \Theta\}$
where $\sigma(\cdot, \theta)$ is an ANN with parameter vector $\theta$
2. Replace the objective function with $M(\theta) := \int v_{\sigma(\cdot, \theta)} (a_0)$
2. Replace the objective function with $M(\theta) := v_{\sigma(\cdot, \theta)} (a_0)$
3. Replace $M$ with a Monte Carlo approximation $\hat M$
4. Use gradient ascent to maximize $\hat M(\theta)$ over $\theta$.

Expand Down
31 changes: 7 additions & 24 deletions lectures/jax_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,17 +453,12 @@ def train_jax_model(
Train model using gradient descent.

"""
def update(θ, _):
train_loss = loss_fn(θ, x, y)
val_loss = loss_fn(θ, x_validate, y_validate)
def update(_, θ):
θ_new = update_parameters(θ, x, y, config)
accumulate = train_loss, val_loss
return θ_new, accumulate
return θ_new

θ_final, (training_losses, validation_losses) = jax.lax.scan(
update, θ, None, length=config.epochs
)
return θ_final, training_losses, validation_losses
θ_final = jax.lax.fori_loop(0, config.epochs, update, θ)
return θ_final
```

### Execution
Expand All @@ -473,21 +468,18 @@ Let's run our code and see how it goes.
We'll reuse the data we generated earlier.

```{code-cell} ipython3
# Reset parameter vector
config = Config()
param_key = jax.random.PRNGKey(1234)
θ = initialize_network(param_key, config)
```

```{code-cell} ipython3
# Warmup run to trigger JIT compilation
train_jax_model(θ, x_train, y_train, x_validate, y_validate, config)

# Reset and time the actual run
θ = initialize_network(param_key, config)
start_time = time()
θ, training_loss, validation_loss = train_jax_model(
θ, x_train, y_train, x_validate, y_validate, config
)
θ = train_jax_model(θ, x_train, y_train, x_validate, y_validate, config)
θ[0].W.block_until_ready() # Ensure computation completes
jax_runtime = time() - start_time

Expand All @@ -499,16 +491,7 @@ print(f"Final MSE on validation data = {jax_mse:.6f}")

Despite the simplicity of our implementation, we actually perform slightly better than Keras.

This figure shows MSE across iterations:

```{code-cell} ipython3
fig, ax = plt.subplots()
ax.plot(range(len(validation_loss)), validation_loss, label='validation loss')
ax.legend()
plt.show()
```

Here’s a visualization of the quality of our fit.
Here's a visualization of the quality of our fit.

```{code-cell} ipython3
fig, ax = plt.subplots()
Expand Down
Loading