Skip to content

Commit

Permalink
Merge 0cf7b85 into c93d8e7
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Feb 15, 2020
2 parents c93d8e7 + 0cf7b85 commit 1b6460e
Show file tree
Hide file tree
Showing 10 changed files with 489 additions and 81 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ version = "0.3.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -18,10 +20,14 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Combinatorics = "0.7"
DiffRules = "0.1, 1.0"
Distributions = "0.22"
FillArrays = "0.8"
FiniteDifferences = "0.9"
ForwardDiff = "0.10.6"
PDMats = "0.9"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.7"
Expand Down
14 changes: 11 additions & 3 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ using PDMats,
StatsFuns

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, data
using ZygoteRules: ZygoteRules, pullback
TrackedVecOrMat, track, @grad, data
using ZygoteRules: ZygoteRules, @adjoint, pullback
using LinearAlgebra: copytri!
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
using Base.Iterators: drop

import StatsFuns: logsumexp,
binomlogpdf,
Expand All @@ -35,11 +38,16 @@ export TuringScalMvNormal,
TuringMvLogNormal,
TuringPoissonBinomial,
TuringWishart,
TuringInverseWishart
TuringInverseWishart,
ArrayDist,
FillDist

include("common.jl")
include("univariate.jl")
include("multivariate.jl")
include("matrixvariate.jl")
include("flatten.jl")
include("array_dist.jl")
include("multi.jl")

end
90 changes: 90 additions & 0 deletions src/array_dist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Univariate

const VectorOfUnivariate{
S <: ValueSupport,
Tdist <: UnivariateDistribution{S},
Tdists <: AbstractVector{Tdist},
} = Distributions.Product{S, Tdist, Tdists}

function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T}
if T <: TrackedReal
init_m = vcat(dists[1].μ)
means = mapreduce(vcat, drop(dists, 1); init = init_m) do d
d.μ
end
init_v = vcat(dists[1].σ^2)
vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d
d.σ^2
end
else
means = [d.μ for d in dists]
vars = [d.σ^2 for d in dists]
end

return MvNormal(means, vars)
end
function ArrayDist(dists::AbstractVector{<:UnivariateDistribution})
return Distributions.Product(dists)
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return sum(logpdf.(dist.v, x))
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# Any other more efficient implementation breaks Zygote
return [logpdf(dist, x[:,i]) for i in 1:size(x, 2)]
end
function Distributions.logpdf(
dist::VectorOfUnivariate,
x::AbstractVector{<:AbstractMatrix{<:Real}},
)
return logpdf.(Ref(dist), x)
end

struct MatrixOfUnivariate{
S <: ValueSupport,
Tdist <: UnivariateDistribution{S},
Tdists <: AbstractMatrix{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::MatrixOfUnivariate) = size(dist.dists)
function ArrayDist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# Broadcasting here breaks Tracker for some reason
return sum(zip(dist.dists, x)) do (dist, x)
logpdf(dist, x)
end
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
end

# Multivariate

struct VectorOfMultivariate{
S <: ValueSupport,
Tdist <: MultivariateDistribution{S},
Tdists <: AbstractVector{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist))
Base.length(dist::VectorOfMultivariate) = length(dist.dists)
function ArrayDist(dists::AbstractVector{<:MultivariateDistribution})
return VectorOfMultivariate(dists)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist))
end
function Distributions.logpdf(
dist::VectorOfMultivariate,
x::AbstractVector{<:AbstractVector{<:Real}},
)
return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist))
end
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
end
56 changes: 47 additions & 9 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
## Generic ##

Base.one(::Irrational) = 1

function Base.fill(
value::TrackedReal,
dims::Vararg{Union{Integer, AbstractUnitRange}},
)
return track(fill, value, dims...)
end
Tracker.@grad function Base.fill(value::Real, dims...)
@grad function Base.fill(value::Real, dims...)
return fill(data(value), dims...), function(Δ)
size(Δ) dims && error("Dimension mismatch")
return (sum(Δ), map(_->nothing, dims)...)
Expand All @@ -16,15 +18,15 @@ end
## StatsFuns ##

