Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LeakyReLU #81

Merged
merged 29 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1e27430
added LeakyReLU as a Bijector
torfjelde Feb 9, 2020
180a527
added Compat.jl
torfjelde Feb 11, 2020
0f5200a
forgot to use Compat.jl
torfjelde Feb 12, 2020
7d36fde
Merge branch 'master' into tor/leaky-relu
torfjelde Sep 10, 2020
533b86e
added some tests for LeakyReLU
torfjelde Sep 10, 2020
fd2e33b
using masks rather than ifelse for type-stability and simplicity
torfjelde Sep 10, 2020
8b69cab
removed some redundant comments
torfjelde Sep 10, 2020
ba299ff
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
9b4d0d6
removed unnecessary broadcasting and useless import
torfjelde Sep 11, 2020
dbeb414
Merge branch 'tor/leaky-relu' of https://github.com/TuringLang/Biject…
torfjelde Sep 11, 2020
7771155
fixed a typo
torfjelde Sep 11, 2020
6e9d9f7
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
18e30f9
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c5b058b
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c72fd11
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
68d3518
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c14a5db
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
4a21b73
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
b4119fd
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
1b97a0f
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
1ebd9ba
Apply suggestions from code review
torfjelde Sep 11, 2020
07cc632
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
f2f167e
Apply suggestions from code review
torfjelde Sep 12, 2020
003dfb6
Apply suggestions from code review
torfjelde Sep 12, 2020
e17d5d7
Update src/bijectors/leaky_relu.jl
torfjelde Sep 12, 2020
a83cb7b
Update src/bijectors/leaky_relu.jl
torfjelde Sep 12, 2020
28158ae
Apply suggestions from code review
torfjelde Sep 12, 2020
6077578
Apply suggestions from code review
torfjelde Sep 12, 2020
410166b
Merge branch 'master' into tor/leaky-relu
torfjelde Sep 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
LeakyReLU{T, N}(α::T) <: Bijector{N}

Defines the invertible mapping

x ↦ x if x ≥ 0 else αx

