From 55e1e2a74a3ff24fa95516750ae46224c15056e0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 30 Aug 2019 04:52:06 +0200 Subject: [PATCH 01/89] added the Stacked bijector --- src/Bijectors.jl | 1 + src/interface.jl | 76 +++++++++++++++++++++++++++++++++++++++++++++-- test/interface.jl | 59 ++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 3 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index f5b1781e..82556b5b 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -24,6 +24,7 @@ export TransformDistribution, Inversed, Composed, compose, + Stacked, Identity, DistributionBijector, bijector, diff --git a/src/interface.jl b/src/interface.jl index f45588ca..64aa3af8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -123,7 +123,9 @@ end struct SingularJacobianException{B} <: Exception where {B<:Bijector} b::B end -Base.showerror(io::IO, e::SingularJacobianException) = print(io, "jacobian of $(e.b) is singular") +function Base.showerror(io::IO, e::SingularJacobianException) + print(io, "jacobian of $(e.b) is singular") +end # TODO: allow batch-computation, especially for univariate case? "Computes the absolute determinant of the Jacobian of the inverse-transformation." @@ -240,6 +242,73 @@ function forward(cb::Composed, x) return (rv=rv, logabsdetjac=logjac) end +-########### +-# Stacked # +-########### +""" + Stacked(bs) + Stacked(bs, ranges) + vcat(bs::Bijector...) + +A `Bijector` which stacks bijectors together which can then be applied to a vector +where `bs[i]::Bijector` is applied to `x[ranges[i]]`. + +# Examples +``` +b1 = Logistic(0.0, 1.0) +b2 = Identity() +b = vcat(b1, b2) +b([0.0, 1.0]) == [b1(0.0), 1.0] # => true +``` +""" +struct Stacked{B, N} <: Bijector where N + bs::B + ranges::NTuple{N, UnitRange{Int}} +end +Stacked(bs) = Stacked(bs, NTuple{length(bs), UnitRange{Int}}([i:i for i = 1:length(bs)])) +Stacked(bs, ranges) = Stacked(bs, NTuple{length(bs), UnitRange{Int}}(ranges)) + +Base.vcat(bs::Bijector...) = Stacked(bs) + +inv(sb::Stacked) = Stacked(inv.(sb.bs), sb.ranges) + +# TODO: Is there a better approach to this? +@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs::Bijector...) where N + exprs = [] + for i = 1:N + push!(exprs, :(bs[$i](x[rs[$i]]))) + end + + return :(vcat($(exprs...))) +end +_transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) = b(x) + +(sb::Stacked)(x::AbstractArray{<: Real}) = _transform(x, sb.ranges, sb.bs...) +(sb::Stacked)(x::AbstractMatrix{<: Real}) = hcat([sb(x[:, i]) for i = 1:size(x, 2)]...) +function (sb::Stacked)(x::TrackedArray{A, 2}) where {A} + return Tracker.collect(hcat([sb(x[:, i]) for i = 1:size(x, 2)]...)) +end + +@generated function _logabsdetjac( + x, + rs::NTuple{N, UnitRange{Int}}, + bs::Bijector... +) where {N} + exprs = [] + for i = 1:N + push!(exprs, :(sum(logabsdetjac(bs[$i], x[rs[$i]])))) + end + + return :(sum([$(exprs...), ])) +end +logabsdetjac(b::Stacked, x::AbstractVector{<: Real}) = _logabsdetjac(x, b.ranges, b.bs...) +function logabsdetjac(sb::Stacked, x::AbstractMatrix{<: Real}) + return hcat([logabsdetjac(sb, x[:, i]) for i = 1:size(x, 2)]) +end +function logabsdetjac(sb::Stacked, x::TrackedArray{A, 2}) where {A} + return Tracker.collect(hcat([logabsdetjac(sb, x[:, i]) for i = 1:size(x, 2)])) +end + ############################## # Example bijector: Identity # ############################## @@ -324,13 +393,14 @@ logabsdetjac(b::Scale, x) = log(abs(b.a)) # Simplex bijector # #################### struct SimplexBijector{T} <: Bijector where {T} end +SimplexBijector() = SimplexBijector{Val{true}}() const simplex_b = SimplexBijector{Val{false}}() const simplex_b_proj = SimplexBijector{Val{true}}() # The following implementations are basically just copy-paste from `invlink` and # `link` for `SimplexDistributions` but dropping the dependence on the `Distribution`. -function _clamp(x::T, b::SimplexBijector) where {T} +function _clamp(x::T, b::Union{SimplexBijector, Inversed{<:SimplexBijector}}) where {T} bounds = (zero(T), one(T)) clamped_x = clamp(x, bounds...) DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x" @@ -518,7 +588,7 @@ bijector(d::Normal) = IdentityBijector bijector(d::MvNormal) = IdentityBijector bijector(d::PositiveDistribution) = Log() bijector(d::MvLogNormal) = Log() -bijector(d::SimplexDistribution) = simplex_b_proj +bijector(d::SimplexDistribution) = SimplexBijector{Val{true}}() _union2tuple(T1::Type, T2::Type) = (T1, T2) _union2tuple(T1::Type, T2::Union) = (T1, _union2tuple(T2.a, T2.b)...) diff --git a/test/interface.jl b/test/interface.jl index 24a40768..86ee6d84 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -282,6 +282,65 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test f_t == f_a end + @testset "Stacked <: Bijector" begin + # `logabsdetjac` without AD + d = Beta() + b = bijector(d) + x = rand(d) + y = b(x) + sb = vcat(b, b, inv(b), inv(b)) + @test logabsdetjac(sb, [x, x, y, y]) ≈ 0.0 + + # `logabsdetjac` with AD + b = DistributionBijector(d) + y = b(x) + sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple + sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 + @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 + + @testset "Stacked: ADVI with MvNormal" begin + # MvNormal test + d = MvNormal(zeros(10), ones(10)) + dists = [ + Beta(), + Beta(), + Beta(), + InverseGamma(), + InverseGamma(), + Gamma(), + Gamma(), + InverseGamma(), + Cauchy(), + Gamma() + ] + bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists + ibs = inv.(bs) # invert, so we get unconstrained-to-constrained + sb = vcat(ibs...) # => Stacked <: Bijector + @test sb isa Stacked + + td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} + @test td isa Distribution{Multivariate, Continuous} + + y = rand(td) + + bs = bijector.(tuple(dists...)) + ibs = inv.(bs) + sb = vcat(ibs...) + isb = inv(sb) + @test sb isa Stacked{<: Tuple} + + # inverse + y = rand(td) + x = isb(y) + @test sb(x) ≈ y + + # AD verification + @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) + @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) + end + end + @testset "Example: ADVI single" begin # Usage in ADVI d = Beta() From d82c9da698e088a66f7f434299427ca088f40aae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 30 Aug 2019 04:53:20 +0200 Subject: [PATCH 02/89] a couple of style fixes --- src/interface.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 64aa3af8..2cd662d5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -555,7 +555,10 @@ struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D< dist::D transform::B end -function TransformedDistribution(d::D, b::B) where {V<:VariateForm, B<:Bijector, D<:Distribution{V, Continuous}} +function TransformedDistribution( + d::D, + b::B +) where {V<:VariateForm, B<:Bijector, D<:Distribution{V, Continuous}} return TransformedDistribution{D, B, V}(d, b) end @@ -712,7 +715,8 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - return (logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac, res.logabsdetjac) + lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac + return (lp, res.logabsdetjac) end # TODO: should eventually drop using `logpdf_with_trans` From d557b3dff2c6fbda3a42dfd7a2a47f3eaa71f610 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 30 Aug 2019 05:24:22 +0200 Subject: [PATCH 03/89] added size assertion to Stacked and better testing --- src/interface.jl | 17 ++++++++++++----- test/interface.jl | 43 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 2cd662d5..07df4d87 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -283,7 +283,14 @@ inv(sb::Stacked) = Stacked(inv.(sb.bs), sb.ranges) end _transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) = b(x) -(sb::Stacked)(x::AbstractArray{<: Real}) = _transform(x, sb.ranges, sb.bs...) +function (sb::Stacked)(x::AbstractArray{<: Real}) + y = _transform(x, sb.ranges, sb.bs...) + + # TODO: maybe tell user to check their ranges? + @assert size(y) == size(x) + + return y +end (sb::Stacked)(x::AbstractMatrix{<: Real}) = hcat([sb(x[:, i]) for i = 1:size(x, 2)]...) function (sb::Stacked)(x::TrackedArray{A, 2}) where {A} return Tracker.collect(hcat([sb(x[:, i]) for i = 1:size(x, 2)]...)) @@ -302,11 +309,11 @@ end return :(sum([$(exprs...), ])) end logabsdetjac(b::Stacked, x::AbstractVector{<: Real}) = _logabsdetjac(x, b.ranges, b.bs...) -function logabsdetjac(sb::Stacked, x::AbstractMatrix{<: Real}) - return hcat([logabsdetjac(sb, x[:, i]) for i = 1:size(x, 2)]) +function logabsdetjac(b::Stacked, x::AbstractMatrix{<: Real}) + return hcat([logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)]) end -function logabsdetjac(sb::Stacked, x::TrackedArray{A, 2}) where {A} - return Tracker.collect(hcat([logabsdetjac(sb, x[:, i]) for i = 1:size(x, 2)])) +function logabsdetjac(b::Stacked, x::TrackedArray{A, 2}) where {A} + return Tracker.collect(hcat([logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)])) end ############################## diff --git a/test/interface.jl b/test/interface.jl index 86ee6d84..8f603267 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -301,7 +301,6 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @testset "Stacked: ADVI with MvNormal" begin # MvNormal test - d = MvNormal(zeros(10), ones(10)) dists = [ Beta(), Beta(), @@ -312,29 +311,59 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end Gamma(), InverseGamma(), Cauchy(), - Gamma() + Gamma(), + MvNormal(zeros(2), ones(2)) ] - bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists - ibs = inv.(bs) # invert, so we get unconstrained-to-constrained - sb = vcat(ibs...) # => Stacked <: Bijector + + ranges = [] + idx = 1 + for i = 1:length(dists) + d = dists[i] + push!(ranges, idx:idx + length(d) - 1) + idx += length(d) + end + + num_params = ranges[end][end] + d = MvNormal(zeros(num_params), ones(num_params)) + + # Stacked{<:Array} + bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists + ibs = inv.(bs) # invert, so we get unconstrained-to-constrained + sb = Stacked(ibs, ranges) # => Stacked <: Bijector + x = rand(d) + sb(x) @test sb isa Stacked td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} @test td isa Distribution{Multivariate, Continuous} - y = rand(td) + # check that wrong ranges fails + sb = vcat(ibs...) + td = transformed(d, sb) + x = rand(d) + @test_throws AssertionError sb(x) + # Stacked{<:Tuple} bs = bijector.(tuple(dists...)) ibs = inv.(bs) - sb = vcat(ibs...) + sb = Stacked(ibs, ranges) isb = inv(sb) @test sb isa Stacked{<: Tuple} # inverse + td = transformed(d, sb) y = rand(td) x = isb(y) @test sb(x) ≈ y + # verification of computation + x = rand(d) + y = sb(x) + y_ = vcat([ibs[i](x[ranges[i]]) for i = 1:length(dists)]...) + x_ = vcat([bs[i](y[ranges[i]]) for i = 1:length(dists)]...) + @test x ≈ x_ + @test y ≈ y_ + # AD verification @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) From e168ce0c1fbf4eea1a7474389452aa35a3025906 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 30 Aug 2019 06:20:25 +0200 Subject: [PATCH 04/89] added Stacked test to tests for norm flows --- test/norm_flows.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 7f5d94cf..9b12da44 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -59,5 +59,21 @@ end @test res.rv ≈ y @test logpdf(flow, y) ≈ lp rtol=0.1 -end + # flow with unconstrained-to-constrained + d1 = Beta() + b1 = inv(bijector(d1)) + d2 = InverseGamma() + b2 = inv(bijector(d2)) + + x = rand(d) .+ 10 + y = b(x) + + sb = vcat(b1, b1) + @test all((sb ∘ b)(x) .≤ 1.0) + + sb = vcat(b1, b2) + cb = (sb ∘ b) + y = cb(x) + @test (0 ≤ y[1] ≤ 1.0) && (0 < y[2]) +end From ea5b57c81e1cda88148f22b8f4d054bf46b37ee2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Sep 2019 18:41:57 +0200 Subject: [PATCH 05/89] export TransformedDistribution --- src/Bijectors.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 82556b5b..f375b989 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -29,6 +29,7 @@ export TransformDistribution, DistributionBijector, bijector, transformed, + TransformedDistribution, UnivariateTransformed, MultivariateTransformed, logpdf_with_jac, From fae5dbf3794b58a2c563b660647fe365a639a36d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Sep 2019 18:50:09 +0200 Subject: [PATCH 06/89] added som useful implementations for TransformedDistribution --- Manifest.toml | 6 +++--- Project.toml | 1 + src/Bijectors.jl | 1 + src/interface.jl | 34 +++++++++++++++++++++++++++++++--- 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index f1c882b3..cc6f7d86 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -232,10 +232,10 @@ deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] -git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" +deps = ["BinDeps", "BinaryProvider", "Libdl"] +git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "0.7.2" +version = "0.8.0" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] diff --git a/Project.toml b/Project.toml index f75e2029..17f58a7c 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index f375b989..14b2cc58 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -2,6 +2,7 @@ module Bijectors using Reexport, Requires @reexport using Distributions +@reexport using StatsBase using StatsFuns using LinearAlgebra using MappedArrays diff --git a/src/interface.jl b/src/interface.jl index 07df4d87..468cef47 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -5,7 +5,8 @@ using Tracker import Base: inv, ∘ import Random: AbstractRNG -import Distributions: logpdf, rand, rand!, _rand!, _logpdf +import Distributions: logpdf, rand, rand!, _rand!, _logpdf, params +import StatsBase: entropy ####################################### # AD stuff "extracted" from Turing.jl # @@ -64,7 +65,7 @@ inv(ib::Inversed{<:Bijector}) = ib.orig logabsdetjac(b::Bijector, x) logabsdetjac(ib::Inversed{<:Bijector}, y) -Computes the log(abs(det(J(x)))) where J is the jacobian of the transform. +Computes the log(abs(det(J(b(x))))) where J is the jacobian of the transform. Similarily for the inverse-transform. Default implementation for `Inversed{<:Bijector}` is implemented as @@ -554,7 +555,11 @@ end (ib::Inversed{<:DistributionBijector})(y) = invlink(ib.orig.dist, y) -"Returns the constrained-to-unconstrained bijector for distribution `d`." +""" + bijector(d::Distribution) + +Returns the constrained-to-unconstrained bijector for distribution `d`. +""" bijector(d::Distribution) = DistributionBijector(d) # Transformed distributions @@ -806,3 +811,26 @@ In the case where `d isa Distribution`, this means """ forward(d::Distribution) = forward(GLOBAL_RNG, d) forward(d::Distribution, num_samples::Int) = forward(GLOBAL_RNG, d, num_samples) + +# utility stuff +params(td::Transformed) = params(td.dist) +entropy(td::Transformed) = entropy(td.dist) + +# logabsdetjac for distributions +logabsdetjacinv(d::UnivariateDistribution, x::T) where T <: Real = zero(T) +logabsdetjacinv(d::MultivariateDistribution, x::AbstractVector{T}) where {T<:Real} = zero(T) + +# for transformed distributions the `y` is going to be the transformed variable +# and so we use the inverse transform to get what we want +# TODO: should this be renamed to `logabsdetinvjac`? +""" + logabsdetjacinv(td::UnivariateTransformed, y::Real) + logabsdetjacinv(td::MultivariateTransformed, y::AbstractVector{<:Real}) + +Computes the `logabsdetjac` of the _inverse_ transformation, since `rand(td)` returns +the _transformed_ random variable. +""" +logabsdetjacinv(td::UnivariateTransformed, y::Real) = logabsdetjac(inv(td.transform), y) +function logabsdetjacinv(td::MvTransformed, y::AbstractVector{<:Real}) + return logabsdetjac(inv(td.transform), y) +end From fee6aad89d016bd52a196875cf3be44f5c6c8542 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Sep 2019 11:16:49 +0200 Subject: [PATCH 07/89] fixed composer which had a bug leftover from previous PR --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 468cef47..fac30111 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -177,7 +177,7 @@ struct Composed{A} <: Bijector end composel(ts::Bijector...) = Composed(ts) -composer(ts::Bijector...) = Composed(inv(ts)) +composer(ts::Bijector...) = Composed(reverse(ts)) # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that From af38dc2611023f604eff341dfd0bf65ac6e983eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Sep 2019 11:18:35 +0200 Subject: [PATCH 08/89] removed vectorization in logabsdetjac to fail loadly rather than silently --- src/interface.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index fac30111..0076f5f5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -204,11 +204,11 @@ _transform(x, b::Bijector, bs::Bijector...) = _transform(b(x), bs...) function _logabsdetjac(x, b1::Bijector, b2::Bijector) res = forward(b1, x) - return logabsdetjac(b2, res.rv) .+ res.logabsdetjac + return logabsdetjac(b2, res.rv) + res.logabsdetjac end function _logabsdetjac(x, b1::Bijector, bs::Bijector...) res = forward(b1, x) - return _logabsdetjac(res.rv, bs...) .+ res.logabsdetjac + return _logabsdetjac(res.rv, bs...) + res.logabsdetjac end logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) @@ -311,10 +311,10 @@ end end logabsdetjac(b::Stacked, x::AbstractVector{<: Real}) = _logabsdetjac(x, b.ranges, b.bs...) function logabsdetjac(b::Stacked, x::AbstractMatrix{<: Real}) - return hcat([logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)]) + return [logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)] end function logabsdetjac(b::Stacked, x::TrackedArray{A, 2}) where {A} - return Tracker.collect(hcat([logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)])) + return Tracker.collect([logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)]) end ############################## From 6ca67bd9f8ed42ab19b5557024e0b73ce5cf9754 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Sep 2019 11:20:44 +0200 Subject: [PATCH 09/89] initial implementation of coupling-layers --- src/Bijectors.jl | 2 + src/couplings.jl | 125 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 src/couplings.jl diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 14b2cc58..47772d16 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -451,4 +451,6 @@ include("interface.jl") include("norm_flows.jl") +include("couplings.jl") + end # module diff --git a/src/couplings.jl b/src/couplings.jl new file mode 100644 index 00000000..e712e894 --- /dev/null +++ b/src/couplings.jl @@ -0,0 +1,125 @@ +using SparseArrays + +# TODO: should we add another field `A_3` which we can use to filter out those +# parts of the vector to which we apply the identity? E.g. +# you want to use x[1] to parameterize transform of x[2], but you don't want +# to do anything with x[3] +struct PartitionMask{A} + A_1::A + A_2::A + A_3::A +end + +function PartitionMask( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::AbstractVector{Int}, + indices_3::AbstractVector{Int} +) + A_1 = spzeros(n, length(indices_1)); + A_2 = spzeros(n, length(indices_2)); + A_3 = spzeros(n, length(indices_3)); + + for (i, idx) in enumerate(indices_1) + A_1[idx, i] = 1.0 + end + + for (i, idx) in enumerate(indices_2) + A_2[idx, i] = 1.0 + end + + for (i, idx) in enumerate(indices_3) + A_3[idx, i] = 1.0 + end + + return PartitionMask(A_1, A_2, A_3) +end + +PartitionMask( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::AbstractVector{Int}, + indices_3::Nothing +) = PartitionMask(n, indices_1, indices_2, Int[]) + +PartitionMask( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::Nothing, + indices_3::AbstractVector{Int} +) = PartitionMask(n, indices_1, Int[], indices_3) + +""" + PartitionMask(n::Int, indices) + +Assumes you want to _split_ the vector, where `indices` refer to the +parts of the vector you want to apply the bijector to. +""" +function PartitionMask(n::Int, indices) + indices_2 = [i for i in 1:n if i ∉ indices] + + # sparse arrays <3 + A_1 = spzeros(n, length(indices)); + A_2 = spzeros(n, length(indices_2)); + + # Like doing: + # A[1, 1] = 1.0 + # A[3, 2] = 1.0 + for (i, idx) in enumerate(indices) + A_1[idx, i] = 1.0 + end + + for (i, idx) in enumerate(indices_2) + A_2[idx, i] = 1.0 + end + + return PartitionMask(A_1, A_2, spzeros(n, 0)) +end +PartitionMask(x::AbstractVector, indices) = PartitionMask(length(x), indices) + +@inline combine(m::PartitionMask, x_1, x_2, x_3) = m.A_1 * x_1 .+ m.A_2 * x_2 .+ m.A_3 * x_3 +@inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x) + + +# CouplingLayer + +struct CouplingLayer{B, M, F} <: Bijector where {B, M <: PartitionMask, F} + mask::M + θ::F +end + +CouplingLayer(B, mask::M, θ::F) where {M, F} = CouplingLayer{B, M, F}(mask, θ) +function CouplingLayer(B, θ, n::Int) + idx = Int(floor(n / 2)) + return CouplingLayer(B, PartitionMask(n, 1:idx), θ) +end + + + +function (cl::CouplingLayer{B})(x) where {B} + x_1, x_2, x_3 = partition(cl.mask, x) + + b = B(cl.θ(x_2)) + + return combine(cl.mask, b(x_1), x_2, x_3) +end + + +function (icl::Inversed{<:CouplingLayer{B}})(y) where {B} + cl = icl.orig + + y_1, y_2, y_3 = partition(cl.mask, y) + + b = B(cl.θ(y_2)) + ib = inv(b) + + return combine(cl.mask, ib(y_1), y_2, y_3) +end + + +function logabsdetjac(cl::CouplingLayer{B}, x) where {B} + x_1, x_2, x_3 = partition(cl.mask, x) + b = B(cl.θ(x_2)) + + return logabsdetjac(b, x_1) +end From 3b9440c88a10b8f0124aa8e6d5ac67193e977574 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Sep 2019 11:48:31 +0200 Subject: [PATCH 10/89] added some docstrings --- src/couplings.jl | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/couplings.jl b/src/couplings.jl index e712e894..71f2f407 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -1,9 +1,14 @@ using SparseArrays -# TODO: should we add another field `A_3` which we can use to filter out those -# parts of the vector to which we apply the identity? E.g. -# you want to use x[1] to parameterize transform of x[2], but you don't want -# to do anything with x[3] +""" + PartitionMask{A}(A_1::A, A_2::A, A_3::A) where {A} + +This is used to partition and recombine a vector into 3 disjoint "subvectors". + +Implements +- `partition(m::PartitionMask, x)`: partitions `x` into 3 disjoint "subvectors" +- `combine(m::PartitionMask, x_1, x_2, x_3)`: combines 3 disjoint vectors into a single one +""" struct PartitionMask{A} A_1::A A_2::A @@ -77,12 +82,31 @@ function PartitionMask(n::Int, indices) end PartitionMask(x::AbstractVector, indices) = PartitionMask(length(x), indices) +""" + combine(m::PartitionMask, x_1, x_2, x_3) + +Combines `x_1`, `x_2`, and `x_3` into a single vector. +""" @inline combine(m::PartitionMask, x_1, x_2, x_3) = m.A_1 * x_1 .+ m.A_2 * x_2 .+ m.A_3 * x_3 + +""" + partition(m::PartitionMask, x) + +Partitions `x` into 3 disjoint subvectors. +""" @inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x) # CouplingLayer +""" + CouplingLayer{B, M, F}(mask::M, θ::F) + +Implements a coupling-layer as defined in [1]. + +# References +[1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). +""" struct CouplingLayer{B, M, F} <: Bijector where {B, M <: PartitionMask, F} mask::M θ::F @@ -95,12 +119,14 @@ function CouplingLayer(B, θ, n::Int) end - function (cl::CouplingLayer{B})(x) where {B} + # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) + # construct bijector `B` using θ(x₂) b = B(cl.θ(x_2)) + # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end From 8a970ee7886135c20fe1e3a5889b77dad102ec0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Sep 2019 11:49:00 +0200 Subject: [PATCH 11/89] currently relies on Tracker#master due to Tracker/issues/12 --- Manifest.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index cc6f7d86..ceaf096d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -280,9 +280,11 @@ version = "0.5.6" [[Tracker]] deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] -git-tree-sha1 = "327342fec6e09f68ced0c2dc5731ed475e4b696b" +git-tree-sha1 = "1aa443d3b4bfa91a8aec32f169a479cb87309910" +repo-rev = "master" +repo-url = "https://github.com/FluxML/Tracker.jl.git" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.2" +version = "0.2.3" [[URIParser]] deps = ["Test", "Unicode"] From ba2e2438fcd290b2e7ea469fbad41e6a99f8d8fa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 11 Sep 2019 19:37:19 +0200 Subject: [PATCH 12/89] added extra construct of coupling layer from couplinglayer and mask --- src/couplings.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/couplings.jl b/src/couplings.jl index 71f2f407..575d9071 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -118,6 +118,9 @@ function CouplingLayer(B, θ, n::Int) return CouplingLayer(B, PartitionMask(n, 1:idx), θ) end +function CouplingLayer(cl::CouplingLayer{B}, mask::PartitionMask) where {B} + return CouplingLayer(B, mask, cl.θ) +end function (cl::CouplingLayer{B})(x) where {B} # partition vector using `cl.mask::PartitionMask` From 33758bd189e449c853b30db64b8246fe20c3fa69 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Sep 2019 19:15:29 +0200 Subject: [PATCH 13/89] added dimension of expected input to Bijector type --- src/interface.jl | 78 +++++++++++++++++++++++++++++------------------ src/norm_flows.jl | 4 +-- test/interface.jl | 8 ++--- 3 files changed, 55 insertions(+), 35 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index d0705dd4..0bdd7cf3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -37,7 +37,7 @@ end ###################### "Abstract type for a `Bijector`." -abstract type Bijector end +abstract type Bijector{N} end Broadcast.broadcastable(b::Bijector) = Ref(b) @@ -45,7 +45,7 @@ Broadcast.broadcastable(b::Bijector) = Ref(b) Abstract type for a `Bijector` making use of auto-differentation (AD) to implement `jacobian` and, by impliciation, `logabsdetjac`. """ -abstract type ADBijector{AD} <: Bijector end +abstract type ADBijector{AD, N} <: Bijector{N} end """ inv(b::Bijector) @@ -53,10 +53,14 @@ abstract type ADBijector{AD} <: Bijector end A `Bijector` representing the inverse transform of `b`. """ -struct Inversed{B <: Bijector} <: Bijector +# TODO: can we do something like `Bijector{N}` instead? +struct Inversed{B <: Bijector, N} <: Bijector{N} orig::B + + Inversed(b::B) where {N, B<:Bijector{N}} = new{B, N}(b) end + inv(b::Bijector) = Inversed(b) inv(ib::Inversed{<:Bijector}) = ib.orig @@ -174,31 +178,40 @@ cb2 = composel(b2, b1) # => Composed.ts == (b2, b1) cb1(x) == cb2(x) == b1(b2(x)) # => true ``` """ -struct Composed{A} <: Bijector +struct Composed{A, N} <: Bijector{N} ts::A end +Composed(ts::A) where {N, A <: AbstractArray{<: Bijector{N}}} = Composed{A, N}(ts) + """ composel(ts::Bijector...)::Composed{<:Tuple} Constructs `Composed` such that `ts` are applied left-to-right. """ -composel(ts::Bijector...) = Composed(ts) +composel(ts::Bijector{N}...) where {N} = Composed{typeof(ts), N}(ts) """ composer(ts::Bijector...)::Composed{<:Tuple} Constructs `Composed` such that `ts` are applied right-to-left. """ +<<<<<<< HEAD composer(ts::Bijector...) = Composed(reverse(ts)) +======= +function composer(ts::Bijector{N}...) where {N} + its = reverse(ts) + return Composed{typeof(its), N}(its) +end +>>>>>>> added dimension of expected input to Bijector type # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that # when we use the mathematical composition ∘ we get the expected behavior. # TODO: change behavior of `transform` of `Composed`? -∘(b1::Bijector, b2::Bijector) = composel(b2, b1) +∘(b1::Bijector{N}, b2::Bijector{N}) where {N} = composel(b2, b1) -inv(ct::Composed) = Composed(map(inv, reverse(ct.ts))) +inv(ct::Composed) = composer(map(inv, ct.ts)...) # # TODO: should arrays also be using recursive implementation instead? function (cb::Composed{<:AbstractArray{<:Bijector}})(x) @@ -260,7 +273,7 @@ end # Example bijector: Identity # ############################## -struct Identity <: Bijector end +struct Identity{N} <: Bijector{N} end (::Identity)(x) = x (::Inversed{<:Identity})(y) = y @@ -268,14 +281,12 @@ forward(::Identity, x) = (rv=x, logabsdetjac=zero(eltype(x))) logabsdetjac(::Identity, y) = zero(eltype(y)) -const IdentityBijector = Identity() - ############################### # Example: Logit and Logistic # ############################### using StatsFuns: logit, logistic -struct Logit{T<:Real} <: Bijector +struct Logit{T<:Real} <: Bijector{0} a::T b::T end @@ -289,14 +300,17 @@ logabsdetjac(b::Logit{<:Real}, x) = @. - log((x - b.a) * (b.b - x) / (b.b - b.a) # Exp & Log # ############# -struct Exp <: Bijector end -struct Log <: Bijector end +struct Exp{N} <: Bijector{N} end +struct Log{N} <: Bijector{N} end + +Exp() = Exp{0}() +Log() = Log{0}() (b::Log)(x) = @. log(x) (b::Exp)(y) = @. exp(y) -inv(b::Log) = Exp() -inv(b::Exp) = Log() +inv(b::Log{N}) where {N} = Exp{N}() +inv(b::Exp{N}) where {N} = Log{N}() logabsdetjac(b::Log, x) = - sum(log.(x)) logabsdetjac(b::Exp, y) = sum(y) @@ -304,10 +318,13 @@ logabsdetjac(b::Exp, y) = sum(y) ################# # Shift & Scale # ################# -struct Shift{T} <: Bijector +struct Shift{T, N} <: Bijector{N} a::T end +Shift(a::T) where {T<:Real} = Shift{T, 0}(a) +Shift(a::AbstractArray{T, N}) where {T, N} = Shift{T, N}(a) + (b::Shift)(x) = b.a + x (b::Shift{<:Real})(x::AbstractArray) = b.a .+ x (b::Shift{<:AbstractVector})(x::AbstractMatrix) = b.a .+ x @@ -318,10 +335,13 @@ logabsdetjac(b::Shift, x) = zero(eltype(x)) logabsdetjac(b::Shift{<:Real}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) logabsdetjac(b::Shift{<:AbstractVector}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) -struct Scale{T} <: Bijector +struct Scale{T, N} <: Bijector{N} a::T end +Scale(a::T) where {T<:Real} = Scale{T, 0}(a) +Scale(a::AbstractArray{T, N}) where {T, N} = Scale{T, N}(a) + (b::Scale)(x) = b.a * x (b::Scale{<:Real})(x::AbstractArray) = b.a .* x (b::Scale{<:AbstractVector{<:Real}})(x::AbstractMatrix{<:Real}) = x * b.a @@ -339,7 +359,7 @@ logabsdetjac(b::Scale, x) = log(abs(b.a)) #################### # Simplex bijector # #################### -struct SimplexBijector{T} <: Bijector where {T} end +struct SimplexBijector{T} <: Bijector{1} where {T} end const simplex_b = SimplexBijector{Val{false}}() const simplex_b_proj = SimplexBijector{Val{true}}() @@ -481,11 +501,11 @@ This is the default `Bijector` for a distribution. It uses `link` and `invlink` to compute the transformations, and `AD` to compute the `jacobian` and `logabsdetjac`. """ -struct DistributionBijector{AD, D} <: ADBijector{AD} where {D<:Distribution} +struct DistributionBijector{AD, D, N} <: ADBijector{AD, N} where {D<:Distribution} dist::D end function DistributionBijector(dist::D) where {D<:Distribution} - DistributionBijector{ADBackend(), D}(dist) + DistributionBijector{ADBackend(), D, length(size(dist))}(dist) end # Simply uses `link` and `invlink` as transforms with AD to get jacobian @@ -497,12 +517,12 @@ end bijector(d::Distribution) = DistributionBijector(d) # Transformed distributions -struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} +struct TransformedDistribution{D, B, V, N} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector{N}} dist::D transform::B end function TransformedDistribution(d::D, b::B) where {V<:VariateForm, B<:Bijector, D<:Distribution{V, Continuous}} - return TransformedDistribution{D, B, V}(d, b) + return TransformedDistribution{D, B, V, length(size(d))}(d, b) end @@ -530,10 +550,10 @@ transformed(d) = transformed(d, bijector(d)) Returns the constrained-to-unconstrained bijector for distribution `d`. """ -bijector(d::Normal) = IdentityBijector -bijector(d::MvNormal) = IdentityBijector -bijector(d::PositiveDistribution) = Log() -bijector(d::MvLogNormal) = Log() +bijector(d::Normal) = Identity{0}() +bijector(d::MvNormal) = Identity{1}() +bijector(d::PositiveDistribution) = Log{0}() +bijector(d::MvLogNormal) = Log{0}() bijector(d::SimplexDistribution) = simplex_b_proj _union2tuple(T1::Type, T2::Type) = (T1, T2) @@ -564,7 +584,7 @@ function bijector(d::TransformDistribution) where {D<:Distribution} elseif upperbounded return (Log() ∘ Shift(b) ∘ Scale(- one(typeof(b)))) else - return IdentityBijector + return Identity{0}() end end @@ -692,7 +712,7 @@ end const GLOBAL_RNG = Distributions.GLOBAL_RNG function _forward(d::UnivariateDistribution, x) - y, logjac = forward(IdentityBijector, x) + y, logjac = forward(Identity{0}(), x) return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf.(d, x)) end @@ -701,7 +721,7 @@ function forward(rng::AbstractRNG, d::Distribution, num_samples::Int) return _forward(d, rand(rng, d, num_samples)) end function _forward(d::Distribution, x) - y, logjac = forward(IdentityBijector, x) + y, logjac = forward(Identity{length(size(d))}(), x) return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf(d, x)) end diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 272c1ed3..120ae676 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -14,7 +14,7 @@ using Roots # for inverse # PlanarLayer # ############### -mutable struct PlanarLayer{T1,T2} <: Bijector +mutable struct PlanarLayer{T1,T2} <: Bijector{1} w::T1 u::T1 b::T2 @@ -90,7 +90,7 @@ logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac # FIXME: using `TrackedArray` for the parameters, we end up with # nested tracked structures; don't want this. -mutable struct RadialLayer{T1,T2} <: Bijector +mutable struct RadialLayer{T1,T2} <: Bijector{1} α_::T1 β::T1 z_0::T2 diff --git a/test/interface.jl b/test/interface.jl index a04f1ddd..f04163dc 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -6,7 +6,7 @@ using ForwardDiff Random.seed!(123) -struct NonInvertibleBijector{AD} <: ADBijector{AD} end +struct NonInvertibleBijector{AD} <: ADBijector{AD, 2} end # Scalar tests @testset "Interface" begin @@ -95,7 +95,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @testset "$dist: ForwardDiff AD" begin x = rand(dist) - b = DistributionBijector{Bijectors.ADBackend(:forward_diff), typeof(dist)}(dist) + b = DistributionBijector{Bijectors.ADBackend(:forward_diff), typeof(dist), length(size(dist))}(dist) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf @@ -108,7 +108,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @testset "$dist: Tracker AD" begin x = rand(dist) - b = DistributionBijector{Bijectors.ADBackend(:reverse_diff), typeof(dist)}(dist) + b = DistributionBijector{Bijectors.ADBackend(:reverse_diff), typeof(dist), length(size(dist))}(dist) @test abs(det(Bijectors.jacobian(b, x))) > 0 @test logabsdetjac(b, x) ≠ Inf @@ -236,7 +236,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end x = rand(d) y = td.transform(x) - b = Bijectors.composel(td.transform, Bijectors.Identity()) + b = Bijectors.composel(td.transform, Bijectors.Identity{0}()) ib = inv(b) @test forward(b, x) == forward(td.transform, x) From ad54c16eb1cf96135fc15c0f34f98fd795c96874 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 09:56:50 +0200 Subject: [PATCH 14/89] composition now fails upon dimension mismatch --- src/interface.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 0bdd7cf3..fcef0dcf 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -209,7 +209,14 @@ end # but in mathematics we usually go from right-to-left; this reversal ensures that # when we use the mathematical composition ∘ we get the expected behavior. # TODO: change behavior of `transform` of `Composed`? -∘(b1::Bijector{N}, b2::Bijector{N}) where {N} = composel(b2, b1) +@generated function ∘(b1::Bijector{N1}, b2::Bijector{N2}) where {N1, N2} + if N1 == N2 + return :(composel(b2, b1)) + else + # FIXME: this doesn't give a stack trace? + return :(throw(DimensionMismatch("$(typeof(b1)) expects $(N1)-dim but $(typeof(b2)) expects $(N2)-dim"))) + end +end inv(ct::Composed) = composer(map(inv, ct.ts)...) From 0ddf71decf4fce9d9a2459463790cc5eac17bcb8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 09:57:19 +0200 Subject: [PATCH 15/89] removed dots in logabsdetjac accumulates so as to not fail silently --- src/interface.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index fcef0dcf..a0033fa8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -237,11 +237,11 @@ _transform(x, b::Bijector, bs::Bijector...) = _transform(b(x), bs...) function _logabsdetjac(x, b1::Bijector, b2::Bijector) res = forward(b1, x) - return logabsdetjac(b2, res.rv) .+ res.logabsdetjac + return logabsdetjac(b2, res.rv) + res.logabsdetjac end function _logabsdetjac(x, b1::Bijector, bs::Bijector...) res = forward(b1, x) - return _logabsdetjac(res.rv, bs...) .+ res.logabsdetjac + return _logabsdetjac(res.rv, bs...) + res.logabsdetjac end logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) @@ -250,16 +250,16 @@ logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...) # in which case forward(...) immediately calls `_forward(::NamedTuple, b::Bijector)` function _forward(f::NamedTuple, b::Bijector) y, logjac = forward(b, f.rv) - return (rv=y, logabsdetjac=logjac .+ f.logabsdetjac) + return (rv=y, logabsdetjac=logjac + f.logabsdetjac) end function _forward(f::NamedTuple, b1::Bijector, b2::Bijector) f1 = forward(b1, f.rv) f2 = forward(b2, f1.rv) - return (rv=f2.rv, logabsdetjac=f2.logabsdetjac .+ f1.logabsdetjac .+ f.logabsdetjac) + return (rv=f2.rv, logabsdetjac=f2.logabsdetjac + f1.logabsdetjac + f.logabsdetjac) end function _forward(f::NamedTuple, b::Bijector, bs::Bijector...) f1 = forward(b, f.rv) - f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac .+ f.logabsdetjac) + f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac + f.logabsdetjac) return _forward(f_, bs...) end _forward(x, b::Bijector, bs::Bijector...) = _forward(forward(b, x), bs...) @@ -271,7 +271,7 @@ function forward(cb::Composed, x) for t in cb.ts[2:end] res = forward(t, rv) rv = res.rv - logjac = res.logabsdetjac .+ logjac + logjac = res.logabsdetjac + logjac end return (rv=rv, logabsdetjac=logjac) end From 32ed5c4d10a7dc35593e6a7b9419ba16402c19b0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 09:58:01 +0200 Subject: [PATCH 16/89] updated Identity to work with batches --- src/interface.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index a0033fa8..6c15ea5a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -282,11 +282,23 @@ end struct Identity{N} <: Bijector{N} end (::Identity)(x) = x -(::Inversed{<:Identity})(y) = y +inv(b::Identity) = b forward(::Identity, x) = (rv=x, logabsdetjac=zero(eltype(x))) -logabsdetjac(::Identity, y) = zero(eltype(y)) +logabsdetjac(::Identity, x::Real) = zero(eltype(x)) +@generated function logabsdetjac( + b::Identity{N1}, + x::AbstractArray{T2, N2} +) where {N1, T2, N2} + if N1 == N2 + return :(zero(eltype(x))) + elseif N1 + 1 == N2 + return :(zeros(eltype(x), size(x, $N2))) + else + return :(throw(MethodError(logabsdetjac, (b, x)))) + end +end ############################### # Example: Logit and Logistic # From 29562af15b7d71f6a6945cd2ee0ffa860e92a17d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 09:58:34 +0200 Subject: [PATCH 17/89] batch-specialization for Log and Exp --- src/interface.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 6c15ea5a..c04b33bd 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -325,14 +325,21 @@ struct Log{N} <: Bijector{N} end Exp() = Exp{0}() Log() = Log{0}() -(b::Log)(x) = @. log(x) (b::Exp)(y) = @. exp(y) +(b::Log)(x) = @. log(x) -inv(b::Log{N}) where {N} = Exp{N}() inv(b::Exp{N}) where {N} = Log{N}() +inv(b::Log{N}) where {N} = Exp{N}() + +logabsdetjac(b::Exp{0}, x::Real) = x +logabsdetjac(b::Exp{0}, x::AbstractVector) = x +logabsdetjac(b::Exp{1}, x::AbstractVector) = sum(x) +logabsdetjac(b::Exp{1}, x::AbstractMatrix) = vec(sum(x; dims = 1)) -logabsdetjac(b::Log, x) = - sum(log.(x)) -logabsdetjac(b::Exp, y) = sum(y) +logabsdetjac(b::Log{0}, x::Real) = -log(x) +logabsdetjac(b::Log{0}, x::AbstractVector) = -log.(x) +logabsdetjac(b::Log{1}, x::AbstractVector) = - sum(log.(x)) +logabsdetjac(b::Log{1}, x::AbstractMatrix) = - vec(sum(log.(x); dims = 1)) ################# # Shift & Scale # From 40ed8116a5f1c75ddb30d723830b244e9879eff4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 09:59:05 +0200 Subject: [PATCH 18/89] batch-specialization for Scale and Shift --- src/interface.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index c04b33bd..32c8cc34 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -349,28 +349,29 @@ struct Shift{T, N} <: Bijector{N} end Shift(a::T) where {T<:Real} = Shift{T, 0}(a) -Shift(a::AbstractArray{T, N}) where {T, N} = Shift{T, N}(a) +Shift(a::A) where {T, N, A<:AbstractArray{T, N}} = Shift{A, N}(a) (b::Shift)(x) = b.a + x (b::Shift{<:Real})(x::AbstractArray) = b.a .+ x (b::Shift{<:AbstractVector})(x::AbstractMatrix) = b.a .+ x inv(b::Shift) = Shift(-b.a) -logabsdetjac(b::Shift, x) = zero(eltype(x)) -# FIXME: ambiguous whether or not this is actually a batch or whatever -logabsdetjac(b::Shift{<:Real}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) -logabsdetjac(b::Shift{<:AbstractVector}, x::AbstractMatrix) = zeros(eltype(x), size(x, 2)) + +logabsdetjac(b::Shift{<:Real, 0}, x::Real) = zero(eltype(x)) +logabsdetjac(b::Shift{<:Real, 0}, x::AbstractVector) = zeros(eltype(x), length(x)) +logabsdetjac(b::Shift{T, 1}, x::AbstractVector) where {T<:Union{Real, AbstractVector}} = zero(eltype(x)) +logabsdetjac(b::Shift{T, 1}, x::AbstractMatrix) where {T<:Union{Real, AbstractVector}} = zeros(eltype(x), size(x, 2)) struct Scale{T, N} <: Bijector{N} a::T end Scale(a::T) where {T<:Real} = Scale{T, 0}(a) -Scale(a::AbstractArray{T, N}) where {T, N} = Scale{T, N}(a) +Scale(a::A) where {T, N, A<:AbstractArray{T, N}} = Scale{A, N}(a) -(b::Scale)(x) = b.a * x +(b::Scale)(x) = b.a .* x (b::Scale{<:Real})(x::AbstractArray) = b.a .* x -(b::Scale{<:AbstractVector{<:Real}})(x::AbstractMatrix{<:Real}) = x * b.a +(b::Scale{<:AbstractVector{<:Real}, 2})(x::AbstractMatrix{<:Real}) = b.a .* x inv(b::Scale) = Scale(inv(b.a)) inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) @@ -380,7 +381,11 @@ inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) # logabsdetjac(b::Scale{<: AbstractVector}, x::AbstractMatrix) # Is this a batch or is it simply a matrix we want to scale differently # in each component? -logabsdetjac(b::Scale, x) = log(abs(b.a)) +logabsdetjac(b::Scale{<:Real, 0}, x::Real) = log(abs(b.a)) +logabsdetjac(b::Scale{<:Real, 0}, x::AbstractVector) = log(abs(b.a)) .* ones(eltype(x), length(x)) +logabsdetjac(b::Scale{<:Real, 1}, x::AbstractVector) = log(abs(b.a)) * length(x) +logabsdetjac(b::Scale{<:AbstractVector, 1}, x::AbstractVector) = sum(log.(abs.(b.a))) +logabsdetjac(b::Scale{<:AbstractVector, 1}, x::AbstractMatrix) = sum(log.(abs.(b.a))) * ones(eltype(x), size(x, 2)) #################### # Simplex bijector # From 669e72d2eac889ccc65fbe1b0988ee8f66d1d83d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 09:59:18 +0200 Subject: [PATCH 19/89] added dimension method for bijectors --- src/interface.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 32c8cc34..ae730c31 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -39,6 +39,8 @@ end "Abstract type for a `Bijector`." abstract type Bijector{N} end +dimension(b::Bijector{N}) where {N} = N + Broadcast.broadcastable(b::Bijector) = Ref(b) """ From f08dc5178c135365c3d89722ac093ef3091bf710 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 09:59:58 +0200 Subject: [PATCH 20/89] added tests for batch computation for a bunch of bijectors --- test/interface.jl | 69 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index f04163dc..53fe1199 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -4,9 +4,11 @@ using Random using LinearAlgebra using ForwardDiff +using Bijectors: Log, Exp, Shift, Scale, Logit + Random.seed!(123) -struct NonInvertibleBijector{AD} <: ADBijector{AD, 2} end +struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end # Scalar tests @testset "Interface" begin @@ -121,6 +123,71 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 2} end end end + @testset "Batch computation" begin + bs_xs = [ + (Scale(2.0), randn(3)), + (Scale([1.0, 2.0]), randn(2, 3)), + (Shift(2.0), randn(3)), + (Shift([1.0, 2.0]), randn(2, 3)), + (Log{0}(), exp.(randn(3))), + (Log{1}(), exp.(randn(2, 3))), + (Exp{0}(), randn(3)), + (Exp{1}(), randn(2, 3)), + (Log{1}() ∘ Exp{1}(), randn(2, 3)), + (inv(Logit(-1.0, 1.0)), randn(3)), + (Identity{0}(), randn(3)), + (Identity{1}(), randn(2, 3)) + ] + + for (b, xs) in bs_xs + @testset "$b" begin + D = Bijectors.dimension(b) + ib = inv(b) + + @test Bijectors.dimension(ib) == D + + x = D == 0 ? xs[1] : xs[:, 1] + + y = b(x) + ys = b(xs) + + x_ = ib(y) + xs_ = ib(ys) + + @test size(y) == size(x) + @test size(ys) == size(xs) + @test size(x_) == size(x) + @test size(xs_) == size(xs) + + if D == 0 + @test y == ys[1] + + @test length(logabsdetjac(b, xs)) == length(xs) + @test logabsdetjac(b, x) == logabsdetjac(b, xs)[1] + + @test length(logabsdetjac(ib, ys)) == length(xs) + @test logabsdetjac(ib, y) == logabsdetjac(ib, ys)[1] + elseif Bijectors.dimension(b) == 1 + @test y == ys[:, 1] + # Comparing sizes instead of lengths ensures we catch errors s.t. + # length(x) == 3 when size(x) == (1, 3). + # We want the return value to + @test size(logabsdetjac(b, xs)) == (size(xs, 2), ) + @test logabsdetjac(b, x) == logabsdetjac(b, xs)[1] + + @test size(logabsdetjac(ib, ys)) == (size(xs, 2), ) + @test logabsdetjac(ib, y) == logabsdetjac(ib, ys)[1] + else + error("tests not implemented yet") + end + end + end + + @testset "Composition" begin + @test_throws DimensionMismatch (Exp{1}() ∘ Log{0}()) + end + end + @testset "Truncated" begin d = Truncated(Normal(), -1, 1) b = bijector(d) From e26c893a5a6f18c137498522dad7fc12dc92987b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 10:08:05 +0200 Subject: [PATCH 21/89] removed a false comment --- src/interface.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index ae730c31..c11fcba2 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -215,7 +215,6 @@ end if N1 == N2 return :(composel(b2, b1)) else - # FIXME: this doesn't give a stack trace? return :(throw(DimensionMismatch("$(typeof(b1)) expects $(N1)-dim but $(typeof(b2)) expects $(N2)-dim"))) end end From 589e56456362a7a1ddc2c215a850b6dc57da64b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Sep 2019 13:31:25 +0200 Subject: [PATCH 22/89] removed some left-over stuff from a failed merge --- src/interface.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index c11fcba2..e624e71b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -198,14 +198,10 @@ composel(ts::Bijector{N}...) where {N} = Composed{typeof(ts), N}(ts) Constructs `Composed` such that `ts` are applied right-to-left. """ -<<<<<<< HEAD -composer(ts::Bijector...) = Composed(reverse(ts)) -======= function composer(ts::Bijector{N}...) where {N} its = reverse(ts) return Composed{typeof(its), N}(its) end ->>>>>>> added dimension of expected input to Bijector type # The transformation of `Composed` applies functions left-to-right # but in mathematics we usually go from right-to-left; this reversal ensures that From a16f05a2730bcf85359edbbf160b02425e07ee9d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 12:18:45 +0200 Subject: [PATCH 23/89] added some convenient constructors for SimplexBijector --- src/interface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 0076f5f5..1d229270 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -401,7 +401,8 @@ logabsdetjac(b::Scale, x) = log(abs(b.a)) # Simplex bijector # #################### struct SimplexBijector{T} <: Bijector where {T} end -SimplexBijector() = SimplexBijector{Val{true}}() +SimplexBijector(proj:::bool) = SimplexBijector{Val{proj}}() +SimplexBijector() = SimplexBijector(true) const simplex_b = SimplexBijector{Val{false}}() const simplex_b_proj = SimplexBijector{Val{true}}() From 09d1bd660016aacfa2367bb2287405dcfc148e12 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 12:28:40 +0200 Subject: [PATCH 24/89] fixed bug from previous commit and added tests for Stacked --- src/interface.jl | 2 +- test/interface.jl | 326 +++++++++++++++++++++++----------------------- 2 files changed, 167 insertions(+), 161 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 1d229270..01b30b87 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -401,7 +401,7 @@ logabsdetjac(b::Scale, x) = log(abs(b.a)) # Simplex bijector # #################### struct SimplexBijector{T} <: Bijector where {T} end -SimplexBijector(proj:::bool) = SimplexBijector{Val{proj}}() +SimplexBijector(proj::Bool) = SimplexBijector{Val(proj)}() SimplexBijector() = SimplexBijector(true) const simplex_b = SimplexBijector{Val{false}}() diff --git a/test/interface.jl b/test/interface.jl index 8f603267..f5ec3a79 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -196,187 +196,193 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end end end - @testset "Matrix variate" begin - v = 7.0 - S = Matrix(1.0I, 2, 2) - S[1, 2] = S[2, 1] = 0.5 - - matrix_dists = [ - Wishart(v,S), - InverseWishart(v,S) - ] - - for dist in matrix_dists - @testset "$dist: dist" begin - td = transformed(dist) +@testset "Matrix variate" begin + v = 7.0 + S = Matrix(1.0I, 2, 2) + S[1, 2] = S[2, 1] = 0.5 - # single sample - y = rand(td) - x = inv(td.transform)(y) - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + matrix_dists = [ + Wishart(v,S), + InverseWishart(v,S) + ] - # TODO: implement `logabsdetjac` for these - # logpdf_with_jac - # lp, logjac = logpdf_with_jac(td, y) - # @test lp ≈ logpdf(td, y) - # @test logjac ≈ logabsdetjacinv(td.transform, y) + for dist in matrix_dists + @testset "$dist: dist" begin + td = transformed(dist) - # multi-sample - y = rand(td, 10) - x = inv(td.transform)(y) - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - end + # single sample + y = rand(td) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + + # TODO: implement `logabsdetjac` for these + # logpdf_with_jac + # lp, logjac = logpdf_with_jac(td, y) + # @test lp ≈ logpdf(td, y) + # @test logjac ≈ logabsdetjacinv(td.transform, y) + + # multi-sample + y = rand(td, 10) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) end end +end - @testset "Composition <: Bijector" begin - d = Beta() - td = transformed(d) +@testset "Composition <: Bijector" begin + d = Beta() + td = transformed(d) - x = rand(d) - y = td.transform(x) + x = rand(d) + y = td.transform(x) - b = Bijectors.composel(td.transform, Bijectors.Identity()) - ib = inv(b) + b = Bijectors.composel(td.transform, Bijectors.Identity()) + ib = inv(b) - @test forward(b, x) == forward(td.transform, x) - @test forward(ib, y) == forward(inv(td.transform), y) + @test forward(b, x) == forward(td.transform, x) + @test forward(ib, y) == forward(inv(td.transform), y) - # inverse works fine for composition - cb = b ∘ ib - @test cb(x) ≈ x + # inverse works fine for composition + cb = b ∘ ib + @test cb(x) ≈ x - cb2 = cb ∘ cb - @test cb(x) ≈ x + cb2 = cb ∘ cb + @test cb(x) ≈ x - # ensures that the `logabsdetjac` is correct - x = rand(d) - b = inv(bijector(d)) - @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) + # ensures that the `logabsdetjac` is correct + x = rand(d) + b = inv(bijector(d)) + @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) - # order of composed evaluation - b1 = DistributionBijector(d) - b2 = DistributionBijector(Gamma()) + # order of composed evaluation + b1 = DistributionBijector(d) + b2 = DistributionBijector(Gamma()) - cb = b1 ∘ b2 - @test cb(x) ≈ b1(b2(x)) + cb = b1 ∘ b2 + @test cb(x) ≈ b1(b2(x)) - # contrived example - b = bijector(d) - cb = inv(b) ∘ b - cb = cb ∘ cb - @test (cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x - - # forward for tuple and array - d = Beta() - b = inv(bijector(d)) - b⁻¹ = inv(b) - x = rand(d) + # contrived example + b = bijector(d) + cb = inv(b) ∘ b + cb = cb ∘ cb + @test (cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x - cb_t = b⁻¹ ∘ b⁻¹ - f_t = forward(cb_t, x) + # forward for tuple and array + d = Beta() + b = inv(bijector(d)) + b⁻¹ = inv(b) + x = rand(d) - cb_a = Composed([b⁻¹, b⁻¹]) - f_a = forward(cb_a, x) + cb_t = b⁻¹ ∘ b⁻¹ + f_t = forward(cb_t, x) - @test f_t == f_a - end + cb_a = Composed([b⁻¹, b⁻¹]) + f_a = forward(cb_a, x) - @testset "Stacked <: Bijector" begin - # `logabsdetjac` without AD - d = Beta() - b = bijector(d) - x = rand(d) - y = b(x) - sb = vcat(b, b, inv(b), inv(b)) - @test logabsdetjac(sb, [x, x, y, y]) ≈ 0.0 - - # `logabsdetjac` with AD - b = DistributionBijector(d) - y = b(x) - sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple - sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array - @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 - @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 - - @testset "Stacked: ADVI with MvNormal" begin - # MvNormal test - dists = [ - Beta(), - Beta(), - Beta(), - InverseGamma(), - InverseGamma(), - Gamma(), - Gamma(), - InverseGamma(), - Cauchy(), - Gamma(), - MvNormal(zeros(2), ones(2)) - ] - - ranges = [] - idx = 1 - for i = 1:length(dists) - d = dists[i] - push!(ranges, idx:idx + length(d) - 1) - idx += length(d) - end + @test f_t == f_a +end - num_params = ranges[end][end] - d = MvNormal(zeros(num_params), ones(num_params)) - - # Stacked{<:Array} - bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists - ibs = inv.(bs) # invert, so we get unconstrained-to-constrained - sb = Stacked(ibs, ranges) # => Stacked <: Bijector - x = rand(d) - sb(x) - @test sb isa Stacked - - td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} - @test td isa Distribution{Multivariate, Continuous} - - # check that wrong ranges fails - sb = vcat(ibs...) - td = transformed(d, sb) - x = rand(d) - @test_throws AssertionError sb(x) - - # Stacked{<:Tuple} - bs = bijector.(tuple(dists...)) - ibs = inv.(bs) - sb = Stacked(ibs, ranges) - isb = inv(sb) - @test sb isa Stacked{<: Tuple} - - # inverse - td = transformed(d, sb) - y = rand(td) - x = isb(y) - @test sb(x) ≈ y - - # verification of computation - x = rand(d) - y = sb(x) - y_ = vcat([ibs[i](x[ranges[i]]) for i = 1:length(dists)]...) - x_ = vcat([bs[i](y[ranges[i]]) for i = 1:length(dists)]...) - @test x ≈ x_ - @test y ≈ y_ - - # AD verification - @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) - @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) +@testset "Stacked <: Bijector" begin + # `logabsdetjac` without AD + d = Beta() + b = bijector(d) + x = rand(d) + y = b(x) + sb = vcat(b, b, inv(b), inv(b)) + @test logabsdetjac(sb, [x, x, y, y]) ≈ 0.0 + + # `logabsdetjac` with AD + b = DistributionBijector(d) + y = b(x) + sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple + sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 + @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 + + # stuff + x = ones(3) + sb = vcat(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) + @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] + @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[i]) for i = 1:3]) + + @testset "Stacked: ADVI with MvNormal" begin + # MvNormal test + dists = [ + Beta(), + Beta(), + Beta(), + InverseGamma(), + InverseGamma(), + Gamma(), + Gamma(), + InverseGamma(), + Cauchy(), + Gamma(), + MvNormal(zeros(2), ones(2)) + ] + + ranges = [] + idx = 1 + for i = 1:length(dists) + d = dists[i] + push!(ranges, idx:idx + length(d) - 1) + idx += length(d) end - end - @testset "Example: ADVI single" begin - # Usage in ADVI - d = Beta() - b = DistributionBijector(d) # [0, 1] → ℝ - ib = inv(b) # ℝ → [0, 1] - td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] - x = rand(td) # ∈ [0, 1] - @test 0 ≤ x ≤ 1 + num_params = ranges[end][end] + d = MvNormal(zeros(num_params), ones(num_params)) + + # Stacked{<:Array} + bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists + ibs = inv.(bs) # invert, so we get unconstrained-to-constrained + sb = Stacked(ibs, ranges) # => Stacked <: Bijector + x = rand(d) + sb(x) + @test sb isa Stacked + + td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} + @test td isa Distribution{Multivariate, Continuous} + + # check that wrong ranges fails + sb = vcat(ibs...) + td = transformed(d, sb) + x = rand(d) + @test_throws AssertionError sb(x) + + # Stacked{<:Tuple} + bs = bijector.(tuple(dists...)) + ibs = inv.(bs) + sb = Stacked(ibs, ranges) + isb = inv(sb) + @test sb isa Stacked{<: Tuple} + + # inverse + td = transformed(d, sb) + y = rand(td) + x = isb(y) + @test sb(x) ≈ y + + # verification of computation + x = rand(d) + y = sb(x) + y_ = vcat([ibs[i](x[ranges[i]]) for i = 1:length(dists)]...) + x_ = vcat([bs[i](y[ranges[i]]) for i = 1:length(dists)]...) + @test x ≈ x_ + @test y ≈ y_ + + # AD verification + @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) + @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) end end + +@testset "Example: ADVI single" begin + # Usage in ADVI + d = Beta() + b = DistributionBijector(d) # [0, 1] → ℝ + ib = inv(b) # ℝ → [0, 1] + td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] + x = rand(td) # ∈ [0, 1] + @test 0 ≤ x ≤ 1 +end +end From 353e80c9fcc36f11cbf237234aa616bfb6b68f19 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 13:02:34 +0200 Subject: [PATCH 25/89] fixed some left-over stuff from a merge --- test/interface.jl | 344 ++++++++++++++++++++++------------------------ 1 file changed, 162 insertions(+), 182 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 58b16b96..0c8746b9 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -196,215 +196,195 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end end end -@testset "Matrix variate" begin - v = 7.0 - S = Matrix(1.0I, 2, 2) - S[1, 2] = S[2, 1] = 0.5 + @testset "Matrix variate" begin + v = 7.0 + S = Matrix(1.0I, 2, 2) + S[1, 2] = S[2, 1] = 0.5 + + matrix_dists = [ + Wishart(v,S), + InverseWishart(v,S) + ] - matrix_dists = [ - Wishart(v,S), - InverseWishart(v,S) - ] + for dist in matrix_dists + @testset "$dist: dist" begin + td = transformed(dist) - for dist in matrix_dists - @testset "$dist: dist" begin - td = transformed(dist) + # single sample + y = rand(td) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - # single sample - y = rand(td) - x = inv(td.transform)(y) - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) - - # TODO: implement `logabsdetjac` for these - # logpdf_with_jac - # lp, logjac = logpdf_with_jac(td, y) - # @test lp ≈ logpdf(td, y) - # @test logjac ≈ logabsdetjacinv(td.transform, y) - - # multi-sample - y = rand(td, 10) - x = inv(td.transform)(y) - @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + # TODO: implement `logabsdetjac` for these + # logpdf_with_jac + # lp, logjac = logpdf_with_jac(td, y) + # @test lp ≈ logpdf(td, y) + # @test logjac ≈ logabsdetjacinv(td.transform, y) + + # multi-sample + y = rand(td, 10) + x = inv(td.transform)(y) + @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) + end end end -end -@testset "Composition <: Bijector" begin - d = Beta() - td = transformed(d) + @testset "Composition <: Bijector" begin + d = Beta() + td = transformed(d) - x = rand(d) - y = td.transform(x) + x = rand(d) + y = td.transform(x) - b = Bijectors.composel(td.transform, Bijectors.Identity()) - ib = inv(b) + b = Bijectors.composel(td.transform, Bijectors.Identity()) + ib = inv(b) - @test forward(b, x) == forward(td.transform, x) - @test forward(ib, y) == forward(inv(td.transform), y) + @test forward(b, x) == forward(td.transform, x) + @test forward(ib, y) == forward(inv(td.transform), y) - # inverse works fine for composition - cb = b ∘ ib - @test cb(x) ≈ x + # inverse works fine for composition + cb = b ∘ ib + @test cb(x) ≈ x - cb2 = cb ∘ cb - @test cb(x) ≈ x + cb2 = cb ∘ cb + @test cb(x) ≈ x - # ensures that the `logabsdetjac` is correct - x = rand(d) - b = inv(bijector(d)) - @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) + # ensures that the `logabsdetjac` is correct + x = rand(d) + b = inv(bijector(d)) + @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) -<<<<<<< HEAD - # order of composed evaluation - b1 = DistributionBijector(d) - b2 = DistributionBijector(Gamma()) -======= @test forward(b, x) == forward(Bijectors.composer(b.ts...), x) # inverse works fine for composition cb = b ∘ ib @test cb(x) ≈ x ->>>>>>> master - - cb = b1 ∘ b2 - @test cb(x) ≈ b1(b2(x)) - - # contrived example - b = bijector(d) - cb = inv(b) ∘ b - cb = cb ∘ cb - @test (cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x - - # forward for tuple and array - d = Beta() - b = inv(bijector(d)) - b⁻¹ = inv(b) - x = rand(d) - - cb_t = b⁻¹ ∘ b⁻¹ - f_t = forward(cb_t, x) - cb_a = Composed([b⁻¹, b⁻¹]) - f_a = forward(cb_a, x) + cb = b1 ∘ b2 + @test cb(x) ≈ b1(b2(x)) - @test f_t == f_a -end - -@testset "Stacked <: Bijector" begin - # `logabsdetjac` without AD - d = Beta() - b = bijector(d) - x = rand(d) - y = b(x) - sb = vcat(b, b, inv(b), inv(b)) - @test logabsdetjac(sb, [x, x, y, y]) ≈ 0.0 - - # `logabsdetjac` with AD - b = DistributionBijector(d) - y = b(x) - sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple - sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array - @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 - @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 - - # stuff - x = ones(3) - sb = vcat(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) - @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] - @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[i]) for i = 1:3]) - - @testset "Stacked: ADVI with MvNormal" begin - # MvNormal test - dists = [ - Beta(), - Beta(), - Beta(), - InverseGamma(), - InverseGamma(), - Gamma(), - Gamma(), - InverseGamma(), - Cauchy(), - Gamma(), - MvNormal(zeros(2), ones(2)) - ] - -<<<<<<< HEAD - ranges = [] - idx = 1 - for i = 1:length(dists) - d = dists[i] - push!(ranges, idx:idx + length(d) - 1) - idx += length(d) - end - - num_params = ranges[end][end] - d = MvNormal(zeros(num_params), ones(num_params)) - - # Stacked{<:Array} - bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists - ibs = inv.(bs) # invert, so we get unconstrained-to-constrained - sb = Stacked(ibs, ranges) # => Stacked <: Bijector + # contrived example + b = bijector(d) + cb = inv(b) ∘ b + cb = cb ∘ cb + @test (cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x + + # forward for tuple and array + d = Beta() + b = inv(bijector(d)) + b⁻¹ = inv(b) x = rand(d) - sb(x) - @test sb isa Stacked - td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} - @test td isa Distribution{Multivariate, Continuous} -======= - @test f_t == f_a + cb_t = b⁻¹ ∘ b⁻¹ + f_t = forward(cb_t, x) - # `composer` and `composel` - cb_l = Bijectors.composel(b⁻¹, b⁻¹, b) - cb_r = Bijectors.composer(reverse(cb_l.ts)...) - y = cb_l(x) - @test y == Bijectors.composel(cb_r.ts...)(x) + cb_a = Composed([b⁻¹, b⁻¹]) + f_a = forward(cb_a, x) - k = length(cb_l.ts) - @test all([cb_l.ts[i] == cb_r.ts[i] for i = 1:k]) + @test f_t == f_a end ->>>>>>> master - # check that wrong ranges fails - sb = vcat(ibs...) - td = transformed(d, sb) - x = rand(d) - @test_throws AssertionError sb(x) - - # Stacked{<:Tuple} - bs = bijector.(tuple(dists...)) - ibs = inv.(bs) - sb = Stacked(ibs, ranges) - isb = inv(sb) - @test sb isa Stacked{<: Tuple} - - # inverse - td = transformed(d, sb) - y = rand(td) - x = isb(y) - @test sb(x) ≈ y - - # verification of computation + @testset "Stacked <: Bijector" begin + # `logabsdetjac` without AD + d = Beta() + b = bijector(d) x = rand(d) - y = sb(x) - y_ = vcat([ibs[i](x[ranges[i]]) for i = 1:length(dists)]...) - x_ = vcat([bs[i](y[ranges[i]]) for i = 1:length(dists)]...) - @test x ≈ x_ - @test y ≈ y_ - - # AD verification - @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) - @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) + y = b(x) + sb = vcat(b, b, inv(b), inv(b)) + @test logabsdetjac(sb, [x, x, y, y]) ≈ 0.0 + + # `logabsdetjac` with AD + b = DistributionBijector(d) + y = b(x) + sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple + sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 + @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 + + # stuff + x = ones(3) + sb = vcat(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) + @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] + @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[i]) for i = 1:3]) + + @testset "Stacked: ADVI with MvNormal" begin + # MvNormal test + dists = [ + Beta(), + Beta(), + Beta(), + InverseGamma(), + InverseGamma(), + Gamma(), + Gamma(), + InverseGamma(), + Cauchy(), + Gamma(), + MvNormal(zeros(2), ones(2)) + ] + + ranges = [] + idx = 1 + for i = 1:length(dists) + d = dists[i] + push!(ranges, idx:idx + length(d) - 1) + idx += length(d) + end + + num_params = ranges[end][end] + d = MvNormal(zeros(num_params), ones(num_params)) + + # Stacked{<:Array} + bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists + ibs = inv.(bs) # invert, so we get unconstrained-to-constrained + sb = Stacked(ibs, ranges) # => Stacked <: Bijector + x = rand(d) + sb(x) + @test sb isa Stacked + + td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} + @test td isa Distribution{Multivariate, Continuous} + + # check that wrong ranges fails + sb = vcat(ibs...) + td = transformed(d, sb) + x = rand(d) + @test_throws AssertionError sb(x) + + # Stacked{<:Tuple} + bs = bijector.(tuple(dists...)) + ibs = inv.(bs) + sb = Stacked(ibs, ranges) + isb = inv(sb) + @test sb isa Stacked{<: Tuple} + + # inverse + td = transformed(d, sb) + y = rand(td) + x = isb(y) + @test sb(x) ≈ y + + # verification of computation + x = rand(d) + y = sb(x) + y_ = vcat([ibs[i](x[ranges[i]]) for i = 1:length(dists)]...) + x_ = vcat([bs[i](y[ranges[i]]) for i = 1:length(dists)]...) + @test x ≈ x_ + @test y ≈ y_ + + # AD verification + @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) + @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) + end end -end -@testset "Example: ADVI single" begin - # Usage in ADVI - d = Beta() - b = DistributionBijector(d) # [0, 1] → ℝ - ib = inv(b) # ℝ → [0, 1] - td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] - x = rand(td) # ∈ [0, 1] - @test 0 ≤ x ≤ 1 -end + @testset "Example: ADVI single" begin + # Usage in ADVI + d = Beta() + b = DistributionBijector(d) # [0, 1] → ℝ + ib = inv(b) # ℝ → [0, 1] + td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] + x = rand(td) # ∈ [0, 1] + @test 0 ≤ x ≤ 1 + end end From 98860a5f012dcb3e0c1daa7be0ed75c9cdb02673 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 13:33:00 +0200 Subject: [PATCH 26/89] added comments convincing myself that entropy is invariant --- src/interface.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 26bda9f7..446ec27c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -831,6 +831,16 @@ forward(d::Distribution, num_samples::Int) = forward(GLOBAL_RNG, d, num_samples) # utility stuff params(td::Transformed) = params(td.dist) + +# ℍ(p̃(y)) +# = ∫ p̃(y) log p̃(y) dy +# = ∫ p(f⁻¹(y)) |det J(f⁻¹, y)| log (p(f⁻¹(y)) |det J(f⁻¹, y)|) dy +# = ∫ p(x) (log p(x) |det J(f⁻¹, f(x))|) dx +# = ∫ p(x) (log p(x) |det J(f⁻¹ ∘ f, x)|) dx +# = ∫ p(x) log (p(x) |det J(id, x)|) dx +# = ∫ p(x) log (p(x) ⋅ 1) dx +# = ∫ p(x) log p(x) dx +# = ℍ(p(x)) entropy(td::Transformed) = entropy(td.dist) # logabsdetjac for distributions From 6e7a174099082ea12d5b313572c9b571a62d32f6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 13:35:37 +0200 Subject: [PATCH 27/89] removed redundant comment --- src/interface.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 446ec27c..723b4ea0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -847,9 +847,7 @@ entropy(td::Transformed) = entropy(td.dist) logabsdetjacinv(d::UnivariateDistribution, x::T) where T <: Real = zero(T) logabsdetjacinv(d::MultivariateDistribution, x::AbstractVector{T}) where {T<:Real} = zero(T) -# for transformed distributions the `y` is going to be the transformed variable -# and so we use the inverse transform to get what we want -# TODO: should this be renamed to `logabsdetinvjac`? + """ logabsdetjacinv(td::UnivariateTransformed, y::Real) logabsdetjacinv(td::MultivariateTransformed, y::AbstractVector{<:Real}) From b63fb2951f95a82b7705265f90b83028b663d633 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 13:42:06 +0200 Subject: [PATCH 28/89] fixed tests which had be messed up in the merge --- test/interface.jl | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 0c8746b9..7ec86282 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -242,6 +242,8 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test forward(b, x) == forward(td.transform, x) @test forward(ib, y) == forward(inv(td.transform), y) + @test forward(b, x) == forward(Bijectors.composer(b.ts...), x) + # inverse works fine for composition cb = b ∘ ib @test cb(x) ≈ x @@ -254,11 +256,9 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end b = inv(bijector(d)) @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) - @test forward(b, x) == forward(Bijectors.composer(b.ts...), x) - - # inverse works fine for composition - cb = b ∘ ib - @test cb(x) ≈ x + # order of composed evaluation + b1 = DistributionBijector(d) + b2 = DistributionBijector(Gamma()) cb = b1 ∘ b2 @test cb(x) ≈ b1(b2(x)) @@ -282,6 +282,15 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end f_a = forward(cb_a, x) @test f_t == f_a + + # `composer` and `composel` + cb_l = Bijectors.composel(b⁻¹, b⁻¹, b) + cb_r = Bijectors.composer(reverse(cb_l.ts)...) + y = cb_l(x) + @test y == Bijectors.composel(cb_r.ts...)(x) + + k = length(cb_l.ts) + @test all([cb_l.ts[i] == cb_r.ts[i] for i = 1:k]) end @testset "Stacked <: Bijector" begin From edaed9a7e279cb530102c7f4b98168b6c473bd24 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 13:53:27 +0200 Subject: [PATCH 29/89] added message for size-discrepancy in Stacked and tests --- src/interface.jl | 4 ++-- test/interface.jl | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 723b4ea0..d3524d36 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -304,7 +304,7 @@ function (sb::Stacked)(x::AbstractArray{<: Real}) y = _transform(x, sb.ranges, sb.bs...) # TODO: maybe tell user to check their ranges? - @assert size(y) == size(x) + @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" return y end @@ -417,7 +417,7 @@ logabsdetjac(b::Scale, x) = log(abs(b.a)) # Simplex bijector # #################### struct SimplexBijector{T} <: Bijector where {T} end -SimplexBijector(proj::Bool) = SimplexBijector{Val(proj)}() +SimplexBijector(proj::Bool) = SimplexBijector{Val{proj}}() SimplexBijector() = SimplexBijector(true) const simplex_b = SimplexBijector{Val{false}}() diff --git a/test/interface.jl b/test/interface.jl index 7ec86282..207f6329 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -310,12 +310,16 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 - # stuff + # value-test x = ones(3) sb = vcat(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[i]) for i = 1:3]) + # TODO: change when we have dimensionality in the type + sb = vcat([Bijectors.Exp(), Bijectors.SimplexBijector()]...) + @test_throws AssertionError sb(x) + @testset "Stacked: ADVI with MvNormal" begin # MvNormal test dists = [ From 976ddf6aab0fbf6ee66a0f0848932d7e024dd167 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 14:06:52 +0200 Subject: [PATCH 30/89] remvoed reexport of StatsBase --- src/Bijectors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 14b2cc58..f8876ea1 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -2,7 +2,6 @@ module Bijectors using Reexport, Requires @reexport using Distributions -@reexport using StatsBase using StatsFuns using LinearAlgebra using MappedArrays @@ -33,10 +32,11 @@ export TransformDistribution, TransformedDistribution, UnivariateTransformed, MultivariateTransformed, + entropy, logpdf_with_jac, logpdf_forward, PlanarLayer, - RadialLayer + RadialLayer, const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) From dd6ec2aedb8230a55389830f0b813154aab70844 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 14:16:35 +0200 Subject: [PATCH 31/89] fixed a couple of typos --- src/Bijectors.jl | 2 +- test/interface.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index f8876ea1..c4b87878 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -36,7 +36,7 @@ export TransformDistribution, logpdf_with_jac, logpdf_forward, PlanarLayer, - RadialLayer, + RadialLayer const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) diff --git a/test/interface.jl b/test/interface.jl index 207f6329..f2175080 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -318,7 +318,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end # TODO: change when we have dimensionality in the type sb = vcat([Bijectors.Exp(), Bijectors.SimplexBijector()]...) - @test_throws AssertionError sb(x) + @test_throws AssertionError sb(x ./ sum(x)) @testset "Stacked: ADVI with MvNormal" begin # MvNormal test From c38f203f501abf58971182dc880d612a0bfb00b3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 15:28:46 +0200 Subject: [PATCH 32/89] checking something with CI --- test/interface.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/interface.jl b/test/interface.jl index f2175080..acc90723 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -318,6 +318,8 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end # TODO: change when we have dimensionality in the type sb = vcat([Bijectors.Exp(), Bijectors.SimplexBijector()]...) + println(sb) + println(x) @test_throws AssertionError sb(x ./ sum(x)) @testset "Stacked: ADVI with MvNormal" begin From f5d68eb1bd3b890a73b88f9f940289ff8e7b5bb3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 15:39:53 +0200 Subject: [PATCH 33/89] more CI test stuff --- test/interface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/interface.jl b/test/interface.jl index acc90723..27bbc6c3 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -320,6 +320,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end sb = vcat([Bijectors.Exp(), Bijectors.SimplexBijector()]...) println(sb) println(x) + sb(x ./ sum(x)) @test_throws AssertionError sb(x ./ sum(x)) @testset "Stacked: ADVI with MvNormal" begin From 898ae50dad330ea04187aeae9f3df02f78c5bfae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 16:03:51 +0200 Subject: [PATCH 34/89] fixed the test which failed on CI --- test/interface.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 27bbc6c3..ef71ef85 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -317,10 +317,8 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[i]) for i = 1:3]) # TODO: change when we have dimensionality in the type - sb = vcat([Bijectors.Exp(), Bijectors.SimplexBijector()]...) - println(sb) - println(x) - sb(x ./ sum(x)) + x = ones(4) + sb = vcat(Bijectors.Exp(), Bijectors.SimplexBijector()...) @test_throws AssertionError sb(x ./ sum(x)) @testset "Stacked: ADVI with MvNormal" begin From ddc8cd7650a92d33608d46ec6afff849e3761b32 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 16:05:01 +0200 Subject: [PATCH 35/89] okay now I fixed it --- test/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index ef71ef85..6192cfc9 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -318,7 +318,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end # TODO: change when we have dimensionality in the type x = ones(4) - sb = vcat(Bijectors.Exp(), Bijectors.SimplexBijector()...) + sb = vcat(Bijectors.Exp(), Bijectors.SimplexBijector()) @test_throws AssertionError sb(x ./ sum(x)) @testset "Stacked: ADVI with MvNormal" begin From 8b4eecb9942123b5dd9337455bd97d6b50a83c90 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 16:28:54 +0200 Subject: [PATCH 36/89] okay NOW i fixed the test --- test/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index 6192cfc9..ca9aa893 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -318,7 +318,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end # TODO: change when we have dimensionality in the type x = ones(4) - sb = vcat(Bijectors.Exp(), Bijectors.SimplexBijector()) + sb = Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) @test_throws AssertionError sb(x ./ sum(x)) @testset "Stacked: ADVI with MvNormal" begin From 3be1af3269cac3172c0a3eb6eca25b421d7c1876 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 16:29:00 +0200 Subject: [PATCH 37/89] added asserts to evaluation of SimplexBijector to catch length-1 --- src/interface.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index d3524d36..16be93e4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -434,6 +434,7 @@ end function (b::SimplexBijector{Val{proj}})(x::AbstractVector{T}) where {T, proj} y, K = similar(x), length(x) + @assert K > 1 ϵ = _eps(T) sum_tmp = zero(T) @@ -459,6 +460,7 @@ end # Vectorised implementation of the above. function (b::SimplexBijector{Val{proj}})(X::AbstractMatrix{T}) where {T<:Real, proj} Y, K, N = similar(X), size(X, 1), size(X, 2) + @assert K > 1 ϵ = _eps(T) @inbounds @simd for n in 1:size(X, 2) @@ -483,6 +485,7 @@ end function (ib::Inversed{<:SimplexBijector{Val{proj}}})(y::AbstractVector{T}) where {T, proj} x, K = similar(y), length(y) + @assert K > 1 ϵ = _eps(T) @inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1))) @@ -508,6 +511,7 @@ function (ib::Inversed{<:SimplexBijector{Val{proj}}})( Y::AbstractMatrix{T} ) where {T<:Real, proj} X, K, N = similar(Y), size(Y, 1), size(Y, 2) + @assert K > 1 ϵ = _eps(T) @inbounds @simd for n in 1:size(X, 2) From 6e4f3c82a88d693cee6fdb2af43f55388a8848e8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 16:48:14 +0200 Subject: [PATCH 38/89] made Stacked constructor more strict as it should be --- src/interface.jl | 17 ++++++++++++++++- test/interface.jl | 3 +++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index 16be93e4..86e46465 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -281,9 +281,24 @@ b([0.0, 1.0]) == [b1(0.0), 1.0] # => true struct Stacked{B, N} <: Bijector where N bs::B ranges::NTuple{N, UnitRange{Int}} + + function Stacked( + bs::C, + ranges::NTuple{N, UnitRange{Int}} + ) where {N, C<:Tuple{Vararg{<:Bijector, N}}} + return new{C, N}(bs, ranges) + end + + function Stacked( + bs::A, + ranges::NTuple{N, UnitRange{Int}} + ) where {N, A<:AbstractArray{<:Bijector}} + @assert length(bs) == N "number of bijectors is not same as number of ranges" + return new{A, N}(bs, ranges) + end end Stacked(bs) = Stacked(bs, NTuple{length(bs), UnitRange{Int}}([i:i for i = 1:length(bs)])) -Stacked(bs, ranges) = Stacked(bs, NTuple{length(bs), UnitRange{Int}}(ranges)) +Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...)) Base.vcat(bs::Bijector...) = Stacked(bs) diff --git a/test/interface.jl b/test/interface.jl index ca9aa893..c0038994 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -321,6 +321,9 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end sb = Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) @test_throws AssertionError sb(x ./ sum(x)) + @test_throws AssertionError Stacked([Bijectors.Exp(), ], (1:1, 2:3)) + @test_throws MethodError Stacked((Bijectors.Exp(), ), (1:1, 2:3)) + @testset "Stacked: ADVI with MvNormal" begin # MvNormal test dists = [ From 8a925b2ab4403b99169b34c892ffb0d7d5e43955 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 16:49:22 +0200 Subject: [PATCH 39/89] added messages to assertions for SimplexBijector --- src/interface.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 86e46465..a5231718 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -449,7 +449,7 @@ end function (b::SimplexBijector{Val{proj}})(x::AbstractVector{T}) where {T, proj} y, K = similar(x), length(x) - @assert K > 1 + @assert K > 1 "x needs to be of length greater than 1" ϵ = _eps(T) sum_tmp = zero(T) @@ -475,7 +475,7 @@ end # Vectorised implementation of the above. function (b::SimplexBijector{Val{proj}})(X::AbstractMatrix{T}) where {T<:Real, proj} Y, K, N = similar(X), size(X, 1), size(X, 2) - @assert K > 1 + @assert K > 1 "x needs to be of length greater than 1" ϵ = _eps(T) @inbounds @simd for n in 1:size(X, 2) @@ -500,7 +500,7 @@ end function (ib::Inversed{<:SimplexBijector{Val{proj}}})(y::AbstractVector{T}) where {T, proj} x, K = similar(y), length(y) - @assert K > 1 + @assert K > 1 "x needs to be of length greater than 1" ϵ = _eps(T) @inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1))) @@ -526,7 +526,7 @@ function (ib::Inversed{<:SimplexBijector{Val{proj}}})( Y::AbstractMatrix{T} ) where {T<:Real, proj} X, K, N = similar(Y), size(Y, 1), size(Y, 2) - @assert K > 1 + @assert K > 1 "x needs to be of length greater than 1" ϵ = _eps(T) @inbounds @simd for n in 1:size(X, 2) From ea2eddeeb1ae1c1375d6109dda23302083ce6fcc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 16:57:10 +0200 Subject: [PATCH 40/89] updated Manifest --- Manifest.toml | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index cc6f7d86..26258388 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -85,10 +85,10 @@ deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[Distributions]] -deps = ["Distributed", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "022e6610c320b6e19b454502d759c672580abe00" +deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] +git-tree-sha1 = "baaf9e165ba8a2d11fb4fb3511782ee070ee3694" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.18.0" +version = "0.21.1" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] @@ -130,10 +130,9 @@ deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[Missings]] -deps = ["SparseArrays", "Test"] -git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007" +git-tree-sha1 = "29858ce6c8ae629cf2d733bffa329619a1c843d0" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.1" +version = "0.4.2" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -158,9 +157,9 @@ version = "1.1.0" [[PDMats]] deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde" +git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.9.9" +version = "0.9.10" [[Pkg]] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -232,10 +231,10 @@ deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["BinDeps", "BinaryProvider", "Libdl"] -git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e" +deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] +git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "0.8.0" +version = "0.7.2" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] @@ -280,9 +279,9 @@ version = "0.5.6" [[Tracker]] deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] -git-tree-sha1 = "327342fec6e09f68ced0c2dc5731ed475e4b696b" +git-tree-sha1 = "1aa443d3b4bfa91a8aec32f169a479cb87309910" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.2" +version = "0.2.3" [[URIParser]] deps = ["Test", "Unicode"] From a5af0fc82cd08006cf303dbf90c2cfa575952bc2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 17:45:47 +0200 Subject: [PATCH 41/89] updated README to include info about Stacked --- README.md | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/README.md b/README.md index b2e2cfc4..ba142a93 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,82 @@ inv(td.transform)(rand(td)) will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._ +### Multivariate ADVI example +We can also do _multivariate_ ADVI using the `Stacked` bijector. `Stacked` gives us a way to combine univariate and/or multivariate bijectors into a singe multivariate bijector. Say you have a vector `x` of length 2 and you want to transform the first entry using `Exp` and the second entry using `Log`. `Stacked` gives you an easy and efficient way of representing such a bijector. + +```julia +julia> using Bijectors + +julia> using Bijectors: Exp, Log, SimplexBijector + +julia> # Original distributions + dists = ( + Beta(), + InverseGamma(), + Dirichlet(2, 3) + ); + +julia> # Construct the corresponding ranges + ranges = []; + +julia> idx = 1; + +julia> for i = 1:length(dists) + d = dists[i] + push!(ranges, idx:idx + length(d) - 1) + + global idx + idx += length(d) + end; + +julia> ranges +3-element Array{Any,1}: + 1:1 + 2:2 + 3:4 + +julia> # Base distribution; mean-field normal + num_params = ranges[end][end] +4 + +julia> d = MvNormal(zeros(num_params), ones(num_params)) +DiagNormal( +dim: 4 +μ: [0.0, 0.0, 0.0, 0.0] +Σ: [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0] +) + + +julia> # Construct the transform + bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists +(Logit{Float64}(0.0, 1.0), Log(), SimplexBijector{Val{true}}()) + +julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained +(Inversed{Logit{Float64}}(Logit{Float64}(0.0, 1.0)), Exp(), Inversed{SimplexBijector{Val{true}}}(SimplexBijector{Val{true}}())) + +julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector +Stacked{Tuple{Inversed{Logit{Float64}},Exp,Inversed{SimplexBijector{Val{true}}}},3}((Inversed{Logit{Float64}}(Logit{Float64}(0.0, 1.0)), Exp(), Inversed{SimplexBijector{Val{true}}}(SimplexBijector{Val{true}}())), (1:1, 2:2, 3:4)) + +julia> # Mean-field normal with unconstrained-to-constrained stacked bijector + td = transformed(d, sb); + +julia> y = rand(td) +4-element Array{Float64,1}: + 0.33551575658457006 + 0.12139631354191643 + 0.3900060432982573 + 0.6099939567017427 + +julia> 0.0 ≤ y[1] ≤ 1.0 # => true +true + +julia> 0.0 < y[2] # => true +true + +julia> sum(y[3:4]) ≈ 1.0 # => true +true +``` + ### Normalizing flows A very interesting application is that of _normalizing flows_.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: `PlanarFlow` and `RadialFlow`. Let's create a flow with a single `PlanarLayer`: @@ -285,6 +361,33 @@ julia> logpdf_forward(flow, x) # more efficent and accurate -2.2489445532797867 ``` +Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bounded_ flow: + +```julia +julia> d = MvNormal(zeros(2), ones(2)); + +julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta()))); + +julia> sb = vcat(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)] +Stacked{Tuple{Exp,Inversed{Logit{Float64}}},2}((Exp(), Inversed{Logit{Float64}}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2)) + +julia> b = sb ∘ PlanarLayer(2) +Composed{Tuple{PlanarLayer{Array{Float64,2},Array{Float64,1}},Stacked{Tuple{Exp,Inversed{Logit{Float64}}},2}}}((PlanarLayer{Array{Float64,2},Array{Float64,1}}([-2.00615; 1.17336], [0.248405; -0.319774], [0.481679]), Stacked{Tuple{Exp,Inversed{Logit{Float64}}},2}((Exp(), Inversed{Logit{Float64}}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2)))) + +julia> td = transformed(d, b); + +julia> y = rand(td) +2-element Array{Float64,1}: + 1.026123210859092 + 0.4412529471603579 + +julia> 0 < y[1] +true + +julia> 0 ≤ y[2] ≤ 1 +true +``` + Want to fit the flow? ```julia @@ -349,6 +452,10 @@ julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns a This method is for example useful when computing quantities such as the _expected lower bound (ELBO)_ between this transformed distribution and some other joint density. If no analytical expression is available, we have to approximate the ELBO by a Monte Carlo estimate. But one term in the ELBO is the entropy of the base density, which we _do_ know analytically in this case. Using the analytical expression for the entropy and then using a monte carlo estimate for the rest of the terms in the ELBO gives an estimate with lower variance than if we used the monte carlo estimate for the entire expectation. + +### Normalizing flows with bounded support + + ## Implementing your own `Bijector` There's mainly two ways you can implement your own `Bijector`, and which way you choose mainly depends on the following question: are you bothered enough to manually implement `logabsdetjac`? If the answer is "Yup!", then you subtype from `Bijector`, if "Naaaah" then you subtype `ADBijector`. From 7ff118dc5616a9f6d5d773a364ac1e09e4ecdc60 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 17:47:17 +0200 Subject: [PATCH 42/89] added Stacked to reference-section --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ba142a93..303a015f 100644 --- a/README.md +++ b/README.md @@ -634,6 +634,7 @@ The following are the bijectors available: - `ADBijector{AD} <: Bijector`: subtypes of this only require the user to implement `(b::UserBijector)(x)` and `(ib::Inversed{<:UserBijector})(y)`. Automatic differentation will be used to compute the `jacobian(b, x)` and thus `logabsdetjac(b, x). - Concrete: - `Composed`: represents a composition of bijectors. + - `Stacked`: stacks univariate and multivariate bijectors - `Identity`: does what it says, i.e. nothing. - `Logit` - `Exp` From 2707443d2f82aeee9ae1dfb84281f369f7229469 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 18:53:58 +0200 Subject: [PATCH 43/89] removed useless line --- test/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index c0038994..9b7452ad 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -356,7 +356,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end ibs = inv.(bs) # invert, so we get unconstrained-to-constrained sb = Stacked(ibs, ranges) # => Stacked <: Bijector x = rand(d) - sb(x) + @test sb isa Stacked td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} From 5a7a76ae7526c1f47c43ae2f1172f5af10dede9b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 20:50:57 +0200 Subject: [PATCH 44/89] added forward-specialization for Stacked --- src/interface.jl | 77 ++++++++++++++++++++++++++++++++++++----------- test/interface.jl | 36 +++++++++++++++++++--- 2 files changed, 92 insertions(+), 21 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index a5231718..0c0de410 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -305,17 +305,7 @@ Base.vcat(bs::Bijector...) = Stacked(bs) inv(sb::Stacked) = Stacked(inv.(sb.bs), sb.ranges) # TODO: Is there a better approach to this? -@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs::Bijector...) where N - exprs = [] - for i = 1:N - push!(exprs, :(bs[$i](x[rs[$i]]))) - end - - return :(vcat($(exprs...))) -end -_transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) = b(x) - -function (sb::Stacked)(x::AbstractArray{<: Real}) +function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real}) y = _transform(x, sb.ranges, sb.bs...) # TODO: maybe tell user to check their ranges? @@ -323,24 +313,33 @@ function (sb::Stacked)(x::AbstractArray{<: Real}) return y end +function (sb::Stacked{<:AbstractArray, N})(x::AbstractVector{<:Real}) where {N} + return vcat([sb.bs[i](x[sb.ranges[i]]) for i = 1:N]...) +end + (sb::Stacked)(x::AbstractMatrix{<: Real}) = hcat([sb(x[:, i]) for i = 1:size(x, 2)]...) function (sb::Stacked)(x::TrackedArray{A, 2}) where {A} return Tracker.collect(hcat([sb(x[:, i]) for i = 1:size(x, 2)]...)) end -@generated function _logabsdetjac( - x, - rs::NTuple{N, UnitRange{Int}}, - bs::Bijector... +@generated function logabsdetjac( + b::Stacked{<:Tuple, N}, + x::AbstractVector{<:Real} ) where {N} exprs = [] for i = 1:N - push!(exprs, :(sum(logabsdetjac(bs[$i], x[rs[$i]])))) + push!(exprs, :(sum(logabsdetjac(b.bs[$i], x[b.ranges[$i]])))) end return :(sum([$(exprs...), ])) end -logabsdetjac(b::Stacked, x::AbstractVector{<: Real}) = _logabsdetjac(x, b.ranges, b.bs...) +function logabsdetjac( + b::Stacked{<:AbstractArray, N}, + x::AbstractVector{<:Real} +) where {N} + # TODO: drop the `sum` when we have dimensionality + return sum([sum(logabsdetjac(b.bs[i], x[b.ranges[i]])) for i = 1:N]) +end function logabsdetjac(b::Stacked, x::AbstractMatrix{<: Real}) return [logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)] end @@ -348,6 +347,50 @@ function logabsdetjac(b::Stacked, x::TrackedArray{A, 2}) where {A} return Tracker.collect([logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)]) end +# Generates something similar to: +# +# quote +# (y_1, logjac) = forward(b.bs[1], x[b.ranges[1]]) +# (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) +# logjac += _logjac +# y = vcat(tuple(y_1, y_2)...) +# return (rv = y, logabsdetjac = logjac) +# end +@generated function forward(b::Stacked{T, N}, x::AbstractVector) where {N, T<:Tuple} + expr = Expr(:block) + e = Expr(:call, :tuple) + + push!(expr.args, :((y_1, logjac) = forward(b.bs[1], x[b.ranges[1]]))) + push!(e.args, :y_1) + for i = 2:length(T.parameters) + y_name = Symbol("y_$i") + push!(expr.args, :(($y_name, _logjac) = forward(b.bs[$i], x[b.ranges[$i]]))) + push!(expr.args, :(logjac += _logjac)) + + push!(e.args, y_name) + end + + push!(expr.args, :(y = vcat($e...))) + + # TODO: drop the `sum` when we have dimensionality + push!(expr.args, :(return (rv = y, logabsdetjac = sum(logjac)))) + + return expr +end + +function forward(sb::Stacked{<:AbstractArray, N}, x::AbstractVector) where {N} + ys = [] + logjacs = [] + for i = 1:N + y, logjac = forward(sb.bs[i], x[sb.ranges[i]]) + push!(ys, y) + # TODO: drop the `sum` when we have dimensionality + push!(logjacs, sum(logjac)) + end + + return (rv = vcat(ys...), logabsdetjac = sum(logjacs)) +end + ############################## # Example bijector: Identity # ############################## diff --git a/test/interface.jl b/test/interface.jl index 9b7452ad..c696458e 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -294,27 +294,55 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end end @testset "Stacked <: Bijector" begin - # `logabsdetjac` without AD + # `logabsdetjac` withOUT AD d = Beta() b = bijector(d) x = rand(d) y = b(x) - sb = vcat(b, b, inv(b), inv(b)) - @test logabsdetjac(sb, [x, x, y, y]) ≈ 0.0 + + sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple + res1 = forward(sb1, [x, x, y, y]) + + @test sb1([x, x, y, y]) == res1.rv + @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 + @test res1.logabsdetjac ≈ 0.0 + + sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + res2 = forward(sb2, [x, x, y, y]) + + @test sb2([x, x, y, y]) == res2.rv + @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 + @test res2.logabsdetjac ≈ 0.0 # `logabsdetjac` with AD b = DistributionBijector(d) y = b(x) + sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple - sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + res1 = forward(sb1, [x, x, y, y]) + + @test sb1([x, x, y, y]) == res1.rv @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 + @test res1.logabsdetjac ≈ 0.0 + + sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + res2 = forward(sb2, [x, x, y, y]) + + @test sb2([x, x, y, y]) == res2.rv @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 + @test res2.logabsdetjac ≈ 0.0 + + @which logabsdetjac(sb2, [x, x, y, y]) # value-test x = ones(3) sb = vcat(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) + res = forward(sb, x) @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] + @test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0] @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[i]) for i = 1:3]) + @test res.logabsdetjac == logabsdetjac(sb, x) + # TODO: change when we have dimensionality in the type x = ones(4) From 64070ea449114f6ac990add8a461e232b382150e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 20:54:02 +0200 Subject: [PATCH 45/89] removed something by accident last commit --- src/interface.jl | 10 ++++++++++ test/interface.jl | 2 -- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 0c0de410..bcc6c206 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -305,6 +305,16 @@ Base.vcat(bs::Bijector...) = Stacked(bs) inv(sb::Stacked) = Stacked(inv.(sb.bs), sb.ranges) # TODO: Is there a better approach to this? +@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs::Bijector...) where N + exprs = [] + for i = 1:N + push!(exprs, :(bs[$i](x[rs[$i]]))) + end + + return :(vcat($(exprs...))) +end +_transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) = b(x) + function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real}) y = _transform(x, sb.ranges, sb.bs...) diff --git a/test/interface.jl b/test/interface.jl index c696458e..28b7a613 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -332,8 +332,6 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 @test res2.logabsdetjac ≈ 0.0 - @which logabsdetjac(sb2, [x, x, y, y]) - # value-test x = ones(3) sb = vcat(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) From 9318f95cbb35e5da9b12f7b4cd21badaa587cbff Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 21:06:54 +0200 Subject: [PATCH 46/89] fixed mixed Stacked and added tests --- src/interface.jl | 11 +++++++---- test/interface.jl | 13 ++++++++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index bcc6c206..5ac22930 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -370,20 +370,23 @@ end expr = Expr(:block) e = Expr(:call, :tuple) - push!(expr.args, :((y_1, logjac) = forward(b.bs[1], x[b.ranges[1]]))) + push!(expr.args, :((y_1, _logjac) = forward(b.bs[1], x[b.ranges[1]]))) + # TODO: drop the `sum` when we have dimensionality + push!(expr.args, :(logjac = sum(_logjac))) push!(e.args, :y_1) for i = 2:length(T.parameters) y_name = Symbol("y_$i") push!(expr.args, :(($y_name, _logjac) = forward(b.bs[$i], x[b.ranges[$i]]))) - push!(expr.args, :(logjac += _logjac)) + + # TODO: drop the `sum` when we have dimensionality + push!(expr.args, :(logjac += sum(_logjac))) push!(e.args, y_name) end push!(expr.args, :(y = vcat($e...))) - # TODO: drop the `sum` when we have dimensionality - push!(expr.args, :(return (rv = y, logabsdetjac = sum(logjac)))) + push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) return expr end diff --git a/test/interface.jl b/test/interface.jl index 28b7a613..df3de1f2 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -338,14 +338,21 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end res = forward(sb, x) @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0] - @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[i]) for i = 1:3]) + @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[sb.ranges[i]]) for i = 1:3]) @test res.logabsdetjac == logabsdetjac(sb, x) # TODO: change when we have dimensionality in the type - x = ones(4) sb = Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) - @test_throws AssertionError sb(x ./ sum(x)) + x = ones(3) ./ 3.0 + res = forward(sb, x) + @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[sb.ranges[i]]) for i = 1:2]) + @test res.logabsdetjac == logabsdetjac(sb, x) + + x = ones(4) ./ 4.0 + @test_throws AssertionError sb(x) @test_throws AssertionError Stacked([Bijectors.Exp(), ], (1:1, 2:3)) @test_throws MethodError Stacked((Bijectors.Exp(), ), (1:1, 2:3)) From 89c16cdf1a4bc8b47ebbec1d05e398497dbb0849 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 21:16:02 +0200 Subject: [PATCH 47/89] removed no-longer-needed TODO and added more docstring --- src/interface.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 5ac22930..2b606765 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -268,7 +268,14 @@ end vcat(bs::Bijector...) A `Bijector` which stacks bijectors together which can then be applied to a vector -where `bs[i]::Bijector` is applied to `x[ranges[i]]`. +where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. + +# Arguments +- `bs` can be either a `Tuple` or an `AbstractArray` of bijectors. + - If `bs` is a `Tuple`, implementations are type-stable using generated functions + - If `bs` is an `AbstractArray`, implementations are _not_ type-stable and uses iterative methods +- `ranges` needs to be an iterable consisting of `UnitRange{Int}` + - `length(bs) == length(ranges)` needs to be true. # Examples ``` @@ -317,8 +324,6 @@ _transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) = b(x) function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real}) y = _transform(x, sb.ranges, sb.bs...) - - # TODO: maybe tell user to check their ranges? @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" return y From fd4eec1597beb21c3a0238e9bac04e1e11ecf727 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 21:17:47 +0200 Subject: [PATCH 48/89] fixed comment --- src/interface.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 2b606765..33a5d047 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -366,10 +366,11 @@ end # # quote # (y_1, logjac) = forward(b.bs[1], x[b.ranges[1]]) +# logjac = logjac_ # (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) -# logjac += _logjac -# y = vcat(tuple(y_1, y_2)...) -# return (rv = y, logabsdetjac = logjac) +# logjac += _logjac +# y = vcat(tuple(y_1, y_2)...) +# return (rv = y, logabsdetjac = logjac) # end @generated function forward(b::Stacked{T, N}, x::AbstractVector) where {N, T<:Tuple} expr = Expr(:block) From 2f8c1a05788d4345947253dd5e8d2635e83492dc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 21:19:20 +0200 Subject: [PATCH 49/89] removed redundant whitespace --- src/interface.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 33a5d047..129251c5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -391,9 +391,7 @@ end end push!(expr.args, :(y = vcat($e...))) - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) - return expr end From 597bb964037a460aef172b0eaf2ac8b879f5d3c2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 16 Sep 2019 21:28:45 +0200 Subject: [PATCH 50/89] added assert-check to array-impl of Stacked --- src/interface.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 129251c5..e32a7592 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -325,11 +325,12 @@ _transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector) = b(x) function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real}) y = _transform(x, sb.ranges, sb.bs...) @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" - return y end function (sb::Stacked{<:AbstractArray, N})(x::AbstractVector{<:Real}) where {N} - return vcat([sb.bs[i](x[sb.ranges[i]]) for i = 1:N]...) + y = vcat([sb.bs[i](x[sb.ranges[i]]) for i = 1:N]...) + @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" + return y end (sb::Stacked)(x::AbstractMatrix{<: Real}) = hcat([sb(x[:, i]) for i = 1:size(x, 2)]...) From 6daf7aa01be8d8d7ad47c2aac83a25dc221c1765 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Sep 2019 09:16:16 +0200 Subject: [PATCH 51/89] removed redundant type and fixed typo --- src/interface.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index e32a7592..7a0baab2 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -273,7 +273,7 @@ where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. # Arguments - `bs` can be either a `Tuple` or an `AbstractArray` of bijectors. - If `bs` is a `Tuple`, implementations are type-stable using generated functions - - If `bs` is an `AbstractArray`, implementations are _not_ type-stable and uses iterative methods + - If `bs` is an `AbstractArray`, implementations are _not_ type-stable and use iterative methods - `ranges` needs to be an iterable consisting of `UnitRange{Int}` - `length(bs) == length(ranges)` needs to be true. @@ -304,7 +304,7 @@ struct Stacked{B, N} <: Bijector where N return new{A, N}(bs, ranges) end end -Stacked(bs) = Stacked(bs, NTuple{length(bs), UnitRange{Int}}([i:i for i = 1:length(bs)])) +Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...)) Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...)) Base.vcat(bs::Bijector...) = Stacked(bs) From fce177bed6f5703c84e5045af4766591e6eaa791 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Sep 2019 10:04:50 +0200 Subject: [PATCH 52/89] removed unused consts --- src/interface.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 7a0baab2..3e978da1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -496,9 +496,6 @@ struct SimplexBijector{T} <: Bijector where {T} end SimplexBijector(proj::Bool) = SimplexBijector{Val{proj}}() SimplexBijector() = SimplexBijector(true) -const simplex_b = SimplexBijector{Val{false}}() -const simplex_b_proj = SimplexBijector{Val{true}}() - # The following implementations are basically just copy-paste from `invlink` and # `link` for `SimplexDistributions` but dropping the dependence on the `Distribution`. function _clamp(x::T, b::Union{SimplexBijector, Inversed{<:SimplexBijector}}) where {T} From 170b15e2595fb2e2564f39c4c9a1d37209ec5dd1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Sep 2019 10:49:56 +0200 Subject: [PATCH 53/89] made the forward(b::Stacked, ...) generated slightly nicer --- src/interface.jl | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 3e978da1..d287281f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -366,21 +366,20 @@ end # Generates something similar to: # # quote -# (y_1, logjac) = forward(b.bs[1], x[b.ranges[1]]) -# logjac = logjac_ -# (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) -# logjac += _logjac -# y = vcat(tuple(y_1, y_2)...) -# return (rv = y, logabsdetjac = logjac) +# (y_1, _logjac) = forward(b.bs[1], x[b.ranges[1]]) +# logjac = sum(_logjac) +# (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) +# logjac += sum(_logjac) +# return (rv = vcat(y_1, y_2), logabsdetjac = logjac) # end @generated function forward(b::Stacked{T, N}, x::AbstractVector) where {N, T<:Tuple} expr = Expr(:block) - e = Expr(:call, :tuple) + y_names = [] push!(expr.args, :((y_1, _logjac) = forward(b.bs[1], x[b.ranges[1]]))) # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac = sum(_logjac))) - push!(e.args, :y_1) + push!(y_names, :y_1) for i = 2:length(T.parameters) y_name = Symbol("y_$i") push!(expr.args, :(($y_name, _logjac) = forward(b.bs[$i], x[b.ranges[$i]]))) @@ -388,11 +387,10 @@ end # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac += sum(_logjac))) - push!(e.args, y_name) + push!(y_names, y_name) end - push!(expr.args, :(y = vcat($e...))) - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (rv = vcat($(y_names...)), logabsdetjac = logjac))) return expr end From 568b192490087d670840082bebddcf621e742a21 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Sep 2019 16:38:31 +0200 Subject: [PATCH 54/89] added batch-impls for CouplingLayer since this is unambiguous --- src/couplings.jl | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/couplings.jl b/src/couplings.jl index 575d9071..bcf8d2f4 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -122,7 +122,7 @@ function CouplingLayer(cl::CouplingLayer{B}, mask::PartitionMask) where {B} return CouplingLayer(B, mask, cl.θ) end -function (cl::CouplingLayer{B})(x) where {B} +function (cl::CouplingLayer{B})(x::AbstractVector) where {B} # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) @@ -132,9 +132,13 @@ function (cl::CouplingLayer{B})(x) where {B} # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end +# TODO: drop when we have dimensionality +function (cl::CouplingLayer{B})(x::AbstractMatrix) where {B} + return mapslices(z -> cl(z), x; dims = 1) +end -function (icl::Inversed{<:CouplingLayer{B}})(y) where {B} +function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractVector) where {B} cl = icl.orig y_1, y_2, y_3 = partition(cl.mask, y) @@ -144,11 +148,20 @@ function (icl::Inversed{<:CouplingLayer{B}})(y) where {B} return combine(cl.mask, ib(y_1), y_2, y_3) end +# TODO: drop when we have dimensionality +function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractMatrix) where {B} + return mapslices(z -> icl(z), y; dims = 1) +end - -function logabsdetjac(cl::CouplingLayer{B}, x) where {B} +function logabsdetjac(cl::CouplingLayer{B}, x::AbstractVector) where {B} x_1, x_2, x_3 = partition(cl.mask, x) b = B(cl.θ(x_2)) return logabsdetjac(b, x_1) end + +function logabsdetjac(cl::CouplingLayer{B}, x::AbstractMatrix) where {B} + return vec(mapslices(z -> logabsdetjac(cl, z), x; dims = 1)) +end + + From 89fd618dfa55da4f65be9679151ed58e82b50cfb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Sep 2019 16:48:08 +0200 Subject: [PATCH 55/89] added dependency on StaticArrays --- Manifest.toml | 9 ++++----- Project.toml | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index ceaf096d..dfe2b2e1 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -130,10 +130,9 @@ deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[Missings]] -deps = ["SparseArrays", "Test"] -git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007" +git-tree-sha1 = "29858ce6c8ae629cf2d733bffa329619a1c843d0" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.1" +version = "0.4.2" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -158,9 +157,9 @@ version = "1.1.0" [[PDMats]] deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde" +git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.9.9" +version = "0.9.10" [[Pkg]] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] diff --git a/Project.toml b/Project.toml index 17f58a7c..5bd77bbf 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" From 648c3bebda5f93c10d4f8da5b4c07e66a3c55055 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 18 Sep 2019 17:06:12 +0200 Subject: [PATCH 56/89] export coupling layer --- src/Bijectors.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 47772d16..89df41f0 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -36,7 +36,8 @@ export TransformDistribution, logpdf_with_jac, logpdf_forward, PlanarLayer, - RadialLayer + RadialLayer, + CouplingLayer const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) From 0f06735540e67b2ec6922514d590b5af4863cb2c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Sep 2019 00:11:23 +0200 Subject: [PATCH 57/89] bump to 0.4.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8022aa93..2e71c86a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.4.0" +version = "0.4.1" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" From a85da36f2df428cdf6e868fb0ace614fcb3374bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Sep 2019 09:30:37 +0200 Subject: [PATCH 58/89] fixed size-stuff for the flows --- src/norm_flows.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/norm_flows.jl b/src/norm_flows.jl index 120ae676..aebf032b 100644 --- a/src/norm_flows.jl +++ b/src/norm_flows.jl @@ -56,11 +56,11 @@ function _forward(flow::PlanarLayer, z) return (rv=transformed, logabsdetjac=vec(log_det_jacobian)) # from eq(10) end -forward(flow::PlanarLayer, z) = _forward(flow, z) +forward(flow::PlanarLayer, z::AbstractMatrix) = _forward(flow, z) function forward(flow::PlanarLayer, z::AbstractVector{<: Real}) res = _forward(flow, z) - return (rv=res.rv, logabsdetjac=res.logabsdetjac[1]) + return (rv=vec(res.rv), logabsdetjac=res.logabsdetjac[1]) end @@ -115,7 +115,8 @@ function _transform(flow::RadialLayer, z) return (transformed=transformed, α=α, β_hat=β_hat, r=r) end -(b::RadialLayer)(z) = _transform(b, z).transformed +(b::RadialLayer)(z::AbstractMatrix{<:Real}) = _transform(b, z).transformed +(b::RadialLayer)(z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) function _forward(flow::RadialLayer, z) transformed, α, β_hat, r = _transform(flow, z) @@ -131,13 +132,12 @@ end forward(flow::RadialLayer, z) = _forward(flow, z) -function forward(flow::RadialLayer, z::AbstractVector{<: Real}) +function forward(flow::RadialLayer, z::AbstractVector{<:Real}) res = _forward(flow, z) - return (rv=res.rv[:, 1], logabsdetjac=res.logabsdetjac[1]) + return (rv=vec(res.rv), logabsdetjac=res.logabsdetjac[1]) end -# function inv(flow::RadialLayer, y) -function (ib::Inversed{<: RadialLayer})(y) +function (ib::Inversed{<:RadialLayer})(y) flow = ib.orig α = softplus(flow.α_[1]) # from A.2 β_hat = - α + softplus(flow.β[1]) # from A.2 @@ -150,7 +150,7 @@ function (ib::Inversed{<: RadialLayer})(y) return z end -function (ib::Inversed{<: RadialLayer})(y::AbstractVector{<: Real}) +function (ib::Inversed{<: RadialLayer})(y::AbstractVector{<:Real}) return vec(ib(reshape(y, (length(y), 1)))) end From 07c0d87f27f55243d7e1816012459c81537d57b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Sep 2019 09:49:12 +0200 Subject: [PATCH 59/89] added more tests for batch computation --- test/interface.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 95c81911..1e446f48 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -4,7 +4,7 @@ using Random using LinearAlgebra using ForwardDiff -using Bijectors: Log, Exp, Shift, Scale, Logit +using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector Random.seed!(123) @@ -136,7 +136,12 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end (Log{1}() ∘ Exp{1}(), randn(2, 3)), (inv(Logit(-1.0, 1.0)), randn(3)), (Identity{0}(), randn(3)), - (Identity{1}(), randn(2, 3)) + (Identity{1}(), randn(2, 3)), + (PlanarLayer(2), randn(2, 3)), + (RadialLayer(2), randn(2, 3)), + (PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), + (Exp{1}() ∘ PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), + (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)) ] for (b, xs) in bs_xs From b458dab32534521d28e85439b31fcc6bd20826a9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Sep 2019 09:50:00 +0200 Subject: [PATCH 60/89] updated SimplexBijector --- src/interface.jl | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 26c3ccd4..2ea1dca8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -388,9 +388,8 @@ logabsdetjac(b::Scale{<:AbstractVector, 1}, x::AbstractMatrix) = sum(log.(abs.(b # Simplex bijector # #################### struct SimplexBijector{T} <: Bijector{1} where {T} end - -const simplex_b = SimplexBijector{Val{false}}() -const simplex_b_proj = SimplexBijector{Val{true}}() +SimplexBijector(proj::Bool) = SimplexBijector{Val{proj}}() +SimplexBijector() = SimplexBijector(true) # The following implementations are basically just copy-paste from `invlink` and # `link` for `SimplexDistributions` but dropping the dependence on the `Distribution`. @@ -517,6 +516,10 @@ function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where T return - lp end +function logabsdetjac(b::SimplexBijector, x::AbstractMatrix{<:Real}) + return vec(mapslices(z -> logabsdetjac(b, z), x; dims = 1)) +end + ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### @@ -532,8 +535,14 @@ the `jacobian` and `logabsdetjac`. struct DistributionBijector{AD, D, N} <: ADBijector{AD, N} where {D<:Distribution} dist::D end -function DistributionBijector(dist::D) where {D<:Distribution} - DistributionBijector{ADBackend(), D, length(size(dist))}(dist) +function DistributionBijector(dist::D) where {D<:UnivariateDistribution} + DistributionBijector{ADBackend(), D, 0}(dist) +end +function DistributionBijector(dist::D) where {D<:MultivariateDistribution} + DistributionBijector{ADBackend(), D, 1}(dist) +end +function DistributionBijector(dist::D) where {D<:MatrixDistribution} + DistributionBijector{ADBackend(), D, 2}(dist) end # Simply uses `link` and `invlink` as transforms with AD to get jacobian @@ -574,11 +583,12 @@ transformed(d) = transformed(d, bijector(d)) Returns the constrained-to-unconstrained bijector for distribution `d`. """ +bijector(d::Distribution) = DistributionBijector(d) bijector(d::Normal) = Identity{0}() bijector(d::MvNormal) = Identity{1}() bijector(d::PositiveDistribution) = Log{0}() bijector(d::MvLogNormal) = Log{0}() -bijector(d::SimplexDistribution) = simplex_b_proj +bijector(d::SimplexDistribution) = SimplexBijector{Val{true}}() bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d))) bijector_bounded(d, a=minimum(d), b=maximum(d)) = Logit(a, b) From a47516d2fe7c07c5f2a8c719f63e73cf6a47e333 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Sep 2019 11:55:11 +0200 Subject: [PATCH 61/89] added more and better testing for batch support --- src/interface.jl | 2 -- test/interface.jl | 48 +++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 2ea1dca8..b6a83012 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -281,8 +281,6 @@ struct Identity{N} <: Bijector{N} end (::Identity)(x) = x inv(b::Identity) = b -forward(::Identity, x) = (rv=x, logabsdetjac=zero(eltype(x))) - logabsdetjac(::Identity, x::Real) = zero(eltype(x)) @generated function logabsdetjac( b::Identity{N1}, diff --git a/test/interface.jl b/test/interface.jl index 1e446f48..a34c1ee6 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -159,29 +159,61 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end x_ = ib(y) xs_ = ib(ys) + result = forward(b, x) + results = forward(b, xs) + + iresult = forward(ib, y) + iresults = forward(ib, ys) + + # Sizes @test size(y) == size(x) @test size(ys) == size(xs) + @test size(x_) == size(x) @test size(xs_) == size(xs) + @test size(result.rv) == size(x) + @test size(results.rv) == size(xs) + + @test size(iresult.rv) == size(y) + @test size(iresults.rv) == size(ys) + + # Values + @test ys == mapslices(z -> b(z), xs; dims = 1) + @test ys ≈ results.rv + if D == 0 + # Sizes @test y == ys[1] @test length(logabsdetjac(b, xs)) == length(xs) - @test logabsdetjac(b, x) == logabsdetjac(b, xs)[1] - @test length(logabsdetjac(ib, ys)) == length(xs) - @test logabsdetjac(ib, y) == logabsdetjac(ib, ys)[1] - elseif Bijectors.dimension(b) == 1 + + @test size(results.logabsdetjac) == size(xs, ) + @test size(iresults.logabsdetjac) == size(ys, ) + + # Values + @test logabsdetjac.(b, xs) == logabsdetjac(b, xs) + @test logabsdetjac.(ib, ys) == logabsdetjac(ib, ys) + + @test results.logabsdetjac ≈ vec(logabsdetjac.(b, xs)) + @test iresults.logabsdetjac ≈ vec(logabsdetjac.(ib, ys)) + elseif D == 1 @test y == ys[:, 1] # Comparing sizes instead of lengths ensures we catch errors s.t. # length(x) == 3 when size(x) == (1, 3). - # We want the return value to + # Sizes @test size(logabsdetjac(b, xs)) == (size(xs, 2), ) - @test logabsdetjac(b, x) == logabsdetjac(b, xs)[1] - @test size(logabsdetjac(ib, ys)) == (size(xs, 2), ) - @test logabsdetjac(ib, y) == logabsdetjac(ib, ys)[1] + + @test size(results.logabsdetjac) == (size(xs, 2), ) + @test size(iresults.logabsdetjac) == (size(ys, 2), ) + + # Test all values + @test logabsdetjac(b, xs) == vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) + @test logabsdetjac(ib, ys) == vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) + @test results.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) + @test iresults.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) else error("tests not implemented yet") end From 3182f0d647cf86c7581143da83d45430d7a24686 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Sep 2019 12:08:08 +0200 Subject: [PATCH 62/89] added dimensionality to Stacked and tests --- src/interface.jl | 2 +- test/interface.jl | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 3878deda..7fa4b00e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -302,7 +302,7 @@ b = vcat(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` """ -struct Stacked{B, N} <: Bijector where N +struct Stacked{B, N} <: Bijector{1} where N bs::B ranges::NTuple{N, UnitRange{Int}} diff --git a/test/interface.jl b/test/interface.jl index fd82eccb..1aecaa21 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -141,7 +141,10 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end (RadialLayer(2), randn(2, 3)), (PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), (Exp{1}() ∘ PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), - (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)) + (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)), + (vcat(Exp{1}(), Scale(2.0)), randn(2, 3)), + (Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]), + mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)) ] for (b, xs) in bs_xs @@ -451,7 +454,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end res = forward(sb, x) @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0] - @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[sb.ranges[i]]) for i = 1:3]) + @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3]) @test res.logabsdetjac == logabsdetjac(sb, x) @@ -461,7 +464,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end res = forward(sb, x) @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] - @test logabsdetjac(sb, x) == sum([logabsdetjac(sb.bs[i], x[sb.ranges[i]]) for i = 1:2]) + @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) @test res.logabsdetjac == logabsdetjac(sb, x) x = ones(4) ./ 4.0 From b89423f48a2dd52e75276bcf3dd2fbe9970a9757 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Sep 2019 12:40:31 +0200 Subject: [PATCH 63/89] replaced vcat with stack to avoid unnecessary confusion --- README.md | 2 +- src/Bijectors.jl | 1 + src/interface.jl | 32 +++++++++++++++++++------------- test/interface.jl | 12 ++++++------ test/norm_flows.jl | 4 ++-- 5 files changed, 29 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ab537b5c..01912ed9 100644 --- a/README.md +++ b/README.md @@ -376,7 +376,7 @@ julia> d = MvNormal(zeros(2), ones(2)); julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta()))); -julia> sb = vcat(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)] +julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)] Stacked{Tuple{Exp,Inversed{Logit{Float64}}},2}((Exp(), Inversed{Logit{Float64}}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2)) julia> b = sb ∘ PlanarLayer(2) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index c4b87878..376951ea 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -25,6 +25,7 @@ export TransformDistribution, Composed, compose, Stacked, + stack, Identity, DistributionBijector, bijector, diff --git a/src/interface.jl b/src/interface.jl index 7fa4b00e..02ea33e9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -279,16 +279,18 @@ end -########### -# Stacked # -########### +const ZeroOrOneDimBijector = Union{Bijector{0}, Bijector{1}} + """ Stacked(bs) Stacked(bs, ranges) - vcat(bs::Bijector...) + stack(bs::Bijector{Dim=0}...) A `Bijector` which stacks bijectors together which can then be applied to a vector where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. # Arguments -- `bs` can be either a `Tuple` or an `AbstractArray` of bijectors. +- `bs` can be either a `Tuple` or an `AbstractArray` of 0- and/or 1-dimensional bijectors - If `bs` is a `Tuple`, implementations are type-stable using generated functions - If `bs` is an `AbstractArray`, implementations are _not_ type-stable and use iterative methods - `ranges` needs to be an iterable consisting of `UnitRange{Int}` @@ -296,20 +298,20 @@ where `bs[i]::Bijector` is applied to `x[ranges[i]]::UnitRange{Int}`. # Examples ``` -b1 = Logistic(0.0, 1.0) -b2 = Identity() -b = vcat(b1, b2) +b1 = Logit(0.0, 1.0) +b2 = Identity{0}() +b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` """ -struct Stacked{B, N} <: Bijector{1} where N - bs::B +struct Stacked{Bs, N} <: Bijector{1} where N + bs::Bs ranges::NTuple{N, UnitRange{Int}} function Stacked( bs::C, ranges::NTuple{N, UnitRange{Int}} - ) where {N, C<:Tuple{Vararg{<:Bijector, N}}} + ) where {N, C<:Tuple{Vararg{<:ZeroOrOneDimBijector, N}}} return new{C, N}(bs, ranges) end @@ -318,13 +320,14 @@ struct Stacked{B, N} <: Bijector{1} where N ranges::NTuple{N, UnitRange{Int}} ) where {N, A<:AbstractArray{<:Bijector}} @assert length(bs) == N "number of bijectors is not same as number of ranges" + @assert all(isa.(bs, ZeroOrOneDimBijector)) return new{A, N}(bs, ranges) end end -Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...)) Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...)) +Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...)) -Base.vcat(bs::Bijector...) = Stacked(bs) +stack(bs::Bijector{0}...) = Stacked(bs) inv(sb::Stacked) = Stacked(inv.(sb.bs), sb.ranges) @@ -350,7 +353,10 @@ function (sb::Stacked{<:AbstractArray, N})(x::AbstractVector{<:Real}) where {N} return y end -(sb::Stacked)(x::AbstractMatrix{<: Real}) = hcat([sb(x[:, i]) for i = 1:size(x, 2)]...) +# (sb::Stacked)(x::AbstractMatrix{<: Real}) = hcat([sb(x[:, i]) for i = 1:size(x, 2)]...) +(sb::Stacked)(x::AbstractMatrix{<: Real}) = mapslices(z -> sb(z), x; dims = 1) + +# TODO: implement custom adjoint since we can exploit block-diagonal nature of `Stacked` function (sb::Stacked)(x::TrackedArray{A, 2}) where {A} return Tracker.collect(hcat([sb(x[:, i]) for i = 1:size(x, 2)]...)) end @@ -374,10 +380,10 @@ function logabsdetjac( return sum([sum(logabsdetjac(b.bs[i], x[b.ranges[i]])) for i = 1:N]) end function logabsdetjac(b::Stacked, x::AbstractMatrix{<: Real}) - return [logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)] + return vec(mapslices(z -> logabsdetjac(b, z), x; dims = 1)) end function logabsdetjac(b::Stacked, x::TrackedArray{A, 2}) where {A} - return Tracker.collect([logabsdetjac(b, x[:, i]) for i = 1:size(x, 2)]) + return Tracker.collect(vec(mapslices(z -> logabsdetjac(b, z), x; dims = 1))) end # Generates something similar to: diff --git a/test/interface.jl b/test/interface.jl index 1aecaa21..bace393e 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -142,7 +142,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end (PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), (Exp{1}() ∘ PlanarLayer(2) ∘ RadialLayer(2), randn(2, 3)), (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)), - (vcat(Exp{1}(), Scale(2.0)), randn(2, 3)), + (stack(Exp{0}(), Scale(2.0)), randn(2, 3)), (Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]), mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)) ] @@ -416,7 +416,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end x = rand(d) y = b(x) - sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple + sb1 = stack(b, b, inv(b), inv(b)) # <= Tuple res1 = forward(sb1, [x, x, y, y]) @test sb1([x, x, y, y]) == res1.rv @@ -434,7 +434,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end b = DistributionBijector(d) y = b(x) - sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple + sb1 = stack(b, b, inv(b), inv(b)) # <= Tuple res1 = forward(sb1, [x, x, y, y]) @test sb1([x, x, y, y]) == res1.rv @@ -450,7 +450,7 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end # value-test x = ones(3) - sb = vcat(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) + sb = stack(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) res = forward(sb, x) @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0] @@ -512,8 +512,8 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end @test td isa Distribution{Multivariate, Continuous} # check that wrong ranges fails - sb = vcat(ibs...) - td = transformed(d, sb) + @test_throws MethodError stack(ibs...) + sb = Stacked(ibs) x = rand(d) @test_throws AssertionError sb(x) diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 9b12da44..29d48061 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -69,10 +69,10 @@ end x = rand(d) .+ 10 y = b(x) - sb = vcat(b1, b1) + sb = stack(b1, b1) @test all((sb ∘ b)(x) .≤ 1.0) - sb = vcat(b1, b2) + sb = stack(b1, b2) cb = (sb ∘ b) y = cb(x) @test (0 ≤ y[1] ≤ 1.0) && (0 < y[2]) From 7e7c5afaabe9c4e540f370d08bc4cf75dc233d38 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2019 00:44:54 +0200 Subject: [PATCH 64/89] dropped redundant dimensionality for TransformedDistribution --- src/interface.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index b4b74878..e6ed05f2 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -708,12 +708,13 @@ end (ib::Inversed{<:DistributionBijector})(y) = invlink(ib.orig.dist, y) # Transformed distributions -struct TransformedDistribution{D, B, V, N} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector{N}} +struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B<:Bijector} dist::D transform::B -end -function TransformedDistribution(d::D, b::B) where {V<:VariateForm, B<:Bijector, D<:Distribution{V, Continuous}} - return TransformedDistribution{D, B, V, length(size(d))}(d, b) + + TransformedDistribution(d::UnivariateDistribution, b::Bijector{0}) = new{typeof(d), typeof(b), Univariate}(d, b) + TransformedDistribution(d::MultivariateDistribution, b::Bijector{1}) = new{typeof(d), typeof(b), Multivariate}(d, b) + TransformedDistribution(d::MatrixDistribution, b::Bijector{2}) = new{typeof(d), typeof(b), Matrixvariate}(d, b) end From d3bb89204b9290730c4006d36d016668b6d1e537 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2019 12:30:28 +0200 Subject: [PATCH 65/89] fixed something that was mis-resolved --- test/interface.jl | 88 ----------------------------------------------- 1 file changed, 88 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 93511949..2d5b84af 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -437,94 +437,6 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD} end end end - @testset "Stacked <: Bijector" begin - # `logabsdetjac` without AD - d = Beta() - b = bijector(d) - x = rand(d) - y = b(x) - sb = vcat(b, b, inv(b), inv(b)) - @test logabsdetjac(sb, [x, x, y, y]) ≈ 0.0 - - # `logabsdetjac` with AD - b = DistributionBijector(d) - y = b(x) - sb1 = vcat(b, b, inv(b), inv(b)) # <= tuple - sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array - @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0.0 - @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 - - @testset "Stacked: ADVI with MvNormal" begin - # MvNormal test - dists = [ - Beta(), - Beta(), - Beta(), - InverseGamma(), - InverseGamma(), - Gamma(), - Gamma(), - InverseGamma(), - Cauchy(), - Gamma(), - MvNormal(zeros(2), ones(2)) - ] - - ranges = [] - idx = 1 - for i = 1:length(dists) - d = dists[i] - push!(ranges, idx:idx + length(d) - 1) - idx += length(d) - end - - num_params = ranges[end][end] - d = MvNormal(zeros(num_params), ones(num_params)) - - # Stacked{<:Array} - bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists - ibs = inv.(bs) # invert, so we get unconstrained-to-constrained - sb = Stacked(ibs, ranges) # => Stacked <: Bijector - x = rand(d) - sb(x) - @test sb isa Stacked - - td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} - @test td isa Distribution{Multivariate, Continuous} - - # check that wrong ranges fails - sb = vcat(ibs...) - td = transformed(d, sb) - x = rand(d) - @test_throws AssertionError sb(x) - - # Stacked{<:Tuple} - bs = bijector.(tuple(dists...)) - ibs = inv.(bs) - sb = Stacked(ibs, ranges) - isb = inv(sb) - @test sb isa Stacked{<: Tuple} - - # inverse - td = transformed(d, sb) - y = rand(td) - x = isb(y) - @test sb(x) ≈ y - - # verification of computation - x = rand(d) - y = sb(x) - y_ = vcat([ibs[i](x[ranges[i]]) for i = 1:length(dists)]...) - x_ = vcat([bs[i](y[ranges[i]]) for i = 1:length(dists)]...) - @test x ≈ x_ - @test y ≈ y_ - - # AD verification - @test log(abs(det(ForwardDiff.jacobian(sb, x)))) ≈ logabsdetjac(sb, x) - @test log(abs(det(ForwardDiff.jacobian(isb, y)))) ≈ logabsdetjac(isb, y) - end - end - @testset "Example: ADVI single" begin # Usage in ADVI d = Beta() From 9b2b27f69222cc6c1ce81b6e332e665b1e1ffdee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 20 Sep 2019 12:38:46 +0200 Subject: [PATCH 66/89] adapted CouplingLayer to new interface --- src/couplings.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/couplings.jl b/src/couplings.jl index bcf8d2f4..6463af21 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -107,7 +107,7 @@ Implements a coupling-layer as defined in [1]. # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct CouplingLayer{B, M, F} <: Bijector where {B, M <: PartitionMask, F} +struct CouplingLayer{B, M, F} <: Bijector{1} where {B <: ZeroOrOneDimBijector, M <: PartitionMask, F} mask::M θ::F end From fec862e4b01f70dec45431b08b122026808b1e75 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Sep 2019 19:51:11 +0200 Subject: [PATCH 67/89] WIP making couplings work properly with dimensionality --- src/couplings.jl | 6 +++--- src/interface.jl | 23 ++++++++++++----------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/couplings.jl b/src/couplings.jl index 6463af21..f2a3217a 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -157,11 +157,11 @@ function logabsdetjac(cl::CouplingLayer{B}, x::AbstractVector) where {B} x_1, x_2, x_3 = partition(cl.mask, x) b = B(cl.θ(x_2)) - return logabsdetjac(b, x_1) + # `B` might be 0-dim in which case it will treat `x_1` as a batch + # therefore we sum to ensure such a thing does not happen + return sum(logabsdetjac(b, x_1)) end function logabsdetjac(cl::CouplingLayer{B}, x::AbstractMatrix) where {B} return vec(mapslices(z -> logabsdetjac(cl, z), x; dims = 1)) end - - diff --git a/src/interface.jl b/src/interface.jl index e6ed05f2..4c2f78c5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -517,8 +517,8 @@ struct Scale{T, N} <: Bijector{N} a::T end -Scale(a::T) where {T<:Real} = Scale{T, 0}(a) -Scale(a::A) where {T, N, A<:AbstractArray{T, N}} = Scale{A, N}(a) +Scale(a::T; dim::Val{D} = Val{0}()) where {T<:Real, D} = Scale{T, D}(a) +Scale(a::A; dim::Val{D} = Val{N}()) where {T, D, N, A<:AbstractArray{T, N}} = Scale{A, D}(a) (b::Scale)(x) = b.a .* x (b::Scale{<:Real})(x::AbstractArray) = b.a .* x @@ -535,6 +535,7 @@ inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) logabsdetjac(b::Scale{<:Real, 0}, x::Real) = log(abs(b.a)) logabsdetjac(b::Scale{<:Real, 0}, x::AbstractVector) = log(abs(b.a)) .* ones(eltype(x), length(x)) logabsdetjac(b::Scale{<:Real, 1}, x::AbstractVector) = log(abs(b.a)) * length(x) +logabsdetjac(b::Scale{<:Real, 1}, x::AbstractMatrix) = log(abs(b.a)) * length(x) * ones(eltype(x), size(x, 2)) logabsdetjac(b::Scale{<:AbstractVector, 1}, x::AbstractVector) = sum(log.(abs.(b.a))) logabsdetjac(b::Scale{<:AbstractVector, 1}, x::AbstractMatrix) = sum(log.(abs.(b.a))) * ones(eltype(x), size(x, 2)) @@ -791,13 +792,13 @@ Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) .+ res.logabsdetjac + return logpdf(td.dist, res.rv) + res.logabsdetjac end # TODO: implement more efficiently for flows in the case of `Matrix` function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) .+ res.logabsdetjac + return logpdf(td.dist, res.rv) + res.logabsdetjac end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -805,7 +806,7 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac end # TODO: should eventually drop using `logpdf_with_trans` and replace with @@ -852,18 +853,18 @@ and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -871,7 +872,7 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac + lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac return (lp, res.logabsdetjac) end @@ -891,14 +892,14 @@ the inverse transform to compute the necessary `logabsdetjac`. This is similar to `logpdf_with_trans`. """ # TODO: implement more efficiently for flows in the case of `Matrix` -logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) .- logjac +logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) - logjac logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) function logpdf_forward(td::MvTransformed{<:Dirichlet}, x, logjac) T = eltype(x) ϵ = _eps(T) - return logpdf(td.dist, mappedarray(z->z+ϵ, x)) .- logjac + return logpdf(td.dist, mappedarray(z->z+ϵ, x)) - logjac end From 0e8a8b9d433feea9095ff3ec0b4228fe2eb4cab4 Mon Sep 17 00:00:00 2001 From: "T Fjelde (RA H Ge)" Date: Tue, 24 Sep 2019 19:32:49 +0100 Subject: [PATCH 68/89] added comments to impl custom adjoints for Shift and Scale --- src/interface.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 4c2f78c5..d0faf8c8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -508,8 +508,9 @@ Shift(a::A) where {T, N, A<:AbstractArray{T, N}} = Shift{A, N}(a) inv(b::Shift) = Shift(-b.a) +# FIXME: implement custom adjoint to ensure we don't get tracking logabsdetjac(b::Shift{<:Real, 0}, x::Real) = zero(eltype(x)) -logabsdetjac(b::Shift{<:Real, 0}, x::AbstractVector) = zeros(eltype(x), length(x)) +logabsdetjac(b::Shift{<:Real, 0}, x::AbstractVector{T}) where {T<:Real} = zeros(T, length(x)) logabsdetjac(b::Shift{T, 1}, x::AbstractVector) where {T<:Union{Real, AbstractVector}} = zero(eltype(x)) logabsdetjac(b::Shift{T, 1}, x::AbstractMatrix) where {T<:Union{Real, AbstractVector}} = zeros(eltype(x), size(x, 2)) @@ -527,11 +528,7 @@ Scale(a::A; dim::Val{D} = Val{N}()) where {T, D, N, A<:AbstractArray{T, N}} = Sc inv(b::Scale) = Scale(inv(b.a)) inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) -# TODO: should this be implemented for batch-computation? -# There's an ambiguity issue -# logabsdetjac(b::Scale{<: AbstractVector}, x::AbstractMatrix) -# Is this a batch or is it simply a matrix we want to scale differently -# in each component? +# FIXME: implement custom adjoint to ensure we don't get tracking logabsdetjac(b::Scale{<:Real, 0}, x::Real) = log(abs(b.a)) logabsdetjac(b::Scale{<:Real, 0}, x::AbstractVector) = log(abs(b.a)) .* ones(eltype(x), length(x)) logabsdetjac(b::Scale{<:Real, 1}, x::AbstractVector) = log(abs(b.a)) * length(x) From e8a49bb47e5c3b6f61fed9e649857a7036b08056 Mon Sep 17 00:00:00 2001 From: "T Fjelde (RA H Ge)" Date: Wed, 25 Sep 2019 17:40:16 +0100 Subject: [PATCH 69/89] added slightly better support for Tracker.jl in CouplingLayer --- src/couplings.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/couplings.jl b/src/couplings.jl index f2a3217a..556aa196 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -134,7 +134,7 @@ function (cl::CouplingLayer{B})(x::AbstractVector) where {B} end # TODO: drop when we have dimensionality function (cl::CouplingLayer{B})(x::AbstractMatrix) where {B} - return mapslices(z -> cl(z), x; dims = 1) + return hcat([cl(x[:, i]) for i = 1:size(x, 2)]...) end @@ -150,7 +150,7 @@ function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractVector) where {B} end # TODO: drop when we have dimensionality function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractMatrix) where {B} - return mapslices(z -> icl(z), y; dims = 1) + return hcat([icl(y[:, i]) for i = 1:size(y, 2)]...) end function logabsdetjac(cl::CouplingLayer{B}, x::AbstractVector) where {B} @@ -163,5 +163,13 @@ function logabsdetjac(cl::CouplingLayer{B}, x::AbstractVector) where {B} end function logabsdetjac(cl::CouplingLayer{B}, x::AbstractMatrix) where {B} - return vec(mapslices(z -> logabsdetjac(cl, z), x; dims = 1)) + r = [logabsdetjac(cl, x[:, i]) for i = 1:size(x, 2)] + + # FIXME: this really needs to be handled in a better way + # We need to return a `TrackedArray` + if Tracker.istracked(r[1]) + return Tracker.collect(r) + else + return r + end end From dd5770c56c898fc8f122d080266ad649bfd0f748 Mon Sep 17 00:00:00 2001 From: "T Fjelde (RA H Ge)" Date: Wed, 25 Sep 2019 17:43:26 +0100 Subject: [PATCH 70/89] improved support for Tracker.jl --- src/interface.jl | 92 +++++++++++++++++++++++++++++++++++++++-------- test/interface.jl | 49 +++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 14 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index e6ed05f2..6252acd4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,6 +1,7 @@ using Distributions, Bijectors using ForwardDiff using Tracker +using Tracker: TrackedReal, TrackedArray, track, @grad, data import Base: inv, ∘ @@ -508,10 +509,27 @@ Shift(a::A) where {T, N, A<:AbstractArray{T, N}} = Shift{A, N}(a) inv(b::Shift) = Shift(-b.a) -logabsdetjac(b::Shift{<:Real, 0}, x::Real) = zero(eltype(x)) -logabsdetjac(b::Shift{<:Real, 0}, x::AbstractVector) = zeros(eltype(x), length(x)) -logabsdetjac(b::Shift{T, 1}, x::AbstractVector) where {T<:Union{Real, AbstractVector}} = zero(eltype(x)) -logabsdetjac(b::Shift{T, 1}, x::AbstractMatrix) where {T<:Union{Real, AbstractVector}} = zeros(eltype(x), size(x, 2)) +# FIXME: implement custom adjoint to ensure we don't get tracking +logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val{N}) + +_logabsdetjac_shift(a::Real, x::Real, ::Type{Val{0}}) = zero(eltype(x)) +_logabsdetjac_shift(a::Real, x::AbstractVector{T}, ::Type{Val{0}}) where {T<:Real} = zeros(T, length(x)) +_logabsdetjac_shift(a::T1, x::AbstractVector{T2}, ::Type{Val{1}}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zero(T2) +_logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Type{Val{1}}) where {T1<:Union{Real, AbstractVector}, T2<:Real} = zeros(T2, size(x, 2)) + +function _logabsdetjac_shift(a::TrackedReal, x::Real, ::Type{Val{0}}) + return Tracker.param(_logabsdetjac_shift(data(a), data(x), Val{0})) +end +function _logabsdetjac_shift(a::TrackedReal, x::AbstractVector{T}, ::Type{Val{0}}) where {T<:Real} + return Tracker.param(_logabsdetjac_shift(data(a), data(x), Val{0})) +end +function _logabsdetjac_shift(a::T1, x::AbstractVector{T2}, ::Type{Val{1}}) where {T1<:Union{TrackedReal, TrackedVector}, T2<:Real} + return Tracker.param(_logabsdetjac_shift(data(a), data(x), Val{1})) +end +function _logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Type{Val{1}}) where {T1<:Union{TrackedReal, TrackedVector}, T2<:Real} + return Tracker.param(_logabsdetjac_shift(data(a), data(x), Val{1})) +end + struct Scale{T, N} <: Bijector{N} a::T @@ -527,16 +545,62 @@ Scale(a::A) where {T, N, A<:AbstractArray{T, N}} = Scale{A, N}(a) inv(b::Scale) = Scale(inv(b.a)) inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) -# TODO: should this be implemented for batch-computation? -# There's an ambiguity issue -# logabsdetjac(b::Scale{<: AbstractVector}, x::AbstractMatrix) -# Is this a batch or is it simply a matrix we want to scale differently -# in each component? -logabsdetjac(b::Scale{<:Real, 0}, x::Real) = log(abs(b.a)) -logabsdetjac(b::Scale{<:Real, 0}, x::AbstractVector) = log(abs(b.a)) .* ones(eltype(x), length(x)) -logabsdetjac(b::Scale{<:Real, 1}, x::AbstractVector) = log(abs(b.a)) * length(x) -logabsdetjac(b::Scale{<:AbstractVector, 1}, x::AbstractVector) = sum(log.(abs.(b.a))) -logabsdetjac(b::Scale{<:AbstractVector, 1}, x::AbstractMatrix) = sum(log.(abs.(b.a))) * ones(eltype(x), size(x, 2)) +# We're going to implement custom adjoint for this +logabsdetjac(b::Scale{T, N}, x) where {T, N} = _logabsdetjac_scale(b.a, x, Val{N}) + +_logabsdetjac_scale(a::Real, x::Real, ::Type{Val{0}}) = log(abs(a)) +_logabsdetjac_scale(a::Real, x::AbstractVector, ::Type{Val{0}}) = fill(log(abs(a)), length(x)) +_logabsdetjac_scale(a::Real, x::AbstractVector, ::Type{Val{1}}) = log(abs(a)) * length(x) +_logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Type{Val{1}}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) +_logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Type{Val{1}}) = sum(log.(abs.(a))) +_logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Type{Val{1}}) = fill(sum(log.(abs.(a))), size(x, 2)) + +# adjoints for 0-dim and 1-dim `Scale` using `Real` +function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Type{Val{0}}) + return track(_logabsdetjac_scale, a, data(x), Val{0}) +end +@grad function _logabsdetjac_scale(a::Real, x::Real, ::Type{Val{0}}) + return _logabsdetjac_scale(data(a), data(x), Val{0}), Δ -> (inv(data(a)) .* Δ, nothing, nothing) +end + +# need to treat `AbstractVector` and `AbstractMatrix` separately due to ambiguity errors +function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Type{Val{0}}) + return track(_logabsdetjac_scale, a, data(x), Val{0}) +end +@grad function _logabsdetjac_scale(a::Real, x::AbstractVector, ::Type{Val{0}}) + da = data(a) + J = fill(inv.(da), length(x)) + return _logabsdetjac_scale(da, data(x), Val{0}), Δ -> (transpose(J) * Δ, nothing, nothing) +end + +function _logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Type{Val{0}}) + return track(_logabsdetjac_scale, a, data(x), Val{0}) +end +@grad function _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Type{Val{0}}) + da = data(a) + J = fill(size(x, 1) / da, size(x, 2)) + return _logabsdetjac_scale(da, data(x), Val{0}), Δ -> (transpose(J) * Δ, nothing, nothing) +end + +# adjoints for 1-dim and 2-dim `Scale` using `AbstractVector` +function _logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Type{Val{1}}) + return track(_logabsdetjac_scale, a, data(x), Val{1}) +end +@grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Type{Val{1}}) + da = data(a) + J = sum(inv.(da)) + return _logabsdetjac_scale(da, data(x), Val{1}), Δ -> (transpose(J) * Δ, nothing, nothing) +end +function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Type{Val{1}}) + return track(_logabsdetjac_scale, a, data(x), Val{1}) +end + +@grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Type{Val{1}}) + da = data(a) + J = sum(inv.(da)) .* ones(size(x, 2)) + return _logabsdetjac_scale(da, data(x), Val{1}), Δ -> (transpose(J) * Δ, nothing, nothing) +end + #################### # Simplex bijector # diff --git a/test/interface.jl b/test/interface.jl index bace393e..c143a52d 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -3,6 +3,7 @@ using Bijectors using Random using LinearAlgebra using ForwardDiff +using Tracker using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector @@ -226,6 +227,54 @@ struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end @testset "Composition" begin @test_throws DimensionMismatch (Exp{1}() ∘ Log{0}()) end + + @testset "Batch-computation with Tracker.jl" begin + @testset "Scale" begin + # 0-dim with `Real` parameter + b = Scale(param(2.0)) + lj = logabsdetjac(b, 1.0) + Tracker.back!(lj, 1.0) + @test Tracker.extract_grad!(b.a) == 0.5 + + # 0-dim with `Real` parameter for batch-computation + lj = logabsdetjac(b, [1.0, 2.0, 3.0]) + Tracker.back!(lj, [1.0, 1.0, 1.0]) + @test Tracker.extract_grad!(b.a) == sum([0.5, 0.5, 0.5]) + + + # 1-dim with `Vector` parameter + b = Scale(param([2.0, 3.0, 5.0])) + lj = logabsdetjac(b, [3.0, 4.0, 5.0]) + Tracker.back!(lj) + @test Tracker.extract_grad!(b.a) == fill(sum(inv.(b.a)), 3) + + lj = logabsdetjac(b, [3.0 4.0 5.0; 6.0 7.0 8.0]) + Tracker.back!(lj, [1.0, 1.0, 1.0]) + @test Tracker.extract_grad!(b.a) == fill(sum(inv.(b.a)), 3) .* 3 + end + + @testset "Shift" begin + b = Shift(param(1.0)) + lj = logabsdetjac(b, 1.0) + Tracker.back!(lj, 1.0) + @test Tracker.extract_grad!(b.a) == 0.0 + + # 0-dim with `Real` parameter for batch-computation + lj = logabsdetjac(b, [1.0, 2.0, 3.0]) + Tracker.back!(lj, [1.0, 1.0, 1.0]) + @test Tracker.extract_grad!(b.a) == 0.0 + + # 1-dim with `Vector` parameter + b = Shift(param([2.0, 3.0, 5.0])) + lj = logabsdetjac(b, [3.0, 4.0, 5.0]) + Tracker.back!(lj) + @test Tracker.extract_grad!(b.a) == zeros(3) + + lj = logabsdetjac(b, [3.0 4.0 5.0; 6.0 7.0 8.0]) + Tracker.back!(lj, [1.0, 1.0, 1.0]) + @test Tracker.extract_grad!(b.a) == zeros(3) + end + end end @testset "Truncated" begin From d98da129f6938f519a9fcd65ef0d269067e40f12 Mon Sep 17 00:00:00 2001 From: "T Fjelde (RA H Ge)" Date: Sun, 29 Sep 2019 11:58:50 +0100 Subject: [PATCH 71/89] made some improvements --- src/interface.jl | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 6252acd4..8078cbe0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -500,8 +500,8 @@ struct Shift{T, N} <: Bijector{N} a::T end -Shift(a::T) where {T<:Real} = Shift{T, 0}(a) -Shift(a::A) where {T, N, A<:AbstractArray{T, N}} = Shift{A, N}(a) +Shift(a::T; dim::Type{Val{D}} = Val{0}) where {T<:Real, D} = Shift{T, D}(a) +Shift(a::A; dim::Type{Val{D}} = Val{N}) where {T, D, N, A<:AbstractArray{T, N}} = Shift{A, N}(a) (b::Shift)(x) = b.a + x (b::Shift{<:Real})(x::AbstractArray) = b.a .+ x @@ -535,15 +535,16 @@ struct Scale{T, N} <: Bijector{N} a::T end -Scale(a::T) where {T<:Real} = Scale{T, 0}(a) -Scale(a::A) where {T, N, A<:AbstractArray{T, N}} = Scale{A, N}(a) +Scale(a::T; dim::Type{Val{D}} = Val{0}) where {T<:Real, D} = Scale{T, D}(a) +Scale(a::A; dim::Type{Val{D}} = Val{N}) where {T, D, N, A<:AbstractArray{T, N}} = Scale{A, D}(a) (b::Scale)(x) = b.a .* x (b::Scale{<:Real})(x::AbstractArray) = b.a .* x +(b::Scale{<:AbstractMatrix})(x::AbstractArray) = b.a * x (b::Scale{<:AbstractVector{<:Real}, 2})(x::AbstractMatrix{<:Real}) = b.a .* x -inv(b::Scale) = Scale(inv(b.a)) -inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) +inv(b::Scale{T, D}) where {T, D} = Scale(inv(b.a); dim = Val{D}) +inv(b::Scale{<:AbstractVector, D}) where {D} = Scale(inv.(b.a); dim = Val{D}) # We're going to implement custom adjoint for this logabsdetjac(b::Scale{T, N}, x) where {T, N} = _logabsdetjac_scale(b.a, x, Val{N}) @@ -554,6 +555,8 @@ _logabsdetjac_scale(a::Real, x::AbstractVector, ::Type{Val{1}}) = log(abs(a)) * _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Type{Val{1}}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) _logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Type{Val{1}}) = sum(log.(abs.(a))) _logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Type{Val{1}}) = fill(sum(log.(abs.(a))), size(x, 2)) +_logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Type{Val{1}}) = log(abs(det(a))) +_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Type{Val{1}}) where {T} = log(abs(det(a))) * ones(T, size(x, 2)) # adjoints for 0-dim and 1-dim `Scale` using `Real` function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Type{Val{0}}) @@ -591,16 +594,24 @@ end J = sum(inv.(da)) return _logabsdetjac_scale(da, data(x), Val{1}), Δ -> (transpose(J) * Δ, nothing, nothing) end + function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Type{Val{1}}) return track(_logabsdetjac_scale, a, data(x), Val{1}) end - @grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Type{Val{1}}) da = data(a) J = sum(inv.(da)) .* ones(size(x, 2)) return _logabsdetjac_scale(da, data(x), Val{1}), Δ -> (transpose(J) * Δ, nothing, nothing) end +# TODO: implement analytical gradient for scaling a vector using a matrix +# function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Type{Val{1}}) +# track(_logabsdetjac_scale, a, data(x), Val{1}) +# end +# @grad function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Type{Val{1}}) +# throw +# end + #################### # Simplex bijector # @@ -859,6 +870,12 @@ function logpdf(td::UnivariateTransformed, y::Real) end # TODO: implement more efficiently for flows in the case of `Matrix` +function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) + # batch-implementation for multivariate + res = forward(inv(td.transform), y) + return logpdf(td.dist, res.rv) + res.logabsdetjac +end + function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) return logpdf(td.dist, res.rv) .+ res.logabsdetjac From ff56f75393f278326b861e1e29ab36bc021140b5 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Tue, 1 Oct 2019 14:04:01 +0100 Subject: [PATCH 72/89] fixed Shift inverse and added matrix-scaling of vectors --- src/interface.jl | 51 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 6252acd4..089ea9f9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -500,14 +500,14 @@ struct Shift{T, N} <: Bijector{N} a::T end -Shift(a::T) where {T<:Real} = Shift{T, 0}(a) -Shift(a::A) where {T, N, A<:AbstractArray{T, N}} = Shift{A, N}(a) +Shift(a::T; dim::Type{Val{D}} = Val{0}) where {T<:Real, D} = Shift{T, D}(a) +Shift(a::A; dim::Type{Val{D}} = Val{N}) where {T, D, N, A<:AbstractArray{T, N}} = Shift{A, N}(a) (b::Shift)(x) = b.a + x (b::Shift{<:Real})(x::AbstractArray) = b.a .+ x (b::Shift{<:AbstractVector})(x::AbstractMatrix) = b.a .+ x -inv(b::Shift) = Shift(-b.a) +inv(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val{N}) @@ -535,15 +535,16 @@ struct Scale{T, N} <: Bijector{N} a::T end -Scale(a::T) where {T<:Real} = Scale{T, 0}(a) -Scale(a::A) where {T, N, A<:AbstractArray{T, N}} = Scale{A, N}(a) +Scale(a::T; dim::Type{Val{D}} = Val{0}) where {T<:Real, D} = Scale{T, D}(a) +Scale(a::A; dim::Type{Val{D}} = Val{N}) where {T, D, N, A<:AbstractArray{T, N}} = Scale{A, D}(a) (b::Scale)(x) = b.a .* x (b::Scale{<:Real})(x::AbstractArray) = b.a .* x +(b::Scale{<:AbstractMatrix})(x::AbstractArray) = b.a * x (b::Scale{<:AbstractVector{<:Real}, 2})(x::AbstractMatrix{<:Real}) = b.a .* x -inv(b::Scale) = Scale(inv(b.a)) -inv(b::Scale{<:AbstractVector}) = Scale(inv.(b.a)) +inv(b::Scale{T, D}) where {T, D} = Scale(inv(b.a); dim = Val{D}) +inv(b::Scale{<:AbstractVector, D}) where {D} = Scale(inv.(b.a); dim = Val{D}) # We're going to implement custom adjoint for this logabsdetjac(b::Scale{T, N}, x) where {T, N} = _logabsdetjac_scale(b.a, x, Val{N}) @@ -554,6 +555,8 @@ _logabsdetjac_scale(a::Real, x::AbstractVector, ::Type{Val{1}}) = log(abs(a)) * _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Type{Val{1}}) = fill(log(abs(a)) * size(x, 1), size(x, 2)) _logabsdetjac_scale(a::AbstractVector, x::AbstractVector, ::Type{Val{1}}) = sum(log.(abs.(a))) _logabsdetjac_scale(a::AbstractVector, x::AbstractMatrix, ::Type{Val{1}}) = fill(sum(log.(abs.(a))), size(x, 2)) +_logabsdetjac_scale(a::AbstractMatrix, x::AbstractVector, ::Type{Val{1}}) = log(abs(det(a))) +_logabsdetjac_scale(a::AbstractMatrix, x::AbstractMatrix{T}, ::Type{Val{1}}) where {T} = log(abs(det(a))) * ones(T, size(x, 2)) # adjoints for 0-dim and 1-dim `Scale` using `Real` function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Type{Val{0}}) @@ -591,16 +594,24 @@ end J = sum(inv.(da)) return _logabsdetjac_scale(da, data(x), Val{1}), Δ -> (transpose(J) * Δ, nothing, nothing) end + function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Type{Val{1}}) return track(_logabsdetjac_scale, a, data(x), Val{1}) end - @grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Type{Val{1}}) da = data(a) J = sum(inv.(da)) .* ones(size(x, 2)) return _logabsdetjac_scale(da, data(x), Val{1}), Δ -> (transpose(J) * Δ, nothing, nothing) end +# TODO: implement analytical gradient for scaling a vector using a matrix +# function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Type{Val{1}}) +# track(_logabsdetjac_scale, a, data(x), Val{1}) +# end +# @grad function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Type{Val{1}}) +# throw +# end + #################### # Simplex bijector # @@ -855,13 +866,19 @@ Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) .+ res.logabsdetjac + return logpdf(td.dist, res.rv) + res.logabsdetjac end # TODO: implement more efficiently for flows in the case of `Matrix` +function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) + # batch-implementation for multivariate + res = forward(inv(td.transform), y) + return logpdf(td.dist, res.rv) + res.logabsdetjac +end + function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) .+ res.logabsdetjac + return logpdf(td.dist, res.rv) + res.logabsdetjac end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -869,7 +886,7 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac end # TODO: should eventually drop using `logpdf_with_trans` and replace with @@ -916,18 +933,18 @@ and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) .+ res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -935,7 +952,7 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) .+ res.logabsdetjac + lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac return (lp, res.logabsdetjac) end @@ -955,14 +972,14 @@ the inverse transform to compute the necessary `logabsdetjac`. This is similar to `logpdf_with_trans`. """ # TODO: implement more efficiently for flows in the case of `Matrix` -logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) .- logjac +logpdf_forward(td::Transformed, x, logjac) = logpdf(td.dist, x) - logjac logpdf_forward(td::Transformed, x) = logpdf_forward(td, x, logabsdetjac(td.transform, x)) function logpdf_forward(td::MvTransformed{<:Dirichlet}, x, logjac) T = eltype(x) ϵ = _eps(T) - return logpdf(td.dist, mappedarray(z->z+ϵ, x)) .- logjac + return logpdf(td.dist, mappedarray(z->z+ϵ, x)) - logjac end From 4ceb033588562e0eab41469319a00340dc467e4b Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 17 Oct 2019 02:54:44 +0100 Subject: [PATCH 73/89] added convenient couple and coupling methods --- src/couplings.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/couplings.jl b/src/couplings.jl index 556aa196..11165e4e 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -122,6 +122,17 @@ function CouplingLayer(cl::CouplingLayer{B}, mask::PartitionMask) where {B} return CouplingLayer(B, mask, cl.θ) end +coupling(cl::CouplingLayer{B}) where {B} = B +function couple(cl::CouplingLayer{B}, x::AbstractVector) where {B} + # partition vector using `cl.mask::PartitionMask` + x_1, x_2, x_3 = partition(cl.mask, x) + + # construct bijector `B` using θ(x₂) + b = B(cl.θ(x_2)) + + return b +end + function (cl::CouplingLayer{B})(x::AbstractVector) where {B} # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) From 16cc948e054df78529aca5bbd3d65479e7f3ba82 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Sat, 19 Oct 2019 12:33:49 +0100 Subject: [PATCH 74/89] added testing and a bit more docstrings --- src/couplings.jl | 68 +++++++++++++++++++++++++++++++++++++++---- test/couplings.jl | 73 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 5 deletions(-) create mode 100644 test/couplings.jl diff --git a/src/couplings.jl b/src/couplings.jl index 11165e4e..c52b5b21 100644 --- a/src/couplings.jl +++ b/src/couplings.jl @@ -8,11 +8,43 @@ This is used to partition and recombine a vector into 3 disjoint "subvectors". Implements - `partition(m::PartitionMask, x)`: partitions `x` into 3 disjoint "subvectors" - `combine(m::PartitionMask, x_1, x_2, x_3)`: combines 3 disjoint vectors into a single one + +Note that `PartitionMask` is _not_ a `Bijector`. It is indeed a bijection, but +does not follow the `Bijector` interface. + +Its main use is in `CouplingLayer` where we want to partition the input into 3 parts, +one part to transform, one part to map into the parameter-space of the transform applied +to the first part, and the last part of the vector is not used for anything. + +# Examples +```julia-repl +julia> m = PartitionMask(3, [1], [2]) # <= assumes input-length 3 +PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( + [1, 1] = 1.0, + [2, 1] = 1.0, + [3, 1] = 1.0) + +julia> # Partition into 3 parts; the last part is inferred to be indices `[3, ]` from + # the fact that `[1]` and `[2]` does not make up all indices in `1:3`. + x1, x2, x3 = partition(m, [1., 2., 3.]) +([1.0], [2.0], [3.0]) + +julia> # Recombines the partitions into a vector + combine(m, x1, x2, x3) +3-element Array{Float64,1}: + 1.0 + 2.0 + 3.0 +``` + """ struct PartitionMask{A} A_1::A A_2::A A_3::A + + # Only make it possible to construct using matrices + PartitionMask(A_1::A, A_2::A, A_3::A) where {A <: AbstractMatrix{<:Real}} = new{A}(A_1, A_2, A_3) end function PartitionMask( @@ -40,19 +72,25 @@ function PartitionMask( return PartitionMask(A_1, A_2, A_3) end +PartitionMask( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::AbstractVector{Int} +) = PartitionMask(n, indices_1, indices_2, nothing) + PartitionMask( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}, indices_3::Nothing -) = PartitionMask(n, indices_1, indices_2, Int[]) +) = PartitionMask(n, indices_1, indices_2, [i for i in 1:n if i ∉ (indices_1 ∪ indices_2)]) PartitionMask( n::Int, indices_1::AbstractVector{Int}, indices_2::Nothing, indices_3::AbstractVector{Int} -) = PartitionMask(n, indices_1, Int[], indices_3) +) = PartitionMask(n, indices_1, [i for i in 1:n if i ∉ (indices_1 ∪ indices_3)], indices_3) """ PartitionMask(n::Int, indices) @@ -104,10 +142,29 @@ Partitions `x` into 3 disjoint subvectors. Implements a coupling-layer as defined in [1]. +# Examples +```julia-repl +julia> m = PartitionMask(3, [1], [2]) # <= going to use x[2] to parameterize transform of x[1] +PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( + [1, 1] = 1.0, + [2, 1] = 1.0, + [3, 1] = 1.0) + +julia> cl = CouplingLayer(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`; + +julia> x = [1., 2., 3.]; + +julia> cl(x) +3-element Array{Float64,1}: + 3.0 + 2.0 + 3.0 +``` + # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct CouplingLayer{B, M, F} <: Bijector{1} where {B <: ZeroOrOneDimBijector, M <: PartitionMask, F} +struct CouplingLayer{B, M, F} <: Bijector{1} where {B, M <: PartitionMask, F} mask::M θ::F end @@ -122,7 +179,10 @@ function CouplingLayer(cl::CouplingLayer{B}, mask::PartitionMask) where {B} return CouplingLayer(B, mask, cl.θ) end +"Returns the constructor of the coupling law." coupling(cl::CouplingLayer{B}) where {B} = B + +"Returns the coupling law constructed from `x`." function couple(cl::CouplingLayer{B}, x::AbstractVector) where {B} # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) @@ -143,7 +203,6 @@ function (cl::CouplingLayer{B})(x::AbstractVector) where {B} # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end -# TODO: drop when we have dimensionality function (cl::CouplingLayer{B})(x::AbstractMatrix) where {B} return hcat([cl(x[:, i]) for i = 1:size(x, 2)]...) end @@ -159,7 +218,6 @@ function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractVector) where {B} return combine(cl.mask, ib(y_1), y_2, y_3) end -# TODO: drop when we have dimensionality function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractMatrix) where {B} return hcat([icl(y[:, i]) for i = 1:size(y, 2)]...) end diff --git a/test/couplings.jl b/test/couplings.jl new file mode 100644 index 00000000..b1fea825 --- /dev/null +++ b/test/couplings.jl @@ -0,0 +1,73 @@ +using Test +using Bijectors +using Random +using LinearAlgebra +using ForwardDiff +using Tracker +using Flux + +using Bijectors: + CouplingLayer, + PartitionMask, + coupling, + couple, + partition, + combine, + Shift + +@testset "CouplingLayer" begin + @testset "PartitionMask" begin + m1 = PartitionMask(3, [1], [2]) + m2 = PartitionMask(3, [1], [2], [3]) + + @test (m1.A_1 == m2.A_1) & (m1.A_2 == m2.A_2) & (m1.A_3 == m2.A_3) + + x = [1., 2., 3.] + x1, x2, x3 = partition(m1, x) + @test (x1 == [1.]) & (x2 == [2.]) & (x3 == [3.]) + + y = combine(m1, x1, x2, x3) + @test y == x + end + + @testset "Basics" begin + m = PartitionMask(3, [1], [2]) + cl1 = CouplingLayer(Shift, m, x -> x[1]) + + x = [1., 2., 3.] + @test cl1(x) == [3., 2., 3.] + + cl2 = CouplingLayer(θ -> Shift(θ[1]), m, identity) + @test cl2(x) == cl1(x) + + # inversion + icl1 = inv(cl1) + @test icl1(cl1(x)) == x + @test inv(cl2)(cl2(x)) == x + + # This `cl2` should result in + b = Shift(x[2:2]) + + # logabsdetjac + @test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1]) + + # forward + @test forward(cl1, x) == (rv = cl1(x), logabsdetjac = logabsdetjac(cl1, x)) + @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) + end + + @testset "Tracker" begin + Random.seed!(123) + x = [1., 2., 3.] + + m = PartitionMask(length(x), [1], [2]) + nn = Chain(Dense(1, 2, relu), Dense(2, 1)) + cl = CouplingLayer(Shift, m, nn) + + # should leave two last indices unchanged + @test cl(x)[2:3] == x[2:3] + + # verify that indeed it's tracked + @test Tracker.istracked(cl(x)) + end +end From f0fc0e5463f97e3def0ac88426eb7e67e4a63070 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 7 Nov 2019 10:45:13 +0000 Subject: [PATCH 75/89] fixed a merge-mistake --- Project.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Project.toml b/Project.toml index a8b6394d..62aed0a5 100644 --- a/Project.toml +++ b/Project.toml @@ -11,11 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" -<<<<<<< HEAD SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -======= ->>>>>>> master StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" From c0053c88d44389ed48c4431a130a88cba93b56f8 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 7 Nov 2019 10:51:41 +0000 Subject: [PATCH 76/89] added CouplingLayer tests to the interface-tests --- test/couplings.jl | 80 +++++++++++++++++++++++------------------------ test/interface.jl | 1 + 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/test/couplings.jl b/test/couplings.jl index b1fea825..8029243b 100644 --- a/test/couplings.jl +++ b/test/couplings.jl @@ -15,59 +15,57 @@ using Bijectors: combine, Shift -@testset "CouplingLayer" begin - @testset "PartitionMask" begin - m1 = PartitionMask(3, [1], [2]) - m2 = PartitionMask(3, [1], [2], [3]) +@testset "PartitionMask" begin + m1 = PartitionMask(3, [1], [2]) + m2 = PartitionMask(3, [1], [2], [3]) - @test (m1.A_1 == m2.A_1) & (m1.A_2 == m2.A_2) & (m1.A_3 == m2.A_3) + @test (m1.A_1 == m2.A_1) & (m1.A_2 == m2.A_2) & (m1.A_3 == m2.A_3) - x = [1., 2., 3.] - x1, x2, x3 = partition(m1, x) - @test (x1 == [1.]) & (x2 == [2.]) & (x3 == [3.]) + x = [1., 2., 3.] + x1, x2, x3 = partition(m1, x) + @test (x1 == [1.]) & (x2 == [2.]) & (x3 == [3.]) - y = combine(m1, x1, x2, x3) - @test y == x - end + y = combine(m1, x1, x2, x3) + @test y == x +end - @testset "Basics" begin - m = PartitionMask(3, [1], [2]) - cl1 = CouplingLayer(Shift, m, x -> x[1]) +@testset "Basics" begin + m = PartitionMask(3, [1], [2]) + cl1 = CouplingLayer(Shift, m, x -> x[1]) - x = [1., 2., 3.] - @test cl1(x) == [3., 2., 3.] + x = [1., 2., 3.] + @test cl1(x) == [3., 2., 3.] - cl2 = CouplingLayer(θ -> Shift(θ[1]), m, identity) - @test cl2(x) == cl1(x) + cl2 = CouplingLayer(θ -> Shift(θ[1]), m, identity) + @test cl2(x) == cl1(x) - # inversion - icl1 = inv(cl1) - @test icl1(cl1(x)) == x - @test inv(cl2)(cl2(x)) == x + # inversion + icl1 = inv(cl1) + @test icl1(cl1(x)) == x + @test inv(cl2)(cl2(x)) == x - # This `cl2` should result in - b = Shift(x[2:2]) + # This `cl2` should result in + b = Shift(x[2:2]) - # logabsdetjac - @test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1]) + # logabsdetjac + @test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1]) - # forward - @test forward(cl1, x) == (rv = cl1(x), logabsdetjac = logabsdetjac(cl1, x)) - @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) - end + # forward + @test forward(cl1, x) == (rv = cl1(x), logabsdetjac = logabsdetjac(cl1, x)) + @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) +end - @testset "Tracker" begin - Random.seed!(123) - x = [1., 2., 3.] +@testset "Tracker" begin + Random.seed!(123) + x = [1., 2., 3.] - m = PartitionMask(length(x), [1], [2]) - nn = Chain(Dense(1, 2, relu), Dense(2, 1)) - cl = CouplingLayer(Shift, m, nn) + m = PartitionMask(length(x), [1], [2]) + nn = Chain(Dense(1, 2, relu), Dense(2, 1)) + cl = CouplingLayer(Shift, m, nn) - # should leave two last indices unchanged - @test cl(x)[2:3] == x[2:3] + # should leave two last indices unchanged + @test cl(x)[2:3] == x[2:3] - # verify that indeed it's tracked - @test Tracker.istracked(cl(x)) - end + # verify that indeed it's tracked + @test Tracker.istracked(cl(x)) end diff --git a/test/interface.jl b/test/interface.jl index 4ef08412..58ef5726 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -623,3 +623,4 @@ end @test 0 ≤ x ≤ 1 end +include("couplings.jl") From d0071a3daf035a473031e75727d76f71878ecbe6 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Thu, 23 Jan 2020 19:25:41 +0000 Subject: [PATCH 77/89] fixed tests with CouplingLayer and Tracker.jl --- Project.toml | 3 ++- test/couplings.jl | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 274ea700..89db5704 100644 --- a/Project.toml +++ b/Project.toml @@ -26,9 +26,10 @@ Tracker = "0.2.3" julia = "1" [extras] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["ForwardDiff", "Test", "Tracker"] +test = ["ForwardDiff", "Test", "Tracker", "Flux"] diff --git a/test/couplings.jl b/test/couplings.jl index 8029243b..2da3caf1 100644 --- a/test/couplings.jl +++ b/test/couplings.jl @@ -60,8 +60,9 @@ end x = [1., 2., 3.] m = PartitionMask(length(x), [1], [2]) - nn = Chain(Dense(1, 2, relu), Dense(2, 1)) - cl = CouplingLayer(Shift, m, nn) + nn = Chain(Dense(1, 2, sigmoid), Dense(2, 1)) + nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn) + cl = CouplingLayer(Shift, m, nn_tracked) # should leave two last indices unchanged @test cl(x)[2:3] == x[2:3] From 733393a2e1fdf056f8d7c52f737854ea4173ca5f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 16 Feb 2020 20:15:45 +0000 Subject: [PATCH 78/89] now using Bool for the sparse arrays in PartitionMask --- src/bijectors/coupling_layer.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/bijectors/coupling_layer.jl b/src/bijectors/coupling_layer.jl index c52b5b21..d98a36f7 100644 --- a/src/bijectors/coupling_layer.jl +++ b/src/bijectors/coupling_layer.jl @@ -53,9 +53,9 @@ function PartitionMask( indices_2::AbstractVector{Int}, indices_3::AbstractVector{Int} ) - A_1 = spzeros(n, length(indices_1)); - A_2 = spzeros(n, length(indices_2)); - A_3 = spzeros(n, length(indices_3)); + A_1 = spzeros(Bool, n, length(indices_1)); + A_2 = spzeros(Bool, n, length(indices_2)); + A_3 = spzeros(Bool, n, length(indices_3)); for (i, idx) in enumerate(indices_1) A_1[idx, i] = 1.0 @@ -102,8 +102,8 @@ function PartitionMask(n::Int, indices) indices_2 = [i for i in 1:n if i ∉ indices] # sparse arrays <3 - A_1 = spzeros(n, length(indices)); - A_2 = spzeros(n, length(indices_2)); + A_1 = spzeros(Bool, n, length(indices)); + A_2 = spzeros(Bool, n, length(indices_2)); # Like doing: # A[1, 1] = 1.0 @@ -116,7 +116,7 @@ function PartitionMask(n::Int, indices) A_2[idx, i] = 1.0 end - return PartitionMask(A_1, A_2, spzeros(n, 0)) + return PartitionMask(A_1, A_2, spzeros(Bool, n, 0)) end PartitionMask(x::AbstractVector, indices) = PartitionMask(length(x), indices) From cd487661693ab7e8248ccfa7b491af0cf78d9877 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 16 Feb 2020 20:16:39 +0000 Subject: [PATCH 79/89] PartitionMask construction now using true instead of 1.0 --- src/bijectors/coupling_layer.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/bijectors/coupling_layer.jl b/src/bijectors/coupling_layer.jl index d98a36f7..93087330 100644 --- a/src/bijectors/coupling_layer.jl +++ b/src/bijectors/coupling_layer.jl @@ -58,15 +58,15 @@ function PartitionMask( A_3 = spzeros(Bool, n, length(indices_3)); for (i, idx) in enumerate(indices_1) - A_1[idx, i] = 1.0 + A_1[idx, i] = true end for (i, idx) in enumerate(indices_2) - A_2[idx, i] = 1.0 + A_2[idx, i] = true end for (i, idx) in enumerate(indices_3) - A_3[idx, i] = 1.0 + A_3[idx, i] = true end return PartitionMask(A_1, A_2, A_3) @@ -109,11 +109,11 @@ function PartitionMask(n::Int, indices) # A[1, 1] = 1.0 # A[3, 2] = 1.0 for (i, idx) in enumerate(indices) - A_1[idx, i] = 1.0 + A_1[idx, i] = true end for (i, idx) in enumerate(indices_2) - A_2[idx, i] = 1.0 + A_2[idx, i] = true end return PartitionMask(A_1, A_2, spzeros(Bool, n, 0)) From df2573f8447ce591c087075e7428662e08b6d6ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 17 Feb 2020 00:54:21 +0000 Subject: [PATCH 80/89] fixed some tests I broke --- test/bijectors/couplings.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bijectors/couplings.jl b/test/bijectors/couplings.jl index 53a36289..19ca7fe4 100644 --- a/test/bijectors/couplings.jl +++ b/test/bijectors/couplings.jl @@ -61,7 +61,7 @@ using Bijectors: x = [1., 2., 3.] m = PartitionMask(length(x), [1], [2]) - nn = Chain(Dense(1, 2, sigmoid), Dense(2, 1)) + nn = Flux.Chain(Flux.Dense(1, 2, Flux.sigmoid), Flux.Dense(2, 1)) nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn) cl = CouplingLayer(Shift, m, nn_tracked) From 6bfcab48d42b2d28bdc950144cd9b75f124dc5d6 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Mon, 18 May 2020 03:05:38 +0200 Subject: [PATCH 81/89] added convenient default constructor to CouplingLayer --- src/bijectors/coupling_layer.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/bijectors/coupling_layer.jl b/src/bijectors/coupling_layer.jl index 93087330..cf0d5a1c 100644 --- a/src/bijectors/coupling_layer.jl +++ b/src/bijectors/coupling_layer.jl @@ -169,6 +169,7 @@ struct CouplingLayer{B, M, F} <: Bijector{1} where {B, M <: PartitionMask, F} θ::F end +CouplingLayer(B, mask::M) where {M} = CouplingLayer{B, M, typeof(identity)}(mask, identity) CouplingLayer(B, mask::M, θ::F) where {M, F} = CouplingLayer{B, M, F}(mask, θ) function CouplingLayer(B, θ, n::Int) idx = Int(floor(n / 2)) @@ -208,7 +209,7 @@ function (cl::CouplingLayer{B})(x::AbstractMatrix) where {B} end -function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractVector) where {B} +function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractVector) where {B} cl = icl.orig y_1, y_2, y_3 = partition(cl.mask, y) @@ -218,7 +219,7 @@ function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractVector) where {B} return combine(cl.mask, ib(y_1), y_2, y_3) end -function (icl::Inversed{<:CouplingLayer{B}})(y::AbstractMatrix) where {B} +function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractMatrix) where {B} return hcat([icl(y[:, i]) for i = 1:size(y, 2)]...) end From e9f27f566b0c70334cd42ff875765144953b9a74 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Mon, 18 May 2020 03:56:14 +0200 Subject: [PATCH 82/89] added to docstring for CouplingLayer --- src/bijectors/coupling_layer.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/bijectors/coupling_layer.jl b/src/bijectors/coupling_layer.jl index cf0d5a1c..5904d506 100644 --- a/src/bijectors/coupling_layer.jl +++ b/src/bijectors/coupling_layer.jl @@ -159,6 +159,18 @@ julia> cl(x) 3.0 2.0 3.0 + +julia> inv(cl)(cl(x)) +3-element Array{Float64,1}: + 1.0 + 2.0 + 3.0 + +julia> coupling(cl) # get the `Bijector` map `θ -> b(⋅, θ)` +Shift + +julia> couple(cl, x) # get the `Bijector` resulting from `x` +Shift{Array{Float64,1},1}([2.0]) ``` # References From efc42bb743eb9e6db97ebf4bc1e44520a5e31901 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Sat, 12 Sep 2020 18:42:11 +0200 Subject: [PATCH 83/89] renamed CouplingLayer to Coupling and combined the two functions --- src/bijectors/coupling_layer.jl | 46 ++++++++++++++++----------------- test/bijectors/couplings.jl | 10 +++---- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/src/bijectors/coupling_layer.jl b/src/bijectors/coupling_layer.jl index 5904d506..68400f02 100644 --- a/src/bijectors/coupling_layer.jl +++ b/src/bijectors/coupling_layer.jl @@ -12,7 +12,7 @@ Implements Note that `PartitionMask` is _not_ a `Bijector`. It is indeed a bijection, but does not follow the `Bijector` interface. -Its main use is in `CouplingLayer` where we want to partition the input into 3 parts, +Its main use is in `Coupling` where we want to partition the input into 3 parts, one part to transform, one part to map into the parameter-space of the transform applied to the first part, and the last part of the vector is not used for anything. @@ -135,10 +135,10 @@ Partitions `x` into 3 disjoint subvectors. @inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x) -# CouplingLayer +# Coupling """ - CouplingLayer{B, M, F}(mask::M, θ::F) + Coupling{F, M}(θ::F, mask::M) Implements a coupling-layer as defined in [1]. @@ -150,7 +150,7 @@ PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( [2, 1] = 1.0, [3, 1] = 1.0) -julia> cl = CouplingLayer(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`; +julia> cl = Coupling(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`; julia> x = [1., 2., 3.]; @@ -176,75 +176,73 @@ Shift{Array{Float64,1},1}([2.0]) # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct CouplingLayer{B, M, F} <: Bijector{1} where {B, M <: PartitionMask, F} - mask::M +struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask} θ::F + mask::M end -CouplingLayer(B, mask::M) where {M} = CouplingLayer{B, M, typeof(identity)}(mask, identity) -CouplingLayer(B, mask::M, θ::F) where {M, F} = CouplingLayer{B, M, F}(mask, θ) -function CouplingLayer(B, θ, n::Int) +function Coupling(θ, n::Int) idx = Int(floor(n / 2)) - return CouplingLayer(B, PartitionMask(n, 1:idx), θ) + return Coupling(θ, PartitionMask(n, 1:idx)) end -function CouplingLayer(cl::CouplingLayer{B}, mask::PartitionMask) where {B} - return CouplingLayer(B, mask, cl.θ) +function Coupling(cl::Coupling{B}, mask::PartitionMask) where {B} + return Coupling(cl.θ, mask) end "Returns the constructor of the coupling law." -coupling(cl::CouplingLayer{B}) where {B} = B +coupling(cl::Coupling) = cl.θ "Returns the coupling law constructed from `x`." -function couple(cl::CouplingLayer{B}, x::AbstractVector) where {B} +function couple(cl::Coupling, x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) # construct bijector `B` using θ(x₂) - b = B(cl.θ(x_2)) + b = cl.θ(x_2) return b end -function (cl::CouplingLayer{B})(x::AbstractVector) where {B} +function (cl::Coupling)(x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) # construct bijector `B` using θ(x₂) - b = B(cl.θ(x_2)) + b = cl.θ(x_2) # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end -function (cl::CouplingLayer{B})(x::AbstractMatrix) where {B} +function (cl::Coupling)(x::AbstractMatrix) return hcat([cl(x[:, i]) for i = 1:size(x, 2)]...) end -function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractVector) where {B} +function (icl::Inverse{<:Coupling})(y::AbstractVector) cl = icl.orig y_1, y_2, y_3 = partition(cl.mask, y) - b = B(cl.θ(y_2)) + b = cl.θ(y_2) ib = inv(b) return combine(cl.mask, ib(y_1), y_2, y_3) end -function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractMatrix) where {B} +function (icl::Inverse{<:Coupling})(y::AbstractMatrix) return hcat([icl(y[:, i]) for i = 1:size(y, 2)]...) end -function logabsdetjac(cl::CouplingLayer{B}, x::AbstractVector) where {B} +function logabsdetjac(cl::Coupling, x::AbstractVector) x_1, x_2, x_3 = partition(cl.mask, x) - b = B(cl.θ(x_2)) + b = cl.θ(x_2) # `B` might be 0-dim in which case it will treat `x_1` as a batch # therefore we sum to ensure such a thing does not happen return sum(logabsdetjac(b, x_1)) end -function logabsdetjac(cl::CouplingLayer{B}, x::AbstractMatrix) where {B} +function logabsdetjac(cl::Coupling, x::AbstractMatrix) r = [logabsdetjac(cl, x[:, i]) for i = 1:size(x, 2)] # FIXME: this really needs to be handled in a better way diff --git a/test/bijectors/couplings.jl b/test/bijectors/couplings.jl index 19ca7fe4..49356c26 100644 --- a/test/bijectors/couplings.jl +++ b/test/bijectors/couplings.jl @@ -7,7 +7,7 @@ using Tracker import Flux using Bijectors: - CouplingLayer, + Coupling, PartitionMask, coupling, couple, @@ -15,7 +15,7 @@ using Bijectors: combine, Shift -@testset "CouplingLayer" begin +@testset "Coupling" begin @testset "PartitionMask" begin m1 = PartitionMask(3, [1], [2]) m2 = PartitionMask(3, [1], [2], [3]) @@ -32,12 +32,12 @@ using Bijectors: @testset "Basics" begin m = PartitionMask(3, [1], [2]) - cl1 = CouplingLayer(Shift, m, x -> x[1]) + cl1 = Coupling(x -> Shift(x[1]), m) x = [1., 2., 3.] @test cl1(x) == [3., 2., 3.] - cl2 = CouplingLayer(θ -> Shift(θ[1]), m, identity) + cl2 = Coupling(θ -> Shift(θ[1]), m) @test cl2(x) == cl1(x) # inversion @@ -63,7 +63,7 @@ using Bijectors: m = PartitionMask(length(x), [1], [2]) nn = Flux.Chain(Flux.Dense(1, 2, Flux.sigmoid), Flux.Dense(2, 1)) nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn) - cl = CouplingLayer(Shift, m, nn_tracked) + cl = Coupling(θ -> Shift(nn_tracked(θ)), m) # should leave two last indices unchanged @test cl(x)[2:3] == x[2:3] From 6e002f536ed113f5f9be1ec3f656c0db8b354481 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Sat, 12 Sep 2020 21:19:09 +0200 Subject: [PATCH 84/89] improved testing for coupling and small changes to batch impls --- .../{coupling_layer.jl => coupling.jl} | 4 +- test/ad/utils.jl | 8 +- test/bijectors/{couplings.jl => coupling.jl} | 49 +++-- test/bijectors/utils.jl | 194 ++++++++++++++++++ test/runtests.jl | 6 +- 5 files changed, 234 insertions(+), 27 deletions(-) rename src/bijectors/{coupling_layer.jl => coupling.jl} (98%) rename test/bijectors/{couplings.jl => coupling.jl} (54%) create mode 100644 test/bijectors/utils.jl diff --git a/src/bijectors/coupling_layer.jl b/src/bijectors/coupling.jl similarity index 98% rename from src/bijectors/coupling_layer.jl rename to src/bijectors/coupling.jl index 68400f02..c112c41e 100644 --- a/src/bijectors/coupling_layer.jl +++ b/src/bijectors/coupling.jl @@ -215,7 +215,7 @@ function (cl::Coupling)(x::AbstractVector) return combine(cl.mask, b(x_1), x_2, x_3) end function (cl::Coupling)(x::AbstractMatrix) - return hcat([cl(x[:, i]) for i = 1:size(x, 2)]...) + return eachcolmaphcat(cl, x) end @@ -230,7 +230,7 @@ function (icl::Inverse{<:Coupling})(y::AbstractVector) return combine(cl.mask, ib(y_1), y_2, y_3) end function (icl::Inverse{<:Coupling})(y::AbstractMatrix) - return hcat([icl(y[:, i]) for i = 1:size(y, 2)]...) + return eachcolmaphcat(icl, y) end function logabsdetjac(cl::Coupling, x::AbstractVector) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index c95b33d5..117d916c 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -140,10 +140,10 @@ function test_ad(dist::DistSpec; kwargs...) end end -function test_ad(f, x; rtol = 1e-6, atol = 1e-6) +function test_ad(f, x; rtol = 1e-6, atol = 1e-6, ad = AD) finitediff = FiniteDiff.finite_difference_gradient(f, x) - if AD == "All" || AD == "ForwardDiff_Tracker" + if ad == "All" || ad == "ForwardDiff_Tracker" tracker = Tracker.data(Tracker.gradient(f, x)[1]) @test tracker ≈ finitediff rtol=rtol atol=atol @@ -151,12 +151,12 @@ function test_ad(f, x; rtol = 1e-6, atol = 1e-6) @test forward ≈ finitediff rtol=rtol atol=atol end - if AD == "All" || AD == "Zygote" + if ad == "All" || ad == "Zygote" zygote = Zygote.gradient(f, x)[1] @test zygote ≈ finitediff rtol=rtol atol=atol end - if AD == "All" || AD == "ReverseDiff" + if ad == "All" || ad == "ReverseDiff" reversediff = ReverseDiff.gradient(f, x) @test reversediff ≈ finitediff rtol=rtol atol=atol end diff --git a/test/bijectors/couplings.jl b/test/bijectors/coupling.jl similarity index 54% rename from test/bijectors/couplings.jl rename to test/bijectors/coupling.jl index 49356c26..948002a4 100644 --- a/test/bijectors/couplings.jl +++ b/test/bijectors/coupling.jl @@ -1,11 +1,3 @@ -using Test -using Bijectors -using Random -using LinearAlgebra -using ForwardDiff -using Tracker -import Flux - using Bijectors: Coupling, PartitionMask, @@ -13,7 +5,8 @@ using Bijectors: couple, partition, combine, - Shift + Shift, + Scale @testset "Coupling" begin @testset "PartitionMask" begin @@ -56,19 +49,35 @@ using Bijectors: @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) end - @testset "Tracker" begin - Random.seed!(123) - x = [1., 2., 3.] + # @testset "Tracker" begin + # Random.seed!(123) + # x = [1., 2., 3.] + + # m = PartitionMask(length(x), [1], [2]) + # nn = Flux.Chain(Flux.Dense(1, 2, Flux.sigmoid), Flux.Dense(2, 1)) + # nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn) + # cl = Coupling(θ -> Shift(nn_tracked(θ)), m) - m = PartitionMask(length(x), [1], [2]) - nn = Flux.Chain(Flux.Dense(1, 2, Flux.sigmoid), Flux.Dense(2, 1)) - nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn) - cl = Coupling(θ -> Shift(nn_tracked(θ)), m) + # # should leave two last indices unchanged + # @test cl(x)[2:3] == x[2:3] - # should leave two last indices unchanged - @test cl(x)[2:3] == x[2:3] + # # verify that indeed it's tracked + # @test Tracker.istracked(cl(x)) + # end + + @testset "Classic" begin + m = PartitionMask(3, [1], [2]) - # verify that indeed it's tracked - @test Tracker.istracked(cl(x)) + # With `Scale` + cl = Coupling(x -> Scale(x[1]), m) + x = hcat([-1., -2., -3.], [1., 2., 3.]) + y = hcat([2., -2., -3.], [2., 2., 3.]) + test_bijector(cl, x, y, log.([2., 2.])) + + # With `Shift` + cl = Coupling(x -> Shift(x[1]), m) + x = hcat([-1., -2., -3.], [1., 2., 3.]) + y = hcat([-3., -2., -3.], [3., 2., 3.]) + test_bijector(cl, x, y, zeros(2)) end end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl new file mode 100644 index 00000000..67c35ca0 --- /dev/null +++ b/test/bijectors/utils.jl @@ -0,0 +1,194 @@ +function test_bijector_reals( + b::Bijector{0}, + x_true::Real, + y_true::Real, + logjac_true::Real; + isequal = true, + tol = 1e-6 +) + ib = @inferred inv(b) + y = @inferred b(x_true) + logjac = @inferred logabsdetjac(b, x_true) + ilogjac = @inferred logabsdetjac(ib, y_true) + res = @inferred forward(b, x_true) + + # If `isequal` is false, then we use the computed `y`, + # but if it's true, we use the true `y`. + ires = isequal ? @inferred(forward(inv(b), y_true)) : @inferred(forward(inv(b), y)) + + # Always want the following to hold + @test ires.rv ≈ x_true atol=tol + @test ires.logabsdetjac ≈ -logjac atol=tol + + if isequal + @test y ≈ y_true atol=tol # forward + @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse + @test logjac ≈ logjac_true # logjac forward + @test res.rv ≈ y_true atol=tol # forward using `forward` + @test res.logabsdetjac ≈ logjac_true atol=tol # logjac using `forward` + else + @test y ≠ y_true # forward + @test (@inferred ib(y)) ≈ x_true atol=tol # inverse + @test logjac ≠ logjac_true # logjac forward + @test res.rv ≠ y_true # forward using `forward` + @test res.logabsdetjac ≠ logjac_true # logjac using `forward` + end +end + +function test_bijector_arrays( + b::Bijector, + xs_true::AbstractArray{<:Real}, + ys_true::AbstractArray{<:Real}, + logjacs_true::Union{Real, AbstractArray{<:Real}}; + isequal = true, + tol = 1e-6 +) + ib = @inferred inv(b) + ys = @inferred b(xs_true) + logjacs = @inferred logabsdetjac(b, xs_true) + res = @inferred forward(b, xs_true) + # If `isequal` is false, then we use the computed `y`, + # but if it's true, we use the true `y`. + ires = isequal ? @inferred(forward(inv(b), ys_true)) : @inferred(forward(inv(b), ys)) + + # always want the following to hold + @test ys isa typeof(ys_true) + @test logjacs isa typeof(logjacs_true) + @test mean(abs, ires.rv - xs_true) ≤ tol + @test mean(abs, ires.logabsdetjac + logjacs) ≤ tol + + if isequal + @test mean(abs, ys - ys_true) ≤ tol # forward + @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse + @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward + @test mean(abs, res.rv - ys_true) ≤ tol # forward using `forward` + @test mean(abs, res.logabsdetjac - logjacs_true) ≤ tol # logjac `forward` + @test mean(abs, ires.logabsdetjac + logjacs_true) ≤ tol # inverse logjac `forward` + else + # Don't want the following to be equal to their "true" values + @test mean(abs, ys - ys_true) > tol # forward + @test mean(abs, logjacs - logjacs_true) > tol # logjac forward + @test mean(abs, res.rv - ys_true) > tol # forward using `forward` + + # Still want the following to be equal to the COMPUTED values + @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse + @test mean(abs, res.logabsdetjac - logjacs) ≤ tol # logjac forward using `forward` + end +end + +""" + test_bijector(b::Bijector, xs::Array; kwargs...) + test_bijector(b::Bijector, xs::Array, ys::Array, logjacs::Array; kwargs...) + +Tests the bijector `b` on the inputs `xs` against the, optionally, provided `ys` +and `logjacs`. + +If `ys` and `logjacs` are NOT provided, `isequal` will be set to `false` and +`ys` and `logjacs` will be set to `zeros`. These `ys` and `logjacs` will be +treated as "counter-examples", i.e. values NOT to match. + +# Arguments +- `b::Bijector`: the bijector to test +- `xs`: inputs (has to be several!!!)(has to be several, i.e. a batch!!!) to test +- `ys`: outputs (has to be several, i.e. a batch!!!) to test against +- `logjacs`: `logabsdetjac` outputs (has to be several!!!)(has to be several, i.e. + a batch!!!) to test against + +# Keywords +- `isequal = true`: if `false`, it will be assumed that the given values are + provided as "counter-examples" in the sense that the inputs `xs` should NOT map + to the given outputs. This is useful in cases where one might not know the expected + output, but still wants to test that the evaluation, etc. works. + This is set to `true` by default if `ys` and `logjacs` are not provided. +- `tol = 1e-6`: the absolute tolerance used for the checks. This is also used to check + arrays where we check that the L1-norm is sufficiently small. +""" +function test_bijector(b::Bijector{0}, xs::AbstractVector{<:Real}) + return test_bijector(b, xs, zeros(length(xs)), zeros(length(xs)); isequal = false) +end + +function test_bijector(b::Bijector{1}, xs::AbstractMatrix{<:Real}) + return test_bijector(b, xs, zeros(size(xs)), zeros(size(xs, 2)); isequal = false) +end + +function test_bijector( + b::Bijector{0}, + xs_true::AbstractVector{<:Real}, + ys_true::AbstractVector{<:Real}, + logjacs_true::AbstractVector{<:Real}; + kwargs... +) + ib = inv(b) + + # Batch + test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) + + # Test `logabsdetjac` against jacobians + test_logabsdetjac(b, xs_true) + test_logabsdetjac(b, ys_true) + + for (x_true, y_true, logjac_true) in zip(xs_true, ys_true, logjacs_true) + test_bijector_reals(b, x_true, y_true, logjac_true; kwargs...) + + # Test AD + if isclosedform(b) + test_ad(x -> b(first(x)), [x_true, ]) + end + + if isclosedform(ib) + y = b(x_true) + test_ad(x -> ib(first(x)), [y, ]) + end + + test_ad(x -> logabsdetjac(b, first(x)), [x_true, ]) + end +end + + +function test_bijector( + b::Bijector{1}, + xs_true::AbstractMatrix{<:Real}, + ys_true::AbstractMatrix{<:Real}, + logjacs_true::AbstractVector{<:Real}; + kwargs... +) + ib = inv(b) + + # Batch + test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) + + # Test `logabsdetjac` against jacobians + test_logabsdetjac(b, xs_true) + test_logabsdetjac(b, ys_true) + + for (x_true, y_true, logjac_true) in zip(eachcol(xs_true), eachcol(ys_true), logjacs_true) + # HACK: collect to avoid dealing with sub-arrays and thus allowing us to compare the + # type of the computed output to the "true" output. + test_bijector_arrays(b, collect(x_true), collect(y_true), logjac_true; kwargs...) + + # Test AD + if isclosedform(b) + test_ad(x -> sum(b(x)), collect(x_true)) + end + if isclosedform(ib) + y = b(x_true) + test_ad(x -> sum(ib(x)), y) + end + + test_ad(x -> logabsdetjac(b, x), x_true) + end +end + +function test_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix; tol=1e-6) + if isclosedform(b) + logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)] + @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol + end +end + +function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6) + if isclosedform(b) + logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs] + @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d601d491..3d6acba9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,14 +19,18 @@ using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, const is_TRAVIS = haskey(ENV, "TRAVIS") const GROUP = get(ENV, "GROUP", "All") +# Always include this since it can be useful for other tests. +include("ad/utils.jl") +include("bijectors/utils.jl") + if GROUP == "All" || GROUP == "Interface" include("interface.jl") include("transform.jl") include("norm_flows.jl") include("bijectors/permute.jl") + include("bijectors/coupling.jl") end if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") - include("ad/utils.jl") include("ad/distributions.jl") end From 600fcda672739a1ddf01ec31d3ddf805c7f04495 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Sat, 12 Sep 2020 21:25:08 +0200 Subject: [PATCH 85/89] forgot to include the renamed file --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index d3f6ee32..a1a8d462 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -155,7 +155,7 @@ include("bijectors/truncated.jl") # Normalizing flow related include("bijectors/planar_layer.jl") include("bijectors/radial_layer.jl") -include("bijectors/coupling_layer.jl") +include("bijectors/coupling.jl") include("bijectors/normalise.jl") ################## From b50c27bb4bbca55a40179aea9a1affc2b1b78b28 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 15 Sep 2020 12:42:48 +0100 Subject: [PATCH 86/89] fixed doc for PartitionMask and added keyword to specify type --- src/bijectors/coupling.jl | 82 ++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index c112c41e..73239abd 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -18,11 +18,13 @@ to the first part, and the last part of the vector is not used for anything. # Examples ```julia-repl +julia> using Bijectors: PartitionMask, partition, combine + julia> m = PartitionMask(3, [1], [2]) # <= assumes input-length 3 -PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( - [1, 1] = 1.0, - [2, 1] = 1.0, - [3, 1] = 1.0) +PartitionMask{SparseArrays.SparseMatrixCSC{Bool,Int64}}( + [1, 1] = true, + [2, 1] = true, + [3, 1] = true) julia> # Partition into 3 parts; the last part is inferred to be indices `[3, ]` from # the fact that `[1]` and `[2]` does not make up all indices in `1:3`. @@ -36,7 +38,15 @@ julia> # Recombines the partitions into a vector 2.0 3.0 ``` - +Note that the underlying `SparseMatrix` is using `Bool` as the element type. We can also +specify this to be some other type using the `sp_type` keyword: +```julia-repl +julia> m = PartitionMask(3, [1], [2]; sp_type = Float32) +PartitionMask{SparseArrays.SparseMatrixCSC{Float32,Int64}}( + [1, 1] = 1.0, + [2, 1] = 1.0, + [3, 1] = 1.0) +``` """ struct PartitionMask{A} A_1::A @@ -51,22 +61,23 @@ function PartitionMask( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}, - indices_3::AbstractVector{Int} + indices_3::AbstractVector{Int}; + sp_type = Bool ) - A_1 = spzeros(Bool, n, length(indices_1)); - A_2 = spzeros(Bool, n, length(indices_2)); - A_3 = spzeros(Bool, n, length(indices_3)); + A_1 = spzeros(sp_type, n, length(indices_1)); + A_2 = spzeros(sp_type, n, length(indices_2)); + A_3 = spzeros(sp_type, n, length(indices_3)); for (i, idx) in enumerate(indices_1) - A_1[idx, i] = true + A_1[idx, i] = one(sp_type) end for (i, idx) in enumerate(indices_2) - A_2[idx, i] = true + A_2[idx, i] = one(sp_type) end for (i, idx) in enumerate(indices_3) - A_3[idx, i] = true + A_3[idx, i] = one(sp_type) end return PartitionMask(A_1, A_2, A_3) @@ -75,22 +86,31 @@ end PartitionMask( n::Int, indices_1::AbstractVector{Int}, - indices_2::AbstractVector{Int} -) = PartitionMask(n, indices_1, indices_2, nothing) + indices_2::AbstractVector{Int}; + kwargs... +) = PartitionMask(n, indices_1, indices_2, nothing; kwargs...) PartitionMask( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}, - indices_3::Nothing -) = PartitionMask(n, indices_1, indices_2, [i for i in 1:n if i ∉ (indices_1 ∪ indices_2)]) + indices_3::Nothing; + kwargs... +) = PartitionMask( + n, indices_1, indices_2, [i for i in 1:n if i ∉ (indices_1 ∪ indices_2)]; + kwargs... +) PartitionMask( n::Int, indices_1::AbstractVector{Int}, indices_2::Nothing, - indices_3::AbstractVector{Int} -) = PartitionMask(n, indices_1, [i for i in 1:n if i ∉ (indices_1 ∪ indices_3)], indices_3) + indices_3::AbstractVector{Int}; + kwargs... +) = PartitionMask( + n, indices_1, [i for i in 1:n if i ∉ (indices_1 ∪ indices_3)], indices_3; + kwargs... +) """ PartitionMask(n::Int, indices) @@ -98,27 +118,29 @@ PartitionMask( Assumes you want to _split_ the vector, where `indices` refer to the parts of the vector you want to apply the bijector to. """ -function PartitionMask(n::Int, indices) +function PartitionMask(n::Int, indices; sp_type = Bool) indices_2 = [i for i in 1:n if i ∉ indices] # sparse arrays <3 - A_1 = spzeros(Bool, n, length(indices)); - A_2 = spzeros(Bool, n, length(indices_2)); + A_1 = spzeros(sp_type, n, length(indices)); + A_2 = spzeros(sp_type, n, length(indices_2)); # Like doing: # A[1, 1] = 1.0 # A[3, 2] = 1.0 for (i, idx) in enumerate(indices) - A_1[idx, i] = true + A_1[idx, i] = one(sp_type) end for (i, idx) in enumerate(indices_2) - A_2[idx, i] = true + A_2[idx, i] = one(sp_type) end - return PartitionMask(A_1, A_2, spzeros(Bool, n, 0)) + return PartitionMask(A_1, A_2, spzeros(sp_type, n, 0)) +end +function PartitionMask(x::AbstractVector, indices; kwargs...) + return PartitionMask(length(x), indices; kwargs...) end -PartitionMask(x::AbstractVector, indices) = PartitionMask(length(x), indices) """ combine(m::PartitionMask, x_1, x_2, x_3) @@ -150,7 +172,7 @@ PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( [2, 1] = 1.0, [3, 1] = 1.0) -julia> cl = Coupling(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`; +julia> cl = Coupling(θ -> Shift(θ[1]), m) # <= will do `y[1:1] = x[1:1] + x[2:2]`; julia> x = [1., 2., 3.]; @@ -214,9 +236,7 @@ function (cl::Coupling)(x::AbstractVector) # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end -function (cl::Coupling)(x::AbstractMatrix) - return eachcolmaphcat(cl, x) -end +(cl::Coupling)(x::AbstractMatrix) = eachcolmaphcat(cl, x) function (icl::Inverse{<:Coupling})(y::AbstractVector) @@ -229,9 +249,7 @@ function (icl::Inverse{<:Coupling})(y::AbstractVector) return combine(cl.mask, ib(y_1), y_2, y_3) end -function (icl::Inverse{<:Coupling})(y::AbstractMatrix) - return eachcolmaphcat(icl, y) -end +(icl::Inverse{<:Coupling})(y::AbstractMatrix) = eachcolmaphcat(icl, y) function logabsdetjac(cl::Coupling, x::AbstractVector) x_1, x_2, x_3 = partition(cl.mask, x) From 20de4c9889fe79c7ed3a89d4f9e6ef56f846c8bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 15 Sep 2020 13:21:43 +0100 Subject: [PATCH 87/89] make use of setdiff Co-authored-by: David Widmann --- src/bijectors/coupling.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 73239abd..3741787b 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -97,7 +97,7 @@ PartitionMask( indices_3::Nothing; kwargs... ) = PartitionMask( - n, indices_1, indices_2, [i for i in 1:n if i ∉ (indices_1 ∪ indices_2)]; + n, indices_1, indices_2, setdiff(1:n, indices_1, indices_2); kwargs... ) @@ -108,7 +108,7 @@ PartitionMask( indices_3::AbstractVector{Int}; kwargs... ) = PartitionMask( - n, indices_1, [i for i in 1:n if i ∉ (indices_1 ∪ indices_3)], indices_3; + n, indices_1, setdiff(1:n, indices_1, indices_3), indices_3; kwargs... ) @@ -119,7 +119,7 @@ Assumes you want to _split_ the vector, where `indices` refer to the parts of the vector you want to apply the bijector to. """ function PartitionMask(n::Int, indices; sp_type = Bool) - indices_2 = [i for i in 1:n if i ∉ indices] + indices_2 = setdiff(1:n, indices) # sparse arrays <3 A_1 = spzeros(sp_type, n, length(indices)); From 74451e89162ff8a95ce90debe7ce645ed3600d9d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 15 Sep 2020 14:19:41 +0100 Subject: [PATCH 88/89] made sp_type a type-parameter instead of kwarg --- src/bijectors/coupling.jl | 66 +++++++++++++++++--------------------- test/bijectors/coupling.jl | 16 --------- 2 files changed, 30 insertions(+), 52 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 3741787b..fcfda61c 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -21,7 +21,7 @@ to the first part, and the last part of the vector is not used for anything. julia> using Bijectors: PartitionMask, partition, combine julia> m = PartitionMask(3, [1], [2]) # <= assumes input-length 3 -PartitionMask{SparseArrays.SparseMatrixCSC{Bool,Int64}}( +PartitionMask{Bool,SparseArrays.SparseMatrixCSC{Bool,Int64}}( [1, 1] = true, [2, 1] = true, [3, 1] = true) @@ -41,76 +41,70 @@ julia> # Recombines the partitions into a vector Note that the underlying `SparseMatrix` is using `Bool` as the element type. We can also specify this to be some other type using the `sp_type` keyword: ```julia-repl -julia> m = PartitionMask(3, [1], [2]; sp_type = Float32) -PartitionMask{SparseArrays.SparseMatrixCSC{Float32,Int64}}( +julia> m = PartitionMask{Float32}(3, [1], [2]) +PartitionMask{Float32,SparseArrays.SparseMatrixCSC{Float32,Int64}}( [1, 1] = 1.0, [2, 1] = 1.0, [3, 1] = 1.0) ``` """ -struct PartitionMask{A} +struct PartitionMask{T, A} A_1::A A_2::A A_3::A # Only make it possible to construct using matrices - PartitionMask(A_1::A, A_2::A, A_3::A) where {A <: AbstractMatrix{<:Real}} = new{A}(A_1, A_2, A_3) + PartitionMask(A_1::A, A_2::A, A_3::A) where {T<:Real, A <: AbstractMatrix{T}} = new{T, A}(A_1, A_2, A_3) end -function PartitionMask( +PartitionMask(args...; kwargs...) = PartitionMask{Bool}(args...; kwargs...) + +function PartitionMask{T}( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}, - indices_3::AbstractVector{Int}; - sp_type = Bool -) - A_1 = spzeros(sp_type, n, length(indices_1)); - A_2 = spzeros(sp_type, n, length(indices_2)); - A_3 = spzeros(sp_type, n, length(indices_3)); + indices_3::AbstractVector{Int} +) where {T<:Real} + A_1 = spzeros(T, n, length(indices_1)); + A_2 = spzeros(T, n, length(indices_2)); + A_3 = spzeros(T, n, length(indices_3)); for (i, idx) in enumerate(indices_1) - A_1[idx, i] = one(sp_type) + A_1[idx, i] = one(T) end for (i, idx) in enumerate(indices_2) - A_2[idx, i] = one(sp_type) + A_2[idx, i] = one(T) end for (i, idx) in enumerate(indices_3) - A_3[idx, i] = one(sp_type) + A_3[idx, i] = one(T) end return PartitionMask(A_1, A_2, A_3) end -PartitionMask( +PartitionMask{T}( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}; - kwargs... -) = PartitionMask(n, indices_1, indices_2, nothing; kwargs...) +) where {T} = PartitionMask{T}(n, indices_1, indices_2, nothing) -PartitionMask( +PartitionMask{T}( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}, indices_3::Nothing; kwargs... -) = PartitionMask( - n, indices_1, indices_2, setdiff(1:n, indices_1, indices_2); - kwargs... -) +) where {T} = PartitionMask{T}(n, indices_1, indices_2, setdiff(1:n, indices_1, indices_2)) -PartitionMask( +PartitionMask{T}( n::Int, indices_1::AbstractVector{Int}, indices_2::Nothing, indices_3::AbstractVector{Int}; kwargs... -) = PartitionMask( - n, indices_1, setdiff(1:n, indices_1, indices_3), indices_3; - kwargs... -) +) where {T} = PartitionMask{T}(n, indices_1, setdiff(1:n, indices_1, indices_3), indices_3) """ PartitionMask(n::Int, indices) @@ -118,28 +112,28 @@ PartitionMask( Assumes you want to _split_ the vector, where `indices` refer to the parts of the vector you want to apply the bijector to. """ -function PartitionMask(n::Int, indices; sp_type = Bool) +function PartitionMask{T}(n::Int, indices) where {T} indices_2 = setdiff(1:n, indices) # sparse arrays <3 - A_1 = spzeros(sp_type, n, length(indices)); - A_2 = spzeros(sp_type, n, length(indices_2)); + A_1 = spzeros(T, n, length(indices)); + A_2 = spzeros(T, n, length(indices_2)); # Like doing: # A[1, 1] = 1.0 # A[3, 2] = 1.0 for (i, idx) in enumerate(indices) - A_1[idx, i] = one(sp_type) + A_1[idx, i] = one(T) end for (i, idx) in enumerate(indices_2) - A_2[idx, i] = one(sp_type) + A_2[idx, i] = one(T) end - return PartitionMask(A_1, A_2, spzeros(sp_type, n, 0)) + return PartitionMask(A_1, A_2, spzeros(T, n, 0)) end -function PartitionMask(x::AbstractVector, indices; kwargs...) - return PartitionMask(length(x), indices; kwargs...) +function PartitionMask{T}(x::AbstractVector, indices) where {T} + return PartitionMask(length(x), indices) end """ diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index 948002a4..7e251816 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -49,22 +49,6 @@ using Bijectors: @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) end - # @testset "Tracker" begin - # Random.seed!(123) - # x = [1., 2., 3.] - - # m = PartitionMask(length(x), [1], [2]) - # nn = Flux.Chain(Flux.Dense(1, 2, Flux.sigmoid), Flux.Dense(2, 1)) - # nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn) - # cl = Coupling(θ -> Shift(nn_tracked(θ)), m) - - # # should leave two last indices unchanged - # @test cl(x)[2:3] == x[2:3] - - # # verify that indeed it's tracked - # @test Tracker.istracked(cl(x)) - # end - @testset "Classic" begin m = PartitionMask(3, [1], [2]) From f16fc63dac0f1cb1ce85eacbfb046934595c20e6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 15 Sep 2020 14:36:40 +0100 Subject: [PATCH 89/89] removed AD testing of coupling with Scale due to Scale has nothing as adjoint --- test/bijectors/coupling.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index 7e251816..fcf1c402 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -57,11 +57,5 @@ using Bijectors: x = hcat([-1., -2., -3.], [1., 2., 3.]) y = hcat([2., -2., -3.], [2., 2., 3.]) test_bijector(cl, x, y, log.([2., 2.])) - - # With `Shift` - cl = Coupling(x -> Shift(x[1]), m) - x = hcat([-1., -2., -3.], [1., 2., 3.]) - y = hcat([-3., -2., -3.], [3., 2., 3.]) - test_bijector(cl, x, y, zeros(2)) end end