Skip to content

Commit

Permalink
Migrate first example to Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace committed Feb 11, 2024
1 parent 94f9823 commit a2ff622
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
30 changes: 24 additions & 6 deletions examples/exact_time_learning.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is an extended version of exact_time_inference.jl. It combines it with
# Optim + ParameterHandling + Zygote to learn the kernel parameters.
# Optim + ParameterHandling + Enzyme to learn the kernel parameters.
# Each of these other packages know nothing about TemporalGPs, they're just general-purpose
# packages which play nicely with TemporalGPs (and AbstractGPs).

Expand All @@ -12,7 +12,7 @@ using TemporalGPs: RegularSpacing
# Load standard packages from the Julia ecosystem
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation
using Enzyme # Algorithmic Differentiation

# Declare model parameters using `ParameterHandling.jl` types.
# var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the
Expand Down Expand Up @@ -42,15 +42,33 @@ y = rand(f(x, params.var_noise));

# Specify an objective function for Optim to minimise in terms of x and y.
# We choose the usual negative log marginal likelihood (NLML).
function objective(params)
function objective(x, y, params)
f = build_gp(params)
return -logpdf(f(x, params.var_noise), y)
end

# Optimise using Optim. Zygote takes a little while to compile.
# In order to compute the gradient with Enzyme, we define the following function:
function enzyme_gradient(x, y, θ, unpack)
# Define shadows
# It is unclear why the x, y, shadows are needed here
# Making these variables `Const` leads to an error
= make_zero(θ)
dx = make_zero(x)
dy = make_zero(y)
autodiff(
Reverse,
(x, y, par, unpack) -> objective(x, y, unpack(par)),
Duplicated(x, dx), Duplicated(y, dy),
Duplicated(θ, dθ),
Const(unpack)
)
return
end

# Optimise using Optim.
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective unpack, θ)),
θ -> objective(x, y, unpack(θ)),
θ -> enzyme_gradient(x, y, θ, unpack),
flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
Expand Down

0 comments on commit a2ff622

Please sign in to comment.