Skip to content

Commit

Permalink
Merge pull request #100 from JuliaGaussianProcesses/tgf/nonzero-mean
Browse files Browse the repository at this point in the history
Allow for non-zero mean
  • Loading branch information
theogf committed Apr 11, 2023
2 parents f82dd5b + a59f6d9 commit 03c0961
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 64 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <wt0881@my.bristol.ac.uk> and contributors"]
version = "0.6.1"
version = "0.6.2"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand All @@ -16,7 +16,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.5"
AbstractGPs = "0.5.15"
BlockDiagonals = "0.1.7"
ChainRulesCore = "1"
FillArrays = "0.13.0 - 0.13.7"
Expand Down
2 changes: 1 addition & 1 deletion examples/exact_time_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using TemporalGPs
using TemporalGPs: RegularSpacing

# Build a GP as per usual, and wrap it inside a TemporalGPs.jl object.
f_raw = GP(Matern52Kernel());
f_raw = GP(5.0, Matern52Kernel());
f = to_sde(f_raw, SArrayStorage(Float64));

# Specify a collection of inputs. Must be increasing.
Expand Down
5 changes: 3 additions & 2 deletions examples/exact_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using Zygote # Algorithmic Differentiation
# var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the
# variance of the observation noise. Note that they're all constrained to be positive.
flat_initial_params, unpack = ParameterHandling.value_flatten((
mean = 3.0,
var_kernel = positive(0.6),
λ = positive(0.1),
var_noise = positive(2.0),
Expand All @@ -28,7 +29,7 @@ params = unpack(flat_initial_params);

function build_gp(params)
k = params.var_kernel * Matern52Kernel() ScaleTransform(params.λ)
return to_sde(GP(k), SArrayStorage(Float64))
return to_sde(GP(params.mean, k), SArrayStorage(Float64))
end

# Specify a collection of inputs. Must be increasing.
Expand All @@ -50,7 +51,7 @@ end
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective unpack, θ)),
flat_initial_params + randn(3), # Perturb the parameters to make learning non-trivial
flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
linesearch = Optim.LineSearches.BackTracking(),
Expand Down
40 changes: 32 additions & 8 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
LTISDE (Linear Time-Invariant Stochastic Differential Equation)
A lightweight wrapper around a `GP` `f` that tells this package to handle inference in `f`.
Can be constructed via the `to_sde` function.
Can be constructed via the [`to_sde`](@ref) function.
"""
struct LTISDE{Tf<:GP{<:AbstractGPs.ZeroMean}, Tstorage<:StorageType} <: AbstractGP
struct LTISDE{Tf<:GP, Tstorage<:StorageType} <: AbstractGP
f::Tf
storage::Tstorage
end

function to_sde(f::GP{<:AbstractGPs.ZeroMean}, storage_type=ArrayStorage(Float64))
function to_sde(f::GP, storage_type=ArrayStorage(Float64))
return LTISDE(f, storage_type)
end

Expand All @@ -19,13 +19,13 @@ storage_type(f::LTISDE) = f.storage
const FiniteLTISDE = FiniteGP{<:LTISDE}
A `FiniteLTISDE` is just a regular `FiniteGP` that happens to contain an `LTISDE`, as
opposed to any other `AbstractGP`.
opposed to any other `AbstractGP`, useful for dispatching.
"""
const FiniteLTISDE = FiniteGP{<:LTISDE}

# Deal with a bug in AbstractGPs.
function FiniteGP(f::LTISDE, x::AbstractVector{<:Real})
return FiniteGP(f, x, convert(eltype(storage_type(f)), 1e-12))
function AbstractGPs.FiniteGP(f::LTISDE, x::AbstractVector{<:Real})
return AbstractGPs.FiniteGP(f, x, convert(eltype(storage_type(f)), 1e-12))
end

# Implement the AbstractGP API.
Expand Down Expand Up @@ -68,11 +68,11 @@ function _logpdf(ft::FiniteLTISDE, y::AbstractVector{<:Union{Missing, Real}})
end

# Converting GPs into LGSSMs (Linear Gaussian State-Space Models).

function build_lgssm(f::LTISDE, x::AbstractVector, Σys::AbstractVector)
m = get_mean(f)
k = get_kernel(f)
s = Zygote.literal_getfield(f, Val(:storage))
As, as, Qs, emission_proj, x0 = lgssm_components(k, x, s)
As, as, Qs, emission_proj, x0 = lgssm_components(m, k, x, s)
return LGSSM(
GaussMarkovModel(Forward(), As, as, Qs, x0), build_emissions(emission_proj, Σys),
)
Expand All @@ -85,6 +85,9 @@ function build_lgssm(ft::FiniteLTISDE)
return build_lgssm(f, x, Σys)
end

