Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested indexing in composite primitive LayerMap #131

Merged
merged 9 commits into from
Aug 24, 2023

Conversation

adrhill
Copy link
Member

@adrhill adrhill commented Aug 24, 2023

This allows users to assign rules to specific layers in nested Chains of Chains, as introduced in #119.
For this purpose, the helper function show_layer_indices is introduced.

Closes #121.

Example

julia> using ExplainableAI, Metalhead

julia> vgg11 = VGG(11).layers
Chain(
  Chain(
    Conv((3, 3), 3 => 64, relu, pad=1),  # 1_792 parameters
    MaxPool((2, 2)),
    Conv((3, 3), 64 => 128, relu, pad=1),  # 73_856 parameters
    MaxPool((2, 2)),
    Conv((3, 3), 128 => 256, relu, pad=1),  # 295_168 parameters
    Conv((3, 3), 256 => 256, relu, pad=1),  # 590_080 parameters
    MaxPool((2, 2)),
    Conv((3, 3), 256 => 512, relu, pad=1),  # 1_180_160 parameters
    Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
    MaxPool((2, 2)),
    Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
    Conv((3, 3), 512 => 512, relu, pad=1),  # 2_359_808 parameters
    MaxPool((2, 2)),
  ),
  Chain(
    MLUtils.flatten,
    Dense(25088 => 4096, relu),         # 102_764_544 parameters
    Dropout(0.5),
    Dense(4096 => 4096, relu),          # 16_781_312 parameters
    Dropout(0.5),
    Dense(4096 => 1000),                # 4_097_000 parameters
  ),
)                   # Total: 22 arrays, 132_863_336 parameters, 506.839 MiB.

julia> show_layer_indices(vgg11)
ChainTuple(
  ChainTuple(
    (1, 1),
    (1, 2),
    (1, 3),
    (1, 4),
    (1, 5),
    (1, 6),
    (1, 7),
    (1, 8),
    (1, 9),
    (1, 10),
    (1, 11),
    (1, 12),
    (1, 13),
  ),
  ChainTuple(
    (2, 1),
    (2, 2),
    (2, 3),
    (2, 4),
    (2, 5),
    (2, 6),
  ),
)


julia> c = Composite(LayerMap((2, 4), EpsilonRule()))
Composite(
  LayerMap: layer (2, 4) => EpsilonRule{Float32}(1.0f-6),
)


julia> LRP(vgg11, c, flatten=false)
LRP(
  ChainTuple(
    Conv((3, 3), 3 => 64, relu, pad=1)    => ZeroRule(),
    MaxPool((2, 2))                       => ZeroRule(),
    Conv((3, 3), 64 => 128, relu, pad=1)  => ZeroRule(),
    MaxPool((2, 2))                       => ZeroRule(),
    Conv((3, 3), 128 => 256, relu, pad=1) => ZeroRule(),
    Conv((3, 3), 256 => 256, relu, pad=1) => ZeroRule(),
    MaxPool((2, 2))                       => ZeroRule(),
    Conv((3, 3), 256 => 512, relu, pad=1) => ZeroRule(),
    Conv((3, 3), 512 => 512, relu, pad=1) => ZeroRule(),
    MaxPool((2, 2))                       => ZeroRule(),
    Conv((3, 3), 512 => 512, relu, pad=1) => ZeroRule(),
    Conv((3, 3), 512 => 512, relu, pad=1) => ZeroRule(),
    MaxPool((2, 2))                       => ZeroRule(),
  ),
  ChainTuple(
    MLUtils.flatten            => ZeroRule(),
    Dense(25088 => 4096, relu) => ZeroRule(),
    Dropout(0.5)               => ZeroRule(),
    Dense(4096 => 4096, relu)  => EpsilonRule{Float32}(1.0f-6),
    Dropout(0.5)               => ZeroRule(),
    Dense(4096 => 1000)        => ZeroRule(),
  ),
)

@codecov
Copy link

codecov bot commented Aug 24, 2023

Codecov Report

Merging #131 (d11b485) into master (a31ea9b) will increase coverage by 1.31%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #131      +/-   ##
==========================================
+ Coverage   92.59%   93.91%   +1.31%     
==========================================
  Files          18       18              
  Lines         648      657       +9     
==========================================
+ Hits          600      617      +17     
+ Misses         48       40       -8     
Files Changed Coverage Δ
src/ExplainableAI.jl 100.00% <ø> (ø)
src/flux_chain_utils.jl 92.13% <100.00%> (+9.18%) ⬆️
src/lrp/composite.jl 100.00% <100.00%> (ø)
src/lrp/show.jl 97.80% <100.00%> (ø)

@adrhill adrhill merged commit 15a7a9d into master Aug 24, 2023
7 checks passed
@adrhill adrhill deleted the ah/positional-primitive branch August 24, 2023 14:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add composite primitive to assign rule at specific position in model
1 participant