Skip to content

Commit

Permalink
Merge branch 'master' into tor/leaky-relu
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Sep 22, 2020
2 parents 6077578 + 84314e7 commit 410166b
Show file tree
Hide file tree
Showing 10 changed files with 513 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.3"
version = "0.8.5"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
1 change: 1 addition & 0 deletions src/Bijectors.jl
Expand Up @@ -66,6 +66,7 @@ export TransformDistribution,
logpdf_forward,
PlanarLayer,
RadialLayer,
CouplingLayer,
InvertibleBatchNorm

if VERSION < v"1.1"
Expand Down
234 changes: 234 additions & 0 deletions src/bijectors/coupling.jl
@@ -0,0 +1,234 @@
using SparseArrays

"""
PartitionMask{A}(A_1::A, A_2::A, A_3::A) where {A}
This is used to partition and recombine a vector into 3 disjoint "subvectors".
Implements
- `partition(m::PartitionMask, x)`: partitions `x` into 3 disjoint "subvectors"
- `combine(m::PartitionMask, x_1, x_2, x_3)`: combines 3 disjoint vectors into a single one
Note that `PartitionMask` is _not_ a `Bijector`. It is indeed a bijection, but
does not follow the `Bijector` interface.
Its main use is in `Coupling` where we want to partition the input into 3 parts,
one part to transform, one part to map into the parameter-space of the transform applied
to the first part, and the last part of the vector is not used for anything.
# Examples
```julia-repl
julia> using Bijectors: PartitionMask, partition, combine
julia> m = PartitionMask(3, [1], [2]) # <= assumes input-length 3
PartitionMask{Bool,SparseArrays.SparseMatrixCSC{Bool,Int64}}(
[1, 1] = true,
[2, 1] = true,
[3, 1] = true)
julia> # Partition into 3 parts; the last part is inferred to be indices `[3, ]` from
# the fact that `[1]` and `[2]` does not make up all indices in `1:3`.
x1, x2, x3 = partition(m, [1., 2., 3.])
([1.0], [2.0], [3.0])
julia> # Recombines the partitions into a vector
combine(m, x1, x2, x3)
3-element Array{Float64,1}:
1.0
2.0
3.0
```
Note that the underlying `SparseMatrix` is using `Bool` as the element type. We can also
specify this to be some other type using the `sp_type` keyword:
```julia-repl
julia> m = PartitionMask{Float32}(3, [1], [2])
PartitionMask{Float32,SparseArrays.SparseMatrixCSC{Float32,Int64}}(
[1, 1] = 1.0,
[2, 1] = 1.0,
[3, 1] = 1.0)
```
"""
struct PartitionMask{T, A}
A_1::A
A_2::A
A_3::A

# Only make it possible to construct using matrices
PartitionMask(A_1::A, A_2::A, A_3::A) where {T<:Real, A <: AbstractMatrix{T}} = new{T, A}(A_1, A_2, A_3)
end

PartitionMask(args...) = PartitionMask{Bool}(args...)

function PartitionMask{T}(
n::Int,
indices_1::AbstractVector{Int},
indices_2::AbstractVector{Int},
indices_3::AbstractVector{Int}
) where {T<:Real}
A_1 = sparse(indices_1, 1:length(indices_1), one(T), n, length(indices_1))
A_2 = sparse(indices_2, 1:length(indices_2), one(T), n, length(indices_2))
A_3 = sparse(indices_3, 1:length(indices_3), one(T), n, length(indices_3))

return PartitionMask(A_1, A_2, A_3)
end

PartitionMask{T}(
n::Int,
indices_1::AbstractVector{Int},
indices_2::AbstractVector{Int};
) where {T} = PartitionMask{T}(n, indices_1, indices_2, nothing)

PartitionMask{T}(
n::Int,
indices_1::AbstractVector{Int},
indices_2::AbstractVector{Int},
indices_3::Nothing,
) where {T} = PartitionMask{T}(n, indices_1, indices_2, setdiff(1:n, indices_1, indices_2))

PartitionMask{T}(
n::Int,
indices_1::AbstractVector{Int},
indices_2::Nothing,
indices_3::AbstractVector{Int},
) where {T} = PartitionMask{T}(n, indices_1, setdiff(1:n, indices_1, indices_3), indices_3)