get_mean(f::LTISDE) = get_mean(Zygote.literal_getfield(f, Val(:f)))
get_mean(f::GP) = Zygote.literal_getfield(f, Val(:mean))

get_kernel(f::LTISDE) = get_kernel(Zygote.literal_getfield(f, Val(:f)))
get_kernel(f::GP) = Zygote.literal_getfield(f, Val(:kernel))

Expand Down Expand Up @@ -115,7 +118,28 @@ end
return map(Zygote.wrap_chainrules_output, x)
end

# Constructor for combining kernel and mean functions
function lgssm_components(
::ZeroMean, k::Kernel, t::AbstractVector, storage_type::StorageType
)
return lgssm_components(k, t, storage_type)
end

function lgssm_components(
m::AbstractGPs.MeanFunction, k::Kernel, t::AbstractVector, storage_type::StorageType
)
m = collect(mean_vector(m, t)) # `collect` is needed as there are still issues with Zygote and FillArrays.
As, as, Qs, (Hs, hs), x0 = lgssm_components(k, t, storage_type)
hs = add_proj_mean(hs, m)

return As, as, Qs, (Hs, hs), x0
end

# Either build a new vector or update an existing one with
add_proj_mean(hs::AbstractVector{<:Real}, m) = hs .+ m
function add_proj_mean(hs::AbstractVector, m)
return map((h, m) -> h + vcat(m, Zeros(length(h) - 1)), hs, m)
end

# Generic constructors for base kernels.

Expand Down
7 changes: 4 additions & 3 deletions src/models/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ emission_type(model::LGSSM) = eltype(emissions(model))


# Functionality for indexing into an LGSSM.
"""
ElementOfLGSSM
Represents an element of [`LGSSM`](@ref) with a given ordering.
"""
struct ElementOfLGSSM{Tordering, Ttransition, Temission}
ordering::Tordering
transition::Ttransition
Expand All @@ -70,10 +74,7 @@ end
return ElementOfLGSSM(ordering(model), model.transitions[n], model.emissions[n])
end



# Draw a sample from the model.

function AbstractGPs.rand(rng::AbstractRNG, model::LGSSM)
iterable = zip(ε_randn(rng, model), model)
init = rand(rng, x0(model))
Expand Down
39 changes: 23 additions & 16 deletions src/models/linear_gaussian_conditionals.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
"""
abstract type AbstractLGC end
AbstractLGC
Represents a Gaussian conditional distribution:
```julia
y | x ∼ Gaussian(A * x + a, Q)
```
Note that this can be used in two contexts:
- Transition: `y` is the next state, `x` is the current state.
- Emission: `y` is the observation, `x` is the state.
Subtypes have discretion over how to implement the interface for this type. In particular
`A`, `a`, and `Q` may not be represented explicitly so that structure can be exploited to
accelerate computations.
# Interface:
## Interface:
- `==`
- `eltype`
- `dim_out`
Expand All @@ -28,10 +34,11 @@ Base.:(==)(x::AbstractLGC, y::AbstractLGC) = (x.A == y.A) && (x.a == y.a) && (x.
Base.eltype(f::AbstractLGC) = eltype(f.A)

"""
predict(x::Gaussian, f::AbstractLGC)
predict(x::Gaussian, f::AbstractLGC)::Gaussian{Tm,AbstractMatrix}
Compute the distribution "predicted" by this conditional given a `Gaussian` input `x`. Will
Compute the distribution "predicted" by this conditional given a [`Gaussian`](@ref) input `x`. Will
be equivalent to
```julia
Gaussian(f.A * x.m + f.a, f.A * x.P * f.A' + f.Q)
```
Expand All @@ -44,12 +51,12 @@ function predict(x::Gaussian, f::AbstractLGC)
end