logsumexp(x::TrackedArray) = track(logsumexp, x)
Tracker.@grad function logsumexp(x::TrackedArray)
@grad function logsumexp(x::TrackedArray)
lse = logsumexp(data(x))
return lse, Δ ->.* exp.(x .- lse),)
end

## Linear algebra ##

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
end

Expand All @@ -39,27 +41,27 @@ function turing_chol(A::AbstractMatrix, check)
(chol.factors, chol.info)
end
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
@grad function turing_chol(A::AbstractMatrix, check)
C, back = pullback(unsafe_cholesky, data(A), data(check))
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
end

unsafe_cholesky(x, check) = cholesky(x, check=check)
ZygoteRules.@adjoint function unsafe_cholesky::Real, check)
@adjoint function unsafe_cholesky::Real, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero(Σ), nothing)
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky::Diagonal, check)
@adjoint function unsafe_cholesky::Diagonal, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero.factors), nothing)
Expand All @@ -78,7 +80,7 @@ end
# Specialised logdet for cholesky to target the triangle directly.
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
@grad function logdet_chol_tri(U::AbstractMatrix)
U_data = data(U)
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
end
Expand All @@ -88,6 +90,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
end

# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.

zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat)
return track(zygote_ldiv, A, B)
Expand All @@ -96,11 +99,46 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
return track(zygote_ldiv, A, B)
end
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
Tracker.@grad function zygote_ldiv(A, B)
@grad function zygote_ldiv(A, B)
Y, back = pullback(\, data(A), data(B))
return Y, Δ->back(data(Δ))
end

function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat)
return (a.U \ (a.U' \ b))
end

# SpecialFunctions

function SpecialFunctions.logabsgamma(x::TrackedReal)
v = loggamma(x)
return v, sign(data(v))
end

# Some Tracker fixes

for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat]
if :TrackedReal in c
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
end
end
@grad function vcat(x::Real)
vcat(data(x)), (Δ) -> (Δ[1],)
end
@grad function vcat(x1::Real, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
end
@grad function vcat(x1::AbstractVector, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
end

# Zygote fill has issues with non-numbers

@adjoint function fill(x::T, dims...) where {T}
function zfill(x, dims...,)
return reshape([x for i in 1:prod(dims)], dims)
end
pullback(zfill, x, dims...)
end
76 changes: 76 additions & 0 deletions src/flatten.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
macro register(dist)
return quote
DistributionsAD.eval(getexpr($(esc(dist))))
DistributionsAD.toflatten(::$(esc(dist))) = true
end
end
function getexpr(Tdist)
x = gensym()
fnames = fieldnames(Tdist)
flattened_args = Expr(:tuple, [:(dist.$f) for f in fnames]...)
func = Expr(:->,
Expr(:tuple, fnames..., x),
Expr(:block,
Expr(:call, :logpdf,
Expr(:call, :($Tdist), fnames...),
x,
)
)
)
return :(flatten(dist::$Tdist) = ($func, $flattened_args))
end
const flattened_dists = [ Bernoulli,
BetaBinomial,
Binomial,
Geometric,
NegativeBinomial,
Poisson,
Skellam,
PoissonBinomial,
Arcsine,
Beta,
BetaPrime,
Biweight,
Cauchy,
Chernoff,
Chi,
Chisq,
Cosine,
Epanechnikov,
Erlang,
Exponential,
FDist,
Frechet,
Gamma,
GeneralizedExtremeValue,
GeneralizedPareto,
Gumbel,
InverseGamma,
InverseGaussian,
Kolmogorov,
Laplace,
Levy,
LocationScale,
Logistic,
LogitNormal,
LogNormal,
Normal,
NormalCanon,
NormalInverseGaussian,
Pareto,
PGeneralizedGaussian,
Rayleigh,
SymTriangularDist,
TDist,
TriangularDist,
Triweight,
Categorical,
Truncated,
]
for T in flattened_dists
@eval toflatten(::$T) = true
end
toflatten(::Distribution) = false
for T in flattened_dists
eval(getexpr(T))
end
4 changes: 2 additions & 2 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ end

## Adjoints

ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
return pullback(TuringWishart, df, S)
end
ZygoteRules.@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
return pullback(TuringInverseWishart, df, S)
end

Expand Down

0 comments on commit 1b6460e

Please sign in to comment.