-
-
Notifications
You must be signed in to change notification settings - Fork 300
/
sampler.jl
64 lines (49 loc) · 1.75 KB
/
sampler.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#=
`MVNSampler` is used to draw from multivariate normal distribution
=#
import Base: ==
immutable MVNSampler{TM<:Real,TS<:Real,TQ<:LinAlg.BlasReal}
mu::Vector{TM}
Sigma::Matrix{TS}
Q::Matrix{TQ}
end
function MVNSampler{TM<:Real,TS<:Real}(mu::Vector{TM}, Sigma::Matrix{TS})
ATOL1, RTOL1 = 1e-8, 1e-8
ATOL2, RTOL2 = 1e-8, 1e-14
n = length(mu)
if size(Sigma) != (n, n) # Check Sigma is n x n
throw(ArgumentError(
"Sigma must be 2 dimensional and square matrix of same length to mu"
))
end
issymmetric(Sigma) || throw(ArgumentError("Sigma must be symmetric"))
C = cholfact(Symmetric(Sigma, :L), Val{true})
A = C.factors
r = C.rank
p = invperm(C.piv)
if r == n # Positive definite
Q = tril!(A)[p, p]
return MVNSampler(mu, Sigma, Q)
end
non_PSD_msg = "Sigma must be positive semidefinite"
for i in r+1:n
A[i, i] >= -ATOL1 - RTOL1 * A[1, 1] ||
throw(ArgumentError(non_PSD_msg))
end
tril!(view(A, :, 1:r))
A[:, r+1:end] = 0
Q = A[p, p]
isapprox(Q*Q', Sigma; rtol=RTOL2, atol=ATOL2) ||
throw(ArgumentError(non_PSD_msg))
return MVNSampler(mu, Sigma, Q)
end
# methods with the optional rng argument first
Base.rand(rng::AbstractRNG, d::MVNSampler) =
d.mu + d.Q * randn(rng, length(d.mu))
Base.rand(rng::AbstractRNG, d::MVNSampler, n::Integer) =
d.mu .+ d.Q * randn(rng, (length(d.mu), n))
# methods to draw from `MVNSampler`
Base.rand(d::MVNSampler) = rand(Base.GLOBAL_RNG, d)
Base.rand(d::MVNSampler, n::Integer) = rand(Base.GLOBAL_RNG, d, n)
==(f1::MVNSampler, f2::MVNSampler) =
(f1.mu == f2.mu) && (f1.Sigma == f2.Sigma) && (f1.Q == f2.Q)