Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Parallel layers in LRP #135

Merged
merged 5 commits into from
Aug 29, 2023
Merged

Support Parallel layers in LRP #135

merged 5 commits into from
Aug 29, 2023

Conversation

adrhill
Copy link
Member

@adrhill adrhill commented Aug 29, 2023

Closes #10, building on top of infrastructure introduced in #119. 🎉

Examples

Using Metalhead's pre-trained ResNet(152) model without canonization on the Readme example.
Input for reference:

input-image

Heapmaps for class "castle"

Using EpsilonPlusFlat composite:
resnet152_castle_EpsilonPlusFlat

Using EpsilonPlus composite:
resnet152_castle_EpsilonPlus

Heapmaps for class "street sign"

Using EpsilonPlusFlat composite:
resnet152_stsign_EpsilonPlusFlat

Using EpsilonPlus composite:
resnet152_stsign_EpsilonPlus

Code used to generate heatmaps:

using ExplainableAI
using Flux
using Metalhead                         # pre-trained vision models
using HTTP, FileIO, ImageMagick         # load image from URL
using ImageInTerminal

# Load input
url = HTTP.URI("https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
img = load(url)
input = preprocess_imagenet(img)
input = reshape(input, 224, 224, 3, :)  # reshape to WHCN format

# Define model sizes
resnet(size) = ResNet(size, pretrain=true).layers 
resnet_sizes = (18, 50, 152)

# Load composites
composites = Dict(
    "EpsilonPlus"            => EpsilonPlus(),  
    "EpsilonPlusFlat"        => EpsilonPlusFlat(), 
    "EpsilonAlpha2Beta1"     => EpsilonAlpha2Beta1(), 
    "EpsilonAlpha2Beta1Flat" => EpsilonAlpha2Beta1Flat(),
)


# Generate heatmaps for classes castle and street sign
for s in resnet_sizes
    model = resnet(s)
    for (cname, c) in composites
        @info "Analyzing ResNet$s on composite $cname" 
        analyzer = LRP(model, c)
        h_castle = heatmap(input, analyzer)
        h_stsign = heatmap(input, analyzer, 920)
        
        display(h_castle)
        display(h_stsign)
        save("resnet$(s)_$(cname)_castle.png", h_castle)
        save("resnet$(s)_$(cname)_stsign.png", h_stsign)
    end
end

@adrhill adrhill merged commit aecad85 into master Aug 29, 2023
5 checks passed
@adrhill adrhill deleted the ah/parallel2 branch August 29, 2023 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add LRP support for Parallel layer
1 participant