Skip to content

Commit

Permalink
Merge 410166b into 84314e7
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Sep 22, 2020
2 parents 84314e7 + 410166b commit 73d7880
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 3 deletions.
93 changes: 93 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
LeakyReLU{T, N}(α::T) <: Bijector{N}
Defines the invertible mapping
x ↦ x if x ≥ 0 else αx
where α > 0.
"""
struct LeakyReLU{T, N} <: Bijector{N}
α::T
end

LeakyReLU::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)

up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α)

# (N=0) Univariate case
function (b::LeakyReLU{<:Any, 0})(x::Real)
mask = x < zero(x)
return mask * b.α * x + !mask * x
end
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x)

function Base.inv(b::LeakyReLU{<:Any,N}) where N
invα = inv.(b.α)
return LeakyReLU{typeof(invα),N}(invα)
end

function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + (1 - mask) * one(x)
return log(abs(J))
end
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x)


# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + !mask * one(x)
return (rv=J * x, logabsdetjac=log(abs(J)))
end

# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end
return (rv=J .* x, logabsdetjac=log.(abs.(J)))
end

# (N=1) Multivariate case
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat)
return let z = zero(eltype(x))
@. (x < z) * b.α * x + (x > z) * x
end
end

function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end

if x isa AbstractVector
return sum(log.(abs.(J)))
elseif x isa AbstractMatrix
return vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end
end

# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end

if x isa AbstractVector
logjac = sum(log.(abs.(J)))
elseif x isa AbstractMatrix
logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end

y = J .* x
return (rv=y, logabsdetjac=logjac)
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ include("bijectors/truncated.jl")
# Normalizing flow related
include("bijectors/planar_layer.jl")
include("bijectors/radial_layer.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/coupling.jl")
include("bijectors/normalise.jl")

Expand Down
86 changes: 86 additions & 0 deletions test/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Test

using Bijectors
using Bijectors: LeakyReLU

using LinearAlgebra
using ForwardDiff

true_logabsdetjac(b::Bijector{0}, x::Real) = (log abs)(ForwardDiff.derivative(b, x))
true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log abs).(ForwardDiff.derivative.(b, x))
true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1]
true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs))

@testset "0-dim parameter, 0-dim input" begin
b = LeakyReLU(0.1; dim=Val(0))
x = 1.
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(Float32(1.))) == Float64
@assert eltype(b(Float64(1.))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(0))
@assert eltype(b(Float32(1.))) == Float32
@assert eltype(b(Float64(1.))) == Float64

# logabsdetjac
@test logabsdetjac(b, x) == true_logabsdetjac(b, x)
@test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x)

# Batch
xs = randn(10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac logabsdetjac(b, xs)
@test f.rv b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv b(Float32.(xs))
end

@testset "0-dim parameter, 1-dim input" begin
d = 2

b = LeakyReLU(0.1; dim=Val(1))
x = ones(d)
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Batch
xs = randn(d, 10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac logabsdetjac(b, xs)
@test f.rv b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv b(Float32.(xs))

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(ones(Float32, 2))) == Float64
@assert eltype(b(ones(Float64, 2))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(1))
@assert eltype(b(ones(Float32, 2))) == Float32
@assert eltype(b(ones(Float64, 2))) == Float64
end
8 changes: 5 additions & 3 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Tracker
using DistributionsAD

using Bijectors
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, LeakyReLU

Random.seed!(123)

Expand Down Expand Up @@ -159,7 +159,10 @@ end
(SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)),
(stack(Exp{0}(), Scale(2.0)), randn(2, 3)),
(Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]),
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1))
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)),
(LeakyReLU(0.1), randn(3)),
(LeakyReLU(Float32(0.1)), randn(3)),
(LeakyReLU(0.1; dim = Val(1)), randn(2, 3))
]

for (b, xs) in bs_xs
Expand All @@ -172,7 +175,6 @@ end
x = D == 0 ? xs[1] : xs[:, 1]

y = @inferred b(x)

ys = @inferred b(xs)

# Computations which do not have closed-form implementations are not necessarily
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ if GROUP == "All" || GROUP == "Interface"
include("transform.jl")
include("norm_flows.jl")
include("bijectors/permute.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/coupling.jl")
end

if !is_TRAVIS && (GROUP == "All" || GROUP == "AD")
include("ad/distributions.jl")
end

0 comments on commit 73d7880

Please sign in to comment.