Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XAIBase package extension #4

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
matrix:
version:
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand All @@ -33,7 +34,9 @@ jobs:
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-runtest@v1
continue-on-error: ${{ matrix.version == 'nightly' }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v3
with:
Expand Down
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
name = "TextHeatmaps"
uuid = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "1.1.0"
version = "1.2.0-DEV"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[weakdeps]
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[extensions]
TextHeatmapsXAIBaseExt = "XAIBase"

[compat]
ColorSchemes = "3"
Colors = "0.12"
Crayons = "4"
FixedPointNumbers = "0.8"
Requires = "1"
XAIBase = "3"
julia = "1.6"
79 changes: 79 additions & 0 deletions ext/TextHeatmapsXAIBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
module TextHeatmapsXAIBaseExt

using TextHeatmaps, XAIBase

struct HeatmapConfig
colorscheme::Symbol
reduce::Symbol
rangescale::Symbol
end

const DEFAULT_COLORSCHEME = :seismic
const DEFAULT_REDUCE = :sum
const DEFAULT_RANGESCALE = :centered
const DEFAULT_HEATMAP_PRESET = HeatmapConfig(
DEFAULT_COLORSCHEME, DEFAULT_REDUCE, DEFAULT_RANGESCALE
)

const HEATMAP_PRESETS = Dict{Symbol,HeatmapConfig}(
:attribution => HeatmapConfig(:seismic, :sum, :centered),
:sensitivity => HeatmapConfig(:grays, :norm, :extrema),
:cam => HeatmapConfig(:jet, :sum, :extrema),
)

# Select HeatmapConfig preset based on heatmapping style in Explanation
function get_heatmapping_config(heatmap::Symbol)
return get(HEATMAP_PRESETS, heatmap, DEFAULT_HEATMAP_PRESET)
end

# Override HeatmapConfig preset with keyword arguments
function get_heatmapping_config(expl::Explanation; kwargs...)
c = get_heatmapping_config(expl.heatmap)

colorscheme = get(kwargs, :colorscheme, c.colorscheme)
rangescale = get(kwargs, :rangescale, c.rangescale)
reduce = get(kwargs, :reduce, c.reduce)
return HeatmapConfig(colorscheme, reduce, rangescale)
end

"""
heatmap(explanation, text)

Visualize [`Explanation`](@ref) from XAIBase as text heatmap.
Text should be a vector containing vectors of strings, one for each input in the batched explanation.

## Keyword arguments
- `colorscheme::Union{ColorScheme,Symbol}`: color scheme from ColorSchemes.jl.
Defaults to `:$DEFAULT_COLORSCHEME`.
- `rangescale::Symbol`: selects how the color channel reduced heatmap is normalized
before the color scheme is applied. Can be either `:extrema` or `:centered`.
Defaults to `:$DEFAULT_RANGESCALE` for use with the default color scheme `:$DEFAULT_COLORSCHEME`.
"""
function TextHeatmaps.heatmap(
expl::Explanation, texts::AbstractVector{<:AbstractVector{<:AbstractString}}; kwargs...
)
ndims(expl.val) != 2 && throw(
ArgumentError(
"To heatmap text, `explanation.val` must be 2D array of shape `(input_length, batchsize)`. Got array of shape $(size(x)) instead.",
),
)
batchsize = size(expl.val, 2)
textsize = length(texts)
batchsize != textsize && throw(
ArgumentError("Batchsize $batchsize doesn't match number of texts $textsize.")
)

c = get_heatmapping_config(expl; kwargs...)
return [
TextHeatmaps.heatmap(v, t; colorscheme=c.colorscheme, rangescale=c.rangescale) for
(v, t) in zip(eachcol(expl.val), texts)
]
end

function TextHeatmaps.heatmap(
expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...
)
return heatmap(expl, [text]; kwargs...)
end

end # module
10 changes: 10 additions & 0 deletions src/TextHeatmaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@ using Crayons: Crayon
using FixedPointNumbers: N0f8
using Colors: Colorant, RGB, hex
using ColorSchemes: ColorScheme, colorschemes, get, seismic
using Requires: @require

include("heatmap.jl")

if !isdefined(Base, :get_extension)
using Requires
function __init__()
@require XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" include(
"../ext/TextHeatmapsXAIBaseExt.jl"
)
end
end

export heatmap

