Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LeakyReLU #81

Merged
merged 29 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1e27430
added LeakyReLU as a Bijector
torfjelde Feb 9, 2020
180a527
added Compat.jl
torfjelde Feb 11, 2020
0f5200a
forgot to use Compat.jl
torfjelde Feb 12, 2020
7d36fde
Merge branch 'master' into tor/leaky-relu
torfjelde Sep 10, 2020
533b86e
added some tests for LeakyReLU
torfjelde Sep 10, 2020
fd2e33b
using masks rather than ifelse for type-stability and simplicity
torfjelde Sep 10, 2020
8b69cab
removed some redundant comments
torfjelde Sep 10, 2020
ba299ff
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
9b4d0d6
removed unnecessary broadcasting and useless import
torfjelde Sep 11, 2020
dbeb414
Merge branch 'tor/leaky-relu' of https://github.com/TuringLang/Biject…
torfjelde Sep 11, 2020
7771155
fixed a typo
torfjelde Sep 11, 2020
6e9d9f7
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
18e30f9
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c5b058b
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c72fd11
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
68d3518
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c14a5db
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
4a21b73
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
b4119fd
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
1b97a0f
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
1ebd9ba
Apply suggestions from code review
torfjelde Sep 11, 2020
07cc632
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
f2f167e
Apply suggestions from code review
torfjelde Sep 12, 2020
003dfb6
Apply suggestions from code review
torfjelde Sep 12, 2020
e17d5d7
Update src/bijectors/leaky_relu.jl
torfjelde Sep 12, 2020
a83cb7b
Update src/bijectors/leaky_relu.jl
torfjelde Sep 12, 2020
28158ae
Apply suggestions from code review
torfjelde Sep 12, 2020
6077578
Apply suggestions from code review
torfjelde Sep 12, 2020
410166b
Merge branch 'master' into tor/leaky-relu
torfjelde Sep 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
LeakyReLU{T, N}(α::T) <: Bijector{N}

Defines the invertible mapping

x ↦ x if x ≥ 0 else αx