"""
PartitionMask(n::Int, indices)
Assumes you want to _split_ the vector, where `indices` refer to the
parts of the vector you want to apply the bijector to.
"""
function PartitionMask{T}(n::Int, indices) where {T}
indices_2 = setdiff(1:n, indices)

# sparse arrays <3
A_1 = sparse(indices, 1:length(indices), one(T), n, length(indices))
A_2 = sparse(indices_2, 1:length(indices_2), one(T), n, length(indices_2))

return PartitionMask(A_1, A_2, spzeros(T, n, 0))
end
function PartitionMask{T}(x::AbstractVector, indices) where {T}
return PartitionMask{T}(length(x), indices)
end

"""
combine(m::PartitionMask, x_1, x_2, x_3)
Combines `x_1`, `x_2`, and `x_3` into a single vector.
"""
@inline combine(m::PartitionMask, x_1, x_2, x_3) = m.A_1 * x_1 .+ m.A_2 * x_2 .+ m.A_3 * x_3

"""
partition(m::PartitionMask, x)
Partitions `x` into 3 disjoint subvectors.
"""
@inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x)


# Coupling

"""
Coupling{F, M}(θ::F, mask::M)
Implements a coupling-layer as defined in [1].
# Examples
```julia-repl
julia> m = PartitionMask(3, [1], [2]) # <= going to use x[2] to parameterize transform of x[1]
PartitionMask{SparseArrays.SparseMatrixCSC{Float64,Int64}}(
[1, 1] = 1.0,
[2, 1] = 1.0,
[3, 1] = 1.0)
julia> cl = Coupling(θ -> Shift(θ[1]), m) # <= will do `y[1:1] = x[1:1] + x[2:2]`;
julia> x = [1., 2., 3.];
julia> cl(x)
3-element Array{Float64,1}:
3.0
2.0
3.0
julia> inv(cl)(cl(x))
3-element Array{Float64,1}:
1.0
2.0
3.0
julia> coupling(cl) # get the `Bijector` map `θ -> b(⋅, θ)`
Shift
julia> couple(cl, x) # get the `Bijector` resulting from `x`
Shift{Array{Float64,1},1}([2.0])
```
# References
[1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019).
"""
struct Coupling{F, M} <: Bijector{1} where {F, M <: PartitionMask}
θ::F
mask::M
end

function Coupling(θ, n::Int)
idx = div(n, 2)
return Coupling(θ, PartitionMask(n, 1:idx))
end

function Coupling(cl::Coupling, mask::PartitionMask)
return Coupling(cl.θ, mask)
end

"Returns the constructor of the coupling law."
coupling(cl::Coupling) = cl.θ

"Returns the coupling law constructed from `x`."
function couple(cl::Coupling, x::AbstractVector)
# partition vector using `cl.mask::PartitionMask`
x_1, x_2, x_3 = partition(cl.mask, x)

# construct bijector `B` using θ(x₂)
b = cl.θ(x_2)

return b
end

function (cl::Coupling)(x::AbstractVector)
# partition vector using `cl.mask::PartitionMask`
x_1, x_2, x_3 = partition(cl.mask, x)

# construct bijector `B` using θ(x₂)
b = cl.θ(x_2)

# recombine the vector again using the `PartitionMask`
return combine(cl.mask, b(x_1), x_2, x_3)
end
(cl::Coupling)(x::AbstractMatrix) = eachcolmaphcat(cl, x)


function (icl::Inverse{<:Coupling})(y::AbstractVector)
cl = icl.orig

y_1, y_2, y_3 = partition(cl.mask, y)

b = cl.θ(y_2)
ib = inv(b)

return combine(cl.mask, ib(y_1), y_2, y_3)
end
(icl::Inverse{<:Coupling})(y::AbstractMatrix) = eachcolmaphcat(icl, y)

function logabsdetjac(cl::Coupling, x::AbstractVector)
x_1, x_2, x_3 = partition(cl.mask, x)
b = cl.θ(x_2)

