Skip to content

Commit

Permalink
Optimise posteiror marginals and Zygote (#67)
Browse files Browse the repository at this point in the history
* Optimise marginals at same inputs

* Bump patch version
  • Loading branch information
willtebbutt committed Apr 23, 2021
1 parent 4099efc commit b92a7af
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk>"]
version = "0.5.6"
version = "0.5.7"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
43 changes: 35 additions & 8 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end

function AbstractGPs.mean_and_var(ft::FiniteLTISDE)
ms = marginals(ft)
return mean.(ms), var.(ms)
return map(mean, ms), map(var, ms)
end

AbstractGPs.mean(ft::FiniteLTISDE) = mean_and_var(ft)[1]
Expand Down Expand Up @@ -77,16 +77,34 @@ destructure(y::AbstractVector{<:Real}) = y
# Converting GPs into LGSSMs.

function build_lgssm(ft::FiniteLTISDE)
As, as, Qs, emission_proj, x0 = lgssm_components(ft.f.f.kernel, ft.x, ft.f.storage)
k = get_kernel(ft)
x = Zygote.literal_getfield(ft, Val(:x))
s = Zygote.literal_getfield(Zygote.literal_getfield(ft, Val(:f)), Val(:storage))
As, as, Qs, emission_proj, x0 = lgssm_components(k, x, s)
return LGSSM(
GaussMarkovModel(Forward(), As, as, Qs, x0),
build_emissions(emission_proj, build_Σs(ft)),
)
end

build_Σs(ft::FiniteLTISDE) = build_Σs(ft.x, ft.Σy)
function get_kernel(ft::FiniteLTISDE)
return Zygote.literal_getfield(
Zygote.literal_getfield(
Zygote.literal_getfield(ft, Val(:f)), Val(:f),
),
Val(:kernel),
)
end

function build_Σs(ft::FiniteLTISDE)
x = Zygote.literal_getfield(ft, Val(:x))
Σy = Zygote.literal_getfield(ft, Val(:Σy))
return build_Σs(x, Σy)
end

build_Σs(::AbstractVector{<:Real}, Σ::Diagonal{<:Real}) = Σ.diag
function build_Σs(::AbstractVector{<:Real}, Σ::Diagonal{<:Real})
return Zygote.literal_getfield(Σ, Val(:diag))
end

function build_emissions(
(Hs, hs)::Tuple{AbstractVector, AbstractVector}, Σs::AbstractVector,
Expand Down Expand Up @@ -252,8 +270,10 @@ end
# Scaled

function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType)
As, as, Qs, emission_proj, x0 = lgssm_components(k.kernel, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(k.σ²)))
_k = Zygote.literal_getfield(k, Val(:kernel))
σ² = Zygote.literal_getfield(k, Val(:σ²))
As, as, Qs, emission_proj, x0 = lgssm_components(_k, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(σ²)))
return As, as, Qs, _scale_emission_projections(emission_proj, σ), x0
end

Expand All @@ -274,14 +294,21 @@ function lgssm_components(
ts::AbstractVector,
storage_type::StorageType,
)
return lgssm_components(k.kernel, apply_stretch(only(k.transform.s), ts), storage_type)
_k = Zygote.literal_getfield(k, Val(:kernel))
s = Zygote.literal_getfield(Zygote.literal_getfield(k, Val(:transform)), Val(:s))
return lgssm_components(_k, apply_stretch(s[1], ts), storage_type)
end

apply_stretch(a, ts::AbstractVector{<:Real}) = a * ts

apply_stretch(a, ts::StepRangeLen) = a * ts

apply_stretch(a, ts::RegularSpacing) = RegularSpacing(a * ts.t0, a * ts.Δt, ts.N)
function apply_stretch(a, ts::RegularSpacing)
t0 = Zygote.literal_getfield(ts, Val(:t0))
Δt = Zygote.literal_getfield(ts, Val(:Δt))
N = Zygote.literal_getfield(ts, Val(:N))
return RegularSpacing(a * t0, a * Δt, N)
end



