-
Notifications
You must be signed in to change notification settings - Fork 9
/
BufferedAD.jl
95 lines (78 loc) · 3.2 KB
/
BufferedAD.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Holds a buffer for in-place auto-differentiation.
For example, used by stan log potentials.
Fields:
$FIELDS
"""
struct BufferedAD{T, L, S}
""" A struct satisfying the `LogDensityProblems` informal interface. """
enclosed::T
""" The buffer used for in-place gradient computation. """
buffer::Vector{Float64}
""" A buffer for logdensity eval. """
logd_buffer::L
""" A buffer to hold error flags. """
err_buffer::S
end
LogDensityProblems.logdensity(buffered::BufferedAD, x) = LogDensityProblems.logdensity(buffered.enclosed, x)
LogDensityProblems.dimension(buffered::BufferedAD) = length(buffered.buffer)
BufferedAD(log_potential, buffers::Augmentation, logd_buffer = nothing, err_buffer = nothing) =
BufferedAD(
log_potential,
get_buffer(buffers, :gradient_buffer, LogDensityProblems.dimension(log_potential)),
logd_buffer,
err_buffer
)
"""
The target and reference may used different autodiff frameworks;
provided both are non-allocating, this allows autodiff of
`InterpolatedLogPotential`'s to also be non-allocating.
For example, this is useful when the target is a stan log potential
and the reference is a variational distribution with a hand-crafted,
also allocation-free differentiation.
Fields:
$FIELDS
"""
@auto struct InterpolatedAD
""" The enclosed `InterpolatedLogPotential`. """
enclosed
"""
The result of `LogDensityProblemsAD.ADgradient()` on the reference, often a
`BufferedAD`.
"""
ref_ad
"""
The same as `ref_ad` but with the target.
"""
target_ad
""" An extra buffer to combine the two distribution endpoints gradients. """
buffer::Vector{Float64}
end
LogDensityProblemsAD.ADgradient(kind::Symbol, log_potential, buffers::Augmentation) =
LogDensityProblemsAD.ADgradient(kind, log_potential)
LogDensityProblemsAD.ADgradient(kind::Symbol, log_potential::InterpolatedLogPotential{InterpolatingPath{R, T, LinearInterpolator}, B}, buffers::Augmentation) where {R, T, B} =
InterpolatedAD(
log_potential,
LogDensityProblemsAD.ADgradient(kind, log_potential.path.ref, buffers),
LogDensityProblemsAD.ADgradient(kind, log_potential.path.target, buffers),
get_buffer(buffers, :gradient_interpolated_buffer, LogDensityProblems.dimension(log_potential.path.ref))
)
function LogDensityProblems.logdensity(log_potential::InterpolatedAD, x)
l1 = LogDensityProblems.logdensity(log_potential.ref_ad, x)
l2 = LogDensityProblems.logdensity(log_potential.target_ad, x)
beta = log_potential.enclosed.beta
return (1.0 - beta) * l1 + beta * l2
end
LogDensityProblems.dimension(log_potential::InterpolatedAD) = LogDensityProblems.dimension(log_potential.ref_ad)
function LogDensityProblems.logdensity_and_gradient(log_potential::InterpolatedAD, x)
logdens = 0.0
beta = log_potential.enclosed.beta
buffer = log_potential.buffer
l, g = LogDensityProblems.logdensity_and_gradient(log_potential.ref_ad, x)
logdens += l * (1.0 - beta)
buffer .= g .* (1.0 - beta)
l, g = LogDensityProblems.logdensity_and_gradient(log_potential.target_ad, x)
logdens += l * beta
buffer .= buffer .+ g .* beta
return logdens, buffer
end