Skip to content

Commit

Permalink
renamed CouplingLayer to Coupling and combined the two functions
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Sep 12, 2020
1 parent 3ca50a6 commit efc42bb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
46 changes: 22 additions & 24 deletions src/bijectors/coupling_layer.jl
Expand Up @@ -12,7 +12,7 @@ Implements
Note that `PartitionMask` is _not_ a `Bijector`. It is indeed a bijection, but
does not follow the `Bijector` interface.
Its main use is in `CouplingLayer` where we want to partition the input into 3 parts,
Its main use is in `Coupling` where we want to partition the input into 3 parts,
one part to transform, one part to map into the parameter-space of the transform applied
to the first part, and the last part of the vector is not used for anything.
Expand Down Expand Up @@ -135,10 +135,10 @@ Partitions `x` into 3 disjoint subvectors.
@inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x)


# CouplingLayer
# Coupling

"""
CouplingLayer{B, M, F}(mask::M, θ::F)
Coupling{F, M}(θ::F, mask::M)
Implements a coupling-layer as defined in [1].
Expand All @@ -150,7 +150,7 @@ PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}(
[2, 1] = 1.0,
[3, 1] = 1.0)
julia> cl = CouplingLayer(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`;
julia> cl = Coupling(Shift, m, identity) # <= will do `y[1:1] = x[1:1] + x[2:2]`;
julia> x = [1., 2., 3.];
Expand All @@ -176,75 +176,73 @@ Shift{Array{Float64,1},1}([2.0])
# References
[1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019).
"""
struct CouplingLayer{B, M, F} <: Bijector{1} where {B, M <: PartitionMask, F}
mask::M
struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask}
θ::F
mask::M
end

CouplingLayer(B, mask::M) where {M} = CouplingLayer{B, M, typeof(identity)}(mask, identity)
CouplingLayer(B, mask::M, θ::F) where {M, F} = CouplingLayer{B, M, F}(mask, θ)
function CouplingLayer(B, θ, n::Int)
function Coupling(θ, n::Int)
idx = Int(floor(n / 2))
return CouplingLayer(B, PartitionMask(n, 1:idx), θ)
return Coupling, PartitionMask(n, 1:idx))
end

function CouplingLayer(cl::CouplingLayer{B}, mask::PartitionMask) where {B}
return CouplingLayer(B, mask, cl.θ)
function Coupling(cl::Coupling{B}, mask::PartitionMask) where {B}
return Coupling(cl.θ, mask)
end

"Returns the constructor of the coupling law."
coupling(cl::CouplingLayer{B}) where {B} = B
coupling(cl::Coupling) = cl.θ

"Returns the coupling law constructed from `x`."
function couple(cl::CouplingLayer{B}, x::AbstractVector) where {B}
function couple(cl::Coupling, x::AbstractVector)
# partition vector using `cl.mask::PartitionMask`
x_1, x_2, x_3 = partition(cl.mask, x)

# construct bijector `B` using θ(x₂)
b = B(cl.θ(x_2))
b = cl.θ(x_2)

return b
end

function (cl::CouplingLayer{B})(x::AbstractVector) where {B}
function (cl::Coupling)(x::AbstractVector)
# partition vector using `cl.mask::PartitionMask`
x_1, x_2, x_3 = partition(cl.mask, x)

# construct bijector `B` using θ(x₂)
b = B(cl.θ(x_2))
b = cl.θ(x_2)

# recombine the vector again using the `PartitionMask`
return combine(cl.mask, b(x_1), x_2, x_3)
end
function (cl::CouplingLayer{B})(x::AbstractMatrix) where {B}
function (cl::Coupling)(x::AbstractMatrix)
return hcat([cl(x[:, i]) for i = 1:size(x, 2)]...)
end


function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractVector) where {B}
function (icl::Inverse{<:Coupling})(y::AbstractVector)
cl = icl.orig

y_1, y_2, y_3 = partition(cl.mask, y)

b = B(cl.θ(y_2))
b = cl.θ(y_2)
ib = inv(b)

return combine(cl.mask, ib(y_1), y_2, y_3)
end
function (icl::Inverse{<:CouplingLayer{B}})(y::AbstractMatrix) where {B}
function (icl::Inverse{<:Coupling})(y::AbstractMatrix)
return hcat([icl(y[:, i]) for i = 1:size(y, 2)]...)
end

function logabsdetjac(cl::CouplingLayer{B}, x::AbstractVector) where {B}
function logabsdetjac(cl::Coupling, x::AbstractVector)
x_1, x_2, x_3 = partition(cl.mask, x)
b = B(cl.θ(x_2))
b = cl.θ(x_2)

# `B` might be 0-dim in which case it will treat `x_1` as a batch
# therefore we sum to ensure such a thing does not happen
return sum(logabsdetjac(b, x_1))
end

function logabsdetjac(cl::CouplingLayer{B}, x::AbstractMatrix) where {B}
function logabsdetjac(cl::Coupling, x::AbstractMatrix)
r = [logabsdetjac(cl, x[:, i]) for i = 1:size(x, 2)]

# FIXME: this really needs to be handled in a better way
Expand Down
10 changes: 5 additions & 5 deletions test/bijectors/couplings.jl
Expand Up @@ -7,15 +7,15 @@ using Tracker
import Flux

using Bijectors:
CouplingLayer,
Coupling,
PartitionMask,
coupling,
couple,
partition,
combine,
Shift

@testset "CouplingLayer" begin
@testset "Coupling" begin
@testset "PartitionMask" begin
m1 = PartitionMask(3, [1], [2])
m2 = PartitionMask(3, [1], [2], [3])
Expand All @@ -32,12 +32,12 @@ using Bijectors:

@testset "Basics" begin
m = PartitionMask(3, [1], [2])
cl1 = CouplingLayer(Shift, m, x -> x[1])
cl1 = Coupling(x -> Shift(x[1]), m)

x = [1., 2., 3.]
@test cl1(x) == [3., 2., 3.]

cl2 = CouplingLayer-> Shift(θ[1]), m, identity)
cl2 = Coupling-> Shift(θ[1]), m)
@test cl2(x) == cl1(x)

# inversion
Expand All @@ -63,7 +63,7 @@ using Bijectors:
m = PartitionMask(length(x), [1], [2])
nn = Flux.Chain(Flux.Dense(1, 2, Flux.sigmoid), Flux.Dense(2, 1))
nn_tracked = Flux.fmap(x -> (x isa AbstractArray) ? Tracker.param(x) : x, nn)
cl = CouplingLayer(Shift, m, nn_tracked)
cl = Coupling-> Shift(nn_tracked(θ)), m)

# should leave two last indices unchanged
@test cl(x)[2:3] == x[2:3]
Expand Down

0 comments on commit efc42bb

Please sign in to comment.