Skip to content

Commit

Permalink
Merge pull request #122 from juliohm/misclassification
Browse files Browse the repository at this point in the history
Add MisclassLoss as a generalization of ZeroOneLoss
  • Loading branch information
joshday committed Mar 26, 2020
2 parents 0056439 + 8c84c4e commit 9df99f6
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/LossFunctions.jl
Expand Up @@ -31,12 +31,15 @@ import LearnBase:

export

value,
value!,
deriv2!,
value_fun,
deriv_fun,
deriv2_fun,
value_deriv_fun,

ZeroOneLoss,
LogitMarginLoss,
PerceptronLoss,
HingeLoss,
Expand All @@ -61,6 +64,7 @@ export
QuantileLoss,
PinballLoss,

MisclassLoss,
PoissonLoss,
LogitProbLoss,
CrossEntropyLoss,
Expand Down
33 changes: 33 additions & 0 deletions src/supervised/other.jl
@@ -1,3 +1,36 @@
@doc doc"""
MisclassLoss <: SupervisedLoss
Misclassification loss that assigns `1` for misclassified
examples and `0` otherwise. It is a generalization of
`ZeroOneLoss` for more than two classes.
"""
struct MisclassLoss <: SupervisedLoss end

agreement(target, output) = target == output

value(::MisclassLoss, agreement::Bool) = agreement ? 0 : 1
deriv(::MisclassLoss, agreement::Bool) = 0
deriv2(::MisclassLoss, agreement::Bool) = 0
value_deriv(::MisclassLoss, agreement::Bool) = agreement ? (0, 0) : (1, 0)

value(loss::MisclassLoss, target::Number, output::Number) = value(loss, agreement(target, output))
deriv(loss::MisclassLoss, target::Number, output::Number) = deriv(loss, agreement(target, output))
deriv2(loss::MisclassLoss, target::Number, output::Number) = deriv2(loss, agreement(target, output))

isminimizable(::MisclassLoss) = false
isdifferentiable(::MisclassLoss) = false
isdifferentiable(::MisclassLoss, at) = at != 0
istwicedifferentiable(::MisclassLoss) = false
istwicedifferentiable(::MisclassLoss, at) = at != 0
isnemitski(::MisclassLoss) = false
islipschitzcont(::MisclassLoss) = false
isconvex(::MisclassLoss) = false
isclasscalibrated(::MisclassLoss) = false
isclipable(::MisclassLoss) = false

# ===============================================================

@doc doc"""
PoissonLoss <: SupervisedLoss
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Expand Up @@ -35,6 +35,10 @@ margin_losses = [
DWDMarginLoss(1), DWDMarginLoss(2)
]

other_losses = [
MisclassLoss(), PoissonLoss(), CrossEntropyLoss()
]

for t in tests
@testset "$t" begin
include(t)
Expand Down
3 changes: 3 additions & 0 deletions test/tst_loss.jl
Expand Up @@ -483,6 +483,9 @@ end


@testset "Test other loss against reference function" begin
_misclassloss(y, t) = y == t ? 0 : 1
test_value(MisclassLoss(), _misclassloss, 1:10, vcat(1:5,7:11))

_crossentropyloss(y, t) = -y*log(t) - (1-y)*log(1-t)
test_value(CrossEntropyLoss(), _crossentropyloss, 0:0.01:1, 0.01:0.01:0.99)

Expand Down

0 comments on commit 9df99f6

Please sign in to comment.