Skip to content

Commit

Permalink
Add XAIBase package extension (#4)
Browse files Browse the repository at this point in the history
Continues work started in Julia-XAI/XAIBase.jl#16 and Julia-XAI/VisionHeatmaps.jl#7 by moving `heatmap` methods on `Explanation` type to TextHeatmaps.jl via package extensions on XAIBase.
  • Loading branch information
adrhill committed Feb 19, 2024
1 parent 5b982d9 commit 41874be
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 45 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
matrix:
version:
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -33,7 +34,9 @@ jobs:
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-runtest@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
with:
Expand Down
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
name = "TextHeatmaps"
uuid = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "1.1.0"
version = "1.2.0-DEV"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[weakdeps]
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[extensions]
TextHeatmapsXAIBaseExt = "XAIBase"

[compat]
ColorSchemes = "3"
Colors = "0.12"
Crayons = "4"
FixedPointNumbers = "0.8"
Requires = "1"
XAIBase = "3"
julia = "1.6"
79 changes: 79 additions & 0 deletions ext/TextHeatmapsXAIBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
module TextHeatmapsXAIBaseExt

using TextHeatmaps, XAIBase

struct HeatmapConfig
colorscheme::Symbol
reduce::Symbol
rangescale::Symbol
end

const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered
const DEFAULT_HEATMAP_PRESET = HeatmapConfig(
DEFAULT_COLORSCHEME, DEFAULT_REDUCE, DEFAULT_RANGESCALE
)

const HEATMAP_PRESETS = Dict{Symbol,HeatmapConfig}(
:attribution => HeatmapConfig(:seismic, :sum, :centered),
:sensitivity => HeatmapConfig(:grays, :norm, :extrema),
:cam => HeatmapConfig(:jet, :sum, :extrema),
)

# Select HeatmapConfig preset based on heatmapping style in Explanation
function get_heatmapping_config(heatmap::Symbol)
return get(HEATMAP_PRESETS, heatmap, DEFAULT_HEATMAP_PRESET)
end

# Override HeatmapConfig preset with keyword arguments
function get_heatmapping_config(expl::Explanation; kwargs...)
c = get_heatmapping_config(expl.heatmap)

colorscheme = get(kwargs, :colorscheme, c.colorscheme)
rangescale = get(kwargs, :rangescale, c.rangescale)
reduce = get(kwargs, :reduce, c.reduce)
return HeatmapConfig(colorscheme, reduce, rangescale)
end

"""
heatmap(explanation, text)
Visualize [`Explanation`](@ref) from XAIBase as text heatmap.
Text should be a vector containing vectors of strings, one for each input in the batched explanation.
## Keyword arguments
- `colorscheme::Union{ColorScheme,Symbol}`: color scheme from ColorSchemes.jl.
Defaults to `:$DEFAULT_COLORSCHEME`.
- `rangescale::Symbol`: selects how the color channel reduced heatmap is normalized
before the color scheme is applied. Can be either `:extrema` or `:centered`.
Defaults to `:$DEFAULT_RANGESCALE` for use with the default color scheme `:$DEFAULT_COLORSCHEME`.
"""
function TextHeatmaps.heatmap(
expl::Explanation, texts::AbstractVector{<:AbstractVector{<:AbstractString}}; kwargs...
)
ndims(expl.val) != 2 && throw(
ArgumentError(
"To heatmap text, `explanation.val` must be 2D array of shape `(input_length, batchsize)`. Got array of shape $(size(x)) instead.",
),
)
batchsize = size(expl.val, 2)
textsize = length(texts)
batchsize != textsize && throw(
ArgumentError("Batchsize $batchsize doesn't match number of texts $textsize.")
)

c = get_heatmapping_config(expl; kwargs...)
return [
TextHeatmaps.heatmap(v, t; colorscheme=c.colorscheme, rangescale=c.rangescale) for
(v, t) in zip(eachcol(expl.val), texts)
]
end

function TextHeatmaps.heatmap(
expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...
)
return heatmap(expl, [text]; kwargs...)
end

end # module
10 changes: 10 additions & 0 deletions src/TextHeatmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@ using Crayons: Crayon
using FixedPointNumbers: N0f8
using Colors: Colorant, RGB, hex
using ColorSchemes: ColorScheme, colorschemes, get, seismic
using Requires: @require

include("heatmap.jl")

if !isdefined(Base, :get_extension)
using Requires
function __init__()
@require XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" include(
"../ext/TextHeatmapsXAIBaseExt.jl"
)
end
end

export heatmap

