Skip to content

Commit

Permalink
Merge f6151db into dfca913
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Sep 11, 2020
2 parents dfca913 + f6151db commit 94c65b1
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 0 deletions.
230 changes: 230 additions & 0 deletions src/bijectors/named_bijector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
abstract type AbstractNamedBijector <: AbstractBijector end

forward(b::AbstractBijector, x) = (rv = b(x), logabsdetjac = logabsdetjac(b, x))

#######################
### `NamedBijector` ###
#######################
"""
NamedBijector <: AbstractNamedBijector
Wraps a `NamedTuple` of key -> `Bijector` pairs, implementing evaluation, inversion, etc.
# Examples
```julia-repl
julia> using Bijectors: NamedBijector, Scale, Exp
julia> b = NamedBijector((a = Scale(2.0), b = Exp()));
julia> x = (a = 1., b = 0., c = 42.);
julia> b(x)
(a = 2.0, b = 1.0, c = 42.0)
julia> (a = 2 * x.a, b = exp(x.b), c = x.c)
(a = 2.0, b = 1.0, c = 42.0)
```
"""
struct NamedBijector{names, Bs<:NamedTuple{names}} <: AbstractNamedBijector
bs::Bs
end

@inline names_to_bijectors(b::NamedBijector) = b.bs

@generated function (b::NamedBijector{names})(x::NamedTuple) where {names}
return :(merge(x, ($([:($n = b.bs.$n(x.$n)) for n in names]...), )))
end

@generated function Base.inv(b::NamedBijector{names}) where {names}
return :(NamedBijector(($([:($n = inv(b.bs.$n)) for n in names]...), )))
end

@generated function logabsdetjac(b::NamedBijector{names}, x::NamedTuple) where {names}
return :(sum([$([:(logabsdetjac(b.bs.$n, x.$n)) for n in names]...), ]))
end


######################
### `NamedInverse` ###
######################
"""
NamedInverse <: AbstractNamedBijector
Represents the inverse of a `AbstractNamedBijector`, similarily to `Inverse` for `Bijector`.
See also: [`Inverse`](@ref)
"""
struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector
orig::B
end
Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb)
Base.inv(ni::NamedInverse) = ni.orig

logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inv(ni), ni(y))

##########################
### `NamedComposition` ###
##########################
"""
NamedComposition <: AbstractNamedBijector
Wraps a tuple of array of `AbstractNamedBijector` and implements their composition.
This is very similar to `Composed` for `Bijector`, with the exception that we do not require
the inputs to have the same "dimension", which in this case refers to the *symbols* for the
`NamedTuple` that this takes as input.
See also: [`Composed`](@ref)
"""
struct NamedComposition{Bs} <: AbstractNamedBijector
bs::Bs
end

# Essentially just copy-paste from impl of composition for 'standard' bijectors,
# with minor changes here and there.
composel(bs::AbstractNamedBijector...) = NamedComposition(bs)
composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs))
(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1)

inv(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs)))

function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x)
@assert length(cb.bs) > 0
res = cb.bs[1](x)
for b Base.Iterators.drop(cb.bs, 1)
res = b(res)
end

return res
end

@generated function (cb::NamedComposition{T})(x) where {T<:Tuple}
@assert length(T.parameters) > 0
expr = :(x)
for i in 1:length(T.parameters)
expr = :(cb.bs[$i]($expr))
end
return expr
end

function logabsdetjac(cb::NamedComposition, x)
y, logjac = forward(cb.bs[1], x)
for i = 2:length(cb.bs)
res = forward(cb.bs[i], y)
y = res.rv
logjac += res.logabsdetjac
end

return logjac
end

@generated function logabsdetjac(cb::NamedComposition{T}, x) where {T<:Tuple}
N = length(T.parameters)

expr = Expr(:block)
push!(expr.args, :((y, logjac) = forward(cb.bs[1], x)))

for i = 2:N - 1
temp = gensym(:res)
push!(expr.args, :($temp = forward(cb.bs[$i], y)))
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
end
# don't need to evaluate the last bijector, only it's `logabsdetjac`
push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y)))

push!(expr.args, :(return logjac))

return expr
end


function forward(cb::NamedComposition, x)
rv, logjac = forward(cb.bs[1], x)

for t in cb.bs[2:end]
res = forward(t, rv)
rv = res.rv
logjac = res.logabsdetjac + logjac
end
return (rv=rv, logabsdetjac=logjac)
end


@generated function forward(cb::NamedComposition{T}, x) where {T<:Tuple}
expr = Expr(:block)
push!(expr.args, :((y, logjac) = forward(cb.bs[1], x)))
for i = 2:length(T.parameters)
temp = gensym(:temp)
push!(expr.args, :($temp = forward(cb.bs[$i], y)))
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
end
push!(expr.args, :(return (rv = y, logabsdetjac = logjac)))

return expr
end


