diff --git a/Manifest.toml b/Manifest.toml index dbb1a0a8..f1c882b3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,8 +1,14 @@ +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "1.0.0" + [[Arpack]] -deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Random", "SparseArrays", "Test"] -git-tree-sha1 = "1ce1ce9984683f0b6a587d5bdbc688ecb480096f" +deps = ["BinaryProvider", "Libdl", "LinearAlgebra"] +git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6" uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" -version = "0.3.0" +version = "0.3.1" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -14,10 +20,22 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" version = "0.8.10" [[BinaryProvider]] -deps = ["Libdl", "Pkg", "SHA", "Test"] -git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" +deps = ["Libdl", "Logging", "SHA"] +git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.3" +version = "0.5.6" + +[[CSTParser]] +deps = ["Tokenize"] +git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "0.6.2" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -25,11 +43,22 @@ git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" version = "2.1.0" +[[Crayons]] +deps = ["Test"] +git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.0.0" + +[[DataAPI]] +git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.0.1" + [[DataStructures]] -deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] -git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.15.0" +version = "0.17.0" [[Dates]] deps = ["Printf"] @@ -39,6 +68,18 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +[[DiffResults]] +deps = ["Compat", "StaticArrays"] +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "0.0.4" + +[[DiffRules]] +deps = ["Random", "Test"] +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "0.0.10" + [[Distributed]] deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -49,6 +90,12 @@ git-tree-sha1 = "022e6610c320b6e19b454502d759c672580abe00" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" version = "0.18.0" +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.3" + [[InteractiveUtils]] deps = ["LinearAlgebra", "Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -66,6 +113,12 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[MacroTools]] +deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"] +git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.1" + [[MappedArrays]] deps = ["Test"] git-tree-sha1 = "923441c5ac942b60bd3a842d5377d96646bcbf46" @@ -77,14 +130,26 @@ deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[Missings]] -deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] -git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" +deps = ["SparseArrays", "Test"] +git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.0" +version = "0.4.1" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] +git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.6.0" + +[[NaNMath]] +deps = ["Compat"] +git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.2" + [[OrderedCollections]] deps = ["Random", "Serialization", "Test"] git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" @@ -93,9 +158,9 @@ version = "1.1.0" [[PDMats]] deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "b6c91fc0ab970c0563cbbe69af18d741a49ce551" +git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.9.6" +version = "0.9.9" [[Pkg]] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -137,6 +202,12 @@ git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" version = "0.5.0" +[[Roots]] +deps = ["Printf"] +git-tree-sha1 = "9cc4b586c71f9aea25312b94be8c195f119b0ec3" +uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" +version = "0.8.3" + [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -166,15 +237,21 @@ git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "0.7.2" +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.11.0" + [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] -deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] -git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94" +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.29.0" +version = "0.32.0" [[StatsFuns]] deps = ["Rmath", "SpecialFunctions", "Test"] @@ -183,13 +260,30 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "0.8.0" [[SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "SparseArrays"] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[TimerOutputs]] +deps = ["Crayons", "Printf", "Test", "Unicode"] +git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.0" + +[[Tokenize]] +git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.6" + +[[Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] +git-tree-sha1 = "327342fec6e09f68ced0c2dc5731ed475e4b696b" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.2.2" + [[URIParser]] deps = ["Test", "Unicode"] git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" diff --git a/Project.toml b/Project.toml index 02207c27..f75e2029 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,18 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.3.2" +version = "0.4.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 55e4758b..f5b1781e 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -5,6 +5,7 @@ using Reexport, Requires using StatsFuns using LinearAlgebra using MappedArrays +using Roots export TransformDistribution, PositiveDistribution, @@ -13,7 +14,26 @@ export TransformDistribution, PDMatDistribution, link, invlink, - logpdf_with_trans + logpdf_with_trans, + transform, + forward, + logabsdetjac, + logabsdetjacinv, + Bijector, + ADBijector, + Inversed, + Composed, + compose, + Identity, + DistributionBijector, + bijector, + transformed, + UnivariateTransformed, + MultivariateTransformed, + logpdf_with_jac, + logpdf_forward, + PlanarLayer, + RadialLayer const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) @@ -162,8 +182,8 @@ function _clamp(x::T, dist::SimplexDistribution) where T end function link( - d::SimplexDistribution, - x::AbstractVector{T}, + d::SimplexDistribution, + x::AbstractVector{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} y, K = similar(x), length(x) @@ -191,8 +211,8 @@ end # Vectorised implementation of the above. function link( - d::SimplexDistribution, - X::AbstractMatrix{T}, + d::SimplexDistribution, + X::AbstractMatrix{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} Y, K, N = similar(X), size(X, 1), size(X, 2) @@ -219,8 +239,8 @@ function link( end function invlink( - d::SimplexDistribution, - y::AbstractVector{T}, + d::SimplexDistribution, + y::AbstractVector{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} x, K = similar(y), length(y) @@ -245,8 +265,8 @@ end # Vectorised implementation of the above. function invlink( - d::SimplexDistribution, - Y::AbstractMatrix{T}, + d::SimplexDistribution, + Y::AbstractMatrix{T}, ::Type{Val{proj}} = Val{true} ) where {T<:Real, proj} X, K, N = similar(Y), size(Y, 1), size(Y, 2) @@ -335,8 +355,8 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{<:Real}) end function logpdf_with_trans( - d::PDMatDistribution, - X::AbstractMatrix{<:Real}, + d::PDMatDistribution, + X::AbstractMatrix{<:Real}, transform::Bool ) T = eltype(X) @@ -424,4 +444,8 @@ function logpdf_with_trans( return map(x -> logpdf_with_trans(d, x, transform), X) end +include("interface.jl") + +include("norm_flows.jl") + end # module diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 00000000..f45588ca --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,727 @@ +using Distributions, Bijectors +using ForwardDiff +using Tracker + +import Base: inv, ∘ + +import Random: AbstractRNG +import Distributions: logpdf, rand, rand!, _rand!, _logpdf + +####################################### +# AD stuff "extracted" from Turing.jl # +####################################### + +abstract type ADBackend end +struct ForwardDiffAD <: ADBackend end +struct TrackerAD <: ADBackend end + +const ADBACKEND = Ref(:forward_diff) +function setadbackend(backend_sym) + @assert backend_sym == :forward_diff || backend_sym == :reverse_diff + backend_sym == :forward_diff + ADBACKEND[] = backend_sym +end + +ADBackend() = ADBackend(ADBACKEND[]) +ADBackend(T::Symbol) = ADBackend(Val(T)) +function ADBackend(::Val{T}) where {T} + if T === :forward_diff + return ForwardDiffAD + else + return TrackerAD + end +end + +###################### +# Bijector interface # +###################### + +"Abstract type for a `Bijector`." +abstract type Bijector end + +Broadcast.broadcastable(b::Bijector) = Ref(b) + +""" +Abstract type for a `Bijector` making use of auto-differentation (AD) to +implement `jacobian` and, by impliciation, `logabsdetjac`. +""" +abstract type ADBijector{AD} <: Bijector end + +""" + inv(b::Bijector) + Inversed(b::Bijector) + +A `Bijector` representing the inverse transform of `b`. +""" +struct Inversed{B <: Bijector} <: Bijector + orig::B +end + +inv(b::Bijector) = Inversed(b) +inv(ib::Inversed{<:Bijector}) = ib.orig + +""" + logabsdetjac(b::Bijector, x) + logabsdetjac(ib::Inversed{<:Bijector}, y) + +Computes the log(abs(det(J(x)))) where J is the jacobian of the transform. +Similarily for the inverse-transform. + +Default implementation for `Inversed{<:Bijector}` is implemented as +`- logabsdetjac` of original `Bijector`. +""" +logabsdetjac(ib::Inversed{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) + +""" + forward(b::Bijector, x) + forward(ib::Inversed{<:Bijector}, y) + +Computes both `transform` and `logabsdetjac` in one forward pass, and +returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. + +This defaults to the call above, but often one can re-use computation +in the computation of the forward pass and the computation of the +`logabsdetjac`. `forward` allows the user to take advantange of such +efficiencies, if they exist. +""" +forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) +forward(ib::Inversed{<:Bijector}, y) = ( + rv=ib(y), + logabsdetjac=logabsdetjac(ib, y) +) + + +# AD implementations +function jacobian(b::ADBijector{<:ForwardDiffAD}, x::Real) + return ForwardDiff.derivative(b, x) +end +function jacobian(b::Inversed{<:ADBijector{<:ForwardDiffAD}}, y::Real) + return ForwardDiff.derivative(b, y) +end +function jacobian(b::ADBijector{<:ForwardDiffAD}, x::AbstractVector{<:Real}) + return ForwardDiff.jacobian(b, x) +end +function jacobian(b::Inversed{<:ADBijector{<:ForwardDiffAD}}, y::AbstractVector{<:Real}) + return ForwardDiff.jacobian(b, y) +end + +function jacobian(b::ADBijector{<:TrackerAD}, x::Real) + return Tracker.gradient(b, x)[1] +end +function jacobian(b::Inversed{<:ADBijector{<:TrackerAD}}, y::Real) + return Tracker.gradient(b, y)[1] +end +function jacobian(b::ADBijector{<:TrackerAD}, x::AbstractVector{<:Real}) + # We extract `data` so that we don't returne a `Tracked` type + return Tracker.data(Tracker.jacobian(b, x)) +end +function jacobian(b::Inversed{<:ADBijector{<:TrackerAD}}, y::AbstractVector{<:Real}) + # We extract `data` so that we don't returne a `Tracked` type + return Tracker.data(Tracker.jacobian(b, y)) +end + +struct SingularJacobianException{B} <: Exception where {B<:Bijector} + b::B +end +Base.showerror(io::IO, e::SingularJacobianException) = print(io, "jacobian of $(e.b) is singular") + +# TODO: allow batch-computation, especially for univariate case? +"Computes the absolute determinant of the Jacobian of the inverse-transformation." +function logabsdetjac(b::ADBijector, x::Real) + res = log(abs(jacobian(b, x))) + return isfinite(res) ? res : throw(SingularJacobianException(b)) +end + +function logabsdetjac(b::ADBijector, x::AbstractVector{<:Real}) + fact = lu(jacobian(b, x), check=false) + return issuccess(fact) ? log(abs(det(fact))) : throw(SingularJacobianException(b)) +end + +""" + logabsdetjacinv(b::Bijector, y) + +Just an alias for `logabsdetjac(inv(b), y)`. +""" +logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) + +############### +# Composition # +############### + +""" + ∘(b1::Bijector, b2::Bijector) + composel(ts::Bijector...) + composer(ts::Bijector...) + +A `Bijector` representing composition of bijectors. `composel` and `composer` results in a +`Composed` for which application occurs from left-to-right and right-to-left, respectively. + +# Examples +It's important to note that `∘` does what is expected mathematically, which means that the +bijectors are applied to the input right-to-left, e.g. first applying `b2` and then `b1`: +``` +(b1 ∘ b2)(x) == b1(b2(x)) # => true +``` +But in the `Composed` struct itself, we store the bijectors left-to-right, so that +``` +cb1 = b1 ∘ b2 # => Composed.ts == (b2, b1) +cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) +cb1(x) == cb2(x) == b1(b2(x)) # => true +``` +""" +struct Composed{A} <: Bijector + ts::A +end + +composel(ts::Bijector...) = Composed(ts) +composer(ts::Bijector...) = Composed(inv(ts)) + +# The transformation of `Composed` applies functions left-to-right +# but in mathematics we usually go from right-to-left; this reversal ensures that +# when we use the mathematical composition ∘ we get the expected behavior. +# TODO: change behavior of `transform` of `Composed`? +∘(b1::Bijector, b2::Bijector) = composel(b2, b1) + +inv(ct::Composed) = Composed(map(inv, reverse(ct.ts))) + +# # TODO: should arrays also be using recursive implementation instead? +function (cb::Composed{<:AbstractArray{<:Bijector}})(x) + res = x + for b ∈ cb.ts + res = b(res) + end + + return res +end + +# recursive implementation like this allows type-inference +_transform(x, b1::Bijector, b2::Bijector) = b2(b1(x)) +_transform(x, b::Bijector, bs::Bijector...) = _transform(b(x), bs...) +(cb::Composed{<:Tuple})(x) = _transform(x, cb.ts...) + +function _logabsdetjac(x, b1::Bijector, b2::Bijector) + res = forward(b1, x) + return logabsdetjac(b2, res.rv) .+ res.logabsdetjac +end +function _logabsdetjac(x, b1::Bijector, bs::Bijector...) + res = forward(b1, x) + return _logabsdetjac(res.rv, bs...) .+ res.logabsdetjac +end +logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) + +# Recursive implementation of `forward` +# NOTE: we need this one in the case where `length(cb.ts) == 2` +# in which case forward(...) immediately calls `_forward(::NamedTuple, b::Bijector)` +function _forward(f::NamedTuple, b::Bijector) + y, logjac = forward(b, f.rv) + return (rv=y, logabsdetjac=logjac .+ f.logabsdetjac) +end +function _forward(f::NamedTuple, b1::Bijector, b2::Bijector) + f1 = forward(b1, f.rv) + f2 = forward(b2, f1.rv) + return (rv=f2.rv, logabsdetjac=f2.logabsdetjac .+ f1.logabsdetjac .+ f.logabsdetjac) +end +function _forward(f::NamedTuple, b::Bijector, bs::Bijector...) + f1 = forward(b, f.rv) + f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac .+ f.logabsdetjac) + return _forward(f_, bs...) +end +_forward(x, b::Bijector, bs::Bijector...) = _forward(forward(b, x), bs...) +forward(cb::Composed{<:Tuple}, x) = _forward(x, cb.ts...) + +function forward(cb::Composed, x) + rv, logjac = forward(cb.ts[1], x) + + for t in cb.ts[2:end] + res = forward(t, rv) + rv = res.rv + logjac = res.logabsdetjac .+ logjac + end + return (rv=rv, logabsdetjac=logjac) +end + +############################## +# Example bijector: Identity # +############################## + +struct Identity <: Bijector end +(::Identity)(x) = x +(::Inversed{Identity})(y) = y + +forward(::Identity, x) = (rv=x, logabsdetjac=zero(eltype(x))) + +logabsdetjac(::Identity, y) = zero(eltype(y)) + +const IdentityBijector = Identity() + +############################### +# Example: Logit and Logistic # +############################### +using StatsFuns: logit, logistic + +struct Logit{T<:Real} <: Bijector + a::T + b::T +end + +(b::Logit)(x) = @. logit((x - b.a) / (b.b - b.a)) +(ib::Inversed{<:Logit{<:Real}})(y) = @. (ib.orig.b - ib.orig.a) * logistic(y) + ib.orig.a + +logabsdetjac(b::Logit{<:Real}, x) = @. - log((x - b.a) * (b.b - x) / (b.b - b.a)) + +############# +# Exp & Log # +############# + +struct Exp <: Bijector end +struct Log <: Bijector end + +(b::Log)(x) = @. log(x) +(b::Exp)(y) = @. exp(y) + +inv(b::Log) = Exp() +inv(b::Exp) = Log() + +logabsdetjac(b::Log, x) = - sum(log.(x)) +logabsdetjac(b::Exp, y) = sum(y) + +################# +# Shift & Scale # +################# +struct Shift{T} <: Bijector + a::T +end + +(b::Shift)(x) = b.a + x +(b::Shift{<:Real})(x::AbstractArray) = b.a .+ x +(b::Shift{<:AbstractVector})(x::AbstractMatrix) = b.a .+ x + +inv(b::Shift) = Shift(-b.a) +logabsdetjac(b::Shift, x) = zero(eltype(x)) +# FIXME: ambiguous whether or not this is actually a batch or whatever +logabsdetjac(b::Shift{<:Real}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) +logabsdetjac(b::Shift{<:AbstractVector}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) + +struct Scale{T} <: Bijector + a::T +end + +(b::Scale)(x) = b.a * x +(b::Scale{<:Real})(x::AbstractArray) = b.a .* x +(b::Scale{<:AbstractVector{<:Real}})(x::AbstractMatrix{<:Real}) = x * b.a + +inv(b::Scale) = Scale(inv(b.a)) +inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) + +# TODO: should this be implemented for batch-computation? +# There's an ambiguity issue +# logabsdetjac(b::Scale{<: AbstractVector}, x::AbstractMatrix) +# Is this a batch or is it simply a matrix we want to scale differently +# in each component? +logabsdetjac(b::Scale, x) = log(abs(b.a)) + +#################### +# Simplex bijector # +#################### +struct SimplexBijector{T} <: Bijector where {T} end + +const simplex_b = SimplexBijector{Val{false}}() +const simplex_b_proj = SimplexBijector{Val{true}}() + +# The following implementations are basically just copy-paste from `invlink` and +# `link` for `SimplexDistributions` but dropping the dependence on the `Distribution`. +function _clamp(x::T, b::SimplexBijector) where {T} + bounds = (zero(T), one(T)) + clamped_x = clamp(x, bounds...) + DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x" + return clamped_x +end + +function (b::SimplexBijector{Val{proj}})(x::AbstractVector{T}) where {T, proj} + y, K = similar(x), length(x) + + ϵ = _eps(T) + sum_tmp = zero(T) + @inbounds z = x[1] * (one(T) - 2ϵ) + ϵ # z ∈ [ϵ, 1-ϵ] + @inbounds y[1] = StatsFuns.logit(z) + log(T(K - 1)) + @inbounds @simd for k in 2:(K - 1) + sum_tmp += x[k - 1] + # z ∈ [ϵ, 1-ϵ] + # x[k] = 0 && sum_tmp = 1 -> z ≈ 1 + z = (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) + y[k] = StatsFuns.logit(z) + log(T(K - k)) + end + @inbounds sum_tmp += x[K - 1] + @inbounds if proj + y[K] = zero(T) + else + y[K] = one(T) - sum_tmp - x[K] + end + + return y +end + +# Vectorised implementation of the above. +function (b::SimplexBijector{Val{proj}})(X::AbstractMatrix{T}) where {T<:Real, proj} + Y, K, N = similar(X), size(X, 1), size(X, 2) + + ϵ = _eps(T) + @inbounds @simd for n in 1:size(X, 2) + sum_tmp = zero(T) + z = X[1, n] * (one(T) - 2ϵ) + ϵ + Y[1, n] = StatsFuns.logit(z) + log(T(K - 1)) + for k in 2:(K - 1) + sum_tmp += X[k - 1, n] + z = (X[k, n] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) + Y[k, n] = StatsFuns.logit(z) + log(T(K - k)) + end + sum_tmp += X[K-1, n] + if proj + Y[K, n] = zero(T) + else + Y[K, n] = one(T) - sum_tmp - X[K, n] + end + end + + return Y +end + +function (ib::Inversed{<:SimplexBijector{Val{proj}}})(y::AbstractVector{T}) where {T, proj} + x, K = similar(y), length(y) + + ϵ = _eps(T) + @inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1))) + @inbounds x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), ib.orig) + sum_tmp = zero(T) + @inbounds @simd for k = 2:(K - 1) + z = StatsFuns.logistic(y[k] - log(T(K - k))) + sum_tmp += x[k-1] + x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, ib.orig) + end + @inbounds sum_tmp += x[K - 1] + @inbounds if proj + x[K] = _clamp(one(T) - sum_tmp, ib.orig) + else + x[K] = _clamp(one(T) - sum_tmp - y[K], ib.orig) + end + + return x +end + +# Vectorised implementation of the above. +function (ib::Inversed{<:SimplexBijector{Val{proj}}})( + Y::AbstractMatrix{T} +) where {T<:Real, proj} + X, K, N = similar(Y), size(Y, 1), size(Y, 2) + + ϵ = _eps(T) + @inbounds @simd for n in 1:size(X, 2) + sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] - log(T(K - 1))) + X[1, n] = _clamp((z - ϵ) / (one(T) - 2ϵ), ib.orig) + for k in 2:(K - 1) + z = StatsFuns.logistic(Y[k, n] - log(T(K - k))) + sum_tmp += X[k - 1] + X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, ib.orig) + end + sum_tmp += X[K - 1, n] + if proj + X[K, n] = _clamp(one(T) - sum_tmp, ib.orig) + else + X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], ib.orig) + end + end + + return X +end + + +function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where T + ϵ = _eps(T) + lp = zero(T) + + K = length(x) + + sum_tmp = zero(eltype(x)) + @inbounds z = x[1] + lp += log(z + ϵ) + log((one(T) + ϵ) - z) + @inbounds @simd for k in 2:(K - 1) + sum_tmp += x[k-1] + z = x[k] / ((one(T) + ϵ) - sum_tmp) + lp += log(z + ϵ) + log((one(T) + ϵ) - z) + log((one(T) + ϵ) - sum_tmp) + end + + return - lp +end + +####################################################### +# Constrained to unconstrained distribution bijectors # +####################################################### +""" + DistributionBijector(d::Distribution) + DistributionBijector{<:ADBackend, D}(d::Distribution) + +This is the default `Bijector` for a distribution. + +It uses `link` and `invlink` to compute the transformations, and `AD` to compute +the `jacobian` and `logabsdetjac`. +""" +struct DistributionBijector{AD, D} <: ADBijector{AD} where {D<:Distribution} + dist::D +end +function DistributionBijector(dist::D) where {D<:Distribution} + DistributionBijector{ADBackend(), D}(dist) +end + +# Simply uses `link` and `invlink` as transforms with AD to get jacobian +(b::DistributionBijector)(x) = link(b.dist, x) +(ib::Inversed{<:DistributionBijector})(y) = invlink(ib.orig.dist, y) + + +"Returns the constrained-to-unconstrained bijector for distribution `d`." +bijector(d::Distribution) = DistributionBijector(d) + +# Transformed distributions +struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} + dist::D + transform::B +end +function TransformedDistribution(d::D, b::B) where {V<:VariateForm, B<:Bijector, D<:Distribution{V, Continuous}} + return TransformedDistribution{D, B, V}(d, b) +end + + +const UnivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Univariate} +const MultivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Multivariate} +const MvTransformed = MultivariateTransformed +const MatrixTransformed = TransformedDistribution{<:Distribution, <:Bijector, Matrixvariate} +const Transformed = TransformedDistribution + + +""" + transformed(d::Distribution) + transformed(d::Distribution, b::Bijector) + +Couples distribution `d` with the bijector `b` by returning a `TransformedDistribution`. + +If no bijector is provided, i.e. `transformed(d)` is called, then +`transformed(d, bijector(d))` is returned. +""" +transformed(d::Distribution, b::Bijector) = TransformedDistribution(d, b) +transformed(d) = transformed(d, bijector(d)) + +""" + bijector(d::Distribution) + +Returns the constrained-to-unconstrained bijector for distribution `d`. +""" +bijector(d::Normal) = IdentityBijector +bijector(d::MvNormal) = IdentityBijector +bijector(d::PositiveDistribution) = Log() +bijector(d::MvLogNormal) = Log() +bijector(d::SimplexDistribution) = simplex_b_proj + +_union2tuple(T1::Type, T2::Type) = (T1, T2) +_union2tuple(T1::Type, T2::Union) = (T1, _union2tuple(T2.a, T2.b)...) +_union2tuple(T::Union) = _union2tuple(T.a, T.b) + +bijector(d::KSOneSided) = Logit(zero(eltype(d)), zero(eltype(d))) +for D in _union2tuple(UnitDistribution) + # Skipping KSOneSided because it's not a parametric type + if D == KSOneSided + continue + end + @eval bijector(d::$D{T}) where {T<:Real} = Logit(zero(T), one(T)) +end + +# FIXME: (TOR) Can we make this type-stable? +# Everything but `Truncated` can probably be made type-stable +# by explicit implementation. Can also make a `TruncatedBijector` +# which has the same transform as the `link` function. +# E.g. (b::Truncated)(x) = link(b.d, x) or smth +function bijector(d::TransformDistribution) where {D<:Distribution} + a, b = minimum(d), maximum(d) + lowerbounded, upperbounded = isfinite(a), isfinite(b) + if lowerbounded && upperbounded + return Logit(a, b) + elseif lowerbounded + return (Log() ∘ Shift(- a)) + elseif upperbounded + return (Log() ∘ Shift(b) ∘ Scale(- one(typeof(b)))) + else + return IdentityBijector + end +end + +############################## +# Distributions.jl interface # +############################## + +# size +Base.length(td::Transformed) = length(td.dist) +Base.size(td::Transformed) = size(td.dist) + +function logpdf(td::UnivariateTransformed, y::Real) + res = forward(inv(td.transform), y) + return logpdf(td.dist, res.rv) .+ res.logabsdetjac +end + +# TODO: implement more efficiently for flows in the case of `Matrix` +function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) + res = forward(inv(td.transform), y) + return logpdf(td.dist, res.rv) .+ res.logabsdetjac +end + +function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) + T = eltype(y) + ϵ = _eps(T) + + res = forward(inv(td.transform), y) + return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac +end + +# TODO: should eventually drop using `logpdf_with_trans` and replace with +# res = forward(inv(td.transform), y) +# logpdf(td.dist, res.rv) .- res.logabsdetjac +function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + return logpdf_with_trans(td.dist, inv(td.transform)(y), true) +end + +# rand +rand(td::UnivariateTransformed) = td.transform(rand(td.dist)) +rand(rng::AbstractRNG, td::UnivariateTransformed) = td.transform(rand(rng, td.dist)) + +# These ovarloadings are useful for differentiating sampling wrt. params of `td.dist` +# or params of `Bijector`, as they are not inplace like the default `rand` +rand(td::MvTransformed) = td.transform(rand(td.dist)) +rand(rng::AbstractRNG, td::MvTransformed) = td.transform(rand(rng, td.dist)) +# TODO: implement more efficiently for flows +function rand(rng::AbstractRNG, td::MvTransformed, num_samples::Int) + res = hcat([td.transform(rand(td.dist)) for i = 1:num_samples]...) + return res +end + +function _rand!(rng::AbstractRNG, td::MvTransformed, x::AbstractVector{<:Real}) + rand!(rng, td.dist, x) + x .= td.transform(x) +end + +function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) + rand!(rng, td.dist, x) + x .= td.transform(x) +end + +############################################################# +# Additional useful functions for `TransformedDistribution` # +############################################################# +""" + logpdf_with_jac(td::UnivariateTransformed, y::Real) + logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) + logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + +Makes use of the `forward` method to potentially re-use computation +and returns a tuple `(logpdf, logabsdetjac)`. +""" +function logpdf_with_jac(td::UnivariateTransformed, y::Real) + res = forward(inv(td.transform), y) + return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) +end + +# TODO: implement more efficiently for flows in the case of `Matrix` +function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) + res = forward(inv(td.transform), y) + return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) +end + +function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) + res = forward(inv(td.transform), y) + return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) +end + +function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) + T = eltype(y) + ϵ = _eps(T) + + res = forward(inv(td.transform), y) + return (logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac, res.logabsdetjac) +end + +# TODO: should eventually drop using `logpdf_with_trans` +function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) + res = forward(inv(td.transform), y) + return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) +end + +""" + logpdf_forward(td::Transformed, x) + logpdf_forward(td::Transformed, x, logjac) + +Computes the `logpdf` using the forward pass of the bijector rather than using +the inverse transform to compute the necessary `logabsdetjac`. + +This is similar to `logpdf_with_trans`. +""" +# TODO: implement more efficiently for flows in the case of `Matrix` +logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) .- logjac +logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) + +function logpdf_forward(td::MvTransformed{<:Dirichlet}, x, logjac) + T = eltype(x) + ϵ = _eps(T) + + return logpdf(td.dist, mappedarray(z->z+ϵ, x)) .- logjac +end + + +# forward function +const GLOBAL_RNG = Distributions.GLOBAL_RNG + +function _forward(d::UnivariateDistribution, x) + y, logjac = forward(IdentityBijector, x) + return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf.(d, x)) +end + +forward(rng::AbstractRNG, d::Distribution) = _forward(d, rand(rng, d)) +function forward(rng::AbstractRNG, d::Distribution, num_samples::Int) + return _forward(d, rand(rng, d, num_samples)) +end +function _forward(d::Distribution, x) + y, logjac = forward(IdentityBijector, x) + return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf(d, x)) +end + +function _forward(td::Transformed, x) + y, logjac = forward(td.transform, x) + return ( + x = x, + y = y, + logabsdetjac = logjac, + logpdf = logpdf_forward(td, x, logjac) + ) +end +function forward(rng::AbstractRNG, td::Transformed) + return _forward(td, rand(rng, td.dist)) +end +function forward(rng::AbstractRNG, td::Transformed, num_samples::Int) + return _forward(td, rand(rng, td.dist, num_samples)) +end + +""" + forward(d::Distribution) + forward(d::Distribution, num_samples::Int) + +Returns a `NamedTuple` with fields `x`, `y`, `logabsdetjac` and `logpdf`. + +In the case where `d isa TransformedDistribution`, this means +- `x = rand(d.dist)` +- `y = d.transform(x)` +- `logabsdetjac` is the logabsdetjac of the "forward" transform. +- `logpdf` is the logpdf of `y`, not `x` + +In the case where `d isa Distribution`, this means +- `x = rand(d)` +- `y = x` +- `logabsdetjac = 0.0` +- `logpdf` is logpdf of `x` +""" +forward(d::Distribution) = forward(GLOBAL_RNG, d) +forward(d::Distribution, num_samples::Int) = forward(GLOBAL_RNG, d, num_samples) diff --git a/src/norm_flows.jl b/src/norm_flows.jl new file mode 100644 index 00000000..272c1ed3 --- /dev/null +++ b/src/norm_flows.jl @@ -0,0 +1,157 @@ +using Distributions +using LinearAlgebra +using Random +using StatsFuns: softplus +using Roots # for inverse + +################################################################################ +# Planar and Radial Flows # +# Ref: Variational Inference with Normalizing Flows, # +# D. Rezende, S. Mohamed(2015) arXiv:1505.05770 # +################################################################################ + +############### +# PlanarLayer # +############### + +mutable struct PlanarLayer{T1,T2} <: Bijector + w::T1 + u::T1 + b::T2 +end + +function get_u_hat(u, w) + # To preserve invertibility + return ( + u + (planar_flow_m(w' * u) - w' * u)[1] + * w / (norm(w[:,1],2) ^ 2) + ) # from A.1 +end + +function PlanarLayer(dims::Int, container=Array) + w = container(randn(dims, 1)) + u = container(randn(dims, 1)) + b = container(randn(1)) + return PlanarLayer(w, u, b) +end + +planar_flow_m(x) = -1 .+ softplus.(x) # for planar flow from A.1 +dtanh(x) = 1 .- (tanh.(x)) .^ 2 # for planar flow +ψ(z, w, b) = dtanh(w' * z .+ b) .* w # for planar flow from eq(11) + +# An internal version of transform that returns intermediate variables +function _transform(flow::PlanarLayer, z) + u_hat = get_u_hat(flow.u, flow.w) + transformed = z + u_hat * tanh.(flow.w' * z .+ flow.b) # from eq(10) + return (transformed=transformed, u_hat=u_hat) +end + +(b::PlanarLayer)(z) = _transform(b, z).transformed + +function _forward(flow::PlanarLayer, z) + transformed, u_hat = _transform(flow, z) + # Compute log_det_jacobian + psi = ψ(z, flow.w, flow.b) + log_det_jacobian = log.(abs.(1.0 .+ psi' * u_hat)) # from eq(12) + return (rv=transformed, logabsdetjac=vec(log_det_jacobian)) # from eq(10) +end + +forward(flow::PlanarLayer, z) = _forward(flow, z) + +function forward(flow::PlanarLayer, z::AbstractVector{<: Real}) + res = _forward(flow, z) + return (rv=res.rv, logabsdetjac=res.logabsdetjac[1]) +end + + +function (ib::Inversed{<: PlanarLayer})(y::AbstractMatrix{<: Real}) + flow = ib.orig + u_hat = get_u_hat(flow.u, flow.w) + # Define the objective functional; implemented with reference from A.1 + f(y) = alpha -> (flow.w' * y)[1] - alpha - (flow.w' * u_hat)[1] * tanh(alpha+flow.b[1]) + # Run solver + alphas_ = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)] + alphas = alphas_' + z_para = (flow.w ./ norm(flow.w,2)) * alphas + z_per = y - z_para - u_hat * tanh.(flow.w' * z_para .+ flow.b) + + return z_para + z_per +end + +function (ib::Inversed{<: PlanarLayer})(y::AbstractVector{<: Real}) + return vec(ib(reshape(y, (length(y), 1)))) +end + +logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac + +############### +# RadialLayer # +############### + +# FIXME: using `TrackedArray` for the parameters, we end up with +# nested tracked structures; don't want this. +mutable struct RadialLayer{T1,T2} <: Bijector + α_::T1 + β::T1 + z_0::T2 +end + +function RadialLayer(dims::Int, container=Array) + α_ = container(randn(1)) + β = container(randn(1)) + z_0 = container(randn(dims, 1)) + return RadialLayer(α_, β, z_0) +end + +h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) +dh(α, r) = - h(α, r) .^ 2 # for radial flow; derivative of h() + +# An internal version of transform that returns intermediate variables +function _transform(flow::RadialLayer, z) + α = softplus(flow.α_[1]) # from A.2 + β_hat = -α + softplus(flow.β[1]) # from A.2 + r = sqrt.(sum((z .- flow.z_0).^2; dims = 1)) + transformed = z + β_hat .* h(α, r) .* (z .- flow.z_0) # from eq(14) + return (transformed=transformed, α=α, β_hat=β_hat, r=r) +end + +(b::RadialLayer)(z) = _transform(b, z).transformed + +function _forward(flow::RadialLayer, z) + transformed, α, β_hat, r = _transform(flow, z) + # Compute log_det_jacobian + d = size(flow.z_0, 1) + h_ = h(α, r) + log_det_jacobian = @. ( + (d - 1) * log(1.0 + β_hat * h_) + + log(1.0 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) + ) # from eq(14) + return (rv=transformed, logabsdetjac=vec(log_det_jacobian)) +end + +forward(flow::RadialLayer, z) = _forward(flow, z) + +function forward(flow::RadialLayer, z::AbstractVector{<: Real}) + res = _forward(flow, z) + return (rv=res.rv[:, 1], logabsdetjac=res.logabsdetjac[1]) +end + +# function inv(flow::RadialLayer, y) +function (ib::Inversed{<: RadialLayer})(y) + flow = ib.orig + α = softplus(flow.α_[1]) # from A.2 + β_hat = - α + softplus(flow.β[1]) # from A.2 + # Define the objective functional + f(y) = r -> norm(y - flow.z_0, 2) - r * (1 + β_hat / (α + r)) # from eq(26) + # Run solver + rs = [find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y, 2)]' # from A.2 + z_hat = (y .- flow.z_0) ./ (rs .* (1 .+ β_hat ./ (α .+ rs))) # from eq(25) + z = flow.z_0 .+ rs .* z_hat # from A.2 + return z +end + +function (ib::Inversed{<: RadialLayer})(y::AbstractVector{<: Real}) + return vec(ib(reshape(y, (length(y), 1)))) +end + +logabsdetjac(flow::RadialLayer, x) = forward(flow, x).logabsdetjac diff --git a/test/interface.jl b/test/interface.jl new file mode 100644 index 00000000..24a40768 --- /dev/null +++ b/test/interface.jl @@ -0,0 +1,294 @@ +using Test +using Bijectors +using Random +using LinearAlgebra +using ForwardDiff + +Random.seed!(123) + +struct NonInvertibleBijector{AD} <: ADBijector{AD} end + +# Scalar tests +@testset "Interface" begin + @testset "<: ADBijector{AD}" begin + (b::NonInvertibleBijector)(x) = clamp.(x, 0, 1) + + b = NonInvertibleBijector{Bijectors.ADBackend()}() + @test_throws Bijectors.SingularJacobianException logabsdetjac(b, [1.0, 10.0]) + end + + @testset "Univariate" begin + # Tests with scalar-valued distributions. + uni_dists = [ + Arcsine(2, 4), + Beta(2,2), + BetaPrime(), + Biweight(), + Cauchy(), + Chi(3), + Chisq(2), + Cosine(), + Epanechnikov(), + Erlang(), + Exponential(), + FDist(1, 1), + Frechet(), + Gamma(), + InverseGamma(), + InverseGaussian(), + # Kolmogorov(), + Laplace(), + Levy(), + Logistic(), + LogNormal(1.0, 2.5), + Normal(0.1, 2.5), + Pareto(), + Rayleigh(1.0), + TDist(2), + TruncatedNormal(0, 1, -Inf, 2), + ] + + for dist in uni_dists + @testset "$dist: dist" begin + td = transformed(dist) + + # single sample + y = rand(td) + x = inv(td.transform)(y) + @test y == td.transform(x) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # logpdf_with_jac + lp, logjac = logpdf_with_jac(td, y) + @test lp ≈ logpdf(td, y) + @test logjac ≈ logabsdetjacinv(td.transform, y) + + # multi-sample + y = rand(td, 10) + x = inv(td.transform).(y) + @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) + + # logpdf corresponds to logpdf_with_trans + d = dist + b = bijector(d) + x = rand(d) + y = b(x) + @test logpdf(d, inv(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) + @test logpdf(d, x) - logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) + + # forward + f = forward(td) + @test f.x ≈ inv(td.transform)(f.y) + @test f.y ≈ td.transform(f.x) + @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) + @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) + @test f.logpdf ≈ logpdf(td.dist, f.x) - f.logabsdetjac + + # verify against AD + d = dist + b = bijector(d) + x = rand(d) + y = b(x) + @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) + @test log(abs(ForwardDiff.derivative(inv(b), y))) ≈ logabsdetjac(inv(b), y) + end + + @testset "$dist: ForwardDiff AD" begin + x = rand(dist) + b = DistributionBijector{Bijectors.ADBackend(:forward_diff), typeof(dist)}(dist) + + @test abs(det(Bijectors.jacobian(b, x))) > 0 + @test logabsdetjac(b, x) ≠ Inf + + y = b(x) + b⁻¹ = inv(b) + @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 + @test logabsdetjac(b⁻¹, y) ≠ Inf + end + + @testset "$dist: Tracker AD" begin + x = rand(dist) + b = DistributionBijector{Bijectors.ADBackend(:reverse_diff), typeof(dist)}(dist) + + @test abs(det(Bijectors.jacobian(b, x))) > 0 + @test logabsdetjac(b, x) ≠ Inf + + y = b(x) + b⁻¹ = inv(b) + @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 + @test logabsdetjac(b⁻¹, y) ≠ Inf + end + end + end + + @testset "Truncated" begin + d = Truncated(Normal(), -1, 1) + b = bijector(d) + x = rand(d) + @test b(x) == link(d, x) + + d = Truncated(Normal(), -Inf, 1) + b = bijector(d) + x = rand(d) + @test b(x) == link(d, x) + + d = Truncated(Normal(), 1, Inf) + b = bijector(d) + x = rand(d) + @test b(x) == link(d, x) + end + + @testset "Multivariate" begin + vector_dists = [ + Dirichlet(2, 3), + Dirichlet([1000 * one(Float64), eps(Float64)]), + Dirichlet([eps(Float64), 1000 * one(Float64)]), + MvNormal(randn(10), exp.(randn(10))), + MvLogNormal(MvNormal(randn(10), exp.(randn(10)))), + Dirichlet([1000 * one(Float64), eps(Float64)]), + Dirichlet([eps(Float64), 1000 * one(Float64)]), + ] + + for dist in vector_dists + @testset "$dist: dist" begin + dist = Dirichlet([eps(Float64), 1000 * one(Float64)]) + td = transformed(dist) + + # single sample + y = rand(td) + x = inv(td.transform)(y) + @test y == td.transform(x) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # logpdf_with_jac + lp, logjac = logpdf_with_jac(td, y) + @test lp ≈ logpdf(td, y) + @test logjac ≈ logabsdetjacinv(td.transform, y) + + # multi-sample + y = rand(td, 10) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # forward + f = forward(td) + @test f.x ≈ inv(td.transform)(f.y) + @test f.y ≈ td.transform(f.x) + @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) + @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) + + # verify against AD + # similar to what we do in test/transform.jl for Dirichlet + if dist isa Dirichlet + b = Bijectors.SimplexBijector{Val{false}}() + x = rand(dist) + y = b(x) + @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) + @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) + else + b = bijector(dist) + x = rand(dist) + y = b(x) + @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) + @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) + end + end + end + end + + @testset "Matrix variate" begin + v = 7.0 + S = Matrix(1.0I, 2, 2) + S[1, 2] = S[2, 1] = 0.5 + + matrix_dists = [ + Wishart(v,S), + InverseWishart(v,S) + ] + + for dist in matrix_dists + @testset "$dist: dist" begin + td = transformed(dist) + + # single sample + y = rand(td) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # TODO: implement `logabsdetjac` for these + # logpdf_with_jac + # lp, logjac = logpdf_with_jac(td, y) + # @test lp ≈ logpdf(td, y) + # @test logjac ≈ logabsdetjacinv(td.transform, y) + + # multi-sample + y = rand(td, 10) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + end + end + end + + @testset "Composition <: Bijector" begin + d = Beta() + td = transformed(d) + + x = rand(d) + y = td.transform(x) + + b = Bijectors.composel(td.transform, Bijectors.Identity()) + ib = inv(b) + + @test forward(b, x) == forward(td.transform, x) + @test forward(ib, y) == forward(inv(td.transform), y) + + # inverse works fine for composition + cb = b ∘ ib + @test cb(x) ≈ x + + cb2 = cb ∘ cb + @test cb(x) ≈ x + + # ensures that the `logabsdetjac` is correct + x = rand(d) + b = inv(bijector(d)) + @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) + + # order of composed evaluation + b1 = DistributionBijector(d) + b2 = DistributionBijector(Gamma()) + + cb = b1 ∘ b2 + @test cb(x) ≈ b1(b2(x)) + + # contrived example + b = bijector(d) + cb = inv(b) ∘ b + cb = cb ∘ cb + @test (cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x + + # forward for tuple and array + d = Beta() + b = inv(bijector(d)) + b⁻¹ = inv(b) + x = rand(d) + + cb_t = b⁻¹ ∘ b⁻¹ + f_t = forward(cb_t, x) + + cb_a = Composed([b⁻¹, b⁻¹]) + f_a = forward(cb_a, x) + + @test f_t == f_a + end + + @testset "Example: ADVI single" begin + # Usage in ADVI + d = Beta() + b = DistributionBijector(d) # [0, 1] → ℝ + ib = inv(b) # ℝ → [0, 1] + td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] + x = rand(td) # ∈ [0, 1] + @test 0 ≤ x ≤ 1 + end +end diff --git a/test/norm_flows.jl b/test/norm_flows.jl new file mode 100644 index 00000000..7f5d94cf --- /dev/null +++ b/test/norm_flows.jl @@ -0,0 +1,63 @@ +using Test +using Bijectors, ForwardDiff, LinearAlgebra +using Random: seed! + +seed!(1) + +@testset "PlanarLayer" begin + for i in 1:4 + flow = PlanarLayer(2) + z = randn(2, 20) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) + our_method = sum(forward(flow, z).logabsdetjac) + + @test our_method ≈ forward_diff + @test inv(flow)(flow(z)) ≈ z rtol=0.2 + @test (inv(flow) ∘ flow)(z) ≈ z rtol=0.2 + end + + w = ones(10, 1) + u = zeros(10, 1) + b = ones(1) + flow = PlanarLayer(w, u, b) + z = ones(10, 100) + @test inv(flow)(flow(z)) ≈ z +end + +@testset "RadialLayer" begin + for i in 1:4 + flow = RadialLayer(2) + z = randn(2, 20) + forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) + our_method = sum(forward(flow, z).logabsdetjac) + + @test our_method ≈ forward_diff + @test inv(flow)(flow(z)) ≈ z rtol=0.2 + @test (inv(flow) ∘ flow)(z) ≈ z rtol=0.2 + end + + α_ = ones(1) + β = ones(1) + z_0 = zeros(10, 1) + z = ones(10, 100) + flow = RadialLayer(α_, β, z_0) + @test inv(flow)(flow(z)) ≈ z +end + +@testset "Flows" begin + d = MvNormal(zeros(2), ones(2)) + b = PlanarLayer(2) + flow = transformed(d, b) # <= Radial flow + + y = rand(flow) + @test logpdf(flow, y) != 0.0 + + x = rand(d) + y = flow.transform(x) + res = forward(flow, x) + lp = logpdf_forward(flow, x, res.logabsdetjac) + + @test res.rv ≈ y + @test logpdf(flow, y) ≈ lp rtol=0.1 +end + diff --git a/test/runtests.jl b/test/runtests.jl index 66760196..1e5e773e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,4 +2,5 @@ using Bijectors, Random Random.seed!(123456) +include("interface.jl") include("transform.jl")