Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, not to merge] Migrate to Enzyme #127

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk> and contributors"]
version = "0.6.8"
version = "0.7.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.5.17"
Bessels = "0.2.8"
BlockDiagonals = "0.1.7"
ChainRulesCore = "1"
FillArrays = "0.13.0 - 0.13.7, 1"
KernelFunctions = "0.9, 0.10.1"
StaticArrays = "1"
StructArrays = "0.5, 0.6"
Zygote = "0.6.65"
julia = "1.6"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
30 changes: 24 additions & 6 deletions examples/exact_time_learning.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This is an extended version of exact_time_inference.jl. It combines it with
# Optim + ParameterHandling + Zygote to learn the kernel parameters.
# Optim + ParameterHandling + Enzyme to learn the kernel parameters.
# Each of these other packages know nothing about TemporalGPs, they're just general-purpose
# packages which play nicely with TemporalGPs (and AbstractGPs).

Expand All @@ -12,7 +12,7 @@ using TemporalGPs: RegularSpacing
# Load standard packages from the Julia ecosystem
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation
using Enzyme # Algorithmic Differentiation

# Declare model parameters using `ParameterHandling.jl` types.
# var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the
Expand Down Expand Up @@ -42,15 +42,33 @@ y = rand(f(x, params.var_noise));

# Specify an objective function for Optim to minimise in terms of x and y.
# We choose the usual negative log marginal likelihood (NLML).
function objective(params)
function objective(x, y, params)
f = build_gp(params)
return -logpdf(f(x, params.var_noise), y)
end

# Optimise using Optim. Zygote takes a little while to compile.
# In order to compute the gradient with Enzyme, we define the following function:
function enzyme_gradient(x, y, θ, unpack)
# Define shadows
# It is unclear why the x, y, shadows are needed here
# Making these variables `Const` leads to an error
dθ = make_zero(θ)
dx = make_zero(x)
dy = make_zero(y)
autodiff(
Reverse,
(x, y, par, unpack) -> objective(x, y, unpack(par)),
Duplicated(x, dx), Duplicated(y, dy),
Duplicated(θ, dθ),
Const(unpack)
)
return dθ
end

# Optimise using Optim.
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective ∘ unpack, θ)),
θ -> objective(x, y, unpack(θ)),
θ -> enzyme_gradient(x, y, θ, unpack),
flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
Expand Down
6 changes: 0 additions & 6 deletions src/TemporalGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ module TemporalGPs
using AbstractGPs
using Bessels: besseli
using BlockDiagonals
using ChainRulesCore
import ChainRulesCore: rrule
using FillArrays
using LinearAlgebra
using KernelFunctions
using Random
using StaticArrays
using StructArrays
using Zygote

using FillArrays: AbstractFill

Expand All @@ -36,12 +33,9 @@ module TemporalGPs
ApproxPeriodicKernel

# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "harmonise.jl"))
include(joinpath("util", "linear_algebra.jl"))
include(joinpath("util", "scan.jl"))
include(joinpath("util", "zygote_friendly_map.jl"))

include(joinpath("util", "chainrules.jl"))
include(joinpath("util", "gaussian.jl"))
include(joinpath("util", "mul.jl"))
include(joinpath("util", "storage_types.jl"))
Expand Down
101 changes: 36 additions & 65 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,30 @@ end
function build_lgssm(f::LTISDE, x::AbstractVector, Σys::AbstractVector)
m = get_mean(f)
k = get_kernel(f)
s = Zygote.literal_getfield(f, Val(:storage))
s = f.storage
As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, s)
return LGSSM(
GaussMarkovModel(Forward(), As, as, Qs, x0), build_emissions(emission_proj, Σys),
)
end

function build_lgssm(ft::FiniteLTISDE)
f = Zygote.literal_getfield(ft, Val(:f))
x = Zygote.literal_getfield(ft, Val(:x))
Σys = noise_var_to_time_form(x, Zygote.literal_getfield(ft, Val(:Σy)))
f = ft.f
x = ft.x
Σys = noise_var_to_time_form(x, ft.Σy)
return build_lgssm(f, x, Σys)
end

get_mean(f::LTISDE) = get_mean(Zygote.literal_getfield(f, Val(:f)))
get_mean(f::GP) = Zygote.literal_getfield(f, Val(:mean))
get_mean(f::LTISDE) = get_mean(f.f)
get_mean(f::GP) = f.mean

get_kernel(f::LTISDE) = get_kernel(Zygote.literal_getfield(f, Val(:f)))
get_kernel(f::GP) = Zygote.literal_getfield(f, Val(:kernel))
get_kernel(f::LTISDE) = get_kernel(f.f)
get_kernel(f::GP) = f.kernel

