Skip to content

Commit

Permalink
fix: ReLU is not invertible
Browse files Browse the repository at this point in the history
  • Loading branch information
schillic committed Feb 28, 2024
1 parent e467be1 commit d81475b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/BackwardAlgorithms/backward_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ function backward(y::AbstractVector, act::ActivationFunction, ::BackwardAlgorith
end

_inverse(x::AbstractVector, act::ActivationFunction) = [_inverse(xi, act) for xi in x]
_inverse(x::Number, ::ReLU) = x >= zero(x) ? x : zero(x)
_inverse(x::Number, ::ReLU) = x > zero(x) ? x : throw(ArgumentError("ReLU cannot be inverted"))
_inverse(x::Number, ::Sigmoid) = @. -log(1 / x - 1)
_inverse(x::Number, act::LeakyReLU) = x >= zero(x) ? x : x / act.slope

Expand Down
4 changes: 2 additions & 2 deletions test/BackwardAlgorithms/backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ end
## union is too complex -> only perform partial tests
@test X Y && low(X) == [-Inf, 1.0] && high(X) == [Inf, Inf]
# union
Y = UnionSetArray([LineSegment([1.0, 1.0], [2.0, 2.0]), Singleton([0.0, 0.0])])
@test backward(Y, ReLU(), algo) == UnionSetArray([Y[1], Pneg])
Y = UnionSetArray([LineSegment([1.0, 1.0], [2.0, 2.0]), Singleton([1.0, 1.0])])
@test backward(Y, ReLU(), algo) == UnionSetArray([Y[1], Singleton([1.0, 1.0])])

# 3D
# positive point
Expand Down

0 comments on commit d81475b

Please sign in to comment.