end # module
5 changes: 4 additions & 1 deletion src/heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ struct TextHeatmap{
end

function TextHeatmap(
val, words; colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME, rangescale=DEFAULT_RANGESCALE
val,
words;
colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME,
rangescale=DEFAULT_RANGESCALE,
)
if size(val) != size(words)
throw(ArgumentError("Sizes of values and words don't match"))
Expand Down
8 changes: 2 additions & 6 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[compat]
Aqua = "0.7"
ColorSchemes = "3"
Colors = "0.12"
FixedPointNumbers = "0.8"
ReferenceTests = "0.10"
1 change: 1 addition & 0 deletions test/references/Gradient1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test Text Heatmap
1 change: 1 addition & 0 deletions test/references/Gradient2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
another dummy input
1 change: 1 addition & 0 deletions test/references/LRP1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test Text Heatmap
1 change: 1 addition & 0 deletions test/references/LRP1_extrema.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test Text Heatmap
1 change: 1 addition & 0 deletions test/references/LRP2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
another dummy input
1 change: 1 addition & 0 deletions test/references/LRP2_extrema.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
another dummy input
48 changes: 11 additions & 37 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,23 @@ using FixedPointNumbers
using Test
using ReferenceTests
using Aqua
using JuliaFormatter

@testset "TextHeatmaps.jl" begin
@testset "Aqua.jl" begin
@info "Running Aqua.jl's auto quality assurance tests. These might print warnings from dependencies."
Aqua.test_all(TextHeatmaps)
end
@testset "JuliaFormatter.jl" begin
@info "Running JuliaFormatter's code formatting tests."
@test format(TextHeatmaps; verbose=false, overwrite=false)
end
@testset "Heatmap" begin
words = ["Test", "TextHeatmaps"]
val = [4.2, -1.0]

colorscheme = TextHeatmaps.seismic
cmin = get(colorscheme, 0) # red
cmax = get(colorscheme, 1) # blue

# Test default ColorScheme seismic
h = heatmap(val, words)
@test h.colors[1] cmax
@test h.colors[2] != cmin
@test_reference "references/seismic_centered.txt" repr("text/plain", h)
@test_reference "references/seismic_centered_html.txt" repr("text/html", h)

h = heatmap(val, words; rangescale=:extrema)
@test h.colors[1] cmax
@test h.colors[2] cmin
@test_reference "references/seismic_extrema.txt" repr("text/plain", h)

# Test other colorschemes
colorscheme = ColorSchemes.inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test colorscheme symbols
colorscheme = :inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test errors
@test_throws ArgumentError heatmap(val, ["Test", "Text", "Heatmaps"])
# Test inner constructor
@test_throws ArgumentError TextHeatmaps.TextHeatmap(val, words, [cmin, cmax, cmax])
@info "Testing heatmaps..."
include("test_heatmap.jl")
end
@testset "XAIBase extension" begin
@info "Testing heatmaps on XAIBase explanations..."
include("test_xaibase_ext.jl")
end
end
37 changes: 37 additions & 0 deletions test/test_heatmap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
words = ["Test", "TextHeatmaps"]
val = [4.2, -1.0]

colorscheme = TextHeatmaps.seismic
cmin = get(colorscheme, 0) # red
cmax = get(colorscheme, 1) # blue

# Test default ColorScheme seismic
h = heatmap(val, words)
@test h.colors[1] cmax
@test h.colors[2] != cmin
@test_reference "references/seismic_centered.txt" repr("text/plain", h)
@test_reference "references/seismic_centered_html.txt" repr("text/html", h)

h = heatmap(val, words; rangescale=:extrema)
@test h.colors[1] cmax
@test h.colors[2] cmin
@test_reference "references/seismic_extrema.txt" repr("text/plain", h)

# Test other colorschemes
colorscheme = ColorSchemes.inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test colorscheme symbols
colorscheme = :inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test errors
@test_throws ArgumentError heatmap(val, ["Test", "Text", "Heatmaps"])
# Test inner constructor
@test_throws ArgumentError TextHeatmaps.TextHeatmap(val, words, [cmin, cmax, cmax])
24 changes: 24 additions & 0 deletions test/test_xaibase_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using XAIBase

val = output = [1 6; 2 5; 3 4]
text = [["Test", "Text", "Heatmap"], ["another", "dummy", "input"]]
output_selection = [CartesianIndex(1, 2), CartesianIndex(3, 4)] # irrelevant
expl = Explanation(val, output, output_selection, :Gradient, :sensitivity)
h = heatmap(expl, text)
@test_reference "references/Gradient1.txt" repr("text/plain", h[1])
@test_reference "references/Gradient2.txt" repr("text/plain", h[2])

expl = Explanation(
val[:, 1:1], output[:, 1:1], output_selection[1], :Gradient, :sensitivity
)
h = heatmap(expl, text[1])
@test_reference "references/Gradient1.txt" repr("text/plain", only(h))

expl = Explanation(val, output, output_selection, :LRP, :attribution)
h = heatmap(expl, text)
@test_reference "references/LRP1.txt" repr("text/plain", h[1])
@test_reference "references/LRP2.txt" repr("text/plain", h[2])

h = heatmap(expl, text; rangescale=:extrema)
@test_reference "references/LRP1_extrema.txt" repr("text/plain", h[1])
@test_reference "references/LRP2_extrema.txt" repr("text/plain", h[2])

0 comments on commit 41874be

Please sign in to comment.