Skip to content

Commit

Permalink
Merge pull request #123 from juliohm/crossentropy
Browse files Browse the repository at this point in the history
Rename CrossentropyLoss to CrossEntropyLoss
  • Loading branch information
joshday committed Mar 23, 2020
2 parents 8c09d60 + f1a63f6 commit 0056439
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/LossFunctions.jl
Expand Up @@ -63,7 +63,7 @@ export

PoissonLoss,
LogitProbLoss,
CrossentropyLoss,
CrossEntropyLoss,
ZeroOneLoss,

OrdinalMarginLoss,
Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Expand Up @@ -5,6 +5,7 @@
@deprecate scaled(loss, ::Type{Val{K}}) where {K} scaled(loss, Val(K))
@deprecate weightedloss(loss::Loss, ::Type{Val{W}}) where {W} weightedloss(loss, Val(W))
@deprecate WeightedBinaryLoss(loss, ::Type{Val{W}}) where {W} WeightedBinaryLoss(loss, Val(W))
@deprecate CrossentropyLoss CrossEntropyLoss

Base.@deprecate_binding AverageMode AggregateMode
export AvgMode
Expand Down
26 changes: 13 additions & 13 deletions src/supervised/other.jl
Expand Up @@ -27,16 +27,16 @@ isstronglyconvex(::PoissonLoss) = false
# ===============================================================

@doc doc"""
CrossentropyLoss <: SupervisedLoss
CrossEntropyLoss <: SupervisedLoss
Cross-entropy loss also known as log loss and logistic loss is defined as:
``L(target, output) = - target*ln(output) - (1-target)*ln(1-output)``
"""
struct CrossentropyLoss <: SupervisedLoss end
const LogitProbLoss = CrossentropyLoss
struct CrossEntropyLoss <: SupervisedLoss end
const LogitProbLoss = CrossEntropyLoss

function value(loss::CrossentropyLoss, target::Number, output::Number)
function value(loss::CrossEntropyLoss, target::Number, output::Number)
target >= 0 && target <=1 || error("target must be in [0,1]")
output >= 0 && output <=1 || error("output must be in [0,1]")
if target == 0
Expand All @@ -47,15 +47,15 @@ function value(loss::CrossentropyLoss, target::Number, output::Number)
-(target * log(output) + (1-target) * log(1-output))
end
end
deriv(loss::CrossentropyLoss, target::Number, output::Number) = (1-target) / (1-output) - target / output
deriv2(loss::CrossentropyLoss, target::Number, output::Number) = (1-target) / (1-output)^2 + target / output^2
value_deriv(loss::CrossentropyLoss, target::Number, output::Number) = (value(loss,target,output), deriv(loss,target,output))

isdifferentiable(::CrossentropyLoss) = true
isdifferentiable(::CrossentropyLoss, y, t) = t != 0 && t != 1
istwicedifferentiable(::CrossentropyLoss) = true
istwicedifferentiable(::CrossentropyLoss, y, t) = t != 0 && t != 1
isconvex(::CrossentropyLoss) = true
deriv(loss::CrossEntropyLoss, target::Number, output::Number) = (1-target) / (1-output) - target / output
deriv2(loss::CrossEntropyLoss, target::Number, output::Number) = (1-target) / (1-output)^2 + target / output^2
value_deriv(loss::CrossEntropyLoss, target::Number, output::Number) = (value(loss,target,output), deriv(loss,target,output))

isdifferentiable(::CrossEntropyLoss) = true
isdifferentiable(::CrossEntropyLoss, y, t) = t != 0 && t != 1
istwicedifferentiable(::CrossEntropyLoss) = true
istwicedifferentiable(::CrossEntropyLoss, y, t) = t != 0 && t != 1
isconvex(::CrossEntropyLoss) = true

# ===============================================================
# L(target, output) = sign(agreement) < 0 ? 1 : 0
6 changes: 3 additions & 3 deletions test/tst_loss.jl
Expand Up @@ -484,7 +484,7 @@ end

@testset "Test other loss against reference function" begin
_crossentropyloss(y, t) = -y*log(t) - (1-y)*log(1-t)
test_value(CrossentropyLoss(), _crossentropyloss, 0:0.01:1, 0.01:0.01:0.99)
test_value(CrossEntropyLoss(), _crossentropyloss, 0:0.01:1, 0.01:0.01:0.99)

_poissonloss(y, t) = exp(t) - t*y
test_value(PoissonLoss(), _poissonloss, 0:10, range(0,stop=10,length=11))
Expand Down Expand Up @@ -540,8 +540,8 @@ end
@testset "Test first and second derivatives of other losses" begin
test_deriv(PoissonLoss(), -10:.2:10, 0:30)
test_deriv2(PoissonLoss(), -10:.2:10, 0:30)
test_deriv(CrossentropyLoss(), 0:0.01:1, 0.01:0.01:0.99)
test_deriv2(CrossentropyLoss(), 0:0.01:1, 0.01:0.01:0.99)
test_deriv(CrossEntropyLoss(), 0:0.01:1, 0.01:0.01:0.99)
test_deriv2(CrossEntropyLoss(), 0:0.01:1, 0.01:0.01:0.99)
end

@testset "Test second derivatives of distance-based losses" begin
Expand Down

0 comments on commit 0056439

Please sign in to comment.