-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LeakyReLU #81
Merged
Merged
LeakyReLU #81
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
1e27430
added LeakyReLU as a Bijector
torfjelde 180a527
added Compat.jl
torfjelde 0f5200a
forgot to use Compat.jl
torfjelde 7d36fde
Merge branch 'master' into tor/leaky-relu
torfjelde 533b86e
added some tests for LeakyReLU
torfjelde fd2e33b
using masks rather than ifelse for type-stability and simplicity
torfjelde 8b69cab
removed some redundant comments
torfjelde ba299ff
Update src/bijectors/leaky_relu.jl
torfjelde 9b4d0d6
removed unnecessary broadcasting and useless import
torfjelde dbeb414
Merge branch 'tor/leaky-relu' of https://github.com/TuringLang/Biject…
torfjelde 7771155
fixed a typo
torfjelde 6e9d9f7
Update src/bijectors/leaky_relu.jl
torfjelde 18e30f9
Update src/bijectors/leaky_relu.jl
torfjelde c5b058b
Update src/bijectors/leaky_relu.jl
torfjelde c72fd11
Update src/bijectors/leaky_relu.jl
torfjelde 68d3518
Update src/bijectors/leaky_relu.jl
torfjelde c14a5db
Update src/bijectors/leaky_relu.jl
torfjelde 4a21b73
Update src/bijectors/leaky_relu.jl
torfjelde b4119fd
Update src/bijectors/leaky_relu.jl
torfjelde 1b97a0f
Update src/bijectors/leaky_relu.jl
torfjelde 1ebd9ba
Apply suggestions from code review
torfjelde 07cc632
Update src/bijectors/leaky_relu.jl
torfjelde f2f167e
Apply suggestions from code review
torfjelde 003dfb6
Apply suggestions from code review
torfjelde e17d5d7
Update src/bijectors/leaky_relu.jl
torfjelde a83cb7b
Update src/bijectors/leaky_relu.jl
torfjelde 28158ae
Apply suggestions from code review
torfjelde 6077578
Apply suggestions from code review
torfjelde 410166b
Merge branch 'master' into tor/leaky-relu
torfjelde File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
""" | ||
LeakyReLU{T, N}(α::T) <: Bijector{N} | ||
|
||
Defines the invertible mapping | ||
|
||
x ↦ x if x ≥ 0 else αx | ||
|
||
where α > 0. | ||
""" | ||
struct LeakyReLU{T, N} <: Bijector{N} | ||
α::T | ||
end | ||
|
||
LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α) | ||
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α) | ||
|
||
up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α) | ||
|
||
# (N=0) Univariate case | ||
function (b::LeakyReLU{<:Any, 0})(x::Real) | ||
mask = x < zero(x) | ||
return mask * b.α * x + !mask * x | ||
end | ||
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) | ||
|
||
function Base.inv(b::LeakyReLU{<:Any,N}) where N | ||
invα = inv.(b.α) | ||
return LeakyReLU{typeof(invα),N}(invα) | ||
end | ||
|
||
function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real) | ||
mask = x < zero(x) | ||
J = mask * b.α + (1 - mask) * one(x) | ||
return log(abs(J)) | ||
end | ||
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) | ||
|
||
|
||
# We implement `forward` by hand since we can re-use the computation of | ||
# the Jacobian of the transformation. This will lead to faster sampling | ||
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. | ||
function forward(b::LeakyReLU{<:Any, 0}, x::Real) | ||
mask = x < zero(x) | ||
J = mask * b.α + !mask * one(x) | ||
return (rv=J * x, logabsdetjac=log(abs(J))) | ||
end | ||
|
||
# Batched version | ||
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) | ||
J = let T = eltype(x), z = zero(T), o = one(T) | ||
@. (x < z) * b.α + (x > z) * o | ||
end | ||
return (rv=J .* x, logabsdetjac=log.(abs.(J))) | ||
end | ||
|
||
# (N=1) Multivariate case | ||
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat) | ||
return let z = zero(eltype(x)) | ||
@. (x < z) * b.α * x + (x > z) * x | ||
end | ||
end | ||
|
||
function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) | ||
# Is really diagonal of jacobian | ||
J = let T = eltype(x), z = zero(T), o = one(T) | ||
@. (x < z) * b.α + (x > z) * o | ||
end | ||
|
||
if x isa AbstractVector | ||
return sum(log.(abs.(J))) | ||
elseif x isa AbstractMatrix | ||
return vec(sum(log.(abs.(J)); dims = 1)) # sum along column | ||
end | ||
end | ||
|
||
# We implement `forward` by hand since we can re-use the computation of | ||
# the Jacobian of the transformation. This will lead to faster sampling | ||
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. | ||
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) | ||
# Is really diagonal of jacobian | ||
J = let T = eltype(x), z = zero(T), o = one(T) | ||
@. (x < z) * b.α + (x > z) * o | ||
end | ||
|
||
if x isa AbstractVector | ||
logjac = sum(log.(abs.(J))) | ||
elseif x isa AbstractMatrix | ||
logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column | ||
end | ||
|
||
y = J .* x | ||
return (rv=y, logabsdetjac=logjac) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
using Test | ||
|
||
using Bijectors | ||
using Bijectors: LeakyReLU | ||
|
||
using LinearAlgebra | ||
using ForwardDiff | ||
|
||
true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x)) | ||
true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x)) | ||
true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1] | ||
true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs)) | ||
|
||
@testset "0-dim parameter, 0-dim input" begin | ||
b = LeakyReLU(0.1; dim=Val(0)) | ||
x = 1. | ||
@test inv(b)(b(x)) == x | ||
@test inv(b)(b(-x)) == -x | ||
|
||
# Mixing of types | ||
# 1. Changes in input-type | ||
@assert eltype(b(Float32(1.))) == Float64 | ||
@assert eltype(b(Float64(1.))) == Float64 | ||
|
||
# 2. Changes in parameter-type | ||
b = LeakyReLU(Float32(0.1); dim=Val(0)) | ||
@assert eltype(b(Float32(1.))) == Float32 | ||
@assert eltype(b(Float64(1.))) == Float64 | ||
|
||
# logabsdetjac | ||
@test logabsdetjac(b, x) == true_logabsdetjac(b, x) | ||
@test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x) | ||
|
||
# Batch | ||
xs = randn(10) | ||
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) | ||
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) | ||
|
||
@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) | ||
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) | ||
|
||
# Forward | ||
f = forward(b, xs) | ||
@test f.logabsdetjac ≈ logabsdetjac(b, xs) | ||
@test f.rv ≈ b(xs) | ||
|
||
f = forward(b, Float32.(xs)) | ||
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) | ||
@test f.rv ≈ b(Float32.(xs)) | ||
end | ||
|
||
@testset "0-dim parameter, 1-dim input" begin | ||
d = 2 | ||
|
||
b = LeakyReLU(0.1; dim=Val(1)) | ||
x = ones(d) | ||
@test inv(b)(b(x)) == x | ||
@test inv(b)(b(-x)) == -x | ||
|
||
# Batch | ||
xs = randn(d, 10) | ||
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) | ||
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) | ||
|
||
@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) | ||
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) | ||
|
||
# Forward | ||
f = forward(b, xs) | ||
@test f.logabsdetjac ≈ logabsdetjac(b, xs) | ||
@test f.rv ≈ b(xs) | ||
|
||
f = forward(b, Float32.(xs)) | ||
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) | ||
@test f.rv ≈ b(Float32.(xs)) | ||
|
||
# Mixing of types | ||
# 1. Changes in input-type | ||
@assert eltype(b(ones(Float32, 2))) == Float64 | ||
@assert eltype(b(ones(Float64, 2))) == Float64 | ||
|
||
# 2. Changes in parameter-type | ||
b = LeakyReLU(Float32(0.1); dim=Val(1)) | ||
@assert eltype(b(ones(Float32, 2))) == Float32 | ||
@assert eltype(b(ones(Float64, 2))) == Float64 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my only real comment here is on testing other AD backends like ReverseDiff, but I'm not sure how important that is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some tests for it in
test/bijectors/interface.jl
. But yeah, testing in Bijectors is honestly a bit of mess atm. In a couple of the other PRs I've added some functionality which makes it easier to use a "standardized" testing suite for a newBijector
, so the plan is to use that in the future 👍