# Basic usage of LRP
This example will show you best practices for using LRP,
building on the basics shown in the *Getting started* section.

We start out by loading a small convolutional neural network:

In [1]:
using ExplainableAI
using Flux

model = Chain(
    Chain(
        Conv((3, 3), 3 => 8, relu; pad=1),
        Conv((3, 3), 8 => 8, relu; pad=1),
        MaxPool((2, 2)),
        Conv((3, 3), 8 => 16; pad=1),
        BatchNorm(16, relu),
        Conv((3, 3), 16 => 8, relu; pad=1),
        BatchNorm(8, relu),
    ),
    Chain(
        Flux.flatten,
        Dense(2048 => 512, relu),
        Dropout(0.5),
        Dense(512 => 100, softmax)
    ),
);

This model contains two chains: the convolutional layers and the fully connected layers.

## Model preparation
### Stripping the output softmax
When using LRP, it is recommended to explain output logits instead of probabilities.
This can be done by stripping the output softmax activation from the model
using the `strip_softmax` function:

In [2]:
model = strip_softmax(model)

Chain(
  Chain(
    Conv((3, 3), 3 => 8, relu, pad=1),  [90m# 224 parameters[39m
    Conv((3, 3), 8 => 8, relu, pad=1),  [90m# 584 parameters[39m
    MaxPool((2, 2)),
    Conv((3, 3), 8 => 16, pad=1),       [90m# 1_168 parameters[39m
    BatchNorm(16, relu),                [90m# 32 parameters[39m[90m, plus 32[39m
    Conv((3, 3), 16 => 8, relu, pad=1),  [90m# 1_160 parameters[39m
    BatchNorm(8, relu),                 [90m# 16 parameters[39m[90m, plus 16[39m
  ),
  Chain(
    Flux.flatten,
    Dense(2048 => 512, relu),           [90m# 1_049_088 parameters[39m
    Dropout(0.5),
    Dense(512 => 100),                  [90m# 51_300 parameters[39m
  ),
) [90m        # Total: 16 trainable arrays, [39m1_103_572 parameters,
[90m          # plus 4 non-trainable, 48 parameters, summarysize [39m4.213 MiB.

If you don't remove the output softmax,
model checks will fail.

### Canonizing the model
LRP is not invariant to a model's implementation.
Applying the `GammaRule` to two linear layers in a row will yield different results
than first fusing the two layers into one linear layer and then applying the rule.
This fusing is called "canonization" and can be done using the `canonize` function:

In [3]:
model = canonize(model)

Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    [90m# 224 parameters[39m
  Conv((3, 3), 8 => 8, relu, pad=1),    [90m# 584 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, relu, pad=1),   [90m# 1_168 parameters[39m
  Conv((3, 3), 16 => 8, relu, pad=1),   [90m# 1_160 parameters[39m
  BatchNorm(8, relu),                   [90m# 16 parameters[39m[90m, plus 16[39m
  Flux.flatten,
  Dense(2048 => 512, relu),             [90m# 1_049_088 parameters[39m
  Dropout(0.5),
  Dense(512 => 100),                    [90m# 51_300 parameters[39m
) [90m        # Total: 14 trainable arrays, [39m1_103_540 parameters,
[90m          # plus 2 non-trainable, 16 parameters, summarysize [39m4.212 MiB.

### Flattening the model
ExplainableAI.jl's LRP implementation supports nested Flux Chains and Parallel layers.
However, it is recommended to flatten the model before analyzing it.

LRP is implemented by first running a forward pass through the model,
keeping track of the intermediate activations, followed by a backward pass
that computes the relevances.

To keep the LRP implementation simple and maintainable,
ExplainableAI.jl does not pre-compute "nested" activations.
Instead, for every internal chain, a new forward pass is run to compute activations.

By "flattening" a model, this overhead can be avoided.
For this purpose, ExplainableAI.jl provides the function `flatten_model`:

In [4]:
model_flat = flatten_model(model)

Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    [90m# 224 parameters[39m
  Conv((3, 3), 8 => 8, relu, pad=1),    [90m# 584 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, relu, pad=1),   [90m# 1_168 parameters[39m
  Conv((3, 3), 16 => 8, relu, pad=1),   [90m# 1_160 parameters[39m
  BatchNorm(8, relu),                   [90m# 16 parameters[39m[90m, plus 16[39m
  Flux.flatten,
  Dense(2048 => 512, relu),             [90m# 1_049_088 parameters[39m
  Dropout(0.5),
  Dense(512 => 100),                    [90m# 51_300 parameters[39m
) [90m        # Total: 14 trainable arrays, [39m1_103_540 parameters,
[90m          # plus 2 non-trainable, 16 parameters, summarysize [39m4.212 MiB.

This function is called by default when creating an LRP analyzer.
Note that we pass the unflattened model to the analyzer, but `analyzer.model` is flattened:

In [5]:
analyzer = LRP(model)
analyzer.model

Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    [90m# 224 parameters[39m
  Conv((3, 3), 8 => 8, relu, pad=1),    [90m# 584 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, relu, pad=1),   [90m# 1_168 parameters[39m
  Conv((3, 3), 16 => 8, relu, pad=1),   [90m# 1_160 parameters[39m
  BatchNorm(8, relu),                   [90m# 16 parameters[39m[90m, plus 16[39m
  Flux.flatten,
  Dense(2048 => 512, relu),             [90m# 1_049_088 parameters[39m
  Dropout(0.5),
  Dense(512 => 100),                    [90m# 51_300 parameters[39m
) [90m        # Total: 14 trainable arrays, [39m1_103_540 parameters,
[90m          # plus 2 non-trainable, 16 parameters, summarysize [39m4.212 MiB.

If this flattening is not desired, it can be disabled
by passing the keyword argument `flatten=false` to the `LRP` constructor.

## LRP rules
By default, the `LRP` constructor will assign the `ZeroRule` to all layers.

In [6]:
LRP(model)

LRP(
  Conv((3, 3), 3 => 8, relu, pad=1) [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 8 => 8, relu, pad=1) [90m => [39m[33mZeroRule()[39m,
  MaxPool((2, 2))                   [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 8 => 16, relu, pad=1)[90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 16 => 8, relu, pad=1)[90m => [39m[33mZeroRule()[39m,
  BatchNorm(8, relu)                [90m => [39m[33mZeroRule()[39m,
  Flux.flatten                      [90m => [39m[33mZeroRule()[39m,
  Dense(2048 => 512, relu)          [90m => [39m[33mZeroRule()[39m,
  Dropout(0.5)                      [90m => [39m[33mZeroRule()[39m,
  Dense(512 => 100)                 [90m => [39m[33mZeroRule()[39m,
)

This analyzer will return heatmaps that look identical to `InputTimesGradient`.

LRP's strength lies in assigning different rules to different layers,
based on their functionality in the neural network[^1].
ExplainableAI.jl implements many LRP rules out of the box,
but it is also possible to *implement custom rules*.

To assign different rules to different layers,
use one of the composites presets,
or create your own composite, as described in
*Assigning rules to layers*.

In [7]:
composite = EpsilonPlusFlat() # using composite preset EpsilonPlusFlat

Composite(
  GlobalTypeMap(  [90m# all layers[39m
[94m    Flux.Conv              [39m => [33mZPlusRule()[39m,
[94m    Flux.ConvTranspose     [39m => [33mZPlusRule()[39m,
[94m    Flux.CrossCor          [39m => [33mZPlusRule()[39m,
[94m    Flux.Dense             [39m => [33mEpsilonRule{Float32}(1.0f-6)[39m,
[94m    typeof(NNlib.dropout)  [39m => [33mPassRule()[39m,
[94m    Flux.AlphaDropout      [39m => [33mPassRule()[39m,
[94m    Flux.Dropout           [39m => [33mPassRule()[39m,
[94m    Flux.BatchNorm         [39m => [33mPassRule()[39m,
[94m    typeof(Flux.flatten)   [39m => [33mPassRule()[39m,
[94m    typeof(MLUtils.flatten)[39m => [33mPassRule()[39m,
[94m    typeof(identity)       [39m => [33mPassRule()[39m,
 ),
  FirstLayerTypeMap(  [90m# first layer[39m
[94m    Flux.Conv         [39m => [33mFlatRule()[39m,
[94m    Flux.ConvTranspose[39m => [33mFlatRule()[39m,
[94m    Flux.CrossCor     [39m => [33mFlatRule()[39m,
[94m    F

In [8]:
LRP(model, composite)

LRP(
  Conv((3, 3), 3 => 8, relu, pad=1) [90m => [39m[33mFlatRule()[39m,
  Conv((3, 3), 8 => 8, relu, pad=1) [90m => [39m[33mZPlusRule()[39m,
  MaxPool((2, 2))                   [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 8 => 16, relu, pad=1)[90m => [39m[33mZPlusRule()[39m,
  Conv((3, 3), 16 => 8, relu, pad=1)[90m => [39m[33mZPlusRule()[39m,
  BatchNorm(8, relu)                [90m => [39m[33mPassRule()[39m,
  Flux.flatten                      [90m => [39m[33mPassRule()[39m,
  Dense(2048 => 512, relu)          [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
  Dropout(0.5)                      [90m => [39m[33mPassRule()[39m,
  Dense(512 => 100)                 [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
)

## Computing layerwise relevances
If you are interested in computing layerwise relevances,
call `analyze` with an LRP analyzer and the keyword argument
`layerwise_relevances=true`.

The layerwise relevances can be accessed in the `extras` field
of the returned `Explanation`:

In [9]:
input = rand(Float32, 32, 32, 3, 1) # dummy input for our convolutional neural network

expl = analyze(input, analyzer; layerwise_relevances=true)
expl.extras.layerwise_relevances

([0.070911594 -0.005739877 … 0.020050077 0.013649693; 0.058823954 -0.021245474 … -0.029268607 0.019316878; … ; -0.0025407865 0.028541824 … -0.0029151158 -0.0011176834; -0.01845132 -0.0014127557 … 0.020842059 0.005983601;;; -0.009115747 0.102853075 … 0.01257232 0.048881076; -0.07704271 -0.042541314 … -0.06782364 0.07234024; … ; -0.002587716 -0.03424223 … 0.07197469 -0.0026967996; -0.012738505 0.00072014984 … 0.00542081 0.02169352;;; -0.013080878 0.0036326468 … -0.04036488 0.012548849; -0.007655174 0.027398698 … 0.028663527 0.005195543; … ; -0.010705031 -0.0046712295 … 0.0007241109 0.0013092858; -0.004629274 -0.022525892 … -0.011950095 0.029804965;;;;], [0.0 -0.0 … 0.0 -0.0; 0.0 -0.0 … 0.0 0.0; … ; -0.0 -0.0 … -0.0 -0.0; -0.0021022335 -0.0 … 0.04941562 0.0;;; 0.0 0.0050020204 … -0.0 0.0; 0.0 -0.0 … 0.0 -0.024198135; … ; -0.0 -0.0 … 0.0 -0.00372457; 0.0 -0.00019197342 … -0.0 -0.010762867;;; -0.0 -0.022589045 … 0.0 0.023636764; -0.0 0.0 … -0.0053951135 0.024860203; … ; -0.0 -0.004086365 … 

Note that the layerwise relevances are only kept for layers in the outermost `Chain` of the model.
When using our unflattened model, we only obtain three layerwise relevances,
one for each chain in the model and the output relevance:

In [10]:
analyzer = LRP(model; flatten=false) # use unflattened model

expl = analyze(input, analyzer; layerwise_relevances=true)
expl.extras.layerwise_relevances

([0.06675495 -0.004981801 … 0.011021007 0.04421971; 0.048869647 -0.028399747 … -0.027329648 0.01106981; … ; -0.000514915 0.023441497 … 0.0043694833 -0.0018634752; -0.015467346 -0.002745655 … 0.005718846 -0.00504807;;; -0.011888327 0.08724773 … 0.034354433 0.046655826; -0.06444999 -0.047307562 … -0.059081946 0.064309664; … ; -0.0035586779 -0.019582044 … 0.033041686 -0.0013432204; -0.01314467 -4.676871f-5 … -0.007872148 0.011304704;;; -0.0080184415 0.0034944855 … -0.013814093 -0.006194211; -0.00012937235 0.032900084 … 0.012610558 -0.0019307476; … ; -0.0052919756 -0.0015672878 … 0.0014218562 0.00025216877; -0.0031280252 -0.02015292 … -0.0060569393 0.0216214;;;;], [0.0 -0.0 … 0.0 -0.0; 0.0 -0.0 … -0.0 0.0; … ; -0.0 -0.0 … -0.0 -0.0; -0.0028410077 -0.0 … 0.023995642 0.0;;; 0.0 0.0055916235 … -0.0 -0.0; 0.0 -0.0 … 0.0 -0.019809151; … ; -0.0 -0.0 … 0.0 0.00026082754; 0.0 0.0010287425 … 0.0 -0.0064365435;;; -0.0 -0.017906604 … 0.0 0.005572118; -0.0 0.0 … -0.003467294 0.010574062; … ; -0.0 -0.0

## Performance tips
### Using LRP with a GPU
Like all other analyzers, LRP can be used on GPUs.
Follow the instructions on *GPU support*.

### Using LRP without a GPU
Using Julia's package extension mechanism,
ExplainableAI.jl's LRP implementation can optionally make use of
[Tullio.jl](https://github.com/mcabbott/Tullio.jl) and
[LoopVectorization.jl](https://github.com/JuliaSIMD/LoopVectorization.jl)
for faster LRP rules on dense layers.

This only requires loading the packages before loading ExplainableAI.jl:
```julia
using LoopVectorization, Tullio
using ExplainableAI
```

[^1]: G. Montavon et al., [Layer-Wise Relevance Propagation: An Overview](https://link.springer.com/chapter/10.1007/978-3-030-28954-6_10)

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*