function build_emissions(
(Hs, hs)::Tuple{AbstractVector, AbstractVector}, Σs::AbstractVector,
)
Hst = _map(adjoint, Hs)
Hst = map(adjoint, Hs)
return StructArray{get_type(Hst, hs, Σs)}((Hst, hs, Σs))
end

Expand All @@ -114,10 +114,6 @@ function get_type(Hs_prime, hs::AbstractVector{<:AbstractVector}, Σs)
return T
end

@inline function Zygote.wrap_chainrules_output(x::NamedTuple)
return map(Zygote.wrap_chainrules_output, x)
end

# Constructor for combining kernel and mean functions
function lgssm_components(
::ZeroMean, k::Kernel, t::AbstractVector, storage_type::StorageType
Expand All @@ -128,7 +124,7 @@ end
function lgssm_components(
m::AbstractGPs.MeanFunction, k::Kernel, t::AbstractVector, storage_type::StorageType
)
m = collect(mean_vector(m, t)) # `collect` is needed as there are still issues with Zygote and FillArrays.
m = mean_vector(m, t)
As, as, Qs, (Hs, hs), x0 = lgssm_components(k, t, storage_type)
hs = add_proj_mean(hs, m)

Expand All @@ -146,9 +142,9 @@ end
function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T}) where {T}
P = Symmetric(x0.P)
t = vcat([first(t) - 1], t)
As = _map(Δt -> time_exp(F, T(Δt)), diff(t))
As = map(Δt -> time_exp(F, T(Δt)), diff(t))
as = Fill(Zeros{T}(size(first(As), 1)), length(As))
Qs = _map(A -> P - A * P * A', As)
Qs = map(A -> P - A * P * A', As)
Hs = Fill(H, length(As))
hs = Fill(zero(T), length(As))
As, as, Qs, Hs, hs
Expand All @@ -158,14 +154,16 @@ function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::Union{StepRange
P = Symmetric(x0.P)
A = time_exp(F, T(step(t)))
As = Fill(A, length(t))
as = @ignore_derivatives(Fill(Zeros{T}(size(F, 1)), length(t)))
as = Fill(Zeros{T}(size(F, 1)), length(t))
Q = Symmetric(P) - A * Symmetric(P) * A'
Qs = Fill(Q, length(t))
Hs = Fill(H, length(t))
hs = Fill(zero(T), length(As))
As, as, Qs, Hs, hs
end

time_exp(A, t) = exp(A * t)

function lgssm_components(
k::SimpleKernel, t::AbstractVector{<:Real}, storage::StorageType{T},
) where {T<:Real}
Expand Down Expand Up @@ -332,49 +330,49 @@ end
# Scaled

function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real}
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
_k = k.kernel
σ² = k.σ²
F, q, H = to_sde(_k, storage)
σ = sqrt(convert(eltype(storage), only(σ²)))
return F, σ^2 * q, σ * H
end

stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)
stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(k.kernel, storage)

function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
_k = k.kernel
σ² = k.σ²
As, as, Qs, emission_proj, x0 = lgssm_components(_k, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(σ²)))
return As, as, Qs, _scale_emission_projections(emission_proj, σ), x0
end

function _scale_emission_projections((Hs, hs)::Tuple{AbstractVector, AbstractVector}, σ::Real)
return _map(H->σ * H, Hs), _map(h->σ * h, hs)
return map(H->σ * H, Hs), map(h->σ * h, hs)
end

function _scale_emission_projections((Cs, cs, Hs, hs), σ)
return (Cs, cs, _map(H -> σ * H, Hs), _map(h -> σ * h, hs))
return (Cs, cs, map(H -> σ * H, Hs), map(h -> σ * h, hs))
end

# Stretched

function to_sde(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
_k = k.kernel
s = k.transform.s
F, q, H = to_sde(_k, storage)
return F * only(s), q, H
end

stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)
stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(k.kernel, storage)

function lgssm_components(
k::TransformedKernel{<:Kernel, <:ScaleTransform},
ts::AbstractVector,
storage_type::StorageType,
)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
_k = k.kernel
s = k.transform.s
return lgssm_components(_k, apply_stretch(s[1], ts), storage_type)
end

Expand All @@ -383,9 +381,9 @@ apply_stretch(a, ts::AbstractVector{<:Real}) = a * ts
apply_stretch(a, ts::StepRangeLen) = a * ts