where α > 0.
"""
struct LeakyReLU{T, N} <: Bijector{N}
α::T
end

LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)

up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α)

# (N=0) Univariate case
function (b::LeakyReLU{<:Any, 0})(x::Real)
mask = x < zero(x)
return mask * b.α * x + (1 - mask) * x
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = b.(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

function (ib::Inverse{<:LeakyReLU, 0})(y::Real)
mask = y < zero(y)
return mask * (x / b.α) + (1 - mask) * x
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
(ib::Inverse{<:LeakyReLU{<:Any}, 0})(y::AbstractVector{<:Real}) = ib.(y)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + (1 - mask) * one(x)
return log.(abs.(J))
end
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = logabsdetjac.(b, x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved


# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + (1 - mask) * one(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
return (rv=J * x, logabsdetjac=log(abs(J)))
end

# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
mask = x .< zero(eltype(x))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

Suggested change
mask = x .< zero(eltype(x))
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x))
mask = x .< zero(eltype(x))
J = @. (x < zero(x)) * b.α + (x > zero(x)) * one(x)

would be more efficient (since it avoids the allocation of mask) and more stable (since it does not require computations based on types).

As a side remark, IMO it is a bit unfortunate that currently so many almost identical implementations of a function are needed to define a bijector. Maybe this could be resolved by defining defaults on a high level for basically all bijectors (or a subtype of BaseBijectors, similar to BaseKernels in KernelFunctions) that call just in one or two methods that users would have to implement if their bijectors belong to these simple standard groups. With Julia >= 1.3 it is also possible to define methods for abstract functions such as

(f::Bijector{T,0})(x::AbstractVecOrMat{<:Real}) where T = map(f, x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be more efficient (since it avoids the allocation of mask) and more stable (since it does not require computations based on types).

Won't this call zero(x) n times though, i.e. have the same allocation? Other than that, it seems like a good idea 👍

As a side remark

100%.

Maybe this could be resolved by defining defaults on a high level for basically all bijectors (or a subtype of BaseBijectors, similar to BaseKernels in KernelFunctions) that call just in one or two methods that users would have to implement if their bijectors belong to these simple standard groups.

I'm unaware of what KernelFunctions.jl does. Are they simply not making the struct callable? If so, it was a deliberate design-decision to go with making structs callable when we started out. We were debating whether or not to do this or include some transform method so we could define "one method to rule them all" on an abstract type. Ended up going with the "callable struct" approach, but it def has it's issues, i.e. redundant code.

I've recently played around a bit with actually using a transform method under the hood, but adding a macro which allows you to say @deftransform function transform(b, x) ... end and we just add a (b::Bijector)(x) = transform(b, x) just after the method declaration. This would also allow us to implement in-place versions of all the methods, i.e. transform!(b, x, out), logabsdetjac(b, x, out), and so on. Thoughts?

With Julia >= 1.3 it is also possible to define methods for abstract functions such as

Woah, really? If so that would solve the problem, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unaware of what KernelFunctions.jl does. Are they simply not making the struct callable? If so, it was a deliberate design-decision to go with making structs callable when we started out. We were debating whether or not to do this or include some transform method so we could define "one method to rule them all" on an abstract type. Ended up going with the "callable struct" approach, but it def has it's issues, i.e. redundant code.

No, KernelFunctions API is based on callable structs only (using a function was removed a while ago). But e.g. translation-invariant kernels are built all in the same way, usually they use some simple function to evaluate the distance between the inputs (e.g. using Distances) and then they apply some nonlinear mapping afterwards. With this special structure also all kind of optimizations are possible when constructing kernel matrices etc. Hence there is a special type of such kernels (SimpleKernel IIRC), and then users just define their metric and the nonlinear mapping and get everything else for free. There's more information here: https://juliagaussianprocesses.github.io/KernelFunctions.jl/dev/create_kernel/

Copy link
Member Author

@torfjelde torfjelde Sep 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add on to the above, because it sucks to have to implement (b::Bijector), logabsdetjac and then finally forward, would it be an idea to add a macro that would allow you to define all in one go?

EDIT: Just see #137 :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But e.g. translation-invariant kernels are built all in the same way, usually they use some simple function to evaluate the distance between the inputs (e.g. using Distances) and then they apply some nonlinear mapping afterwards.

Ah, I see. I don't think it's so easy to do for bijectors as we have much more structure than just "forward" evaluation, no?

return (rv=J .* x, logabsdetjac=log.(abs.(J)))
end

# (N=1) Multivariate case, with univariate parameter `α`
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat)
mask = x .< zero(eltype(x))
return mask .* b.α .* x .+ (1 .- mask) .* x
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

function (ib::Inverse{<:LeakyReLU, 1})(y::AbstractVecOrMat)
mask = x .< zero(eltype(y))
return mask .* (y ./ ib.orig.α) .+ (1 .- mask) .* y
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
mask = x .< zero(eltype(x))
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

if x isa AbstractVector
return sum(log.(abs.(J)))
elseif x isa AbstractMatrix
return vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end
end

# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
mask = x .< zero(eltype(x))
J = mask .* b.α .+ (1 .- mask) .* one(eltype(x))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

if x isa AbstractVector
logjac = sum(log.(abs.(J)))
elseif x isa AbstractMatrix
logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end

y = J .* x
return (rv=y, logabsdetjac=logjac)
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ include("bijectors/truncated.jl")
# Normalizing flow related
include("bijectors/planar_layer.jl")
include("bijectors/radial_layer.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/normalise.jl")

##################
Expand Down
88 changes: 88 additions & 0 deletions test/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using Test

using Compat

using Bijectors
using Bijectors: LeakyReLU

using LinearAlgebra
using ForwardDiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my only real comment here is on testing other AD backends like ReverseDiff, but I'm not sure how important that is here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tests for it in test/bijectors/interface.jl. But yeah, testing in Bijectors is honestly a bit of mess atm. In a couple of the other PRs I've added some functionality which makes it easier to use a "standardized" testing suite for a new Bijector, so the plan is to use that in the future 👍


true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x))
true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x))
true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1]
true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs))

@testset "0-dim parameter, 0-dim input" begin
b = LeakyReLU(0.1; dim=Val(0))
x = 1.
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(Float32(1.))) == Float64
@assert eltype(b(Float64(1.))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(0))
@assert eltype(b(Float32(1.))) == Float32
@assert eltype(b(Float64(1.))) == Float64

# logabsdetjac
@test logabsdetjac(b, x) == true_logabsdetjac(b, x)
@test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x)

# Batch
xs = randn(10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac ≈ logabsdetjac(b, xs)
@test f.rv ≈ b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv ≈ b(Float32.(xs))
end

@testset "0-dim parameter, 1-dim input" begin
d = 2

b = LeakyReLU(0.1; dim=Val(1))
x = ones(d)
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Batch
xs = randn(d, 10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac ≈ logabsdetjac(b, xs)
@test f.rv ≈ b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv ≈ b(Float32.(xs))

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(ones(Float32, 2))) == Float64
@assert eltype(b(ones(Float64, 2))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(1))
@assert eltype(b(ones(Float32, 2))) == Float32
@assert eltype(b(ones(Float64, 2))) == Float64
end
8 changes: 5 additions & 3 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Tracker
using DistributionsAD

using Bijectors
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, LeakyReLU

Random.seed!(123)

Expand Down Expand Up @@ -159,7 +159,10 @@ end
(SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)),
(stack(Exp{0}(), Scale(2.0)), randn(2, 3)),
(Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]),
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1))
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)),
(LeakyReLU(0.1), randn(3)),
(LeakyReLU(Float32(0.1)), randn(3)),
(LeakyReLU(0.1; dim = Val(1)), randn(2, 3))
]

for (b, xs) in bs_xs
Expand All @@ -172,7 +175,6 @@ end
x = D == 0 ? xs[1] : xs[:, 1]

y = @inferred b(x)

ys = @inferred b(xs)

# Computations which do not have closed-form implementations are not necessarily
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ if GROUP == "All" || GROUP == "Interface"
include("transform.jl")
include("norm_flows.jl")
include("bijectors/permute.jl")
include("bijectors/leaky_relu.jl")
end

if !is_TRAVIS && (GROUP == "All" || GROUP == "AD")
include("ad/utils.jl")
include("ad/distributions.jl")
end