Skip to content

Commit

Permalink
Add FlatRule and WSquareRule
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Jul 1, 2022
1 parent e9f7d35 commit a6e2c59
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export LRP
# LRP rules
export AbstractLRPRule
export LRP_CONFIG
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule, PassRule
export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule, ZBoxRule, PassRule
export modify_input, modify_denominator
export modify_param!, modify_layer!
export check_model
Expand Down
25 changes: 24 additions & 1 deletion src/lrp_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,29 @@ function modify_param!(r::GammaRule, param::AbstractArray{T}) where {T}
return nothing
end

"""
WSquareRule()
LRP-``W^2`` rule. Commonly used on the first layer when values are unbounded.
# References
[1]: G. Montavon et al., Explaining nonlinear classification decisions with deep Taylor decomposition
"""
struct WSquareRule <: AbstractLRPRule end
modify_param!(::WSquareRule, p) = p .^= 2
modify_input(::WSquareRule, input) = ones_like(input)

"""
FlatRule()
LRP-Flat rule. Similar to the [`WSquareRule`](@ref), but with all parameters set to one.
# References
[1]: S. Lapuschkin et al., Unmasking Clever Hans predictors and assessing what machines really learn
"""
struct FlatRule <: AbstractLRPRule end
modify_param!(::FlatRule, p) = fill!(p, 0)
modify_input(::FlatRule, input) = ones_like(input)

"""
PassRule()
Expand Down Expand Up @@ -238,7 +261,7 @@ for R in (ZeroRule, EpsilonRule)
end

# Fast implementation for Dense layer using Tullio.jl's einsum notation:
for R in (ZeroRule, EpsilonRule, GammaRule)
for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
@eval function lrp!(Rₖ, rule::$R, layer::Dense, aₖ, Rₖ₊₁)
reset! = get_layer_resetter(rule, layer)
modify_layer!(rule, layer)
Expand Down
23 changes: 23 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,29 @@ CartesianIndex(5, 3)
"""
drop_batch_index(C::CartesianIndex) = CartesianIndex(C.I[1:(end - 1)])

"""
ones_like(x)
Returns array of ones of same shape and type as `x`.
## Example
```julia-repl
julia> x = rand(Float16, 2, 4, 1)
2×4×1 Array{Float16, 3}:
[:, :, 1] =
0.2148 0.9053 0.751 0.358
0.38 0.09033 0.04053 0.6543
julia> ones_like(x)
2×4×1 Array{Float16, 3}:
[:, :, 1] =
1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
```
"""
ones_like(x::AbstractArray) = ones(eltype(x), size(x))
ones_like(x::Number) = oneunit(x)

# Utils for printing model check summary using PrettyTable.jl
_print_name(layer) = "$layer"
_print_name(layer::Parallel) = "Parallel(...)"
Expand Down

0 comments on commit a6e2c59

Please sign in to comment.