Skip to content

Commit

Permalink
Merge 94f9823 into 784dbad
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace committed Jan 23, 2024
2 parents 784dbad + 94f9823 commit 63df7dc
Show file tree
Hide file tree
Showing 34 changed files with 109 additions and 2,037 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk> and contributors"]
version = "0.6.8"
version = "0.7.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.5.17"
Bessels = "0.2.8"
BlockDiagonals = "0.1.7"
ChainRulesCore = "1"
FillArrays = "0.13.0 - 0.13.7, 1"
KernelFunctions = "0.9, 0.10.1"
StaticArrays = "1"
StructArrays = "0.5, 0.6"
Zygote = "0.6.65"
julia = "1.6"
6 changes: 0 additions & 6 deletions src/TemporalGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ module TemporalGPs
using AbstractGPs
using Bessels: besseli
using BlockDiagonals
using ChainRulesCore
import ChainRulesCore: rrule
using FillArrays
using LinearAlgebra
using KernelFunctions
using Random
using StaticArrays
using StructArrays
using Zygote

using FillArrays: AbstractFill

Expand All @@ -36,12 +33,9 @@ module TemporalGPs
ApproxPeriodicKernel

# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "harmonise.jl"))
include(joinpath("util", "linear_algebra.jl"))
include(joinpath("util", "scan.jl"))
include(joinpath("util", "zygote_friendly_map.jl"))

include(joinpath("util", "chainrules.jl"))
include(joinpath("util", "gaussian.jl"))
include(joinpath("util", "mul.jl"))
include(joinpath("util", "storage_types.jl"))
Expand Down
101 changes: 36 additions & 65 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,30 @@ end
function build_lgssm(f::LTISDE, x::AbstractVector, Σys::AbstractVector)
m = get_mean(f)
k = get_kernel(f)
s = Zygote.literal_getfield(f, Val(:storage))
s = f.storage
As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, s)
return LGSSM(
GaussMarkovModel(Forward(), As, as, Qs, x0), build_emissions(emission_proj, Σys),
)
end

function build_lgssm(ft::FiniteLTISDE)
f = Zygote.literal_getfield(ft, Val(:f))
x = Zygote.literal_getfield(ft, Val(:x))
Σys = noise_var_to_time_form(x, Zygote.literal_getfield(ft, Val(:Σy)))
f = ft.f
x = ft.x
Σys = noise_var_to_time_form(x, ft.Σy)
return build_lgssm(f, x, Σys)
end

get_mean(f::LTISDE) = get_mean(Zygote.literal_getfield(f, Val(:f)))
get_mean(f::GP) = Zygote.literal_getfield(f, Val(:mean))
get_mean(f::LTISDE) = get_mean(f.f)
get_mean(f::GP) = f.mean

get_kernel(f::LTISDE) = get_kernel(Zygote.literal_getfield(f, Val(:f)))
get_kernel(f::GP) = Zygote.literal_getfield(f, Val(:kernel))
get_kernel(f::LTISDE) = get_kernel(f.f)
get_kernel(f::GP) = f.kernel

function build_emissions(
(Hs, hs)::Tuple{AbstractVector, AbstractVector}, Σs::AbstractVector,
)
Hst = _map(adjoint, Hs)
Hst = map(adjoint, Hs)
return StructArray{get_type(Hst, hs, Σs)}((Hst, hs, Σs))
end

Expand All @@ -114,10 +114,6 @@ function get_type(Hs_prime, hs::AbstractVector{<:AbstractVector}, Σs)
return T
end

@inline function Zygote.wrap_chainrules_output(x::NamedTuple)
return map(Zygote.wrap_chainrules_output, x)
end