Expand Down
31 changes: 24 additions & 7 deletions src/gp/posterior_lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,34 @@ function AbstractGPs.cov(fx::FinitePosteriorLTISDE)
end

function AbstractGPs.marginals(fx::FinitePosteriorLTISDE)
x, y, σ²s, pr_indices = build_inference_data(fx.f, fx.x)

model = build_lgssm(fx.f.prior(x, σ²s))
σ²s_pr_full = build_prediction_obs_vars(pr_indices, x, fx.Σy.diag)
model_post = replace_observation_noise_cov(posterior(model, y), σ²s_pr_full)
return map(marginals, marginals(model_post)[pr_indices])
if fx.x != fx.f.data.x
x, y, σ²s, pr_indices = build_inference_data(fx.f, fx.x)

model = build_lgssm(fx.f.prior(x, σ²s))
σ²s_pr_full = build_prediction_obs_vars(pr_indices, x, fx.Σy.diag)
model_post = replace_observation_noise_cov(posterior(model, y), σ²s_pr_full)
return map(marginals, marginals(model_post)[pr_indices])
else
f = Zygote.literal_getfield(fx, Val(:f))
prior = Zygote.literal_getfield(f, Val(:prior))
x = Zygote.literal_getfield(fx, Val(:x))
data = Zygote.literal_getfield(f, Val(:data))
Σy = Zygote.literal_getfield(data, Val(:Σy))
Σy_diag = Zygote.literal_getfield(Σy, Val(:diag))
y = Zygote.literal_getfield(data, Val(:y))

Σy_new = Zygote.literal_getfield(fx, Val(:Σy))
Σy_new_diag = Zygote.literal_getfield(Σy_new, Val(:diag))

model = build_lgssm(AbstractGPs.FiniteGP(prior, x, Σy))
model_post = replace_observation_noise_cov(posterior(model, y), Σy_new_diag)
return map(marginals, marginals(model_post))
end
end

function AbstractGPs.mean_and_var(fx::FinitePosteriorLTISDE)
ms = marginals(fx)
return mean.(ms), var.(ms)
return map(mean, ms), map(var, ms)
end

AbstractGPs.mean(fx::FinitePosteriorLTISDE) = mean_and_var(fx)[1]
Expand Down
6 changes: 5 additions & 1 deletion src/util/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ AbstractGPs.mean(x::Gaussian) = Zygote.literal_getfield(x, Val(:m))

AbstractGPs.cov(x::Gaussian) = Zygote.literal_getfield(x, Val(:P))

AbstractGPs.var(x::Gaussian{<:AbstractVector}) = diag(cov(x))

AbstractGPs.var(x::Gaussian{<:Real}) = cov(x)

get_fields(x::Gaussian) = mean(x), cov(x)

Random.rand(rng::AbstractRNG, x::Gaussian) = vec(rand(rng, x, 1))
Expand Down Expand Up @@ -54,7 +58,7 @@ function Base.isapprox(x::Gaussian, y::Gaussian; kwargs...)
return isapprox(mean(x), mean(y); kwargs...) && isapprox(cov(x), cov(y); kwargs...)
end

AbstractGPs.marginals(x::Gaussian{<:Real, <:Real}) = Normal(mean(x), sqrt(cov(x)))
AbstractGPs.marginals(x::Gaussian{T, T}) where {T<:Real} = Normal{T}(mean(x), sqrt(cov(x)))

function AbstractGPs.marginals(x::Gaussian{<:AbstractVector, <:AbstractMatrix})
return Normal.(mean(x), sqrt.(diag(cov(x))))
Expand Down

2 comments on commit b92a7af

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/35171

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.7 -m "<description of version>" b92a7afe09ae76bb6e6e72c19665a565639d8046
git push origin v0.5.7

Please sign in to comment.