This is an extended version of approx_space_time_inference.jl. It combines it with
Optim + ParameterHandling + Zygote to learn the kernel parameters.
If you understand how to use Optim + ParameterHandling + Zygote for an AbstractGP,
e.g. that shown on the README for this package, and how approx_space_time_inference.jl
works, then you should understand this file.

In [1]:
using AbstractGPs
using KernelFunctions
using TemporalGPs

using TemporalGPs: Separable, approx_posterior_marginals, RegularInTime

Load standard packages from the Julia ecosystem

In [2]:
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation

using ParameterHandling: flatten

Declare model parameters using `ParameterHandling.jl` types.

In [3]:
flat_initial_params, unflatten = flatten((
    var_kernel = positive(0.6),
    λ_space = positive(0.5),
    λ_time = positive(0.1),
    var_noise = positive(0.1),
));

Construct a function to unpack flattened parameters and pull out the raw values.

In [4]:
unpack = ParameterHandling.value ∘ unflatten;
params = unpack(flat_initial_params);

function build_gp(params)
    k_space = SEKernel() ∘ ScaleTransform(params.λ_space)
    k_time = Matern52Kernel() ∘ ScaleTransform(params.λ_time)
    k = params.var_kernel * Separable(k_space, k_time)
    return to_sde(GP(k), ArrayStorage(Float64))
end

build_gp (generic function with 1 method)

Construct inputs. Spatial locations change at each point in time.
Also works with RectilinearGrids of inputs.
Times must be increasing, points in space can be anywhere.

In [5]:
N = 50;
T = 1000;
points_in_space = [randn(N) for _ in 1:T];
points_in_time = RegularSpacing(0.0, 0.1, T);
x = RegularInTime(points_in_time, points_in_space);

Since it's not straightforward to generate samples from this GP at `x`, use a known
function, under a bit of iid noise.

In [6]:
xs = collect(x);
y = sin.(first.(xs)) .+ cos.(last.(xs)) + sqrt.(params.var_noise) .* randn(length(xs));

Spatial pseudo-point inputs.

In [7]:
z_r = collect(range(-3.0, 3.0; length=5));

Specify an objective function for Optim to minimise in terms of x and y.
We choose the usual negative log marginal likelihood (NLML).

In [8]:
function objective(params)
    f = build_gp(params)
    return -elbo(f(x, params.var_noise), y, z_r)
end

objective (generic function with 1 method)

Optimise using Optim. Takes a little while to compile because Zygote.

In [9]:
training_results = Optim.optimize(
    objective ∘ unpack,
    θ -> only(Zygote.gradient(objective ∘ unpack, θ)),
    flat_initial_params,
    BFGS(
        alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
        linesearch = Optim.LineSearches.BackTracking(),
    ),
    Optim.Options(show_trace = true);
    inplace=false,
);

Iter     Function value   Gradient norm 
     0     3.136132e+04     1.850971e+05
 * time: 9.202957153320312e-5
     1     2.592475e+04     7.247775e+05
 * time: 74.0780439376831
     2     2.160486e+04     7.244293e+05
 * time: 161.24048805236816
     3     2.160262e+04     7.231122e+05
 * time: 259.7037661075592
     4     2.160262e+04     7.231122e+05
 * time: 371.77883291244507
     5     2.160262e+04     7.231122e+05
 * time: 460.41095900535583
     6     2.160262e+04     7.231122e+05
 * time: 462.186017036438


Extracting the final values of the parameters.
Should be close to truth.

In [10]:
final_params = unpack(training_results.minimizer);

final_params = unpack(flat_initial_params);

Locations in space at which to make predictions. Assumed to be the same at each point in
time, but this assumption could easily be relaxed.

In [11]:
N_pr = 150;
x_r_pr = range(-5.0, 5.0; length=N_pr);

Compute the approximate posterior marginals.

In [12]:
fx_final = build_gp(final_params)(x, final_params.var_noise)
f_post_marginals = approx_posterior_marginals(dtc, fx_final, y, z_r, x_r_pr);
m_post_marginals = mean.(f_post_marginals);
σ_post_marginals = std.(f_post_marginals);

Visualise the posterior marginals. We don't do this during in CI because it causes
problems.

In [13]:
if get(ENV, "TESTING", "FALSE") == "FALSE"
    using Plots
    savefig(
        plot(
            heatmap(reshape(m_post_marginals, N_pr, T)),
            heatmap(reshape(σ_post_marginals, N_pr, T));
            layout=(1, 2),
        ),
        "posterior.png",
    );
end

"/home/runner/work/TemporalGPs.jl/TemporalGPs.jl/docs/src/examples/approx_space_time_learning/posterior.png"

<hr />
<h6>Package and system information</h6>
<details>
<summary>Package information (click to expand)</summary>
<pre>
Status &#96;~/work/TemporalGPs.jl/TemporalGPs.jl/examples/approx_space_time_learning/Project.toml&#96;
  &#91;99985d1d&#93; AbstractGPs v0.5.16
  &#91;ec8451be&#93; KernelFunctions v0.10.55
  &#91;98b081ad&#93; Literate v2.14.0
  &#91;429524aa&#93; Optim v1.7.5
  &#91;2412ca09&#93; ParameterHandling v0.4.6
  &#91;91a5bcdd&#93; Plots v1.38.10
  &#91;e155a3c4&#93; TemporalGPs v0.6.3 &#96;/home/runner/work/TemporalGPs.jl/TemporalGPs.jl#43bd0d9&#96;
  &#91;e88e6eb3&#93; Zygote v0.6.60
</pre>
To reproduce this notebook's package environment, you can
<a href="./Manifest.toml">
download the full Manifest.toml</a>.
</details>
<details>
<summary>System information (click to expand)</summary>
<pre>
Julia Version 1.8.5
Commit 17cfb8e65ea &#40;2023-01-08 06:45 UTC&#41;
Platform Info:
  OS: Linux &#40;x86_64-linux-gnu&#41;
  CPU: 2 × Intel&#40;R&#41; Xeon&#40;R&#41; Platinum 8370C CPU @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 &#40;ORCJIT, icelake-server&#41;
  Threads: 1 on 2 virtual cores
Environment:
  JULIA_DEBUG &#61; Documenter
  JULIA_LOAD_PATH &#61; :/home/runner/.julia/packages/JuliaGPsDocs/e8FS0/src
</pre>
</details>

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*