-
Notifications
You must be signed in to change notification settings - Fork 32
/
transform.jl
258 lines (226 loc) · 8.46 KB
/
transform.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
using Test
using Bijectors
using ForwardDiff: derivative, jacobian
using LinearAlgebra: logabsdet, I, norm
using Random
Random.seed!(123)
# logabsdet doesn't handle scalars.
_logabsdet(x::AbstractArray) = logabsdet(x)[1]
_logabsdet(x::Real) = log(abs(x))
# Generate a (vector / matrix of) random number(s).
_rand_real(::Real) = randn()
_rand_real(x) = randn(size(x))
_rand_real(x, e) = (y = randn(size(x)); y[end] = e; y)
# Standard tests for all distributions involving a single-sample.
function single_sample_tests(dist, jacobian)
ϵ = eps(Float64)
# Do the regular single-sample tests.
single_sample_tests(dist)
# Check that the implementation of the logpdf agrees with the AD version.
x = rand(dist)
if dist isa SimplexDistribution
logpdf_ad = logpdf(dist, x .+ ϵ) - _logabsdet(jacobian(x->link(dist, x, false), x))
else
logpdf_ad = logpdf(dist, x) - _logabsdet(jacobian(x->link(dist, x), x))
end
@test logpdf_ad ≈ logpdf_with_trans(dist, x, true)
end
# Standard tests for all distributions involving a single-sample. Doesn't check that the
# logpdf implementation is consistent with the link function for technical reasons.
function single_sample_tests(dist)
ϵ = eps(Float64)
# Check that invlink is inverse of link.
x = rand(dist)
@test invlink(dist, link(dist, copy(x))) ≈ x atol=1e-9
# Check that link is inverse of invlink. Hopefully this just holds given the above...
y = link(dist, x)
if dist isa Dirichlet
# `logit` and `logistic` are not perfect inverses. This leads to a diversion.
# Example:
# julia> logit(logistic(0.9999999999999998))
# 1.0
# julia> logistic(logit(0.9999999999999998))
# 0.9999999999999998
@test link(dist, invlink(dist, copy(y))) ≈ y atol=0.5
else
@test link(dist, invlink(dist, copy(y))) ≈ y atol=1e-9
end
if dist isa SimplexDistribution
# This should probably be exact.
@test logpdf(dist, x .+ ϵ) == logpdf_with_trans(dist, x, false)
# Check that invlink maps back to the apppropriate constrained domain.
@test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x, 0)) .+ ϵ for _ in 1:100]))
else
# This should probably be exact.
@test logpdf(dist, x) == logpdf_with_trans(dist, x, false)
@test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100]))
end
# This is a quirk of the current implementation, of which it would be nice to be rid.
@test typeof(x) == typeof(y)
end
# Standard tests for all distributions involving multiple samples. xs should be whatever
# the appropriate repeated version of x is for the distribution in question. ie. for
# univariate distributions, just a vector of identical values. For vector-valued
# distributions, a matrix whose columns are identical.
function multi_sample_tests(dist, x, xs, N)
ys = link(dist, copy(xs))
@test invlink(dist, link(dist, copy(xs))) ≈ xs atol=1e-9
@test link(dist, invlink(dist, copy(ys))) ≈ ys atol=1e-9
@test logpdf_with_trans(dist, xs, true) == fill(logpdf_with_trans(dist, x, true), N)
@test logpdf_with_trans(dist, xs, false) == fill(logpdf_with_trans(dist, x, false), N)
# This is a quirk of the current implementation, of which it would be nice to be rid.
@test typeof(xs) == typeof(ys)
end
# Scalar tests
@testset "scalar" begin
let
# Tests with scalar-valued distributions.
uni_dists = [
Arcsine(2, 4),
Beta(2,2),
BetaPrime(),
Biweight(),
Cauchy(),
Chi(3),
Chisq(2),
Cosine(),
Epanechnikov(),
Erlang(),
Exponential(),
FDist(1, 1),
Frechet(),
Gamma(),
InverseGamma(),
InverseGaussian(),
# Kolmogorov(),
Laplace(),
Levy(),
Logistic(),
LogNormal(1.0, 2.5),
Normal(0.1, 2.5),
Pareto(),
Rayleigh(1.0),
TDist(2),
truncated(Normal(0, 1), -Inf, 2),
]
for dist in uni_dists
single_sample_tests(dist, derivative)
# specialised multi-sample tests.
N = 10
x = rand(dist)
xs = fill(x, N)
multi_sample_tests(dist, x, xs, N)
end
end
end
# Tests with vector-valued distributions.
@testset "vector" begin
let ϵ = eps(Float64)
vector_dists = [
Dirichlet(2, 3),
Dirichlet([1000 * one(Float64), eps(Float64)]),
Dirichlet([eps(Float64), 1000 * one(Float64)]),
MvNormal(randn(10), exp.(randn(10))),
MvLogNormal(MvNormal(randn(10), exp.(randn(10)))),
Dirichlet([1000 * one(Float64), eps(Float64)]),
Dirichlet([eps(Float64), 1000 * one(Float64)]),
]
for dist in vector_dists
if dist isa Dirichlet
single_sample_tests(dist)
# This should fail at the minute. Not sure what the correct way to test this is.
x = rand(dist)
logpdf_turing = logpdf_with_trans(dist, x, true)
J = jacobian(x->link(dist, x, false), x)
@test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing
# Issue #12
stepsize = 1e10
dim = length(dist)
x = [logpdf_with_trans(dist, invlink(dist, link(dist, rand(dist)) .+ randn(dim) .* stepsize), true) for _ in 1:1_000]
@test !any(isinf, x) && !any(isnan, x)
else
single_sample_tests(dist, jacobian)
end
# Multi-sample tests. Columns are observations due to Distributions.jl conventions.
N = 10
x = rand(dist)
xs = repeat(x, 1, N)
multi_sample_tests(dist, x, xs, N)
end
end
end
# Tests with matrix-valued distributions.
@testset "matrix" begin
let
matrix_dists = [
Wishart(7, [1 0.5; 0.5 1]),
InverseWishart(2, [1 0.5; 0.5 1]),
]
for dist in matrix_dists
single_sample_tests(dist)
x = rand(dist); x = x + x' + 2I
lowerinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2]]
upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]]
logpdf_turing = logpdf_with_trans(dist, x, true)
J = jacobian(x->link(dist, x), x)
J = J[lowerinds, upperinds]
@test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing
# Multi-sample tests comprising vectors of matrices.
N = 10
x = rand(dist)
xs = [x for _ in 1:N]
multi_sample_tests(dist, x, xs, N)
end
end
end
################################## Miscelaneous old tests ##################################
# julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), exp.([-1000., -1000., -1000.]), true)
# NaN
# julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), [-1000., -1000., -1000.], true, true)
# -1999.30685281944
#
# julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), exp.([-1., -2., -3.]), true)
# -3.006450206744678
# julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), [-1., -2., -3.], true, true)
# -3.006450206744678
d = Dirichlet([1., 1., 1.])
r = [-1000., -1000., 0.0]
r2 = [-1., -2., 0.0]
# test vector invlink
dist = Dirichlet(ones(5))
x = [[-2.72689, -2.92751, 1.63114, -1.62054, 0.0] [-1.24249, 2.58902, -3.73043, -3.53685, 0.0]]
@test all(sum(Bijectors.invlink(dist, x), dims = 1) .== 1)
# test link
#link(d, r)
# test invlink
@test invlink(d, r) ≈ [0., 0., 1.] atol=1e-9
# test logpdf_with_trans
#@test logpdf_with_trans(d, invlink(d, r), true) -1999.30685281944 1e-9 ≈ # atol=NaN
@test logpdf_with_trans(d, invlink(d, r2), true) ≈ -3.760398892580863 atol=1e-9
macro aeq(x, y)
return quote
x = $(esc(x))
y = $(esc(y))
norm = $(esc(:norm))
norm(x - y) <= 1e-10
end
end
@testset "Dirichlet Jacobians" begin
function test_link_and_invlink()
dist = Dirichlet(4, 4)
x = rand(dist)
y = link(dist, x)
f1 = x -> link(dist, x, true)
f2 = x -> link(dist, x, false)
g1 = y -> invlink(dist, y, true)
g2 = y -> invlink(dist, y, false)
@test @aeq jacobian(f1, x) Bijectors.simplex_link_jacobian(x, true)
@test @aeq jacobian(f2, x) Bijectors.simplex_link_jacobian(x, false)
@test @aeq jacobian(g1, y) Bijectors.simplex_invlink_jacobian(y, true)
@test @aeq jacobian(g2, y) Bijectors.simplex_invlink_jacobian(y, false)
@test @aeq Bijectors.simplex_link_jacobian(x, false) * Bijectors.simplex_invlink_jacobian(y, false) I
end
for i in 1:4
test_link_and_invlink()
end
end