From 839c882d4f79baf6d314ad3592041a8880c562c2 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Mon, 8 Jul 2019 21:26:44 +0100 Subject: [PATCH 01/83] initial implementation of new interface --- src/Bijectors.jl | 14 +++++- src/interface.jl | 110 ++++++++++++++++++++++++++++++++++++++++++++++ test/interface.jl | 52 ++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 175 insertions(+), 2 deletions(-) create mode 100644 src/interface.jl create mode 100644 test/interface.jl diff --git a/src/Bijectors.jl b/src/Bijectors.jl index ea59e59f..3d693f01 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -6,7 +6,7 @@ using StatsFuns using LinearAlgebra using MappedArrays -export TransformDistribution, +export TransformDistribution, RealDistribution, PositiveDistribution, UnitDistribution, @@ -14,7 +14,15 @@ export TransformDistribution, PDMatDistribution, link, invlink, - logpdf_with_trans + logpdf_with_trans, + transform, + inverse, + logdetinvjac, + Bijector, + DefaultBijector, + transformed, + UnivariateTransformed, + MultivariateTransformed const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) @@ -420,4 +428,6 @@ function logpdf_with_trans( return logpdf_with_trans.(Ref(d), X, Ref(transform)) end +include("interface.jl") + end # module diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 00000000..fb8a3ed2 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,110 @@ +using Distributions, Bijectors +using ForwardDiff +using Tracker + +using Turing + +import Random: AbstractRNG +import Distributions: logpdf, rand, rand!, _rand!, _logpdf + + +abstract type Bijector end +abstract type ADBijector{AD} <: Bijector end + +Broadcast.broadcastable(b::Bijector) = Ref(b) + +"Computes the transformation." +transform(b::Bijector, x) = begin end +transform(b::Bijector) = x -> transform(b, x) + +"Computes the inverse transformation of the Bijector." +inverse(b::Bijector, y) = begin end +inverse(b::Bijector) = y -> inverse(b, y) + +# TODO: rename? a bit of a mouthful +# TODO: allow batch-computation, especially for univariate case +"Computes the absolute determinant of the Jacobian of the inverse-transformation." +logdetinvjac(b::Bijector, y) = begin end +logdetinvjac(b::ADBijector{AD}, y::T) where {AD <: Turing.Core.ForwardDiffAD, T <: Real} = log(abs(ForwardDiff.derivative(z -> inverse(b, z), y))) +logdetinvjac(b::ADBijector{AD}, y::AbstractVector) where AD <: Turing.Core.ForwardDiffAD = logabsdet(ForwardDiff.jacobian(z -> inverse(b, z), y))[1] + +# FIXME: untrack? i.e. `Tracker.data(...)` +logdetinvjac(b::ADBijector{AD}, y::T) where {AD <: Turing.Core.TrackerAD, T <: Real} = log(abs(Tracker.gradient(z -> inverse(b, z[1]), [y])[1][1])) +logdetinvjac(b::ADBijector{AD}, y::AbstractVector) where AD <: Turing.Core.TrackerAD = logabsdet(Tracker.jacobian(z -> inverse(b, z), y))[1] + +# Example bijector +struct Identity <: Bijector end +transform(::Identity, x) = x +inverse(::Identity, y) = y +logdetinvjac(::Identity, y::T) where T <: Real = zero(T) +logdetinvjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) + +# Simply uses `link` and `invlink` as transforms with AD to get jacobian +struct DistributionBijector{AD, D} <: ADBijector{AD} where D <: Distribution + dist::D +end +DistributionBijector(dist::D) where D <: Distribution = DistributionBijector{Turing.Core.ADBackend(), D}(dist) + +transform(b::DistributionBijector, x) = link(b.dist, x) +inverse(b::DistributionBijector, y) = invlink(b.dist, y) + +# Transformed distributions +struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} + dist::D + transform::B +end + +struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} where {D <: MultivariateDistribution, B <: Bijector} + dist::D + transform::B +end + + +# implement these on a case-by-case basis, e.g. `PDMatDistribution = Union{InverseWishart, Wishart}` +transformed(d::UnivariateDistribution, b::Bijector) = UnivariateTransformed(d, b) +transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed(d, b) + +transformed(d) = transformed(d, DistributionBijector(d)) + +# can specialize further by +transformed(d::Normal) = transformed(d, Identity()) + +############################## +# Distributions.jl interface # +############################## + +# size +Base.length(td::MultivariateTransformed) = length(td.dist) + +# logp +function logpdf(td::UnivariateTransformed, y::T where T <: Real) + # FIXME: `logpdf_with_trans` give different results from the this: + # logpdf(td.dist, inverse(td.transform, y)) .+ logdetinvjac(td.transform, y) + + logpdf_with_trans(td.dist, y, true) +end +function _logpdf(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) + # FIXME: same as above + # logpdf(td.dist, inverse(td.transform, y)) .+ logdetinvjac(td.transform, y) + logpdf_with_trans(td.dist, y, true) +end + +function logpdf_with_jac(td::UnivariateTransformed, y::T where T <: Real) + # FIXME: different results from `logpdf`; see above + z = logdetinvjac(td.transform, y) + return (logpdf(td.dist, inverse(td.transform, y)) .+ z, z) +end + +function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) + # FIXME: different results from `logpdf`; see above + z = logdetinvjac(td.transform, y) + return (logpdf(td.dist, inverse(td.transform, y)) .+ z, z) +end + +# rand +rand(rng::AbstractRNG, td::UnivariateTransformed) = transform(td.transform, rand(td.dist)) +function _rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) + rand!(rng, td.dist, x) + y = transform(td.transform, x) + copyto!(x, y) +end diff --git a/test/interface.jl b/test/interface.jl new file mode 100644 index 00000000..6129a85d --- /dev/null +++ b/test/interface.jl @@ -0,0 +1,52 @@ +using Test +using Bijectors +using Random + +Random.seed!(123) + +# Scalar tests +@testset "Interface" begin + # Tests with scalar-valued distributions. + uni_dists = [ + Arcsine(2, 4), + Beta(2,2), + BetaPrime(), + Biweight(), + Cauchy(), + Chi(3), + Chisq(2), + Cosine(), + Epanechnikov(), + Erlang(), + Exponential(), + FDist(1, 1), + Frechet(), + Gamma(), + InverseGamma(), + InverseGaussian(), + # Kolmogorov(), + Laplace(), + Levy(), + Logistic(), + LogNormal(1.0, 2.5), + Normal(0.1, 2.5), + Pareto(), + Rayleigh(1.0), + TDist(2), + TruncatedNormal(0, 1, -Inf, 2), + ] + + for dist in uni_dists + @testset "$dist" begin + td = transformed(dist) + + # single sample + y = rand(td) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, y, true) + + # # multi-sample + y = rand(td, 10) + @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, y, true) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 66760196..1e5e773e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,4 +2,5 @@ using Bijectors, Random Random.seed!(123456) +include("interface.jl") include("transform.jl") From 65fce15ba3e9d480b5943e2a1108ea59831fc84e Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Mon, 8 Jul 2019 21:58:38 +0100 Subject: [PATCH 02/83] fixed the testing --- src/interface.jl | 10 +++------- test/interface.jl | 6 ++++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index fb8a3ed2..65424e64 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -78,25 +78,21 @@ Base.length(td::MultivariateTransformed) = length(td.dist) # logp function logpdf(td::UnivariateTransformed, y::T where T <: Real) - # FIXME: `logpdf_with_trans` give different results from the this: # logpdf(td.dist, inverse(td.transform, y)) .+ logdetinvjac(td.transform, y) - - logpdf_with_trans(td.dist, y, true) + logpdf_with_trans(td.dist, inverse(td.transform, y), true) end function _logpdf(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) - # FIXME: same as above # logpdf(td.dist, inverse(td.transform, y)) .+ logdetinvjac(td.transform, y) - logpdf_with_trans(td.dist, y, true) + logpdf_with_trans(td.dist, inverse(td.transform, y), true) end +# TODO: implement these using analytical expressions? function logpdf_with_jac(td::UnivariateTransformed, y::T where T <: Real) - # FIXME: different results from `logpdf`; see above z = logdetinvjac(td.transform, y) return (logpdf(td.dist, inverse(td.transform, y)) .+ z, z) end function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) - # FIXME: different results from `logpdf`; see above z = logdetinvjac(td.transform, y) return (logpdf(td.dist, inverse(td.transform, y)) .+ z, z) end diff --git a/test/interface.jl b/test/interface.jl index 6129a85d..4f5f002c 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -42,11 +42,13 @@ Random.seed!(123) # single sample y = rand(td) - @test logpdf(td, y) ≈ logpdf_with_trans(dist, y, true) + x = inverse(td.transform, y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) # # multi-sample y = rand(td, 10) - @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, y, true) + x = inverse(td.transform).(y) + @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) end end end From 0be22a656dc81d1ef8558b2f537b96ca6229dc5a Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Mon, 8 Jul 2019 23:04:22 +0100 Subject: [PATCH 03/83] now adheres to the style guide --- src/interface.jl | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 65424e64..03cf0537 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -25,12 +25,20 @@ inverse(b::Bijector) = y -> inverse(b, y) # TODO: allow batch-computation, especially for univariate case "Computes the absolute determinant of the Jacobian of the inverse-transformation." logdetinvjac(b::Bijector, y) = begin end -logdetinvjac(b::ADBijector{AD}, y::T) where {AD <: Turing.Core.ForwardDiffAD, T <: Real} = log(abs(ForwardDiff.derivative(z -> inverse(b, z), y))) -logdetinvjac(b::ADBijector{AD}, y::AbstractVector) where AD <: Turing.Core.ForwardDiffAD = logabsdet(ForwardDiff.jacobian(z -> inverse(b, z), y))[1] +function logdetinvjac(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) + log(abs(ForwardDiff.derivative(z -> inverse(b, z), y))) +end +function logdetinvjac(b::ADBijector{<:Turing.Core.ForwardDiffAD}, y::AbstractVector{<:Real}) + logabsdet(ForwardDiff.jacobian(z -> inverse(b, z), y))[1] +end # FIXME: untrack? i.e. `Tracker.data(...)` -logdetinvjac(b::ADBijector{AD}, y::T) where {AD <: Turing.Core.TrackerAD, T <: Real} = log(abs(Tracker.gradient(z -> inverse(b, z[1]), [y])[1][1])) -logdetinvjac(b::ADBijector{AD}, y::AbstractVector) where AD <: Turing.Core.TrackerAD = logabsdet(Tracker.jacobian(z -> inverse(b, z), y))[1] +function logdetinvjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::Real) + log(abs(Tracker.gradient(z -> inverse(b, z[1]), [y])[1][1])) +end +function logdetinvjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector{<: Real}) + logabsdet(Tracker.jacobian(z -> inverse(b, z), y))[1] +end # Example bijector struct Identity <: Bijector end @@ -43,24 +51,28 @@ logdetinvjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) struct DistributionBijector{AD, D} <: ADBijector{AD} where D <: Distribution dist::D end -DistributionBijector(dist::D) where D <: Distribution = DistributionBijector{Turing.Core.ADBackend(), D}(dist) +function DistributionBijector(dist::D) where D <: Distribution + DistributionBijector{Turing.Core.ADBackend(), D}(dist) +end transform(b::DistributionBijector, x) = link(b.dist, x) inverse(b::DistributionBijector, y) = invlink(b.dist, y) # Transformed distributions -struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} +struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} + where {D <: UnivariateDistribution, B <: Bijector} dist::D transform::B end -struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} where {D <: MultivariateDistribution, B <: Bijector} +struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} + where {D <: MultivariateDistribution, B <: Bijector} dist::D transform::B end -# implement these on a case-by-case basis, e.g. `PDMatDistribution = Union{InverseWishart, Wishart}` +# Can implement these on a case-by-case basis transformed(d::UnivariateDistribution, b::Bijector) = UnivariateTransformed(d, b) transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed(d, b) @@ -99,7 +111,7 @@ end # rand rand(rng::AbstractRNG, td::UnivariateTransformed) = transform(td.transform, rand(td.dist)) -function _rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) +function _rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{<: Real}) rand!(rng, td.dist, x) y = transform(td.transform, x) copyto!(x, y) From 16d83aa35bf1a4fdb6d2c13d7345b8354f4c66dd Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Mon, 8 Jul 2019 23:13:49 +0100 Subject: [PATCH 04/83] cant have type templates on next line... --- src/interface.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 03cf0537..963526d3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -59,14 +59,12 @@ transform(b::DistributionBijector, x) = link(b.dist, x) inverse(b::DistributionBijector, y) = invlink(b.dist, y) # Transformed distributions -struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} - where {D <: UnivariateDistribution, B <: Bijector} +struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} dist::D transform::B end -struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} - where {D <: MultivariateDistribution, B <: Bijector} +struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} where {D <: MultivariateDistribution, B <: Bijector} dist::D transform::B end From 4c90e1efcaa01490d88deb3b119a3648bc876508 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Wed, 10 Jul 2019 20:15:51 +0100 Subject: [PATCH 05/83] made more similar to the interface suggested by @xukai92 --- src/Bijectors.jl | 7 +++--- src/interface.jl | 64 ++++++++++++++++++++++++++++------------------- test/interface.jl | 5 ++-- 3 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 3d693f01..ae30ca01 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -16,10 +16,11 @@ export TransformDistribution, invlink, logpdf_with_trans, transform, - inverse, - logdetinvjac, + forward, + logabsdetjac, Bijector, - DefaultBijector, + Inversed, + DistributionBijector, transformed, UnivariateTransformed, MultivariateTransformed diff --git a/src/interface.jl b/src/interface.jl index 963526d3..64761708 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -11,41 +11,53 @@ import Distributions: logpdf, rand, rand!, _rand!, _logpdf abstract type Bijector end abstract type ADBijector{AD} <: Bijector end +struct Inversed{B <: Bijector} <: Bijector + orig::B +end + Broadcast.broadcastable(b::Bijector) = Ref(b) -"Computes the transformation." -transform(b::Bijector, x) = begin end +logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = + error("`logabsdetjacob(b::$T1, y::$T2)` is not implemented.") +forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = + error("`forward(b::$T1, y::$T2)` is not implemented.") +transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = + error("`transform(b::$T1, y::$T2)` is not implemented.") + + transform(b::Bijector) = x -> transform(b, x) +forward(ib::Inversed{<: Bijector}, y) = (transform(ib, y), logabsdetjac(ib, y)) +logabsdetjac(ib::Inversed{<: Bijector}, y) = - logabsdetjac(ib.orig, transform(ib, y)) + +Base.inv(b::Bijector) = Inversed(b) +Base.inv(ib::Inversed{<:Bijector}) = ib.orig -"Computes the inverse transformation of the Bijector." -inverse(b::Bijector, y) = begin end -inverse(b::Bijector) = y -> inverse(b, y) # TODO: rename? a bit of a mouthful # TODO: allow batch-computation, especially for univariate case "Computes the absolute determinant of the Jacobian of the inverse-transformation." -logdetinvjac(b::Bijector, y) = begin end -function logdetinvjac(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) - log(abs(ForwardDiff.derivative(z -> inverse(b, z), y))) +function logabsdetjac(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) + log(abs(ForwardDiff.derivative(z -> transform(b, z), y))) end -function logdetinvjac(b::ADBijector{<:Turing.Core.ForwardDiffAD}, y::AbstractVector{<:Real}) - logabsdet(ForwardDiff.jacobian(z -> inverse(b, z), y))[1] +function logabsdetjac(b::ADBijector{<:Turing.Core.ForwardDiffAD}, y::AbstractVector{<:Real}) + logabsdet(ForwardDiff.jacobian(z -> transform(b, z), y))[1] end # FIXME: untrack? i.e. `Tracker.data(...)` -function logdetinvjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::Real) - log(abs(Tracker.gradient(z -> inverse(b, z[1]), [y])[1][1])) +function logabsdetjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::Real) + log(abs(Tracker.gradient(z -> transform(b, z[1]), [y])[1][1])) end -function logdetinvjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector{<: Real}) - logabsdet(Tracker.jacobian(z -> inverse(b, z), y))[1] +function logabsdetjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector{<: Real}) + logabsdet(Tracker.jacobian(z -> transform(b, z), y))[1] end # Example bijector struct Identity <: Bijector end transform(::Identity, x) = x -inverse(::Identity, y) = y -logdetinvjac(::Identity, y::T) where T <: Real = zero(T) -logdetinvjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) +transform(::Inversed{Identity}, y) = y +forward(::Identity, x) = (x, zero(x)) +logabsdetjac(::Identity, y::T) where T <: Real = zero(T) +logabsdetjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) # Simply uses `link` and `invlink` as transforms with AD to get jacobian struct DistributionBijector{AD, D} <: ADBijector{AD} where D <: Distribution @@ -56,7 +68,7 @@ function DistributionBijector(dist::D) where D <: Distribution end transform(b::DistributionBijector, x) = link(b.dist, x) -inverse(b::DistributionBijector, y) = invlink(b.dist, y) +transform(ib::Inversed{<: DistributionBijector}, y) = invlink(ib.orig.dist, y) # Transformed distributions struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} @@ -88,23 +100,23 @@ Base.length(td::MultivariateTransformed) = length(td.dist) # logp function logpdf(td::UnivariateTransformed, y::T where T <: Real) - # logpdf(td.dist, inverse(td.transform, y)) .+ logdetinvjac(td.transform, y) - logpdf_with_trans(td.dist, inverse(td.transform, y), true) + # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) + logpdf_with_trans(td.dist, transform(inv(td.transform), y), true) end function _logpdf(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) - # logpdf(td.dist, inverse(td.transform, y)) .+ logdetinvjac(td.transform, y) - logpdf_with_trans(td.dist, inverse(td.transform, y), true) + # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) + logpdf_with_trans(td.dist, transform(inv(td.transform), y), true) end # TODO: implement these using analytical expressions? function logpdf_with_jac(td::UnivariateTransformed, y::T where T <: Real) - z = logdetinvjac(td.transform, y) - return (logpdf(td.dist, inverse(td.transform, y)) .+ z, z) + z = logabsdetjac(inv(td.transform), y) + return (logpdf(td.dist, transform(inv(td.transform), y)) .+ z, z) end function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) - z = logdetinvjac(td.transform, y) - return (logpdf(td.dist, inverse(td.transform, y)) .+ z, z) + z = logabsdetjac(td.transform, y) + return (logpdf(td.dist, transform(inv(td.transform), y)) .+ z, z) end # rand diff --git a/test/interface.jl b/test/interface.jl index 4f5f002c..3d15ba41 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,3 +1,4 @@ +using Revise using Test using Bijectors using Random @@ -42,12 +43,12 @@ Random.seed!(123) # single sample y = rand(td) - x = inverse(td.transform, y) + x = transform(inv(td.transform), y) @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) # # multi-sample y = rand(td, 10) - x = inverse(td.transform).(y) + x = transform.(inv(td.transform), y) @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) end end From c182b1ff0bfb3791161e5c460712b08e72de9064 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Wed, 10 Jul 2019 20:19:51 +0100 Subject: [PATCH 06/83] fixed typo in error message --- src/interface.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 64761708..d5cd29f4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -18,7 +18,7 @@ end Broadcast.broadcastable(b::Bijector) = Ref(b) logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = - error("`logabsdetjacob(b::$T1, y::$T2)` is not implemented.") + error("`logabsdetjac(b::$T1, y::$T2)` is not implemented.") forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = error("`forward(b::$T1, y::$T2)` is not implemented.") transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = @@ -115,7 +115,7 @@ function logpdf_with_jac(td::UnivariateTransformed, y::T where T <: Real) end function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) - z = logabsdetjac(td.transform, y) + z = logabsdetjac(inv(td.transform), y) return (logpdf(td.dist, transform(inv(td.transform), y)) .+ z, z) end From 52b5deb95a1594cad122f7726f7d163599a51ca0 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Wed, 10 Jul 2019 20:58:22 +0100 Subject: [PATCH 07/83] added composition of Bijectors --- src/interface.jl | 46 +++++++++++++++++++++++++++++++++++++++------- test/interface.jl | 17 +++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d5cd29f4..f628b569 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -7,6 +7,8 @@ using Turing import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf +import Base: inv + abstract type Bijector end abstract type ADBijector{AD} <: Bijector end @@ -17,23 +19,34 @@ end Broadcast.broadcastable(b::Bijector) = Ref(b) +"Computes the log(abs(det(J(x)))) where J is the jacobian of the transform." logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = error("`logabsdetjac(b::$T1, y::$T2)` is not implemented.") + +"Transforms the input using the bijector." +transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = + error("`transform(b::$T1, y::$T2)` is not implemented.") + +"Computes both `transform` and `logabsdetjac` in one forward pass." forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = error("`forward(b::$T1, y::$T2)` is not implemented.") -transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = - error("`transform(b::$T1, y::$T2)` is not implemented.") transform(b::Bijector) = x -> transform(b, x) -forward(ib::Inversed{<: Bijector}, y) = (transform(ib, y), logabsdetjac(ib, y)) + +# default `forward` implementations; should in general implement efficient way +# of computing both `transform` and `logabsdetjac` together. +forward(b::Bijector, x) = (rv=transform(b, x), logabsdetjac=logabsdetjac(b, x)) +forward(ib::Inversed{<: Bijector}, y) = (rv=transform(ib, y), logabsdetjac=logabsdetjac(ib, y)) + +# defaults implementation for inverses logabsdetjac(ib::Inversed{<: Bijector}, y) = - logabsdetjac(ib.orig, transform(ib, y)) -Base.inv(b::Bijector) = Inversed(b) -Base.inv(ib::Inversed{<:Bijector}) = ib.orig +inv(b::Bijector) = Inversed(b) +inv(ib::Inversed{<:Bijector}) = ib.orig +# AD implementations -# TODO: rename? a bit of a mouthful # TODO: allow batch-computation, especially for univariate case "Computes the absolute determinant of the Jacobian of the inverse-transformation." function logabsdetjac(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) @@ -51,11 +64,30 @@ function logabsdetjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector logabsdet(Tracker.jacobian(z -> transform(b, z), y))[1] end +# Composition + +struct Composed{B<:Bijector} <: Bijector + ts::Vector{B} +end + +compose(ts...) = Composed([ts...]) + +inv(ct::Composed{B}) where {B<:Bijector} = Composed(map(inv, reverse(ct.ts))) + +function forward(ct::Composed{<:Bijector}, x) + res = (rv=x, logabsdetjac=0) + for t in ct.ts + res′ = forward(t, res.rv) + res = (rv=res′.rv, logabsdetjac=res.logabsdetjac + res′.logabsdetjac) + end + return res +end + # Example bijector struct Identity <: Bijector end transform(::Identity, x) = x transform(::Inversed{Identity}, y) = y -forward(::Identity, x) = (x, zero(x)) +forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) logabsdetjac(::Identity, y::T) where T <: Real = zero(T) logabsdetjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) diff --git a/test/interface.jl b/test/interface.jl index 3d15ba41..ec2d2e1c 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -52,4 +52,21 @@ Random.seed!(123) @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) end end + + @testset "Composition" begin + d = Beta() + td = transformed(d) + + x = rand(d) + y = transform(td.transform, x) + + forward(td.transform, x) + forward(inv(td.transform), y) + + b = Bijectors.compose(td.transform, Bijectors.Identity()) + ib = inv(b) + + @test forward(b, x) == forward(td.transform, x) + @test forward(ib, y) == forward(inv(td.transform), y) + end end From 88c681d9fc5ca14ff74fee8880f6fa0c4e31e9a0 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 11 Jul 2019 14:19:50 +0100 Subject: [PATCH 08/83] improved composing compositions and added overloading of circ --- src/interface.jl | 60 +++++++++++++++++++++++++++++++++++++++++------ test/interface.jl | 27 ++++++++++++++++++--- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index f628b569..088bc1fa 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -7,7 +7,7 @@ using Turing import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf -import Base: inv +import Base: inv, ∘ abstract type Bijector end @@ -33,6 +33,7 @@ forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = transform(b::Bijector) = x -> transform(b, x) +(ib::Inversed{<: Bijector})(y) = transform(ib, y) # default `forward` implementations; should in general implement efficient way # of computing both `transform` and `logabsdetjac` together. @@ -64,34 +65,77 @@ function logabsdetjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector logabsdet(Tracker.jacobian(z -> transform(b, z), y))[1] end -# Composition +############### +# Composition # +############### struct Composed{B<:Bijector} <: Bijector ts::Vector{B} end -compose(ts...) = Composed([ts...]) +function compose(ts...) + res = [] + + for b ∈ ts + if b isa Composed + # "lift" the transformations + for b_ ∈ b.ts + push!(res, b_) + end + else + push!(res, b) + end + end + + Composed([res...]) +end + +# The transformation of `Composed` applies functions left-to-right +# but in mathematics we usually go from right-to-left; this reversal ensures that +# when we use the mathematical composition ∘ we get the expected behavior. +# TODO: change behavior of `transform` of `Composed`? +∘(b1::B1, b2::B2) where {B1 <: Bijector, B2 <: Bijector} = Bijectors.compose(b2, b1) inv(ct::Composed{B}) where {B<:Bijector} = Composed(map(inv, reverse(ct.ts))) -function forward(ct::Composed{<:Bijector}, x) +# TODO: can we implement this recursively, and with aggressive inlining, make this type-stable? +function transform(cb::Composed{<: Bijector}, x) + res = x + for b ∈ cb.ts + res = transform(b, res) + end + + return res +end + +(cb::Composed{<: Bijector})(x) = transform(cb, x) + +function forward(cb::Composed{<:Bijector}, x) res = (rv=x, logabsdetjac=0) - for t in ct.ts + for t in cb.ts res′ = forward(t, res.rv) res = (rv=res′.rv, logabsdetjac=res.logabsdetjac + res′.logabsdetjac) end return res end -# Example bijector +############################## +# Example bijector: Identity # +############################## + struct Identity <: Bijector end transform(::Identity, x) = x transform(::Inversed{Identity}, y) = y +(b::Identity)(x) = transform(b, x) + forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) + logabsdetjac(::Identity, y::T) where T <: Real = zero(T) logabsdetjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) -# Simply uses `link` and `invlink` as transforms with AD to get jacobian +####################################################### +# Constrained to unconstrained distribution bijectors # +####################################################### struct DistributionBijector{AD, D} <: ADBijector{AD} where D <: Distribution dist::D end @@ -99,8 +143,10 @@ function DistributionBijector(dist::D) where D <: Distribution DistributionBijector{Turing.Core.ADBackend(), D}(dist) end +# Simply uses `link` and `invlink` as transforms with AD to get jacobian transform(b::DistributionBijector, x) = link(b.dist, x) transform(ib::Inversed{<: DistributionBijector}, y) = invlink(ib.orig.dist, y) +(b::DistributionBijector)(x) = transform(b, x) # Transformed distributions struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} diff --git a/test/interface.jl b/test/interface.jl index ec2d2e1c..49d739f3 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -60,13 +60,34 @@ Random.seed!(123) x = rand(d) y = transform(td.transform, x) - forward(td.transform, x) - forward(inv(td.transform), y) - b = Bijectors.compose(td.transform, Bijectors.Identity()) ib = inv(b) @test forward(b, x) == forward(td.transform, x) @test forward(ib, y) == forward(inv(td.transform), y) + + # inverse works fine for composition + cb = b ∘ ib + @test transform(cb, x) ≈ x + + cb2 = cb ∘ cb + @test transform(cb, x) ≈ x + + # order of composed evaluation + b1 = DistributionBijector(d) + b2 = DistributionBijector(Gamma()) + + cb = b1 ∘ b2 + @test cb(x) ≈ b1(b2(x)) + end + + @testset "Example: ADVI" begin + # Usage in ADVI + d = Beta() + b = DistributionBijector(d) # [0, 1] → ℝ + ib = inv(b) # ℝ → [0, 1] + td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] + x = rand(td) # ∈ [0, 1] + @test 0 ≤ x ≤ 1 # => true end end From b413972901a47d3cdcbb09e804598cb9fda015cc Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 11 Jul 2019 14:56:07 +0100 Subject: [PATCH 09/83] composition with its own inverse now results in Identity --- src/Bijectors.jl | 4 ++++ src/interface.jl | 10 ++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index ae30ca01..a6e5f57b 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -19,7 +19,11 @@ export TransformDistribution, forward, logabsdetjac, Bijector, + ADBijector, Inversed, + Composed, + compose, + Identity, DistributionBijector, transformed, UnivariateTransformed, diff --git a/src/interface.jl b/src/interface.jl index 088bc1fa..d67bc81b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -83,11 +83,17 @@ function compose(ts...) push!(res, b_) end else - push!(res, b) + # TODO: do we want this? + if (length(res) > 0) && (res[end] == inv(b)) + # remove if inverse + pop!(res) + else + push!(res, b) + end end end - Composed([res...]) + length(res) == 0 ? Identity() : Composed([res...]) end # The transformation of `Composed` applies functions left-to-right From aa657232fe38a50ec096acec48f37cddc49eb83b Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 11 Jul 2019 15:00:28 +0100 Subject: [PATCH 10/83] made a constant IdentityBijector --- src/interface.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d67bc81b..2189be67 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -93,7 +93,7 @@ function compose(ts...) end end - length(res) == 0 ? Identity() : Composed([res...]) + length(res) == 0 ? IdentityBijector : Composed([res...]) end # The transformation of `Composed` applies functions left-to-right @@ -139,6 +139,8 @@ forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) logabsdetjac(::Identity, y::T) where T <: Real = zero(T) logabsdetjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) +const IdentityBijector = Identity() + ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### @@ -173,7 +175,7 @@ transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed( transformed(d) = transformed(d, DistributionBijector(d)) # can specialize further by -transformed(d::Normal) = transformed(d, Identity()) +transformed(d::Normal) = transformed(d, IdentityBijector) ############################## # Distributions.jl interface # From 08f5e1ac8bf4f474fbdbe5e279efde7d7868e1b6 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 11 Jul 2019 15:20:54 +0100 Subject: [PATCH 11/83] added Logit bijector and specialization for Beta distribution --- src/Bijectors.jl | 1 + src/interface.jl | 27 ++++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index a6e5f57b..c87e9047 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -25,6 +25,7 @@ export TransformDistribution, compose, Identity, DistributionBijector, + bijector, transformed, UnivariateTransformed, MultivariateTransformed diff --git a/src/interface.jl b/src/interface.jl index 2189be67..904c4257 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -141,6 +141,24 @@ logabsdetjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) const IdentityBijector = Identity() +############################### +# Example: Logit and Logistic # +############################### +using StatsFuns: logit, logistic + +struct Logit{T<:Real} <: Bijector + a::T + b::T +end + +transform(b::Logit, x::Real) = logit((x - b.a) / (b.b - b.a)) +transform(b::Inversed{Logit{<:Real}}, y::Real) = (b.b - b.a) * logistic(y) + b.a +(b::Logit)(x) = transform(b, x) + +logabsdetjac(b::Logit{<:Real}, x::Real) = log((x - b.a) * (b.b - x) / (b.b - b.a)) +forward(b::Logit, x::Real) = (rv=transform(b, x), logabsdetjac=-logabsdetjac(b, x)) + + ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### @@ -156,6 +174,9 @@ transform(b::DistributionBijector, x) = link(b.dist, x) transform(ib::Inversed{<: DistributionBijector}, y) = invlink(ib.orig.dist, y) (b::DistributionBijector)(x) = transform(b, x) +"Returns the constrained-to-unconstrained bijector for distribution `d`." +bijector(d::Distribution) = DistributionBijector(d) + # Transformed distributions struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} dist::D @@ -171,11 +192,11 @@ end # Can implement these on a case-by-case basis transformed(d::UnivariateDistribution, b::Bijector) = UnivariateTransformed(d, b) transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed(d, b) - -transformed(d) = transformed(d, DistributionBijector(d)) +transformed(d) = transformed(d, bijector(d)) # can specialize further by -transformed(d::Normal) = transformed(d, IdentityBijector) +bijector(d::Normal) = IdentityBijector +bijector(d::Beta{T}) where T <: Real = Logit(zero(T), one(T)) ############################## # Distributions.jl interface # From 2e715971b87bf61e20d8b88ff5e9fb12beb65d6f Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 11 Jul 2019 15:24:54 +0100 Subject: [PATCH 12/83] fixed a typo in transform of inv(Logit) --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 904c4257..d632f611 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -152,7 +152,7 @@ struct Logit{T<:Real} <: Bijector end transform(b::Logit, x::Real) = logit((x - b.a) / (b.b - b.a)) -transform(b::Inversed{Logit{<:Real}}, y::Real) = (b.b - b.a) * logistic(y) + b.a +transform(ib::Inversed{Logit{T}}, y::Real) where T <: Real = (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a (b::Logit)(x) = transform(b, x) logabsdetjac(b::Logit{<:Real}, x::Real) = log((x - b.a) * (b.b - x) / (b.b - b.a)) From 625917e398486ad18e2d07a7f7ccd19d94f65de0 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Fri, 12 Jul 2019 14:45:56 +0100 Subject: [PATCH 13/83] fixed typo --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index d632f611..b98ad1a7 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -100,7 +100,7 @@ end # but in mathematics we usually go from right-to-left; this reversal ensures that # when we use the mathematical composition ∘ we get the expected behavior. # TODO: change behavior of `transform` of `Composed`? -∘(b1::B1, b2::B2) where {B1 <: Bijector, B2 <: Bijector} = Bijectors.compose(b2, b1) +∘(b1::B1, b2::B2) where {B1 <: Bijector, B2 <: Bijector} = compose(b2, b1) inv(ct::Composed{B}) where {B<:Bijector} = Composed(map(inv, reverse(ct.ts))) From d043ef4c68a29f988d56428e295194f550377111 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Sat, 13 Jul 2019 11:51:23 +0100 Subject: [PATCH 14/83] added jacobian function for bijectors, with AD implementations --- src/interface.jl | 28 +++++++++++++++++----------- test/interface.jl | 22 ++++++++++++++++++++-- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index b98ad1a7..5cc6622c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -47,22 +47,28 @@ inv(b::Bijector) = Inversed(b) inv(ib::Inversed{<:Bijector}) = ib.orig # AD implementations +# FIXME: `Inverse` of `ADBijector` is NOT a an `ADBijector` +function jacobian(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) + ForwardDiff.derivative(z -> transform(b, z), y) +end +function jacobian(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::AbstractVector{<: Real}) + ForwardDiff.jacobian(z -> transform(b, z), y) +end -# TODO: allow batch-computation, especially for univariate case -"Computes the absolute determinant of the Jacobian of the inverse-transformation." -function logabsdetjac(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) - log(abs(ForwardDiff.derivative(z -> transform(b, z), y))) +function jacobian(b::ADBijector{<: Turing.Core.TrackerAD}, y::Real) + Tracker.gradient(z -> transform(b, z), y)[1] end -function logabsdetjac(b::ADBijector{<:Turing.Core.ForwardDiffAD}, y::AbstractVector{<:Real}) - logabsdet(ForwardDiff.jacobian(z -> transform(b, z), y))[1] +function jacobian(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector{<: Real}) + Tracker.jacobian(z -> transform(b, z), y) end -# FIXME: untrack? i.e. `Tracker.data(...)` -function logabsdetjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::Real) - log(abs(Tracker.gradient(z -> transform(b, z[1]), [y])[1][1])) +# TODO: allow batch-computation, especially for univariate case? +"Computes the absolute determinant of the Jacobian of the inverse-transformation." +function logabsdetjac(b::ADBijector, y::Real) + log(abs(jacobian(b, y))) end -function logabsdetjac(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector{<: Real}) - logabsdet(Tracker.jacobian(z -> transform(b, z), y))[1] +function logabsdetjac(b::ADBijector, y::AbstractVector{<:Real}) + logabsdet(jacobian(b, y))[1] end ############### diff --git a/test/interface.jl b/test/interface.jl index 49d739f3..8245e870 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -2,6 +2,8 @@ using Revise using Test using Bijectors using Random +using Turing +using LinearAlgebra Random.seed!(123) @@ -38,7 +40,7 @@ Random.seed!(123) ] for dist in uni_dists - @testset "$dist" begin + @testset "$dist: dist" begin td = transformed(dist) # single sample @@ -46,11 +48,27 @@ Random.seed!(123) x = transform(inv(td.transform), y) @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - # # multi-sample + # multi-sample y = rand(td, 10) x = transform.(inv(td.transform), y) @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) end + + @testset "$dist: ForwardDiff AD" begin + x = rand(dist) + b = DistributionBijector{Turing.Core.ADBackend(:forward_diff), typeof(dist)}(dist) + + @test abs(det(Bijectors.jacobian(b, x))) > 0 + @test logabsdetjac(b, x) ≠ Inf + end + + @testset "$dist: Tracker AD" begin + x = rand(dist) + b = DistributionBijector{Turing.Core.ADBackend(:reverse_diff), typeof(dist)}(dist) + + @test abs(det(Bijectors.jacobian(b, x))) > 0 + @test logabsdetjac(b, x) ≠ Inf + end end @testset "Composition" begin From 1940748fb557cfd30fb72709b313f9163f86075e Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Sat, 13 Jul 2019 11:55:09 +0100 Subject: [PATCH 15/83] added jacobian using AD for inverses of ADBijector --- src/interface.jl | 11 ++++++++--- test/interface.jl | 10 ++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 5cc6622c..9b63194a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -47,18 +47,23 @@ inv(b::Bijector) = Inversed(b) inv(ib::Inversed{<:Bijector}) = ib.orig # AD implementations -# FIXME: `Inverse` of `ADBijector` is NOT a an `ADBijector` function jacobian(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) ForwardDiff.derivative(z -> transform(b, z), y) end -function jacobian(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::AbstractVector{<: Real}) +function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.ForwardDiffAD}}, y::Real) + ForwardDiff.derivative(z -> transform(b, z), y) +end +function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.ForwardDiffAD}}, y::AbstractVector{<: Real}) ForwardDiff.jacobian(z -> transform(b, z), y) end function jacobian(b::ADBijector{<: Turing.Core.TrackerAD}, y::Real) Tracker.gradient(z -> transform(b, z), y)[1] end -function jacobian(b::ADBijector{<: Turing.Core.TrackerAD}, y::AbstractVector{<: Real}) +function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.TrackerAD}}, y::Real) + Tracker.gradient(z -> transform(b, z), y)[1] +end +function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.TrackerAD}}, y::AbstractVector{<: Real}) Tracker.jacobian(z -> transform(b, z), y) end diff --git a/test/interface.jl b/test/interface.jl index 8245e870..2a2a63d7 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -60,6 +60,11 @@ Random.seed!(123) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf + + y = transform(b, x) + b⁻¹ = inv(b) + @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 + @test logabsdetjac(b⁻¹, y) ≠ Inf end @testset "$dist: Tracker AD" begin @@ -68,6 +73,11 @@ Random.seed!(123) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf + + y = transform(b, x) + b⁻¹ = inv(b) + @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 + @test logabsdetjac(b⁻¹, y) ≠ Inf end end From 15f1af236940a8eaf09ac1bc098ba42a505fcdab Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Jul 2019 18:34:02 +0200 Subject: [PATCH 16/83] added simple AD-types and removed dep on Turing --- Manifest.toml | 119 ++++++++++++++++++++++++++++++++++++++++------- Project.toml | 3 ++ src/interface.jl | 44 ++++++++++++++---- 3 files changed, 139 insertions(+), 27 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index dbb1a0a8..2dce8390 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,8 +1,14 @@ +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "1.0.0" + [[Arpack]] -deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Random", "SparseArrays", "Test"] -git-tree-sha1 = "1ce1ce9984683f0b6a587d5bdbc688ecb480096f" +deps = ["BinaryProvider", "Libdl", "LinearAlgebra"] +git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6" uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" -version = "0.3.0" +version = "0.3.1" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -14,10 +20,22 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" version = "0.8.10" [[BinaryProvider]] -deps = ["Libdl", "Pkg", "SHA", "Test"] -git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" +deps = ["Libdl", "Logging", "SHA"] +git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.3" +version = "0.5.6" + +[[CSTParser]] +deps = ["Tokenize"] +git-tree-sha1 = "376a39f1862000442011390f1edf5e7f4dcc7142" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "0.6.0" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -25,11 +43,17 @@ git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" version = "2.1.0" +[[Crayons]] +deps = ["Test"] +git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.0.0" + [[DataStructures]] -deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] -git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.15.0" +version = "0.17.0" [[Dates]] deps = ["Printf"] @@ -39,6 +63,18 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +[[DiffResults]] +deps = ["Compat", "StaticArrays"] +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "0.0.4" + +[[DiffRules]] +deps = ["Random", "Test"] +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "0.0.10" + [[Distributed]] deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -49,6 +85,12 @@ git-tree-sha1 = "022e6610c320b6e19b454502d759c672580abe00" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" version = "0.18.0" +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.3" + [[InteractiveUtils]] deps = ["LinearAlgebra", "Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -66,6 +108,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[MacroTools]] +deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"] +git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.1" + [[MappedArrays]] deps = ["Test"] git-tree-sha1 = "923441c5ac942b60bd3a842d5377d96646bcbf46" @@ -77,14 +125,26 @@ deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[Missings]] -deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] -git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" +deps = ["SparseArrays", "Test"] +git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.0" +version = "0.4.1" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] +git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.6.0" + +[[NaNMath]] +deps = ["Compat"] +git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.2" + [[OrderedCollections]] deps = ["Random", "Serialization", "Test"] git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" @@ -93,9 +153,9 @@ version = "1.1.0" [[PDMats]] deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "b6c91fc0ab970c0563cbbe69af18d741a49ce551" +git-tree-sha1 = "8b68513175b2dc4023a564cb0e917ce90e74fd69" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.9.6" +version = "0.9.7" [[Pkg]] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -166,15 +226,21 @@ git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "0.7.2" +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.11.0" + [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] -deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] -git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94" +deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "2b6ca97be7ddfad5d9f16a13fe277d29f3d11c23" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.29.0" +version = "0.31.0" [[StatsFuns]] deps = ["Rmath", "SpecialFunctions", "Test"] @@ -183,13 +249,30 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "0.8.0" [[SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "SparseArrays"] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[TimerOutputs]] +deps = ["Crayons", "Printf", "Test", "Unicode"] +git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.0" + +[[Tokenize]] +git-tree-sha1 = "c8a8b00ae44a94950814ff77850470711a360225" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.5" + +[[Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] +git-tree-sha1 = "327342fec6e09f68ced0c2dc5731ed475e4b696b" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.2.2" + [[URIParser]] deps = ["Test", "Unicode"] git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" diff --git a/Project.toml b/Project.toml index cacd2bc7..5a2ec43d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,11 +4,14 @@ version = "0.3.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/src/interface.jl b/src/interface.jl index 9b63194a..68ddad5d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -2,13 +2,39 @@ using Distributions, Bijectors using ForwardDiff using Tracker -using Turing +import Base: inv, ∘ import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf -import Base: inv, ∘ +####################################### +# AD stuff "extracted" from Turing.jl # +####################################### + +abstract type ADBackend end +struct ForwardDiffAD <: ADBackend end +struct TrackerAD <: ADBackend end + +const ADBACKEND = Ref(:forward) +function setadbackend(backend_sym) + @assert backend_sym == :forward_diff || backend_sym == :reverse_diff + backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40) + ADBACKEND[] = backend_sym +end + +ADBackend() = ADBackend(ADBACKEND[]) +ADBackend(T::Symbol) = ADBackend(Val(T)) +function ADBackend(::Val{T}) where {T} + if T === :forward_diff + return ForwardDiffAD + else + return TrackerAD + end +end +###################### +# Bijector interface # +###################### abstract type Bijector end abstract type ADBijector{AD} <: Bijector end @@ -47,23 +73,23 @@ inv(b::Bijector) = Inversed(b) inv(ib::Inversed{<:Bijector}) = ib.orig # AD implementations -function jacobian(b::ADBijector{<: Turing.Core.ForwardDiffAD}, y::Real) +function jacobian(b::ADBijector{<: ForwardDiffAD}, y::Real) ForwardDiff.derivative(z -> transform(b, z), y) end -function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.ForwardDiffAD}}, y::Real) +function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::Real) ForwardDiff.derivative(z -> transform(b, z), y) end -function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.ForwardDiffAD}}, y::AbstractVector{<: Real}) +function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::AbstractVector{<: Real}) ForwardDiff.jacobian(z -> transform(b, z), y) end -function jacobian(b::ADBijector{<: Turing.Core.TrackerAD}, y::Real) +function jacobian(b::ADBijector{<: TrackerAD}, y::Real) Tracker.gradient(z -> transform(b, z), y)[1] end -function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.TrackerAD}}, y::Real) +function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::Real) Tracker.gradient(z -> transform(b, z), y)[1] end -function jacobian(b::Inversed{<: ADBijector{<: Turing.Core.TrackerAD}}, y::AbstractVector{<: Real}) +function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::AbstractVector{<: Real}) Tracker.jacobian(z -> transform(b, z), y) end @@ -177,7 +203,7 @@ struct DistributionBijector{AD, D} <: ADBijector{AD} where D <: Distribution dist::D end function DistributionBijector(dist::D) where D <: Distribution - DistributionBijector{Turing.Core.ADBackend(), D}(dist) + DistributionBijector{ADBackend(), D}(dist) end # Simply uses `link` and `invlink` as transforms with AD to get jacobian From 9758bcdebaab2ecfff72c4f809475e963a5ee601 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 23 Jul 2019 10:54:12 +0200 Subject: [PATCH 17/83] Add norm flows --- src/Bijectors.jl | 2 ++ src/norm_flows.jl | 88 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 src/norm_flows.jl diff --git a/src/Bijectors.jl b/src/Bijectors.jl index c87e9047..a790cb77 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -436,4 +436,6 @@ end include("interface.jl") +include("norm_flows.jl") + end # module diff --git a/src/norm_flows.jl b/src/norm_flows.jl new file mode 100644 index 00000000..6949fe3d --- /dev/null +++ b/src/norm_flows.jl @@ -0,0 +1,88 @@ +using Distributions +using LinearAlgebra +using Random +using Flux + + +######################################################################################################################### +# Planar and Radial Flows : Variational Inference with Normalizing Flows, D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # +######################################################################################################################### + +mutable struct PlanarLayer <: Bijector + w + + u + u_hat + b +end + +mutable struct RadialLayer <: Bijector + α + β + z_not +end + +function update_u_hat(u, w) + # to preserve invertibility + u_hat = u + (m(transpose(w)*u) - transpose(w)*u)[1]*w/(norm(w[:,1],2)^2) +end + +function update_u_hat!(flow::PlanarLayer) + flow.u_hat = flow.u + (m(transpose(flow.w)*flow.u) - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) +end + + +function PlanarLayer(dims::Int) + w = param(randn(dims, 1)) + u = param(randn(dims, 1)) + b = param(randn(1)) + u_hat = update_u_hat(u, w) + return PlanarLayer(w, u, u_hat, b) +end + +function RadialLayer(dims::Int) + α_ = params(randn(1)) + β = params(randn(1)) + z_not = param(randn(dims, 1)) + return RadialLayer(α, β, z_not) +end + +m(x) = -1 .+ log.(1 .+ exp.(x)) #for planar flow +dtanh(x) = 1 .- (tanh.(x)).^2 #for planar flow +ψ(z, w, b) = dtanh(transpose(w)*z .+ b).*w #for planar flow +softplus(x) = log.(1 .+ exp.(x)) #for radial flow +h(α, r) = 1 ./ (α .+ r) #for radial flow +dh(α, r) = -dh(α, r).^2 #for radial flow + +function transform(flow::PlanarLayer, z) + return z + flow.u_hat*tanh.(transpose(flow.w)*z .+ flow.b) +end + +function transform(flow::RadialLayer, z) + α = softplus(flow.α_) + β_hat = -α + softplus(flow.β) + r = norm.(z - flow.z_not, 1) + return z + β_hat*h(α, r)*(z - flow.z_not) +end + +function forward(flow::T, z) where {T<:PlanarLayer} + update_u_hat!(flow) + # compute log_det_jacobian + transformed = transform(flow, z) + psi = ψ(transformed, flow.w, flow.b) + log_det_jacobian = log.(abs.(1.0 .+ transpose(psi)*flow.u_hat)) + + return (rv=transformed, logabsdetjacob=Bijector) +end + + +function forward(flow::T, z) where {T<:RadialLayer} + # compute log_det_jacobian + transformed = transform(flow, z) + α = softplus(flow.α_) + β_hat = -α + softplus(flow.β) + r = norm.(z - flow.z_not, 1) + d = size(flow.z_not)[1] + log_det_jacobian = log.(((1 + β_hat*h(α, r)).^(d-1)) .* ( 1 + β_hat*h(α, r) + β_hat*dh(α, r)*r)) + return (rv=transformed, logabsdetjacob=log_det_jacobian) +end \ No newline at end of file From b8c22db35d90a35443481d572f1d6d23839b4caa Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 23 Jul 2019 11:03:27 +0200 Subject: [PATCH 18/83] minor changes --- src/norm_flows.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 6949fe3d..5f2ed94e 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -4,13 +4,15 @@ using Random using Flux -######################################################################################################################### -# Planar and Radial Flows : Variational Inference with Normalizing Flows, D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # -######################################################################################################################### +################################################################################ +# Planar and Radial Flows # +# Ref: Variational Inference with Normalizing Flows, # +# D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # +################################################################################ mutable struct PlanarLayer <: Bijector w - + u u_hat b @@ -28,7 +30,8 @@ function update_u_hat(u, w) end function update_u_hat!(flow::PlanarLayer) - flow.u_hat = flow.u + (m(transpose(flow.w)*flow.u) - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) + flow.u_hat = flow.u + (m(transpose(flow.w)*flow.u) \ + - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) end @@ -83,6 +86,7 @@ function forward(flow::T, z) where {T<:RadialLayer} β_hat = -α + softplus(flow.β) r = norm.(z - flow.z_not, 1) d = size(flow.z_not)[1] - log_det_jacobian = log.(((1 + β_hat*h(α, r)).^(d-1)) .* ( 1 + β_hat*h(α, r) + β_hat*dh(α, r)*r)) + log_det_jacobian = log.(((1 + β_hat*h(α, r)).^(d-1)) \ + .* ( 1 + β_hat*h(α, r) + β_hat*dh(α, r)*r)) return (rv=transformed, logabsdetjacob=log_det_jacobian) end \ No newline at end of file From 22b1936822c45cd6c3e4924be3477b2e051275c9 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 23 Jul 2019 21:40:49 +0200 Subject: [PATCH 19/83] fix bugs --- src/Bijectors.jl | 28 +++++++++++++++------------- src/norm_flows.jl | 26 +++++++++++--------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index a790cb77..fb638655 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -5,8 +5,8 @@ using Reexport, Requires using StatsFuns using LinearAlgebra using MappedArrays - -export TransformDistribution, +using Tracker: param +export TransformDistribution, RealDistribution, PositiveDistribution, UnitDistribution, @@ -28,7 +28,9 @@ export TransformDistribution, bijector, transformed, UnivariateTransformed, - MultivariateTransformed + MultivariateTransformed, + PlanarLayer, + RadialLayer const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) @@ -177,8 +179,8 @@ function _clamp(x::T, dist::SimplexDistribution) where T end function link( - d::SimplexDistribution, - x::AbstractVector{T}, + d::SimplexDistribution, + x::AbstractVector{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} y, K = similar(x), length(x) @@ -206,8 +208,8 @@ end # Vectorised implementation of the above. function link( - d::SimplexDistribution, - X::AbstractMatrix{T}, + d::SimplexDistribution, + X::AbstractMatrix{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} Y, K, N = similar(X), size(X, 1), size(X, 2) @@ -234,8 +236,8 @@ function link( end function invlink( - d::SimplexDistribution, - y::AbstractVector{T}, + d::SimplexDistribution, + y::AbstractVector{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} x, K = similar(y), length(y) @@ -260,8 +262,8 @@ end # Vectorised implementation of the above. function invlink( - d::SimplexDistribution, - Y::AbstractMatrix{T}, + d::SimplexDistribution, + Y::AbstractMatrix{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} X, K, N = similar(Y), size(Y, 1), size(Y, 2) @@ -355,8 +357,8 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{T}) where {T<:Real} end function logpdf_with_trans( - d::PDMatDistribution, - X::AbstractMatrix{<:Real}, + d::PDMatDistribution, + X::AbstractMatrix{<:Real}, transform::Bool ) T = eltype(X) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 5f2ed94e..c6c06a0f 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -1,8 +1,6 @@ using Distributions using LinearAlgebra using Random -using Flux - ################################################################################ # Planar and Radial Flows # @@ -19,7 +17,7 @@ mutable struct PlanarLayer <: Bijector end mutable struct RadialLayer <: Bijector - α + α_ β z_not end @@ -30,8 +28,7 @@ function update_u_hat(u, w) end function update_u_hat!(flow::PlanarLayer) - flow.u_hat = flow.u + (m(transpose(flow.w)*flow.u) \ - - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) + flow.u_hat = flow.u + (m(transpose(flow.w)*flow.u) - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) end @@ -44,10 +41,10 @@ function PlanarLayer(dims::Int) end function RadialLayer(dims::Int) - α_ = params(randn(1)) - β = params(randn(1)) + α_ = param(randn(1)) + β = param(randn(1)) z_not = param(randn(dims, 1)) - return RadialLayer(α, β, z_not) + return RadialLayer(α_, β, z_not) end m(x) = -1 .+ log.(1 .+ exp.(x)) #for planar flow @@ -64,8 +61,8 @@ end function transform(flow::RadialLayer, z) α = softplus(flow.α_) β_hat = -α + softplus(flow.β) - r = norm.(z - flow.z_not, 1) - return z + β_hat*h(α, r)*(z - flow.z_not) + r = norm.(z .- flow.z_not, 1) + return z + β_hat.*h(α, r).*(z .- flow.z_not) end function forward(flow::T, z) where {T<:PlanarLayer} @@ -75,7 +72,7 @@ function forward(flow::T, z) where {T<:PlanarLayer} psi = ψ(transformed, flow.w, flow.b) log_det_jacobian = log.(abs.(1.0 .+ transpose(psi)*flow.u_hat)) - return (rv=transformed, logabsdetjacob=Bijector) + return (rv=transformed, logabsdetjacob=log_det_jacobian) end @@ -84,9 +81,8 @@ function forward(flow::T, z) where {T<:RadialLayer} transformed = transform(flow, z) α = softplus(flow.α_) β_hat = -α + softplus(flow.β) - r = norm.(z - flow.z_not, 1) + r = norm.(z .- flow.z_not, 1) d = size(flow.z_not)[1] - log_det_jacobian = log.(((1 + β_hat*h(α, r)).^(d-1)) \ - .* ( 1 + β_hat*h(α, r) + β_hat*dh(α, r)*r)) + log_det_jacobian = log.(((1.0 .+ β_hat.*h(α, r)).^(d-1)) .* ( 1.0 .+ β_hat.*h(α, r) + β_hat.*dh(α, r).*r)) return (rv=transformed, logabsdetjacob=log_det_jacobian) -end \ No newline at end of file +end From 5719e5dd075422a6a5234410a6083b6a9273fd3b Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 23 Jul 2019 22:01:18 +0200 Subject: [PATCH 20/83] fix more bugs --- src/norm_flows.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index c6c06a0f..7f600623 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -52,7 +52,7 @@ dtanh(x) = 1 .- (tanh.(x)).^2 #for planar flow ψ(z, w, b) = dtanh(transpose(w)*z .+ b).*w #for planar flow softplus(x) = log.(1 .+ exp.(x)) #for radial flow h(α, r) = 1 ./ (α .+ r) #for radial flow -dh(α, r) = -dh(α, r).^2 #for radial flow +dh(α, r) = -h(α, r).^2 #for radial flow function transform(flow::PlanarLayer, z) return z + flow.u_hat*tanh.(transpose(flow.w)*z .+ flow.b) @@ -81,7 +81,7 @@ function forward(flow::T, z) where {T<:RadialLayer} transformed = transform(flow, z) α = softplus(flow.α_) β_hat = -α + softplus(flow.β) - r = norm.(z .- flow.z_not, 1) + r = transpose(norm.([z[:,i] .- flow.z_not[:,:] for i in 1:size(z)[2]], 1)) d = size(flow.z_not)[1] log_det_jacobian = log.(((1.0 .+ β_hat.*h(α, r)).^(d-1)) .* ( 1.0 .+ β_hat.*h(α, r) + β_hat.*dh(α, r).*r)) return (rv=transformed, logabsdetjacob=log_det_jacobian) From 8104a284589c9028d4d08c32ca75a0886da14dbc Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 26 Jul 2019 00:18:18 +0200 Subject: [PATCH 21/83] fix bugs and follow style guide --- src/Bijectors.jl | 1 + src/norm_flows.jl | 32 +++++++++++++++----------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index fb638655..e20dacbb 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -6,6 +6,7 @@ using StatsFuns using LinearAlgebra using MappedArrays using Tracker: param + export TransformDistribution, RealDistribution, PositiveDistribution, diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 7f600623..9bbb629f 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -1,6 +1,7 @@ using Distributions using LinearAlgebra using Random +using StatsFuns: softplus ################################################################################ # Planar and Radial Flows # @@ -10,7 +11,6 @@ using Random mutable struct PlanarLayer <: Bijector w - u u_hat b @@ -24,11 +24,11 @@ end function update_u_hat(u, w) # to preserve invertibility - u_hat = u + (m(transpose(w)*u) - transpose(w)*u)[1]*w/(norm(w[:,1],2)^2) + u_hat = u + (planar_flow_m(transpose(w)*u) - transpose(w)*u)[1]*w/(norm(w[:,1],2)^2) end function update_u_hat!(flow::PlanarLayer) - flow.u_hat = flow.u + (m(transpose(flow.w)*flow.u) - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) + flow.u_hat = flow.u + (planar_flow_m(transpose(flow.w)*flow.u) - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) end @@ -47,29 +47,27 @@ function RadialLayer(dims::Int) return RadialLayer(α_, β, z_not) end -m(x) = -1 .+ log.(1 .+ exp.(x)) #for planar flow -dtanh(x) = 1 .- (tanh.(x)).^2 #for planar flow -ψ(z, w, b) = dtanh(transpose(w)*z .+ b).*w #for planar flow -softplus(x) = log.(1 .+ exp.(x)) #for radial flow -h(α, r) = 1 ./ (α .+ r) #for radial flow -dh(α, r) = -h(α, r).^2 #for radial flow +planar_flow_m(x) = -1 .+ log.(1 .+ exp.(x)) # For planar flow +dtanh(x) = 1 .- (tanh.(x)).^2 # For planar flow +ψ(z, w, b) = dtanh(transpose(w)*z .+ b).*w # For planar flow +h(α, r) = 1 ./ (α .+ r) # For radial flow +dh(α, r) = -h(α, r).^2 # For radial flow function transform(flow::PlanarLayer, z) return z + flow.u_hat*tanh.(transpose(flow.w)*z .+ flow.b) end function transform(flow::RadialLayer, z) - α = softplus(flow.α_) - β_hat = -α + softplus(flow.β) + α = softplus(flow.α_[1]) + β_hat = -α + softplus(flow.β[1]) r = norm.(z .- flow.z_not, 1) return z + β_hat.*h(α, r).*(z .- flow.z_not) end function forward(flow::T, z) where {T<:PlanarLayer} update_u_hat!(flow) - # compute log_det_jacobian - transformed = transform(flow, z) - psi = ψ(transformed, flow.w, flow.b) + # Compute log_det_jacobian + psi = ψ(z, flow.w, flow.b) log_det_jacobian = log.(abs.(1.0 .+ transpose(psi)*flow.u_hat)) return (rv=transformed, logabsdetjacob=log_det_jacobian) @@ -77,10 +75,10 @@ end function forward(flow::T, z) where {T<:RadialLayer} - # compute log_det_jacobian + # Compute log_det_jacobian transformed = transform(flow, z) - α = softplus(flow.α_) - β_hat = -α + softplus(flow.β) + α = softplus(flow.α_[1]) + β_hat = -α + softplus(flow.β[1]) r = transpose(norm.([z[:,i] .- flow.z_not[:,:] for i in 1:size(z)[2]], 1)) d = size(flow.z_not)[1] log_det_jacobian = log.(((1.0 .+ β_hat.*h(α, r)).^(d-1)) .* ( 1.0 .+ β_hat.*h(α, r) + β_hat.*dh(α, r).*r)) From 494075d761d9f4548c21e6692da40fdf5270fb35 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 26 Jul 2019 00:32:28 +0200 Subject: [PATCH 22/83] add tests on logabsdetjacob --- test/norm_flows.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 test/norm_flows.jl diff --git a/test/norm_flows.jl b/test/norm_flows.jl new file mode 100644 index 00000000..fa2a3fd8 --- /dev/null +++ b/test/norm_flows.jl @@ -0,0 +1,22 @@ +using Test +using Bijectors, ForwardDiff, LinearAlgebra + +@testset "planar flows" begin + for i in 1:100 + flow = PlanarLayer(10) + z = randn(10,100) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t).data, z)))) + our_method = sum(forward(flow, z).logabsdetjacob) + @test our_method ≈ forward_diff + end +end + +@testset "radial flows" begin + for i in 1:100 + flow = RadialLayer(10) + z = randn(10,100) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t).data, z)))) + our_method = sum(forward(flow, z).logabsdetjacob) + @test our_method ≈ forward_diff + end +end From 59aeceba62cb0e66ffbfb1bce7365849a97c9aca Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 26 Jul 2019 00:42:42 +0200 Subject: [PATCH 23/83] fix spaces --- src/norm_flows.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 9bbb629f..19e87fa4 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -24,11 +24,11 @@ end function update_u_hat(u, w) # to preserve invertibility - u_hat = u + (planar_flow_m(transpose(w)*u) - transpose(w)*u)[1]*w/(norm(w[:,1],2)^2) + u_hat = u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] * w/(norm(w[:,1],2) ^ 2) end function update_u_hat!(flow::PlanarLayer) - flow.u_hat = flow.u + (planar_flow_m(transpose(flow.w)*flow.u) - transpose(flow.w)*flow.u)[1]*flow.w/(norm(flow.w,2)^2) + flow.u_hat = flow.u + (planar_flow_m(transpose(flow.w) * flow.u) - transpose(flow.w) * flow.u)[1] * flow.w/(norm(flow.w,2)^2) end @@ -49,7 +49,7 @@ end planar_flow_m(x) = -1 .+ log.(1 .+ exp.(x)) # For planar flow dtanh(x) = 1 .- (tanh.(x)).^2 # For planar flow -ψ(z, w, b) = dtanh(transpose(w)*z .+ b).*w # For planar flow +ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # For planar flow h(α, r) = 1 ./ (α .+ r) # For radial flow dh(α, r) = -h(α, r).^2 # For radial flow @@ -61,7 +61,7 @@ function transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) β_hat = -α + softplus(flow.β[1]) r = norm.(z .- flow.z_not, 1) - return z + β_hat.*h(α, r).*(z .- flow.z_not) + return z + β_hat .* h(α, r) .* (z .- flow.z_not) end function forward(flow::T, z) where {T<:PlanarLayer} @@ -81,6 +81,6 @@ function forward(flow::T, z) where {T<:RadialLayer} β_hat = -α + softplus(flow.β[1]) r = transpose(norm.([z[:,i] .- flow.z_not[:,:] for i in 1:size(z)[2]], 1)) d = size(flow.z_not)[1] - log_det_jacobian = log.(((1.0 .+ β_hat.*h(α, r)).^(d-1)) .* ( 1.0 .+ β_hat.*h(α, r) + β_hat.*dh(α, r).*r)) + log_det_jacobian = log.(((1.0 .+ β_hat .* h(α, r)).^(d-1)) .* ( 1.0 .+ β_hat .* h(α, r) + β_hat .* dh(α, r) .* r)) return (rv=transformed, logabsdetjacob=log_det_jacobian) end From 0d7a70ebc71d7c207b61a853310a98b7bc9d0885 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Fri, 26 Jul 2019 07:43:20 -0400 Subject: [PATCH 24/83] defaults to using Tuple but allow any containers --- src/interface.jl | 44 ++++++++++++-------------------------------- test/interface.jl | 5 ++--- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 68ddad5d..0290f64d 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -95,43 +95,23 @@ end # TODO: allow batch-computation, especially for univariate case? "Computes the absolute determinant of the Jacobian of the inverse-transformation." -function logabsdetjac(b::ADBijector, y::Real) - log(abs(jacobian(b, y))) +function logabsdetjac(b::ADBijector, x::Real) + log(abs(jacobian(b, x))) end -function logabsdetjac(b::ADBijector, y::AbstractVector{<:Real}) - logabsdet(jacobian(b, y))[1] +function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) + fact = lu(jacobian(b, x), check=false) + issuccess(fact) ? log(abs(det(fact))) : -Inf # TODO: or smallest possible float? end ############### # Composition # ############### -struct Composed{B<:Bijector} <: Bijector - ts::Vector{B} +struct Composed{A} <: Bijector + ts::A end -function compose(ts...) - res = [] - - for b ∈ ts - if b isa Composed - # "lift" the transformations - for b_ ∈ b.ts - push!(res, b_) - end - else - # TODO: do we want this? - if (length(res) > 0) && (res[end] == inv(b)) - # remove if inverse - pop!(res) - else - push!(res, b) - end - end - end - - length(res) == 0 ? IdentityBijector : Composed([res...]) -end +compose(ts...) = Composed(ts) # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that @@ -139,10 +119,10 @@ end # TODO: change behavior of `transform` of `Composed`? ∘(b1::B1, b2::B2) where {B1 <: Bijector, B2 <: Bijector} = compose(b2, b1) -inv(ct::Composed{B}) where {B<:Bijector} = Composed(map(inv, reverse(ct.ts))) +inv(ct::Composed) = Composed(map(inv, reverse(ct.ts))) # TODO: can we implement this recursively, and with aggressive inlining, make this type-stable? -function transform(cb::Composed{<: Bijector}, x) +function transform(cb::Composed, x) res = x for b ∈ cb.ts res = transform(b, res) @@ -151,9 +131,9 @@ function transform(cb::Composed{<: Bijector}, x) return res end -(cb::Composed{<: Bijector})(x) = transform(cb, x) +(cb::Composed)(x) = transform(cb, x) -function forward(cb::Composed{<:Bijector}, x) +function forward(cb::Composed, x) res = (rv=x, logabsdetjac=0) for t in cb.ts res′ = forward(t, res.rv) diff --git a/test/interface.jl b/test/interface.jl index 2a2a63d7..b8d5dfa6 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -2,7 +2,6 @@ using Revise using Test using Bijectors using Random -using Turing using LinearAlgebra Random.seed!(123) @@ -56,7 +55,7 @@ Random.seed!(123) @testset "$dist: ForwardDiff AD" begin x = rand(dist) - b = DistributionBijector{Turing.Core.ADBackend(:forward_diff), typeof(dist)}(dist) + b = DistributionBijector{Bijectors.ADBackend(:forward_diff), typeof(dist)}(dist) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf @@ -69,7 +68,7 @@ Random.seed!(123) @testset "$dist: Tracker AD" begin x = rand(dist) - b = DistributionBijector{Turing.Core.ADBackend(:reverse_diff), typeof(dist)}(dist) + b = DistributionBijector{Bijectors.ADBackend(:reverse_diff), typeof(dist)}(dist) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf From a3bf6148cd2a06762c3b047726c5d4fa15f782ec Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Sat, 27 Jul 2019 16:02:28 +0200 Subject: [PATCH 25/83] fixing style-issues --- src/interface.jl | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 0290f64d..938dcb20 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -74,33 +74,31 @@ inv(ib::Inversed{<:Bijector}) = ib.orig # AD implementations function jacobian(b::ADBijector{<: ForwardDiffAD}, y::Real) - ForwardDiff.derivative(z -> transform(b, z), y) + return ForwardDiff.derivative(z -> transform(b, z), y) end function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::Real) - ForwardDiff.derivative(z -> transform(b, z), y) + return ForwardDiff.derivative(z -> transform(b, z), y) end function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::AbstractVector{<: Real}) - ForwardDiff.jacobian(z -> transform(b, z), y) + return ForwardDiff.jacobian(z -> transform(b, z), y) end function jacobian(b::ADBijector{<: TrackerAD}, y::Real) - Tracker.gradient(z -> transform(b, z), y)[1] + return Tracker.gradient(z -> transform(b, z), y)[1] end function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::Real) - Tracker.gradient(z -> transform(b, z), y)[1] + return Tracker.gradient(z -> transform(b, z), y)[1] end function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::AbstractVector{<: Real}) - Tracker.jacobian(z -> transform(b, z), y) + return Tracker.jacobian(z -> transform(b, z), y) end # TODO: allow batch-computation, especially for univariate case? "Computes the absolute determinant of the Jacobian of the inverse-transformation." -function logabsdetjac(b::ADBijector, x::Real) - log(abs(jacobian(b, x))) -end +logabsdetjac(b::ADBijector, x::Real) = log(abs(jacobian(b, x))) function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) fact = lu(jacobian(b, x), check=false) - issuccess(fact) ? log(abs(det(fact))) : -Inf # TODO: or smallest possible float? + return issuccess(fact) ? log(abs(det(fact))) : -Inf # TODO: or smallest possible float? end ############### @@ -117,7 +115,7 @@ compose(ts...) = Composed(ts) # but in mathematics we usually go from right-to-left; this reversal ensures that # when we use the mathematical composition ∘ we get the expected behavior. # TODO: change behavior of `transform` of `Composed`? -∘(b1::B1, b2::B2) where {B1 <: Bijector, B2 <: Bijector} = compose(b2, b1) +∘(b1::Bijector, b2::Bijector) = compose(b2, b1) inv(ct::Composed) = Composed(map(inv, reverse(ct.ts))) @@ -223,22 +221,21 @@ bijector(d::Beta{T}) where T <: Real = Logit(zero(T), one(T)) Base.length(td::MultivariateTransformed) = length(td.dist) # logp -function logpdf(td::UnivariateTransformed, y::T where T <: Real) +function logpdf(td::UnivariateTransformed, y::Real) # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) logpdf_with_trans(td.dist, transform(inv(td.transform), y), true) end -function _logpdf(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) +function _logpdf(td::MultivariateTransformed, y::AbstractVector{<: Real}) # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) logpdf_with_trans(td.dist, transform(inv(td.transform), y), true) end -# TODO: implement these using analytical expressions? -function logpdf_with_jac(td::UnivariateTransformed, y::T where T <: Real) +function logpdf_with_jac(td::UnivariateTransformed, y::Real) z = logabsdetjac(inv(td.transform), y) return (logpdf(td.dist, transform(inv(td.transform), y)) .+ z, z) end -function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{T} where T <: Real) +function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{<:Real}) z = logabsdetjac(inv(td.transform), y) return (logpdf(td.dist, transform(inv(td.transform), y)) .+ z, z) end From b0a01e94ad23639a64f8306835c4dc8f440bc64e Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 31 Jul 2019 15:30:55 +0200 Subject: [PATCH 26/83] add iterative norm for planar flows --- src/norm_flows.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 19e87fa4..1b602ede 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -2,6 +2,7 @@ using Distributions using LinearAlgebra using Random using StatsFuns: softplus +using Roots, LinearAlgebra # for inverse ################################################################################ # Planar and Radial Flows # @@ -84,3 +85,14 @@ function forward(flow::T, z) where {T<:RadialLayer} log_det_jacobian = log.(((1.0 .+ β_hat .* h(α, r)).^(d-1)) .* ( 1.0 .+ β_hat .* h(α, r) + β_hat .* dh(α, r) .* r)) return (rv=transformed, logabsdetjacob=log_det_jacobian) end + +function inv(flow::PlanarLayer, y) + function f(y) + return loss(alpha) = (transpose(flow.w.data)*y)[1] - alpha -(transpose(flow.w.data)*flow.u_hat.data)[1]*tanh(alpha+flow.b.data[1]) + end + alphas = transpose([find_zero(f(y[:,i:i]), 0, Order16()) for i in 1:size(z)[2]]) + z_para = (flow.w.data ./norm(flow.w.data,2))*alphas + z_per = y - z_para - flow.u_hat.data*tanh.(transpose(flow.w.data)*z_para .+ flow.b.data) + + return z_para+z_per +end From 96c2c7fc1cd1fd1b021d4f09d39d9bfec08f0853 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 31 Jul 2019 15:39:51 +0200 Subject: [PATCH 27/83] fix minor bug --- src/norm_flows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 1b602ede..66ecd3a2 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -90,7 +90,7 @@ function inv(flow::PlanarLayer, y) function f(y) return loss(alpha) = (transpose(flow.w.data)*y)[1] - alpha -(transpose(flow.w.data)*flow.u_hat.data)[1]*tanh(alpha+flow.b.data[1]) end - alphas = transpose([find_zero(f(y[:,i:i]), 0, Order16()) for i in 1:size(z)[2]]) + alphas = transpose([find_zero(f(y[:,i:i]), 0, Order16()) for i in 1:size(y)[2]]) z_para = (flow.w.data ./norm(flow.w.data,2))*alphas z_per = y - z_para - flow.u_hat.data*tanh.(transpose(flow.w.data)*z_para .+ flow.b.data) From 2109d6d5914317b706ba09a79612e65068820348 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Thu, 1 Aug 2019 11:01:31 +0530 Subject: [PATCH 28/83] adhere to stylecode, add radial inverse, remove tracker dependency, restructure code --- src/Bijectors.jl | 1 - src/norm_flows.jl | 126 ++++++++++++++++++++++++++++++---------------- 2 files changed, 82 insertions(+), 45 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e20dacbb..c8cc661d 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -5,7 +5,6 @@ using Reexport, Requires using StatsFuns using LinearAlgebra using MappedArrays -using Tracker: param export TransformDistribution, RealDistribution, diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 66ecd3a2..2a9119df 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -17,54 +17,34 @@ mutable struct PlanarLayer <: Bijector b end -mutable struct RadialLayer <: Bijector - α_ - β - z_not -end - -function update_u_hat(u, w) +function get_u_hat(u, w) # to preserve invertibility - u_hat = u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] * w/(norm(w[:,1],2) ^ 2) + u_hat = ( + u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] + * w/(norm(w[:,1],2) ^ 2) + ) end function update_u_hat!(flow::PlanarLayer) - flow.u_hat = flow.u + (planar_flow_m(transpose(flow.w) * flow.u) - transpose(flow.w) * flow.u)[1] * flow.w/(norm(flow.w,2)^2) + flow.u_hat = get_u_hat(flow.u, flow.w) end -function PlanarLayer(dims::Int) - w = param(randn(dims, 1)) - u = param(randn(dims, 1)) - b = param(randn(1)) - u_hat = update_u_hat(u, w) +function PlanarLayer(dims::Int, container=Array) + w = container(randn(dims, 1)) + u = container(randn(dims, 1)) + b = container(randn(1)) + u_hat = get_u_hat(u, w) return PlanarLayer(w, u, u_hat, b) end -function RadialLayer(dims::Int) - α_ = param(randn(1)) - β = param(randn(1)) - z_not = param(randn(dims, 1)) - return RadialLayer(α_, β, z_not) -end - -planar_flow_m(x) = -1 .+ log.(1 .+ exp.(x)) # For planar flow -dtanh(x) = 1 .- (tanh.(x)).^2 # For planar flow -ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # For planar flow -h(α, r) = 1 ./ (α .+ r) # For radial flow -dh(α, r) = -h(α, r).^2 # For radial flow +planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow +dtanh(x) = 1 .- (tanh.(x)).^2 # for planar flow function transform(flow::PlanarLayer, z) return z + flow.u_hat*tanh.(transpose(flow.w)*z .+ flow.b) end -function transform(flow::RadialLayer, z) - α = softplus(flow.α_[1]) - β_hat = -α + softplus(flow.β[1]) - r = norm.(z .- flow.z_not, 1) - return z + β_hat .* h(α, r) .* (z .- flow.z_not) -end - function forward(flow::T, z) where {T<:PlanarLayer} update_u_hat!(flow) # Compute log_det_jacobian @@ -74,25 +54,83 @@ function forward(flow::T, z) where {T<:PlanarLayer} return (rv=transformed, logabsdetjacob=log_det_jacobian) end +function inv(flow::PlanarLayer, y) + function f(y) + return loss(alpha) = ( + (transpose(flow.w.data)*y)[1] - alpha + - (transpose(flow.w.data)*flow.u_hat.data)[1] + * tanh(alpha+flow.b.data[1]) + ) + end + alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]] + alphas = transpose(alphas) + z_para = (flow.w.data ./norm(flow.w.data,2)) * alphas + z_per = ( + y - z_para - flow.u_hat.data*tanh.( + transpose(flow.w.data) * z_para + .+ flow.b.data + ) + ) + + return z_para+z_per +end + +mutable struct RadialLayer <: Bijector + α_ + β + z_0 +end + +function RadialLayer(dims::Int, container=Array) + α_ = container(randn(1)) + β = container(randn(1)) + z_0 = container(randn(dims, 1)) + return RadialLayer(α_, β, z_0) +end + + + +ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow +h(α, r) = 1 ./ (α .+ r) # for radial flow +dh(α, r) = -h(α, r).^2 # for radial flow + + +function transform(flow::RadialLayer, z) + α = softplus(flow.α_[1]) + β_hat = -α + softplus(flow.β[1]) + r = norm.(z .- flow.z_0, 1) + return z + β_hat .* h(α, r) .* (z .- flow.z_0) +end + + function forward(flow::T, z) where {T<:RadialLayer} - # Compute log_det_jacobian + # compute log_det_jacobian transformed = transform(flow, z) α = softplus(flow.α_[1]) β_hat = -α + softplus(flow.β[1]) - r = transpose(norm.([z[:,i] .- flow.z_not[:,:] for i in 1:size(z)[2]], 1)) - d = size(flow.z_not)[1] - log_det_jacobian = log.(((1.0 .+ β_hat .* h(α, r)).^(d-1)) .* ( 1.0 .+ β_hat .* h(α, r) + β_hat .* dh(α, r) .* r)) + r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z)[2]], 1)) + d = size(flow.z_0)[1] + h_ = h(α, r) + h_ = h(α, r) + log_det_jacobian = @. ( + (d-1) * log(1.0 + β_hat * h_) + + log(1.0 + β_hat * h_ + β_hat * (- h_^2) * r) + ) return (rv=transformed, logabsdetjacob=log_det_jacobian) end -function inv(flow::PlanarLayer, y) +function inv(flow::RadialLayer, y) + α = softplus(flow.α_.data[1]) + β_hat = -α + softplus(flow.β.data[1]) function f(y) - return loss(alpha) = (transpose(flow.w.data)*y)[1] - alpha -(transpose(flow.w.data)*flow.u_hat.data)[1]*tanh(alpha+flow.b.data[1]) + return loss(r) = ( + norm(y - flow.z_not.data, 2) + - r * (1 + β_hat / (α + r)) + ) end - alphas = transpose([find_zero(f(y[:,i:i]), 0, Order16()) for i in 1:size(y)[2]]) - z_para = (flow.w.data ./norm(flow.w.data,2))*alphas - z_per = y - z_para - flow.u_hat.data*tanh.(transpose(flow.w.data)*z_para .+ flow.b.data) - - return z_para+z_per + rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]] + rs = transpose(alphas) + z = (y.-flow.z_not.data) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) + return z end From b727f5e1dad2007487020e897057c292738f7af8 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 7 Aug 2019 21:35:52 +0530 Subject: [PATCH 29/83] fix radius bug --- src/norm_flows.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 2a9119df..3d45c5b4 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -98,7 +98,7 @@ dh(α, r) = -h(α, r).^2 # for radial flow function transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) β_hat = -α + softplus(flow.β[1]) - r = norm.(z .- flow.z_0, 1) + r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z)[2]], 1)) return z + β_hat .* h(α, r) .* (z .- flow.z_0) end @@ -112,7 +112,6 @@ function forward(flow::T, z) where {T<:RadialLayer} r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z)[2]], 1)) d = size(flow.z_0)[1] h_ = h(α, r) - h_ = h(α, r) log_det_jacobian = @. ( (d-1) * log(1.0 + β_hat * h_) + log(1.0 + β_hat * h_ + β_hat * (- h_^2) * r) From 6c1a5a0070c8aefaeb1e9a3d28ce066a9ae3b9c2 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 7 Aug 2019 21:50:11 +0530 Subject: [PATCH 30/83] fix param dependency --- src/norm_flows.jl | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 3d45c5b4..6abbc9f9 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -57,18 +57,18 @@ end function inv(flow::PlanarLayer, y) function f(y) return loss(alpha) = ( - (transpose(flow.w.data)*y)[1] - alpha - - (transpose(flow.w.data)*flow.u_hat.data)[1] - * tanh(alpha+flow.b.data[1]) + (transpose(flow.w)*y)[1] - alpha + - (transpose(flow.w)*flow.u_hat)[1] + * tanh(alpha+flow.b[1]) ) end alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]] alphas = transpose(alphas) - z_para = (flow.w.data ./norm(flow.w.data,2)) * alphas + z_para = (flow.w ./norm(flow.w,2)) * alphas z_per = ( - y - z_para - flow.u_hat.data*tanh.( - transpose(flow.w.data) * z_para - .+ flow.b.data + y - z_para - flow.u_hat*tanh.( + transpose(flow.w) * z_para + .+ flow.b ) ) @@ -120,16 +120,20 @@ function forward(flow::T, z) where {T<:RadialLayer} end function inv(flow::RadialLayer, y) - α = softplus(flow.α_.data[1]) - β_hat = -α + softplus(flow.β.data[1]) + α = softplus(flow.α_[1]) + β_hat = -α + softplus(flow.β[1]) function f(y) return loss(r) = ( - norm(y - flow.z_not.data, 2) + norm(y - flow.z_0, 2) - r * (1 + β_hat / (α + r)) ) end rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]] rs = transpose(alphas) - z = (y.-flow.z_not.data) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) + z = (y.-flow.z_0) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) return z end + +flow = RadialLayer(2) + +transform(flow, randn(2,100)) From 0c026d4af51e2b8e35185401ac1ddd883fc04d5b Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 13 Aug 2019 11:33:47 +0530 Subject: [PATCH 31/83] fix inv() for radial flow; follow style guidelines --- src/norm_flows.jl | 71 ++++++++++++++++++++++------------------------ test/norm_flows.jl | 10 ++++--- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 6abbc9f9..858078d6 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -10,19 +10,19 @@ using Roots, LinearAlgebra # for inverse # D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # ################################################################################ -mutable struct PlanarLayer <: Bijector - w - u - u_hat - b +mutable struct PlanarLayer{T1,T2} <: Bijector + w::T1 + u::T1 + u_hat::T1 + b::T2 end function get_u_hat(u, w) - # to preserve invertibility + # To preserve invertibility u_hat = ( u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] - * w/(norm(w[:,1],2) ^ 2) - ) + * w / (norm(w[:,1],2) ^ 2) + ) end function update_u_hat!(flow::PlanarLayer) @@ -39,17 +39,17 @@ function PlanarLayer(dims::Int, container=Array) end planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow -dtanh(x) = 1 .- (tanh.(x)).^2 # for planar flow +dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow function transform(flow::PlanarLayer, z) - return z + flow.u_hat*tanh.(transpose(flow.w)*z .+ flow.b) + return z + flow.u_hat * tanh.(transpose(flow.w) * z .+ flow.b) end function forward(flow::T, z) where {T<:PlanarLayer} update_u_hat!(flow) # Compute log_det_jacobian psi = ψ(z, flow.w, flow.b) - log_det_jacobian = log.(abs.(1.0 .+ transpose(psi)*flow.u_hat)) + log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * flow.u_hat)) return (rv=transformed, logabsdetjacob=log_det_jacobian) end @@ -57,28 +57,28 @@ end function inv(flow::PlanarLayer, y) function f(y) return loss(alpha) = ( - (transpose(flow.w)*y)[1] - alpha - - (transpose(flow.w)*flow.u_hat)[1] + (transpose(flow.w) * y)[1] - alpha + - (transpose(flow.w) * flow.u_hat)[1] * tanh(alpha+flow.b[1]) ) end - alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]] - alphas = transpose(alphas) - z_para = (flow.w ./norm(flow.w,2)) * alphas + alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] + alphas = transpose(alphas_) + z_para = (flow.w ./ norm(flow.w,2)) * alphas z_per = ( - y - z_para - flow.u_hat*tanh.( + y - z_para - flow.u_hat * tanh.( transpose(flow.w) * z_para .+ flow.b - ) ) + ) - return z_para+z_per + return z_para + z_per end -mutable struct RadialLayer <: Bijector - α_ - β - z_0 +mutable struct RadialLayer{T1,T2} <: Bijector + α_::T1 + β::T1 + z_0::T2 end function RadialLayer(dims::Int, container=Array) @@ -92,48 +92,45 @@ end ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow h(α, r) = 1 ./ (α .+ r) # for radial flow -dh(α, r) = -h(α, r).^2 # for radial flow +dh(α, r) = - h(α, r) .^ 2 # for radial flow function transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) β_hat = -α + softplus(flow.β[1]) - r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z)[2]], 1)) + r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z, 2)], 1)) return z + β_hat .* h(α, r) .* (z .- flow.z_0) end function forward(flow::T, z) where {T<:RadialLayer} - # compute log_det_jacobian + # Compute log_det_jacobian transformed = transform(flow, z) α = softplus(flow.α_[1]) β_hat = -α + softplus(flow.β[1]) - r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z)[2]], 1)) - d = size(flow.z_0)[1] + r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z, 2)], 1)) + d = size(flow.z_0, 1) h_ = h(α, r) log_det_jacobian = @. ( (d-1) * log(1.0 + β_hat * h_) - + log(1.0 + β_hat * h_ + β_hat * (- h_^2) * r) + + log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) return (rv=transformed, logabsdetjacob=log_det_jacobian) end function inv(flow::RadialLayer, y) α = softplus(flow.α_[1]) - β_hat = -α + softplus(flow.β[1]) + β_hat = - α + softplus(flow.β[1]) function f(y) return loss(r) = ( norm(y - flow.z_0, 2) - r * (1 + β_hat / (α + r)) ) end - rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]] - rs = transpose(alphas) - z = (y.-flow.z_0) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) + rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] + rs = transpose(rs_) + z_hat = (y .- flow.z_0) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) + z = flow.z_0 .+ rs .* z_hat return z end - -flow = RadialLayer(2) - -transform(flow, randn(2,100)) diff --git a/test/norm_flows.jl b/test/norm_flows.jl index fa2a3fd8..e162eb6a 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -2,21 +2,23 @@ using Test using Bijectors, ForwardDiff, LinearAlgebra @testset "planar flows" begin - for i in 1:100 + for i in 1:1 flow = PlanarLayer(10) z = randn(10,100) - forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t).data, z)))) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) our_method = sum(forward(flow, z).logabsdetjacob) @test our_method ≈ forward_diff + # @test inv(flow, transform(flow, z)) ≈ z end end @testset "radial flows" begin - for i in 1:100 + for i in 1:10 flow = RadialLayer(10) z = randn(10,100) - forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t).data, z)))) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) our_method = sum(forward(flow, z).logabsdetjacob) @test our_method ≈ forward_diff + # @test inv(flow, transform(flow, z)) ≈ z end end From d3ff9a68c9e4a083fe956e503ec0f44784e0d942 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 13 Aug 2019 11:36:47 +0530 Subject: [PATCH 32/83] minor change to test --- test/norm_flows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/norm_flows.jl b/test/norm_flows.jl index e162eb6a..7e963c97 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -2,7 +2,7 @@ using Test using Bijectors, ForwardDiff, LinearAlgebra @testset "planar flows" begin - for i in 1:1 + for i in 1:10 flow = PlanarLayer(10) z = randn(10,100) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) From 8ca489ae15c3ed31514b08c971a7307ec2a0a09c Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 14 Aug 2019 06:16:54 +0530 Subject: [PATCH 33/83] minor fix --- src/norm_flows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 858078d6..c0a6bbac 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -19,7 +19,7 @@ end function get_u_hat(u, w) # To preserve invertibility - u_hat = ( + return ( u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] * w / (norm(w[:,1],2) ^ 2) ) From fb2399f8086d98acef92b52643537591897cdb27 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 14 Aug 2019 06:31:43 +0530 Subject: [PATCH 34/83] add ref to paper for each equation --- src/norm_flows.jl | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index c0a6bbac..6343af3f 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -22,7 +22,7 @@ function get_u_hat(u, w) return ( u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] * w / (norm(w[:,1],2) ^ 2) - ) + ) # from A.1 end function update_u_hat!(flow::PlanarLayer) @@ -38,23 +38,25 @@ function PlanarLayer(dims::Int, container=Array) return PlanarLayer(w, u, u_hat, b) end -planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow +planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow from A.1 dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow +ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow from eq(11) function transform(flow::PlanarLayer, z) - return z + flow.u_hat * tanh.(transpose(flow.w) * z .+ flow.b) + return z + flow.u_hat * tanh.(transpose(flow.w) * z .+ flow.b) # from eq(10) end function forward(flow::T, z) where {T<:PlanarLayer} update_u_hat!(flow) # Compute log_det_jacobian psi = ψ(z, flow.w, flow.b) - log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * flow.u_hat)) + log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * flow.u_hat)) # from eq(12) return (rv=transformed, logabsdetjacob=log_det_jacobian) end function inv(flow::PlanarLayer, y) + # Implemented with reference from A.1 function f(y) return loss(alpha) = ( (transpose(flow.w) * y)[1] - alpha @@ -90,47 +92,48 @@ end -ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow -h(α, r) = 1 ./ (α .+ r) # for radial flow -dh(α, r) = - h(α, r) .^ 2 # for radial flow +h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) +dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h() function transform(flow::RadialLayer, z) - α = softplus(flow.α_[1]) - β_hat = -α + softplus(flow.β[1]) + α = softplus(flow.α_[1]) # from A.2 + β_hat = -α + softplus(flow.β[1]) # from A.2 r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z, 2)], 1)) - return z + β_hat .* h(α, r) .* (z .- flow.z_0) + return z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) end function forward(flow::T, z) where {T<:RadialLayer} # Compute log_det_jacobian - transformed = transform(flow, z) - α = softplus(flow.α_[1]) - β_hat = -α + softplus(flow.β[1]) + α = softplus(flow.α_[1]) # from A.2 + β_hat = -α + softplus(flow.β[1]) # from A.2 r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z, 2)], 1)) + transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) + d = size(flow.z_0, 1) h_ = h(α, r) log_det_jacobian = @. ( (d-1) * log(1.0 + β_hat * h_) + log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) - ) + ) # from eq(14) return (rv=transformed, logabsdetjacob=log_det_jacobian) end function inv(flow::RadialLayer, y) - α = softplus(flow.α_[1]) - β_hat = - α + softplus(flow.β[1]) + α = softplus(flow.α_[1]) # from A.2 + β_hat = - α + softplus(flow.β[1]) # from A.2 function f(y) + # From eq(26) return loss(r) = ( norm(y - flow.z_0, 2) - r * (1 + β_hat / (α + r)) ) end - rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] + rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] # A.2 rs = transpose(rs_) - z_hat = (y .- flow.z_0) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) - z = flow.z_0 .+ rs .* z_hat + z_hat = (y .- flow.z_0) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) # from eq(25) + z = flow.z_0 .+ rs .* z_hat # from A.2 return z end From 45a67134b9c7f9e2b99be98199e631beacd30caa Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 14 Aug 2019 06:44:37 +0530 Subject: [PATCH 35/83] fix forward and remove redundant --- src/norm_flows.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 6343af3f..0a2576d4 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -51,8 +51,8 @@ function forward(flow::T, z) where {T<:PlanarLayer} # Compute log_det_jacobian psi = ψ(z, flow.w, flow.b) log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * flow.u_hat)) # from eq(12) - - return (rv=transformed, logabsdetjacob=log_det_jacobian) + transformed = z + flow.u_hat * tanh.(transpose(flow.w) * z .+ flow.b) + return (rv=transformed, logabsdetjacob=log_det_jacobian) # from eq(10) end function inv(flow::PlanarLayer, y) @@ -99,7 +99,7 @@ dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h() function transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 - r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z, 2)], 1)) + r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 1)) return z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) end @@ -109,7 +109,7 @@ function forward(flow::T, z) where {T<:RadialLayer} # Compute log_det_jacobian α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 - r = transpose(norm.([z[:,i] .- flow.z_0[:,:] for i in 1:size(z, 2)], 1)) + r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 1)) transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) d = size(flow.z_0, 1) From 3f86c99167870129bcc698d9d204c54f45bed6e8 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Thu, 15 Aug 2019 11:49:55 +0530 Subject: [PATCH 36/83] remove update_u_hat!() requirement --- src/norm_flows.jl | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 0a2576d4..641c2616 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -2,7 +2,7 @@ using Distributions using LinearAlgebra using Random using StatsFuns: softplus -using Roots, LinearAlgebra # for inverse +using Roots # for inverse ################################################################################ # Planar and Radial Flows # @@ -13,7 +13,6 @@ using Roots, LinearAlgebra # for inverse mutable struct PlanarLayer{T1,T2} <: Bijector w::T1 u::T1 - u_hat::T1 b::T2 end @@ -25,17 +24,11 @@ function get_u_hat(u, w) ) # from A.1 end -function update_u_hat!(flow::PlanarLayer) - flow.u_hat = get_u_hat(flow.u, flow.w) -end - - function PlanarLayer(dims::Int, container=Array) w = container(randn(dims, 1)) u = container(randn(dims, 1)) b = container(randn(1)) - u_hat = get_u_hat(u, w) - return PlanarLayer(w, u, u_hat, b) + return PlanarLayer(w, u, b) end planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow from A.1 @@ -43,24 +36,26 @@ dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow from eq(11) function transform(flow::PlanarLayer, z) - return z + flow.u_hat * tanh.(transpose(flow.w) * z .+ flow.b) # from eq(10) + u_hat = get_u_hat(flow.u, flow.w) + return z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) # from eq(10) end function forward(flow::T, z) where {T<:PlanarLayer} - update_u_hat!(flow) + u_hat = get_u_hat(flow.u, flow.w) # Compute log_det_jacobian psi = ψ(z, flow.w, flow.b) - log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * flow.u_hat)) # from eq(12) - transformed = z + flow.u_hat * tanh.(transpose(flow.w) * z .+ flow.b) + log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * u_hat)) # from eq(12) + transformed = z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) return (rv=transformed, logabsdetjacob=log_det_jacobian) # from eq(10) end function inv(flow::PlanarLayer, y) + u_hat = get_u_hat(flow.u, flow.w) # Implemented with reference from A.1 function f(y) return loss(alpha) = ( (transpose(flow.w) * y)[1] - alpha - - (transpose(flow.w) * flow.u_hat)[1] + - (transpose(flow.w) * u_hat)[1] * tanh(alpha+flow.b[1]) ) end @@ -68,7 +63,7 @@ function inv(flow::PlanarLayer, y) alphas = transpose(alphas_) z_para = (flow.w ./ norm(flow.w,2)) * alphas z_per = ( - y - z_para - flow.u_hat * tanh.( + y - z_para - u_hat * tanh.( transpose(flow.w) * z_para .+ flow.b ) @@ -99,7 +94,7 @@ dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h() function transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 - r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 1)) + r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)) return z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) end @@ -109,7 +104,7 @@ function forward(flow::T, z) where {T<:RadialLayer} # Compute log_det_jacobian α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 - r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 1)) + r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)) transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) d = size(flow.z_0, 1) From d88240688b17195b9d8c8a63e6c94e009b0997bd Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sat, 17 Aug 2019 16:01:53 +0530 Subject: [PATCH 37/83] update tests and ifx bug --- src/Bijectors.jl | 1 + src/norm_flows.jl | 10 ++-------- test/norm_flows.jl | 23 +++++++++++++++++------ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index c8cc661d..217628aa 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -5,6 +5,7 @@ using Reexport, Requires using StatsFuns using LinearAlgebra using MappedArrays +using Roots export TransformDistribution, RealDistribution, diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 641c2616..13f7d59c 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -85,12 +85,9 @@ function RadialLayer(dims::Int, container=Array) return RadialLayer(α_, β, z_0) end - - h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h() - function transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 @@ -98,15 +95,12 @@ function transform(flow::RadialLayer, z) return z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) end - - function forward(flow::T, z) where {T<:RadialLayer} - # Compute log_det_jacobian α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)) transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) - + # Compute log_det_jacobian d = size(flow.z_0, 1) h_ = h(α, r) log_det_jacobian = @. ( @@ -128,7 +122,7 @@ function inv(flow::RadialLayer, y) end rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] # A.2 rs = transpose(rs_) - z_hat = (y .- flow.z_0) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) ) # from eq(25) + z_hat = (y .- flow.z_0) ./ (rs .* (1 .+ β_hat ./ (α .+ rs)) ) # from eq(25) z = flow.z_0 .+ rs .* z_hat # from A.2 return z end diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 7e963c97..d3c1ef0b 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -3,22 +3,33 @@ using Bijectors, ForwardDiff, LinearAlgebra @testset "planar flows" begin for i in 1:10 - flow = PlanarLayer(10) - z = randn(10,100) + flow = PlanarLayer(2) + z = randn(2, 1) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) our_method = sum(forward(flow, z).logabsdetjacob) @test our_method ≈ forward_diff - # @test inv(flow, transform(flow, z)) ≈ z end + + w = ones(10, 1) + u = zeros(10, 1) + b = ones(1) + flow = PlanarLayer(w, u, b) + z = ones(10, 100) + @test inv(flow, transform(flow, z)) ≈ z end @testset "radial flows" begin for i in 1:10 - flow = RadialLayer(10) - z = randn(10,100) + flow = RadialLayer(1,0,zeros(2,1)) + z = randn(2, 1) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) our_method = sum(forward(flow, z).logabsdetjacob) @test our_method ≈ forward_diff - # @test inv(flow, transform(flow, z)) ≈ z end + α_ = ones(1) + β = ones(1) + z_0 = zeros(10, 1) + z = ones(10, 100) + flow = RadialLayer(α_, β, z_0) + @test inv(flow, transform(flow, z)) ≈ z end From bac200f837566e6186cbc2bb5cbb063de4f3157f Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sat, 17 Aug 2019 16:21:28 +0530 Subject: [PATCH 38/83] update tests --- test/norm_flows.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/norm_flows.jl b/test/norm_flows.jl index d3c1ef0b..14b1c784 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -3,11 +3,14 @@ using Bijectors, ForwardDiff, LinearAlgebra @testset "planar flows" begin for i in 1:10 - flow = PlanarLayer(2) - z = randn(2, 1) + flow = PlanarLayer(10) + z = randn(10, 100) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) our_method = sum(forward(flow, z).logabsdetjacob) @test our_method ≈ forward_diff + + # Inverse not accurate enough to pass with `≈` operator. + @test_broken inv(flow, transform(flow, z)) ≈ z end w = ones(10, 1) @@ -20,11 +23,12 @@ end @testset "radial flows" begin for i in 1:10 - flow = RadialLayer(1,0,zeros(2,1)) - z = randn(2, 1) + flow = RadialLayer(2) + z = randn(2, 100) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) our_method = sum(forward(flow, z).logabsdetjacob) @test our_method ≈ forward_diff + @test inv(flow, transform(flow, z)) ≈ z end α_ = ones(1) β = ones(1) From f276e57d4178eefc1d34d7cf742149febcd5f22a Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sat, 17 Aug 2019 18:22:32 +0530 Subject: [PATCH 39/83] implement Bijector call. We can now transform using BijectorName(x) --- src/norm_flows.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 13f7d59c..a105a71f 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -10,6 +10,8 @@ using Roots # for inverse # D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # ################################################################################ +(b::Bijector)(x) = transform(b, x) + mutable struct PlanarLayer{T1,T2} <: Bijector w::T1 u::T1 @@ -46,7 +48,7 @@ function forward(flow::T, z) where {T<:PlanarLayer} psi = ψ(z, flow.w, flow.b) log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * u_hat)) # from eq(12) transformed = z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) - return (rv=transformed, logabsdetjacob=log_det_jacobian) # from eq(10) + return (rv=transformed, logabsdetjac=log_det_jacobian) # from eq(10) end function inv(flow::PlanarLayer, y) @@ -107,7 +109,7 @@ function forward(flow::T, z) where {T<:RadialLayer} (d-1) * log(1.0 + β_hat * h_) + log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) # from eq(14) - return (rv=transformed, logabsdetjacob=log_det_jacobian) + return (rv=transformed, logabsdetjac=log_det_jacobian) end function inv(flow::RadialLayer, y) From 23a0501d6897bd59940e873e7cb6afdea94e2509 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sat, 17 Aug 2019 18:23:37 +0530 Subject: [PATCH 40/83] Add inv and rand functions for composed bijectors --- src/interface.jl | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 68ddad5d..5113f240 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -46,7 +46,7 @@ end Broadcast.broadcastable(b::Bijector) = Ref(b) "Computes the log(abs(det(J(x)))) where J is the jacobian of the transform." -logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = +logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = error("`logabsdetjac(b::$T1, y::$T2)` is not implemented.") "Transforms the input using the bijector." @@ -54,7 +54,7 @@ transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = error("`transform(b::$T1, y::$T2)` is not implemented.") "Computes both `transform` and `logabsdetjac` in one forward pass." -forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = +forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = error("`forward(b::$T1, y::$T2)` is not implemented.") @@ -112,7 +112,7 @@ end function compose(ts...) res = [] - + for b ∈ ts if b isa Composed # "lift" the transformations @@ -151,6 +151,14 @@ function transform(cb::Composed{<: Bijector}, x) return res end +function inv(cb::Composed{<: Bijector}, y) + res = y + for b ∈ reverse(cb.ts) + res = inv(b, res) + end + return res +end + (cb::Composed{<: Bijector})(x) = transform(cb, x) function forward(cb::Composed{<:Bijector}, x) @@ -162,6 +170,12 @@ function forward(cb::Composed{<:Bijector}, x) return res end +function rand(flow::Composed, dims::Integer, shape::Integer=1) + dims = [dims] + append!(dims, shape) + print(dims) + return transform(flow, randn(dims...)) +end ############################## # Example bijector: Identity # ############################## From dc73717daeb5879c06c5529ebafcb408e9c9d498 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 18 Aug 2019 10:03:38 +0530 Subject: [PATCH 41/83] Add direct calls for norm flows --- src/interface.jl | 1 - src/norm_flows.jl | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 5113f240..da832f8c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -173,7 +173,6 @@ end function rand(flow::Composed, dims::Integer, shape::Integer=1) dims = [dims] append!(dims, shape) - print(dims) return transform(flow, randn(dims...)) end ############################## diff --git a/src/norm_flows.jl b/src/norm_flows.jl index a105a71f..95775e45 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -10,14 +10,14 @@ using Roots # for inverse # D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # ################################################################################ -(b::Bijector)(x) = transform(b, x) - mutable struct PlanarLayer{T1,T2} <: Bijector w::T1 u::T1 b::T2 end +(b::PlanarLayer)(x) = transform(b, x) + function get_u_hat(u, w) # To preserve invertibility return ( @@ -87,6 +87,8 @@ function RadialLayer(dims::Int, container=Array) return RadialLayer(α_, β, z_0) end +(b::RadialLayer)(x) = transform(b, x) + h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h() From d754b74212a63031a1bc21d86c752c4c8d8160cf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Aug 2019 20:18:23 +0200 Subject: [PATCH 42/83] added docstrings and moved away from transform to callables --- src/interface.jl | 205 +++++++++++++++++++++++++++++++++------------- test/interface.jl | 22 +++-- 2 files changed, 160 insertions(+), 67 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 938dcb20..f7e2fb32 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -15,10 +15,10 @@ abstract type ADBackend end struct ForwardDiffAD <: ADBackend end struct TrackerAD <: ADBackend end -const ADBACKEND = Ref(:forward) +const ADBACKEND = Ref(:forward_diff) function setadbackend(backend_sym) @assert backend_sym == :forward_diff || backend_sym == :reverse_diff - backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40) + backend_sym == :forward_diff ADBACKEND[] = backend_sym end @@ -36,61 +36,85 @@ end # Bijector interface # ###################### +"Abstract type for a `Bijector`." abstract type Bijector end + +Broadcast.broadcastable(b::Bijector) = Ref(b) + +"Abstract type for a `Bijector` making use of auto-differentation (AD)." abstract type ADBijector{AD} <: Bijector end +""" + inv(b::Bijector) + Inversed(b::Bijector) + +A `Bijector` representing the inverse transform of `b`. +""" struct Inversed{B <: Bijector} <: Bijector orig::B end -Broadcast.broadcastable(b::Bijector) = Ref(b) - -"Computes the log(abs(det(J(x)))) where J is the jacobian of the transform." -logabsdetjac(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = - error("`logabsdetjac(b::$T1, y::$T2)` is not implemented.") +inv(b::Bijector) = Inversed(b) +inv(ib::Inversed{<:Bijector}) = ib.orig -"Transforms the input using the bijector." -transform(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = - error("`transform(b::$T1, y::$T2)` is not implemented.") +""" + logabsdetjac(b::Bijector, x) + logabsdetjac(ib::Inversed{<: Bijector}, y) -"Computes both `transform` and `logabsdetjac` in one forward pass." -forward(b::T1, y::T2) where {T<:Bijector,T1<:Inversed{T},T2} = - error("`forward(b::$T1, y::$T2)` is not implemented.") +Computes the log(abs(det(J(x)))) where J is the jacobian of the transform. +Similarily for the inverse-transform. +Default implementation for `Inversed{<: Bijector}` is implemented as +`- logabsdetjac` of original `Bijector`. +""" +logabsdetjac(ib::Inversed{<: Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) -transform(b::Bijector) = x -> transform(b, x) -(ib::Inversed{<: Bijector})(y) = transform(ib, y) +""" + forward(b::Bijector, x) + forward(ib::Inversed{<: Bijector}, y) -# default `forward` implementations; should in general implement efficient way -# of computing both `transform` and `logabsdetjac` together. -forward(b::Bijector, x) = (rv=transform(b, x), logabsdetjac=logabsdetjac(b, x)) -forward(ib::Inversed{<: Bijector}, y) = (rv=transform(ib, y), logabsdetjac=logabsdetjac(ib, y)) +Computes both `transform` and `logabsdetjac` in one forward pass, and +returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. -# defaults implementation for inverses -logabsdetjac(ib::Inversed{<: Bijector}, y) = - logabsdetjac(ib.orig, transform(ib, y)) +This defaults to the call above, but often one can re-use computation +in the computation of the forward pass and the computation of the +`logabsdetjac`. `forward` allows the user to take advantange of such +efficiencies, if they exist. +""" +forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) +forward(ib::Inversed{<: Bijector}, y) = ( + rv=ib(y), + logabsdetjac=logabsdetjac(ib, y) +) -inv(b::Bijector) = Inversed(b) -inv(ib::Inversed{<:Bijector}) = ib.orig # AD implementations -function jacobian(b::ADBijector{<: ForwardDiffAD}, y::Real) - return ForwardDiff.derivative(z -> transform(b, z), y) +function jacobian(b::ADBijector{<: ForwardDiffAD}, x::Real) + return ForwardDiff.derivative(b, x) end function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::Real) - return ForwardDiff.derivative(z -> transform(b, z), y) + return ForwardDiff.derivative(b, y) +end +function jacobian(b::ADBijector{<: ForwardDiffAD}, x::AbstractVector{<: Real}) + return ForwardDiff.jacobian(b, x) end function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::AbstractVector{<: Real}) - return ForwardDiff.jacobian(z -> transform(b, z), y) + return ForwardDiff.jacobian(b, y) end -function jacobian(b::ADBijector{<: TrackerAD}, y::Real) - return Tracker.gradient(z -> transform(b, z), y)[1] +function jacobian(b::ADBijector{<: TrackerAD}, x::Real) + return Tracker.gradient(b, x)[1] end function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::Real) - return Tracker.gradient(z -> transform(b, z), y)[1] + return Tracker.gradient(b, y)[1] +end +function jacobian(b::ADBijector{<: TrackerAD}, x::AbstractVector{<: Real}) + # we extract `data` so that we don't returne a `Tracked` type + return Tracker.data(Tracker.jacobian(b, x)) end function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::AbstractVector{<: Real}) - return Tracker.jacobian(z -> transform(b, z), y) + # we extract `data` so that we don't returne a `Tracked` type + return Tracker.data(Tracker.jacobian(b, y)) end # TODO: allow batch-computation, especially for univariate case? @@ -98,18 +122,44 @@ end logabsdetjac(b::ADBijector, x::Real) = log(abs(jacobian(b, x))) function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) fact = lu(jacobian(b, x), check=false) - return issuccess(fact) ? log(abs(det(fact))) : -Inf # TODO: or smallest possible float? + return issuccess(fact) ? log(abs(det(fact))) : -Inf # TODO: do this or not? end +""" + logabsdetjacinv(b::Bijector, x) + +Just an alias for `logabsdetjac(inv(b), b(x))`. +""" +logabsdetjacinv(b::Bijector, x) = logabsdetjac(inv(b), b(x)) + ############### # Composition # ############### +""" + ∘(b1::Bijector, b2::Bijector) + compose(ts::Bijector...) + +A `Bijector` representing composition of bijectors. + +# Examples +It's important to note that `∘` does what is expected mathematically, which means that the +bijectors are applied to the input right-to-left, e.g. first applying `b2` and then `b1`: +``` +(b1 ∘ b2)(x) == b1(b2(x)) # => true +``` +But in the `Composed` struct itself, we store the bijectors left-to-right, so that +``` +cb1 = b1 ∘ b2 # => Composed.ts == [b2, b1] +cb2 = compose(b2, b1) +cb1(x) == cb2(x) == b1(b2(x)) # => true +``` +""" struct Composed{A} <: Bijector ts::A end -compose(ts...) = Composed(ts) +compose(ts::Bijector...) = Composed(ts) # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that @@ -119,18 +169,33 @@ compose(ts...) = Composed(ts) inv(ct::Composed) = Composed(map(inv, reverse(ct.ts))) -# TODO: can we implement this recursively, and with aggressive inlining, make this type-stable? -function transform(cb::Composed, x) +# # TODO: should arrays also be using recursive implementation instead? +function (cb::Composed{<: AbstractArray{<: Bijector}})(x) res = x for b ∈ cb.ts - res = transform(b, res) + res = b(res) end return res end -(cb::Composed)(x) = transform(cb, x) +# recursive implementation like this allows type-inference +_transform(x, b1::Bijector, b2::Bijector) = b2(b1(x)) +_transform(x, b::Bijector, bs::Bijector...) = _transform(b(x), bs...) +(cb::Composed{<: Tuple})(x) = _transform(x, cb.ts...) + +function _logabsdetjac(x, b1::Bijector, b2::Bijector) + logabsdetjac(b2, b1(x)) + logabsdetjac(b1, x) + res = forward(b1, x) + return logabsdetjac(b2, res.rv) + res.logabsdetjac +end +function _logabsdetjac(x, b1::Bijector, bs::Bijector...) + res = forward(b1, x) + return _logabsdetjac(res.rv, bs...) + res.logabsdetjac +end +logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) +# TODO: implement `forward` recursively function forward(cb::Composed, x) res = (rv=x, logabsdetjac=0) for t in cb.ts @@ -145,9 +210,8 @@ end ############################## struct Identity <: Bijector end -transform(::Identity, x) = x -transform(::Inversed{Identity}, y) = y -(b::Identity)(x) = transform(b, x) +(::Identity)(x) = x +(::Inversed{Identity})(y) = y forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) @@ -166,17 +230,25 @@ struct Logit{T<:Real} <: Bijector b::T end -transform(b::Logit, x::Real) = logit((x - b.a) / (b.b - b.a)) -transform(ib::Inversed{Logit{T}}, y::Real) where T <: Real = (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -(b::Logit)(x) = transform(b, x) +(b::Logit)(x) = @. logit((x - b.a) / (b.b - b.a)) +(ib::Inversed{<: Logit{<: Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -logabsdetjac(b::Logit{<:Real}, x::Real) = log((x - b.a) * (b.b - x) / (b.b - b.a)) -forward(b::Logit, x::Real) = (rv=transform(b, x), logabsdetjac=-logabsdetjac(b, x)) +logabsdetjac(b::Logit{<:Real}, x) = log((x - b.a) * (b.b - x) / (b.b - b.a)) +forward(b::Logit, x) = (rv=b(x), logabsdetjac=-logabsdetjac(b, x)) ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### +""" + DistributionBijector(d::Distribution) + DistributionBijector{<: ADBackend, D}(d::Distribution) + +This is the default `Bijector` for a distribution. + +It uses `link` and `invlink` to compute the transformations, and `AD` to compute +the `jacobian` and `logabsdetjac`. +""" struct DistributionBijector{AD, D} <: ADBijector{AD} where D <: Distribution dist::D end @@ -185,9 +257,8 @@ function DistributionBijector(dist::D) where D <: Distribution end # Simply uses `link` and `invlink` as transforms with AD to get jacobian -transform(b::DistributionBijector, x) = link(b.dist, x) -transform(ib::Inversed{<: DistributionBijector}, y) = invlink(ib.orig.dist, y) -(b::DistributionBijector)(x) = transform(b, x) +(b::DistributionBijector)(x) = link(b.dist, x) +(ib::Inversed{<: DistributionBijector})(y) = invlink(ib.orig.dist, y) "Returns the constrained-to-unconstrained bijector for distribution `d`." bijector(d::Distribution) = DistributionBijector(d) @@ -204,13 +275,27 @@ struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} w end -# Can implement these on a case-by-case basis +""" + transformed(d::Distribution) + transformed(d::Distribution, b::Bijector) + +Couples the distribution `d` with the bijector `b` by returning a `UnivariateTransformed` +or `MultivariateTransformed`, depending on type `D`. + +If not bijector is provided, i.e. `transformed(d)` is called, +then `transformed(d, bijector(d))` is returned. +""" transformed(d::UnivariateDistribution, b::Bijector) = UnivariateTransformed(d, b) transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed(d, b) transformed(d) = transformed(d, bijector(d)) -# can specialize further by +""" + bijector(d::Distribution) + +Returns the constrained-to-unconstrained bijector for distribution `d`. +""" bijector(d::Normal) = IdentityBijector +bijector(d::MvNormal) = IdentityBijector bijector(d::Beta{T}) where T <: Real = Logit(zero(T), one(T)) ############################## @@ -223,27 +308,29 @@ Base.length(td::MultivariateTransformed) = length(td.dist) # logp function logpdf(td::UnivariateTransformed, y::Real) # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) - logpdf_with_trans(td.dist, transform(inv(td.transform), y), true) + logpdf_with_trans(td.dist, inv(td.transform)(y), true) end function _logpdf(td::MultivariateTransformed, y::AbstractVector{<: Real}) # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) - logpdf_with_trans(td.dist, transform(inv(td.transform), y), true) + logpdf_with_trans(td.dist, inv(td.transform)(y), true) end function logpdf_with_jac(td::UnivariateTransformed, y::Real) z = logabsdetjac(inv(td.transform), y) - return (logpdf(td.dist, transform(inv(td.transform), y)) .+ z, z) + return (logpdf(td.dist, inv(td.transform)(y)) .+ z, z) end function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{<:Real}) z = logabsdetjac(inv(td.transform), y) - return (logpdf(td.dist, transform(inv(td.transform), y)) .+ z, z) + return (logpdf(td.dist, inv(td.transform)(y)) .+ z, z) end # rand -rand(rng::AbstractRNG, td::UnivariateTransformed) = transform(td.transform, rand(td.dist)) -function _rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{<: Real}) - rand!(rng, td.dist, x) - y = transform(td.transform, x) - copyto!(x, y) +rand(td::UnivariateTransformed) = td.transform(rand(td.dist)) +rand(rng::AbstractRNG, td::UnivariateTransformed) = td.transform(rand(rng, td.dist)) + +rand(td::MultivariateTransformed) = td.transform(rand(td.dist)) +function rand(td::MultivariateTransformed, num_samples::Int) + res = hcat([td.transform(rand(td.dist)) for i = 1:num_samples]...) + return res end diff --git a/test/interface.jl b/test/interface.jl index b8d5dfa6..9ca72181 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -44,12 +44,12 @@ Random.seed!(123) # single sample y = rand(td) - x = transform(inv(td.transform), y) + x = inv(td.transform)(y) @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) # multi-sample y = rand(td, 10) - x = transform.(inv(td.transform), y) + x = inv(td.transform).(y) @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) end @@ -60,7 +60,7 @@ Random.seed!(123) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf - y = transform(b, x) + y = b(x) b⁻¹ = inv(b) @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 @test logabsdetjac(b⁻¹, y) ≠ Inf @@ -73,7 +73,7 @@ Random.seed!(123) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf - y = transform(b, x) + y = b(x) b⁻¹ = inv(b) @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 @test logabsdetjac(b⁻¹, y) ≠ Inf @@ -85,7 +85,7 @@ Random.seed!(123) td = transformed(d) x = rand(d) - y = transform(td.transform, x) + y = td.transform(x) b = Bijectors.compose(td.transform, Bijectors.Identity()) ib = inv(b) @@ -95,10 +95,10 @@ Random.seed!(123) # inverse works fine for composition cb = b ∘ ib - @test transform(cb, x) ≈ x + @test cb(x) ≈ x cb2 = cb ∘ cb - @test transform(cb, x) ≈ x + @test cb(x) ≈ x # order of composed evaluation b1 = DistributionBijector(d) @@ -106,6 +106,12 @@ Random.seed!(123) cb = b1 ∘ b2 @test cb(x) ≈ b1(b2(x)) + + # contrived example + b = bijector(d) + cb = inv(b) ∘ b + cb = cb ∘ cb + @test (cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x end @testset "Example: ADVI" begin @@ -115,6 +121,6 @@ Random.seed!(123) ib = inv(b) # ℝ → [0, 1] td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] x = rand(td) # ∈ [0, 1] - @test 0 ≤ x ≤ 1 # => true + @test 0 ≤ x ≤ 1 end end From 4a7b2600b4b1a00777eb33b62c600b6a341ac715 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 19 Aug 2019 21:15:15 +0100 Subject: [PATCH 43/83] Fix the two remaining issues in https://github.com/torfjelde/Bijectors.jl/pull/1 (#1) * add dep and update deps * update test and fix the randomness * replace transpose(x) with x' * use tab to align comments * replace transpose(x) with x' * unify transform code for RadialLayer * unify transform code for PlanarLayer * improve code style and add comments --- Manifest.toml | 39 +++++++++++++------- Project.toml | 1 + src/norm_flows.jl | 91 ++++++++++++++++++++-------------------------- test/norm_flows.jl | 26 +++++++------ 4 files changed, 81 insertions(+), 76 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 2dce8390..387a9ea6 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,3 +1,5 @@ +# This file is machine-generated - editing it directly is not advised + [[Adapt]] deps = ["LinearAlgebra"] git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf" @@ -27,9 +29,9 @@ version = "0.5.6" [[CSTParser]] deps = ["Tokenize"] -git-tree-sha1 = "376a39f1862000442011390f1edf5e7f4dcc7142" +git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "0.6.0" +version = "0.6.2" [[CommonSubexpressions]] deps = ["Test"] @@ -49,6 +51,11 @@ git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.0.0" +[[DataAPI]] +git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.0.1" + [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a" @@ -76,7 +83,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "0.0.10" [[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] +deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[Distributions]] @@ -92,7 +99,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.3" [[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] +deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[LibGit2]] @@ -153,9 +160,9 @@ version = "1.1.0" [[PDMats]] deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "8b68513175b2dc4023a564cb0e917ce90e74fd69" +git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.9.7" +version = "0.9.9" [[Pkg]] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -197,6 +204,12 @@ git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" version = "0.5.0" +[[Roots]] +deps = ["Printf"] +git-tree-sha1 = "9cc4b586c71f9aea25312b94be8c195f119b0ec3" +uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" +version = "0.8.3" + [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -237,10 +250,10 @@ deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] -deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "2b6ca97be7ddfad5d9f16a13fe277d29f3d11c23" +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.31.0" +version = "0.32.0" [[StatsFuns]] deps = ["Rmath", "SpecialFunctions", "Test"] @@ -249,7 +262,7 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "0.8.0" [[SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +deps = ["Libdl", "LinearAlgebra", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[Test]] @@ -263,9 +276,9 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.0" [[Tokenize]] -git-tree-sha1 = "c8a8b00ae44a94950814ff77850470711a360225" +git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf" uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.5" +version = "0.5.6" [[Tracker]] deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] @@ -280,7 +293,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random"] +deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] diff --git a/Project.toml b/Project.toml index 5a2ec43d..4bc9f7b8 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 95775e45..bf5a2cd3 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -21,9 +21,9 @@ end function get_u_hat(u, w) # To preserve invertibility return ( - u + (planar_flow_m(transpose(w) * u) - transpose(w) * u)[1] + u + (planar_flow_m(w' * u) - w' * u)[1] * w / (norm(w[:,1],2) ^ 2) - ) # from A.1 + ) # from A.1 end function PlanarLayer(dims::Int, container=Array) @@ -33,43 +33,36 @@ function PlanarLayer(dims::Int, container=Array) return PlanarLayer(w, u, b) end -planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow from A.1 -dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow -ψ(z, w, b) = dtanh(transpose(w) * z .+ b) .* w # for planar flow from eq(11) +planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow from A.1 +dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow +ψ(z, w, b) = dtanh(w' * z .+ b) .* w # for planar flow from eq(11) -function transform(flow::PlanarLayer, z) +# An internal version of transform that returns intermediate variables +function _transform(flow::PlanarLayer, z) u_hat = get_u_hat(flow.u, flow.w) - return z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) # from eq(10) + transformed = z + u_hat * tanh.(flow.w' * z .+ flow.b) # from eq(10) + return (transformed=transformed, u_hat=u_hat) end +transform(flow::PlanarLayer, z) = _transform(flow, z).transformed + function forward(flow::T, z) where {T<:PlanarLayer} - u_hat = get_u_hat(flow.u, flow.w) + transformed, u_hat = _transform(flow, z) # Compute log_det_jacobian psi = ψ(z, flow.w, flow.b) - log_det_jacobian = log.(abs.(1.0 .+ transpose(psi) * u_hat)) # from eq(12) - transformed = z + u_hat * tanh.(transpose(flow.w) * z .+ flow.b) - return (rv=transformed, logabsdetjac=log_det_jacobian) # from eq(10) + log_det_jacobian = log.(abs.(1.0 .+ psi' * u_hat)) # from eq(12) + return (rv=transformed, logabsdetjac=log_det_jacobian) # from eq(10) end function inv(flow::PlanarLayer, y) u_hat = get_u_hat(flow.u, flow.w) - # Implemented with reference from A.1 - function f(y) - return loss(alpha) = ( - (transpose(flow.w) * y)[1] - alpha - - (transpose(flow.w) * u_hat)[1] - * tanh(alpha+flow.b[1]) - ) - end + # Define the objective functional; implemented with reference from A.1 + f(y) = alpha -> (flow.w' * y)[1] - alpha - (flow.w' * u_hat)[1] * tanh(alpha+flow.b[1]) + # Run solver alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] - alphas = transpose(alphas_) + alphas = alphas_' z_para = (flow.w ./ norm(flow.w,2)) * alphas - z_per = ( - y - z_para - u_hat * tanh.( - transpose(flow.w) * z_para - .+ flow.b - ) - ) + z_per = y - z_para - u_hat * tanh.(flow.w' * z_para .+ flow.b) return z_para + z_per end @@ -89,44 +82,40 @@ end (b::RadialLayer)(x) = transform(b, x) -h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) -dh(α, r) = - h(α, r) .^ 2 # for radial flow, derivative of h() +h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) +dh(α, r) = - h(α, r) .^ 2 # for radial flow; derivative of h() -function transform(flow::RadialLayer, z) - α = softplus(flow.α_[1]) # from A.2 - β_hat = -α + softplus(flow.β[1]) # from A.2 - r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)) - return z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) +# An internal version of transform that returns intermediate variables +function _transform(flow::RadialLayer, z) + α = softplus(flow.α_[1]) # from A.2 + β_hat = -α + softplus(flow.β[1]) # from A.2 + r = norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)' + transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) + return (transformed=transformed, α=α, β_hat=β_hat, r=r) end +transform(flow::RadialLayer, z) = _transform(flow, z).transformed + function forward(flow::T, z) where {T<:RadialLayer} - α = softplus(flow.α_[1]) # from A.2 - β_hat = -α + softplus(flow.β[1]) # from A.2 - r = transpose(norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)) - transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) + transformed, α, β_hat, r = _transform(flow, z) # Compute log_det_jacobian d = size(flow.z_0, 1) h_ = h(α, r) log_det_jacobian = @. ( - (d-1) * log(1.0 + β_hat * h_) + (d - 1) * log(1.0 + β_hat * h_) + log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) - ) # from eq(14) + ) # from eq(14) return (rv=transformed, logabsdetjac=log_det_jacobian) end function inv(flow::RadialLayer, y) - α = softplus(flow.α_[1]) # from A.2 - β_hat = - α + softplus(flow.β[1]) # from A.2 - function f(y) - # From eq(26) - return loss(r) = ( - norm(y - flow.z_0, 2) - - r * (1 + β_hat / (α + r)) - ) - end - rs_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] # A.2 - rs = transpose(rs_) - z_hat = (y .- flow.z_0) ./ (rs .* (1 .+ β_hat ./ (α .+ rs)) ) # from eq(25) + α = softplus(flow.α_[1]) # from A.2 + β_hat = - α + softplus(flow.β[1]) # from A.2 + # Define the objective functional + f(y) = r -> norm(y - flow.z_0, 2) - r * (1 + β_hat / (α + r)) # from eq(26) + # Run solver + rs = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)]' # from A.2 + z_hat = (y .- flow.z_0) ./ (rs .* (1 .+ β_hat ./ (α .+ rs))) # from eq(25) z = flow.z_0 .+ rs .* z_hat # from A.2 return z end diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 14b1c784..8c4d7e96 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -1,16 +1,17 @@ using Test using Bijectors, ForwardDiff, LinearAlgebra +using Random: seed! -@testset "planar flows" begin - for i in 1:10 - flow = PlanarLayer(10) - z = randn(10, 100) +seed!(1) + +@testset "PlanarLayer" begin + for i in 1:4 + flow = PlanarLayer(2) + z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) - our_method = sum(forward(flow, z).logabsdetjacob) + our_method = sum(forward(flow, z).logabsdetjac) @test our_method ≈ forward_diff - - # Inverse not accurate enough to pass with `≈` operator. - @test_broken inv(flow, transform(flow, z)) ≈ z + @test inv(flow, transform(flow, z)) ≈ z rtol=0.2 end w = ones(10, 1) @@ -21,15 +22,16 @@ using Bijectors, ForwardDiff, LinearAlgebra @test inv(flow, transform(flow, z)) ≈ z end -@testset "radial flows" begin - for i in 1:10 +@testset "RadialLayer" begin + for i in 1:4 flow = RadialLayer(2) - z = randn(2, 100) + z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) - our_method = sum(forward(flow, z).logabsdetjacob) + our_method = sum(forward(flow, z).logabsdetjac) @test our_method ≈ forward_diff @test inv(flow, transform(flow, z)) ≈ z end + α_ = ones(1) β = ones(1) z_0 = zeros(10, 1) From 1a712b26dcc8d391ade9b556c4184bbff7df9dfb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 14:45:43 +0200 Subject: [PATCH 44/83] adapted flows to new interface --- src/interface.jl | 5 ----- src/norm_flows.jl | 15 +++++++-------- test/norm_flows.jl | 16 ++++++++++------ 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 176b8b3f..f7e2fb32 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -205,11 +205,6 @@ function forward(cb::Composed, x) return res end -function rand(flow::Composed, dims::Integer, shape::Integer=1) - dims = [dims] - append!(dims, shape) - return transform(flow, randn(dims...)) -end ############################## # Example bijector: Identity # ############################## diff --git a/src/norm_flows.jl b/src/norm_flows.jl index bf5a2cd3..29604476 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -16,8 +16,6 @@ mutable struct PlanarLayer{T1,T2} <: Bijector b::T2 end -(b::PlanarLayer)(x) = transform(b, x) - function get_u_hat(u, w) # To preserve invertibility return ( @@ -44,7 +42,7 @@ function _transform(flow::PlanarLayer, z) return (transformed=transformed, u_hat=u_hat) end -transform(flow::PlanarLayer, z) = _transform(flow, z).transformed +(b::PlanarLayer)(z) = _transform(b, z).transformed function forward(flow::T, z) where {T<:PlanarLayer} transformed, u_hat = _transform(flow, z) @@ -54,7 +52,8 @@ function forward(flow::T, z) where {T<:PlanarLayer} return (rv=transformed, logabsdetjac=log_det_jacobian) # from eq(10) end -function inv(flow::PlanarLayer, y) +function (ib::Inversed{<: PlanarLayer})(y) + flow = ib.orig u_hat = get_u_hat(flow.u, flow.w) # Define the objective functional; implemented with reference from A.1 f(y) = alpha -> (flow.w' * y)[1] - alpha - (flow.w' * u_hat)[1] * tanh(alpha+flow.b[1]) @@ -80,8 +79,6 @@ function RadialLayer(dims::Int, container=Array) return RadialLayer(α_, β, z_0) end -(b::RadialLayer)(x) = transform(b, x) - h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) dh(α, r) = - h(α, r) .^ 2 # for radial flow; derivative of h() @@ -94,7 +91,7 @@ function _transform(flow::RadialLayer, z) return (transformed=transformed, α=α, β_hat=β_hat, r=r) end -transform(flow::RadialLayer, z) = _transform(flow, z).transformed +(b::RadialLayer)(z) = _transform(b, z).transformed function forward(flow::T, z) where {T<:RadialLayer} transformed, α, β_hat, r = _transform(flow, z) @@ -108,7 +105,9 @@ function forward(flow::T, z) where {T<:RadialLayer} return (rv=transformed, logabsdetjac=log_det_jacobian) end -function inv(flow::RadialLayer, y) +# function inv(flow::RadialLayer, y) +function (ib::Inversed{<: RadialLayer})(y) + flow = ib.orig α = softplus(flow.α_[1]) # from A.2 β_hat = - α + softplus(flow.β[1]) # from A.2 # Define the objective functional diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 8c4d7e96..0bb8ad18 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -8,10 +8,12 @@ seed!(1) for i in 1:4 flow = PlanarLayer(2) z = randn(2, 20) - forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) our_method = sum(forward(flow, z).logabsdetjac) + @test our_method ≈ forward_diff - @test inv(flow, transform(flow, z)) ≈ z rtol=0.2 + @test inv(flow)(flow(z)) ≈ z rtol=0.2 + @test (inv(flow) ∘ flow)(z) ≈ z rtol=0.2 end w = ones(10, 1) @@ -19,17 +21,19 @@ seed!(1) b = ones(1) flow = PlanarLayer(w, u, b) z = ones(10, 100) - @test inv(flow, transform(flow, z)) ≈ z + @test inv(flow)(flow(z)) ≈ z end @testset "RadialLayer" begin for i in 1:4 flow = RadialLayer(2) z = randn(2, 20) - forward_diff = log(abs(det(ForwardDiff.jacobian(t -> transform(flow, t), z)))) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) our_method = sum(forward(flow, z).logabsdetjac) + @test our_method ≈ forward_diff - @test inv(flow, transform(flow, z)) ≈ z + @test inv(flow)(flow(z)) ≈ z rtol=0.2 + @test (inv(flow) ∘ flow)(z) ≈ z rtol=0.2 end α_ = ones(1) @@ -37,5 +41,5 @@ end z_0 = zeros(10, 1) z = ones(10, 100) flow = RadialLayer(α_, β, z_0) - @test inv(flow, transform(flow, z)) ≈ z + @test inv(flow)(flow(z)) ≈ z end From 1c4b77d17810ea6feffd55b01e126806eeabe5e8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 15:01:32 +0200 Subject: [PATCH 45/83] removed random Revise import --- test/interface.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index 9ca72181..556286d4 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,4 +1,3 @@ -using Revise using Test using Bijectors using Random From f321c82cc42053aceb07d7f02b2ad2b1b32942ac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 16:37:26 +0200 Subject: [PATCH 46/83] introduced TransformedDistribution and MatrixTransformed as subtype --- src/interface.jl | 42 ++++++++---- test/interface.jl | 159 +++++++++++++++++++++++++++------------------- 2 files changed, 121 insertions(+), 80 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index f7e2fb32..206c1ce3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -264,29 +264,31 @@ end bijector(d::Distribution) = DistributionBijector(d) # Transformed distributions -struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector} +struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D <: Distribution{V, Continuous}, B <: Bijector} dist::D transform::B end - -struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} where {D <: MultivariateDistribution, B <: Bijector} - dist::D - transform::B +function TransformedDistribution(d::D, b::B) where {V <: VariateForm, B <: Bijector, D <: Distribution{V, Continuous}} + return TransformedDistribution{D, B, V}(d, b) end +const UnivariateTransformed = TransformedDistribution{<: Distribution, <: Bijector, Univariate} +const MultivariateTransformed = TransformedDistribution{<: Distribution, <: Bijector, Multivariate} +const MatrixTransformed = TransformedDistribution{<: Distribution, <: Bijector, Matrixvariate} +const Transformed = Union{UnivariateTransformed, MultivariateTransformed, MatrixTransformed} + + """ transformed(d::Distribution) transformed(d::Distribution, b::Bijector) -Couples the distribution `d` with the bijector `b` by returning a `UnivariateTransformed` -or `MultivariateTransformed`, depending on type `D`. +Couples distribution `d` with the bijector `b` by returning a `TransformedDistribution`. -If not bijector is provided, i.e. `transformed(d)` is called, -then `transformed(d, bijector(d))` is returned. +If no bijector is provided, i.e. `transformed(d)` is called, then +`transformed(d, bijector(d))` is returned. """ -transformed(d::UnivariateDistribution, b::Bijector) = UnivariateTransformed(d, b) -transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed(d, b) +transformed(d::Distribution, b::Bijector) = TransformedDistribution(d, b) transformed(d) = transformed(d, bijector(d)) """ @@ -303,7 +305,8 @@ bijector(d::Beta{T}) where T <: Real = Logit(zero(T), one(T)) ############################## # size -Base.length(td::MultivariateTransformed) = length(td.dist) +Base.length(td::Transformed) = length(td.dist) +Base.size(td::Transformed) = size(td.dist) # logp function logpdf(td::UnivariateTransformed, y::Real) @@ -329,8 +332,21 @@ end rand(td::UnivariateTransformed) = td.transform(rand(td.dist)) rand(rng::AbstractRNG, td::UnivariateTransformed) = td.transform(rand(rng, td.dist)) +# These ovarloadings are useful for differentiating sampling wrt. params of `td.dist` +# or params of `Bijector`, as they are not inplace like the default `rand` rand(td::MultivariateTransformed) = td.transform(rand(td.dist)) -function rand(td::MultivariateTransformed, num_samples::Int) +rand(rng::AbstractRNG, td::MultivariateTransformed) = td.transform(rand(rng, td.dist)) +function rand(rng::AbstractRNG, td::MultivariateTransformed, num_samples::Int) res = hcat([td.transform(rand(td.dist)) for i = 1:num_samples]...) return res end + +function _rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{<: Real}) + rand!(rng, td.dist, x) + x .= td.transform(x) +end + +function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<: Real}) + rand!(rng, td.dist, x) + x .= td.transform(x) +end diff --git a/test/interface.jl b/test/interface.jl index 556286d4..e5a6f7aa 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -7,75 +7,98 @@ Random.seed!(123) # Scalar tests @testset "Interface" begin - # Tests with scalar-valued distributions. - uni_dists = [ - Arcsine(2, 4), - Beta(2,2), - BetaPrime(), - Biweight(), - Cauchy(), - Chi(3), - Chisq(2), - Cosine(), - Epanechnikov(), - Erlang(), - Exponential(), - FDist(1, 1), - Frechet(), - Gamma(), - InverseGamma(), - InverseGaussian(), - # Kolmogorov(), - Laplace(), - Levy(), - Logistic(), - LogNormal(1.0, 2.5), - Normal(0.1, 2.5), - Pareto(), - Rayleigh(1.0), - TDist(2), - TruncatedNormal(0, 1, -Inf, 2), - ] - - for dist in uni_dists - @testset "$dist: dist" begin - td = transformed(dist) - - # single sample - y = rand(td) - x = inv(td.transform)(y) - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - - # multi-sample - y = rand(td, 10) - x = inv(td.transform).(y) - @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) - end - - @testset "$dist: ForwardDiff AD" begin - x = rand(dist) - b = DistributionBijector{Bijectors.ADBackend(:forward_diff), typeof(dist)}(dist) - - @test abs(det(Bijectors.jacobian(b, x))) > 0 - @test logabsdetjac(b, x) ≠ Inf - - y = b(x) - b⁻¹ = inv(b) - @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 - @test logabsdetjac(b⁻¹, y) ≠ Inf + @testset "Univariate" begin + # Tests with scalar-valued distributions. + uni_dists = [ + Arcsine(2, 4), + Beta(2,2), + BetaPrime(), + Biweight(), + Cauchy(), + Chi(3), + Chisq(2), + Cosine(), + Epanechnikov(), + Erlang(), + Exponential(), + FDist(1, 1), + Frechet(), + Gamma(), + InverseGamma(), + InverseGaussian(), + # Kolmogorov(), + Laplace(), + Levy(), + Logistic(), + LogNormal(1.0, 2.5), + Normal(0.1, 2.5), + Pareto(), + Rayleigh(1.0), + TDist(2), + TruncatedNormal(0, 1, -Inf, 2), + ] + + for dist in uni_dists + @testset "$dist: dist" begin + td = transformed(dist) + + # single sample + y = rand(td) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # multi-sample + y = rand(td, 10) + x = inv(td.transform).(y) + @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) + end + + @testset "$dist: ForwardDiff AD" begin + x = rand(dist) + b = DistributionBijector{Bijectors.ADBackend(:forward_diff), typeof(dist)}(dist) + + @test abs(det(Bijectors.jacobian(b, x))) > 0 + @test logabsdetjac(b, x) ≠ Inf + + y = b(x) + b⁻¹ = inv(b) + @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 + @test logabsdetjac(b⁻¹, y) ≠ Inf + end + + @testset "$dist: Tracker AD" begin + x = rand(dist) + b = DistributionBijector{Bijectors.ADBackend(:reverse_diff), typeof(dist)}(dist) + + @test abs(det(Bijectors.jacobian(b, x))) > 0 + @test logabsdetjac(b, x) ≠ Inf + + y = b(x) + b⁻¹ = inv(b) + @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 + @test logabsdetjac(b⁻¹, y) ≠ Inf + end end + end - @testset "$dist: Tracker AD" begin - x = rand(dist) - b = DistributionBijector{Bijectors.ADBackend(:reverse_diff), typeof(dist)}(dist) - - @test abs(det(Bijectors.jacobian(b, x))) > 0 - @test logabsdetjac(b, x) ≠ Inf - - y = b(x) - b⁻¹ = inv(b) - @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 - @test logabsdetjac(b⁻¹, y) ≠ Inf + @testset "Matrix variate" begin + v = 7.0 + S = Matrix(1.0I, 2, 2) + S[1, 2] = S[2, 1] = 0.5 + + matrix_dists = [ + Wishart(v,S), + InverseWishart(v,S) + ] + + for dist in matrix_dists + @testset "$dist: dist" begin + td = transformed(dist) + y = rand(rng, td) + + x = inv(td.transform)(y) + @test td.transform(x) ≈ y + end end end @@ -123,3 +146,5 @@ Random.seed!(123) @test 0 ≤ x ≤ 1 end end + + From 71fcca516ffae211a946ae62cb241514c189148c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 16:39:30 +0200 Subject: [PATCH 47/83] fixed typo --- test/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index e5a6f7aa..efe88806 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -94,7 +94,7 @@ Random.seed!(123) for dist in matrix_dists @testset "$dist: dist" begin td = transformed(dist) - y = rand(rng, td) + y = rand(td) x = inv(td.transform)(y) @test td.transform(x) ≈ y From a4ffc54072285ee80191ba08598e22a6f9c60bd5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 19:56:17 +0200 Subject: [PATCH 48/83] added proper implementation of bijector(d) for UnitDistribution --- src/Bijectors.jl | 2 ++ src/interface.jl | 59 +++++++++++++++++++++++-------- test/interface.jl | 89 ++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 135 insertions(+), 15 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 9824db4e..97c12b72 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -18,6 +18,7 @@ export TransformDistribution, transform, forward, logabsdetjac, + logabsdetjacinv, Bijector, ADBijector, Inversed, @@ -29,6 +30,7 @@ export TransformDistribution, transformed, UnivariateTransformed, MultivariateTransformed, + logpdf_with_jac, PlanarLayer, RadialLayer diff --git a/src/interface.jl b/src/interface.jl index 206c1ce3..88d033c5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -126,11 +126,11 @@ function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) end """ - logabsdetjacinv(b::Bijector, x) + logabsdetjacinv(b::Bijector, y) -Just an alias for `logabsdetjac(inv(b), b(x))`. +Just an alias for `logabsdetjac(inv(b), y)`. """ -logabsdetjacinv(b::Bijector, x) = logabsdetjac(inv(b), b(x)) +logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) ############### # Composition # @@ -234,8 +234,20 @@ end (ib::Inversed{<: Logit{<: Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a logabsdetjac(b::Logit{<:Real}, x) = log((x - b.a) * (b.b - x) / (b.b - b.a)) -forward(b::Logit, x) = (rv=b(x), logabsdetjac=-logabsdetjac(b, x)) +struct Exp <: Bijector end +struct Log <: Bijector end +const exp_b = Exp() +const log_b = Log() + +(b::Log)(x) = @. log(x) +(b::Exp)(y) = @. exp(y) + +inv(b::Log) = exp_b +inv(b::Exp) = log_b + +logabsdetjac(b::Log, x) = log(x) +logabsdetjac(b::Exp, y) = - y ####################################################### # Constrained to unconstrained distribution bijectors # @@ -298,7 +310,17 @@ Returns the constrained-to-unconstrained bijector for distribution `d`. """ bijector(d::Normal) = IdentityBijector bijector(d::MvNormal) = IdentityBijector -bijector(d::Beta{T}) where T <: Real = Logit(zero(T), one(T)) +bijector(d::PositiveDistribution) = log_b + +_union2tuple(T1::Type, T2::Type) = (T1, T2) +_union2tuple(T1::Type, T2::Union) = (T1, union2tuple(T2.a, T2.b)...) +_union2tuple(T::Union) = union2tuple(T.a, T.b) + +bijector(d::Kolmogorov) = Logit(zero(eltype(d)), zero(eltype(d))) +for D in _union2tuple(PositiveDistribution)[2:end] + # skipping Kolmogorov because it's a DataType + @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), zero(T)) +end ############################## # Distributions.jl interface # @@ -308,24 +330,33 @@ bijector(d::Beta{T}) where T <: Real = Logit(zero(T), one(T)) Base.length(td::Transformed) = length(td.dist) Base.size(td::Transformed) = size(td.dist) -# logp +# TODO: should eventually drop using `logpdf_with_trans` and replace with +# res = forward(inv(td.transform), y) +# logpdf(td.dist, res.rv) .- res.logabsdetjac function logpdf(td::UnivariateTransformed, y::Real) - # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) - logpdf_with_trans(td.dist, inv(td.transform)(y), true) + return logpdf_with_trans(td.dist, inv(td.transform)(y), true) end function _logpdf(td::MultivariateTransformed, y::AbstractVector{<: Real}) - # logpdf(td.dist, transform(inv(td.transform), y)) .+ logabsdetjac(inv(td.transform), y) - logpdf_with_trans(td.dist, inv(td.transform)(y), true) + return logpdf_with_trans(td.dist, inv(td.transform)(y), true) +end + +function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + return logpdf_with_trans(td.dist, inv(td.transform)(y), true) end function logpdf_with_jac(td::UnivariateTransformed, y::Real) - z = logabsdetjac(inv(td.transform), y) - return (logpdf(td.dist, inv(td.transform)(y)) .+ z, z) + res = forward(inv(td.transform), y) + return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) end function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{<:Real}) - z = logabsdetjac(inv(td.transform), y) - return (logpdf(td.dist, inv(td.transform)(y)) .+ z, z) + res = forward(inv(td.transform), y) + return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) +end + +function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + res = forward(inv(td.transform), y) + return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) end # rand diff --git a/test/interface.jl b/test/interface.jl index efe88806..108a99a9 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -47,6 +47,11 @@ Random.seed!(123) x = inv(td.transform)(y) @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + # logpdf_with_jac + lp, logjac = logpdf_with_jac(td, y) + @test lp ≈ logpdf(td, y) + @test logjac ≈ logabsdetjacinv(td.transform, y) + # multi-sample y = rand(td, 10) x = inv(td.transform).(y) @@ -81,6 +86,39 @@ Random.seed!(123) end end + @testset "Multivariate" begin + vector_dists = [ + Dirichlet(2, 3), + Dirichlet([1000 * one(Float64), eps(Float64)]), + Dirichlet([eps(Float64), 1000 * one(Float64)]), + MvNormal(randn(10), exp.(randn(10))), + # MvLogNormal(MvNormal(randn(10), exp.(randn(10)))), + Dirichlet([1000 * one(Float64), eps(Float64)]), + Dirichlet([eps(Float64), 1000 * one(Float64)]), + ] + + for dist in vector_dists + @testset "$dist: dist" begin + td = transformed(dist) + + # single sample + y = rand(td) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # logpdf_with_jac + lp, logjac = logpdf_with_jac(td, y) + @test lp ≈ logpdf(td, y) + @test logjac ≈ logabsdetjacinv(td.transform, y) + + # multi-sample + y = rand(td, 10) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + end + end + end + @testset "Matrix variate" begin v = 7.0 S = Matrix(1.0I, 2, 2) @@ -94,10 +132,22 @@ Random.seed!(123) for dist in matrix_dists @testset "$dist: dist" begin td = transformed(dist) + + # single sample y = rand(td) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # TODO: implement `logabsdetjac` for these + # logpdf_with_jac + # lp, logjac = logpdf_with_jac(td, y) + # @test lp ≈ logpdf(td, y) + # @test logjac ≈ logabsdetjacinv(td.transform, y) + # multi-sample + y = rand(td, 10) x = inv(td.transform)(y) - @test td.transform(x) ≈ y + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) end end end @@ -147,4 +197,41 @@ Random.seed!(123) end end +d = BetaPrime(Dual(1.0), Dual(1.0)) +b = bijector(d) + + + +bijector() + + +A = Union{[Beta, Normal, Gamma, InverseGamma]...} + +using ForwardDiff: Dual +@code_warntype bijector(Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0))) +d = Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0)) + +d = InverseGamma() +b = bijector(d) +eltype(d) +x = rand(d) +y = b(x) + + +x == inv(b)(y) + +# logabsdetjacinv(b, y) == logabsdetjac(inv(b), y) +# logabsdetjacinv(b, y) == - y + +@test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) +@test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) + + + + +logpdf_with_trans(d, x, true) - logpdf_with_trans(d, x, false) + +logpdf_with_trans(d, invlink(d, y), true) - logpdf_with_trans(d, invlink(d, y), false) +logabsdetjac(b, x) +logpdf(d, x) + logabsdetjacinv(b, y) From a8c2d8f1ab5adfc79cd87f30189a05f5cb81f973 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 20:29:12 +0200 Subject: [PATCH 49/83] added Scale and Shift together with bijector for Truncated --- src/interface.jl | 43 +++++++++++++++++++++++++++++++---- test/interface.jl | 57 +++++++++++++++-------------------------------- 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 88d033c5..2e484ede 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -235,6 +235,10 @@ end logabsdetjac(b::Logit{<:Real}, x) = log((x - b.a) * (b.b - x) / (b.b - b.a)) +############# +# Exp & Log # +############# + struct Exp <: Bijector end struct Log <: Bijector end const exp_b = Exp() @@ -249,6 +253,25 @@ inv(b::Exp) = log_b logabsdetjac(b::Log, x) = log(x) logabsdetjac(b::Exp, y) = - y +################# +# Shift & Scale # +################# +struct Shift{T} <: Bijector + a::T +end + +(b::Shift)(x) = b.a + x +inv(b::Shift) = Shift(-b.a) +logabsdetjac(b::Shift, x::T) where T = zero(T) + +struct Scale{T} <: Bijector + a::T +end + +(b::Scale)(x) = b.a * x +inv(b::Scale) = Scale(b^(-1)) +logabsdetjac(b::Scale, x) = log(abs(b.a)) + ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### @@ -313,13 +336,25 @@ bijector(d::MvNormal) = IdentityBijector bijector(d::PositiveDistribution) = log_b _union2tuple(T1::Type, T2::Type) = (T1, T2) -_union2tuple(T1::Type, T2::Union) = (T1, union2tuple(T2.a, T2.b)...) -_union2tuple(T::Union) = union2tuple(T.a, T.b) +_union2tuple(T1::Type, T2::Union) = (T1, _union2tuple(T2.a, T2.b)...) +_union2tuple(T::Union) = _union2tuple(T.a, T.b) bijector(d::Kolmogorov) = Logit(zero(eltype(d)), zero(eltype(d))) -for D in _union2tuple(PositiveDistribution)[2:end] +for D in _union2tuple(UnitDistribution)[2:end] # skipping Kolmogorov because it's a DataType - @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), zero(T)) + @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), one(T)) +end + +function bijector(d::Truncated{D}) where D <: Distribution + a, b = minimum(d), maximum(d) + lowerbounded, upperbounded = isfinite(a), isfinite(b) + if lowerbounded && upperbounded + return Logit(a, b) + elseif lowerbounded + return (log_b ∘ Shift(- a)) + else + return (log_b ∘ Shift(b) ∘ Scale(- one(typeof(b)))) + end end ############################## diff --git a/test/interface.jl b/test/interface.jl index 108a99a9..d75c96c5 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -40,6 +40,7 @@ Random.seed!(123) for dist in uni_dists @testset "$dist: dist" begin + dist = Erlang() td = transformed(dist) # single sample @@ -86,6 +87,23 @@ Random.seed!(123) end end + @testset "Truncated" begin + d = Truncated(Normal(), -1, 1) + b = bijector(d) + x = rand(d) + @test b(x) == link(d, x) + + d = Truncated(Normal(), -Inf, 1) + b = bijector(d) + x = rand(d) + @test b(x) == link(d, x) + + d = Truncated(Normal(), 1, Inf) + b = bijector(d) + x = rand(d) + @test b(x) == link(d, x) + end + @testset "Multivariate" begin vector_dists = [ Dirichlet(2, 3), @@ -196,42 +214,3 @@ Random.seed!(123) @test 0 ≤ x ≤ 1 end end - -d = BetaPrime(Dual(1.0), Dual(1.0)) -b = bijector(d) - - - -bijector() - - -A = Union{[Beta, Normal, Gamma, InverseGamma]...} - -using ForwardDiff: Dual -@code_warntype bijector(Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0))) -d = Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0)) - -d = InverseGamma() -b = bijector(d) -eltype(d) -x = rand(d) -y = b(x) - - -x == inv(b)(y) - -# logabsdetjacinv(b, y) == logabsdetjac(inv(b), y) -# logabsdetjacinv(b, y) == - y - -@test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) -@test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) - - - - -logpdf_with_trans(d, x, true) - logpdf_with_trans(d, x, false) - -logpdf_with_trans(d, invlink(d, y), true) - logpdf_with_trans(d, invlink(d, y), false) -logabsdetjac(b, x) - -logpdf(d, x) + logabsdetjacinv(b, y) From 5f8b4a047f9e8c1f38730f3251888ea39014dd8b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 20:34:48 +0200 Subject: [PATCH 50/83] removed redundant line in logabsdetjac --- src/interface.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 2e484ede..160abd7a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -185,7 +185,6 @@ _transform(x, b::Bijector, bs::Bijector...) = _transform(b(x), bs...) (cb::Composed{<: Tuple})(x) = _transform(x, cb.ts...) function _logabsdetjac(x, b1::Bijector, b2::Bijector) - logabsdetjac(b2, b1(x)) + logabsdetjac(b1, x) res = forward(b1, x) return logabsdetjac(b2, res.rv) + res.logabsdetjac end From 3054ad2853c33c0e4666267f1889d9a40d0352df Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 21:05:27 +0200 Subject: [PATCH 51/83] fixed a typo and added some more tests --- src/interface.jl | 9 ++++-- test/interface.jl | 76 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 160abd7a..1d411f0a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -268,7 +268,7 @@ struct Scale{T} <: Bijector end (b::Scale)(x) = b.a * x -inv(b::Scale) = Scale(b^(-1)) +inv(b::Scale) = Scale(b.a^(-1)) logabsdetjac(b::Scale, x) = log(abs(b.a)) ####################################################### @@ -344,15 +344,18 @@ for D in _union2tuple(UnitDistribution)[2:end] @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), one(T)) end -function bijector(d::Truncated{D}) where D <: Distribution +# FIXME: can we make this typestable? +function bijector(d::TransformDistribution) where D <: Distribution a, b = minimum(d), maximum(d) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded return Logit(a, b) elseif lowerbounded return (log_b ∘ Shift(- a)) - else + elseif upperbounded return (log_b ∘ Shift(b) ∘ Scale(- one(typeof(b)))) + else + return IdentityBijector end end diff --git a/test/interface.jl b/test/interface.jl index d75c96c5..fa33eb9f 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -40,7 +40,6 @@ Random.seed!(123) for dist in uni_dists @testset "$dist: dist" begin - dist = Erlang() td = transformed(dist) # single sample @@ -57,6 +56,14 @@ Random.seed!(123) y = rand(td, 10) x = inv(td.transform).(y) @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) + + # logpdf corresponds to logpdf_with_trans + d = dist + b = bijector(d) + x = rand(d) + y = b(x) + @test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) + @test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) end @testset "$dist: ForwardDiff AD" begin @@ -190,6 +197,11 @@ Random.seed!(123) cb2 = cb ∘ cb @test cb(x) ≈ x + # ensures that the `logabsdetjac` is correct + x = rand(d) + b = inv(bijector(d)) + @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) + # order of composed evaluation b1 = DistributionBijector(d) b2 = DistributionBijector(Gamma()) @@ -214,3 +226,65 @@ Random.seed!(123) @test 0 ≤ x ≤ 1 end end + +# using ForwardDiff: Dual +# d = BetaPrime(Dual(1.0), Dual(1.0)) +# b = bijector(d) + + + +# bijector() + + +# A = Union{[Beta, Normal, Gamma, InverseGamma]...} + +# @code_warntype bijector(Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0))) +# d = Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0)) + +# d = InverseGamma() +# b = bijector(d) +# eltype(d) +# x = rand(d) +# y = b(x) + + +# x == inv(b)(y) + +# # logabsdetjacinv(b, y) == logabsdetjac(inv(b), y) +# # logabsdetjacinv(b, y) == - y + +# @test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) +# @test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) + + + + +# logpdf_with_trans(d, x, true) - logpdf_with_trans(d, x, false) + +# logpdf_with_trans(d, invlink(d, y), true) - logpdf_with_trans(d, invlink(d, y), false) +# logabsdetjac(b, x) + +# logpdf(d, x) + logabsdetjacinv(b, y) + + +# using BenchmarkTools + +# @btime b(x) +# @btime link(d, x) + +# @btime logabsdetjac(b, x) + +# d = Truncated(Normal(), -1, 1) +# b = bijector(d) +# x = rand(d) +# @test b(x) == link(d, x) + +# d = Truncated(Normal(), -Inf, 1) +# b = bijector(d) +# x = rand(d) +# @test b(x) == link(d, x) + +# d = Truncated(Normal(), 1, Inf) +# b = bijector(d) +# x = rand(d) +# @test b(x) == link(d, x) From 883a8ecedef58c968eae9bcfccac31b10184536e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 21:06:10 +0200 Subject: [PATCH 52/83] removed some unecessary commented code --- test/interface.jl | 61 ----------------------------------------------- 1 file changed, 61 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index fa33eb9f..8f724eea 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -227,64 +227,3 @@ Random.seed!(123) end end -# using ForwardDiff: Dual -# d = BetaPrime(Dual(1.0), Dual(1.0)) -# b = bijector(d) - - - -# bijector() - - -# A = Union{[Beta, Normal, Gamma, InverseGamma]...} - -# @code_warntype bijector(Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0))) -# d = Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0)) - -# d = InverseGamma() -# b = bijector(d) -# eltype(d) -# x = rand(d) -# y = b(x) - - -# x == inv(b)(y) - -# # logabsdetjacinv(b, y) == logabsdetjac(inv(b), y) -# # logabsdetjacinv(b, y) == - y - -# @test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) -# @test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) - - - - -# logpdf_with_trans(d, x, true) - logpdf_with_trans(d, x, false) - -# logpdf_with_trans(d, invlink(d, y), true) - logpdf_with_trans(d, invlink(d, y), false) -# logabsdetjac(b, x) - -# logpdf(d, x) + logabsdetjacinv(b, y) - - -# using BenchmarkTools - -# @btime b(x) -# @btime link(d, x) - -# @btime logabsdetjac(b, x) - -# d = Truncated(Normal(), -1, 1) -# b = bijector(d) -# x = rand(d) -# @test b(x) == link(d, x) - -# d = Truncated(Normal(), -Inf, 1) -# b = bijector(d) -# x = rand(d) -# @test b(x) == link(d, x) - -# d = Truncated(Normal(), 1, Inf) -# b = bijector(d) -# x = rand(d) -# @test b(x) == link(d, x) From 5484f6d63370138dadc809047657175866c58075 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Aug 2019 21:18:45 +0200 Subject: [PATCH 53/83] addded some comments --- src/interface.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 1d411f0a..0652277c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -340,11 +340,15 @@ _union2tuple(T::Union) = _union2tuple(T.a, T.b) bijector(d::Kolmogorov) = Logit(zero(eltype(d)), zero(eltype(d))) for D in _union2tuple(UnitDistribution)[2:end] - # skipping Kolmogorov because it's a DataType + # Skipping Kolmogorov because it's a DataType @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), one(T)) end -# FIXME: can we make this typestable? +# FIXME: Can we make this type-stable? +# Everything but `Truncated` can probably be made type-stable +# by explicit implementation. Can also make a `TruncatedBijector` +# which has the same transform as the `link` function. +# E.g. (b::Truncated)(x) = link(b.d, x) or smth function bijector(d::TransformDistribution) where D <: Distribution a, b = minimum(d), maximum(d) lowerbounded, upperbounded = isfinite(a), isfinite(b) From 3565a181208ce9bd39452740038a76ddba5abd37 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 23 Aug 2019 01:03:31 +0200 Subject: [PATCH 54/83] added SimplexBijector and forward(flow, ::Vector) now returns vector --- src/Bijectors.jl | 1 + src/interface.jl | 240 ++++++++++++++++++++++++++++++++++++++++----- src/norm_flows.jl | 49 +++++++-- test/interface.jl | 148 ++++++++++++++++++++++++++++ test/norm_flows.jl | 18 ++++ 5 files changed, 426 insertions(+), 30 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 97c12b72..f5b1781e 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -31,6 +31,7 @@ export TransformDistribution, UnivariateTransformed, MultivariateTransformed, logpdf_with_jac, + logpdf_forward, PlanarLayer, RadialLayer diff --git a/src/interface.jl b/src/interface.jl index 0652277c..46919558 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -215,7 +215,7 @@ struct Identity <: Bijector end forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) logabsdetjac(::Identity, y::T) where T <: Real = zero(T) -logabsdetjac(::Identity, y::AbstractVector{T}) where T <: Real = zero(T) +logabsdetjac(::Identity, y::AbstractArray{T}) where T <: Real = zero(T) const IdentityBijector = Identity() @@ -271,6 +271,135 @@ end inv(b::Scale) = Scale(b.a^(-1)) logabsdetjac(b::Scale, x) = log(abs(b.a)) +#################### +# Simplex bijector # +#################### +struct SimplexBijector{T} <: Bijector where T end + +const simplex_b = SimplexBijector{Val{false}}() +const simplex_b_proj = SimplexBijector{Val{true}}() + +function _clamp(x::T, b::SimplexBijector) where T + bounds = (zero(T), one(T)) + clamped_x = clamp(x, bounds...) + DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x" + return clamped_x +end + +function (b::SimplexBijector{Val{proj}})(x::AbstractVector{T}) where {T, proj} + y, K = similar(x), length(x) + + ϵ = _eps(T) + sum_tmp = zero(T) + @inbounds z = x[1] * (one(T) - 2ϵ) + ϵ # z ∈ [ϵ, 1-ϵ] + @inbounds y[1] = StatsFuns.logit(z) + log(T(K - 1)) + @inbounds @simd for k in 2:(K - 1) + sum_tmp += x[k - 1] + # z ∈ [ϵ, 1-ϵ] + # x[k] = 0 && sum_tmp = 1 -> z ≈ 1 + z = (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) + y[k] = StatsFuns.logit(z) + log(T(K - k)) + end + @inbounds sum_tmp += x[K - 1] + @inbounds if proj + y[K] = zero(T) + else + y[K] = one(T) - sum_tmp - x[K] + end + + return y +end + +# Vectorised implementation of the above. +function (b::SimplexBijector{Val{proj}})(X::AbstractMatrix{T}) where {T<:Real, proj} + Y, K, N = similar(X), size(X, 1), size(X, 2) + + ϵ = _eps(T) + @inbounds @simd for n in 1:size(X, 2) + sum_tmp = zero(T) + z = X[1, n] * (one(T) - 2ϵ) + ϵ + Y[1, n] = StatsFuns.logit(z) + log(T(K - 1)) + for k in 2:(K - 1) + sum_tmp += X[k - 1, n] + z = (X[k, n] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) + Y[k, n] = StatsFuns.logit(z) + log(T(K - k)) + end + sum_tmp += X[K-1, n] + if proj + Y[K, n] = zero(T) + else + Y[K, n] = one(T) - sum_tmp - X[K, n] + end + end + + return Y +end + +function (ib::Inversed{<: SimplexBijector{Val{proj}}})(y::AbstractVector{T}) where {T, proj} + x, K = similar(y), length(y) + + ϵ = _eps(T) + @inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1))) + @inbounds x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), ib.orig) + sum_tmp = zero(T) + @inbounds @simd for k = 2:(K - 1) + z = StatsFuns.logistic(y[k] - log(T(K - k))) + sum_tmp += x[k-1] + x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, ib.orig) + end + @inbounds sum_tmp += x[K - 1] + @inbounds if proj + x[K] = _clamp(one(T) - sum_tmp, ib.orig) + else + x[K] = _clamp(one(T) - sum_tmp - y[K], ib.orig) + end + + return x +end + +# Vectorised implementation of the above. +function (ib::Inversed{<: SimplexBijector{Val{proj}}})(Y::AbstractMatrix{T}) where {T<:Real, proj} + X, K, N = similar(Y), size(Y, 1), size(Y, 2) + + ϵ = _eps(T) + @inbounds @simd for n in 1:size(X, 2) + sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] - log(T(K - 1))) + X[1, n] = _clamp((z - ϵ) / (one(T) - 2ϵ), ib.orig) + for k in 2:(K - 1) + z = StatsFuns.logistic(Y[k, n] - log(T(K - k))) + sum_tmp += X[k - 1] + X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, ib.orig) + end + sum_tmp += X[K - 1, n] + if proj + X[K, n] = _clamp(one(T) - sum_tmp, ib.orig) + else + X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], ib.orig) + end + end + + return X +end + + +function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where T + ϵ = _eps(T) + lp = zero(T) + + K = length(x) + + sum_tmp = zero(eltype(x)) + @inbounds z = x[1] + lp += log(z + ϵ) + log((one(T) + ϵ) - z) + @inbounds @simd for k in 2:(K - 1) + sum_tmp += x[k-1] + z = x[k] / ((one(T) + ϵ) - sum_tmp) + lp += log(z + ϵ) + log((one(T) + ϵ) - z) + log((one(T) + ϵ) - sum_tmp) + end + + return lp +end + ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### @@ -294,6 +423,12 @@ end (b::DistributionBijector)(x) = link(b.dist, x) (ib::Inversed{<: DistributionBijector})(y) = invlink(ib.orig.dist, y) +# HACK: he he he, we gottem' boys +function logabsdetjac(b::B, x::AbstractVector{<:Real}) where {AD, B <: DistributionBijector{AD, <: Dirichlet}} + return logpdf_with_trans(b.dist, x, true) - logpdf_with_trans(b.dist, x, false) +end + + "Returns the constrained-to-unconstrained bijector for distribution `d`." bijector(d::Distribution) = DistributionBijector(d) @@ -309,8 +444,9 @@ end const UnivariateTransformed = TransformedDistribution{<: Distribution, <: Bijector, Univariate} const MultivariateTransformed = TransformedDistribution{<: Distribution, <: Bijector, Multivariate} +const MvTransformed = MultivariateTransformed const MatrixTransformed = TransformedDistribution{<: Distribution, <: Bijector, Matrixvariate} -const Transformed = Union{UnivariateTransformed, MultivariateTransformed, MatrixTransformed} +const Transformed = TransformedDistribution """ @@ -333,6 +469,7 @@ Returns the constrained-to-unconstrained bijector for distribution `d`. bijector(d::Normal) = IdentityBijector bijector(d::MvNormal) = IdentityBijector bijector(d::PositiveDistribution) = log_b +bijector(d::SimplexDistribution) = simplex_b_proj _union2tuple(T1::Type, T2::Type) = (T1, T2) _union2tuple(T1::Type, T2::Union) = (T1, _union2tuple(T2.a, T2.b)...) @@ -371,33 +508,30 @@ end Base.length(td::Transformed) = length(td.dist) Base.size(td::Transformed) = size(td.dist) -# TODO: should eventually drop using `logpdf_with_trans` and replace with -# res = forward(inv(td.transform), y) -# logpdf(td.dist, res.rv) .- res.logabsdetjac function logpdf(td::UnivariateTransformed, y::Real) - return logpdf_with_trans(td.dist, inv(td.transform)(y), true) -end -function _logpdf(td::MultivariateTransformed, y::AbstractVector{<: Real}) - return logpdf_with_trans(td.dist, inv(td.transform)(y), true) -end - -function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - return logpdf_with_trans(td.dist, inv(td.transform)(y), true) + res = forward(inv(td.transform), y) + return logpdf(td.dist, res.rv) .- res.logabsdetjac end -function logpdf_with_jac(td::UnivariateTransformed, y::Real) +# TODO: implement more efficiently for flows in the case of `Matrix` +function _logpdf(td::MvTransformed, y::AbstractVector{<: Real}) res = forward(inv(td.transform), y) - return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) + return logpdf(td.dist, res.rv) .- res.logabsdetjac end -function logpdf_with_jac(td::MultivariateTransformed, y::AbstractVector{<:Real}) +function _logpdf(td::MvTransformed{<: Dirichlet}, y::AbstractVector{<: Real}) + T = eltype(y) + ϵ = _eps(T) + res = forward(inv(td.transform), y) - return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) + return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .- res.logabsdetjac end -function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - res = forward(inv(td.transform), y) - return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) +# TODO: should eventually drop using `logpdf_with_trans` and replace with +# res = forward(inv(td.transform), y) +# logpdf(td.dist, res.rv) .- res.logabsdetjac +function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + return logpdf_with_trans(td.dist, inv(td.transform)(y), true) end # rand @@ -406,14 +540,15 @@ rand(rng::AbstractRNG, td::UnivariateTransformed) = td.transform(rand(rng, td.di # These ovarloadings are useful for differentiating sampling wrt. params of `td.dist` # or params of `Bijector`, as they are not inplace like the default `rand` -rand(td::MultivariateTransformed) = td.transform(rand(td.dist)) -rand(rng::AbstractRNG, td::MultivariateTransformed) = td.transform(rand(rng, td.dist)) -function rand(rng::AbstractRNG, td::MultivariateTransformed, num_samples::Int) +rand(td::MvTransformed) = td.transform(rand(td.dist)) +rand(rng::AbstractRNG, td::MvTransformed) = td.transform(rand(rng, td.dist)) +# TODO: implement more efficiently for flows +function rand(rng::AbstractRNG, td::MvTransformed, num_samples::Int) res = hcat([td.transform(rand(td.dist)) for i = 1:num_samples]...) return res end -function _rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{<: Real}) +function _rand!(rng::AbstractRNG, td::MvTransformed, x::AbstractVector{<: Real}) rand!(rng, td.dist, x) x .= td.transform(x) end @@ -422,3 +557,60 @@ function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<: Real} rand!(rng, td.dist, x) x .= td.transform(x) end + +############################################################# +# Additional useful functions for `TransformedDistribution` # +############################################################# +""" + logpdf_with_jac(td::UnivariateTransformed, y::Real) + logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) + logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + +Makes use of the `forward` method to potentially re-use computation +and returns a tuple `(logpdf, logabsdetjac)`. +""" +function logpdf_with_jac(td::UnivariateTransformed, y::Real) + res = forward(inv(td.transform), y) + return (logpdf(td.dist, res.rv) .- res.logabsdetjac, res.logabsdetjac) +end + +# TODO: implement more efficiently for flows in the case of `Matrix` +function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) + res = forward(inv(td.transform), y) + return (logpdf(td.dist, res.rv) .- res.logabsdetjac, res.logabsdetjac) +end + +function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) + res = forward(inv(td.transform), y) + return (logpdf(td.dist, res.rv) .- res.logabsdetjac, res.logabsdetjac) +end + +function logpdf_with_jac(td::MvTransformed{<: Dirichlet}, y::AbstractVector{<:Real}) + T = eltype(y) + ϵ = _eps(T) + + res = forward(inv(td.transform), y) + return (logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .- res.logabsdetjac, res.logabsdetjac) +end + +# TODO: should eventually drop using `logpdf_with_trans` +function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + res = forward(inv(td.transform), y) + return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) +end + +""" + logpdf_forward(td::Transformed, x) + logpdf_forward(td::Transformed, x, logjac) + +Computes the `logpdf` using the forward pass of the bijector rather than using +the inverse transform to compute the necessary `logabsdetjac`. + +This is similar to `logpdf_with_trans`. +""" +# TODO: implement more efficiently for flows in the case of `Matrix` +logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) .+ logjac +logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) + +forward(d::Transformed, x) = forward(d.transform, x) +forward(d::Transformed) = forward(d, rand(d.dist)) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 29604476..c9bb9a48 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -10,6 +10,10 @@ using Roots # for inverse # D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # ################################################################################ +############### +# PlanarLayer # +############### + mutable struct PlanarLayer{T1,T2} <: Bijector w::T1 u::T1 @@ -44,15 +48,23 @@ end (b::PlanarLayer)(z) = _transform(b, z).transformed -function forward(flow::T, z) where {T<:PlanarLayer} +function _forward(flow::PlanarLayer, z) transformed, u_hat = _transform(flow, z) # Compute log_det_jacobian psi = ψ(z, flow.w, flow.b) - log_det_jacobian = log.(abs.(1.0 .+ psi' * u_hat)) # from eq(12) - return (rv=transformed, logabsdetjac=log_det_jacobian) # from eq(10) + log_det_jacobian = log.(abs.(1.0 .+ psi' * u_hat)) # from eq(12) + return (rv=transformed, logabsdetjac=vec(log_det_jacobian)) # from eq(10) +end + +forward(flow::PlanarLayer, z) = _forward(flow, z) + +function forward(flow::PlanarLayer, z::AbstractVector{<: Real}) + res = _forward(flow, z) + return (rv=res.rv, logabsdetjac=res.logabsdetjac[1]) end -function (ib::Inversed{<: PlanarLayer})(y) + +function (ib::Inversed{<: PlanarLayer})(y::AbstractMatrix{<: Real}) flow = ib.orig u_hat = get_u_hat(flow.u, flow.w) # Define the objective functional; implemented with reference from A.1 @@ -66,6 +78,18 @@ function (ib::Inversed{<: PlanarLayer})(y) return z_para + z_per end +function (ib::Inversed{<: PlanarLayer})(y::AbstractVector{<: Real}) + return vec(ib(reshape(y, (length(y), 1)))) +end + +logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac + +############### +# RadialLayer # +############### + +# FIXME: using `TrackedArray` for the parameters, we end up with +# nested tracked structures; don't want this. mutable struct RadialLayer{T1,T2} <: Bijector α_::T1 β::T1 @@ -93,7 +117,7 @@ end (b::RadialLayer)(z) = _transform(b, z).transformed -function forward(flow::T, z) where {T<:RadialLayer} +function _forward(flow::RadialLayer, z) transformed, α, β_hat, r = _transform(flow, z) # Compute log_det_jacobian d = size(flow.z_0, 1) @@ -102,7 +126,14 @@ function forward(flow::T, z) where {T<:RadialLayer} (d - 1) * log(1.0 + β_hat * h_) + log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) # from eq(14) - return (rv=transformed, logabsdetjac=log_det_jacobian) + return (rv=transformed, logabsdetjac=vec(log_det_jacobian)) +end + +forward(flow::RadialLayer, z) = _forward(flow, z) + +function forward(flow::RadialLayer, z::AbstractVector{<: Real}) + res = forward(flow, z) + return (rv=res.rv, logabsdetjac=res.logabsdetjac[1]) end # function inv(flow::RadialLayer, y) @@ -118,3 +149,9 @@ function (ib::Inversed{<: RadialLayer})(y) z = flow.z_0 .+ rs .* z_hat # from A.2 return z end + +function (ib::Inversed{<: RadialLayer})(y::AbstractVector{<: Real}) + return vec(ib(reshape(y, (length(y), 1)))) +end + +logabsdetjac(flow::RadialLayer, x) = forward(flow, x).logabsdetjac diff --git a/test/interface.jl b/test/interface.jl index 8f724eea..623dcb57 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -227,3 +227,151 @@ Random.seed!(123) end end +# using ForwardDiff: Dual +# d = BetaPrime(Dual(1.0), Dual(1.0)) +# b = bijector(d) + + + +# bijector() + + +# A = Union{[Beta, Normal, Gamma, InverseGamma]...} + +# @code_warntype bijector(Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0))) +# d = Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0)) + +# d = InverseGamma() +# b = bijector(d) +# eltype(d) +# x = rand(d) +# y = b(x) + + +# x == inv(b)(y) + +# # logabsdetjacinv(b, y) == logabsdetjac(inv(b), y) +# # logabsdetjacinv(b, y) == - y + +# @test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) +# @test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) + + + + +# logpdf_with_trans(d, x, true) - logpdf_with_trans(d, x, false) + +# logpdf_with_trans(d, invlink(d, y), true) - logpdf_with_trans(d, invlink(d, y), false) +# logabsdetjac(b, x) + +# logpdf(d, x) + logabsdetjacinv(b, y) + + +# using BenchmarkTools + +# @btime b(x) +# @btime link(d, x) + +# @btime logabsdetjac(b, x) + +# d = Truncated(Normal(), -1, 1) +# b = bijector(d) +# x = rand(d) +# @test b(x) == link(d, x) + +# d = Truncated(Normal(), -Inf, 1) +# b = bijector(d) +# x = rand(d) +# @test b(x) == link(d, x) + +# d = Truncated(Normal(), 1, Inf) +# b = bijector(d) +# x = rand(d) +# @test b(x) == link(d, x) + + +# d = Beta() +# x = rand(d) +# f(x, d) = bijector(d)(x) +# @code_warntype f(x, d) + + +# d = MvNormal(zeros(10), ones(10)) +# b = PlanarLayer(10) +# flow = transformed(d, b) # <= Radial flow +# y = rand(flow, 5) + + + +# res = forward(flow) +# res.rv +# res.logabsdetjac + +# x = rand(d, 5) +# res = forward(flow, x) +# res.rv +# res.logabsdetjac + +# logpdf_with_jac(flow, y) + +# forward(b, rand(d)) + +# @code_typed logpdf(flow, y) + + +# using Tracker +# b = PlanarLayer(10, param) +# flow = transformed(d, b) +# y = rand(flow) +# sum(y) + +# Tracker.back!(sum(y), 1.0) +# Tracker.grad(b.u) + + +# x = rand(d) +# @code_warntype forward(b, x) +# y = rand(flow, 5) +# Tracker.back!(mean(sum(y, dims=1)), 1.0) + +# Tracker.grad(b.u) + +# @code_warntype forward(b, y) + + +# Bijectors.logpdf_forward(flow, x) + +# forward(inv(flow.transform), rand(flow)).rv +# rand(flow) + +# Bijectors.get_u_hat(flow.transform.u, flow.transform.w) + + +# Bijectors.logpdf_forward(flow, x) + +# logpdf(flow, y) + +# using Tracker +# b = RadialLayer(10, param) + +# b.α_ +# b.β +# b.z_0 + +# b(x)[1] + + +# x = rand(d) +# (b ∘ b)(x) +# b(x) + + +# rb = PlanarLayer(10, param) +# pb = PlanarLayer(10, param) + +# flow = transformed(d, rb ∘ pb) +# y = rand(flow) + +# Tracker.back!(sum(y.^2), 1.0) +# Tracker.grad(rb.u) + diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 0bb8ad18..7f5d94cf 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -43,3 +43,21 @@ end flow = RadialLayer(α_, β, z_0) @test inv(flow)(flow(z)) ≈ z end + +@testset "Flows" begin + d = MvNormal(zeros(2), ones(2)) + b = PlanarLayer(2) + flow = transformed(d, b) # <= Radial flow + + y = rand(flow) + @test logpdf(flow, y) != 0.0 + + x = rand(d) + y = flow.transform(x) + res = forward(flow, x) + lp = logpdf_forward(flow, x, res.logabsdetjac) + + @test res.rv ≈ y + @test logpdf(flow, y) ≈ lp rtol=0.1 +end + From 7cf3f6c42a6e5f87e229a7dfaea682d6268fd5f2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 13:48:45 +0200 Subject: [PATCH 55/83] removed left-over commented out code --- src/interface.jl | 6 +- test/interface.jl | 149 ---------------------------------------------- 2 files changed, 5 insertions(+), 150 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 46919558..62e07a64 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -279,6 +279,8 @@ struct SimplexBijector{T} <: Bijector where T end const simplex_b = SimplexBijector{Val{false}}() const simplex_b_proj = SimplexBijector{Val{true}}() +# the following implementations are basically just copy-paste from `invlink` and +# `link` for `SimplexDistributions` but dropping the dependence on the `Distribution`. function _clamp(x::T, b::SimplexBijector) where T bounds = (zero(T), one(T)) clamped_x = clamp(x, bounds...) @@ -358,7 +360,9 @@ function (ib::Inversed{<: SimplexBijector{Val{proj}}})(y::AbstractVector{T}) whe end # Vectorised implementation of the above. -function (ib::Inversed{<: SimplexBijector{Val{proj}}})(Y::AbstractMatrix{T}) where {T<:Real, proj} +function (ib::Inversed{<: SimplexBijector{Val{proj}}})( + Y::AbstractMatrix{T} +) where {T<:Real, proj} X, K, N = similar(Y), size(Y, 1), size(Y, 2) ϵ = _eps(T) diff --git a/test/interface.jl b/test/interface.jl index 623dcb57..a1b68db1 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -226,152 +226,3 @@ Random.seed!(123) @test 0 ≤ x ≤ 1 end end - -# using ForwardDiff: Dual -# d = BetaPrime(Dual(1.0), Dual(1.0)) -# b = bijector(d) - - - -# bijector() - - -# A = Union{[Beta, Normal, Gamma, InverseGamma]...} - -# @code_warntype bijector(Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0))) -# d = Beta(ForwardDiff.Dual(2.0), ForwardDiff.Dual(3.0)) - -# d = InverseGamma() -# b = bijector(d) -# eltype(d) -# x = rand(d) -# y = b(x) - - -# x == inv(b)(y) - -# # logabsdetjacinv(b, y) == logabsdetjac(inv(b), y) -# # logabsdetjacinv(b, y) == - y - -# @test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) -# @test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) - - - - -# logpdf_with_trans(d, x, true) - logpdf_with_trans(d, x, false) - -# logpdf_with_trans(d, invlink(d, y), true) - logpdf_with_trans(d, invlink(d, y), false) -# logabsdetjac(b, x) - -# logpdf(d, x) + logabsdetjacinv(b, y) - - -# using BenchmarkTools - -# @btime b(x) -# @btime link(d, x) - -# @btime logabsdetjac(b, x) - -# d = Truncated(Normal(), -1, 1) -# b = bijector(d) -# x = rand(d) -# @test b(x) == link(d, x) - -# d = Truncated(Normal(), -Inf, 1) -# b = bijector(d) -# x = rand(d) -# @test b(x) == link(d, x) - -# d = Truncated(Normal(), 1, Inf) -# b = bijector(d) -# x = rand(d) -# @test b(x) == link(d, x) - - -# d = Beta() -# x = rand(d) -# f(x, d) = bijector(d)(x) -# @code_warntype f(x, d) - - -# d = MvNormal(zeros(10), ones(10)) -# b = PlanarLayer(10) -# flow = transformed(d, b) # <= Radial flow -# y = rand(flow, 5) - - - -# res = forward(flow) -# res.rv -# res.logabsdetjac - -# x = rand(d, 5) -# res = forward(flow, x) -# res.rv -# res.logabsdetjac - -# logpdf_with_jac(flow, y) - -# forward(b, rand(d)) - -# @code_typed logpdf(flow, y) - - -# using Tracker -# b = PlanarLayer(10, param) -# flow = transformed(d, b) -# y = rand(flow) -# sum(y) - -# Tracker.back!(sum(y), 1.0) -# Tracker.grad(b.u) - - -# x = rand(d) -# @code_warntype forward(b, x) -# y = rand(flow, 5) -# Tracker.back!(mean(sum(y, dims=1)), 1.0) - -# Tracker.grad(b.u) - -# @code_warntype forward(b, y) - - -# Bijectors.logpdf_forward(flow, x) - -# forward(inv(flow.transform), rand(flow)).rv -# rand(flow) - -# Bijectors.get_u_hat(flow.transform.u, flow.transform.w) - - -# Bijectors.logpdf_forward(flow, x) - -# logpdf(flow, y) - -# using Tracker -# b = RadialLayer(10, param) - -# b.α_ -# b.β -# b.z_0 - -# b(x)[1] - - -# x = rand(d) -# (b ∘ b)(x) -# b(x) - - -# rb = PlanarLayer(10, param) -# pb = PlanarLayer(10, param) - -# flow = transformed(d, rb ∘ pb) -# y = rand(flow) - -# Tracker.back!(sum(y.^2), 1.0) -# Tracker.grad(rb.u) - From 6cb2f61904e02d69e4d0d0b1cb9e87481538f695 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 13:51:44 +0200 Subject: [PATCH 56/83] updated Manifest --- Manifest.toml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 387a9ea6..f1c882b3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,5 +1,3 @@ -# This file is machine-generated - editing it directly is not advised - [[Adapt]] deps = ["LinearAlgebra"] git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf" @@ -83,7 +81,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "0.0.10" [[Distributed]] -deps = ["Random", "Serialization", "Sockets"] +deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[Distributions]] @@ -99,7 +97,7 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.3" [[InteractiveUtils]] -deps = ["Markdown"] +deps = ["LinearAlgebra", "Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[LibGit2]] @@ -262,7 +260,7 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "0.8.0" [[SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "SparseArrays"] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[Test]] @@ -293,7 +291,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67" version = "0.4.0" [[UUIDs]] -deps = ["Random", "SHA"] +deps = ["Random"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] From 3ee2584af88484cc84816ab5c3da9b51a5b6c7e6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 15:44:40 +0200 Subject: [PATCH 57/83] fixed issue with _transform for RadialLayer when using Tracked --- src/norm_flows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index c9bb9a48..c8068775 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -110,7 +110,7 @@ dh(α, r) = - h(α, r) .^ 2 # for radial flow; derivative of h() function _transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 - r = norm.([z[:,i] .- flow.z_0 for i in 1:size(z, 2)], 2)' + r = sum((z .- flow.z_0).^2; dims = 1) transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) return (transformed=transformed, α=α, β_hat=β_hat, r=r) end From 939709885367b095b17600ca2e7b98d74edfc5b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 15:46:32 +0200 Subject: [PATCH 58/83] forgot a sqrt in previous commit --- src/norm_flows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index c8068775..d4571fea 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -110,7 +110,7 @@ dh(α, r) = - h(α, r) .^ 2 # for radial flow; derivative of h() function _transform(flow::RadialLayer, z) α = softplus(flow.α_[1]) # from A.2 β_hat = -α + softplus(flow.β[1]) # from A.2 - r = sum((z .- flow.z_0).^2; dims = 1) + r = sqrt.(sum((z .- flow.z_0).^2; dims = 1)) transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) return (transformed=transformed, α=α, β_hat=β_hat, r=r) end From 2bf33d5a015b48956923159e9b4744d9d867e620 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 16:15:22 +0200 Subject: [PATCH 59/83] removed now redundant hack for Dirichlet --- src/interface.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 62e07a64..807e38d8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -427,11 +427,6 @@ end (b::DistributionBijector)(x) = link(b.dist, x) (ib::Inversed{<: DistributionBijector})(y) = invlink(ib.orig.dist, y) -# HACK: he he he, we gottem' boys -function logabsdetjac(b::B, x::AbstractVector{<:Real}) where {AD, B <: DistributionBijector{AD, <: Dirichlet}} - return logpdf_with_trans(b.dist, x, true) - logpdf_with_trans(b.dist, x, false) -end - "Returns the constrained-to-unconstrained bijector for distribution `d`." bijector(d::Distribution) = DistributionBijector(d) From 0db99881cb0f9ac636ecd10f6bdb0350813f4811 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 17:03:05 +0200 Subject: [PATCH 60/83] fixed stackoverflow on forward(::RadialLayer, x::AbstractArray) due to typo --- src/norm_flows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index d4571fea..e2877d96 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -132,7 +132,7 @@ end forward(flow::RadialLayer, z) = _forward(flow, z) function forward(flow::RadialLayer, z::AbstractVector{<: Real}) - res = forward(flow, z) + res = _forward(flow, z) return (rv=res.rv, logabsdetjac=res.logabsdetjac[1]) end From 0d5b78c67b4074c8eae9940845e7a14ba8671598 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 17:04:50 +0200 Subject: [PATCH 61/83] added recursive implementation for forward(cb::Composed, x) --- src/interface.jl | 29 +++++++++++++++++++++++------ test/interface.jl | 18 ++++++++++++++++-- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 807e38d8..5e98c337 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -194,15 +194,32 @@ function _logabsdetjac(x, b1::Bijector, bs::Bijector...) end logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) -# TODO: implement `forward` recursively -function forward(cb::Composed, x) - res = (rv=x, logabsdetjac=0) +# recursive implementation of `forward` +function _forward(f, b1::Bijector, b2::Bijector) + f1 = forward(b1, f.rv) + f2 = forward(b2, f1.rv) + return (rv=f2.rv, logabsdetjac=f2.logabsdetjac + f1.logabsdetjac + f.logabsdetjac) +end +function _forward(f, b::Bijector, bs::Bijector...) + f1 = forward(b, f.rv) + f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac + f.logabsdetjac) + return _forward(f_, bs...) +end +forward(cb::Composed{<: Tuple}, x, logjac) = _forward((rv=x, logabsdetjac=logjac), cb.ts...) +forward(cb::Composed{<: Tuple}, x) = forward(cb, x, zero(eltype(x))) + +function forward(cb::Composed, x, logjac) + rv = x + logjac_ = logjac + for t in cb.ts - res′ = forward(t, res.rv) - res = (rv=res′.rv, logabsdetjac=res.logabsdetjac + res′.logabsdetjac) + res = forward(t, rv) + rv = res.rv + logjac_ = res.logabsdetjac + logjac_ end - return res + return (rv=rv, logabsdetjac=logjac_) end +forward(cb::Composed, x) = forward(cb, x, zero(eltype(x))) ############################## # Example bijector: Identity # diff --git a/test/interface.jl b/test/interface.jl index a1b68db1..48acda96 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -177,7 +177,7 @@ Random.seed!(123) end end - @testset "Composition" begin + @testset "Composition <: Bijector" begin d = Beta() td = transformed(d) @@ -214,9 +214,23 @@ Random.seed!(123) cb = inv(b) ∘ b cb = cb ∘ cb @test (cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x + + # forward for tuple and array + d = Beta() + b = inv(bijector(d)) + b⁻¹ = inv(b) + x = rand(d) + + cb_t = b⁻¹ ∘ b⁻¹ + f_t = forward(cb_t, x) + + cb_a = Composed([b⁻¹, b⁻¹]) + f_a = forward(cb_a, x) + + @test f_t == f_a end - @testset "Example: ADVI" begin + @testset "Example: ADVI single" begin # Usage in ADVI d = Beta() b = DistributionBijector(d) # [0, 1] → ℝ From a6d4c36a3c1d3c09aecde4b0c9f5819ce2930fda Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 17:33:20 +0200 Subject: [PATCH 62/83] support for batch computation using forward(b::Bijector, x, logjac) --- src/interface.jl | 50 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 5e98c337..be8c2c1e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -41,7 +41,10 @@ abstract type Bijector end Broadcast.broadcastable(b::Bijector) = Ref(b) -"Abstract type for a `Bijector` making use of auto-differentation (AD)." +""" +Abstract type for a `Bijector` making use of auto-differentation (AD) to +implement `jacobian` and, by impliciation, `logabsdetjac`. +""" abstract type ADBijector{AD} <: Bijector end """ @@ -71,17 +74,43 @@ logabsdetjac(ib::Inversed{<: Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) """ forward(b::Bijector, x) - forward(ib::Inversed{<: Bijector}, y) + forward(b::Bijector, x, logjac) Computes both `transform` and `logabsdetjac` in one forward pass, and returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. +`forward(b::Bijector, x, logjac)` allows the user to specify the accumulation +variable for the `logabsdetjac` field. This is useful for doing batch computations. +See below for example. + This defaults to the call above, but often one can re-use computation in the computation of the forward pass and the computation of the `logabsdetjac`. `forward` allows the user to take advantange of such efficiencies, if they exist. + +# Examples +`forward(b::Bijector, x, logjac)` allows the user to specify the accumulation +variable for the `logabsdetjac` field. This is useful for doing **batch computations**. +``` +julia> b = PlanarLayer(2); + +julia> cb = b ∘ b; + +julia> x = randn(2, 5) +2×5 Array{Float64,2}: + 0.35499 0.244763 -0.790103 -0.551053 -0.315193 + -0.210833 -0.854612 -0.197942 -0.486802 0.864496 + +julia> forward(cb, x) +ERROR: MethodError: no method matching +(::Array{Float64,1}, ::Float64) + ... +julia> forward(cb, x, zeros(size(x, 2))) +(rv = [0.0277394 -0.16401 … -0.253787 0.145421; 0.0144131 -0.573255 … -0.69141 0.547456], logabsdetjac = [-2.26736, -1.63996, -0.884037, -2.49163, -1.24809]) +``` + """ -forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) +forward(b::Bijector, x) = forward(b, x, zero(eltype(x))) +forward(b::Bijector, x, logjac) = (rv=b(x), logabsdetjac=logjac + logabsdetjac(b, x)) forward(ib::Inversed{<: Bijector}, y) = ( rv=ib(y), logabsdetjac=logabsdetjac(ib, y) @@ -150,8 +179,8 @@ bijectors are applied to the input right-to-left, e.g. first applying `b2` and t ``` But in the `Composed` struct itself, we store the bijectors left-to-right, so that ``` -cb1 = b1 ∘ b2 # => Composed.ts == [b2, b1] -cb2 = compose(b2, b1) +cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) +cb2 = compose(b2, b1) # => Composed.ts == (b2, b1) cb1(x) == cb2(x) == b1(b2(x)) # => true ``` """ @@ -205,8 +234,12 @@ function _forward(f, b::Bijector, bs::Bijector...) f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac + f.logabsdetjac) return _forward(f_, bs...) end -forward(cb::Composed{<: Tuple}, x, logjac) = _forward((rv=x, logabsdetjac=logjac), cb.ts...) -forward(cb::Composed{<: Tuple}, x) = forward(cb, x, zero(eltype(x))) +# if `x` represents multiple elements to act on, we want to allow the user to +# specify the `logjac` accumulation field since it's ambigious, e.g. should +# it be a vector or a float? +function forward(cb::Composed{<: Tuple}, x, logjac) + _forward((rv=x, logabsdetjac=logjac), cb.ts...) +end function forward(cb::Composed, x, logjac) rv = x @@ -219,7 +252,6 @@ function forward(cb::Composed, x, logjac) end return (rv=rv, logabsdetjac=logjac_) end -forward(cb::Composed, x) = forward(cb, x, zero(eltype(x))) ############################## # Example bijector: Identity # @@ -249,7 +281,7 @@ end (b::Logit)(x) = @. logit((x - b.a) / (b.b - b.a)) (ib::Inversed{<: Logit{<: Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -logabsdetjac(b::Logit{<:Real}, x) = log((x - b.a) * (b.b - x) / (b.b - b.a)) +logabsdetjac(b::Logit{<:Real}, x) = @. log((x - b.a) * (b.b - x) / (b.b - b.a)) ############# # Exp & Log # From e2f26dfcf784cd7f36ce38752f77b57940fa4a56 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 17:39:42 +0200 Subject: [PATCH 63/83] edited some comments --- src/interface.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index be8c2c1e..4536bdad 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -89,23 +89,23 @@ in the computation of the forward pass and the computation of the efficiencies, if they exist. # Examples -`forward(b::Bijector, x, logjac)` allows the user to specify the accumulation -variable for the `logabsdetjac` field. This is useful for doing **batch computations**. +`forward(b::Bijector, x, logjac)` allows specification of the accumulation variable +for the `logabsdetjac` field. This is useful for doing **batch computations**. ``` julia> b = PlanarLayer(2); julia> cb = b ∘ b; -julia> x = randn(2, 5) -2×5 Array{Float64,2}: - 0.35499 0.244763 -0.790103 -0.551053 -0.315193 - -0.210833 -0.854612 -0.197942 -0.486802 0.864496 +julia> x = randn(2, 3) +2×3 Array{Float64,2}: + 0.0660476 -0.77195 -1.7832 + -0.147743 -1.46459 0.264924 julia> forward(cb, x) ERROR: MethodError: no method matching +(::Array{Float64,1}, ::Float64) ... julia> forward(cb, x, zeros(size(x, 2))) -(rv = [0.0277394 -0.16401 … -0.253787 0.145421; 0.0144131 -0.573255 … -0.69141 0.547456], logabsdetjac = [-2.26736, -1.63996, -0.884037, -2.49163, -1.24809]) +(rv = [1.10887 0.32029 -0.704563; -0.639206 -1.97935 -0.243419], logabsdetjac = [0.018534, 1.46352e-5, 0.00521633]) ``` """ From 4638dd2e11ef62e1dcbf8545e29ad122ae36b8e7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 18:36:01 +0200 Subject: [PATCH 64/83] fixed a typo in RadialLayer forward --- src/norm_flows.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index e2877d96..272c1ed3 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -133,7 +133,7 @@ forward(flow::RadialLayer, z) = _forward(flow, z) function forward(flow::RadialLayer, z::AbstractVector{<: Real}) res = _forward(flow, z) - return (rv=res.rv, logabsdetjac=res.logabsdetjac[1]) + return (rv=res.rv[:, 1], logabsdetjac=res.logabsdetjac[1]) end # function inv(flow::RadialLayer, y) From 24ce4cc321266539c8deaf636f474ebb87bbe987 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 20:49:41 +0200 Subject: [PATCH 65/83] replaced forward(b, x, logjac) with fused logjac instead --- src/interface.jl | 53 +++++++++++------------------------------------- 1 file changed, 12 insertions(+), 41 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 4536bdad..2cb3cae1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -74,43 +74,17 @@ logabsdetjac(ib::Inversed{<: Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) """ forward(b::Bijector, x) - forward(b::Bijector, x, logjac) + forward(ib::Inversed{<: Bijector}, y) Computes both `transform` and `logabsdetjac` in one forward pass, and returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. -`forward(b::Bijector, x, logjac)` allows the user to specify the accumulation -variable for the `logabsdetjac` field. This is useful for doing batch computations. -See below for example. - This defaults to the call above, but often one can re-use computation in the computation of the forward pass and the computation of the `logabsdetjac`. `forward` allows the user to take advantange of such efficiencies, if they exist. - -# Examples -`forward(b::Bijector, x, logjac)` allows specification of the accumulation variable -for the `logabsdetjac` field. This is useful for doing **batch computations**. -``` -julia> b = PlanarLayer(2); - -julia> cb = b ∘ b; - -julia> x = randn(2, 3) -2×3 Array{Float64,2}: - 0.0660476 -0.77195 -1.7832 - -0.147743 -1.46459 0.264924 - -julia> forward(cb, x) -ERROR: MethodError: no method matching +(::Array{Float64,1}, ::Float64) - ... -julia> forward(cb, x, zeros(size(x, 2))) -(rv = [1.10887 0.32029 -0.704563; -0.639206 -1.97935 -0.243419], logabsdetjac = [0.018534, 1.46352e-5, 0.00521633]) -``` - """ -forward(b::Bijector, x) = forward(b, x, zero(eltype(x))) -forward(b::Bijector, x, logjac) = (rv=b(x), logabsdetjac=logjac + logabsdetjac(b, x)) +forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) forward(ib::Inversed{<: Bijector}, y) = ( rv=ib(y), logabsdetjac=logabsdetjac(ib, y) @@ -215,11 +189,11 @@ _transform(x, b::Bijector, bs::Bijector...) = _transform(b(x), bs...) function _logabsdetjac(x, b1::Bijector, b2::Bijector) res = forward(b1, x) - return logabsdetjac(b2, res.rv) + res.logabsdetjac + return logabsdetjac(b2, res.rv) .+ res.logabsdetjac end function _logabsdetjac(x, b1::Bijector, bs::Bijector...) res = forward(b1, x) - return _logabsdetjac(res.rv, bs...) + res.logabsdetjac + return _logabsdetjac(res.rv, bs...) .+ res.logabsdetjac end logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) @@ -227,30 +201,27 @@ logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) function _forward(f, b1::Bijector, b2::Bijector) f1 = forward(b1, f.rv) f2 = forward(b2, f1.rv) - return (rv=f2.rv, logabsdetjac=f2.logabsdetjac + f1.logabsdetjac + f.logabsdetjac) + return (rv=f2.rv, logabsdetjac=f2.logabsdetjac .+ f1.logabsdetjac .+ f.logabsdetjac) end function _forward(f, b::Bijector, bs::Bijector...) f1 = forward(b, f.rv) - f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac + f.logabsdetjac) + f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac .+ f.logabsdetjac) return _forward(f_, bs...) end -# if `x` represents multiple elements to act on, we want to allow the user to -# specify the `logjac` accumulation field since it's ambigious, e.g. should -# it be a vector or a float? -function forward(cb::Composed{<: Tuple}, x, logjac) - _forward((rv=x, logabsdetjac=logjac), cb.ts...) +function forward(cb::Composed{<: Tuple}, x) + _forward((rv=x, logabsdetjac=zero(eltype(x))), cb.ts...) end -function forward(cb::Composed, x, logjac) +function forward(cb::Composed, x) rv = x - logjac_ = logjac + logjac = zero(eltype(x)) for t in cb.ts res = forward(t, rv) rv = res.rv - logjac_ = res.logabsdetjac + logjac_ + logjac = res.logabsdetjac .+ logjac end - return (rv=rv, logabsdetjac=logjac_) + return (rv=rv, logabsdetjac=logjac) end ############################## From 7c157f9f2da6047e6c568c71a0b2c06af8803565 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Aug 2019 22:34:26 +0200 Subject: [PATCH 66/83] initializing recursive forward call using first result --- src/interface.jl | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 2cb3cae1..1d3c2730 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -198,25 +198,29 @@ end logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) # recursive implementation of `forward` -function _forward(f, b1::Bijector, b2::Bijector) +# HACK: we need this one in the case where `length(cb.ts) == 2` +# in which case forward(...) immediately calls `_forward(::NamedTuple, b::Bijector)` +function _forward(f::NamedTuple, b::Bijector) + y, logjac = forward(b, f.rv) + return (rv=y, logabsdetjac=logjac .+ f.logabsdetjac) +end +function _forward(f::NamedTuple, b1::Bijector, b2::Bijector) f1 = forward(b1, f.rv) f2 = forward(b2, f1.rv) return (rv=f2.rv, logabsdetjac=f2.logabsdetjac .+ f1.logabsdetjac .+ f.logabsdetjac) end -function _forward(f, b::Bijector, bs::Bijector...) +function _forward(f::NamedTuple, b::Bijector, bs::Bijector...) f1 = forward(b, f.rv) f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac .+ f.logabsdetjac) return _forward(f_, bs...) end -function forward(cb::Composed{<: Tuple}, x) - _forward((rv=x, logabsdetjac=zero(eltype(x))), cb.ts...) -end +_forward(x, b::Bijector, bs::Bijector...) = _forward(forward(b, x), bs...) +forward(cb::Composed{<: Tuple}, x) = _forward(x, cb.ts...) function forward(cb::Composed, x) - rv = x - logjac = zero(eltype(x)) + rv, logjac = forward(cb.ts[1], x) - for t in cb.ts + for t in cb.ts[2:end] res = forward(t, rv) rv = res.rv logjac = res.logabsdetjac .+ logjac @@ -280,6 +284,9 @@ struct Shift{T} <: Bijector end (b::Shift)(x) = b.a + x +(b::Shift{<: Real})(x::AbstractVector) = b.a .+ x +(b::Shift{<: AbstractVector})(x::AbstractMatrix) = b.a .+ x + inv(b::Shift) = Shift(-b.a) logabsdetjac(b::Shift, x::T) where T = zero(T) @@ -288,6 +295,9 @@ struct Scale{T} <: Bijector end (b::Scale)(x) = b.a * x +(b::Scale{<: Real})(x::AbstractVector) = b.a .* x +(b::Scale{<: AbstractVector{<: Real}})(x::AbstractMatrix{<: Real}) = b.a * x + inv(b::Scale) = Scale(b.a^(-1)) logabsdetjac(b::Scale, x) = log(abs(b.a)) From aa0f8b1bfc849519e1c2fa0aafd50cc3fbf74008 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 25 Aug 2019 14:28:06 +0200 Subject: [PATCH 67/83] fixed logabsdetjac of Shift for batch, though it is ambiguous imo --- src/interface.jl | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 1d3c2730..d2c54c22 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -284,21 +284,31 @@ struct Shift{T} <: Bijector end (b::Shift)(x) = b.a + x -(b::Shift{<: Real})(x::AbstractVector) = b.a .+ x +(b::Shift{<: Real})(x::AbstractArray) = b.a .+ x (b::Shift{<: AbstractVector})(x::AbstractMatrix) = b.a .+ x inv(b::Shift) = Shift(-b.a) -logabsdetjac(b::Shift, x::T) where T = zero(T) +logabsdetjac(b::Shift, x) = zero(eltype(x)) +# FIXME: ambiguous whether or not this is actually a batch or whatever +logabsdetjac(b::Shift{<: Real}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) +logabsdetjac(b::Shift{<: AbstractVector}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) struct Scale{T} <: Bijector a::T end (b::Scale)(x) = b.a * x -(b::Scale{<: Real})(x::AbstractVector) = b.a .* x -(b::Scale{<: AbstractVector{<: Real}})(x::AbstractMatrix{<: Real}) = b.a * x +(b::Scale{<: Real})(x::AbstractArray) = b.a .* x +(b::Scale{<: AbstractVector{<: Real}})(x::AbstractMatrix{<: Real}) = x * b.a -inv(b::Scale) = Scale(b.a^(-1)) +inv(b::Scale) = Scale(inv(b.a)) +inv(b::Scale{<: AbstractVector}) = Scale(inv.(b.a)) + +# TODO: should this be implemented for batch-computation? +# There's an ambiguity issue +# logabsdetjac(b::Scale{<: AbstractVector}, x::AbstractMatrix) +# Is this a batch or is it simply a matrix we want to scale differently +# in each component? logabsdetjac(b::Scale, x) = log(abs(b.a)) #################### From 28bd7107b2cba27383bba331a6bb55d6f0cbffcf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Aug 2019 02:38:45 +0200 Subject: [PATCH 68/83] added test for transform and inverse for univariate transformed --- test/interface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/interface.jl b/test/interface.jl index 48acda96..4db523c6 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -45,6 +45,7 @@ Random.seed!(123) # single sample y = rand(td) x = inv(td.transform)(y) + @test y == td.transform(x) @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) # logpdf_with_jac From afce05c31e34fe3350995a642e8f6924f2788599 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Aug 2019 02:54:15 +0200 Subject: [PATCH 69/83] MvLogMvNormal no longer uses DistributionBijector --- src/interface.jl | 5 +++-- test/interface.jl | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d2c54c22..7be4ac01 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -273,8 +273,8 @@ const log_b = Log() inv(b::Log) = exp_b inv(b::Exp) = log_b -logabsdetjac(b::Log, x) = log(x) -logabsdetjac(b::Exp, y) = - y +logabsdetjac(b::Log, x) = sum(log.(x)) +logabsdetjac(b::Exp, y) = - sum(y) ################# # Shift & Scale # @@ -508,6 +508,7 @@ Returns the constrained-to-unconstrained bijector for distribution `d`. bijector(d::Normal) = IdentityBijector bijector(d::MvNormal) = IdentityBijector bijector(d::PositiveDistribution) = log_b +bijector(d::MvLogNormal) = log_b bijector(d::SimplexDistribution) = simplex_b_proj _union2tuple(T1::Type, T2::Type) = (T1, T2) diff --git a/test/interface.jl b/test/interface.jl index 4db523c6..39466782 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -118,7 +118,7 @@ Random.seed!(123) Dirichlet([1000 * one(Float64), eps(Float64)]), Dirichlet([eps(Float64), 1000 * one(Float64)]), MvNormal(randn(10), exp.(randn(10))), - # MvLogNormal(MvNormal(randn(10), exp.(randn(10)))), + MvLogNormal(MvNormal(randn(10), exp.(randn(10)))), Dirichlet([1000 * one(Float64), eps(Float64)]), Dirichlet([eps(Float64), 1000 * one(Float64)]), ] @@ -130,6 +130,7 @@ Random.seed!(123) # single sample y = rand(td) x = inv(td.transform)(y) + @test y == td.transform(x) @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) # logpdf_with_jac From 4044155bf2dfcf500239a3c076f39c3ac8e44e44 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Aug 2019 03:09:10 +0200 Subject: [PATCH 70/83] captialized comments --- src/interface.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 7be4ac01..ba28ebd6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -112,11 +112,11 @@ function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::Real) return Tracker.gradient(b, y)[1] end function jacobian(b::ADBijector{<: TrackerAD}, x::AbstractVector{<: Real}) - # we extract `data` so that we don't returne a `Tracked` type + # We extract `data` so that we don't returne a `Tracked` type return Tracker.data(Tracker.jacobian(b, x)) end function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::AbstractVector{<: Real}) - # we extract `data` so that we don't returne a `Tracked` type + # We extract `data` so that we don't returne a `Tracked` type return Tracker.data(Tracker.jacobian(b, y)) end @@ -197,7 +197,7 @@ function _logabsdetjac(x, b1::Bijector, bs::Bijector...) end logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) -# recursive implementation of `forward` +# Recursive implementation of `forward` # HACK: we need this one in the case where `length(cb.ts) == 2` # in which case forward(...) immediately calls `_forward(::NamedTuple, b::Bijector)` function _forward(f::NamedTuple, b::Bijector) @@ -238,8 +238,8 @@ struct Identity <: Bijector end forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) -logabsdetjac(::Identity, y::T) where T <: Real = zero(T) -logabsdetjac(::Identity, y::AbstractArray{T}) where T <: Real = zero(T) +logabsdetjac(::Identity, y::T) where {T<:Real} = zero(T) +logabsdetjac(::Identity, y::AbstractArray{T}) where {T <: Real} = zero(T) const IdentityBijector = Identity() @@ -319,7 +319,7 @@ struct SimplexBijector{T} <: Bijector where T end const simplex_b = SimplexBijector{Val{false}}() const simplex_b_proj = SimplexBijector{Val{true}}() -# the following implementations are basically just copy-paste from `invlink` and +# The following implementations are basically just copy-paste from `invlink` and # `link` for `SimplexDistributions` but dropping the dependence on the `Distribution`. function _clamp(x::T, b::SimplexBijector) where T bounds = (zero(T), one(T)) From eb4bbc8460904c7dd9fdb0b37cf594ca6be425ec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Aug 2019 03:09:54 +0200 Subject: [PATCH 71/83] added my name to a comment --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index ba28ebd6..575009fe 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -521,7 +521,7 @@ for D in _union2tuple(UnitDistribution)[2:end] @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), one(T)) end -# FIXME: Can we make this type-stable? +# FIXME: (TOR) Can we make this type-stable? # Everything but `Truncated` can probably be made type-stable # by explicit implementation. Can also make a `TruncatedBijector` # which has the same transform as the `link` function. From a568f86d53f0143db0fc334edaa081efb2b1731b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Aug 2019 03:16:12 +0200 Subject: [PATCH 72/83] fixed a typo leading to Kolmogorov being treated as unit-contrained --- src/interface.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 575009fe..1e064418 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -515,9 +515,12 @@ _union2tuple(T1::Type, T2::Type) = (T1, T2) _union2tuple(T1::Type, T2::Union) = (T1, _union2tuple(T2.a, T2.b)...) _union2tuple(T::Union) = _union2tuple(T.a, T.b) -bijector(d::Kolmogorov) = Logit(zero(eltype(d)), zero(eltype(d))) -for D in _union2tuple(UnitDistribution)[2:end] - # Skipping Kolmogorov because it's a DataType +bijector(d::KSOneSided) = Logit(zero(eltype(d)), zero(eltype(d))) +for D in _union2tuple(UnitDistribution) + # Skipping KSOneSided because it's not a parametric type + if D == KSOneSided + continue + end @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), one(T)) end From 0d5a3be879a98194acae3e2ccbc99eb9bcb11c2d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Aug 2019 19:54:24 +0200 Subject: [PATCH 73/83] added SingularJacobianException to logabsdetjac of ADBijector --- src/interface.jl | 13 +++++++++++-- test/interface.jl | 9 +++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 1e064418..4f812e12 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -120,12 +120,21 @@ function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::AbstractVector{<: return Tracker.data(Tracker.jacobian(b, y)) end +struct SingularJacobianException{B} <: Exception where B <: Bijector + b::B +end +Base.showerror(io::IO, e::SingularJacobianException) = print(io, "jacobian of $(e.b) is singular") + # TODO: allow batch-computation, especially for univariate case? "Computes the absolute determinant of the Jacobian of the inverse-transformation." -logabsdetjac(b::ADBijector, x::Real) = log(abs(jacobian(b, x))) +function logabsdetjac(b::ADBijector, x::Real) + res = log(abs(jacobian(b, x))) + return isfinite(res) ? res : throw(SingularJacobianException(b)) +end + function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) fact = lu(jacobian(b, x), check=false) - return issuccess(fact) ? log(abs(det(fact))) : -Inf # TODO: do this or not? + return issuccess(fact) ? log(abs(det(fact))) : throw(SingularJacobianException(b)) end """ diff --git a/test/interface.jl b/test/interface.jl index 39466782..68d60465 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -5,8 +5,17 @@ using LinearAlgebra Random.seed!(123) +struct NonInvertibleBijector{AD} <: ADBijector{AD} end + # Scalar tests @testset "Interface" begin + @testset "<: ADBijector{AD}" begin + (b::NonInvertibleBijector)(x) = clamp.(x, 0, 1) + + b = NonInvertibleBijector{Bijectors.ADBackend()}() + @test_throws Bijectors.SingularJacobianException logabsdetjac(b, [1.0, 10.0]) + end + @testset "Univariate" begin # Tests with scalar-valued distributions. uni_dists = [ From 323e61a358b87b5282d3c7fc68460400798ae4a9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Aug 2019 20:09:55 +0200 Subject: [PATCH 74/83] made changes to adhere to style-guide --- src/interface.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 4f812e12..c266c4a0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -70,7 +70,7 @@ Similarily for the inverse-transform. Default implementation for `Inversed{<: Bijector}` is implemented as `- logabsdetjac` of original `Bijector`. """ -logabsdetjac(ib::Inversed{<: Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) +logabsdetjac(ib::Inversed{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) """ forward(b::Bijector, x) @@ -85,7 +85,7 @@ in the computation of the forward pass and the computation of the efficiencies, if they exist. """ forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) -forward(ib::Inversed{<: Bijector}, y) = ( +forward(ib::Inversed{<:Bijector}, y) = ( rv=ib(y), logabsdetjac=logabsdetjac(ib, y) ) @@ -120,7 +120,7 @@ function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::AbstractVector{<: return Tracker.data(Tracker.jacobian(b, y)) end -struct SingularJacobianException{B} <: Exception where B <: Bijector +struct SingularJacobianException{B} <: Exception where {B<:Bijector} b::B end Base.showerror(io::IO, e::SingularJacobianException) = print(io, "jacobian of $(e.b) is singular") @@ -248,7 +248,7 @@ struct Identity <: Bijector end forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) logabsdetjac(::Identity, y::T) where {T<:Real} = zero(T) -logabsdetjac(::Identity, y::AbstractArray{T}) where {T <: Real} = zero(T) +logabsdetjac(::Identity, y::AbstractArray{T}) where {T<:Real} = zero(T) const IdentityBijector = Identity() @@ -323,14 +323,14 @@ logabsdetjac(b::Scale, x) = log(abs(b.a)) #################### # Simplex bijector # #################### -struct SimplexBijector{T} <: Bijector where T end +struct SimplexBijector{T} <: Bijector where {T} end const simplex_b = SimplexBijector{Val{false}}() const simplex_b_proj = SimplexBijector{Val{true}}() # The following implementations are basically just copy-paste from `invlink` and # `link` for `SimplexDistributions` but dropping the dependence on the `Distribution`. -function _clamp(x::T, b::SimplexBijector) where T +function _clamp(x::T, b::SimplexBijector) where {T} bounds = (zero(T), one(T)) clamped_x = clamp(x, bounds...) DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x" @@ -465,10 +465,10 @@ This is the default `Bijector` for a distribution. It uses `link` and `invlink` to compute the transformations, and `AD` to compute the `jacobian` and `logabsdetjac`. """ -struct DistributionBijector{AD, D} <: ADBijector{AD} where D <: Distribution +struct DistributionBijector{AD, D} <: ADBijector{AD} where {D<:Distribution} dist::D end -function DistributionBijector(dist::D) where D <: Distribution +function DistributionBijector(dist::D) where {D<:Distribution} DistributionBijector{ADBackend(), D}(dist) end @@ -481,19 +481,19 @@ end bijector(d::Distribution) = DistributionBijector(d) # Transformed distributions -struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D <: Distribution{V, Continuous}, B <: Bijector} +struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} dist::D transform::B end -function TransformedDistribution(d::D, b::B) where {V <: VariateForm, B <: Bijector, D <: Distribution{V, Continuous}} +function TransformedDistribution(d::D, b::B) where {V<:VariateForm, B<:Bijector, D<:Distribution{V, Continuous}} return TransformedDistribution{D, B, V}(d, b) end -const UnivariateTransformed = TransformedDistribution{<: Distribution, <: Bijector, Univariate} -const MultivariateTransformed = TransformedDistribution{<: Distribution, <: Bijector, Multivariate} +const UnivariateTransformed = TransformedDistribution{<: Distribution, <:Bijector, Univariate} +const MultivariateTransformed = TransformedDistribution{<: Distribution, <:Bijector, Multivariate} const MvTransformed = MultivariateTransformed -const MatrixTransformed = TransformedDistribution{<: Distribution, <: Bijector, Matrixvariate} +const MatrixTransformed = TransformedDistribution{<: Distribution, <:Bijector, Matrixvariate} const Transformed = TransformedDistribution @@ -530,7 +530,7 @@ for D in _union2tuple(UnitDistribution) if D == KSOneSided continue end - @eval bijector(d::$D{T}) where T <: Real = Logit(zero(T), one(T)) + @eval bijector(d::$D{T}) where {T<:Real} = Logit(zero(T), one(T)) end # FIXME: (TOR) Can we make this type-stable? @@ -538,7 +538,7 @@ end # by explicit implementation. Can also make a `TruncatedBijector` # which has the same transform as the `link` function. # E.g. (b::Truncated)(x) = link(b.d, x) or smth -function bijector(d::TransformDistribution) where D <: Distribution +function bijector(d::TransformDistribution) where {D<:Distribution} a, b = minimum(d), maximum(d) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded From 04f5afabdefecb3dafb9fb6b49ecc5d58d9d6984 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Aug 2019 20:13:25 +0200 Subject: [PATCH 75/83] mode style-changes --- src/interface.jl | 70 ++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index c266c4a0..7ec98536 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -62,19 +62,19 @@ inv(ib::Inversed{<:Bijector}) = ib.orig """ logabsdetjac(b::Bijector, x) - logabsdetjac(ib::Inversed{<: Bijector}, y) + logabsdetjac(ib::Inversed{<:Bijector}, y) Computes the log(abs(det(J(x)))) where J is the jacobian of the transform. Similarily for the inverse-transform. -Default implementation for `Inversed{<: Bijector}` is implemented as +Default implementation for `Inversed{<:Bijector}` is implemented as `- logabsdetjac` of original `Bijector`. """ logabsdetjac(ib::Inversed{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) """ forward(b::Bijector, x) - forward(ib::Inversed{<: Bijector}, y) + forward(ib::Inversed{<:Bijector}, y) Computes both `transform` and `logabsdetjac` in one forward pass, and returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. @@ -92,30 +92,30 @@ forward(ib::Inversed{<:Bijector}, y) = ( # AD implementations -function jacobian(b::ADBijector{<: ForwardDiffAD}, x::Real) +function jacobian(b::ADBijector{<:ForwardDiffAD}, x::Real) return ForwardDiff.derivative(b, x) end -function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::Real) +function jacobian(b::Inversed{<:ADBijector{<:ForwardDiffAD}}, y::Real) return ForwardDiff.derivative(b, y) end -function jacobian(b::ADBijector{<: ForwardDiffAD}, x::AbstractVector{<: Real}) +function jacobian(b::ADBijector{<:ForwardDiffAD}, x::AbstractVector{<:Real}) return ForwardDiff.jacobian(b, x) end -function jacobian(b::Inversed{<: ADBijector{<: ForwardDiffAD}}, y::AbstractVector{<: Real}) +function jacobian(b::Inversed{<:ADBijector{<:ForwardDiffAD}}, y::AbstractVector{<:Real}) return ForwardDiff.jacobian(b, y) end -function jacobian(b::ADBijector{<: TrackerAD}, x::Real) +function jacobian(b::ADBijector{<:TrackerAD}, x::Real) return Tracker.gradient(b, x)[1] end -function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::Real) +function jacobian(b::Inversed{<:ADBijector{<:TrackerAD}}, y::Real) return Tracker.gradient(b, y)[1] end -function jacobian(b::ADBijector{<: TrackerAD}, x::AbstractVector{<: Real}) +function jacobian(b::ADBijector{<:TrackerAD}, x::AbstractVector{<:Real}) # We extract `data` so that we don't returne a `Tracked` type return Tracker.data(Tracker.jacobian(b, x)) end -function jacobian(b::Inversed{<: ADBijector{<: TrackerAD}}, y::AbstractVector{<: Real}) +function jacobian(b::Inversed{<:ADBijector{<:TrackerAD}}, y::AbstractVector{<:Real}) # We extract `data` so that we don't returne a `Tracked` type return Tracker.data(Tracker.jacobian(b, y)) end @@ -182,7 +182,7 @@ compose(ts::Bijector...) = Composed(ts) inv(ct::Composed) = Composed(map(inv, reverse(ct.ts))) # # TODO: should arrays also be using recursive implementation instead? -function (cb::Composed{<: AbstractArray{<: Bijector}})(x) +function (cb::Composed{<:AbstractArray{<:Bijector}})(x) res = x for b ∈ cb.ts res = b(res) @@ -194,7 +194,7 @@ end # recursive implementation like this allows type-inference _transform(x, b1::Bijector, b2::Bijector) = b2(b1(x)) _transform(x, b::Bijector, bs::Bijector...) = _transform(b(x), bs...) -(cb::Composed{<: Tuple})(x) = _transform(x, cb.ts...) +(cb::Composed{<:Tuple})(x) = _transform(x, cb.ts...) function _logabsdetjac(x, b1::Bijector, b2::Bijector) res = forward(b1, x) @@ -207,7 +207,7 @@ end logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) # Recursive implementation of `forward` -# HACK: we need this one in the case where `length(cb.ts) == 2` +# NOTE: we need this one in the case where `length(cb.ts) == 2` # in which case forward(...) immediately calls `_forward(::NamedTuple, b::Bijector)` function _forward(f::NamedTuple, b::Bijector) y, logjac = forward(b, f.rv) @@ -224,7 +224,7 @@ function _forward(f::NamedTuple, b::Bijector, bs::Bijector...) return _forward(f_, bs...) end _forward(x, b::Bijector, bs::Bijector...) = _forward(forward(b, x), bs...) -forward(cb::Composed{<: Tuple}, x) = _forward(x, cb.ts...) +forward(cb::Composed{<:Tuple}, x) = _forward(x, cb.ts...) function forward(cb::Composed, x) rv, logjac = forward(cb.ts[1], x) @@ -263,7 +263,7 @@ struct Logit{T<:Real} <: Bijector end (b::Logit)(x) = @. logit((x - b.a) / (b.b - b.a)) -(ib::Inversed{<: Logit{<: Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a +(ib::Inversed{<:Logit{<:Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a logabsdetjac(b::Logit{<:Real}, x) = @. log((x - b.a) * (b.b - x) / (b.b - b.a)) @@ -293,25 +293,25 @@ struct Shift{T} <: Bijector end (b::Shift)(x) = b.a + x -(b::Shift{<: Real})(x::AbstractArray) = b.a .+ x -(b::Shift{<: AbstractVector})(x::AbstractMatrix) = b.a .+ x +(b::Shift{<:Real})(x::AbstractArray) = b.a .+ x +(b::Shift{<:AbstractVector})(x::AbstractMatrix) = b.a .+ x inv(b::Shift) = Shift(-b.a) logabsdetjac(b::Shift, x) = zero(eltype(x)) # FIXME: ambiguous whether or not this is actually a batch or whatever -logabsdetjac(b::Shift{<: Real}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) -logabsdetjac(b::Shift{<: AbstractVector}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) +logabsdetjac(b::Shift{<:Real}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) +logabsdetjac(b::Shift{<:AbstractVector}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) struct Scale{T} <: Bijector a::T end (b::Scale)(x) = b.a * x -(b::Scale{<: Real})(x::AbstractArray) = b.a .* x -(b::Scale{<: AbstractVector{<: Real}})(x::AbstractMatrix{<: Real}) = x * b.a +(b::Scale{<:Real})(x::AbstractArray) = b.a .* x +(b::Scale{<:AbstractVector{<:Real}})(x::AbstractMatrix{<:Real}) = x * b.a inv(b::Scale) = Scale(inv(b.a)) -inv(b::Scale{<: AbstractVector}) = Scale(inv.(b.a)) +inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) # TODO: should this be implemented for batch-computation? # There's an ambiguity issue @@ -386,7 +386,7 @@ function (b::SimplexBijector{Val{proj}})(X::AbstractMatrix{T}) where {T<:Real, p return Y end -function (ib::Inversed{<: SimplexBijector{Val{proj}}})(y::AbstractVector{T}) where {T, proj} +function (ib::Inversed{<:SimplexBijector{Val{proj}}})(y::AbstractVector{T}) where {T, proj} x, K = similar(y), length(y) ϵ = _eps(T) @@ -409,7 +409,7 @@ function (ib::Inversed{<: SimplexBijector{Val{proj}}})(y::AbstractVector{T}) whe end # Vectorised implementation of the above. -function (ib::Inversed{<: SimplexBijector{Val{proj}}})( +function (ib::Inversed{<:SimplexBijector{Val{proj}}})( Y::AbstractMatrix{T} ) where {T<:Real, proj} X, K, N = similar(Y), size(Y, 1), size(Y, 2) @@ -458,7 +458,7 @@ end ####################################################### """ DistributionBijector(d::Distribution) - DistributionBijector{<: ADBackend, D}(d::Distribution) + DistributionBijector{<:ADBackend, D}(d::Distribution) This is the default `Bijector` for a distribution. @@ -474,7 +474,7 @@ end # Simply uses `link` and `invlink` as transforms with AD to get jacobian (b::DistributionBijector)(x) = link(b.dist, x) -(ib::Inversed{<: DistributionBijector})(y) = invlink(ib.orig.dist, y) +(ib::Inversed{<:DistributionBijector})(y) = invlink(ib.orig.dist, y) "Returns the constrained-to-unconstrained bijector for distribution `d`." @@ -490,10 +490,10 @@ function TransformedDistribution(d::D, b::B) where {V<:VariateForm, B<:Bijector, end -const UnivariateTransformed = TransformedDistribution{<: Distribution, <:Bijector, Univariate} -const MultivariateTransformed = TransformedDistribution{<: Distribution, <:Bijector, Multivariate} +const UnivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Univariate} +const MultivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Multivariate} const MvTransformed = MultivariateTransformed -const MatrixTransformed = TransformedDistribution{<: Distribution, <:Bijector, Matrixvariate} +const MatrixTransformed = TransformedDistribution{<:Distribution, <:Bijector, Matrixvariate} const Transformed = TransformedDistribution @@ -566,12 +566,12 @@ function logpdf(td::UnivariateTransformed, y::Real) end # TODO: implement more efficiently for flows in the case of `Matrix` -function _logpdf(td::MvTransformed, y::AbstractVector{<: Real}) +function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) return logpdf(td.dist, res.rv) .- res.logabsdetjac end -function _logpdf(td::MvTransformed{<: Dirichlet}, y::AbstractVector{<: Real}) +function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) T = eltype(y) ϵ = _eps(T) @@ -600,12 +600,12 @@ function rand(rng::AbstractRNG, td::MvTransformed, num_samples::Int) return res end -function _rand!(rng::AbstractRNG, td::MvTransformed, x::AbstractVector{<: Real}) +function _rand!(rng::AbstractRNG, td::MvTransformed, x::AbstractVector{<:Real}) rand!(rng, td.dist, x) x .= td.transform(x) end -function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<: Real}) +function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) rand!(rng, td.dist, x) x .= td.transform(x) end @@ -637,7 +637,7 @@ function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) return (logpdf(td.dist, res.rv) .- res.logabsdetjac, res.logabsdetjac) end -function logpdf_with_jac(td::MvTransformed{<: Dirichlet}, y::AbstractVector{<:Real}) +function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) T = eltype(y) ϵ = _eps(T) From 68355c76a4c88c9c9d81f0e0ca5c5a5f787f2f68 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Aug 2019 20:49:22 +0200 Subject: [PATCH 76/83] added a more useful forward(td::TransformedDistribution) --- src/interface.jl | 43 ++++++++++++++++++++++++++++++++++++++----- test/interface.jl | 14 ++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 7ec98536..3ebf3a2c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -245,10 +245,9 @@ struct Identity <: Bijector end (::Identity)(x) = x (::Inversed{Identity})(y) = y -forward(::Identity, x) = (rv=x, logabsdetjac=zero(x)) +forward(::Identity, x) = (rv=x, logabsdetjac=zero(eltype(x))) -logabsdetjac(::Identity, y::T) where {T<:Real} = zero(T) -logabsdetjac(::Identity, y::AbstractArray{T}) where {T<:Real} = zero(T) +logabsdetjac(::Identity, y) = zero(eltype(y)) const IdentityBijector = Identity() @@ -664,5 +663,39 @@ This is similar to `logpdf_with_trans`. logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) .+ logjac logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) -forward(d::Transformed, x) = forward(d.transform, x) -forward(d::Transformed) = forward(d, rand(d.dist)) + +# forward function +const GLOBAL_RNG = Distributions.GLOBAL_RNG + +function _forward(d::UnivariateDistribution, x) + y, logjac = forward(IdentityBijector, x) + return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf.(d, x)) +end + +forward(rng::AbstractRNG, d::Distribution) = _forward(d, rand(rng, d)) +function forward(rng::AbstractRNG, d::Distribution, num_samples::Int) + return _forward(d, rand(rng, d, num_samples)) +end +function _forward(d::Distribution, x) + y, logjac = forward(IdentityBijector, x) + return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf(d, x)) +end + +function _forward(td::Transformed, x) + y, logjac = forward(td.transform, x) + return ( + x = x, + y = y, + logabsdetjac = logjac, + logpdf = logpdf_forward(td, x, logjac) + ) +end +function forward(rng::AbstractRNG, td::Transformed) + return _forward(td, rand(rng, td.dist)) +end +function forward(rng::AbstractRNG, td::Transformed, num_samples::Int) + return _forward(td, rand(rng, td.dist, num_samples)) +end + +forward(td::Distribution) = forward(GLOBAL_RNG, td) +forward(td::Distribution, num_samples::Int) = forward(GLOBAL_RNG, td, num_samples) diff --git a/test/interface.jl b/test/interface.jl index 68d60465..437a1467 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -74,6 +74,13 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end y = b(x) @test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) @test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) + + # forward + f = forward(td) + @test f.x ≈ inv(td.transform)(f.y) + @test f.y ≈ td.transform(f.x) + @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) + @test f.logpdf ≈ logpdf(td.dist, f.x) + f.logabsdetjac end @testset "$dist: ForwardDiff AD" begin @@ -151,6 +158,13 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end y = rand(td, 10) x = inv(td.transform)(y) @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # forward + f = forward(td) + @test f.x ≈ inv(td.transform)(f.y) + @test f.y ≈ td.transform(f.x) + @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) + @test f.logpdf ≈ logpdf(td.dist, f.x) + f.logabsdetjac end end end From a2336d5cc919185afbfa278f2649581453cd792c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Aug 2019 20:50:15 +0200 Subject: [PATCH 77/83] replaced constant variables log_b and exp_b with Log and Exp --- src/interface.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 3ebf3a2c..fb313d62 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -272,14 +272,12 @@ logabsdetjac(b::Logit{<:Real}, x) = @. log((x - b.a) * (b.b - x) / (b.b - b.a)) struct Exp <: Bijector end struct Log <: Bijector end -const exp_b = Exp() -const log_b = Log() (b::Log)(x) = @. log(x) (b::Exp)(y) = @. exp(y) -inv(b::Log) = exp_b -inv(b::Exp) = log_b +inv(b::Log) = Exp() +inv(b::Exp) = Log() logabsdetjac(b::Log, x) = sum(log.(x)) logabsdetjac(b::Exp, y) = - sum(y) @@ -515,8 +513,8 @@ Returns the constrained-to-unconstrained bijector for distribution `d`. """ bijector(d::Normal) = IdentityBijector bijector(d::MvNormal) = IdentityBijector -bijector(d::PositiveDistribution) = log_b -bijector(d::MvLogNormal) = log_b +bijector(d::PositiveDistribution) = Log() +bijector(d::MvLogNormal) = Log() bijector(d::SimplexDistribution) = simplex_b_proj _union2tuple(T1::Type, T2::Type) = (T1, T2) @@ -543,9 +541,9 @@ function bijector(d::TransformDistribution) where {D<:Distribution} if lowerbounded && upperbounded return Logit(a, b) elseif lowerbounded - return (log_b ∘ Shift(- a)) + return (Log() ∘ Shift(- a)) elseif upperbounded - return (log_b ∘ Shift(b) ∘ Scale(- one(typeof(b)))) + return (Log() ∘ Shift(b) ∘ Scale(- one(typeof(b)))) else return IdentityBijector end From d71b87e73337fea87874f4e9070a36985d080806 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Aug 2019 08:02:18 +0200 Subject: [PATCH 78/83] compose replaced by composel and composer --- src/interface.jl | 13 ++++++++----- test/interface.jl | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index fb313d62..be725a36 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -150,9 +150,11 @@ logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) """ ∘(b1::Bijector, b2::Bijector) - compose(ts::Bijector...) + composel(ts::Bijector...) + composer(ts::Bijector...) -A `Bijector` representing composition of bijectors. +A `Bijector` representing composition of bijectors. `composel` and `composer` results in a +`Composed` for which application occurs from left-to-right and right-to-left, respectively. # Examples It's important to note that `∘` does what is expected mathematically, which means that the @@ -163,7 +165,7 @@ bijectors are applied to the input right-to-left, e.g. first applying `b2` and t But in the `Composed` struct itself, we store the bijectors left-to-right, so that ``` cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) -cb2 = compose(b2, b1) # => Composed.ts == (b2, b1) +cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) cb1(x) == cb2(x) == b1(b2(x)) # => true ``` """ @@ -171,13 +173,14 @@ struct Composed{A} <: Bijector ts::A end -compose(ts::Bijector...) = Composed(ts) +composel(ts::Bijector...) = Composed(ts) +composer(ts::Bijector...) = Composed(inv(ts)) # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that # when we use the mathematical composition ∘ we get the expected behavior. # TODO: change behavior of `transform` of `Composed`? -∘(b1::Bijector, b2::Bijector) = compose(b2, b1) +∘(b1::Bijector, b2::Bijector) = composel(b2, b1) inv(ct::Composed) = Composed(map(inv, reverse(ct.ts))) diff --git a/test/interface.jl b/test/interface.jl index 437a1467..35f999f4 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -209,7 +209,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end x = rand(d) y = td.transform(x) - b = Bijectors.compose(td.transform, Bijectors.Identity()) + b = Bijectors.composel(td.transform, Bijectors.Identity()) ib = inv(b) @test forward(b, x) == forward(td.transform, x) From d6856d138df14043707675a0420c294e30e306cb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Aug 2019 01:05:24 +0200 Subject: [PATCH 79/83] fixed sign typo and added AD verification --- src/interface.jl | 22 +++++++++++----------- test/interface.jl | 29 +++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index be725a36..bf6f0355 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -267,7 +267,7 @@ end (b::Logit)(x) = @. logit((x - b.a) / (b.b - b.a)) (ib::Inversed{<:Logit{<:Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a -logabsdetjac(b::Logit{<:Real}, x) = @. log((x - b.a) * (b.b - x) / (b.b - b.a)) +logabsdetjac(b::Logit{<:Real}, x) = @. - log((x - b.a) * (b.b - x) / (b.b - b.a)) ############# # Exp & Log # @@ -282,8 +282,8 @@ struct Log <: Bijector end inv(b::Log) = Exp() inv(b::Exp) = Log() -logabsdetjac(b::Log, x) = sum(log.(x)) -logabsdetjac(b::Exp, y) = - sum(y) +logabsdetjac(b::Log, x) = - sum(log.(x)) +logabsdetjac(b::Exp, y) = sum(y) ################# # Shift & Scale # @@ -450,7 +450,7 @@ function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where T lp += log(z + ϵ) + log((one(T) + ϵ) - z) + log((one(T) + ϵ) - sum_tmp) end - return lp + return - lp end ####################################################### @@ -562,13 +562,13 @@ Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) .- res.logabsdetjac + return logpdf(td.dist, res.rv) .+ res.logabsdetjac end # TODO: implement more efficiently for flows in the case of `Matrix` function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) .- res.logabsdetjac + return logpdf(td.dist, res.rv) .+ res.logabsdetjac end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -576,7 +576,7 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .- res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac end # TODO: should eventually drop using `logpdf_with_trans` and replace with @@ -623,18 +623,18 @@ and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .- res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .- res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .- res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -642,7 +642,7 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - return (logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .- res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac, res.logabsdetjac) end # TODO: should eventually drop using `logpdf_with_trans` diff --git a/test/interface.jl b/test/interface.jl index 35f999f4..d783f4ef 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -2,6 +2,7 @@ using Test using Bijectors using Random using LinearAlgebra +using ForwardDiff Random.seed!(123) @@ -72,8 +73,8 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end b = bijector(d) x = rand(d) y = b(x) - @test logpdf(d, inv(b)(y)) - logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) - @test logpdf(d, x) + logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) + @test logpdf(d, inv(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) + @test logpdf(d, x) - logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) # forward f = forward(td) @@ -81,6 +82,14 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test f.y ≈ td.transform(f.x) @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) @test f.logpdf ≈ logpdf(td.dist, f.x) + f.logabsdetjac + + # verify against AD + d = dist + b = bijector(d) + x = rand(d) + y = b(x) + @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) + @test log(abs(ForwardDiff.derivative(inv(b), y))) ≈ logabsdetjac(inv(b), y) end @testset "$dist: ForwardDiff AD" begin @@ -165,6 +174,22 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test f.y ≈ td.transform(f.x) @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) @test f.logpdf ≈ logpdf(td.dist, f.x) + f.logabsdetjac + + # verify against AD + # similar to what we do in test/transform.jl for Dirichlet + if dist isa Dirichlet + b = Bijectors.SimplexBijector{Val{false}}() + x = rand(dist) + y = b(x) + @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) + @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) + else + b = bijector(dist) + x = rand(dist) + y = b(x) + @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) + @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) + end end end end From a851abe06cd15e9576a9e836e51d954fc3749dd9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Aug 2019 03:50:17 +0200 Subject: [PATCH 80/83] added special logpdf_forward for dirichlet in forward --- src/interface.jl | 9 ++++++++- test/interface.jl | 6 ++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index bf6f0355..49d187b8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -661,9 +661,16 @@ the inverse transform to compute the necessary `logabsdetjac`. This is similar to `logpdf_with_trans`. """ # TODO: implement more efficiently for flows in the case of `Matrix` -logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) .+ logjac +logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) .- logjac logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) +function logpdf_forward(td::MvTransformed{<:Dirichlet}, x, logjac) + T = eltype(x) + ϵ = _eps(T) + + return logpdf(td.dist, mappedarray(z->z+ϵ, x)) .- logjac +end + # forward function const GLOBAL_RNG = Distributions.GLOBAL_RNG diff --git a/test/interface.jl b/test/interface.jl index d783f4ef..24a40768 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -81,7 +81,8 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test f.x ≈ inv(td.transform)(f.y) @test f.y ≈ td.transform(f.x) @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) - @test f.logpdf ≈ logpdf(td.dist, f.x) + f.logabsdetjac + @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) + @test f.logpdf ≈ logpdf(td.dist, f.x) - f.logabsdetjac # verify against AD d = dist @@ -150,6 +151,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end for dist in vector_dists @testset "$dist: dist" begin + dist = Dirichlet([eps(Float64), 1000 * one(Float64)]) td = transformed(dist) # single sample @@ -173,7 +175,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test f.x ≈ inv(td.transform)(f.y) @test f.y ≈ td.transform(f.x) @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) - @test f.logpdf ≈ logpdf(td.dist, f.x) + f.logabsdetjac + @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) # verify against AD # similar to what we do in test/transform.jl for Dirichlet From ea6f1ee32062e60c9aa5517c6e5ce16b8d57adc9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Aug 2019 20:01:28 +0200 Subject: [PATCH 81/83] added docstring to forward(d::Distribution) --- src/interface.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 49d187b8..9cbdfc10 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -705,5 +705,23 @@ function forward(rng::AbstractRNG, td::Transformed, num_samples::Int) return _forward(td, rand(rng, td.dist, num_samples)) end +""" + forward(td::Distribution) + forward(td::Distribution, num_samples::Int) + +Returns a `NamedTuple` with fields `x`, `y`, `logabsdetjac` and `logpdf`. + +In the case where `d isa TransformedDistribution`, this means +- `x = rand(td.dist)` +- `y = td.transform(x)` +- `logabsdetjac` is the logabsdetjac of the "forward" transform. +- `logpdf` is the logpdf of `y`, not `x` + +In the case where `d isa Distribution`, this means +- `x = rand(td)` +- `y = x` +- `logabsdetjac = 0.0` +- `logpdf` is logpdf of `x` +""" forward(td::Distribution) = forward(GLOBAL_RNG, td) forward(td::Distribution, num_samples::Int) = forward(GLOBAL_RNG, td, num_samples) From fc2351ab3cadbb4fb9115ceb6ff352a4a30460ae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Aug 2019 20:02:22 +0200 Subject: [PATCH 82/83] change variable name in forward(d::Distribution) --- src/interface.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9cbdfc10..f45588ca 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -706,22 +706,22 @@ function forward(rng::AbstractRNG, td::Transformed, num_samples::Int) end """ - forward(td::Distribution) - forward(td::Distribution, num_samples::Int) + forward(d::Distribution) + forward(d::Distribution, num_samples::Int) Returns a `NamedTuple` with fields `x`, `y`, `logabsdetjac` and `logpdf`. In the case where `d isa TransformedDistribution`, this means -- `x = rand(td.dist)` -- `y = td.transform(x)` +- `x = rand(d.dist)` +- `y = d.transform(x)` - `logabsdetjac` is the logabsdetjac of the "forward" transform. - `logpdf` is the logpdf of `y`, not `x` In the case where `d isa Distribution`, this means -- `x = rand(td)` +- `x = rand(d)` - `y = x` - `logabsdetjac = 0.0` - `logpdf` is logpdf of `x` """ -forward(td::Distribution) = forward(GLOBAL_RNG, td) -forward(td::Distribution, num_samples::Int) = forward(GLOBAL_RNG, td, num_samples) +forward(d::Distribution) = forward(GLOBAL_RNG, d) +forward(d::Distribution, num_samples::Int) = forward(GLOBAL_RNG, d, num_samples) From b49aac5184f80c3644e4ace6609c89f8ed7233ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Aug 2019 21:51:54 +0200 Subject: [PATCH 83/83] increment version number to 0.4.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6264830f..f75e2029 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.3.2" +version = "0.4.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"