Skip to content

Commit

Permalink
Merge pull request #14 from TuringLang/tor/turing-diag-normal-extras
Browse files Browse the repository at this point in the history
added length, params, and entropy for TuringDiagNormal
  • Loading branch information
mohamed82008 authored Feb 10, 2020
2 parents a28e134 + b8e274c commit 80eca28
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
13 changes: 9 additions & 4 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ struct TuringDiagMvNormal{Tm<:AbstractVector, Tσ<:AbstractVector} <: Continuous
σ::Tσ
end

Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ)
Distributions.dim(d::TuringDiagMvNormal) = length(d.m)
Base.length(d::TuringDiagMvNormal) = length(d.m)
Base.size(d::TuringDiagMvNormal) = (length(d), length(d))
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal)
return d.m .+ d.σ .* randn(rng, length(d))
end
Base.size(d::TuringDiagMvNormal) = (length(d), )
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int)
return d.m .+ d.σ .* randn(rng, length(d), n)
end
Expand Down Expand Up @@ -79,6 +78,12 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
end

import StatsBase: entropy
function entropy(d::TuringDiagMvNormal)
T = eltype(d.σ)
return (length(d) * (T(log2π) + one(T)) / 2 + sum(log.(d.σ)))
end

# zero mean, dense covariance
MvNormal(A::TrackedMatrix) = TuringMvNormal(A)

Expand Down
10 changes: 10 additions & 0 deletions test/others.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using StatsBase: entropy

@testset "Others" begin
@test fill(param(1.0), 3) isa TrackedArray
x = rand(3)
Expand All @@ -11,3 +13,11 @@
B = copy(A)
@test DistributionsAD.zygote_ldiv(A, B) == A \ B
end

@testset "Extras from StatsBase.jl" begin
sigmas = exp.(randn(10))
d1 = TuringDiagMvNormal(zeros(10), sigmas)
d2 = MvNormal(zeros(10), sigmas)

@test entropy(d1) == entropy(d2)
end

0 comments on commit 80eca28

Please sign in to comment.