Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Functors.functor #162

Merged
merged 2 commits into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.12"
version = "0.8.13"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -21,6 +22,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ArgCheck = "1, 2"
Compat = "3"
Distributions = "0.23.3, 0.24"
Functors = "0.1"
MappedArrays = "0.2.2, 0.3"
NNlib = "0.6, 0.7"
NonlinearSolve = "0.3"
Expand Down
15 changes: 4 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ true
```

### Normalizing flows
A very interesting application is that of _normalizing flows_.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: `PlanarFlow` and `RadialFlow`. Let's create a flow with a single `PlanarLayer`:
A very interesting application is that of _normalizing flows_.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: `PlanarLayer` and `RadialLayer`. Let's create a flow with a single `PlanarLayer`:

```julia
julia> d = MvNormal(zeros(2), ones(2));
Expand Down Expand Up @@ -394,7 +394,7 @@ julia> y = rand(flow)
1.3337915588180933
1.010861989639227

julia> logpdf(flow, y) # uses inverse of `b`; not very efficient for `PlanarFlow` and not 100% accurate
julia> logpdf(flow, y) # uses inverse of `b`
-2.8996106373788293

julia> x = rand(flow.dist)
Expand Down Expand Up @@ -471,22 +471,15 @@ julia> Tracker.grad(b.w)
0.0013039074681623036
```

We can easily create more complex flows by simply doing `PlanarFlow(10) ∘ PlanarFlow(10) ∘ RadialFlow(10)` and so on.
We can easily create more complex flows by simply doing `PlanarLayer(10) ∘ PlanarLayer(10) ∘ RadialLayer(10)` and so on.

In those cases, it might be useful to use Flux.jl's `treelike` to extract the parameters:
In those cases, it might be useful to use Flux.jl's `Flux.params` to extract the parameters:
```julia
julia> using Flux

julia> @Flux.treelike Composed

julia> @Flux.treelike TransformedDistribution

julia> @Flux.treelike PlanarLayer

