Skip to content

Commit

Permalink
Merge e1850d1 into 50a712a
Browse files Browse the repository at this point in the history
  • Loading branch information
emerali committed Oct 9, 2020
2 parents 50a712a + e1850d1 commit a26bb17
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 49 deletions.
80 changes: 31 additions & 49 deletions src/samplers/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ struct VonMisesFisherSampler
b::Float64
x0::Float64
c::Float64
Q::Matrix{Float64}
v::Vector{Float64}
end

function VonMisesFisherSampler::Vector{Float64}, κ::Float64)
p = length(μ)
b = _vmf_bval(p, κ)
x0 = (1.0 - b) / (1.0 + b)
c = κ * x0 + (p - 1) * log1p(-abs2(x0))
Q = _vmf_rotmat(μ)
VonMisesFisherSampler(p, κ, b, x0, c, Q)
v = _vmf_householder_vec(μ)
VonMisesFisherSampler(p, κ, b, x0, c, v)
end

function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler,
Expand All @@ -36,7 +36,9 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler,
end

# rotate
mul!(x, spl.Q, t)
copyto!(x, t)
scale = 2.0 * (spl.v' * t)
@. x -= (scale * spl.v)
return x
end

Expand All @@ -56,12 +58,13 @@ end

_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))

function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
# generate the W value -- the key step in simulating vMF
#
# following movMF's document
#
function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ)
ξ = rand(rng)
w = 1.0 + (log+ (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end

function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ)
r = (p - 1) / 2.0
betad = Beta(r, r)
z = rand(rng, betad)
Expand All @@ -73,50 +76,29 @@ function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
return w::Float64
end

# generate the W value -- the key step in simulating vMF
#
# following movMF's document for the p != 3 case
# and Wenzel Jakob's document for the p == 3 case
_vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) = (p == 3) ? _vmf_genw3(rng, p, b, x0, c, κ) : _vmf_genwp(rng, p, b, x0, c, κ)


_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) =
_vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ)

function _vmf_rotmat(u::Vector{Float64})
# construct a rotation matrix Q
# s.t. Q * [1,0,...,0]^T --> u
#
# Strategy: construct a full-rank matrix
# with first column being u, and then
# perform QR factorization
#

p = length(u)
A = zeros(p, p)
copyto!(view(A,:,1), u)

# let k the be index of entry with max abs
k = 1
a = abs(u[1])
for i = 2:p
@inbounds ai = abs(u[i])
if ai > a
k = i
a = ai
end
end
function _vmf_householder_vec::Vector{Float64})
# assuming μ is a unit-vector (which it should be)
# can compute v in a single pass over μ

# other columns of A will be filled with
# indicator vectors, except the one
# that activates the k-th entry
i = 1
for j = 2:p
if i == k
i += 1
end
A[i, j] = 1.0
end
p = length(μ)
v = zeros(p)
v[1] = μ[1] - 1.0
s = sqrt(-2*v[1])
v[1] /= s

# perform QR factorization
Q = Matrix(qr!(A).Q)
if dot(view(Q,:,1), u) < 0.0 # the first column was negated
for i = 1:p
@inbounds Q[i,1] = -Q[i,1]
end
@inbounds for i in 2:p
v[i] = μ[i] / s
end
return Q

return v
end
37 changes: 37 additions & 0 deletions test/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,35 @@ function gen_vmf_tdata(n::Int, p::Int,
return X
end


function test_genw3::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
p = 3

if ismissing(rng)
μ = randn(p)
else
μ = randn(rng, p)
end
μ = μ ./ norm(μ)

s = Distributions.VonMisesFisherSampler(μ, float(κ))

genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]
genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]

@test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01)
@test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ)

# test mean and stdev against analytical formulas
coth_κ = coth(κ)
mean_w = coth_κ - 1/κ
var_w = 1 - coth_κ^2 + 1/κ^2

@test isapprox(mean(genw3_res), mean_w, atol=0.01)
@test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ)
end


function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int,
rng::Union{AbstractRNG, Missing} = missing)
if ismissing(rng)
Expand Down Expand Up @@ -65,6 +94,7 @@ function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int,
x = rand(rng, d)
end
@test norm(x) 1.0
@test insupport(d, x)

if ismissing(rng)
X = rand(d, n)
Expand All @@ -73,6 +103,7 @@ function test_vonmisesfisher(p::Int, κ::Real, n::Int, ns::Int,
end
for i = 1:n
@test norm(X[:,i]) 1.0
@test insupport(d, X[:,i])
end

# MLE
Expand Down Expand Up @@ -119,4 +150,10 @@ ns = 10^6
(2, 2)]
test_vonmisesfisher(p, κ, n, ns, rng)
end

if !ismissing(rng)
@testset "Testing genw with $key at (3, )" for κ in [0.1, 0.5, 1.0, 2.0, 5.0]
test_genw3(κ, ns, rng)
end
end
end

0 comments on commit a26bb17

Please sign in to comment.