-
Notifications
You must be signed in to change notification settings - Fork 2
/
heatmap.jl
108 lines (99 loc) · 4.47 KB
/
heatmap.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
# Analyzer => (colorscheme, reduce, rangescale)
:LRP => (ColorSchemes.seismic, :sum, :centered), # attribution
:InputTimesGradient => (ColorSchemes.seismic, :sum, :centered), # attribution
:Gradient => (ColorSchemes.grays, :norm, :extrema), # gradient
)
"""
heatmap(explanation)
heatmap(input, analyzer)
heatmap(input, analyzer, neuron_selection)
Visualize explanation.
Assumes Flux's WHCN convention (width, height, color channels, batch size).
See also [`analyze`](@ref).
## Keyword arguments
- `cs::ColorScheme`: color scheme from ColorSchemes.jl that is applied.
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
When calling `heatmap` with an array, the default is `ColorSchemes.seismic`.
- `reduce::Symbol`: selects how color channels are reduced to a single number to apply a color scheme.
The following methods can be selected, which are then applied over the color channels
for each "pixel" in the explanation:
- `:sum`: sum up color channels
- `:norm`: compute 2-norm over the color channels
- `:maxabs`: compute `maximum(abs, x)` over the color channels
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
When calling `heatmap` with an array, the default is `:sum`.
- `rangescale::Symbol`: selects how the color channel reduced heatmap is normalized
before the color scheme is applied. Can be either `:extrema` or `:centered`.
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
When calling `heatmap` with an array, the default for use with the `seismic` color scheme is `:centered`.
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
- `unpack_singleton::Bool`: When heatmapping a batch with a single sample, setting `unpack_singleton=true`
will return an image instead of an Vector containing a single image.
**Note:** keyword arguments can't be used when calling `heatmap` with an analyzer.
"""
function heatmap(
attr::AbstractArray{T,N};
cs::ColorScheme=ColorSchemes.seismic,
reduce::Symbol=:sum,
rangescale::Symbol=:centered,
permute::Bool=true,
unpack_singleton::Bool=true,
) where {T,N}
N != 4 && throw(
DomainError(
N,
"""heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
Please reshape your explanation to match this format if your model doesn't adhere to this convention.""",
),
)
if unpack_singleton && size(attr, 4) == 1
return _heatmap(attr[:, :, :, 1], cs, reduce, rangescale, permute)
end
return map(a -> _heatmap(a, cs, reduce, rangescale, permute), eachslice(attr; dims=4))
end
# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
function heatmap(expl::Explanation; permute::Bool=true, kwargs...)
_cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl.analyzer]
return heatmap(
expl.val;
reduce=get(kwargs, :reduce, _reduce),
rangescale=get(kwargs, :rangescale, _rangescale),
cs=get(kwargs, :cs, _cs),
permute=permute,
)
end
# Analyze & heatmap in one go
function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
return heatmap(analyze(input, analyzer, args...; kwargs...))
end
# Lower level function that is mapped along batch dimension
function _heatmap(
attr::AbstractArray{T,3},
cs::ColorScheme,
reduce::Symbol,
rangescale::Symbol,
permute::Bool,
) where {T<:Real}
img = dropdims(_reduce(attr, reduce); dims=3)
permute && (img = permutedims(img))
return ColorSchemes.get(cs, img, rangescale)
end
# Reduce explanations across color channels into a single scalar – assumes WHCN convention
function _reduce(attr::AbstractArray{T,3}, method::Symbol) where {T}
if size(attr, 3) == 1 # nothing to reduce
return attr
elseif method == :sum
return reduce(+, attr; dims=3)
elseif method == :maxabs
return reduce((c...) -> maximum(abs.(c)), attr; dims=3, init=zero(T))
elseif method == :norm
return reduce((c...) -> sqrt(sum(c .^ 2)), attr; dims=3, init=zero(T))
end
throw(
ArgumentError(
"Color channel reducer :$method not supported, `reduce` should be :maxabs, :sum or :norm",
),
)
end