############################
### `NamedCouplingLayer` ###
############################
# TODO: Add ref to `Coupling` or `CouplingLayer` once that's merged.
"""
NamedCoupling{target, deps, F} <: AbstractNamedBijector
Implements a coupling layer for named bijectors.
# Examples
```julia-repl
julia> using Bijectors: NamedCoupling, Scale
julia> b = NamedCoupling(:b, (:a, :c), (a, c) -> Scale(a + c))
NamedCoupling{:b,(:a, :c),var"#3#4"}(var"#3#4"())
julia> x = (a = 1., b = 2., c = 3.);
julia> b(x)
(a = 1.0, b = 8.0, c = 3.0)
julia> (a = x.a, b = (x.a + x.c) * x.b, c = x.c)
(a = 1.0, b = 8.0, c = 3.0)
```
"""
struct NamedCoupling{target, deps, F} <: AbstractNamedBijector where {F, target}
f::F
end

NamedCoupling(target, deps, f::F) where {F} = NamedCoupling{target, deps, F}(f)
function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F}
return NamedCoupling{target, deps, F}(f)
end

coupling(b::NamedCoupling) = b.f
# For some reason trying to use the parameteric types doesn't always work
# so we have to do this weird approach of extracting type and then index `parameters`.
target(b::NamedCoupling) = typeof(b).parameters[1]
deps(b::NamedCoupling) = typeof(b).parameters[2]

@generated function (nc::NamedCoupling{target, deps, F})(x::NamedTuple) where {target, deps, F}
return quote
b = nc.f($([:(x.$d) for d in deps]...))
return merge(x, ($target = b(x.$target), ))
end
end

@generated function (ni::NamedInverse{<:NamedCoupling{target, deps, F}})(
x::NamedTuple
) where {target, deps, F}
return quote
b = ni.orig.f($([:(x.$d) for d in deps]...))
return merge(x, ($target = inv(b)(x.$target), ))
end
end

@generated function logabsdetjac(nc::NamedCoupling{target, deps, F}, x::NamedTuple) where {target, deps, F}
return quote
b = nc.f($([:(x.$d) for d in deps]...))
return logabsdetjac(b, x.$target)
end
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ include("bijectors/simplex.jl")
include("bijectors/pd.jl")
include("bijectors/corr.jl")
include("bijectors/truncated.jl")
include("bijectors/named_bijector.jl")

# Normalizing flow related
include("bijectors/planar_layer.jl")
Expand Down
54 changes: 54 additions & 0 deletions test/bijectors/named_bijector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using Test
using Bijectors
using Bijectors: Exp, Log, Logit, AbstractNamedBijector, NamedBijector, NamedInverse, NamedCoupling, NamedComposition, Shift

@testset "NamedBijector" begin
b = NamedBijector((a = Exp(), b = Log()))
@test b((a = 0.0, b = exp(1.0))) == (a = 1.0, b = 1.0)
end

@testset "NamedComposition" begin
b = NamedBijector((a = Exp(), ))
x = (a = 0., b = 1.)

nc1 = NamedComposition((b, b))
@test nc1(x) == b(b(x))
@test logabsdetjac(nc1, x) logabsdetjac(b, x) + logabsdetjac(b, b(x))

nc2 = b b
@test nc1 == nc2

inc2 = inv(nc2)
@test (inc2 nc2)(x) == x
@test logabsdetjac((inc2 nc2), x) 0.0
end

@testset "NamedCoupling" begin
nc = NamedCoupling(Val(:b), Val((:a, )), a -> Logit(zero(a), a))
@inferred NamedCoupling(Val(:b), Val((:a, )), Shift)

nc = NamedCoupling(:b, (:a, ), a -> Logit(0., a)) # <= not type-inferrable but eh

@test Bijectors.target(nc) == :b
@test Bijectors.deps(nc) == (:a, )

@inferred Bijectors.target(nc)
@inferred Bijectors.deps(nc)

x = (a = 1.0, b = 0.5, c = 99999.)
@test Bijectors.coupling(nc)(x.a) isa Logit
@test inv(nc)(nc(x)) == x

@test logabsdetjac(nc, x) == logabsdetjac(Logit(0., 1.), x.b)
@test logabsdetjac(inv(nc), nc(x)) == -logabsdetjac(nc, x)

x = (a = 0.0, b = 2.0, c = 1.0)
nc = NamedCoupling(:c, (:a, :b), (a, b) -> Logit(a, b))
@test nc(x).c == 0.0
@test inv(nc)(nc(x)) == x

x = (a = 0.0, b = 2.0, c = 1.0)
nc = NamedCoupling(:c, (:b, ), b -> Shift(b))
@test nc(x).c == 3.0
@test inv(nc)(nc(x)) == x
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ if GROUP == "All" || GROUP == "Interface"
include("transform.jl")
include("norm_flows.jl")
include("bijectors/permute.jl")
include("bijectors/named_bijector.jl")
end

if !is_TRAVIS && (GROUP == "All" || GROUP == "AD")
Expand Down

0 comments on commit 94c65b1

Please sign in to comment.