diff --git a/src/Bijectors.jl b/src/Bijectors.jl index a526f7f8..e435ebf7 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -66,6 +66,7 @@ export TransformDistribution, logpdf_forward, PlanarLayer, RadialLayer, + CouplingLayer, InvertibleBatchNorm if VERSION < v"1.1" diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl new file mode 100644 index 00000000..fcfda61c --- /dev/null +++ b/src/bijectors/coupling.jl @@ -0,0 +1,267 @@ +using SparseArrays + +""" + PartitionMask{A}(A_1::A, A_2::A, A_3::A) where {A} + +This is used to partition and recombine a vector into 3 disjoint "subvectors". + +Implements +- `partition(m::PartitionMask, x)`: partitions `x` into 3 disjoint "subvectors" +- `combine(m::PartitionMask, x_1, x_2, x_3)`: combines 3 disjoint vectors into a single one + +Note that `PartitionMask` is _not_ a `Bijector`. It is indeed a bijection, but +does not follow the `Bijector` interface. + +Its main use is in `Coupling` where we want to partition the input into 3 parts, +one part to transform, one part to map into the parameter-space of the transform applied +to the first part, and the last part of the vector is not used for anything. + +# Examples +```julia-repl +julia> using Bijectors: PartitionMask, partition, combine + +julia> m = PartitionMask(3, [1], [2]) # <= assumes input-length 3 +PartitionMask{Bool,SparseArrays.SparseMatrixCSC{Bool,Int64}}( + [1, 1] = true, + [2, 1] = true, + [3, 1] = true) + +julia> # Partition into 3 parts; the last part is inferred to be indices `[3, ]` from + # the fact that `[1]` and `[2]` does not make up all indices in `1:3`. + x1, x2, x3 = partition(m, [1., 2., 3.]) +([1.0], [2.0], [3.0]) + +julia> # Recombines the partitions into a vector + combine(m, x1, x2, x3) +3-element Array{Float64,1}: + 1.0 + 2.0 + 3.0 +``` +Note that the underlying `SparseMatrix` is using `Bool` as the element type. We can also +specify this to be some other type using the `sp_type` keyword: +```julia-repl +julia> m = PartitionMask{Float32}(3, [1], [2]) +PartitionMask{Float32,SparseArrays.SparseMatrixCSC{Float32,Int64}}( + [1, 1] = 1.0, + [2, 1] = 1.0, + [3, 1] = 1.0) +``` +""" +struct PartitionMask{T, A} + A_1::A + A_2::A + A_3::A + + # Only make it possible to construct using matrices + PartitionMask(A_1::A, A_2::A, A_3::A) where {T<:Real, A <: AbstractMatrix{T}} = new{T, A}(A_1, A_2, A_3) +end + +PartitionMask(args...; kwargs...) = PartitionMask{Bool}(args...; kwargs...) + +function PartitionMask{T}( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::AbstractVector{Int}, + indices_3::AbstractVector{Int} +) where {T<:Real} + A_1 = spzeros(T, n, length(indices_1)); + A_2 = spzeros(T, n, length(indices_2)); + A_3 = spzeros(T, n, length(indices_3)); + + for (i, idx) in enumerate(indices_1) + A_1[idx, i] = one(T) + end + + for (i, idx) in enumerate(indices_2) + A_2[idx, i] = one(T) + end + + for (i, idx) in enumerate(indices_3) + A_3[idx, i] = one(T) + end + + return PartitionMask(A_1, A_2, A_3) +end + +PartitionMask{T}( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::AbstractVector{Int}; +) where {T} = PartitionMask{T}(n, indices_1, indices_2, nothing) + +PartitionMask{T}( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::AbstractVector{Int}, + indices_3::Nothing; + kwargs... +) where {T} = PartitionMask{T}(n, indices_1, indices_2, setdiff(1:n, indices_1, indices_2)) + +PartitionMask{T}( + n::Int, + indices_1::AbstractVector{Int}, + indices_2::Nothing, + indices_3::AbstractVector{Int}; + kwargs... +) where {T} = PartitionMask{T}(n, indices_1, setdiff(1:n, indices_1, indices_3), indices_3) + +""" + PartitionMask(n::Int, indices) + +Assumes you want to _split_ the vector, where `indices` refer to the +parts of the vector you want to apply the bijector to. +""" +function PartitionMask{T}(n::Int, indices) where {T} + indices_2 = setdiff(1:n, indices) + + # sparse arrays <3 + A_1 = spzeros(T, n, length(indices)); + A_2 = spzeros(T, n, length(indices_2)); + + # Like doing: + # A[1, 1] = 1.0 + # A[3, 2] = 1.0 + for (i, idx) in enumerate(indices) + A_1[idx, i] = one(T) + end + + for (i, idx) in enumerate(indices_2) + A_2[idx, i] = one(T) + end + + return PartitionMask(A_1, A_2, spzeros(T, n, 0)) +end +function PartitionMask{T}(x::AbstractVector, indices) where {T} + return PartitionMask(length(x), indices) +end + +""" + combine(m::PartitionMask, x_1, x_2, x_3) + +Combines `x_1`, `x_2`, and `x_3` into a single vector. +""" +@inline combine(m::PartitionMask, x_1, x_2, x_3) = m.A_1 * x_1 .+ m.A_2 * x_2 .+ m.A_3 * x_3 + +""" + partition(m::PartitionMask, x) + +Partitions `x` into 3 disjoint subvectors. +""" +@inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x) + + +# Coupling + +""" + Coupling{F, M}(θ::F, mask::M) + +Implements a coupling-layer as defined in [1]. + +# Examples +```julia-repl +julia> m = PartitionMask(3, [1], [2]) # <= going to use x[2] to parameterize transform of x[1] +PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( + [1, 1] = 1.0, + [2, 1] = 1.0, + [3, 1] = 1.0) + +julia> cl = Coupling(θ -> Shift(θ[1]), m) # <= will do `y[1:1] = x[1:1] + x[2:2]`; + +julia> x = [1., 2., 3.]; + +julia> cl(x) +3-element Array{Float64,1}: + 3.0 + 2.0 + 3.0 + +julia> inv(cl)(cl(x)) +3-element Array{Float64,1}: + 1.0 + 2.0 + 3.0 + +julia> coupling(cl) # get the `Bijector` map `θ -> b(⋅, θ)` +Shift + +julia> couple(cl, x) # get the `Bijector` resulting from `x` +Shift{Array{Float64,1},1}([2.0]) +``` + +# References +[1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). +""" +struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask} + θ::F + mask::M +end + +function Coupling(θ, n::Int) + idx = Int(floor(n / 2)) + return Coupling(θ, PartitionMask(n, 1:idx)) +end + +function Coupling(cl::Coupling{B}, mask::PartitionMask) where {B} + return Coupling(cl.θ, mask) +end + +"Returns the constructor of the coupling law." +coupling(cl::Coupling) = cl.θ + +"Returns the coupling law constructed from `x`." +function couple(cl::Coupling, x::AbstractVector) + # partition vector using `cl.mask::PartitionMask` + x_1, x_2, x_3 = partition(cl.mask, x) + + # construct bijector `B` using θ(x₂) + b = cl.θ(x_2) + + return b +end + +function (cl::Coupling)(x::AbstractVector) + # partition vector using `cl.mask::PartitionMask` + x_1, x_2, x_3 = partition(cl.mask, x) + + # construct bijector `B` using θ(x₂) + b = cl.θ(x_2) + + # recombine the vector again using the `PartitionMask` + return combine(cl.mask, b(x_1), x_2, x_3) +end +(cl::Coupling)(x::AbstractMatrix) = eachcolmaphcat(cl, x) + + +function (icl::Inverse{<:Coupling})(y::AbstractVector) + cl = icl.orig + + y_1, y_2, y_3 = partition(cl.mask, y) + + b = cl.θ(y_2) + ib = inv(b) + + return combine(cl.mask, ib(y_1), y_2, y_3) +end +(icl::Inverse{<:Coupling})(y::AbstractMatrix) = eachcolmaphcat(icl, y) + +function logabsdetjac(cl::Coupling, x::AbstractVector) + x_1, x_2, x_3 = partition(cl.mask, x) + b = cl.θ(x_2) + + # `B` might be 0-dim in which case it will treat `x_1` as a batch + # therefore we sum to ensure such a thing does not happen + return sum(logabsdetjac(b, x_1)) +end + +function logabsdetjac(cl::Coupling, x::AbstractMatrix) + r = [logabsdetjac(cl, x[:, i]) for i = 1:size(x, 2)] + + # FIXME: this really needs to be handled in a better way + # We need to return a `TrackedArray` + if Tracker.istracked(r[1]) + return Tracker.collect(r) + else + return r + end +end diff --git a/src/interface.jl b/src/interface.jl index 534b7b4f..a1a8d462 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -155,6 +155,7 @@ include("bijectors/truncated.jl") # Normalizing flow related include("bijectors/planar_layer.jl") include("bijectors/radial_layer.jl") +include("bijectors/coupling.jl") include("bijectors/normalise.jl") ################## diff --git a/test/ad/utils.jl b/test/ad/utils.jl index c95b33d5..117d916c 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -140,10 +140,10 @@ function test_ad(dist::DistSpec; kwargs...) end end -function test_ad(f, x; rtol = 1e-6, atol = 1e-6) +function test_ad(f, x; rtol = 1e-6, atol = 1e-6, ad = AD) finitediff = FiniteDiff.finite_difference_gradient(f, x) - if AD == "All" || AD == "ForwardDiff_Tracker" + if ad == "All" || ad == "ForwardDiff_Tracker" tracker = Tracker.data(Tracker.gradient(f, x)[1]) @test tracker ≈ finitediff rtol=rtol atol=atol @@ -151,12 +151,12 @@ function test_ad(f, x; rtol = 1e-6, atol = 1e-6) @test forward ≈ finitediff rtol=rtol atol=atol end - if AD == "All" || AD == "Zygote" + if ad == "All" || ad == "Zygote" zygote = Zygote.gradient(f, x)[1] @test zygote ≈ finitediff rtol=rtol atol=atol end - if AD == "All" || AD == "ReverseDiff" + if ad == "All" || ad == "ReverseDiff" reversediff = ReverseDiff.gradient(f, x) @test reversediff ≈ finitediff rtol=rtol atol=atol end diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl new file mode 100644 index 00000000..fcf1c402 --- /dev/null +++ b/test/bijectors/coupling.jl @@ -0,0 +1,61 @@ +using Bijectors: + Coupling, + PartitionMask, + coupling, + couple, + partition, + combine, + Shift, + Scale + +@testset "Coupling" begin + @testset "PartitionMask" begin + m1 = PartitionMask(3, [1], [2]) + m2 = PartitionMask(3, [1], [2], [3]) + + @test (m1.A_1 == m2.A_1) & (m1.A_2 == m2.A_2) & (m1.A_3 == m2.A_3) + + x = [1., 2., 3.] + x1, x2, x3 = partition(m1, x) + @test (x1 == [1.]) & (x2 == [2.]) & (x3 == [3.]) + + y = combine(m1, x1, x2, x3) + @test y == x + end + + @testset "Basics" begin + m = PartitionMask(3, [1], [2]) + cl1 = Coupling(x -> Shift(x[1]), m) + + x = [1., 2., 3.] + @test cl1(x) == [3., 2., 3.] + + cl2 = Coupling(θ -> Shift(θ[1]), m) + @test cl2(x) == cl1(x) + + # inversion + icl1 = inv(cl1) + @test icl1(cl1(x)) == x + @test inv(cl2)(cl2(x)) == x + + # This `cl2` should result in + b = Shift(x[2:2]) + + # logabsdetjac + @test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1]) + + # forward + @test forward(cl1, x) == (rv = cl1(x), logabsdetjac = logabsdetjac(cl1, x)) + @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) + end + + @testset "Classic" begin + m = PartitionMask(3, [1], [2]) + + # With `Scale` + cl = Coupling(x -> Scale(x[1]), m) + x = hcat([-1., -2., -3.], [1., 2., 3.]) + y = hcat([2., -2., -3.], [2., 2., 3.]) + test_bijector(cl, x, y, log.([2., 2.])) + end +end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl new file mode 100644 index 00000000..67c35ca0 --- /dev/null +++ b/test/bijectors/utils.jl @@ -0,0 +1,194 @@ +function test_bijector_reals( + b::Bijector{0}, + x_true::Real, + y_true::Real, + logjac_true::Real; + isequal = true, + tol = 1e-6 +) + ib = @inferred inv(b) + y = @inferred b(x_true) + logjac = @inferred logabsdetjac(b, x_true) + ilogjac = @inferred logabsdetjac(ib, y_true) + res = @inferred forward(b, x_true) + + # If `isequal` is false, then we use the computed `y`, + # but if it's true, we use the true `y`. + ires = isequal ? @inferred(forward(inv(b), y_true)) : @inferred(forward(inv(b), y)) + + # Always want the following to hold + @test ires.rv ≈ x_true atol=tol + @test ires.logabsdetjac ≈ -logjac atol=tol + + if isequal + @test y ≈ y_true atol=tol # forward + @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse + @test logjac ≈ logjac_true # logjac forward + @test res.rv ≈ y_true atol=tol # forward using `forward` + @test res.logabsdetjac ≈ logjac_true atol=tol # logjac using `forward` + else + @test y ≠ y_true # forward + @test (@inferred ib(y)) ≈ x_true atol=tol # inverse + @test logjac ≠ logjac_true # logjac forward + @test res.rv ≠ y_true # forward using `forward` + @test res.logabsdetjac ≠ logjac_true # logjac using `forward` + end +end + +function test_bijector_arrays( + b::Bijector, + xs_true::AbstractArray{<:Real}, + ys_true::AbstractArray{<:Real}, + logjacs_true::Union{Real, AbstractArray{<:Real}}; + isequal = true, + tol = 1e-6 +) + ib = @inferred inv(b) + ys = @inferred b(xs_true) + logjacs = @inferred logabsdetjac(b, xs_true) + res = @inferred forward(b, xs_true) + # If `isequal` is false, then we use the computed `y`, + # but if it's true, we use the true `y`. + ires = isequal ? @inferred(forward(inv(b), ys_true)) : @inferred(forward(inv(b), ys)) + + # always want the following to hold + @test ys isa typeof(ys_true) + @test logjacs isa typeof(logjacs_true) + @test mean(abs, ires.rv - xs_true) ≤ tol + @test mean(abs, ires.logabsdetjac + logjacs) ≤ tol + + if isequal + @test mean(abs, ys - ys_true) ≤ tol # forward + @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse + @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward + @test mean(abs, res.rv - ys_true) ≤ tol # forward using `forward` + @test mean(abs, res.logabsdetjac - logjacs_true) ≤ tol # logjac `forward` + @test mean(abs, ires.logabsdetjac + logjacs_true) ≤ tol # inverse logjac `forward` + else + # Don't want the following to be equal to their "true" values + @test mean(abs, ys - ys_true) > tol # forward + @test mean(abs, logjacs - logjacs_true) > tol # logjac forward + @test mean(abs, res.rv - ys_true) > tol # forward using `forward` + + # Still want the following to be equal to the COMPUTED values + @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse + @test mean(abs, res.logabsdetjac - logjacs) ≤ tol # logjac forward using `forward` + end +end + +""" + test_bijector(b::Bijector, xs::Array; kwargs...) + test_bijector(b::Bijector, xs::Array, ys::Array, logjacs::Array; kwargs...) + +Tests the bijector `b` on the inputs `xs` against the, optionally, provided `ys` +and `logjacs`. + +If `ys` and `logjacs` are NOT provided, `isequal` will be set to `false` and +`ys` and `logjacs` will be set to `zeros`. These `ys` and `logjacs` will be +treated as "counter-examples", i.e. values NOT to match. + +# Arguments +- `b::Bijector`: the bijector to test +- `xs`: inputs (has to be several!!!)(has to be several, i.e. a batch!!!) to test +- `ys`: outputs (has to be several, i.e. a batch!!!) to test against +- `logjacs`: `logabsdetjac` outputs (has to be several!!!)(has to be several, i.e. + a batch!!!) to test against + +# Keywords +- `isequal = true`: if `false`, it will be assumed that the given values are + provided as "counter-examples" in the sense that the inputs `xs` should NOT map + to the given outputs. This is useful in cases where one might not know the expected + output, but still wants to test that the evaluation, etc. works. + This is set to `true` by default if `ys` and `logjacs` are not provided. +- `tol = 1e-6`: the absolute tolerance used for the checks. This is also used to check + arrays where we check that the L1-norm is sufficiently small. +""" +function test_bijector(b::Bijector{0}, xs::AbstractVector{<:Real}) + return test_bijector(b, xs, zeros(length(xs)), zeros(length(xs)); isequal = false) +end + +function test_bijector(b::Bijector{1}, xs::AbstractMatrix{<:Real}) + return test_bijector(b, xs, zeros(size(xs)), zeros(size(xs, 2)); isequal = false) +end + +function test_bijector( + b::Bijector{0}, + xs_true::AbstractVector{<:Real}, + ys_true::AbstractVector{<:Real}, + logjacs_true::AbstractVector{<:Real}; + kwargs... +) + ib = inv(b) + + # Batch + test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) + + # Test `logabsdetjac` against jacobians + test_logabsdetjac(b, xs_true) + test_logabsdetjac(b, ys_true) + + for (x_true, y_true, logjac_true) in zip(xs_true, ys_true, logjacs_true) + test_bijector_reals(b, x_true, y_true, logjac_true; kwargs...) + + # Test AD + if isclosedform(b) + test_ad(x -> b(first(x)), [x_true, ]) + end + + if isclosedform(ib) + y = b(x_true) + test_ad(x -> ib(first(x)), [y, ]) + end + + test_ad(x -> logabsdetjac(b, first(x)), [x_true, ]) + end +end + + +function test_bijector( + b::Bijector{1}, + xs_true::AbstractMatrix{<:Real}, + ys_true::AbstractMatrix{<:Real}, + logjacs_true::AbstractVector{<:Real}; + kwargs... +) + ib = inv(b) + + # Batch + test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) + + # Test `logabsdetjac` against jacobians + test_logabsdetjac(b, xs_true) + test_logabsdetjac(b, ys_true) + + for (x_true, y_true, logjac_true) in zip(eachcol(xs_true), eachcol(ys_true), logjacs_true) + # HACK: collect to avoid dealing with sub-arrays and thus allowing us to compare the + # type of the computed output to the "true" output. + test_bijector_arrays(b, collect(x_true), collect(y_true), logjac_true; kwargs...) + + # Test AD + if isclosedform(b) + test_ad(x -> sum(b(x)), collect(x_true)) + end + if isclosedform(ib) + y = b(x_true) + test_ad(x -> sum(ib(x)), y) + end + + test_ad(x -> logabsdetjac(b, x), x_true) + end +end + +function test_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix; tol=1e-6) + if isclosedform(b) + logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)] + @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol + end +end + +function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6) + if isclosedform(b) + logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs] + @test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d601d491..3d6acba9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,14 +19,18 @@ using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, const is_TRAVIS = haskey(ENV, "TRAVIS") const GROUP = get(ENV, "GROUP", "All") +# Always include this since it can be useful for other tests. +include("ad/utils.jl") +include("bijectors/utils.jl") + if GROUP == "All" || GROUP == "Interface" include("interface.jl") include("transform.jl") include("norm_flows.jl") include("bijectors/permute.jl") + include("bijectors/coupling.jl") end if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") - include("ad/utils.jl") include("ad/distributions.jl") end