diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 951f890..97fe7c9 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -31,12 +31,15 @@ import LearnBase: export + value, + value!, deriv2!, value_fun, deriv_fun, deriv2_fun, value_deriv_fun, + ZeroOneLoss, LogitMarginLoss, PerceptronLoss, HingeLoss, @@ -61,6 +64,7 @@ export QuantileLoss, PinballLoss, + MisclassLoss, PoissonLoss, LogitProbLoss, CrossEntropyLoss, diff --git a/src/supervised/other.jl b/src/supervised/other.jl index d2b17fb..7389def 100644 --- a/src/supervised/other.jl +++ b/src/supervised/other.jl @@ -1,3 +1,36 @@ +@doc doc""" + MisclassLoss <: SupervisedLoss + +Misclassification loss that assigns `1` for misclassified +examples and `0` otherwise. It is a generalization of +`ZeroOneLoss` for more than two classes. +""" +struct MisclassLoss <: SupervisedLoss end + +agreement(target, output) = target == output + +value(::MisclassLoss, agreement::Bool) = agreement ? 0 : 1 +deriv(::MisclassLoss, agreement::Bool) = 0 +deriv2(::MisclassLoss, agreement::Bool) = 0 +value_deriv(::MisclassLoss, agreement::Bool) = agreement ? (0, 0) : (1, 0) + +value(loss::MisclassLoss, target::Number, output::Number) = value(loss, agreement(target, output)) +deriv(loss::MisclassLoss, target::Number, output::Number) = deriv(loss, agreement(target, output)) +deriv2(loss::MisclassLoss, target::Number, output::Number) = deriv2(loss, agreement(target, output)) + +isminimizable(::MisclassLoss) = false +isdifferentiable(::MisclassLoss) = false +isdifferentiable(::MisclassLoss, at) = at != 0 +istwicedifferentiable(::MisclassLoss) = false +istwicedifferentiable(::MisclassLoss, at) = at != 0 +isnemitski(::MisclassLoss) = false +islipschitzcont(::MisclassLoss) = false +isconvex(::MisclassLoss) = false +isclasscalibrated(::MisclassLoss) = false +isclipable(::MisclassLoss) = false + +# =============================================================== + @doc doc""" PoissonLoss <: SupervisedLoss diff --git a/test/runtests.jl b/test/runtests.jl index a2264f6..1078231 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,10 @@ margin_losses = [ DWDMarginLoss(1), DWDMarginLoss(2) ] +other_losses = [ + MisclassLoss(), PoissonLoss(), CrossEntropyLoss() +] + for t in tests @testset "$t" begin include(t) diff --git a/test/tst_loss.jl b/test/tst_loss.jl index 49cf6ec..2f2fa9e 100644 --- a/test/tst_loss.jl +++ b/test/tst_loss.jl @@ -483,6 +483,9 @@ end @testset "Test other loss against reference function" begin + _misclassloss(y, t) = y == t ? 0 : 1 + test_value(MisclassLoss(), _misclassloss, 1:10, vcat(1:5,7:11)) + _crossentropyloss(y, t) = -y*log(t) - (1-y)*log(1-t) test_value(CrossEntropyLoss(), _crossentropyloss, 0:0.01:1, 0.01:0.01:0.99)