This is an extended version of exact_space_time_inference.jl. It combines it with
Optim + ParameterHandling + Zygote to learn the kernel parameters.
If you understand how to use Optim + ParameterHandling + Zygote for an AbstractGP,
e.g. that shown on the README for this package, and how exact_space_time_inference.jl
works, then you should understand this file.

In [1]:
using AbstractGPs
using KernelFunctions
using TemporalGPs

using TemporalGPs: Separable, RectilinearGrid

Load standard packages from the Julia ecosystem

In [2]:
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation

Declare model parameters using `ParameterHandling.jl` types.

In [3]:
flat_initial_params, unflatten = ParameterHandling.flatten((
    var_kernel = positive(0.6),
    λ_space = positive(2.5),
    λ_time = positive(2.5),
    var_noise = positive(0.1),
));

Construct a function to unpack flattened parameters and pull out the raw values.

In [4]:
unpack = ParameterHandling.value ∘ unflatten;
params = unpack(flat_initial_params);

function build_gp(params)
    k_space = SEKernel() ∘ ScaleTransform(params.λ_space)
    k_time = Matern52Kernel() ∘ ScaleTransform(params.λ_time)
    k = params.var_kernel * Separable(k_space, k_time)
    return to_sde(GP(k), ArrayStorage(Float64))
end

build_gp (generic function with 1 method)

Construct a rectilinear grid of points in space and time.
Exact inference only works for such grids.
Times must be increasing, points in space can be anywhere.

In [5]:
N = 50;
T = 1_000;
points_in_space = collect(range(-3.0, 3.0; length=N));
points_in_time = RegularSpacing(0.0, 0.01, T);
x = RectilinearGrid(points_in_space, points_in_time);
y = rand(build_gp(params)(x, 1e-4));

Specify an objective function for Optim to minimise in terms of x and y.
We choose the usual negative log marginal likelihood (NLML).

In [6]:
function objective(params)
    f = build_gp(params)
    return -logpdf(f(x, params.var_noise), y)
end

objective (generic function with 1 method)

Optimise using Optim. Takes a little while to compile because Zygote.

In [7]:
training_results = Optim.optimize(
    objective ∘ unpack,
    θ -> only(Zygote.gradient(objective ∘ unpack, θ)),
    flat_initial_params + randn(4), # Add some noise to make learning non-trivial
    BFGS(
        alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
        linesearch = Optim.LineSearches.BackTracking(),
    ),
    Optim.Options(show_trace = true);
    inplace=false,
);

Iter     Function value   Gradient norm 
     0    -1.269103e+04     1.958044e+04
 * time: 9.393692016601562e-5
     1    -3.660335e+04     2.372930e+04
 * time: 6.185869932174683
     2    -6.040538e+04     2.378800e+04
 * time: 11.515004873275757
     3    -6.110041e+04     2.379605e+04
 * time: 18.171354055404663
     4    -6.308566e+04     2.377875e+04
 * time: 23.646811962127686
     5    -8.594268e+04     2.327451e+04
 * time: 28.970414876937866
     6    -1.075644e+05     2.209347e+04
 * time: 35.12078785896301
     7    -1.276988e+05     1.930850e+04
 * time: 40.23667001724243
     8    -1.435205e+05     1.252901e+04
 * time: 45.67296504974365
     9    -1.486239e+05     5.320639e+03
 * time: 51.729368925094604
    10    -1.490844e+05     4.607494e+03
 * time: 57.3311550617218
    11    -1.495813e+05     3.362683e+03
 * time: 62.34265089035034
    12    -1.498416e+05     5.268581e+03
 * time: 68.16419291496277
    13    -1.499946e+05     7.867433e+02
 * time: 73.46008205413818


Extracting the final values of the parameters.
Should be close to truth.

In [8]:
final_params = unpack(training_results.minimizer)

(var_kernel = 0.5269485845960719, λ_space = 2.5201829315334705, λ_time = 2.5822932473777636, var_noise = 0.00010072355943937465)

Construct the posterior as per usual.

In [9]:
f_post = posterior(build_gp(final_params)(x, final_params.var_noise), y);

Specify some locations at which to make predictions.

In [10]:
T_pr = 1200;
points_in_time_pr = RegularSpacing(0.0, 0.01, T_pr);
x_pr = RectilinearGrid(points_in_space, points_in_time_pr);

Compute the exact posterior marginals at `x_pr`.
This isn't optimised at present, so might take a little while.

In [11]:
f_post_marginals = marginals(f_post(x_pr));
m_post_marginals = mean.(f_post_marginals);
σ_post_marginals = std.(f_post_marginals);

Visualise the posterior marginals. We don't do this during in CI because it causes
problems.

In [12]:
if get(ENV, "TESTING", "FALSE") == "FALSE"
    using Plots
    savefig(
        plot(
            heatmap(reshape(m_post_marginals, N, T_pr)),
            heatmap(reshape(σ_post_marginals, N, T_pr));
            layout=(1, 2),
        ),
        "posterior.png",
    );
end

"/home/runner/work/TemporalGPs.jl/TemporalGPs.jl/docs/src/examples/exact_space_time_learning/posterior.png"

<hr />
<h6>Package and system information</h6>
<details>
<summary>Package information (click to expand)</summary>
<pre>
Status &#96;~/work/TemporalGPs.jl/TemporalGPs.jl/examples/exact_space_time_learning/Project.toml&#96;
  &#91;99985d1d&#93; AbstractGPs v0.5.16
  &#91;ec8451be&#93; KernelFunctions v0.10.55
  &#91;98b081ad&#93; Literate v2.14.0
  &#91;429524aa&#93; Optim v1.7.5
  &#91;2412ca09&#93; ParameterHandling v0.4.6
  &#91;91a5bcdd&#93; Plots v1.38.10
  &#91;e155a3c4&#93; TemporalGPs v0.6.3 &#96;/home/runner/work/TemporalGPs.jl/TemporalGPs.jl#43bd0d9&#96;
  &#91;e88e6eb3&#93; Zygote v0.6.60
</pre>
To reproduce this notebook's package environment, you can
<a href="./Manifest.toml">
download the full Manifest.toml</a>.
</details>
<details>
<summary>System information (click to expand)</summary>
<pre>
Julia Version 1.8.5
Commit 17cfb8e65ea &#40;2023-01-08 06:45 UTC&#41;
Platform Info:
  OS: Linux &#40;x86_64-linux-gnu&#41;
  CPU: 2 × Intel&#40;R&#41; Xeon&#40;R&#41; Platinum 8370C CPU @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 &#40;ORCJIT, icelake-server&#41;
  Threads: 1 on 2 virtual cores
Environment:
  JULIA_DEBUG &#61; Documenter
  JULIA_LOAD_PATH &#61; :/home/runner/.julia/packages/JuliaGPsDocs/e8FS0/src
</pre>
</details>

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*