diff --git a/src/bijectors/coupling_layer.jl b/src/bijectors/coupling_layer.jl index 5904d506..68400f02 100644 --- a/src/bijectors/coupling_layer.jl +++ b/src/bijectors/coupling_layer.jl @@ -12,7 +12,7 @@ Implements Note that `PartitionMask` is _not_ a `Bijector`. It is indeed a bijection, but does not follow the `Bijector` interface. -Its main use is in `CouplingLayer` where we want to partition the input into 3 parts, +Its main use is in `Coupling` where we want to partition the input into 3 parts, one part to transform, one part to map into the parameter-space of the transform applied to the first part, and the last part of the vector is not used for anything. @@ -135,10 +135,10 @@ Partitions `x` into 3 disjoint subvectors. @inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x) -# CouplingLayer +# Coupling """ - CouplingLayer{B, M, F}(mask::M, θ::F) + Coupling{F, M}(θ::F, mask::M) Implements a coupling-layer as defined in [1]. @@ -150,7 +150,7 @@ PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}( [2, 1] = 1.0, [3, 1] = 1.0) -julia> cl = CouplingLayer(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`; +julia> cl = Coupling(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`; julia> x = [1., 2., 3.]; @@ -176,75 +176,73 @@ Shift{Array{Float64,1},1}([2.0]) # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct CouplingLayer{B, M, F} <: Bijector{1} where {B, M <: PartitionMask, F} - mask::M +struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask} θ::F + mask::M end -CouplingLayer(B, mask::M) where {M} = CouplingLayer{B, M, typeof(identity)}(mask, identity) -CouplingLayer(B, mask::M, θ::F) where {M, F} = CouplingLayer{B, M, F}(mask, θ) -function CouplingLayer(B, θ, n::Int) +function Coupling(θ, n::Int) idx = Int(floor(n / 2)) - return CouplingLayer(B, PartitionMask(n, 1:idx), θ) + return Coupling(θ, PartitionMask(n, 1:idx)) end -function CouplingLayer(cl::CouplingLayer{B}, mask::PartitionMask) where {B} - return CouplingLayer(B, mask, cl.θ) +function Coupling(cl::Coupling{B}, mask::PartitionMask) where {B} + return Coupling(cl.θ, mask) end "Returns the constructor of the coupling law." -coupling(cl::CouplingLayer{B}) where {B} = B +coupling(cl::Coupling) = cl.θ "Returns the coupling law constructed from `x`." -function couple(cl::CouplingLayer{B}, x::AbstractVector) where {B} +function couple(cl::Coupling, x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) # construct bijector `B` using θ(x₂) - b = B(cl.θ(x_2)) + b = cl.θ(x_2) return b end -function (cl::CouplingLayer{B})(x::AbstractVector) where {B} +function (cl::Coupling)(x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) # construct bijector `B` using θ(x₂) - b = B(cl.θ(x_2)) + b = cl.θ(x_2) # recombine the vector again using the `PartitionMask` return combine(cl.mask, b(x_1), x_2, x_3) end -function (cl::CouplingLayer{B})(x::AbstractMatrix) where {B} +function (cl::Coupling)(x::AbstractMatrix) return hcat([cl(x[:, i]) for i = 1:size(x, 2)]...) end -function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractVector) where {B} +function (icl::Inverse{<:Coupling})(y::AbstractVector) cl = icl.orig y_1, y_2, y_3 = partition(cl.mask, y) - b = B(cl.θ(y_2)) + b = cl.θ(y_2) ib = inv(b) return combine(cl.mask, ib(y_1), y_2, y_3) end -function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractMatrix) where {B} +function (icl::Inverse{<:Coupling})(y::AbstractMatrix) return hcat([icl(y[:, i]) for i = 1:size(y, 2)]...) end -function logabsdetjac(cl::CouplingLayer{B}, x::AbstractVector) where {B} +function logabsdetjac(cl::Coupling, x::AbstractVector) x_1, x_2, x_3 = partition(cl.mask, x) - b = B(cl.θ(x_2)) + b = cl.θ(x_2) # `B` might be 0-dim in which case it will treat `x_1` as a batch # therefore we sum to ensure such a thing does not happen return sum(logabsdetjac(b, x_1)) end -function logabsdetjac(cl::CouplingLayer{B}, x::AbstractMatrix) where {B} +function logabsdetjac(cl::Coupling, x::AbstractMatrix) r = [logabsdetjac(cl, x[:, i]) for i = 1:size(x, 2)] # FIXME: this really needs to be handled in a better way diff --git a/test/bijectors/couplings.jl b/test/bijectors/couplings.jl index 19ca7fe4..49356c26 100644 --- a/test/bijectors/couplings.jl +++ b/test/bijectors/couplings.jl @@ -7,7 +7,7 @@ using Tracker import Flux using Bijectors: - CouplingLayer, + Coupling, PartitionMask, coupling, couple, @@ -15,7 +15,7 @@ using Bijectors: combine, Shift -@testset "CouplingLayer" begin +@testset "Coupling" begin @testset "PartitionMask" begin m1 = PartitionMask(3, [1], [2]) m2 = PartitionMask(3, [1], [2], [3]) @@ -32,12 +32,12 @@ using Bijectors: @testset "Basics" begin m = PartitionMask(3, [1], [2]) - cl1 = CouplingLayer(Shift, m, x -> x[1]) + cl1 = Coupling(x -> Shift(x[1]), m) x = [1., 2., 3.] @test cl1(x) == [3., 2., 3.] - cl2 = CouplingLayer(θ -> Shift(θ[1]), m, identity) + cl2 = Coupling(θ -> Shift(θ[1]), m) @test cl2(x) == cl1(x) # inversion @@ -63,7 +63,7 @@ using Bijectors: m = PartitionMask(length(x), [1], [2]) nn = Flux.Chain(Flux.Dense(1, 2, Flux.sigmoid), Flux.Dense(2, 1)) nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn) - cl = CouplingLayer(Shift, m, nn_tracked) + cl = Coupling(θ -> Shift(nn_tracked(θ)), m) # should leave two last indices unchanged @test cl(x)[2:3] == x[2:3]