end # module
5 changes: 4 additions & 1 deletion src/heatmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ struct TextHeatmap{
end

function TextHeatmap(
val, words; colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME, rangescale=DEFAULT_RANGESCALE
val,
words;
colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME,
rangescale=DEFAULT_RANGESCALE,
)
if size(val) != size(words)
throw(ArgumentError("Sizes of values and words don't match"))
Expand Down
8 changes: 2 additions & 6 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"

[compat]
Aqua = "0.7"
ColorSchemes = "3"
Colors = "0.12"
FixedPointNumbers = "0.8"
ReferenceTests = "0.10"
1 change: 1 addition & 0 deletions test/references/Gradient1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test Text Heatmap
1 change: 1 addition & 0 deletions test/references/Gradient2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
another dummy input
1 change: 1 addition & 0 deletions test/references/LRP1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test Text Heatmap
1 change: 1 addition & 0 deletions test/references/LRP1_extrema.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Test Text Heatmap
1 change: 1 addition & 0 deletions test/references/LRP2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
another dummy input
1 change: 1 addition & 0 deletions test/references/LRP2_extrema.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
another dummy input
48 changes: 11 additions & 37 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,23 @@ using FixedPointNumbers
using Test
using ReferenceTests
using Aqua
using JuliaFormatter

@testset "TextHeatmaps.jl" begin
@testset "Aqua.jl" begin
@info "Running Aqua.jl's auto quality assurance tests. These might print warnings from dependencies."
Aqua.test_all(TextHeatmaps)
end
@testset "JuliaFormatter.jl" begin
@info "Running JuliaFormatter's code formatting tests."
@test format(TextHeatmaps; verbose=false, overwrite=false)
end
@testset "Heatmap" begin
words = ["Test", "TextHeatmaps"]
val = [4.2, -1.0]

colorscheme = TextHeatmaps.seismic
cmin = get(colorscheme, 0) # red
cmax = get(colorscheme, 1) # blue

# Test default ColorScheme seismic
h = heatmap(val, words)
@test h.colors[1] ≈ cmax
@test h.colors[2] != cmin
@test_reference "references/seismic_centered.txt" repr("text/plain", h)
@test_reference "references/seismic_centered_html.txt" repr("text/html", h)

h = heatmap(val, words; rangescale=:extrema)
@test h.colors[1] ≈ cmax
@test h.colors[2] ≈ cmin
@test_reference "references/seismic_extrema.txt" repr("text/plain", h)

# Test other colorschemes
colorscheme = ColorSchemes.inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test colorscheme symbols
colorscheme = :inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test errors
@test_throws ArgumentError heatmap(val, ["Test", "Text", "Heatmaps"])
# Test inner constructor
@test_throws ArgumentError TextHeatmaps.TextHeatmap(val, words, [cmin, cmax, cmax])
@info "Testing heatmaps..."
include("test_heatmap.jl")
end
@testset "XAIBase extension" begin
@info "Testing heatmaps on XAIBase explanations..."
include("test_xaibase_ext.jl")
end
end
37 changes: 37 additions & 0 deletions test/test_heatmap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
words = ["Test", "TextHeatmaps"]
val = [4.2, -1.0]

colorscheme = TextHeatmaps.seismic
cmin = get(colorscheme, 0) # red
cmax = get(colorscheme, 1) # blue

# Test default ColorScheme seismic
h = heatmap(val, words)
@test h.colors[1] ≈ cmax
@test h.colors[2] != cmin
@test_reference "references/seismic_centered.txt" repr("text/plain", h)
@test_reference "references/seismic_centered_html.txt" repr("text/html", h)

h = heatmap(val, words; rangescale=:extrema)
@test h.colors[1] ≈ cmax
@test h.colors[2] ≈ cmin
@test_reference "references/seismic_extrema.txt" repr("text/plain", h)

# Test other colorschemes
colorscheme = ColorSchemes.inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test colorscheme symbols
colorscheme = :inferno
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:centered)
@test_reference "references/inferno_centered.txt" repr("text/plain", h)
h = heatmap(val, words; colorscheme=colorscheme, rangescale=:extrema)
@test_reference "references/inferno_extrema.txt" repr("text/plain", h)

# Test errors
@test_throws ArgumentError heatmap(val, ["Test", "Text", "Heatmaps"])
# Test inner constructor
@test_throws ArgumentError TextHeatmaps.TextHeatmap(val, words, [cmin, cmax, cmax])
24 changes: 24 additions & 0 deletions test/test_xaibase_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using XAIBase

val = output = [1 6; 2 5; 3 4]
text = [["Test", "Text", "Heatmap"], ["another", "dummy", "input"]]
output_selection = [CartesianIndex(1, 2), CartesianIndex(3, 4)] # irrelevant
expl = Explanation(val, output, output_selection, :Gradient, :sensitivity)
h = heatmap(expl, text)
@test_reference "references/Gradient1.txt" repr("text/plain", h[1])
@test_reference "references/Gradient2.txt" repr("text/plain", h[2])

expl = Explanation(
val[:, 1:1], output[:, 1:1], output_selection[1], :Gradient, :sensitivity
)
h = heatmap(expl, text[1])
@test_reference "references/Gradient1.txt" repr("text/plain", only(h))

expl = Explanation(val, output, output_selection, :LRP, :attribution)
h = heatmap(expl, text)
@test_reference "references/LRP1.txt" repr("text/plain", h[1])
@test_reference "references/LRP2.txt" repr("text/plain", h[2])

h = heatmap(expl, text; rangescale=:extrema)
@test_reference "references/LRP1_extrema.txt" repr("text/plain", h[1])
@test_reference "references/LRP2_extrema.txt" repr("text/plain", h[2])
Loading