diff --git a/src/bijectors/product_bijector.jl b/src/bijectors/product_bijector.jl new file mode 100644 index 00000000..94cd1fb1 --- /dev/null +++ b/src/bijectors/product_bijector.jl @@ -0,0 +1,66 @@ +struct ProductBijector{Bs,N} <: Transform + bs::Bs +end + +ProductBijector(bs::AbstractArray{<:Any,N}) where {N} = ProductBijector{typeof(bs),N}(bs) + +inverse(b::ProductBijector) = ProductBijector(map(inverse, b.bs)) + +function _product_bijector_check_dim(::Val{N}, ::Val{M}) where {N,M} + if N > M + throw( + DimensionMismatch( + "Number of bijectors needs to be smaller than or equal to the number of dimensions", + ), + ) + end +end + +function _product_bijector_slices( + ::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M} +) where {N,M} + _product_bijector_check_dim(Val(N), Val(M)) + + # If N < M, then the bijectors expect an input vector of dimension `M - N`. + # To achieve this, we need to slice along the last `N` dimensions. + return eachslice(x; dims=ntuple(i -> i + (M - N), N)) +end + +# Specialization for case where we're just applying elementwise. +function transform( + b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,N} +) where {N} + return map(transform, b.bs, x) +end +# General case. +function transform( + b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M} +) where {N,M} + slices = _product_bijector_slices(b, x) + return stack(map(transform, b.bs, slices)) +end + +function with_logabsdet_jacobian( + b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,N} +) where {N} + results = map(with_logabsdet_jacobian, b.bs, x) + return map(first, results), sum(last, results) +end +function with_logabsdet_jacobian( + b::ProductBijector{<:AbstractArray,N}, x::AbstractArray{<:Real,M} +) where {N,M} + slices = _product_bijector_slices(b, x) + results = map(with_logabsdet_jacobian, b.bs, slices) + return stack(map(first, results)), sum(last, results) +end + +# Other utilities. +function output_size(b::ProductBijector{<:AbstractArray,N}, sz::NTuple{M}) where {N,M} + _product_bijector_check_dim(Val(N), Val(M)) + + sz_redundant = ntuple(i -> sz[i + (M - N)], N) + sz_example = ntuple(i -> sz[i], M - N) + # NOTE: `Base.stack`, which is used in the transformation, only supports the scenario where + # all `b.bs` have the same output sizes => only need to check the first one. + return (output_size(first(b.bs), sz_example)..., sz_redundant...) +end diff --git a/src/interface.jl b/src/interface.jl index 1ba9aa60..0b84523c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -251,6 +251,7 @@ include("bijectors/corr.jl") include("bijectors/truncated.jl") include("bijectors/named_bijector.jl") include("bijectors/ordered.jl") +include("bijectors/product_bijector.jl") # Normalizing flow related include("bijectors/planar_layer.jl") diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 0126db05..65a3e69c 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -66,6 +66,11 @@ has_constant_bijector(d::Type{<:KSOneSided}) = true function has_constant_bijector(::Type{<:Product{Continuous,D}}) where {D} return has_constant_bijector(D) end +function has_constant_bijector( + ::Type{<:Distributions.ProductDistribution{<:Any,<:Any,A}} +) where {A} + return has_constant_bijector(eltype(A)) +end # Container distributions. bijector(d::DiscreteUnivariateDistribution) = identity @@ -93,6 +98,22 @@ end end end +function bijector(d::Distributions.ProductDistribution{N,0,A}) where {N,A} + # This is the univariate scenario, so if we have a constant bijector + # we can just use the same one for all elements. + return if has_constant_bijector(eltype(A)) + elementwise(bijector(d.dists[1])) + else + ProductBijector(map(bijector, d.dists)) + end +end + +function bijector(d::Distributions.ProductDistribution{N,M,A}) where {N,M,A} + dists = d.dists + bs = bijector.(dists) + return ProductBijector{typeof(bs),N - M}(bs) +end + # Specialized implementations. bijector(d::Normal) = identity bijector(d::Distributions.AbstractMvNormal) = identity diff --git a/test/Project.toml b/test/Project.toml index f20d117b..39875f6e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -21,6 +22,7 @@ ChainRulesTestUtils = "0.7, 1" ChangesOfVariables = "0.1" Combinatorics = "1.0.2" DistributionsAD = "0.6.3" +FillArrays = "1" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" Functors = "0.1, 0.2, 0.3, 0.4" diff --git a/test/bijectors/product_bijector.jl b/test/bijectors/product_bijector.jl new file mode 100644 index 00000000..818f89c0 --- /dev/null +++ b/test/bijectors/product_bijector.jl @@ -0,0 +1,69 @@ +using Bijectors: ProductBijector +using FillArrays + +has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) + +@testset "ProductBijector" begin + # Some distributions. + ds = [ + # 1D. + (Normal(), true), + (InverseGamma(), false), + (Beta(), false), + # 2D. + (MvNormal(Zeros(3), I), true), + (Dirichlet(Ones(3)), false), + ] + + # Stacking a single dimension. + N = 4 + @testset "Single-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds + b = bijector(d) + xs = [rand(d) for _ in 1:N] + x = stack(xs) + + d_prod = product_distribution(Fill(d, N)) + b_prod = bijector(d_prod) + + sz_true = (Bijectors.output_size(b, size(xs[1]))..., N) + @test Bijectors.output_size(b_prod, size(x)) == sz_true + + results = map(xs) do x + with_logabsdet_jacobian(b, x) + end + y, logjac = stack(map(first, results)), sum(last, results) + + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + end + + @testset "Two-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds + b = bijector(d) + xs = [rand(d) for _ in 1:N, _ in 1:(N + 1)] + x = stack(xs) + + d_prod = product_distribution(Fill(d, N, N + 1)) + b_prod = bijector(d_prod) + + sz_true = (Bijectors.output_size(b, size(xs[1]))..., N, N + 1) + @test Bijectors.output_size(b_prod, size(x)) == sz_true + + results = map(Base.Fix1(with_logabsdet_jacobian, b), xs) + y, logjac = stack(map(first, results)), sum(last, results) + + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5c4f2df6..01b81fb8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,7 @@ if GROUP == "All" || GROUP == "Interface" include("bijectors/pd.jl") include("bijectors/reshape.jl") include("bijectors/corr.jl") + include("bijectors/product_bijector.jl") include("distributionsad.jl") end