Skip to content

Commit

Permalink
Merge pull request #80 from nishnik/cross_entropy
Browse files Browse the repository at this point in the history
Added test for Cross Entropy derivative
  • Loading branch information
Evizero authored Mar 1, 2017
2 parents 7334d54 + 0e1ed54 commit 086367e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
20 changes: 16 additions & 4 deletions src/supervised/other.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# ===============================================================
# L(y, t) = exp(t) - t*y

"""
doc"""
PoissonLoss <: SupervisedLoss
Loss under a Poisson noise distribution (KL-divergence)
``L(target, output) = exp(output) - target*output``
"""
immutable PoissonLoss <: SupervisedLoss end

Expand All @@ -26,12 +27,21 @@ isconvex(::PoissonLoss) = true
isstronglyconvex(::PoissonLoss) = false

# ===============================================================
# L(target, output) = - target*ln(output) - (1-target)*ln(1-output)

doc"""
CrossentropyLoss <: SupervisedLoss
Cross-entropy loss also known as log loss and logistic loss is defined as:
``L(target, output) = - target*ln(output) - (1-target)*ln(1-output)``
"""

immutable CrossentropyLoss <: SupervisedLoss end
typealias LogitProbLoss CrossentropyLoss

function value(loss::CrossentropyLoss, target::Number, output::Number)
target >= 0 && target <=1 || error("target must be in [0,1]")
output >= 0 && output <=1 || error("output must be in [0,1]")
if target == 0
-log(1 - output)
elseif target == 1
Expand All @@ -45,8 +55,10 @@ deriv2(loss::CrossentropyLoss, target::Number, output::Number) = (1-target) / (1
value_deriv(loss::CrossentropyLoss, target::Number, output::Number) = (value(loss,target,output), deriv(loss,target,output))

isdifferentiable(::CrossentropyLoss) = true
isdifferentiable(::CrossentropyLoss, y, t) = t != 0 && t != 1
istwicedifferentiable(::CrossentropyLoss) = true
istwicedifferentiable(::CrossentropyLoss, y, t) = t != 0 && t != 1
isconvex(::CrossentropyLoss) = true

# ===============================================================
# L(target, output) = sign(agreement) < 0 ? 1 : 0

29 changes: 25 additions & 4 deletions test/tst_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ function test_deriv(l::DistanceLoss, t_vec)
end
end

function test_deriv(l::SupervisedLoss, t_vec)
function test_deriv(l::SupervisedLoss, y_vec, t_vec)
@testset "$(l): " begin
for y in -10:.2:10, t in t_vec
for y in y_vec, t in t_vec
if isdifferentiable(l, y, t)
d_dual = epsilon(LossFunctions.value(l, y, dual(t, 1)))
d_comp = @inferred deriv(l, y, t)
Expand Down Expand Up @@ -210,6 +210,24 @@ function test_deriv2(l::DistanceLoss, t_vec)
end
end

function test_deriv2(l::SupervisedLoss, y_vec, t_vec)
@testset "$(l): " begin
for y in y_vec, t in t_vec
if istwicedifferentiable(l, y, t) && isdifferentiable(l, y, t)
d2_dual = epsilon(deriv(l, dual(y, 0), dual(t, 1)))
d2_comp = @inferred deriv2(l, y, t)
@test abs(d2_dual - d2_comp) < 1e-10
@test_approx_eq d2_comp @inferred(l''(y, t))
@test_approx_eq d2_comp @inferred deriv2(l, y, t)
@test_approx_eq d2_comp deriv2_fun(l)(y, t)
else
# y-t == 0 ? print(".") : print("$(y-t) ")
#print(".")
end
end
end
end

function test_scaledloss(l::Loss, t_vec, y_vec)
@testset "Scaling for $(l): " begin
for λ = (2.0, 2)
Expand Down Expand Up @@ -509,8 +527,11 @@ end
end
end

@testset "Test first derivatives of other losses" begin
test_deriv(PoissonLoss(), 0:30)
@testset "Test first and second derivatives of other losses" begin
test_deriv(PoissonLoss(), -10:.2:10, 0:30)
test_deriv2(PoissonLoss(), -10:.2:10, 0:30)
test_deriv(CrossentropyLoss(), 0:0.01:1, 0.01:0.01:0.99)
test_deriv2(CrossentropyLoss(), 0:0.01:1, 0.01:0.01:0.99)
end

@testset "Test second derivatives of distance-based losses" begin
Expand Down

0 comments on commit 086367e

Please sign in to comment.