"""
predict_marginals(x::Gaussian, f::AbstractLGC)
predict_marginals(x::Gaussian, f::AbstractLGC)::Gaussian{Tm,Diagonal}
Equivalent to
```julia
y = predict(x, f)
Gaussian(mean(y), Diagonal(cov(y)))
xꜝ⁺¹ = predict(xꜝ, f)
Gaussian(mean(xꜝ⁺¹), Diagonal(cov(xꜝ⁺¹)))
```
"""
function predict_marginals(x::Gaussian, f::AbstractLGC)
Expand All @@ -64,7 +71,7 @@ end
conditional_rand(ε::AbstractVector, f::AbstractLGC, x::AbstractVector)
Sample from the conditional distribution `y | x`. `ε` is the randomness needed to generate
this sample. If `rng` is provided, it will be used to construct `ε` via `ε_randn`.
this sample. If `rng` is provided, it will be used to construct `ε` via [`ε_randn`](@ref).
If implementing a new `AbstractLGC`, implement the `ε` method as it avoids randomness, which
means that it plays nicely with `scan_emit`'s checkpointed rrule.
Expand All @@ -81,7 +88,7 @@ end
"""
ε_randn(rng::AbstractRNG, f::AbstractLGC)
Generate the vector of random numbers needed inside `conditional_rand`.
Generate the vector of random numbers needed inside [`conditional_rand`](@ref).
"""
ε_randn(rng::AbstractRNG, f::AbstractLGC) = ε_randn(rng, f.A)
ε_randn(rng::AbstractRNG, A::AbstractMatrix{T}) where {T<:Real} = randn(rng, T, size(A, 1))
Expand All @@ -101,10 +108,10 @@ ChainRulesCore.@non_differentiable scalar_type(x)
TA<:AbstractMatrix, Ta<:AbstractVector, TQ<:AbstractMatrix,
} <: AbstractLGC
a.k.a. LGC. An `AbstractLGC` designed for problems in which `A` is a matrix, and
a.k.a. LGC. An [`AbstractLGC`](@ref) designed for problems in which `A` is a matrix, and
`size(A, 1) < size(A, 2)`. It should still work (roughly) for problems in which
`size(A, 1) > size(A, 2)`, but one should expect worse accuracy and performance than a
`LargeOutputLGC` in such circumstances.
[`LargeOutputLGC`](@ref) in such circumstances.
"""
struct SmallOutputLGC{
TA<:AbstractMatrix, Ta<:AbstractVector, TQ<:AbstractMatrix,
Expand Down Expand Up @@ -156,7 +163,7 @@ end
TA<:AbstractMatrix, Ta<:AbstractVector, TQ<:AbstractMatrix,
} <: AbstractLGC
A SmallOutputLGC (LGC) specialised for models in which the dimension of the
A [`SmallOutputLGC`](@ref) (LGC) specialised for models in which the dimension of the
outputs are greater than that of the inputs. These specialisations both improve numerical
stability and performance (time and memory), so it's worth using if your model lives in
this regime.
Expand Down Expand Up @@ -234,8 +241,8 @@ end
"""
ScalarOutputLGC
An LGC that operates on a vector-valued input space and a scalar-valued output space.
Similar to `SmallOutputLGC` when its `dim_out` is 1 but, for example, `conditional_rand`
An [`AbstractLGC`](@ref) that operates on a vector-valued input space and a scalar-valued output space.
Similar to [`SmallOutputLGC`](@ref) when its `dim_out` is 1 but, for example, [`conditional_rand`](@ref)
returns a `Real` rather than an `AbstractVector` of length 1.
"""
struct ScalarOutputLGC{
Expand Down Expand Up @@ -284,10 +291,10 @@ end
BottleneckLGC
A composition of an affine map that projects onto a low-dimensional subspace and a
`LargeOutputLGC`. This structure is exploited by only ever computing `Cholesky`
[`LargeOutputLGC`](@ref). This structure is exploited by only ever computing `Cholesky`
factorisations in the space the affine map maps to, rather than the input or output space.
Letting, `H` and `h` parametrise the affine map, and `f` the "fan-out" `LargeOutputLGC`, the
Letting, `H` and `h` parametrise the affine map, and `f` the "fan-out" [`LargeOutputLGC`](@ref), the
conditional distribution that this model parametrises is
```julia
y | x ~ Gaussian(f.A * (H * x + h) + f.a, f.Q)
Expand Down
2 changes: 0 additions & 2 deletions src/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ function approx_posterior_marginals(
z_r::AbstractVector,
x_r::AbstractVector,
)
fx.f.f.mean isa AbstractGPs.ZeroMean || throw(error("Prior mean of GP isn't zero."))

# Compute approximate posterior LGSSM.
lgssm = build_lgssm(dtcify(z_r, fx))
fx_post = posterior(lgssm, restructure(y, lgssm.emissions))
Expand Down
3 changes: 3 additions & 0 deletions src/util/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ function (project::ProjectTo{Fill})(dx::Tangent{<:Fill})
# end
Fill(dx.value / prod(length, project.axes), project.axes)
end
function (project::ProjectTo{Fill})(dx::Tangent{Any,<:NamedTuple{(:value, :axes)}})
Fill(dx.value / prod(length, project.axes), project.axes)
end

# Yet another thing that should not happen
function Zygote.accum(x::Fill, y::NamedTuple{(:value, :axes)})
Expand Down
Loading

2 comments on commit 03c0961

@theogf
Copy link
Member Author

@theogf theogf commented on 03c0961 Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/81417

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.2 -m "<description of version>" 03c09612bcc14047c64597f84e108e3621267a3e
git push origin v0.6.2

Please sign in to comment.