Skip to content

Commit

Permalink
fix and test most Matrix MvNormal cases
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Jan 16, 2020
1 parent 69ba53e commit 4dca48c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 34 deletions.
29 changes: 18 additions & 11 deletions src/multivariate.jl
Expand Up @@ -14,10 +14,10 @@ function TuringDenseMvNormal(m::AbstractVector, A::AbstractMatrix)
end
Base.length(d::TuringDenseMvNormal) = length(d.m)
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal)
return d.m .+ d.C.U' * randn(rng, dim(d))
return d.m .+ d.C.U' * randn(rng, length(d))
end
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal, n::Int)
return d.m .+ d.C.U' * randn(rng, dim(d), n)
return d.m .+ d.C.U' * randn(rng, length(d), n)
end

"""
Expand Down Expand Up @@ -60,23 +60,23 @@ for T in (:AbstractVector, :AbstractMatrix)
end

function _logpdf(d::TuringScalMvNormal, x::AbstractVector)
return -(length(x) * log(2π * abs2(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2
return -(length(x) * log(2π * abs2(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))) / 2
end
function _logpdf(d::TuringScalMvNormal, x::AbstractMatrix)
return -(size(x, 2) * log(2π) .+ 2 * sum(log(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2
return -(size(x, 1) * log(2π * abs2(d.σ)) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
end

function _logpdf(d::TuringDiagMvNormal, x::AbstractVector)
return -(length(x) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2
return -(length(x) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))) / 2
end
function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix)
return -(size(x, 2) * log(2π) .+ 2 * sum(log.(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2
return -((size(x, 1) * log(2π) + 2 * sum(log.(d.σ))) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
end
function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
return -(length(x) * log(2π) + logdet(d.C) + sum(abs2, zygote_ldiv(d.C.U', x .- d.m))) / 2
return -(length(x) * log(2π) + logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2
end
function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
return -(size(x, 2) * log(2π) .+ logdet(d.C) .+ sum(abs2, zygote_ldiv(d.C.U', x .- d.m), dims=1)') ./ 2
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
end

# zero mean, dense covariance
Expand Down Expand Up @@ -137,7 +137,7 @@ struct TuringMvLogNormal{TD} <: AbstractMvLogNormal
end
MvLogNormal(d::TuringDenseMvNormal) = TuringMvLogNormal(d)
MvLogNormal(d::TuringDiagMvNormal) = TuringMvLogNormal(d)
Distributions.dim(d::TuringMvLogNormal) = length(d.normal)
Distributions.length(d::TuringMvLogNormal) = length(d.normal)
function Distributions.rand(rng::Random.AbstractRNG, d::TuringMvLogNormal)
return exp!(rand(rng, d.normal))
end
Expand All @@ -147,8 +147,15 @@ end
for T in (:(Tracker.TrackedVector), :(Tracker.TrackedMatrix))
@eval Distributions.logpdf(d::TuringMvLogNormal, x::$T) = _logpdf(d, x)
end
function _logpdf(d::TuringMvLogNormal, x::AbstractVecOrMat{T}) where {T<:Real}
return insupport(d, x) ? (_logpdf(d.normal, log.(x)) - sum(log.(x))) : -Inf
function _logpdf(d::TuringMvLogNormal, x::AbstractVector{T}) where {T<:Real}
return insupport(d, x) ? (_logpdf(d.normal, log.(x)) - sum(log.(x))) : -T(Inf)
end
function _logpdf(d::TuringMvLogNormal, x::AbstractMatrix{T}) where {T<:Real}
if all(insupport(d, x))
return _logpdf(d.normal, log.(x)) - vec(sum(log.(x), dims=1))
else
return fill(-T(Inf), size(x, 2))
end
end

# zero mean, dense covariance
Expand Down
66 changes: 45 additions & 21 deletions test/runtests.jl
Expand Up @@ -15,7 +15,8 @@ mean = zeros(dim)
cov_mat = Matrix{Float64}(I, dim, dim)
cov_vec = ones(dim)
cov_num = 1.0
norm_val = ones(dim)
norm_val_vec = ones(dim)
norm_val_mat = ones(dim, 2)
alpha = ones(4)
dir_val = fill(0.25, 4)
beta_mat = rand(MatrixBeta(dim, dim, dim))
Expand Down Expand Up @@ -183,32 +184,55 @@ separator()
@testset "Multivariate continuous distributions" begin
test_head("Testing: Multivariate continuous distributions")
mult_cont_dists = [
DistSpec(:MvNormal, (mean, cov_mat), norm_val),
DistSpec(:MvNormal, (mean, cov_vec), norm_val),
DistSpec(:MvNormal, (mean, cov_num), norm_val),
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val),
DistSpec(:MvNormal, (cov_mat,), norm_val),
DistSpec(:MvNormal, (cov_vec,), norm_val),
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val),
DistSpec(:MvLogNormal, (mean, cov_mat), norm_val),
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val),
DistSpec(:MvLogNormal, (mean, cov_num), norm_val),
DistSpec(:MvLogNormal, (cov_mat,), norm_val),
DistSpec(:MvLogNormal, (cov_vec,), norm_val),
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val),
# Vector case
DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec),
DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec),
DistSpec(:MvNormal, (mean, cov_num), norm_val_vec),
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_vec),
DistSpec(:MvNormal, (cov_mat,), norm_val_vec),
DistSpec(:MvNormal, (cov_vec,), norm_val_vec),
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_vec),
DistSpec(:MvLogNormal, (mean, cov_mat), norm_val_vec),
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_vec),
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_vec),
DistSpec(:MvLogNormal, (cov_mat,), norm_val_vec),
DistSpec(:MvLogNormal, (cov_vec,), norm_val_vec),
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec),
# Matrix case
DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat),
DistSpec(:MvNormal, (mean, cov_num), norm_val_mat),
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_mat),
DistSpec(:MvNormal, (cov_vec,), norm_val_mat),
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_mat),
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_mat),
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_mat),
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
]

broken_mult_cont_dists = [
# Dispatch error
DistSpec(:MvNormalCanon, (mean, cov_mat), norm_val),
DistSpec(:MvNormalCanon, (mean, cov_vec), norm_val),
DistSpec(:MvNormalCanon, (mean, cov_num), norm_val),
DistSpec(:MvNormalCanon, (cov_mat,), norm_val),
DistSpec(:MvNormalCanon, (cov_vec,), norm_val),
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val),
DistSpec(:MvNormalCanon, (mean, cov_mat), norm_val_vec),
DistSpec(:MvNormalCanon, (mean, cov_vec), norm_val_vec),
DistSpec(:MvNormalCanon, (mean, cov_num), norm_val_vec),
DistSpec(:MvNormalCanon, (cov_mat,), norm_val_vec),
DistSpec(:MvNormalCanon, (cov_vec,), norm_val_vec),
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_vec),
DistSpec(:Dirichlet, (alpha,), dir_val),
DistSpec(:MvNormalCanon, (mean, cov_mat), norm_val_mat),
DistSpec(:MvNormalCanon, (mean, cov_vec), norm_val_mat),
DistSpec(:MvNormalCanon, (mean, cov_num), norm_val_mat),
DistSpec(:MvNormalCanon, (cov_mat,), norm_val_mat),
DistSpec(:MvNormalCanon, (cov_vec,), norm_val_mat),
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_mat),
DistSpec(:Dirichlet, (alpha,), dir_val),
# Test failure
DistSpec(:(() -> Product(Normal.(randn(dim), 1))), (), norm_val),
DistSpec(:MvNormal, (mean, cov_mat), norm_val_mat),
DistSpec(:MvNormal, (cov_mat,), norm_val_mat),
DistSpec(:MvLogNormal, (mean, cov_mat), norm_val_mat),
DistSpec(:MvLogNormal, (cov_mat,), norm_val_mat),
DistSpec(:(() -> Product(Normal.(randn(dim), 1))), (), norm_val_vec),
DistSpec(:(() -> Product(Normal.(randn(dim), 1))), (), norm_val_mat),
]

for d in mult_cont_dists
Expand Down
23 changes: 21 additions & 2 deletions test/test_utils.jl
Expand Up @@ -47,7 +47,16 @@ function get_function(dist::DistSpec, inds, val)
if val
sym = gensym()
push!(syms, sym)
expr = :(($(syms...),) -> logpdf($(dist.name)($(args...)), $(sym)))
expr = quote
($(syms...),) -> begin
temp = logpdf($(dist.name)($(args...)), $(sym))
if temp isa AbstractVector
return sum(temp)
else
return temp
end
end
end
if length(inds) == 0
f = x -> Base.invokelatest(eval(expr), unpack(x, dist.x)...)
return ADTestFunction(string(expr), f, pack(dist.x))
Expand All @@ -57,7 +66,17 @@ function get_function(dist::DistSpec, inds, val)
end
else
@assert length(inds) > 0
expr = :(($(syms...),) -> logpdf($(dist.name)($(args...)), $(dist.x)))
expr = quote
($(syms...),) -> begin
d = $(dist.name)($(args...))
temp = logpdf(d, $(dist.x))
if temp isa AbstractVector
return sum(temp)
else
return temp
end
end
end
f = x -> Base.invokelatest(eval(expr), unpack(x, dist.θ[inds]...)...)
return ADTestFunction(string(expr), f, pack(dist.θ[inds]...))
end
Expand Down

0 comments on commit 4dca48c

Please sign in to comment.