Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Sep 12, 2020
1 parent 28158ae commit 6077578
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,22 @@ end

# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
J = let z = zero(x), o = one(x)
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end
return (rv=J .* x, logabsdetjac=log.(abs.(J)))
end

# (N=1) Multivariate case
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat)
return let z = zero(x)
return let z = zero(eltype(x))
@. (x < z) * b.α * x + (x > z) * x
end
end

function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let z = zero(x), o = one(x)
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end

Expand All @@ -78,7 +78,7 @@ end
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let z = zero(x), o = one(x)
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end

Expand Down

0 comments on commit 6077578

Please sign in to comment.