Skip to content

Commit

Permalink
Merge 91536ac into b296d39
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Feb 24, 2020
2 parents b296d39 + 91536ac commit a30f222
Show file tree
Hide file tree
Showing 3 changed files with 277 additions and 2 deletions.
95 changes: 95 additions & 0 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,91 @@
## Dirichlet ##

struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
alpha::TV
alpha0::T
lmnB::T
end
function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end
Zygote.@nograd DistributionsAD.check

function TuringDirichlet(alpha::AbstractVector)
check(alpha)
alpha0 = sum(alpha)
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(alpha)
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
end

function TuringDirichlet(d::Integer, alpha::Real)
alpha0 = alpha * d
_alpha = fill(alpha, d)
lmnB = loggamma(alpha) * d - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(_alpha)
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
end
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
Tf = float(T)
TuringDirichlet(convert(AbstractVector{Tf}, alpha))
end
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))

Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector)
simplex_logpdf(d.alpha, d.lmnB, x)
end
function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix)
simplex_logpdf(d.alpha, d.lmnB, x)
end
function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T}
TV = typeof(d.alpha)
logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x)
end

ZygoteRules.@adjoint function Distributions.Dirichlet(alpha)
return pullback(TuringDirichlet, alpha)
end
ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
return pullback(TuringDirichlet, d, alpha)
end

function simplex_logpdf(alpha, lmnB, x::AbstractVector)
sum((alpha .- 1) .* log.(x)) - lmnB
end
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB)
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
sum((alpha .- 1) .* log.(c)) - lmnB
end
end

Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
.* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1))
end
end
Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
(log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
end
end

ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)
simplex_logpdf(alpha, lmnB, x), Δ ->.* log.(x), -Δ, Δ .* (alpha .- 1))
end

ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
simplex_logpdf(alpha, lmnB, x), Δ -> begin
(log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ))
end
end

## MvNormal ##

