Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify LRP analyzer and clean-up rule default parameters #110

Merged
merged 5 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions docs/src/literate/advanced_lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@ analyzer = LRP(model)
# The best point of entry into the source code is
# [`/src/lrp/rules.jl`](https://github.com/adrhill/ExplainableAI.jl/blob/master/src/lrp/rules.jl).
#
# Internally, ExplainableAI pre-allocates modified layers by dispatching `modify_layer`
# on rule and layer types. This constructs the `state` of a LRP analyzer.
#
# Calling `analyze` on a LRP-model then applies a forward-pass of the model,
# Calling `analyze` on a LRP-analyzer pre-allocates modified layers by dispatching
# `modify_layer` on rule and layer types. It then applies a forward-pass of the model,
# keeping track of the activations `aₖ` for each layer `k`.
# The relevance `Rₖ₊₁` is then set to the output neuron activation and the rules are applied
# in a backward-pass over the model layers and previous activations.
Expand Down
2 changes: 1 addition & 1 deletion src/analyze_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type AbstractXAIMethod end

const BATCHDIM_MISSING = ArgumentError(
"""The input is a 1D vector and therefore missing the required batch dimension.
Call analyze with the keyword argument add_batch_dim=false."""
Call `analyze` with the keyword argument `add_batch_dim=false`."""
)

"""
Expand Down
62 changes: 25 additions & 37 deletions src/lrp/lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,25 @@ or by passing a composite, see [`Composite`](@ref) for an example.
[1] G. Montavon et al., Layer-Wise Relevance Propagation: An Overview
[2] W. Samek et al., Explaining Deep Neural Networks and Beyond: A Review of Methods and Applications
"""
struct LRP{R<:AbstractVector{<:Tuple}} <: AbstractXAIMethod
struct LRP{R<:AbstractVector{<:AbstractLRPRule}} <: AbstractXAIMethod
model::Chain
state::R # each entry is a tuple `(rule, modified_layer)`
end
rules(analyzer::LRP) = map(first, analyzer.state)
modified_layers(analyzer::LRP) = map(last, analyzer.state)

# Construct rule_layer_tuples from model and array of rules
function LRP(
model::Chain,
rules::AbstractVector{<:AbstractLRPRule};
is_flat=false,
skip_checks=false,
verbose=true,
)
!is_flat && (model = flatten_model(model))
if !skip_checks
check_output_softmax(model)
check_model(Val(:LRP), model; verbose=verbose)
end
rules::R

state = map(zip(rules, model.layers)) do (r, l)
!is_compatible(r, l) && throw(LRPCompatibilityError(r, l))
return (r, modify_layer(r, l))
# Construct LRP analyzer by assigning a rule to each layer
function LRP(
model::Chain,
rules::AbstractVector{<:AbstractLRPRule};
is_flat=false,
skip_checks=false,
verbose=true,
)
!is_flat && (model = flatten_model(model))
if !skip_checks
check_output_softmax(model)
check_model(Val(:LRP), model; verbose=verbose)
end
return new{typeof(rules)}(model, rules)
end
return LRP(model, state)
end

# Construct vector of rules by applying composite
Expand All @@ -53,28 +46,23 @@ end
LRP(model::Chain; kwargs...) = LRP(model, Composite(ZeroRule()); kwargs...)

# The call to the LRP analyzer.
function (analyzer::LRP)(
function (lrp::LRP)(
input::AbstractArray{T}, ns::AbstractNeuronSelector; layerwise_relevances=false
) where {T}
# Compute layerwise activations on forward pass through model:
acts = [input, Flux.activations(analyzer.model, input)...]
# Allocate array for layerwise relevances:
rels = similar.(acts)

# Mask output neuron
output_indices = ns(acts[end])
rels[end] .= zero(T)
rels[end][output_indices] = acts[end][output_indices]
acts = [input, Flux.activations(lrp.model, input)...] # compute aₖ for all layers k
rels = similar.(acts) # allocate Rₖ for all layers k
mask_output_neuron!(rels[end], acts[end], ns) # compute Rₖ₊₁ of output layer

# Backward pass through layers, applying LRP rules
for (i, (rule, modified_layer)) in Iterators.reverse(enumerate(analyzer.state))
lrp!(rels[i], rule, modified_layer, acts[i], rels[i + 1]) # inplace update rels[i]
modified_layers = get_modified_layers(lrp.rules, lrp.model.layers)
for i in length(lrp.rules):-1:1
# Backward-pass applying LRP rules, inplace updating rels[i]
lrp!(rels[i], lrp.rules[i], modified_layers[i], acts[i], rels[i + 1])
end

return Explanation(
first(rels),
last(acts),
output_indices,
ns(last(acts)),
:LRP,
ifelse(layerwise_relevances, rels, Nothing),
)
Expand Down
47 changes: 30 additions & 17 deletions src/lrp/rules.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# https://adrhill.github.io/ExplainableAI.jl/stable/generated/advanced_lrp/#How-it-works-internally
# https://adrhill.github.io/ExplainableAI.jl/dev/generated/advanced_lrp/#How-it-works-internally
abstract type AbstractLRPRule end

# Bibliography
Expand All @@ -8,6 +8,13 @@ const REF_MONTAVON_DTD = "G. Montavon et al., *Explaining Nonlinear Classificati
const REF_MONTAVON_OVERVIEW = "G. Montavon et al., *Layer-Wise Relevance Propagation: An Overview*"
const REF_ANDEOL_DOMAIN_INVARIANT = "L. Andéol et al., *Learning Domain Invariant Representations by Joint Wasserstein Distance Minimization*"

# Default parameters
const LRP_DEFAULT_GAMMA = 0.25f0
const LRP_DEFAULT_EPSILON = 1.0f-6
const LRP_DEFAULT_STABILIZER = 1.0f-9
const LRP_DEFAULT_ALPHA = 2.0f0
const LRP_DEFAULT_BETA = 1.0f0

# Generic LRP rule. Used by all rules without custom implementations.
function lrp!(Rₖ, rule::AbstractLRPRule, modified_layer, aₖ, Rₖ₊₁)
ãₖ = modify_input(rule, aₖ)
Expand Down Expand Up @@ -66,7 +73,7 @@ modify_input(rule, input) = input

Modify denominator ``z`` for numerical stability on the forward pass.
"""
modify_denominator(rule, d) = stabilize_denom(d, 1.0f-9)
modify_denominator(rule, d) = stabilize_denom(d, LRP_DEFAULT_STABILIZER)

"""
is_compatible(rule, layer)
Expand Down Expand Up @@ -133,6 +140,13 @@ function modify_layer(rule, layer; keep_bias=true)
return copy_layer(layer, w, b)
end

function get_modified_layers(rules, layers)
return map(zip(rules, layers)) do (r, l)
!is_compatible(r, l) && throw(LRPCompatibilityError(r, l))
modify_layer(r, l)
end
end

# Useful presets, used e.g. in AlphaBetaRule, ZBoxRule & ZPlusRule:
modify_parameters(::Val{:keep_positive}, p) = keep_positive(p)
modify_parameters(::Val{:keep_negative}, p) = keep_negative(p)
Expand Down Expand Up @@ -162,7 +176,7 @@ struct ZeroRule <: AbstractLRPRule end
is_compatible(::ZeroRule, layer) = true # compatible with all layer types

"""
EpsilonRule([epsilon=1.0f-6])
EpsilonRule([epsilon=$(LRP_DEFAULT_EPSILON)])

LRP-``ϵ`` rule. Commonly used on middle layers.

Expand All @@ -173,21 +187,20 @@ R_j^k = \\sum_i\\frac{w_{ij}a_j^k}{\\epsilon +\\sum_{l}w_{il}a_l^k+b_i} R_i^{k+1
```

# Optional arguments
- `epsilon`: Optional stabilization parameter, defaults to `1f-6`.
- `epsilon`: Optional stabilization parameter, defaults to `$(LRP_DEFAULT_EPSILON)`.

# References
- $REF_BACH_LRP
"""
struct EpsilonRule{T<:Real} <: AbstractLRPRule
ϵ::T
EpsilonRule(epsilon=1.0f-6) = new{eltype(epsilon)}(epsilon)
EpsilonRule(epsilon=LRP_DEFAULT_EPSILON) = new{eltype(epsilon)}(epsilon)
end
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
is_compatible(::EpsilonRule, layer) = true # compatible with all layer types

const LRP_GAMMA_DEFAULT = 0.25f0
"""
GammaRule([gamma=$(LRP_GAMMA_DEFAULT)])
GammaRule([gamma=$(LRP_DEFAULT_GAMMA)])

LRP-``γ`` rule. Commonly used on lower layers.

Expand All @@ -199,14 +212,14 @@ R_j^k = \\sum_i\\frac{(w_{ij}+\\gamma w_{ij}^+)a_j^k}
```

# Optional arguments
- `γ`: Optional multiplier for added positive weights, defaults to `$(LRP_GAMMA_DEFAULT)`.
- `gamma`: Optional multiplier for added positive weights, defaults to `$(LRP_DEFAULT_GAMMA)`.

# References
- $REF_MONTAVON_OVERVIEW
"""
struct GammaRule{T<:Real} <: AbstractLRPRule
γ::T
GammaRule(gamma=LRP_GAMMA_DEFAULT) = new{eltype(gamma)}(gamma)
GammaRule(gamma=LRP_DEFAULT_GAMMA) = new{eltype(gamma)}(gamma)
end
function modify_parameters(r::GammaRule, param::AbstractArray)
γ = convert(eltype(param), r.γ)
Expand All @@ -216,7 +229,7 @@ end
# Internally used for GeneralizedGammaRule:
struct NegativeGammaRule{T<:Real} <: AbstractLRPRule
γ::T
NegativeGammaRule(gamma=LRP_GAMMA_DEFAULT) = new{eltype(gamma)}(gamma)
NegativeGammaRule(gamma=LRP_DEFAULT_GAMMA) = new{eltype(gamma)}(gamma)
end
function modify_parameters(r::NegativeGammaRule, param::AbstractArray)
γ = convert(eltype(param), r.γ)
Expand Down Expand Up @@ -347,7 +360,7 @@ function zbox_input(in::AbstractArray{T}, A::AbstractArray) where {T}
end

"""
AlphaBetaRule([alpha=2.0], [beta=1.0])
AlphaBetaRule([alpha=$(LRP_DEFAULT_ALPHA)], [beta=$(LRP_DEFAULT_BETA)])

LRP-``αβ`` rule. Weights positive and negative contributions according to the
parameters `alpha` and `beta` respectively. The difference ``α-β`` must be equal to one.
Expand All @@ -363,8 +376,8 @@ R_j^k = \\sum_i\\left(
```

# Optional arguments
- `alpha`: Multiplier for the positive output term, defaults to `2.0`.
- `beta`: Multiplier for the negative output term, defaults to `1.0`.
- `alpha`: Multiplier for the positive output term, defaults to `$(LRP_DEFAULT_ALPHA)`.
- `beta`: Multiplier for the negative output term, defaults to `$(LRP_DEFAULT_BETA)`.

# References
- $REF_BACH_LRP
Expand All @@ -373,7 +386,7 @@ R_j^k = \\sum_i\\left(
struct AlphaBetaRule{T<:Real} <: AbstractLRPRule
α::T
β::T
function AlphaBetaRule(alpha=2.0f0, beta=1.0f0)
function AlphaBetaRule(alpha=LRP_DEFAULT_ALPHA, beta=LRP_DEFAULT_BETA)
alpha < 0 && throw(ArgumentError("Parameter `alpha` must be ≥0."))
beta < 0 && throw(ArgumentError("Parameter `beta` must be ≥0."))
!isone(alpha - beta) && throw(ArgumentError("`alpha - beta` must be equal one."))
Expand Down Expand Up @@ -452,7 +465,7 @@ function lrp!(Rₖ, rule::ZPlusRule, modified_layers, aₖ, Rₖ₊₁)
end

"""
GeneralizedGammaRule([gamma=$(LRP_GAMMA_DEFAULT)])
GeneralizedGammaRule([gamma=$(LRP_DEFAULT_GAMMA)])

Generalized LRP-``γ`` rule. Can be used on layers with `leakyrelu` activation functions.

Expand All @@ -470,14 +483,14 @@ I(z_k<0) \\cdot R^{k+1}_i
```

# Optional arguments
- `γ`: Optional multiplier for added positive weights, defaults to `$(LRP_GAMMA_DEFAULT)`.
- `gamma`: Optional multiplier for added positive weights, defaults to `$(LRP_DEFAULT_GAMMA)`.

# References
- $REF_ANDEOL_DOMAIN_INVARIANT
"""
struct GeneralizedGammaRule{T<:Real} <: AbstractLRPRule
γ::T
GeneralizedGammaRule(gamma=LRP_GAMMA_DEFAULT) = new{eltype(gamma)}(gamma)
GeneralizedGammaRule(gamma=LRP_DEFAULT_GAMMA) = new{eltype(gamma)}(gamma)
end
function modify_layer(rule::GeneralizedGammaRule, layer)
# ˡ/ʳ: LHS/RHS of the generalized Gamma-rule equation
Expand Down
9 changes: 4 additions & 5 deletions src/lrp/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const COLOR_COMMENT = :light_black
const COLOR_RULE = :yellow
const COLOR_TYPE = :blue
const COLOR_TYPE = :light_blue
const COLOR_RANGE = :green

typename(x) = string(nameof(typeof(x)))
Expand All @@ -10,13 +10,12 @@ typename(x) = string(nameof(typeof(x)))
################

_print_layer(io::IO, l) = string(sprint(show, l; context=io))
function Base.show(io::IO, m::MIME"text/plain", analyzer::LRP)
layer_names = [_print_layer(io, l) for l in analyzer.model]
rs = rules(analyzer)
function Base.show(io::IO, m::MIME"text/plain", lrp::LRP)
layer_names = [_print_layer(io, layer) for layer in lrp.model]
npad = maximum(length.(layer_names)) + 1 # padding to align rules with rpad

println(io, "LRP", "(")
for (r, l) in zip(rs, layer_names)
for (r, l) in zip(lrp.rules, layer_names)
print(io, " ", rpad(l, npad), " => ")
printstyled(io, r; color=COLOR_RULE)
println(io, ",")
Expand Down
8 changes: 8 additions & 0 deletions src/neuron_selection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
abstract type AbstractNeuronSelector end

function mask_output_neuron!(
output_relevance, output_activation, ns::AbstractNeuronSelector
)
fill!(output_relevance, 0)
neuron_selection = ns(output_activation)
output_relevance[neuron_selection] = output_activation[neuron_selection]
end

"""
MaxActivationSelector()

Expand Down
5 changes: 2 additions & 3 deletions test/test_composite.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ExplainableAI
using ExplainableAI: rules
using Flux

# Load VGG model:
Expand Down Expand Up @@ -40,7 +39,7 @@ composite1 = Composite(
)

analyzer1 = LRP(model, composite1)
@test rules(analyzer1) == [
@test analyzer1.rules == [
ZBoxRule(-3.0f0, 3.0f0)
EpsilonRule(1.0f-6)
FlatRule()
Expand Down Expand Up @@ -81,7 +80,7 @@ composite2 = Composite(
),
)
analyzer2 = LRP(model, composite2)
@test rules(analyzer2) == [
@test analyzer2.rules == [
AlphaBetaRule(2.0f0, 1.0f0)
ZeroRule()
ZeroRule()
Expand Down