diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8b99d00..7921e65 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,6 +20,7 @@ jobs: matrix: version: - '1.6' + - '1' - 'nightly' os: - ubuntu-latest @@ -33,7 +34,9 @@ jobs: arch: ${{ matrix.arch }} - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 + continue-on-error: ${{ matrix.version == 'nightly' }} - uses: julia-actions/julia-runtest@v1 + continue-on-error: ${{ matrix.version == 'nightly' }} - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v3 with: diff --git a/Project.toml b/Project.toml index 832dae6..f6f3cbc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,17 +1,27 @@ name = "TextHeatmaps" uuid = "2dd6718a-6083-4824-b9f7-90e4a57f72d2" authors = ["Adrian Hill "] -version = "1.1.0" +version = "1.2.0-DEV" [deps] ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" +XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" + +[weakdeps] +XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" + +[extensions] +TextHeatmapsXAIBaseExt = "XAIBase" [compat] ColorSchemes = "3" Colors = "0.12" Crayons = "4" FixedPointNumbers = "0.8" +Requires = "1" +XAIBase = "3" julia = "1.6" diff --git a/ext/TextHeatmapsXAIBaseExt.jl b/ext/TextHeatmapsXAIBaseExt.jl new file mode 100644 index 0000000..9e30917 --- /dev/null +++ b/ext/TextHeatmapsXAIBaseExt.jl @@ -0,0 +1,79 @@ +module TextHeatmapsXAIBaseExt + +using TextHeatmaps, XAIBase + +struct HeatmapConfig + colorscheme::Symbol + reduce::Symbol + rangescale::Symbol +end + +const DEFAULT_COLORSCHEME = :seismic +const DEFAULT_REDUCE = :sum +const DEFAULT_RANGESCALE = :centered +const DEFAULT_HEATMAP_PRESET = HeatmapConfig( + DEFAULT_COLORSCHEME, DEFAULT_REDUCE, DEFAULT_RANGESCALE +) + +const HEATMAP_PRESETS = Dict{Symbol,HeatmapConfig}( + :attribution => HeatmapConfig(:seismic, :sum, :centered), + :sensitivity => HeatmapConfig(:grays, :norm, :extrema), + :cam => HeatmapConfig(:jet, :sum, :extrema), +) + +# Select HeatmapConfig preset based on heatmapping style in Explanation +function get_heatmapping_config(heatmap::Symbol) + return get(HEATMAP_PRESETS, heatmap, DEFAULT_HEATMAP_PRESET) +end + +# Override HeatmapConfig preset with keyword arguments +function get_heatmapping_config(expl::Explanation; kwargs...) + c = get_heatmapping_config(expl.heatmap) + + colorscheme = get(kwargs, :colorscheme, c.colorscheme) + rangescale = get(kwargs, :rangescale, c.rangescale) + reduce = get(kwargs, :reduce, c.reduce) + return HeatmapConfig(colorscheme, reduce, rangescale) +end + +""" + heatmap(explanation, text) + +Visualize [`Explanation`](@ref) from XAIBase as text heatmap. +Text should be a vector containing vectors of strings, one for each input in the batched explanation. + +## Keyword arguments +- `colorscheme::Union{ColorScheme,Symbol}`: color scheme from ColorSchemes.jl. + Defaults to `:$DEFAULT_COLORSCHEME`. +- `rangescale::Symbol`: selects how the color channel reduced heatmap is normalized + before the color scheme is applied. Can be either `:extrema` or `:centered`. + Defaults to `:$DEFAULT_RANGESCALE` for use with the default color scheme `:$DEFAULT_COLORSCHEME`. +""" +function TextHeatmaps.heatmap( + expl::Explanation, texts::AbstractVector{<:AbstractVector{<:AbstractString}}; kwargs... +) + ndims(expl.val) != 2 && throw( + ArgumentError( + "To heatmap text, `explanation.val` must be 2D array of shape `(input_length, batchsize)`. Got array of shape $(size(x)) instead.", + ), + ) + batchsize = size(expl.val, 2) + textsize = length(texts) + batchsize != textsize && throw( + ArgumentError("Batchsize $batchsize doesn't match number of texts $textsize.") + ) + + c = get_heatmapping_config(expl; kwargs...) + return [ + TextHeatmaps.heatmap(v, t; colorscheme=c.colorscheme, rangescale=c.rangescale) for + (v, t) in zip(eachcol(expl.val), texts) + ] +end + +function TextHeatmaps.heatmap( + expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs... +) + return heatmap(expl, [text]; kwargs...) +end + +end # module diff --git a/src/TextHeatmaps.jl b/src/TextHeatmaps.jl index 09da549..05eedf7 100644 --- a/src/TextHeatmaps.jl +++ b/src/TextHeatmaps.jl @@ -4,9 +4,19 @@ using Crayons: Crayon using FixedPointNumbers: N0f8 using Colors: Colorant, RGB, hex using ColorSchemes: ColorScheme, colorschemes, get, seismic +using Requires: @require include("heatmap.jl") +if !isdefined(Base, :get_extension) + using Requires + function __init__() + @require XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" include( + "../ext/TextHeatmapsXAIBaseExt.jl" + ) + end +end + export heatmap end # module diff --git a/src/heatmap.jl b/src/heatmap.jl index fb19f79..0d45270 100644 --- a/src/heatmap.jl +++ b/src/heatmap.jl @@ -39,7 +39,10 @@ struct TextHeatmap{ end function TextHeatmap( - val, words; colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME, rangescale=DEFAULT_RANGESCALE + val, + words; + colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME, + rangescale=DEFAULT_RANGESCALE, ) if size(val) != size(words) throw(ArgumentError("Sizes of values and words don't match")) diff --git a/test/Project.toml b/test/Project.toml index 233d918..09bd8ba 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,12 +3,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" -[compat] -Aqua = "0.7" -ColorSchemes = "3" -Colors = "0.12" -FixedPointNumbers = "0.8" -ReferenceTests = "0.10" diff --git a/test/references/Gradient1.txt b/test/references/Gradient1.txt new file mode 100644 index 0000000..255b698 --- /dev/null +++ b/test/references/Gradient1.txt @@ -0,0 +1 @@ +Test Text Heatmap \ No newline at end of file diff --git a/test/references/Gradient2.txt b/test/references/Gradient2.txt new file mode 100644 index 0000000..7374908 --- /dev/null +++ b/test/references/Gradient2.txt @@ -0,0 +1 @@ +another dummy input \ No newline at end of file diff --git a/test/references/LRP1.txt b/test/references/LRP1.txt new file mode 100644 index 0000000..0cb3d72 --- /dev/null +++ b/test/references/LRP1.txt @@ -0,0 +1 @@ +Test Text Heatmap \ No newline at end of file diff --git a/test/references/LRP1_extrema.txt b/test/references/LRP1_extrema.txt new file mode 100644 index 0000000..995cb76 --- /dev/null +++ b/test/references/LRP1_extrema.txt @@ -0,0 +1 @@ +Test Text Heatmap \ No newline at end of file diff --git a/test/references/LRP2.txt b/test/references/LRP2.txt new file mode 100644 index 0000000..2f73e2e --- /dev/null +++ b/test/references/LRP2.txt @@ -0,0 +1 @@ +another dummy input \ No newline at end of file diff --git a/test/references/LRP2_extrema.txt b/test/references/LRP2_extrema.txt new file mode 100644 index 0000000..3b2c8d6 --- /dev/null +++ b/test/references/LRP2_extrema.txt @@ -0,0 +1 @@ +another dummy input \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 582ca22..f2c7725 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,49 +6,23 @@ using FixedPointNumbers using Test using ReferenceTests using Aqua +using JuliaFormatter @testset "TextHeatmaps.jl" begin @testset "Aqua.jl" begin @info "Running Aqua.jl's auto quality assurance tests. These might print warnings from dependencies." Aqua.test_all(TextHeatmaps) end + @testset "JuliaFormatter.jl" begin + @info "Running JuliaFormatter's code formatting tests." + @test format(TextHeatmaps; verbose=false, overwrite=false) + end @testset "Heatmap" begin - words = ["Test", "TextHeatmaps"] - val = [4.2, -1.0] - - colorscheme = TextHeatmaps.seismic - cmin = get(colorscheme, 0) # red - cmax = get(colorscheme, 1) # blue - - # Test default ColorScheme seismic - h = heatmap(val, words) - @test h.colors[1] ≈ cmax - @test h.colors[2] != cmin - @test_reference "references/seismic_centered.txt" repr("text/plain", h) - @test_reference "references/seismic_centered_html.txt" repr("text/html", h) - - h = heatmap(val, words; rangescale=:extrema) - @test h.colors[1] ≈ cmax - @test h.colors[2] ≈ cmin - @test_reference "references/seismic_extrema.txt" repr("text/plain", h) - - # Test other colorschemes - colorscheme = ColorSchemes.inferno - h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered) - @test_reference "references/inferno_centered.txt" repr("text/plain", h) - h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema) - @test_reference "references/inferno_extrema.txt" repr("text/plain", h) - - # Test colorscheme symbols - colorscheme = :inferno - h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered) - @test_reference "references/inferno_centered.txt" repr("text/plain", h) - h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema) - @test_reference "references/inferno_extrema.txt" repr("text/plain", h) - - # Test errors - @test_throws ArgumentError heatmap(val, ["Test", "Text", "Heatmaps"]) - # Test inner constructor - @test_throws ArgumentError TextHeatmaps.TextHeatmap(val, words, [cmin, cmax, cmax]) + @info "Testing heatmaps..." + include("test_heatmap.jl") + end + @testset "XAIBase extension" begin + @info "Testing heatmaps on XAIBase explanations..." + include("test_xaibase_ext.jl") end end diff --git a/test/test_heatmap.jl b/test/test_heatmap.jl new file mode 100644 index 0000000..fd8c729 --- /dev/null +++ b/test/test_heatmap.jl @@ -0,0 +1,37 @@ +words = ["Test", "TextHeatmaps"] +val = [4.2, -1.0] + +colorscheme = TextHeatmaps.seismic +cmin = get(colorscheme, 0) # red +cmax = get(colorscheme, 1) # blue + +# Test default ColorScheme seismic +h = heatmap(val, words) +@test h.colors[1] ≈ cmax +@test h.colors[2] != cmin +@test_reference "references/seismic_centered.txt" repr("text/plain", h) +@test_reference "references/seismic_centered_html.txt" repr("text/html", h) + +h = heatmap(val, words; rangescale=:extrema) +@test h.colors[1] ≈ cmax +@test h.colors[2] ≈ cmin +@test_reference "references/seismic_extrema.txt" repr("text/plain", h) + +# Test other colorschemes +colorscheme = ColorSchemes.inferno +h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered) +@test_reference "references/inferno_centered.txt" repr("text/plain", h) +h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema) +@test_reference "references/inferno_extrema.txt" repr("text/plain", h) + +# Test colorscheme symbols +colorscheme = :inferno +h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered) +@test_reference "references/inferno_centered.txt" repr("text/plain", h) +h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema) +@test_reference "references/inferno_extrema.txt" repr("text/plain", h) + +# Test errors +@test_throws ArgumentError heatmap(val, ["Test", "Text", "Heatmaps"]) +# Test inner constructor +@test_throws ArgumentError TextHeatmaps.TextHeatmap(val, words, [cmin, cmax, cmax]) diff --git a/test/test_xaibase_ext.jl b/test/test_xaibase_ext.jl new file mode 100644 index 0000000..074dc71 --- /dev/null +++ b/test/test_xaibase_ext.jl @@ -0,0 +1,24 @@ +using XAIBase + +val = output = [1 6; 2 5; 3 4] +text = [["Test", "Text", "Heatmap"], ["another", "dummy", "input"]] +output_selection = [CartesianIndex(1, 2), CartesianIndex(3, 4)] # irrelevant +expl = Explanation(val, output, output_selection, :Gradient, :sensitivity) +h = heatmap(expl, text) +@test_reference "references/Gradient1.txt" repr("text/plain", h[1]) +@test_reference "references/Gradient2.txt" repr("text/plain", h[2]) + +expl = Explanation( + val[:, 1:1], output[:, 1:1], output_selection[1], :Gradient, :sensitivity +) +h = heatmap(expl, text[1]) +@test_reference "references/Gradient1.txt" repr("text/plain", only(h)) + +expl = Explanation(val, output, output_selection, :LRP, :attribution) +h = heatmap(expl, text) +@test_reference "references/LRP1.txt" repr("text/plain", h[1]) +@test_reference "references/LRP2.txt" repr("text/plain", h[2]) + +h = heatmap(expl, text; rangescale=:extrema) +@test_reference "references/LRP1_extrema.txt" repr("text/plain", h[1]) +@test_reference "references/LRP2_extrema.txt" repr("text/plain", h[2])