"""
Expand Down Expand Up @@ -68,13 +156,20 @@ end
function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix)
return -((size(x, 1) * log(2π) + 2 * sum(log.(d.σ))) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
end

function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
return -(length(x) * log(2π) + logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2
end
function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
end

for T in (:TrackedVector, :TrackedMatrix)
@eval function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
end
end

import StatsBase: entropy
function entropy(d::TuringDiagMvNormal)
T = eltype(d.σ)
Expand Down
182 changes: 181 additions & 1 deletion src/univariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,184 @@ function _dft_zygote(x::Vector{T}) where T
end
return copy(y)
end
=#
=#

## Categorical ##

struct TuringDiscreteNonParametric{T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} <: DiscreteUnivariateDistribution
support::Ts
p::Ps

function TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args=true) where {
T <: Real,
P <: Real,
Ts <: AbstractVector{T},
Ps <: AbstractVector{P},
}
check_args || return new{T, P, Ts, Ps}(vs, ps)
Distributions.@check_args(TuringDiscreteNonParametric, length(vs) == length(ps))
Distributions.@check_args(TuringDiscreteNonParametric, isprobvec(ps))
Distributions.@check_args(TuringDiscreteNonParametric, allunique(vs))
sort_order = sortperm(vs)
vs = vs[sort_order]
ps = ps[sort_order]
new{T, P, Ts, Ps}(vs, ps)
end
end
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
T <: Real,
P <: Real,
Ts <: AbstractVector{T},
Ps <: AbstractVector{P},
}
return TuringDiscreteNonParametric{T, P, Ts, Ps}(vs, ps; check_args = check_args)
end
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
T <: Real,
P <: Real,
Ts <: AbstractVector{T},
Ps <: SubArray,
}
_ps = collect(ps)
_Ps = typeof(ps)
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
end
function TuringDiscreteNonParametric(vs::Ts, ps::Ps; check_args=true) where {
T <: Real,
P <: Real,
Ts <: AbstractVector{T},
Ps <: TrackedVector{P, <:SubArray},
}
_ps = ps[:]
_Ps = typeof(_ps)
return TuringDiscreteNonParametric{T, P, Ts, _Ps}(vs, _ps, check_args = check_args)
end

Base.eltype(::Type{<:TuringDiscreteNonParametric{T}}) where T = T

# Accessors
Distributions.support(d::TuringDiscreteNonParametric) = d.support

Distributions.probs(d::TuringDiscreteNonParametric) = d.p

Base.isapprox(c1::D, c2::D) where D <: TuringDiscreteNonParametric =
(support(c1) support(c2) || all(support(c1) .≈ support(c2))) &&
(probs(c1) probs(c2) || all(probs(c1) .≈ probs(c2)))

function Distributions.rand(rng::AbstractRNG, d::TuringDiscreteNonParametric{T,P}) where {T,P}
x = support(d)
p = probs(d)
n = length(p)
draw = rand(rng, P)
cp = zero(P)
i = 0
while cp < draw && i < n
cp += p[i +=1]
end
x[max(i,1)]
end

Distributions.rand(d::TuringDiscreteNonParametric) = rand(GLOBAL_RNG, d)

Distributions.sampler(d::TuringDiscreteNonParametric) =
DiscreteNonParametricSampler(support(d), probs(d))

Distributions.get_evalsamples(d::TuringDiscreteNonParametric, ::Float64) = support(d)

Distributions.pdf(d::TuringDiscreteNonParametric) = copy(probs(d))

# Helper functions for pdf and cdf required to fix ambiguous method
# error involving [pc]df(::DisceteUnivariateDistribution, ::Int)
function _pdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
idx_range = searchsorted(support(d), x)
if length(idx_range) > 0
return probs(d)[first(idx_range)]
else
return zero(P)
end
end
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Int) where T = _pdf(d, convert(T, x))
Distributions.pdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _pdf(d, convert(T, x))

function _cdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
x > maximum(d) && return 1.0
s = zero(P)
ps = probs(d)
stop_idx = searchsortedlast(support(d), x)
for i in 1:stop_idx
s += ps[i]
end
return s
end
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _cdf(d, convert(T, x))
Distributions.cdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _cdf(d, convert(T, x))

function _ccdf(d::TuringDiscreteNonParametric{T,P}, x::T) where {T,P}
x < minimum(d) && return 1.0
s = zero(P)
ps = probs(d)
stop_idx = searchsortedlast(support(d), x)
for i in (stop_idx+1):length(ps)
s += ps[i]
end
return s
end
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Integer) where T = _ccdf(d, convert(T, x))
Distributions.ccdf(d::TuringDiscreteNonParametric{T}, x::Real) where T = _ccdf(d, convert(T, x))

function Distributions.quantile(d::TuringDiscreteNonParametric, q::Real)
0 <= q <= 1 || throw(DomainError())
x = support(d)
p = probs(d)
k = length(x)
i = 1
cp = p[1]
while cp < q && i < k #Note: is i < k necessary?
i += 1
@inbounds cp += p[i]
end
x[i]
end

Base.minimum(d::TuringDiscreteNonParametric) = first(support(d))
Base.maximum(d::TuringDiscreteNonParametric) = last(support(d))
Distributions.insupport(d::TuringDiscreteNonParametric, x::Real) =
length(searchsorted(support(d), x)) > 0

Distributions.mean(d::TuringDiscreteNonParametric) = dot(probs(d), support(d))

function Distributions.var(d::TuringDiscreteNonParametric{T}) where T
m = mean(d)
x = support(d)
p = probs(d)
k = length(x)
σ² = zero(T)
for i in 1:k
@inbounds σ² += abs2(x[i] - m) * p[i]
end
σ²
end

Distributions.mode(d::TuringDiscreteNonParametric) = support(d)[argmax(probs(d))]
function Distributions.modes(d::TuringDiscreteNonParametric{T,P}) where {T,P}
x = support(d)
p = probs(d)
k = length(x)
mds = T[]
max_p = zero(P)
@inbounds for i in 1:k
pi = p[i]
xi = x[i]
if pi > max_p
max_p = pi
mds = [xi]
elseif pi == max_p
push!(mds, xi)
end
end
mds
end

function Distributions.Categorical(p::TrackedVector; check_args = true)
return TuringDiscreteNonParametric(1:length(p), p, check_args = check_args)
end
2 changes: 1 addition & 1 deletion test/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ separator()
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat),
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
DistSpec(:Dirichlet, (alpha,), dir_val),
]

broken_mult_cont_dists = [
Expand All @@ -231,7 +232,6 @@ separator()
DistSpec(:MvNormalCanon, (cov_mat,), norm_val_mat),
DistSpec(:MvNormalCanon, (cov_vec,), norm_val_mat),
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_mat),
DistSpec(:Dirichlet, (alpha,), dir_val),
# Test failure
DistSpec(:MvNormal, (mean, cov_mat), norm_val_mat),
DistSpec(:MvNormal, (cov_mat,), norm_val_mat),
Expand Down

0 comments on commit a30f222

Please sign in to comment.