# Basic usage of LRP

We start out by loading a small convolutional neural network:

In [1]:
using RelevancePropagation
using Flux

model = Chain(
    Chain(
        Conv((3, 3), 3 => 8, relu; pad=1),
        Conv((3, 3), 8 => 8, relu; pad=1),
        MaxPool((2, 2)),
        Conv((3, 3), 8 => 16; pad=1),
        BatchNorm(16, relu),
        Conv((3, 3), 16 => 8, relu; pad=1),
        BatchNorm(8, relu),
    ),
    Chain(Flux.flatten, Dense(2048 => 512, relu), Dropout(0.5), Dense(512 => 100, softmax)),
);

This model contains two chains: the convolutional layers and the fully connected layers.

## Model preparation
### Stripping the output softmax
When using LRP, it is recommended to explain output logits instead of probabilities.
This can be done by stripping the output softmax activation from the model
using the `strip_softmax` function:

In [2]:
model = strip_softmax(model)

Chain(
  Chain(
    Conv((3, 3), 3 => 8, relu, pad=1),  [90m# 224 parameters[39m
    Conv((3, 3), 8 => 8, relu, pad=1),  [90m# 584 parameters[39m
    MaxPool((2, 2)),
    Conv((3, 3), 8 => 16, pad=1),       [90m# 1_168 parameters[39m
    BatchNorm(16, relu),                [90m# 32 parameters[39m[90m, plus 32[39m
    Conv((3, 3), 16 => 8, relu, pad=1),  [90m# 1_160 parameters[39m
    BatchNorm(8, relu),                 [90m# 16 parameters[39m[90m, plus 16[39m
  ),
  Chain(
    Flux.flatten,
    Dense(2048 => 512, relu),           [90m# 1_049_088 parameters[39m
    Dropout(0.5),
    Dense(512 => 100),                  [90m# 51_300 parameters[39m
  ),
) [90m        # Total: 16 trainable arrays, [39m1_103_572 parameters,
[90m          # plus 4 non-trainable, 48 parameters, summarysize [39m4.213 MiB.

If you don't remove the output softmax,
model checks will fail.

### Canonizing the model
LRP is not invariant to a model's implementation.
Applying the `GammaRule` to two linear layers in a row will yield different results
than first fusing the two layers into one linear layer and then applying the rule.
This fusing is called "canonization" and can be done using the `canonize` function:

In [3]:
model_canonized = canonize(model)

Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    [90m# 224 parameters[39m
  Conv((3, 3), 8 => 8, relu, pad=1),    [90m# 584 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, relu, pad=1),   [90m# 1_168 parameters[39m
  Conv((3, 3), 16 => 8, relu, pad=1),   [90m# 1_160 parameters[39m
  BatchNorm(8, relu),                   [90m# 16 parameters[39m[90m, plus 16[39m
  Flux.flatten,
  Dense(2048 => 512, relu),             [90m# 1_049_088 parameters[39m
  Dropout(0.5),
  Dense(512 => 100),                    [90m# 51_300 parameters[39m
) [90m        # Total: 14 trainable arrays, [39m1_103_540 parameters,
[90m          # plus 2 non-trainable, 16 parameters, summarysize [39m4.212 MiB.

After canonization, the first `BatchNorm` layer has been fused into the preceding `Conv` layer.
The second `BatchNorm` layer wasn't fused
since its preceding `Conv` layer has a ReLU activation function.

### Flattening the model
RelevancePropagation.jl's LRP implementation supports nested Flux Chains and Parallel layers.
However, it is recommended to flatten the model before analyzing it.

LRP is implemented by first running a forward pass through the model,
keeping track of the intermediate activations, followed by a backward pass
that computes the relevances.

To keep the LRP implementation simple and maintainable,
RelevancePropagation.jl does not pre-compute "nested" activations.
Instead, for every internal chain, a new forward pass is run to compute activations.

By "flattening" a model, this overhead can be avoided.
For this purpose, RelevancePropagation.jl provides the function `flatten_model`:

In [4]:
model_flat = flatten_model(model)

Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    [90m# 224 parameters[39m
  Conv((3, 3), 8 => 8, relu, pad=1),    [90m# 584 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, pad=1),         [90m# 1_168 parameters[39m
  BatchNorm(16, relu),                  [90m# 32 parameters[39m[90m, plus 32[39m
  Conv((3, 3), 16 => 8, relu, pad=1),   [90m# 1_160 parameters[39m
  BatchNorm(8, relu),                   [90m# 16 parameters[39m[90m, plus 16[39m
  Flux.flatten,
  Dense(2048 => 512, relu),             [90m# 1_049_088 parameters[39m
  Dropout(0.5),
  Dense(512 => 100),                    [90m# 51_300 parameters[39m
) [90m        # Total: 16 trainable arrays, [39m1_103_572 parameters,
[90m          # plus 4 non-trainable, 48 parameters, summarysize [39m4.212 MiB.

This function is called by default when creating an LRP analyzer.
Note that we pass the unflattened model to the analyzer, but `analyzer.model` is flattened:

In [5]:
analyzer = LRP(model)
analyzer.model

Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    [90m# 224 parameters[39m
  Conv((3, 3), 8 => 8, relu, pad=1),    [90m# 584 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, pad=1),         [90m# 1_168 parameters[39m
  BatchNorm(16, relu),                  [90m# 32 parameters[39m[90m, plus 32[39m
  Conv((3, 3), 16 => 8, relu, pad=1),   [90m# 1_160 parameters[39m
  BatchNorm(8, relu),                   [90m# 16 parameters[39m[90m, plus 16[39m
  Flux.flatten,
  Dense(2048 => 512, relu),             [90m# 1_049_088 parameters[39m
  Dropout(0.5),
  Dense(512 => 100),                    [90m# 51_300 parameters[39m
) [90m        # Total: 16 trainable arrays, [39m1_103_572 parameters,
[90m          # plus 4 non-trainable, 48 parameters, summarysize [39m4.212 MiB.

If this flattening is not desired, it can be disabled
by passing the keyword argument `flatten=false` to the `LRP` constructor.

## LRP rules
By default, the `LRP` constructor will assign the `ZeroRule` to all layers.

In [6]:
LRP(model)

LRP(
  Conv((3, 3), 3 => 8, relu, pad=1) [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 8 => 8, relu, pad=1) [90m => [39m[33mZeroRule()[39m,
  MaxPool((2, 2))                   [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 8 => 16, pad=1)      [90m => [39m[33mZeroRule()[39m,
  BatchNorm(16, relu)               [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 16 => 8, relu, pad=1)[90m => [39m[33mZeroRule()[39m,
  BatchNorm(8, relu)                [90m => [39m[33mZeroRule()[39m,
  Flux.flatten                      [90m => [39m[33mZeroRule()[39m,
  Dense(2048 => 512, relu)          [90m => [39m[33mZeroRule()[39m,
  Dropout(0.5)                      [90m => [39m[33mZeroRule()[39m,
  Dense(512 => 100)                 [90m => [39m[33mZeroRule()[39m,
)

This analyzer will return heatmaps that look identical to the `InputTimesGradient` analyzer
from [ExplainableAI.jl](https://github.com/Julia-XAI/ExplainableAI.jl).

LRP's strength lies in assigning different rules to different layers,
based on their functionality in the neural network[^1].
RelevancePropagation.jl implements many LRP rules out of the box,
but it is also possible to *implement custom rules*.

To assign different rules to different layers,
use one of the composites presets,
or create your own composite, as described in
*Assigning rules to layers*.

In [7]:
composite = EpsilonPlusFlat() # using composite preset EpsilonPlusFlat

Composite(
  GlobalTypeMap(  [90m# all layers[39m
[94m    Flux.Conv              [39m => [33mZPlusRule()[39m,
[94m    Flux.ConvTranspose     [39m => [33mZPlusRule()[39m,
[94m    Flux.CrossCor          [39m => [33mZPlusRule()[39m,
[94m    Flux.Dense             [39m => [33mEpsilonRule{Float32}(1.0f-6)[39m,
[94m    typeof(NNlib.dropout)  [39m => [33mPassRule()[39m,
[94m    Flux.AlphaDropout      [39m => [33mPassRule()[39m,
[94m    Flux.Dropout           [39m => [33mPassRule()[39m,
[94m    Flux.BatchNorm         [39m => [33mPassRule()[39m,
[94m    typeof(Flux.flatten)   [39m => [33mPassRule()[39m,
[94m    typeof(MLUtils.flatten)[39m => [33mPassRule()[39m,
[94m    typeof(identity)       [39m => [33mPassRule()[39m,
 ),
  FirstLayerTypeMap(  [90m# first layer[39m
[94m    Flux.Conv         [39m => [33mFlatRule()[39m,
[94m    Flux.ConvTranspose[39m => [33mFlatRule()[39m,
[94m    Flux.CrossCor     [39m => [33mFlatRule()[39m,
 ),
)

In [8]:
LRP(model, composite)

LRP(
  Conv((3, 3), 3 => 8, relu, pad=1) [90m => [39m[33mFlatRule()[39m,
  Conv((3, 3), 8 => 8, relu, pad=1) [90m => [39m[33mZPlusRule()[39m,
  MaxPool((2, 2))                   [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 8 => 16, pad=1)      [90m => [39m[33mZPlusRule()[39m,
  BatchNorm(16, relu)               [90m => [39m[33mPassRule()[39m,
  Conv((3, 3), 16 => 8, relu, pad=1)[90m => [39m[33mZPlusRule()[39m,
  BatchNorm(8, relu)                [90m => [39m[33mPassRule()[39m,
  Flux.flatten                      [90m => [39m[33mPassRule()[39m,
  Dense(2048 => 512, relu)          [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
  Dropout(0.5)                      [90m => [39m[33mPassRule()[39m,
  Dense(512 => 100)                 [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
)

## Computing layerwise relevances
If you are interested in computing layerwise relevances,
call `analyze` with an LRP analyzer and the keyword argument
`layerwise_relevances=true`.

The layerwise relevances can be accessed in the `extras` field
of the returned `Explanation`:

In [9]:
input = rand(Float32, 32, 32, 3, 1) # dummy input for our convolutional neural network

expl = analyze(input, analyzer; layerwise_relevances=true)
expl.extras.layerwise_relevances

(Float32[-0.045444563 -0.26116747 … -0.18418919 0.022211475; -0.02525588 0.07895095 … 0.15981244 -0.13078083; … ; -0.04687914 0.07910321 … 0.09199613 0.12717393; 0.03390475 -0.25860354 … -0.061920624 -0.019728703;;; 0.08199054 0.15446149 … -1.6924783 0.11057547; -0.37049145 -0.21730764 … -0.048435368 0.079642646; … ; 0.17469604 0.19805294 … 0.0055019576 0.12714502; -0.03183676 -0.12583175 … 0.19378994 -0.12801933;;; -0.0010350663 0.03946558 … -0.49190733 0.6596138; -0.043050125 -0.16689013 … 0.9517378 -0.2919031; … ; 0.013863304 0.021302825 … -0.16041997 -0.047749992; -0.024313372 0.16763961 … 0.16531514 0.018675799;;;;], Float32[0.0036868607 -0.0 … -0.0 0.0; 0.009831537 -0.0 … -0.0 -0.0; … ; 0.07072032 -0.0 … -0.0 -0.0; 0.0060353205 0.0 … 0.0 0.0;;; 0.00032768803 0.019801872 … -0.0 -0.0; 0.0 0.0 … -0.12721023 -0.0046421215; … ; 0.0 0.0 … -0.0 -0.0; 0.0 -0.0 … 0.0 0.0;;; -0.0061789453 0.0 … 0.22457685 0.0; 0.004490305 -0.0 … 0.077684134 -0.0; … ; 0.051793803 0.0415779 … -0.09648164 -0.

Note that the layerwise relevances are only kept for layers in the outermost `Chain` of the model.
When using our unflattened model, we only obtain three layerwise relevances,
one for each chain in the model and the output relevance:

In [10]:
analyzer = LRP(model; flatten=false) # use unflattened model

expl = analyze(input, analyzer; layerwise_relevances=true)
expl.extras.layerwise_relevances

(Float32[0.014908061 -0.010275087 … -0.114645995 0.0061288048; 0.013130265 -0.37091494 … -0.040975712 -0.16242087; … ; -0.25511387 0.26326725 … -0.57306737 -0.59616524; -0.031121291 -0.45590526 … -0.5059504 0.121867225;;; 0.11382516 -0.10646851 … -0.3485705 0.17627715; -0.10605026 -0.10684224 … 0.08777002 0.03606443; … ; 0.40929225 0.32206985 … -2.2140057 -0.103317104; -0.055774953 -0.37625435 … 0.51358277 -0.05808303;;; -6.3059333f-6 -0.5549344 … -0.32855216 0.26615795; -0.13330628 -0.1338728 … 0.6101553 0.20615782; … ; 0.04696441 0.15260619 … -1.0518061 0.07603169; -0.037169613 0.08803467 … -1.2858689 -0.42897636;;;;], Float32[-0.0 0.0 … 0.015541518 -0.0022511436; 0.0013909015 -0.0014181173 … 0.02218824 0.020792758; … ; 0.0011591922 0.00852769 … 0.015254115 0.018995356; 0.00028160607 0.012736286 … -0.0 0.004941677;;; 0.00057267403 -0.002083079 … 0.006641929 0.014116885; -0.0 -0.0033896693 … 0.00012088958 0.0002560582; … ; 0.0 0.0029619779 … -0.0059263497 -0.0010557839; -0.0 -0.0 … -0

## Performance tips
### ### Using LRP with a GPU
All LRP analyzers support GPU backends,
building on top of [Flux.jl's GPU support](https://fluxml.ai/Flux.jl/stable/gpu/).
Using a GPU only requires moving the input array and model weights to the GPU.

For example, using [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl):

```julia
using CUDA, cuDNN
using Flux
using ExplainableAI

# move input array and model weights to GPU
input = input |> gpu # or gpu(input)
model = model |> gpu # or gpu(model)

# analyzers don't require calling `gpu`
analyzer = LRP(model)

# explanations are computed on the GPU
expl = analyze(input, analyzer)
```

Some operations, like saving, require moving explanations back to the CPU.
This can be done using Flux's `cpu` function:

```julia
val = expl.val |> cpu # or cpu(expl.val)

using BSON
BSON.@save "explanation.bson" val
```

### Using LRP without a GPU
Using Julia's package extension mechanism,
RelevancePropagation.jl's LRP implementation can optionally make use of
[Tullio.jl](https://github.com/mcabbott/Tullio.jl) and
[LoopVectorization.jl](https://github.com/JuliaSIMD/LoopVectorization.jl)
for faster LRP rules on dense layers.

This only requires loading the packages before loading RelevancePropagation.jl:
```julia
using LoopVectorization, Tullio
using RelevancePropagation
```

[^1]: G. Montavon et al., [Layer-Wise Relevance Propagation: An Overview](https://link.springer.com/chapter/10.1007/978-3-030-28954-6_10)

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*