diff --git a/src/measures/confusion_matrix.jl b/src/measures/confusion_matrix.jl index f25e16dd..40742d7e 100644 --- a/src/measures/confusion_matrix.jl +++ b/src/measures/confusion_matrix.jl @@ -13,8 +13,9 @@ end ConfusionMatrix(m, labels) Instantiates a confusion matrix out of a square integer matrix `m`. -Rows are the predicted class, columns the ground truth. See also -the [wikipedia article](https://en.wikipedia.org/wiki/Confusion_matrix). +Rows are the predicted class, columns the ground truth. See also the +[wikipedia article](https://en.wikipedia.org/wiki/Confusion_matrix). + """ function ConfusionMatrix(m::Matrix{Int}, labels::Vector{String}) s = size(m) @@ -46,17 +47,25 @@ The ordering follows that of `levels(y)`. ## Note -To decrease the risk of unexpected errors, if `y` does not have scientific type -`OrderedFactor{2}` (and so does not have a "natural ordering" -negative-positive), a warning is shown indicating the current order unless the -user explicitly specifies either `rev` or `perm` in which case it's assumed the -user is aware of the class ordering. +To decrease the risk of unexpected errors, if `y` does not have +scientific type `OrderedFactor{2}` (and so does not have a "natural +ordering" negative-positive), a warning is shown indicating the +current order unless the user explicitly specifies either `rev` or +`perm` in which case it's assumed the user is aware of the class +ordering. + +The `confusion_matrix` is a measure (although neither a score nor a +loss) and so may be specified as such in calls to `evaluate`, +`evaluate!`, although not in `TunedModel`s. In this case, however, +there no way to specify an ordering different from `levels(y)`, where +`y` is the target. + """ -function confusion_matrix(ŷ::VC, y::VC; +function confusion_matrix(ŷ::Vec{<:CategoricalElement}, y::Vec{<:CategoricalElement}; rev::Union{Nothing,Bool}=nothing, perm::Union{Nothing,Vector{<:Integer}}=nothing, - warn::Bool=true - ) where VC <: Vec{<:CategoricalElement} + warn::Bool=true) + check_dimensions(ŷ, y) levels_ = levels(y) nc = length(levels_) @@ -172,3 +181,32 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrix{C} "└" * "─"^cw * "┴" * ("─"^cw * "┴")^(C-1) * ("─"^cw * "┘") |> wline write(stream, take!(iob)) end + + +## MAKE CONFUSION MATRIX A MEASURE + +const Confusion = typeof(confusion_matrix) + +is_measure(::Confusion) = true +is_measure_type(::Type{Confusion}) = true + +MLJModelInterface.name(::Type{Confusion}) = "confusion_matrix" +MLJModelInterface.target_scitype(::Type{Confusion}) = + AbstractVector{<:Finite} +MLJModelInterface.supports_weights(::Type{Confusion}) = false +MLJModelInterface.prediction_type(::Type{Confusion}) = :deterministic +MLJModelInterface.docstring(::Type{Confusion}) = + "confusion matrix; aliases: confusion_matrix, confmat. " +orientation(::Type{Confusion}) = :other +reports_each_observation(::Type{Confusion}) = false +is_feature_dependent(::Type{Confusion}) = false +aggregation(::Type{Confusion}) = Sum() + +# aggregation: +Base.round(m::MLJBase.ConfusionMatrix; kws...) = m +function Base.:+(m1::ConfusionMatrix, m2::ConfusionMatrix) + if m1.labels != m2.labels + throw(ArgumentError("Confusion matrix labels must agree")) + end + ConfusionMatrix(m1.mat + m2.mat, m1.labels) +end diff --git a/src/measures/registry.jl b/src/measures/registry.jl index 101e542e..d6ae7742 100644 --- a/src/measures/registry.jl +++ b/src/measures/registry.jl @@ -16,6 +16,8 @@ const LOSSFUNCTIONS_MEASURE_TYPES = const MEASURE_TYPES = vcat(LOCAL_MEASURE_TYPES, LOSSFUNCTIONS_MEASURE_TYPES) +push!(MEASURE_TYPES, Confusion) + const MeasureProxy = NamedTuple{Tuple(MEASURE_TRAITS)} Base.show(stream::IO, p::MeasureProxy) = diff --git a/src/resampling.jl b/src/resampling.jl index 45d07c50..b052156a 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -383,10 +383,11 @@ end measure=nothing, weights=nothing, operation=predict, - n = 1, + repeats = 1, acceleration=default_resource(), force=false, - verbosity=1) + verbosity=1, + check_measure=true) Estimate the performance of a machine `mach` wrapping a supervised model in data, using the specified `resampling` strategy (defaulting diff --git a/test/measures/finite.jl b/test/measures/finite.jl index 3ec2a911..a364ecd0 100644 --- a/test/measures/finite.jl +++ b/test/measures/finite.jl @@ -35,56 +35,6 @@ seed!(51803) @test -mean(BrierScore()(yhat2, y2)) / 2 ≈ 0.21875 end -@testset "confusion matrix" begin - y = categorical(['m', 'f', 'n', 'f', 'm', 'n', 'n', 'm', 'f']) - ŷ = categorical(['f', 'f', 'm', 'f', 'n', 'm', 'n', 'm', 'f']) - l = levels(y) # f, m, n - cm = confmat(ŷ, y; warn=false) - e(l,i,j) = sum((ŷ .== l[i]) .& (y .== l[j])) - for i in 1:3, j in 1:3 - @test cm[i,j] == e(l,i,j) - end - perm = [3, 1, 2] - l2 = l[perm] - cm2 = confmat(ŷ, y; perm=perm) # no warning because permutation is given - for i in 1:3, j in 1:3 - @test cm2[i,j] == e(l2,i,j) - end - @test_logs (:warn, "The classes are un-ordered,\nusing order: ['f', 'm', 'n'].\nTo suppress this warning, consider coercing to OrderedFactor.") confmat(ŷ, y) - ŷc = coerce(ŷ, OrderedFactor) - yc = coerce(y, OrderedFactor) - @test confmat(ŷc, yc).mat == cm.mat - - y = categorical(['a','b','a','b']) - ŷ = categorical(['b','b','a','a']) - @test_logs (:warn, "The classes are un-ordered,\nusing: negative='a' and positive='b'.\nTo suppress this warning, consider coercing to OrderedFactor.") confmat(ŷ, y) - - # more tests for coverage - y = categorical([1,2,3,1,2,3,1,2,3]) - ŷ = categorical([1,2,3,1,2,3,1,2,3]) - @test_throws ArgumentError confmat(ŷ, y, rev=true) - - # silly test for display - ŷ = coerce(y, OrderedFactor) - y = coerce(y, OrderedFactor) - iob = IOBuffer() - Base.show(iob, MIME("text/plain"), confmat(ŷ, y)) - siob = String(take!(iob)) - @test strip(siob) == strip(""" - ┌─────────────────────────────────────────┐ - │ Ground Truth │ - ┌─────────────┼─────────────┬─────────────┬─────────────┤ - │ Predicted │ 1 │ 2 │ 3 │ - ├─────────────┼─────────────┼─────────────┼─────────────┤ - │ 1 │ 3 │ 0 │ 0 │ - ├─────────────┼─────────────┼─────────────┼─────────────┤ - │ 2 │ 0 │ 3 │ 0 │ - ├─────────────┼─────────────┼─────────────┼─────────────┤ - │ 3 │ 0 │ 0 │ 3 │ - └─────────────┴─────────────┴─────────────┴─────────────┘""") - -end - @testset "mcr, acc, bacc, mcc" begin y = categorical(['m', 'f', 'n', 'f', 'm', 'n', 'n', 'm', 'f']) ŷ = categorical(['f', 'f', 'm', 'f', 'n', 'm', 'n', 'm', 'f'])