Skip to content

Commit

Permalink
Merge pull request #105 from JuliaGaussianProcesses/tgf/kernel_product
Browse files Browse the repository at this point in the history
Implement `KernelProduct` SDE representation.
  • Loading branch information
theogf committed Apr 18, 2023
2 parents 03c0961 + 937e9e1 commit fa8b4af
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk> and contributors"]
version = "0.6.2"
version = "0.6.3"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
87 changes: 63 additions & 24 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,44 +143,38 @@ end

# Generic constructors for base kernels.

function lgssm_components(
k::SimpleKernel, t::AbstractVector{<:Real}, storage::StorageType{T},
) where {T<:Real}

# Compute stationary distribution and sde.
x0 = stationary_distribution(k, storage)
P = x0.P
F, _, H = to_sde(k, storage)

# Use stationary distribution + sde to compute finite-dimensional Gauss-Markov model.
function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T}) where {T}
P = Symmetric(x0.P)
t = vcat([first(t) - 1], t)
As = _map(Δt -> time_exp(F, T(Δt)), diff(t))
as = Fill(Zeros{T}(size(first(As), 1)), length(As))
Qs = _map(A -> Symmetric(P) - A * Symmetric(P) * A', As)
Qs = _map(A -> P - A * P * A', As)
Hs = Fill(H, length(As))
hs = Fill(zero(T), length(As))
emission_projections = (Hs, hs)

return As, as, Qs, emission_projections, x0
As, as, Qs, Hs, hs
end

function lgssm_components(
k::SimpleKernel, t::Union{StepRangeLen, RegularSpacing}, storage_type::StorageType{T},
) where {T<:Real}

# Compute stationary distribution and sde.
x0 = stationary_distribution(k, storage_type)
P = x0.P
F, _, H = to_sde(k, storage_type)

# Use stationary distribution + sde to compute finite-dimensional Gauss-Markov model.
function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::Union{StepRangeLen, RegularSpacing}, ::StorageType{T}) where {T}
P = Symmetric(x0.P)
A = time_exp(F, T(step(t)))
As = Fill(A, length(t))
as = @ignore_derivatives(Fill(Zeros{T}(size(F, 1)), length(t)))
Q = Symmetric(P) - A * Symmetric(P) * A'
Qs = Fill(Q, length(t))
Hs = Fill(H, length(t))
hs = Fill(zero(T), length(As))
As, as, Qs, Hs, hs
end

function lgssm_components(
k::SimpleKernel, t::AbstractVector{<:Real}, storage::StorageType{T},
) where {T<:Real}

# Compute stationary distribution and sde.
x0 = stationary_distribution(k, storage)
# Use stationary distribution + sde to compute finite-dimensional Gauss-Markov model.
As, as, Qs, Hs, hs = broadcast_components(to_sde(k, storage), x0, t, storage)

emission_projections = (Hs, hs)

return As, as, Qs, emission_projections, x0
Expand Down Expand Up @@ -265,6 +259,16 @@ end

# Scaled

function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real}
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
F, q, H = to_sde(_k, storage)
σ = sqrt(convert(eltype(storage), only(σ²)))
return F, σ^2 * q, σ * H
end

stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)

function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
Expand All @@ -283,6 +287,15 @@ end

# Stretched

function to_sde(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
F, q, H = to_sde(_k, storage)
return F * only(s), q, H
end

stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)