where α > 0.
"""
struct LeakyReLU{T, N} <: Bijector{N}
α::T
end

LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)

up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α)

# (N=0) Univariate case
function (b::LeakyReLU{<:Any, 0})(x::Real)
mask = x < zero(x)
return mask * b.α * x + !mask * x
end
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x)

function Base.inv(b::LeakyReLU{<:Any,N}) where N
invα = inv.(b.α)
return LeakyReLU{typeof(invα),N}(invα)
end

function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + (1 - mask) * one(x)
return log(abs(J))
end
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x)


# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + !mask * one(x)
return (rv=J * x, logabsdetjac=log(abs(J)))
end

# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
J = let z = zero(x), o = one(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO that's not a good fix since it destroys the whole point of just having one broadcast expression that is fused together. Now we would allocate two additional vectors of ones and zeros just for Tracker.

Did you check if we can work around the Tracker bug by using a separate function ispositive(x) = x > zero(x) and isnegative(x) = x < zero(x) that we can include in the broadcast expression instead of the explicit x < zero(x) etc statements?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we would allocate two additional vectors of ones and zeros just for Tracker.

Not two vectors; just two numbers, right? (I realize the above impl was wrong, it was supposed to contain eltype)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, even if we use isnegative, we're running into the same issue because we're broadcasting over one(x), so we would have to also have this in a separate function, e.g. mul1(x, y) = x * one(y) or something 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With eltype it's just a number - IMO it's still unfortunate since there would be no need for this type-based computation if Tracker would be fixed (I opened a PR: FluxML/Tracker.jl#85). In general it's better to not use types if not needed since the instances contain more information (similar argument for why generated functions should be used only if needed). It should just work with this single broadcast expression here.

Apart from that, are the let blocks actually needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW it seems the tests still fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still breaks, but now the only one failing is LeakyReLU{..., 1}. Again it hits a snag at materlialize:

Variables
  b::LeakyReLU{Float64,1}
  x::TrackedArray{…,Array{Float64,2}}
  z::Tracker.TrackedReal{Float64}

Body::Any
1%1 = Bijectors.eltype(x)::Core.Compiler.Const(Tracker.TrackedReal{Float64}, false)
│        (z = Bijectors.zero(%1))
│   %3 = Base.broadcasted(Bijectors.:<, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}}
│   %4 = Base.getproperty(b, )::Float64%5 = Base.broadcasted(Bijectors.:*, %3, %4)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},Float64}}
│   %6 = Base.broadcasted(Bijectors.:>, x, z)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}}
│   %7 = Base.broadcasted(Bijectors.:*, %6, x)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},TrackedArray{,Array{Float64,2}}}}
│   %8 = Base.broadcasted(Bijectors.:+, %5, %7)::Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(<),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},Float64}},Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Tracker.TrackedStyle,Nothing,typeof(>),Tuple{TrackedArray{,Array{Float64,2}},Tracker.TrackedReal{Float64}}},TrackedArray{,Array{Float64,2}}}}}}
│   %9 = Base.materialize(%8)::Any
└──      return %9

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's still unfortunate since there would be no need for this type-based computation if Tracker would be fixed (I opened a PR: FluxML/Tracker.jl#85)

100% agree! And nice!

Apart from that, are the let blocks actually needed?

Nah, I just added them to make explicit that it's only needed for this particular statement, and not to clutter the names in the full function. Was nicer IMO, but ofc subjective.

@. (x < z) * b.α + (x > z) * o
end
return (rv=J .* x, logabsdetjac=log.(abs.(J)))
end

# (N=1) Multivariate case
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat)
return let z = zero(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@. (x < z) * b.α * x + (x > z) * x
end
end

function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let z = zero(x), o = one(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@. (x < z) * b.α + (x > z) * o
end

if x isa AbstractVector
return sum(log.(abs.(J)))
elseif x isa AbstractMatrix
return vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end
end

# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let z = zero(x), o = one(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@. (x < z) * b.α + (x > z) * o
end

if x isa AbstractVector
logjac = sum(log.(abs.(J)))
elseif x isa AbstractMatrix
logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end

y = J .* x
return (rv=y, logabsdetjac=logjac)
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ include("bijectors/truncated.jl")
# Normalizing flow related
include("bijectors/planar_layer.jl")
include("bijectors/radial_layer.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/normalise.jl")

##################
Expand Down
86 changes: 86 additions & 0 deletions test/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Test

using Bijectors
using Bijectors: LeakyReLU

using LinearAlgebra
using ForwardDiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my only real comment here is on testing other AD backends like ReverseDiff, but I'm not sure how important that is here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tests for it in test/bijectors/interface.jl. But yeah, testing in Bijectors is honestly a bit of mess atm. In a couple of the other PRs I've added some functionality which makes it easier to use a "standardized" testing suite for a new Bijector, so the plan is to use that in the future 👍


true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x))
true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x))
true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1]
true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs))

@testset "0-dim parameter, 0-dim input" begin
b = LeakyReLU(0.1; dim=Val(0))
x = 1.
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(Float32(1.))) == Float64
@assert eltype(b(Float64(1.))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(0))
@assert eltype(b(Float32(1.))) == Float32
@assert eltype(b(Float64(1.))) == Float64

# logabsdetjac
@test logabsdetjac(b, x) == true_logabsdetjac(b, x)
@test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x)

# Batch
xs = randn(10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac ≈ logabsdetjac(b, xs)
@test f.rv ≈ b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv ≈ b(Float32.(xs))
end

@testset "0-dim parameter, 1-dim input" begin
d = 2

b = LeakyReLU(0.1; dim=Val(1))
x = ones(d)
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Batch
xs = randn(d, 10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac ≈ logabsdetjac(b, xs)
@test f.rv ≈ b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv ≈ b(Float32.(xs))

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(ones(Float32, 2))) == Float64
@assert eltype(b(ones(Float64, 2))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(1))
@assert eltype(b(ones(Float32, 2))) == Float32
@assert eltype(b(ones(Float64, 2))) == Float64
end
8 changes: 5 additions & 3 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Tracker
using DistributionsAD

using Bijectors
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, LeakyReLU

Random.seed!(123)

Expand Down Expand Up @@ -159,7 +159,10 @@ end
(SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)),
(stack(Exp{0}(), Scale(2.0)), randn(2, 3)),
(Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]),
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1))
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)),
(LeakyReLU(0.1), randn(3)),
(LeakyReLU(Float32(0.1)), randn(3)),
(LeakyReLU(0.1; dim = Val(1)), randn(2, 3))
]

for (b, xs) in bs_xs
Expand All @@ -172,7 +175,6 @@ end
x = D == 0 ? xs[1] : xs[:, 1]

y = @inferred b(x)

ys = @inferred b(xs)

# Computations which do not have closed-form implementations are not necessarily
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ if GROUP == "All" || GROUP == "Interface"
include("transform.jl")
include("norm_flows.jl")
include("bijectors/permute.jl")
include("bijectors/leaky_relu.jl")
end

if !is_TRAVIS && (GROUP == "All" || GROUP == "AD")
include("ad/utils.jl")
include("ad/distributions.jl")
end