Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions src/measures/confusion_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ end
ConfusionMatrix(m, labels)

Instantiates a confusion matrix out of a square integer matrix `m`.
Rows are the predicted class, columns the ground truth. See also
the [wikipedia article](https://en.wikipedia.org/wiki/Confusion_matrix).
Rows are the predicted class, columns the ground truth. See also the
[wikipedia article](https://en.wikipedia.org/wiki/Confusion_matrix).

"""
function ConfusionMatrix(m::Matrix{Int}, labels::Vector{String})
s = size(m)
Expand Down Expand Up @@ -46,17 +47,25 @@ The ordering follows that of `levels(y)`.

## Note

To decrease the risk of unexpected errors, if `y` does not have scientific type
`OrderedFactor{2}` (and so does not have a "natural ordering"
negative-positive), a warning is shown indicating the current order unless the
user explicitly specifies either `rev` or `perm` in which case it's assumed the
user is aware of the class ordering.
To decrease the risk of unexpected errors, if `y` does not have
scientific type `OrderedFactor{2}` (and so does not have a "natural
ordering" negative-positive), a warning is shown indicating the
current order unless the user explicitly specifies either `rev` or
`perm` in which case it's assumed the user is aware of the class
ordering.

The `confusion_matrix` is a measure (although neither a score nor a
loss) and so may be specified as such in calls to `evaluate`,
`evaluate!`, although not in `TunedModel`s. In this case, however,
there no way to specify an ordering different from `levels(y)`, where
`y` is the target.

"""
function confusion_matrix(ŷ::VC, y::VC;
function confusion_matrix(ŷ::Vec{<:CategoricalElement}, y::Vec{<:CategoricalElement};
rev::Union{Nothing,Bool}=nothing,
perm::Union{Nothing,Vector{<:Integer}}=nothing,
warn::Bool=true
) where VC <: Vec{<:CategoricalElement}
warn::Bool=true)

check_dimensions(ŷ, y)
levels_ = levels(y)
nc = length(levels_)
Expand Down Expand Up @@ -172,3 +181,32 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrix{C}
"└" * "─"^cw * "┴" * ("─"^cw * "┴")^(C-1) * ("─"^cw * "┘") |> wline
write(stream, take!(iob))
end


## MAKE CONFUSION MATRIX A MEASURE

const Confusion = typeof(confusion_matrix)

is_measure(::Confusion) = true
is_measure_type(::Type{Confusion}) = true

MLJModelInterface.name(::Type{Confusion}) = "confusion_matrix"
MLJModelInterface.target_scitype(::Type{Confusion}) =
AbstractVector{<:Finite}
MLJModelInterface.supports_weights(::Type{Confusion}) = false
MLJModelInterface.prediction_type(::Type{Confusion}) = :deterministic
MLJModelInterface.docstring(::Type{Confusion}) =
"confusion matrix; aliases: confusion_matrix, confmat. "
orientation(::Type{Confusion}) = :other
reports_each_observation(::Type{Confusion}) = false
is_feature_dependent(::Type{Confusion}) = false
aggregation(::Type{Confusion}) = Sum()

# aggregation:
Base.round(m::MLJBase.ConfusionMatrix; kws...) = m
function Base.:+(m1::ConfusionMatrix, m2::ConfusionMatrix)
if m1.labels != m2.labels
throw(ArgumentError("Confusion matrix labels must agree"))
end
ConfusionMatrix(m1.mat + m2.mat, m1.labels)
end
2 changes: 2 additions & 0 deletions src/measures/registry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ const LOSSFUNCTIONS_MEASURE_TYPES =

const MEASURE_TYPES = vcat(LOCAL_MEASURE_TYPES, LOSSFUNCTIONS_MEASURE_TYPES)

push!(MEASURE_TYPES, Confusion)

const MeasureProxy = NamedTuple{Tuple(MEASURE_TRAITS)}

Base.show(stream::IO, p::MeasureProxy) =
Expand Down
5 changes: 3 additions & 2 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,11 @@ end
measure=nothing,
weights=nothing,
operation=predict,
n = 1,
repeats = 1,
acceleration=default_resource(),
force=false,
verbosity=1)
verbosity=1,
check_measure=true)

Estimate the performance of a machine `mach` wrapping a supervised
model in data, using the specified `resampling` strategy (defaulting
Expand Down
50 changes: 0 additions & 50 deletions test/measures/finite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,56 +35,6 @@ seed!(51803)
@test -mean(BrierScore()(yhat2, y2)) / 2 ≈ 0.21875
end

@testset "confusion matrix" begin
y = categorical(['m', 'f', 'n', 'f', 'm', 'n', 'n', 'm', 'f'])
ŷ = categorical(['f', 'f', 'm', 'f', 'n', 'm', 'n', 'm', 'f'])
l = levels(y) # f, m, n
cm = confmat(ŷ, y; warn=false)
e(l,i,j) = sum((ŷ .== l[i]) .& (y .== l[j]))
for i in 1:3, j in 1:3
@test cm[i,j] == e(l,i,j)
end
perm = [3, 1, 2]
l2 = l[perm]
cm2 = confmat(ŷ, y; perm=perm) # no warning because permutation is given
for i in 1:3, j in 1:3
@test cm2[i,j] == e(l2,i,j)
end
@test_logs (:warn, "The classes are un-ordered,\nusing order: ['f', 'm', 'n'].\nTo suppress this warning, consider coercing to OrderedFactor.") confmat(ŷ, y)
ŷc = coerce(ŷ, OrderedFactor)
yc = coerce(y, OrderedFactor)
@test confmat(ŷc, yc).mat == cm.mat

y = categorical(['a','b','a','b'])
ŷ = categorical(['b','b','a','a'])
@test_logs (:warn, "The classes are un-ordered,\nusing: negative='a' and positive='b'.\nTo suppress this warning, consider coercing to OrderedFactor.") confmat(ŷ, y)

# more tests for coverage
y = categorical([1,2,3,1,2,3,1,2,3])
ŷ = categorical([1,2,3,1,2,3,1,2,3])
@test_throws ArgumentError confmat(ŷ, y, rev=true)

# silly test for display
ŷ = coerce(y, OrderedFactor)
y = coerce(y, OrderedFactor)
iob = IOBuffer()
Base.show(iob, MIME("text/plain"), confmat(ŷ, y))
siob = String(take!(iob))
@test strip(siob) == strip("""
┌─────────────────────────────────────────┐
│ Ground Truth │
┌─────────────┼─────────────┬─────────────┬─────────────┤
│ Predicted │ 1 │ 2 │ 3 │
├─────────────┼─────────────┼─────────────┼─────────────┤
│ 1 │ 3 │ 0 │ 0 │
├─────────────┼─────────────┼─────────────┼─────────────┤
│ 2 │ 0 │ 3 │ 0 │
├─────────────┼─────────────┼─────────────┼─────────────┤
│ 3 │ 0 │ 0 │ 3 │
└─────────────┴─────────────┴─────────────┴─────────────┘""")

end

@testset "mcr, acc, bacc, mcc" begin
y = categorical(['m', 'f', 'n', 'f', 'm', 'n', 'n', 'm', 'f'])
ŷ = categorical(['f', 'f', 'm', 'f', 'n', 'm', 'n', 'm', 'f'])
Expand Down