-
Notifications
You must be signed in to change notification settings - Fork 5
/
expectations.jl
117 lines (98 loc) · 3.89 KB
/
expectations.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
using FastGaussQuadrature: gausshermite
using SpecialFunctions: loggamma
using ChainRulesCore: ChainRulesCore
using IrrationalConstants: sqrt2, invsqrtπ
struct DefaultExpectationMethod end
struct AnalyticExpectation end
struct GaussHermiteExpectation
xs::Vector{Float64}
ws::Vector{Float64}
end
GaussHermiteExpectation(n::Integer) = GaussHermiteExpectation(gausshermite(n)...)
ChainRulesCore.@non_differentiable gausshermite(n)
struct MonteCarloExpectation
n_samples::Int
end
default_expectation_method(_) = GaussHermiteExpectation(20)
"""
expected_loglikelihood(
quadrature,
lik,
q_f::AbstractVector{<:Normal},
y::AbstractVector,
)
This function computes the expected log likelihood:
```math
∫ q(f) log p(y | f) df
```
where `p(y | f)` is the process likelihood. This is described by `lik`, which should be a
callable that takes `f` as input and returns a Distribution over `y` that supports
`loglikelihood(lik(f), y)`.
`q(f)` is an approximation to the latent function values `f` given by:
```math
q(f) = ∫ p(f | u) q(u) du
```
where `q(u)` is the variational distribution over inducing points.
The marginal distributions of `q(f)` are given by `q_f`.
`quadrature` determines which method is used to calculate the expected log
likelihood.
# Extended help
`q(f)` is assumed to be an `MvNormal` distribution and `p(y | f)` is assumed to
have independent marginals such that only the marginals of `q(f)` are required.
"""
expected_loglikelihood(quadrature, lik, q_f, y)
"""
expected_loglikelihood(::DefaultExpectationMethod, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector)
The expected log likelihood, using the default quadrature method for the given likelihood.
(The default quadrature method is defined by `default_expectation_method(lik)`, and should
be the closed form solution if it exists, but otherwise defaults to Gauss-Hermite
quadrature.)
"""
function expected_loglikelihood(
::DefaultExpectationMethod, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
quadrature = default_expectation_method(lik)
return expected_loglikelihood(quadrature, lik, q_f, y)
end
function expected_loglikelihood(
mc::MonteCarloExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
# take `n_samples` reparameterised samples
f_μ = mean.(q_f)
fs = f_μ .+ std.(q_f) .* randn(eltype(f_μ), length(q_f), mc.n_samples)
lls = loglikelihood.(lik.(fs), y)
return sum(lls) / mc.n_samples
end
# Compute the expected_loglikelihood over a collection of observations and marginal distributions
function expected_loglikelihood(
gh::GaussHermiteExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
# Compute the expectation via Gauss-Hermite quadrature
# using a reparameterisation by change of variable
# (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)
return sum(Broadcast.instantiate(
Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair
# of marginal distribution q(fᵢ) and observation yᵢ
expected_loglikelihood(gh, lik, q_fᵢ, yᵢ)
end,
))
end
# Compute the expected_loglikelihood for one observation and a marginal distributions
function expected_loglikelihood(gh::GaussHermiteExpectation, lik, q_f::Normal, y)
μ = mean(q_f)
σ̃ = sqrt2 * std(q_f)
return invsqrtπ * sum(Broadcast.instantiate(
Broadcast.broadcasted(gh.xs, gh.ws) do x, w # Loop over every
# pair of Gauss-Hermite point x with weight w
f = σ̃ * x + μ
loglikelihood(lik(f), y) * w
end,
))
end
function expected_loglikelihood(
::AnalyticExpectation, lik, q_f::AbstractVector{<:Normal}, y::AbstractVector
)
return error(
"No analytic solution exists for $(typeof(lik)). Use `DefaultExpectationMethod`, `GaussHermiteExpectation` or `MonteCarloExpectation` instead.",
)
end