# Assigning LRP rules to layers
In this example, we will show how to assign LRP rules to specific layers.
For this purpose, we first define a small VGG-like convolutional neural network:

In [1]:
using ExplainableAI
using Flux

model = Chain(
    Chain(
        Conv((3, 3), 3 => 8, relu; pad=1),
        Conv((3, 3), 8 => 8, relu; pad=1),
        MaxPool((2, 2)),
        Conv((3, 3), 8 => 16, relu; pad=1),
        Conv((3, 3), 16 => 16, relu; pad=1),
        MaxPool((2, 2)),
    ),
    Chain(
        Flux.flatten,
        Dense(1024 => 512, relu),
        Dropout(0.5),
        Dense(512 => 100, relu)
    ),
);

## Manually assigning rules
When creating an LRP-analyzer, we can assign individual rules to each layer.
As we can see above, our model is a `Chain` of two Flux `Chain`s.
Using `flatten_model`, we can flatten the model into a single `Chain`:

In [2]:
model_flat = flatten_model(model)

Chain(
  Conv((3, 3), 3 => 8, relu, pad=1),    [90m# 224 parameters[39m
  Conv((3, 3), 8 => 8, relu, pad=1),    [90m# 584 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 8 => 16, relu, pad=1),   [90m# 1_168 parameters[39m
  Conv((3, 3), 16 => 16, relu, pad=1),  [90m# 2_320 parameters[39m
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(1024 => 512, relu),             [90m# 524_800 parameters[39m
  Dropout(0.5),
  Dense(512 => 100, relu),              [90m# 51_300 parameters[39m
) [90m                  # Total: 12 arrays, [39m580_396 parameters, 2.216 MiB.

This allows us to define an LRP analyzer using an array of rules
matching the length of the Flux chain:

In [3]:
rules = [
    FlatRule(),
    ZPlusRule(),
    ZeroRule(),
    ZPlusRule(),
    ZPlusRule(),
    ZeroRule(),
    PassRule(),
    EpsilonRule(),
    PassRule(),
    EpsilonRule(),
];

The `LRP` analyzer will show a summary of how layers and rules got matched:

In [4]:
LRP(model_flat, rules)

LRP(
  Conv((3, 3), 3 => 8, relu, pad=1)  [90m => [39m[33mFlatRule()[39m,
  Conv((3, 3), 8 => 8, relu, pad=1)  [90m => [39m[33mZPlusRule()[39m,
  MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
  Conv((3, 3), 8 => 16, relu, pad=1) [90m => [39m[33mZPlusRule()[39m,
  Conv((3, 3), 16 => 16, relu, pad=1)[90m => [39m[33mZPlusRule()[39m,
  MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
  Flux.flatten                       [90m => [39m[33mPassRule()[39m,
  Dense(1024 => 512, relu)           [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
  Dropout(0.5)                       [90m => [39m[33mPassRule()[39m,
  Dense(512 => 100, relu)            [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
)

However, this approach only works for models that can be fully flattened.
For unflattened models and models containing `Parallel` layers, we can compose rules using
`ChainTuple`s and `ParallelTuple`s which match the model structure:

In [5]:
rules = ChainTuple(
    ChainTuple(
        FlatRule(),
        ZPlusRule(),
        ZeroRule(),
        ZPlusRule(),
        ZPlusRule(),
        ZeroRule()
    ),
    ChainTuple(
        PassRule(),
        EpsilonRule(),
        PassRule(),
        EpsilonRule(),
    ),
)

analyzer = LRP(model, rules; flatten=false)

LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)  [90m => [39m[33mFlatRule()[39m,
    Conv((3, 3), 8 => 8, relu, pad=1)  [90m => [39m[33mZPlusRule()[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
    Conv((3, 3), 8 => 16, relu, pad=1) [90m => [39m[33mZPlusRule()[39m,
    Conv((3, 3), 16 => 16, relu, pad=1)[90m => [39m[33mZPlusRule()[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
  ),
  ChainTuple(
    Flux.flatten            [90m => [39m[33mPassRule()[39m,
    Dense(1024 => 512, relu)[90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
    Dropout(0.5)            [90m => [39m[33mPassRule()[39m,
    Dense(512 => 100, relu) [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
  ),
)

## Custom composites
Instead of manually defining a list of rules, we can also define a `Composite`.
A composite constructs a list of LRP-rules by sequentially applying the
composite primitives it contains.

To obtain the same set of rules as in the previous example, we can define

In [6]:
composite = Composite(
    GlobalTypeMap( # the following maps of layer types to LRP rules are applied globally
        Conv                 => ZPlusRule(),   # apply ZPlusRule on all Conv layers
        Dense                => EpsilonRule(), # apply EpsilonRule on all Dense layers
        Dropout              => PassRule(),    # apply PassRule on all Dropout layers
        MaxPool              => ZeroRule(),    # apply ZeroRule on all MaxPool layers
        typeof(Flux.flatten) => PassRule(),    # apply PassRule on all flatten layers
    ),
    FirstLayerMap( # the following rule is applied to the first layer
        FlatRule()
    ),
);

We now construct an LRP analyzer from `composite`

In [7]:
analyzer = LRP(model, composite; flatten=false)

LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)  [90m => [39m[33mFlatRule()[39m,
    Conv((3, 3), 8 => 8, relu, pad=1)  [90m => [39m[33mZPlusRule()[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
    Conv((3, 3), 8 => 16, relu, pad=1) [90m => [39m[33mZPlusRule()[39m,
    Conv((3, 3), 16 => 16, relu, pad=1)[90m => [39m[33mZPlusRule()[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
  ),
  ChainTuple(
    Flux.flatten            [90m => [39m[33mPassRule()[39m,
    Dense(1024 => 512, relu)[90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
    Dropout(0.5)            [90m => [39m[33mPassRule()[39m,
    Dense(512 => 100, relu) [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
  ),
)

As you can see, this analyzer contains the same rules as our previous one.
To compute rules for a model without creating an analyzer, use `lrp_rules`:

In [8]:
lrp_rules(model, composite)

ChainTuple(
  ChainTuple(
    FlatRule(),
    ZPlusRule(),
    ZeroRule(),
    ZPlusRule(),
    ZPlusRule(),
    ZeroRule(),
  ),
  ChainTuple(
    PassRule(),
    EpsilonRule{Float32}(1.0f-6),
    PassRule(),
    EpsilonRule{Float32}(1.0f-6),
  ),
)


## Composite primitives
The following Composite primitives](@ref api-composite-primitives) can used to construct a [`Composite`.

To apply a single rule, use:
* `LayerMap` to apply a rule to a layer at a given index
* `GlobalMap` to apply a rule to all layers
* `RangeMap` to apply a rule to a positional range of layers
* `FirstLayerMap` to apply a rule to the first layer
* `LastLayerMap` to apply a rule to the last layer

To apply a set of rules to layers based on their type, use:
* `GlobalTypeMap` to apply a dictionary that maps layer types to LRP-rules
* `RangeTypeMap` for a `TypeMap` on generalized ranges
* `FirstLayerTypeMap` for a `TypeMap` on the first layer of a model
* `LastLayerTypeMap` for a `TypeMap` on the last layer
* `FirstNTypeMap` for a `TypeMap` on the first `n` layers

Primitives are called sequentially in the order the `Composite` was created with
and overwrite rules specified by previous primitives.

## Assigning a rule to a specific layer
To assign a rule to a specific layer, we can use `LayerMap`,
which maps an LRP-rule to all layers in the model at the given index.

To display indices, use the `show_layer_indices` helper function:

In [9]:
show_layer_indices(model)

ChainTuple(
  ChainTuple(
    (1, 1),
    (1, 2),
    (1, 3),
    (1, 4),
    (1, 5),
    (1, 6),
  ),
  ChainTuple(
    (2, 1),
    (2, 2),
    (2, 3),
    (2, 4),
  ),
)


Let's demonstrate `LayerMap` by assigning a specific rule to the last `Conv` layer
at index `(1, 5)`:

In [10]:
composite = Composite(LayerMap((1, 5), EpsilonRule()))

LRP(model, composite; flatten=false)

LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)  [90m => [39m[33mZeroRule()[39m,
    Conv((3, 3), 8 => 8, relu, pad=1)  [90m => [39m[33mZeroRule()[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
    Conv((3, 3), 8 => 16, relu, pad=1) [90m => [39m[33mZeroRule()[39m,
    Conv((3, 3), 16 => 16, relu, pad=1)[90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
  ),
  ChainTuple(
    Flux.flatten            [90m => [39m[33mZeroRule()[39m,
    Dense(1024 => 512, relu)[90m => [39m[33mZeroRule()[39m,
    Dropout(0.5)            [90m => [39m[33mZeroRule()[39m,
    Dense(512 => 100, relu) [90m => [39m[33mZeroRule()[39m,
  ),
)

This approach also works with `Parallel` layers.

## Composite presets
ExplainableAI.jl provides a set of default composites.
A list of all implemented default composites can be found
in the API reference,
e.g. the `EpsilonPlusFlat` composite:

In [11]:
composite = EpsilonPlusFlat()

Composite(
  GlobalTypeMap(  [90m# all layers[39m
[94m    Flux.Conv              [39m => [33mZPlusRule()[39m,
[94m    Flux.ConvTranspose     [39m => [33mZPlusRule()[39m,
[94m    Flux.CrossCor          [39m => [33mZPlusRule()[39m,
[94m    Flux.Dense             [39m => [33mEpsilonRule{Float32}(1.0f-6)[39m,
[94m    typeof(NNlib.dropout)  [39m => [33mPassRule()[39m,
[94m    Flux.AlphaDropout      [39m => [33mPassRule()[39m,
[94m    Flux.Dropout           [39m => [33mPassRule()[39m,
[94m    Flux.BatchNorm         [39m => [33mPassRule()[39m,
[94m    typeof(Flux.flatten)   [39m => [33mPassRule()[39m,
[94m    typeof(MLUtils.flatten)[39m => [33mPassRule()[39m,
[94m    typeof(identity)       [39m => [33mPassRule()[39m,
 ),
  FirstLayerTypeMap(  [90m# first layer[39m
[94m    Flux.Conv         [39m => [33mFlatRule()[39m,
[94m    Flux.ConvTranspose[39m => [33mFlatRule()[39m,
[94m    Flux.CrossCor     [39m => [33mFlatRule()[39m,
[94m    F

In [12]:
analyzer = LRP(model, composite; flatten=false)

LRP(
  ChainTuple(
    Conv((3, 3), 3 => 8, relu, pad=1)  [90m => [39m[33mFlatRule()[39m,
    Conv((3, 3), 8 => 8, relu, pad=1)  [90m => [39m[33mZPlusRule()[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
    Conv((3, 3), 8 => 16, relu, pad=1) [90m => [39m[33mZPlusRule()[39m,
    Conv((3, 3), 16 => 16, relu, pad=1)[90m => [39m[33mZPlusRule()[39m,
    MaxPool((2, 2))                    [90m => [39m[33mZeroRule()[39m,
  ),
  ChainTuple(
    Flux.flatten            [90m => [39m[33mPassRule()[39m,
    Dense(1024 => 512, relu)[90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
    Dropout(0.5)            [90m => [39m[33mPassRule()[39m,
    Dense(512 => 100, relu) [90m => [39m[33mEpsilonRule{Float32}(1.0f-6)[39m,
  ),
)

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*