# `B` might be 0-dim in which case it will treat `x_1` as a batch
# therefore we sum to ensure such a thing does not happen
return sum(logabsdetjac(b, x_1))
end

function logabsdetjac(cl::Coupling, x::AbstractMatrix)
return [logabsdetjac(cl, view(x, :, i)) for i in axes(x, 2)]
end
4 changes: 4 additions & 0 deletions src/bijectors/stacked.jl
Expand Up @@ -104,6 +104,10 @@ function logabsdetjac(
end
end

function logabsdetjac(b::Stacked{<:Any, 1}, x::AbstractVector{<:Real})
return sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
end

function logabsdetjac(b::Stacked, x::AbstractMatrix{<:Real})
return map(eachcol(x)) do c
logabsdetjac(b, c)
Expand Down
1 change: 1 addition & 0 deletions src/interface.jl
Expand Up @@ -156,6 +156,7 @@ include("bijectors/truncated.jl")
include("bijectors/planar_layer.jl")
include("bijectors/radial_layer.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/coupling.jl")
include("bijectors/normalise.jl")

##################
Expand Down
8 changes: 4 additions & 4 deletions test/ad/utils.jl
Expand Up @@ -140,23 +140,23 @@ function test_ad(dist::DistSpec; kwargs...)
end
end

function test_ad(f, x; rtol = 1e-6, atol = 1e-6)
function test_ad(f, x; rtol = 1e-6, atol = 1e-6, ad = AD)
finitediff = FiniteDiff.finite_difference_gradient(f, x)

if AD == "All" || AD == "ForwardDiff_Tracker"
if ad == "All" || ad == "ForwardDiff_Tracker"
tracker = Tracker.data(Tracker.gradient(f, x)[1])
@test tracker finitediff rtol=rtol atol=atol

forward = ForwardDiff.gradient(f, x)
@test forward finitediff rtol=rtol atol=atol
end

if AD == "All" || AD == "Zygote"
if ad == "All" || ad == "Zygote"
zygote = Zygote.gradient(f, x)[1]
@test zygote finitediff rtol=rtol atol=atol
end

if AD == "All" || AD == "ReverseDiff"
if ad == "All" || ad == "ReverseDiff"
reversediff = ReverseDiff.gradient(f, x)
@test reversediff finitediff rtol=rtol atol=atol
end
Expand Down
61 changes: 61 additions & 0 deletions test/bijectors/coupling.jl
@@ -0,0 +1,61 @@
using Bijectors:
Coupling,
PartitionMask,
coupling,
couple,
partition,
combine,
Shift,
Scale

@testset "Coupling" begin
@testset "PartitionMask" begin
m1 = PartitionMask(3, [1], [2])
m2 = PartitionMask(3, [1], [2], [3])

@test (m1.A_1 == m2.A_1) & (m1.A_2 == m2.A_2) & (m1.A_3 == m2.A_3)

x = [1., 2., 3.]
x1, x2, x3 = partition(m1, x)
@test (x1 == [1.]) & (x2 == [2.]) & (x3 == [3.])

y = combine(m1, x1, x2, x3)
@test y == x
end

@testset "Basics" begin
m = PartitionMask(3, [1], [2])
cl1 = Coupling(x -> Shift(x[1]), m)

x = [1., 2., 3.]
@test cl1(x) == [3., 2., 3.]

cl2 = Coupling-> Shift(θ[1]), m)
@test cl2(x) == cl1(x)

# inversion
icl1 = inv(cl1)
@test icl1(cl1(x)) == x
@test inv(cl2)(cl2(x)) == x

# This `cl2` should result in
b = Shift(x[2:2])

# logabsdetjac
@test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1])

# forward
@test forward(cl1, x) == (rv = cl1(x), logabsdetjac = logabsdetjac(cl1, x))
@test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x))
end

@testset "Classic" begin
m = PartitionMask(3, [1], [2])

# With `Scale`
cl = Coupling(x -> Scale(x[1]), m)
x = hcat([-1., -2., -3.], [1., 2., 3.])
y = hcat([2., -2., -3.], [2., 2., 3.])
test_bijector(cl, x, y, log.([2., 2.]))
end
end

0 comments on commit 410166b

Please sign in to comment.