function apply_stretch(a, ts::RegularSpacing)
t0 = Zygote.literal_getfield(ts, Val(:t0))
Δt = Zygote.literal_getfield(ts, Val(:Δt))
N = Zygote.literal_getfield(ts, Val(:N))
t0 = ts.t0
Δt = ts.Δt
N = ts.N
return RegularSpacing(a * t0, a * Δt, N)
end

Expand Down Expand Up @@ -425,9 +423,9 @@ function lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::Storag
emission_proj_kernels = getindex.(lgssms, 4)
x0_kernels = getindex.(lgssms, 5)

As = _map(block_diagonal, As_kernels...)
as = _map(vcat, as_kernels...)
Qs = _map(block_diagonal, Qs_kernels...)
As = map(block_diagonal, As_kernels...)
as = map(vcat, as_kernels...)
Qs = map(block_diagonal, Qs_kernels...)
emission_projections = _sum_emission_projections(emission_proj_kernels...)
x0 = Gaussian(mapreduce(x -> getproperty(x, :m), vcat, x0_kernels), block_diagonal(getproperty.(x0_kernels, :P)...))
return As, as, Qs, emission_projections, x0
Expand All @@ -444,10 +442,10 @@ function _sum_emission_projections(
cs = getindex.(Cs_cs_Hs_hs, 2)
Hs = getindex.(Cs_cs_Hs_hs, 3)
hs = getindex.(Cs_cs_Hs_hs, 4)
C = _map(vcat, Cs...)
C = map(vcat, Cs...)
c = sum(cs)
H = _map(block_diagonal, Hs...)
h = _map(vcat, hs...)
H = map(block_diagonal, Hs...)
h = map(vcat, hs...)
return C, c, H, h
end

Expand All @@ -460,36 +458,9 @@ function block_diagonal(As::AbstractMatrix{T}...) where {T}
return hvcat(ntuple(_ -> nblocks, nblocks), Xs...)
end

function ChainRulesCore.rrule(::typeof(block_diagonal), As::AbstractMatrix...)
szs = size.(As)
row_szs = (0, cumsum(first.(szs))...)
col_szs = (0, cumsum(last.(szs))...)
block_diagonal_rrule(Δ::AbstractThunk) = block_diagonal_rrule(unthunk(Δ))
function block_diagonal_rrule(Δ)
ΔAs = ntuple(length(As)) do i
Δ[(row_szs[i]+1):row_szs[i+1], (col_szs[i]+1):col_szs[i+1]]
end
return NoTangent(), ΔAs...
end
return block_diagonal(As...), block_diagonal_rrule
end

function block_diagonal(As::SMatrix...)
nblocks = length(As)
sizes = size.(As)
Xs = [i == j ? As[i] : zeros(SMatrix{sizes[j][1], sizes[i][2]}) for i in 1:nblocks, j in 1:nblocks]
return hcat(Base.splat(vcat).(eachrow(Xs))...)
end

function ChainRulesCore.rrule(::typeof(block_diagonal), As::SMatrix...)
szs = size.(As)
row_szs = (0, cumsum(first.(szs))...)
col_szs = (0, cumsum(last.(szs))...)
function block_diagonal_rrule(Δ)
ΔAs = ntuple(length(As)) do i
Δ[SVector{szs[i][1]}((row_szs[i]+1):row_szs[i+1]), SVector{szs[i][2]}((col_szs[i]+1):col_szs[i+1])]
end
return NoTangent(), ΔAs...
end
return block_diagonal(As...), block_diagonal_rrule
end
18 changes: 9 additions & 9 deletions src/gp/posterior_lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ function AbstractGPs.marginals(fx::FinitePosteriorLTISDE)
model_post = replace_observation_noise_cov(posterior(model, ys), σ²s_pr_full)
return destructure(x, map(marginals, marginals(model_post))[pr_indices])
else
f = Zygote.literal_getfield(fx, Val(:f))
prior = Zygote.literal_getfield(f, Val(:prior))
x = Zygote.literal_getfield(fx, Val(:x))
data = Zygote.literal_getfield(f, Val(:data))
Σy = Zygote.literal_getfield(data, Val(:Σy))
Σy_diag = Zygote.literal_getfield(Σy, Val(:diag))
y = Zygote.literal_getfield(data, Val(:y))

Σy_new = Zygote.literal_getfield(fx, Val(:Σy))
f = fx.f
prior = f.prior
x = fx.x
data = f.data
Σy = data.Σy
Σy_diag = Σy.diag
y = data.y

Σy_new = fx.Σy

model = build_lgssm(AbstractGPs.FiniteGP(prior, x, Σy))
Σys_new = noise_var_to_time_form(x, Σy_new)
Expand Down