Skip to content

Commit

Permalink
Merge a2c2439 into a13a865
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Nov 5, 2019
2 parents a13a865 + a2c2439 commit 4942262
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand Down
8 changes: 8 additions & 0 deletions src/multivariate.jl
Expand Up @@ -27,6 +27,8 @@ struct TuringDiagNormal{Tm<:AbstractVector, Tσ<:AbstractVector} <: ContinuousMu
σ::Tσ
end

Distributions.params(d::TuringDiagNormal) = (d.m, d.σ)
Distributions.length(d::TuringDiagNormal) = length(d.m)
Distributions.dim(d::TuringDiagNormal) = length(d.m)
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagNormal)
return d.m .+ d.σ .* randn(rng, dim(d))
Expand Down Expand Up @@ -55,6 +57,12 @@ function _logpdf(d::MvNormal, x::Union{Tracker.TrackedVector, Tracker.TrackedMat
_logpdf(TuringMvNormal(d.μ, getchol(d.Σ)), x)
end

import StatsBase: entropy
function entropy(d::TuringDiagNormal)
T = eltype(d.σ)
return (length(d) * (T(log2π) + one(T)) / 2 + sum(log.(d.σ)))
end

# zero mean, dense covariance
MvNormal(A::TrackedMatrix) = MvNormal(zeros(size(A, 1)), A)

Expand Down

0 comments on commit 4942262

Please sign in to comment.