Skip to content

Commit

Permalink
Simplify LRP analyzer and clean-up rule default parameters (#110)
Browse files Browse the repository at this point in the history
* Simplify LRP analyzer call

* Make color of composite types more visible

* Move LRP rule default params to top of file

* Links to dev docs of "How it works internally"

* Fixes to documentation
  • Loading branch information
adrhill committed Nov 21, 2022
1 parent d65686e commit 0b695aa
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 67 deletions.
6 changes: 2 additions & 4 deletions docs/src/literate/advanced_lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@ analyzer = LRP(model)
# The best point of entry into the source code is
# [`/src/lrp/rules.jl`](https://github.com/adrhill/ExplainableAI.jl/blob/master/src/lrp/rules.jl).
#
# Internally, ExplainableAI pre-allocates modified layers by dispatching `modify_layer`
# on rule and layer types. This constructs the `state` of a LRP analyzer.
#
# Calling `analyze` on a LRP-model then applies a forward-pass of the model,
# Calling `analyze` on a LRP-analyzer pre-allocates modified layers by dispatching
# `modify_layer` on rule and layer types. It then applies a forward-pass of the model,
# keeping track of the activations `aₖ` for each layer `k`.
# The relevance `Rₖ₊₁` is then set to the output neuron activation and the rules are applied
# in a backward-pass over the model layers and previous activations.
Expand Down
2 changes: 1 addition & 1 deletion src/analyze_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type AbstractXAIMethod end

const BATCHDIM_MISSING = ArgumentError(
"""The input is a 1D vector and therefore missing the required batch dimension.
Call analyze with the keyword argument add_batch_dim=false."""
Call `analyze` with the keyword argument `add_batch_dim=false`."""
)

"""
Expand Down
62 changes: 25 additions & 37 deletions src/lrp/lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,25 @@ or by passing a composite, see [`Composite`](@ref) for an example.
[1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
[2] W. Samek et al., Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications
"""
struct LRP{R<:AbstractVector{<:Tuple}} <: AbstractXAIMethod
struct LRP{R<:AbstractVector{<:AbstractLRPRule}} <: AbstractXAIMethod
model::Chain
state::R # each entry is a tuple `(rule, modified_layer)`
end
rules(analyzer::LRP) = map(first, analyzer.state)
modified_layers(analyzer::LRP) = map(last, analyzer.state)

# Construct rule_layer_tuples from model and array of rules
function LRP(
model::Chain,
rules::AbstractVector{<:AbstractLRPRule};
is_flat=false,
skip_checks=false,
verbose=true,
)
!is_flat && (model = flatten_model(model))
if !skip_checks
check_output_softmax(model)
check_model(Val(:LRP), model; verbose=verbose)
end
rules::R

state = map(zip(rules, model.layers)) do (r, l)
!is_compatible(r, l) && throw(LRPCompatibilityError(r, l))
return (r, modify_layer(r, l))
# Construct LRP analyzer by assigning a rule to each layer
function LRP(
model::Chain,
rules::AbstractVector{<:AbstractLRPRule};
is_flat=false,
skip_checks=false,
verbose=true,
)
!is_flat && (model = flatten_model(model))
if !skip_checks
check_output_softmax(model)
check_model(Val(:LRP), model; verbose=verbose)
end
return new{typeof(rules)}(model, rules)
end
return LRP(model, state)
end

# Construct vector of rules by applying composite
Expand All @@ -53,28 +46,23 @@ end
LRP(model::Chain; kwargs...) = LRP(model, Composite(ZeroRule()); kwargs...)

# The call to the LRP analyzer.
function (analyzer::LRP)(
function (lrp::LRP)(
input::AbstractArray{T}, ns::AbstractNeuronSelector; layerwise_relevances=false
) where {T}
# Compute layerwise activations on forward pass through model:
acts = [input, Flux.activations(analyzer.model, input)...]
# Allocate array for layerwise relevances:
rels = similar.(acts)

# Mask output neuron
output_indices = ns(acts[end])
rels[end] .= zero(T)
rels[end][output_indices] = acts[end][output_indices]
acts = [input, Flux.activations(lrp.model, input)...] # compute aₖ for all layers k
rels = similar.(acts) # allocate Rₖ for all layers k
mask_output_neuron!(rels[end], acts[end], ns) # compute Rₖ₊₁ of output layer

# Backward pass through layers, applying LRP rules
for (i, (rule, modified_layer)) in Iterators.reverse(enumerate(analyzer.state))
lrp!(rels[i], rule, modified_layer, acts[i], rels[i + 1]) # inplace update rels[i]
modified_layers = get_modified_layers(lrp.rules, lrp.model.layers)
for i in length(lrp.rules):-1:1
# Backward-pass applying LRP rules, inplace updating rels[i]
lrp!(rels[i], lrp.rules[i], modified_layers[i], acts[i], rels[i + 1])
end

return Explanation(
first(rels),
last(acts),
output_indices,
ns(last(acts)),
:LRP,
ifelse(layerwise_relevances, rels, Nothing),
)
Expand Down
47 changes: 30 additions & 17 deletions src/lrp/rules.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# https://adrhill.github.io/ExplainableAI.jl/stable/generated/advanced_lrp/#How-it-works-internally
# https://adrhill.github.io/ExplainableAI.jl/dev/generated/advanced_lrp/#How-it-works-internally
abstract type AbstractLRPRule end

# Bibliography
Expand All @@ -8,6 +8,13 @@ const REF_MONTAVON_DTD = "G. Montavon et al., *Explaining Nonlinear Classificati
const REF_MONTAVON_OVERVIEW = "G. Montavon et al., *Layer-Wise Relevance Propagation: An Overview*"
const REF_ANDEOL_DOMAIN_INVARIANT = "L. Andéol et al., *Learning Domain Invariant Representations by Joint Wasserstein Distance Minimization*"

# Default parameters
const LRP_DEFAULT_GAMMA = 0.25f0
const LRP_DEFAULT_EPSILON = 1.0f-6
const LRP_DEFAULT_STABILIZER = 1.0f-9
const LRP_DEFAULT_ALPHA = 2.0f0
const LRP_DEFAULT_BETA = 1.0f0

# Generic LRP rule. Used by all rules without custom implementations.
function lrp!(Rₖ, rule::AbstractLRPRule, modified_layer, aₖ, Rₖ₊₁)
ãₖ = modify_input(rule, aₖ)
Expand Down Expand Up @@ -66,7 +73,7 @@ modify_input(rule, input) = input
Modify denominator ``z`` for numerical stability on the forward pass.
"""
modify_denominator(rule, d) = stabilize_denom(d, 1.0f-9)
modify_denominator(rule, d) = stabilize_denom(d, LRP_DEFAULT_STABILIZER)

"""
is_compatible(rule, layer)
Expand Down Expand Up @@ -133,6 +140,13 @@ function modify_layer(rule, layer; keep_bias=true)
return copy_layer(layer, w, b)
end

function get_modified_layers(rules, layers)
return map(zip(rules, layers)) do (r, l)
!is_compatible(r, l) && throw(LRPCompatibilityError(r, l))
modify_layer(r, l)
end
end

# Useful presets, used e.g. in AlphaBetaRule, ZBoxRule & ZPlusRule:
modify_parameters(::Val{:keep_positive}, p) = keep_positive(p)
modify_parameters(::Val{:keep_negative}, p) = keep_negative(p)
Expand Down Expand Up @@ -162,7 +176,7 @@ struct ZeroRule <: AbstractLRPRule end
is_compatible(::ZeroRule, layer) = true # compatible with all layer types

"""
EpsilonRule([epsilon=1.0f-6])
EpsilonRule([epsilon=$(LRP_DEFAULT_EPSILON)])
LRP-``ϵ`` rule. Commonly used on middle layers.
Expand All @@ -173,21 +187,20 @@ R_j^k = \\sum_i\\frac{w_{ij}a_j^k}{\\epsilon +\\sum_{l}w_{il}a_l^k+b_i} R_i^{k+1
```
# Optional arguments
- `epsilon`: Optional stabilization parameter, defaults to `1f-6`.
- `epsilon`: Optional stabilization parameter, defaults to `$(LRP_DEFAULT_EPSILON)`.
# References
- $REF_BACH_LRP
"""
struct EpsilonRule{T<:Real} <: AbstractLRPRule
ϵ::T
EpsilonRule(epsilon=1.0f-6) = new{eltype(epsilon)}(epsilon)
EpsilonRule(epsilon=LRP_DEFAULT_EPSILON) = new{eltype(epsilon)}(epsilon)
end
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
is_compatible(::EpsilonRule, layer) = true # compatible with all layer types

const LRP_GAMMA_DEFAULT = 0.25f0
"""
GammaRule([gamma=$(LRP_GAMMA_DEFAULT)])
GammaRule([gamma=$(LRP_DEFAULT_GAMMA)])
LRP-``γ`` rule. Commonly used on lower layers.
Expand All @@ -199,14 +212,14 @@ R_j^k = \\sum_i\\frac{(w_{ij}+\\gamma w_{ij}^+)a_j^k}
```
# Optional arguments
- `γ`: Optional multiplier for added positive weights, defaults to `$(LRP_GAMMA_DEFAULT)`.
- `gamma`: Optional multiplier for added positive weights, defaults to `$(LRP_DEFAULT_GAMMA)`.
# References
- $REF_MONTAVON_OVERVIEW
"""
struct GammaRule{T<:Real} <: AbstractLRPRule
γ::T
GammaRule(gamma=LRP_GAMMA_DEFAULT) = new{eltype(gamma)}(gamma)
GammaRule(gamma=LRP_DEFAULT_GAMMA) = new{eltype(gamma)}(gamma)
end
function modify_parameters(r::GammaRule, param::AbstractArray)
γ = convert(eltype(param), r.γ)
Expand All @@ -216,7 +229,7 @@ end
# Internally used for GeneralizedGammaRule:
struct NegativeGammaRule{T<:Real} <: AbstractLRPRule
γ::T
NegativeGammaRule(gamma=LRP_GAMMA_DEFAULT) = new{eltype(gamma)}(gamma)
NegativeGammaRule(gamma=LRP_DEFAULT_GAMMA) = new{eltype(gamma)}(gamma)
end
function modify_parameters(r::NegativeGammaRule, param::AbstractArray)
γ = convert(eltype(param), r.γ)
Expand Down Expand Up @@ -347,7 +360,7 @@ function zbox_input(in::AbstractArray{T}, A::AbstractArray) where {T}
end

"""
AlphaBetaRule([alpha=2.0], [beta=1.0])
AlphaBetaRule([alpha=$(LRP_DEFAULT_ALPHA)], [beta=$(LRP_DEFAULT_BETA)])
LRP-``αβ`` rule. Weights positive and negative contributions according to the
parameters `alpha` and `beta` respectively. The difference ``α-β`` must be equal to one.
Expand All @@ -363,8 +376,8 @@ R_j^k = \\sum_i\\left(
```
# Optional arguments
- `alpha`: Multiplier for the positive output term, defaults to `2.0`.
- `beta`: Multiplier for the negative output term, defaults to `1.0`.
- `alpha`: Multiplier for the positive output term, defaults to `$(LRP_DEFAULT_ALPHA)`.
- `beta`: Multiplier for the negative output term, defaults to `$(LRP_DEFAULT_BETA)`.
# References
- $REF_BACH_LRP
Expand All @@ -373,7 +386,7 @@ R_j^k = \\sum_i\\left(
struct AlphaBetaRule{T<:Real} <: AbstractLRPRule
α::T
β::T
function AlphaBetaRule(alpha=2.0f0, beta=1.0f0)
function AlphaBetaRule(alpha=LRP_DEFAULT_ALPHA, beta=LRP_DEFAULT_BETA)
alpha < 0 && throw(ArgumentError("Parameter `alpha` must be ≥0."))
beta < 0 && throw(ArgumentError("Parameter `beta` must be ≥0."))
!isone(alpha - beta) && throw(ArgumentError("`alpha - beta` must be equal one."))
Expand Down Expand Up @@ -452,7 +465,7 @@ function lrp!(Rₖ, rule::ZPlusRule, modified_layers, aₖ, Rₖ₊₁)
end

"""
GeneralizedGammaRule([gamma=$(LRP_GAMMA_DEFAULT)])
GeneralizedGammaRule([gamma=$(LRP_DEFAULT_GAMMA)])
Generalized LRP-``γ`` rule. Can be used on layers with `leakyrelu` activation functions.
Expand All @@ -470,14 +483,14 @@ I(z_k<0) \\cdot R^{k+1}_i
```
# Optional arguments
- `γ`: Optional multiplier for added positive weights, defaults to `$(LRP_GAMMA_DEFAULT)`.
- `gamma`: Optional multiplier for added positive weights, defaults to `$(LRP_DEFAULT_GAMMA)`.
# References
- $REF_ANDEOL_DOMAIN_INVARIANT
"""
struct GeneralizedGammaRule{T<:Real} <: AbstractLRPRule
γ::T
GeneralizedGammaRule(gamma=LRP_GAMMA_DEFAULT) = new{eltype(gamma)}(gamma)
GeneralizedGammaRule(gamma=LRP_DEFAULT_GAMMA) = new{eltype(gamma)}(gamma)
end
function modify_layer(rule::GeneralizedGammaRule, layer)
# ˡ/ʳ: LHS/RHS of the generalized Gamma-rule equation
Expand Down
9 changes: 4 additions & 5 deletions src/lrp/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const COLOR_COMMENT = :light_black
const COLOR_RULE = :yellow
const COLOR_TYPE = :blue
const COLOR_TYPE = :light_blue
const COLOR_RANGE = :green

typename(x) = string(nameof(typeof(x)))
Expand All @@ -10,13 +10,12 @@ typename(x) = string(nameof(typeof(x)))
################

_print_layer(io::IO, l) = string(sprint(show, l; context=io))
function Base.show(io::IO, m::MIME"text/plain", analyzer::LRP)
layer_names = [_print_layer(io, l) for l in analyzer.model]
rs = rules(analyzer)
function Base.show(io::IO, m::MIME"text/plain", lrp::LRP)
layer_names = [_print_layer(io, layer) for layer in lrp.model]
npad = maximum(length.(layer_names)) + 1 # padding to align rules with rpad

println(io, "LRP", "(")
for (r, l) in zip(rs, layer_names)
for (r, l) in zip(lrp.rules, layer_names)
print(io, " ", rpad(l, npad), " => ")
printstyled(io, r; color=COLOR_RULE)
println(io, ",")
Expand Down
8 changes: 8 additions & 0 deletions src/neuron_selection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
abstract type AbstractNeuronSelector end

function mask_output_neuron!(
output_relevance, output_activation, ns::AbstractNeuronSelector
)
fill!(output_relevance, 0)
neuron_selection = ns(output_activation)
output_relevance[neuron_selection] = output_activation[neuron_selection]
end

"""
MaxActivationSelector()
Expand Down
5 changes: 2 additions & 3 deletions test/test_composite.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ExplainableAI
using ExplainableAI: rules
using Flux

# Load VGG model:
Expand Down Expand Up @@ -40,7 +39,7 @@ composite1 = Composite(
)

analyzer1 = LRP(model, composite1)
@test rules(analyzer1) == [
@test analyzer1.rules == [
ZBoxRule(-3.0f0, 3.0f0)
EpsilonRule(1.0f-6)
FlatRule()
Expand Down Expand Up @@ -81,7 +80,7 @@ composite2 = Composite(
),
)
analyzer2 = LRP(model, composite2)
@test rules(analyzer2) == [
@test analyzer2.rules == [
AlphaBetaRule(2.0f0, 1.0f0)
ZeroRule()
ZeroRule()
Expand Down

0 comments on commit 0b695aa

Please sign in to comment.