Skip to content

Commit

Permalink
Merge pull request #17 from JuliaReach/schillic/backward
Browse files Browse the repository at this point in the history
Define backward for singletons and invertible activations; fix: ReLU is not invertible
  • Loading branch information
schillic committed Feb 28, 2024
2 parents c15a58e + d81475b commit 4793aa2
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/BackwardAlgorithms/BoxBackward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,10 @@ for T in (Sigmoid, LeakyReLU)
h = _inverse(high(Y), act)
return Hyperrectangle(; low=l, high=h)
end

# disambiguation
function backward(Y::Singleton, act::$T, algo::BoxBackward)
return Singleton(backward(element(Y), act, algo))
end
end
end
12 changes: 12 additions & 0 deletions src/BackwardAlgorithms/PolyhedraBackward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ end

# apply inverse ReLU activation function
function backward(Y::LazySet, act::ReLU, ::PolyhedraBackward)
return _backward_PolyhedraBackward(Y, act)
end

function _backward_PolyhedraBackward(Y::LazySet, act::ReLU)
n = dim(Y)
if n == 1
X = _backward_1D(Y, act)
Expand Down Expand Up @@ -347,6 +351,14 @@ end
# disambiguation
for T in (:ReLU, :LeakyReLU)
@eval begin
function backward(Y::Singleton, act::$T, algo::PolyhedraBackward)
if all(>(0), element(Y))
return Singleton(backward(element(Y), act, algo))
else
return _backward_PolyhedraBackward(Y, act)
end
end

function backward(Y::UnionSetArray, act::$T, algo::PolyhedraBackward)
return _backward_union(Y, act, algo)
end
Expand Down
13 changes: 10 additions & 3 deletions src/BackwardAlgorithms/backward_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,13 @@ append_sets!(Xs, X::LazySet) = push!(Xs, X)
append_sets!(Xs, X::UnionSetArray) = append!(Xs, array(X))

# apply inverse piecewise-affine activation function to a union of sets
# COV_EXCL_START
for T in (:ReLU, :LeakyReLU)
@eval begin
function backward(Y::UnionSetArray, act::$T, algo::BackwardAlgorithm)
return _backward_union(Y, act, algo)
end
end
end
# COV_EXCL_STOP

function _backward_union(Y::LazySet{N}, act::ActivationFunction,
algo::BackwardAlgorithm) where {N}
Expand All @@ -97,10 +95,19 @@ function backward(y::AbstractVector, act::ActivationFunction, ::BackwardAlgorith
end

_inverse(x::AbstractVector, act::ActivationFunction) = [_inverse(xi, act) for xi in x]
_inverse(x::Number, ::ReLU) = x >= zero(x) ? x : zero(x)
_inverse(x::Number, ::ReLU) = x > zero(x) ? x : throw(ArgumentError("ReLU cannot be inverted"))
_inverse(x::Number, ::Sigmoid) = @. -log(1 / x - 1)
_inverse(x::Number, act::LeakyReLU) = x >= zero(x) ? x : x / act.slope

# invertible activations defined for numbers can be defined for singletons
for T in (:Sigmoid, :LeakyReLU)
@eval begin
function backward(Y::Singleton, act::$T, algo::BackwardAlgorithm)
return Singleton(backward(element(Y), act, algo))
end
end
end

# activation functions must be explicitly supported for sets
function backward(X::LazySet, act::ActivationFunction, algo::BackwardAlgorithm)
throw(ArgumentError("activation function $act not supported by algorithm " *
Expand Down
15 changes: 13 additions & 2 deletions test/BackwardAlgorithms/backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ end
## union is too complex -> only perform partial tests
@test X Y && low(X) == [-Inf, 1.0] && high(X) == [Inf, Inf]
# union
Y = UnionSetArray([LineSegment([1.0, 1.0], [2.0, 2.0]), Singleton([0.0, 0.0])])
@test backward(Y, ReLU(), algo) == UnionSetArray([Y[1], Pneg])
Y = UnionSetArray([LineSegment([1.0, 1.0], [2.0, 2.0]), Singleton([1.0, 1.0])])
@test backward(Y, ReLU(), algo) == UnionSetArray([Y[1], Singleton([1.0, 1.0])])

# 3D
# positive point
Expand Down Expand Up @@ -363,6 +363,17 @@ end
for algo in (BoxBackward(),)
@test isequivalent(backward(Y, lr, algo), X)
end

# default algorithm for union
for algo in (DummyBackward(),)
y1 = Singleton([2.0])
y2 = Singleton([3.0])
x1 = backward(y1, lr, algo)
x2 = backward(y2, lr, algo)
Y2 = UnionSetArray([y1, y2])
X2 = backward(Y2, lr, algo)
@test X2 == UnionSetArray([x1, x2])
end
end

@testset "Backward layer" begin
Expand Down

0 comments on commit 4793aa2

Please sign in to comment.