# Constructor for combining kernel and mean functions
function lgssm_components(
::ZeroMean, k::Kernel, t::AbstractVector, storage_type::StorageType
Expand All @@ -128,7 +124,7 @@ end
function lgssm_components(
m::AbstractGPs.MeanFunction, k::Kernel, t::AbstractVector, storage_type::StorageType
)
m = collect(mean_vector(m, t)) # `collect` is needed as there are still issues with Zygote and FillArrays.
m = mean_vector(m, t)
As, as, Qs, (Hs, hs), x0 = lgssm_components(k, t, storage_type)
hs = add_proj_mean(hs, m)

Expand All @@ -146,9 +142,9 @@ end
function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::AbstractVector{<:Real}, ::StorageType{T}) where {T}
P = Symmetric(x0.P)
t = vcat([first(t) - 1], t)
As = _map(Δt -> time_exp(F, T(Δt)), diff(t))
As = map(Δt -> time_exp(F, T(Δt)), diff(t))
as = Fill(Zeros{T}(size(first(As), 1)), length(As))
Qs = _map(A -> P - A * P * A', As)
Qs = map(A -> P - A * P * A', As)
Hs = Fill(H, length(As))
hs = Fill(zero(T), length(As))
As, as, Qs, Hs, hs
Expand All @@ -158,14 +154,16 @@ function broadcast_components((F, q, H)::Tuple, x0::Gaussian, t::Union{StepRange
P = Symmetric(x0.P)
A = time_exp(F, T(step(t)))
As = Fill(A, length(t))
as = @ignore_derivatives(Fill(Zeros{T}(size(F, 1)), length(t)))
as = Fill(Zeros{T}(size(F, 1)), length(t))
Q = Symmetric(P) - A * Symmetric(P) * A'
Qs = Fill(Q, length(t))
Hs = Fill(H, length(t))
hs = Fill(zero(T), length(As))
As, as, Qs, Hs, hs
end

time_exp(A, t) = exp(A * t)

function lgssm_components(
k::SimpleKernel, t::AbstractVector{<:Real}, storage::StorageType{T},
) where {T<:Real}
Expand Down Expand Up @@ -332,49 +330,49 @@ end
# Scaled

function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real}
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
_k = k.kernel
σ² = k.σ²
F, q, H = to_sde(_k, storage)
σ = sqrt(convert(eltype(storage), only(σ²)))
return F, σ^2 * q, σ * H
end

stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)
stationary_distribution(k::ScaledKernel, storage::StorageType) = stationary_distribution(k.kernel, storage)

function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
_k = k.kernel
σ² = k.σ²
As, as, Qs, emission_proj, x0 = lgssm_components(_k, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(σ²)))
return As, as, Qs, _scale_emission_projections(emission_proj, σ), x0
end

function _scale_emission_projections((Hs, hs)::Tuple{AbstractVector, AbstractVector}, σ::Real)
return _map(H->σ * H, Hs), _map(h->σ * h, hs)
return map(H->σ * H, Hs), map(h->σ * h, hs)
end

function _scale_emission_projections((Cs, cs, Hs, hs), σ)
return (Cs, cs, _map(H -> σ * H, Hs), _map(h -> σ * h, hs))
return (Cs, cs, map(H -> σ * H, Hs), map(h -> σ * h, hs))
end

# Stretched

function to_sde(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
_k = k.kernel
s = k.transform.s
F, q, H = to_sde(_k, storage)
return F * only(s), q, H
end

stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(Zygote.literal_getfield(k, Val(:kernel)), storage)
stationary_distribution(k::TransformedKernel{<:Kernel, <:ScaleTransform}, storage::StorageType) = stationary_distribution(k.kernel, storage)

function lgssm_components(
k::TransformedKernel{<:Kernel, <:ScaleTransform},
ts::AbstractVector,
storage_type::StorageType,
)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
_k = k.kernel
s = k.transform.s
return lgssm_components(_k, apply_stretch(s[1], ts), storage_type)
end

Expand All @@ -383,9 +381,9 @@ apply_stretch(a, ts::AbstractVector{<:Real}) = a * ts
apply_stretch(a, ts::StepRangeLen) = a * ts

function apply_stretch(a, ts::RegularSpacing)
t0 = Zygote.literal_getfield(ts, Val(:t0))
Δt = Zygote.literal_getfield(ts, Val(:Δt))
N = Zygote.literal_getfield(ts, Val(:N))
t0 = ts.t0
Δt = ts.Δt
N = ts.N
return RegularSpacing(a * t0, a * Δt, N)
end

Expand Down Expand Up @@ -425,9 +423,9 @@ function lgssm_components(k::KernelSum, ts::AbstractVector, storage_type::Storag
emission_proj_kernels = getindex.(lgssms, 4)
x0_kernels = getindex.(lgssms, 5)

As = _map(block_diagonal, As_kernels...)
as = _map(vcat, as_kernels...)
Qs = _map(block_diagonal, Qs_kernels...)
As = map(block_diagonal, As_kernels...)
as = map(vcat, as_kernels...)
Qs = map(block_diagonal, Qs_kernels...)
emission_projections = _sum_emission_projections(emission_proj_kernels...)
x0 = Gaussian(mapreduce(x -> getproperty(x, :m), vcat, x0_kernels), block_diagonal(getproperty.(x0_kernels, :P)...))
return As, as, Qs, emission_projections, x0
Expand All @@ -444,10 +442,10 @@ function _sum_emission_projections(
cs = getindex.(Cs_cs_Hs_hs, 2)
Hs = getindex.(Cs_cs_Hs_hs, 3)
hs = getindex.(Cs_cs_Hs_hs, 4)
C = _map(vcat, Cs...)
C = map(vcat, Cs...)
c = sum(cs)
H = _map(block_diagonal, Hs...)
h = _map(vcat, hs...)
H = map(block_diagonal, Hs...)
h = map(vcat, hs...)
return C, c, H, h
end

Expand All @@ -460,36 +458,9 @@ function block_diagonal(As::AbstractMatrix{T}...) where {T}
return hvcat(ntuple(_ -> nblocks, nblocks), Xs...)
end

function ChainRulesCore.rrule(::typeof(block_diagonal), As::AbstractMatrix...)
szs = size.(As)
row_szs = (0, cumsum(first.(szs))...)
col_szs = (0, cumsum(last.(szs))...)
block_diagonal_rrule::AbstractThunk) = block_diagonal_rrule(unthunk(Δ))
function block_diagonal_rrule(Δ)
ΔAs = ntuple(length(As)) do i
Δ[(row_szs[i]+1):row_szs[i+1], (col_szs[i]+1):col_szs[i+1]]
end
return NoTangent(), ΔAs...
end
return block_diagonal(As...), block_diagonal_rrule
end

function block_diagonal(As::SMatrix...)
nblocks = length(As)
sizes = size.(As)
Xs = [i == j ? As[i] : zeros(SMatrix{sizes[j][1], sizes[i][2]}) for i in 1:nblocks, j in 1:nblocks]
return hcat(Base.splat(vcat).(eachrow(Xs))...)
end

function ChainRulesCore.rrule(::typeof(block_diagonal), As::SMatrix...)
szs = size.(As)
row_szs = (0, cumsum(first.(szs))...)
col_szs = (0, cumsum(last.(szs))...)
function block_diagonal_rrule(Δ)
ΔAs = ntuple(length(As)) do i
Δ[SVector{szs[i][1]}((row_szs[i]+1):row_szs[i+1]), SVector{szs[i][2]}((col_szs[i]+1):col_szs[i+1])]
end
return NoTangent(), ΔAs...
end
return block_diagonal(As...), block_diagonal_rrule
end
18 changes: 9 additions & 9 deletions src/gp/posterior_lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ function AbstractGPs.marginals(fx::FinitePosteriorLTISDE)
model_post = replace_observation_noise_cov(posterior(model, ys), σ²s_pr_full)
return destructure(x, map(marginals, marginals(model_post))[pr_indices])
else
f = Zygote.literal_getfield(fx, Val(:f))
prior = Zygote.literal_getfield(f, Val(:prior))
x = Zygote.literal_getfield(fx, Val(:x))
data = Zygote.literal_getfield(f, Val(:data))
Σy = Zygote.literal_getfield(data, Val(:Σy))
Σy_diag = Zygote.literal_getfield(Σy, Val(:diag))
y = Zygote.literal_getfield(data, Val(:y))

Σy_new = Zygote.literal_getfield(fx, Val(:Σy))
f = fx.f
prior = f.prior
x = fx.x
data = f.data
Σy = data.Σy
Σy_diag = Σy.diag
y = data.y

Σy_new = fx.Σy

model = build_lgssm(AbstractGPs.FiniteGP(prior, x, Σy))
Σys_new = noise_var_to_time_form(x, Σy_new)
Expand Down
34 changes: 1 addition & 33 deletions src/models/gauss_markov_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@ struct GaussMarkovModel{
x0::Tx0
end

# Helps Zygote out with some type-stability issues. Why this helps is unclear.
function ChainRulesCore.rrule(::Type{<:GaussMarkovModel}, ordering, As, as, Qs, x0)
function GaussMarkovModel_pullback(Δ)
return NoTangent(), NoTangent(), Δ.As, Δ.as, Δ.Qs, Δ.x0
end
return GaussMarkovModel(ordering, As, as, Qs, x0), GaussMarkovModel_pullback
end

ordering(model::GaussMarkovModel) = model.ordering

Base.eltype(model::GaussMarkovModel) = eltype(first(model.As))
Expand All @@ -65,28 +57,4 @@ function is_of_storage_type(model::GaussMarkovModel, s::StorageType)
return is_of_storage_type((model.As, model.as, model.Qs, model.x0), s)
end

x0(model::GaussMarkovModel) = Zygote.literal_getfield(model, Val(:x0))

function get_adjoint_storage(x::GaussMarkovModel, n::Int, Δx::Tangent{T,<:NamedTuple{(:A, :a, :Q)}}) where {T}
return (
ordering = NoTangent(),
As = get_adjoint_storage(x.As, n, Δx.A),
as = get_adjoint_storage(x.as, n, Δx.a),
Qs = get_adjoint_storage(x.Qs, n, Δx.Q),
x0 = NoTangent(),
)
end

function _accum_at(
Δxs::NamedTuple{(:ordering, :As, :as, :Qs, :x0)},
n::Int,
Δx::Tangent{T, <:NamedTuple{(:A, :a, :Q)}},
) where {T}
return (
ordering = NoTangent(),
As = _accum_at(Δxs.As, n, Δx.A),
as = _accum_at(Δxs.as, n, Δx.a),
Qs = _accum_at(Δxs.Qs, n, Δx.Q),
x0 = NoTangent(),
)
end
x0(model::GaussMarkovModel) = model.x0

0 comments on commit 63df7dc

Please sign in to comment.