julia> Flux.params(flow)
Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)])
```
Though we might just do this for you in the future, so then all you'll have to do is call `Flux.params`.

Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.

Expand Down
2 changes: 1 addition & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ using LinearAlgebra
using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular
import Functors
import NonlinearSolve

export TransformDistribution,
Expand Down Expand Up @@ -266,7 +267,6 @@ function __init__()
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl")
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl")
@require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl")
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("compat/flux.jl")
end

end # module
3 changes: 3 additions & 0 deletions src/bijectors/composed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ end
Composed(bs::Tuple{Vararg{<:Bijector{N}}}) where N = Composed{typeof(bs),N}(bs)
Composed(bs::AbstractArray{<:Bijector{N}}) where N = Composed{typeof(bs),N}(bs)

# field contains nested numerical parameters
Functors.@functor Composed

isclosedform(b::Composed) = all(isclosedform, b.ts)
up1(b::Composed) = Composed(up1.(b.ts))
function Base.:(==)(b1::Composed{<:Any, N}, b2::Composed{<:Any, N}) where {N}
Expand Down
8 changes: 8 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ end
LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)

# field is a numerical parameter
function Functors.functor(::Type{LeakyReLU{<:Any,N}}, x) where N
function reconstruct_leakyrelu(xs)
return LeakyReLU{typeof(xs.α),N}(xs.α)
end
return (α = x.α,), reconstruct_leakyrelu
end

up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α)

# (N=0) Univariate case
Expand Down
10 changes: 10 additions & 0 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ function Logit(a, b)
T = promote_type(typeof(a), typeof(b))
Logit{0, T}(a, b)
end

# fields are numerical parameters
function Functors.functor(::Type{<:Logit{N}}, x) where N
function reconstruct_logit(xs)
T = promote_type(typeof(xs.a), typeof(xs.b))
return Logit{N,T}(xs.a, xs.b)
end
return (a = x.a, b = x.b,), reconstruct_logit
end

up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b)
# For equality of Logit with Float64 fields to one with Duals
Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b
Expand Down
8 changes: 8 additions & 0 deletions src/bijectors/named_bijector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ struct NamedBijector{names, Bs<:NamedTuple{names}} <: AbstractNamedBijector
bs::Bs
end

# fields contain nested numerical parameters
function Functors.functor(::Type{<:NamedBijector{names}}, x) where names
function reconstruct_namedbijector(xs)
return NamedBijector{names,typeof(xs.bs)}(xs.bs)
end
return (bs = x.bs,), reconstruct_namedbijector
end

names_to_bijectors(b::NamedBijector) = b.bs

@generated function (b::NamedBijector{names1})(
Expand Down
10 changes: 10 additions & 0 deletions src/bijectors/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ function InvertibleBatchNorm(
)
end

# define numerical parameters
# TODO: replace with `Functors.@functor InvertibleBatchNorm (b, logs)` when
# https://github.com/FluxML/Functors.jl/pull/7 is merged
function Functors.functor(::Type{<:InvertibleBatchNorm}, x)
function reconstruct_invertiblebatchnorm(xs)
return InvertibleBatchNorm(xs.b, xs.logs, x.m, x.v, x.eps, x.mtm)
end
return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm
end

function forward(bn::InvertibleBatchNorm, x)
dims = ndims(x)
size(x, dims - 1) == length(bn.b) ||
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ function PlanarLayer(dims::Int, wrapper=identity)
return PlanarLayer(w, u, b)
end

# all fields are numerical parameters
Functors.@functor PlanarLayer

"""
get_u_hat(u::AbstractVector{<:Real}, w::AbstractVector{<:Real})

Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ function RadialLayer(dims::Int, wrapper=identity)
return RadialLayer(α_, β, z_0)
end

# all fields are numerical parameters
Functors.@functor RadialLayer

h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14)
#dh(α, r) = .- (1 ./ (α .+ r)) .^ 2 # for radial flow; derivative of h()

Expand Down
8 changes: 8 additions & 0 deletions src/bijectors/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where
return Scale{typeof(a), D}(a)
end

# field is a numerical parameter
function Functors.functor(::Type{<:Scale{<:Any,N}}, x) where N
function reconstruct_scale(xs)
return Scale{typeof(xs.a),N}(xs.a)
end
return (a = x.a,), reconstruct_scale
end

up1(b::Scale{T, N}) where {N, T} = Scale{T, N + 1}(a)

(b::Scale)(x) = b.a .* x
Expand Down
9 changes: 9 additions & 0 deletions src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a
function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Shift{typeof(a), D}(a)
end

# field is a numerical parameter
function Functors.functor(::Type{<:Shift{<:Any,N}}, x) where N
function reconstruct_shift(xs)
return Shift{typeof(xs.a),N}(xs.a)
end
return (a = x.a,), reconstruct_shift
end

up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a)

(b::Shift)(x) = b.a .+ x
Expand Down
10 changes: 10 additions & 0 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ end
Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...))
Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...))

# define nested numerical parameters
# TODO: replace with `Functors.@functor Stacked (bs,)` when
# https://github.com/FluxML/Functors.jl/pull/7 is merged
function Functors.functor(::Type{<:Stacked}, x)
function reconstruct_stacked(xs)
return Stacked(xs.bs, x.ranges)
end
return (bs = x.bs,), reconstruct_stacked
end

function Base.:(==)(b1::Stacked, b2::Stacked)
bs1, bs2 = b1.bs, b2.bs
if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector)
Expand Down
9 changes: 9 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ TruncatedBijector(lb, ub) = TruncatedBijector{0}(lb, ub)
function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2}
return TruncatedBijector{N, T1, T2}(lb, ub)
end

# field are numerical parameters
function Functors.functor(::Type{<:TruncatedBijector{N}}, x) where N
function reconstruct_truncatedbijector(xs)
return TruncatedBijector{N}(xs.lb, xs.ub)
end
return (lb = x.lb, ub = x.ub,), reconstruct_truncatedbijector
end

up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub)

function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector)
Expand Down
3 changes: 0 additions & 3 deletions src/compat/flux.jl

This file was deleted.

4 changes: 4 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ struct Inverse{B <: Bijector, N} <: Bijector{N}

Inverse(b::B) where {N, B<:Bijector{N}} = new{B, N}(b)
end

# field contains nested numerical parameters
Functors.@functor Inverse

up1(b::Inverse) = Inverse(up1(b.orig))

inv(b::Bijector) = Inverse(b)
Expand Down
2 changes: 2 additions & 0 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<
TransformedDistribution(d::MatrixDistribution, b::Bijector{2}) = new{typeof(d), typeof(b), Matrixvariate}(d, b)
end

# fields may contain nested numerical parameters
Functors.@functor TransformedDistribution

const UnivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Univariate}
const MultivariateTransformed = TransformedDistribution{<:Distribution, <:Bijector, Multivariate}
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -16,6 +17,7 @@ Combinatorics = "1.0.2"
DistributionsAD = "0.6.3"
FiniteDifferences = "0.11"
ForwardDiff = "0.10.12"
Functors = "0.1"
NNlib = "0.7"
ReverseDiff = "1.4.2"
Tracker = "0.2.11"
Expand Down
7 changes: 7 additions & 0 deletions test/bijectors/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,10 @@ function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6)
@test mean(logabsdetjac(b, xs) - logjac_ad) ≤ tol
end
end

# Check if `Functors.functor` works properly
function test_functor(x, xs)
_xs, re = Functors.functor(x)
@test x == re(_xs)
@test _xs == xs
end
Comment on lines +197 to +201
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this being used anywhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in the norm_flows tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, missed that 👍

8 changes: 8 additions & 0 deletions test/norm_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ seed!(1)
y, ladj = forward(bn, x)
@test log(abs(det(ForwardDiff.jacobian(bn, x)))) ≈ sum(ladj)
@test log(abs(det(ForwardDiff.jacobian(inv(bn), y)))) ≈ sum(logabsdetjac(inv(bn), y))

test_functor(bn, (b = bn.b, logs = bn.logs))
end

@testset "PlanarLayer" begin
Expand All @@ -37,6 +39,9 @@ end
flow = PlanarLayer(w, u, b)
z = ones(10, 100)
@test inv(flow)(flow(z)) ≈ z

test_functor(flow, (w = w, u = u, b = b))
test_functor(inv(flow), (orig = flow,))
end

@testset "RadialLayer" begin
Expand All @@ -57,6 +62,9 @@ end
z = ones(10, 100)
flow = RadialLayer(α_, β, z_0)
@test inv(flow)(flow(z)) ≈ z

test_functor(flow, (α_ = α_, β = β, z_0 = z_0))
test_functor(inv(flow), (orig = flow,))
end

@testset "Flows" begin
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Combinatorics
using DistributionsAD
using FiniteDifferences
using ForwardDiff
using Functors
using ReverseDiff
using Tracker
using Zygote
Expand Down