function lgssm_components(
k::TransformedKernel{<:Kernel, <:ScaleTransform},
ts::AbstractVector,
Expand All @@ -304,6 +317,32 @@ function apply_stretch(a, ts::RegularSpacing)
return RegularSpacing(a * t0, a * Δt, N)
end

# Product

function lgssm_components(k::KernelProduct, ts::AbstractVector, storage::StorageType)
sde_kernels = to_sde.(k.kernels, Ref(storage))
F_kernels = getindex.(sde_kernels, 1)
F = foldl(_kron_add, F_kernels)
q_kernels = getindex.(sde_kernels, 2)
q = kron(q_kernels...)
H_kernels = getindex.(sde_kernels, 3)
H = kron(H_kernels...)

x0_kernels = stationary_distribution.(k.kernels, Ref(storage))
m_kernels = getproperty.(x0_kernels, :m)
m = kron(m_kernels...)
P_kernels = getproperty.(x0_kernels, :P)
P = kron(P_kernels...)

x0 = Gaussian(m, P)
As, as, Qs, Hs, hs = broadcast_components((F, q, H), x0, ts, storage)
emission_projections = (Hs, hs)
return As, as, Qs, emission_projections, x0
end

_kron_add(A::AbstractMatrix, B::AbstractMatrix) = kron(A, I(size(B,1))) + kron(I(size(A,1)), B)
_kron_add(A::SMatrix{M,M}, B::SMatrix{N,N}) where {M, N} = kron(A, SMatrix{N, N}(I(N))) + kron(SMatrix{M,M}(I(M)), B)

# Sum

function lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::StorageType)
Expand Down
66 changes: 44 additions & 22 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,39 @@ println("lti_sde:")
(; name="stretched-λ=", val=Matern32Kernel() ScaleTransform(λ))
end,

# TEST_TOFIX
# Gradients should be fixed on those composites.
# Error is mostly due do an incompatibility of Tangents
# between Zygote and FiniteDifferences.

# Product kernels
(
name="prod-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ScaleTransform(0.1) *
Matern32Kernel() ScaleTransform(1.1),
skip_grad=true,
),
(
name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel",
val = 3.0 * Matern32Kernel() *
Matern52Kernel() *
ConstantKernel(),
skip_grad=true,
),

# Summed kernels.
# (
# name="sum-Matern12Kernel-Matern32Kernel",
# val=1.5 * Matern12Kernel() ∘ ScaleTransform(0.1) +
# 0.3 * Matern32Kernel() ∘ ScaleTransform(1.1),
# ),
# (
# name="sum-Matern32Kernel-Matern52Kernel-ConstantKernel",
# val = 2.0 * Matern32Kernel() +
# 0.5 * Matern52Kernel() +
# 1.0 * ConstantKernel(),
# ),
(
name="sum-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ScaleTransform(0.1) +
0.3 * Matern32Kernel() ScaleTransform(1.1),
skip_grad=true,
),
(
name="sum-Matern32Kernel-Matern52Kernel-ConstantKernel",
val = 2.0 * Matern32Kernel() +
0.5 * Matern52Kernel() +
1.0 * ConstantKernel(),
skip_grad=true,
),
)

# Construct a Gauss-Markov model with either dense storage or static storage.
Expand Down Expand Up @@ -154,16 +174,18 @@ println("lti_sde:")
end

# Just need to ensure we can differentiate through construction properly.
test_zygote_grad(
_construction_tester,
f_naive,
storage.val,
σ².val,
t.val;
check_inferred=false,
rtol=1e-6,
atol=1e-6,
)
if !(hasfield(typeof(kernel), :skip_grad) && kernel.skip_grad)
test_zygote_grad(
_construction_tester,
f_naive,
storage.val,
σ².val,
t.val;
check_inferred=false,
rtol=1e-6,
atol=1e-6,
)
end
end
end
end
8 changes: 8 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ function to_vec(f::GP)
return gp_vec, GP_from_vec
end

function to_vec(k::ConstantKernel)
c, c_to_vec = to_vec(k.c)
function ConstantKernel_from_vec(c)
return ConstantKernel(c=first(c_to_vec(c)))
end
c, ConstantKernel_from_vec
end

Base.zero(x::AbstractGPs.ZeroMean) = x
Base.zero(x::Kernel) = x
Base.zero(x::TemporalGPs.LTISDE) = x
Expand Down

2 comments on commit fa8b4af

@theogf
Copy link
Member Author

@theogf theogf commented on fa8b4af Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/81845

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.3 -m "<description of version>" fa8b4afa26a364dc1f557b4276e0ce3a2a7493a6
git push origin v0.6.3

Please sign in to comment.