-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
187 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
""" | ||
LeakyReLU{T, N}(α::T) <: Bijector{N} | ||
Defines the invertible mapping | ||
x ↦ x if x ≥ 0 else αx | ||
where α > 0. | ||
""" | ||
struct LeakyReLU{T, N} <: Bijector{N} | ||
α::T | ||
end | ||
|
||
LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α) | ||
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α) | ||
|
||
up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α) | ||
|
||
# (N=0) Univariate case | ||
function (b::LeakyReLU{<:Any, 0})(x::Real) | ||
mask = x < zero(x) | ||
return mask * b.α * x + !mask * x | ||
end | ||
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) | ||
|
||
function Base.inv(b::LeakyReLU{<:Any,N}) where N | ||
invα = inv.(b.α) | ||
return LeakyReLU{typeof(invα),N}(invα) | ||
end | ||
|
||
function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real) | ||
mask = x < zero(x) | ||
J = mask * b.α + (1 - mask) * one(x) | ||
return log(abs(J)) | ||
end | ||
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) | ||
|
||
|
||
# We implement `forward` by hand since we can re-use the computation of | ||
# the Jacobian of the transformation. This will lead to faster sampling | ||
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. | ||
function forward(b::LeakyReLU{<:Any, 0}, x::Real) | ||
mask = x < zero(x) | ||
J = mask * b.α + !mask * one(x) | ||
return (rv=J * x, logabsdetjac=log(abs(J))) | ||
end | ||
|
||
# Batched version | ||
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) | ||
J = let T = eltype(x), z = zero(T), o = one(T) | ||
@. (x < z) * b.α + (x > z) * o | ||
end | ||
return (rv=J .* x, logabsdetjac=log.(abs.(J))) | ||
end | ||
|
||
# (N=1) Multivariate case | ||
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat) | ||
return let z = zero(eltype(x)) | ||
@. (x < z) * b.α * x + (x > z) * x | ||
end | ||
end | ||
|
||
function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) | ||
# Is really diagonal of jacobian | ||
J = let T = eltype(x), z = zero(T), o = one(T) | ||
@. (x < z) * b.α + (x > z) * o | ||
end | ||
|
||
if x isa AbstractVector | ||
return sum(log.(abs.(J))) | ||
elseif x isa AbstractMatrix | ||
return vec(sum(log.(abs.(J)); dims = 1)) # sum along column | ||
end | ||
end | ||
|
||
# We implement `forward` by hand since we can re-use the computation of | ||
# the Jacobian of the transformation. This will lead to faster sampling | ||
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. | ||
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) | ||
# Is really diagonal of jacobian | ||
J = let T = eltype(x), z = zero(T), o = one(T) | ||
@. (x < z) * b.α + (x > z) * o | ||
end | ||
|
||
if x isa AbstractVector | ||
logjac = sum(log.(abs.(J))) | ||
elseif x isa AbstractMatrix | ||
logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column | ||
end | ||
|
||
y = J .* x | ||
return (rv=y, logabsdetjac=logjac) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
using Test | ||
|
||
using Bijectors | ||
using Bijectors: LeakyReLU | ||
|
||
using LinearAlgebra | ||
using ForwardDiff | ||
|
||
true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x)) | ||
true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x)) | ||
true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1] | ||
true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs)) | ||
|
||
@testset "0-dim parameter, 0-dim input" begin | ||
b = LeakyReLU(0.1; dim=Val(0)) | ||
x = 1. | ||
@test inv(b)(b(x)) == x | ||
@test inv(b)(b(-x)) == -x | ||
|
||
# Mixing of types | ||
# 1. Changes in input-type | ||
@assert eltype(b(Float32(1.))) == Float64 | ||
@assert eltype(b(Float64(1.))) == Float64 | ||
|
||
# 2. Changes in parameter-type | ||
b = LeakyReLU(Float32(0.1); dim=Val(0)) | ||
@assert eltype(b(Float32(1.))) == Float32 | ||
@assert eltype(b(Float64(1.))) == Float64 | ||
|
||
# logabsdetjac | ||
@test logabsdetjac(b, x) == true_logabsdetjac(b, x) | ||
@test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x) | ||
|
||
# Batch | ||
xs = randn(10) | ||
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) | ||
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) | ||
|
||
@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) | ||
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) | ||
|
||
# Forward | ||
f = forward(b, xs) | ||
@test f.logabsdetjac ≈ logabsdetjac(b, xs) | ||
@test f.rv ≈ b(xs) | ||
|
||
f = forward(b, Float32.(xs)) | ||
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) | ||
@test f.rv ≈ b(Float32.(xs)) | ||
end | ||
|
||
@testset "0-dim parameter, 1-dim input" begin | ||
d = 2 | ||
|
||
b = LeakyReLU(0.1; dim=Val(1)) | ||
x = ones(d) | ||
@test inv(b)(b(x)) == x | ||
@test inv(b)(b(-x)) == -x | ||
|
||
# Batch | ||
xs = randn(d, 10) | ||
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) | ||
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) | ||
|
||
@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) | ||
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) | ||
|
||
# Forward | ||
f = forward(b, xs) | ||
@test f.logabsdetjac ≈ logabsdetjac(b, xs) | ||
@test f.rv ≈ b(xs) | ||
|
||
f = forward(b, Float32.(xs)) | ||
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) | ||
@test f.rv ≈ b(Float32.(xs)) | ||
|
||
# Mixing of types | ||
# 1. Changes in input-type | ||
@assert eltype(b(ones(Float32, 2))) == Float64 | ||
@assert eltype(b(ones(Float64, 2))) == Float64 | ||
|
||
# 2. Changes in parameter-type | ||
b = LeakyReLU(Float32(0.1); dim=Val(1)) | ||
@assert eltype(b(ones(Float32, 2))) == Float32 | ||
@assert eltype(b(ones(Float64, 2))) == Float64 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters