Skip to content

Commit

Permalink
Merge pull request #88 from mihirparadkar/master
Browse files Browse the repository at this point in the history
Ordinals and minor changes to LogitDistLoss
  • Loading branch information
Evizero committed Jul 10, 2017
2 parents 1c097fd + e635127 commit 2a87c9e
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/LossFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ export
CrossentropyLoss,
ZeroOneLoss,

OrdinalMarginLoss,

weightedloss,

AvgMode
Expand All @@ -60,6 +62,7 @@ include("supervised/margin.jl")
include("supervised/scaledloss.jl")
include("supervised/weightedbinary.jl")
include("supervised/other.jl")
include("supervised/ordinal.jl")
include("supervised/io.jl")

# allow using some special losses as function
Expand Down
4 changes: 2 additions & 2 deletions src/supervised/distance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ struct LogitDistLoss <: DistanceLoss end
function value(loss::LogitDistLoss, difference::Number)
er = exp(difference)
T = typeof(er)
-log(T(4) * er / abs2(one(T) + er))
-log(T(4)) - difference + 2log(one(T) + er)
end
function deriv{T<:Number}(loss::LogitDistLoss, difference::T)
tanh(difference / T(2))
Expand All @@ -424,7 +424,7 @@ function value_deriv(loss::LogitDistLoss, difference::Number)
er = exp(difference)
T = typeof(er)
er1 = one(T) + er
-log(T(4) * er / abs2(er1)), (er - one(T)) / (er1)
-log(T(4)) - difference + 2log(er1), (er - one(T)) / (er1)
end

issymmetric(::LogitDistLoss) = true
Expand Down
72 changes: 72 additions & 0 deletions src/supervised/ordinal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
OrdinalMarginLoss <: SupervisedLoss
Modifies a margin loss `loss` to be used on an ordinal domain with
number of levels `N`. It treats each level as an integer between
1 and `N`, inclusive, and penalizes output according to the sum of
level thresholds crossed relative to target
Assumes target is encoded in an Index encoding scheme where levels are
numbered between 1 and `N`
"""
struct OrdinalMarginLoss{L<:MarginLoss, N} <: SupervisedLoss
loss::L
end

function OrdinalMarginLoss(loss::T, ::Type{Val{N}}) where {T<:MarginLoss,N}
typeof(N) <: Number || _serror()
OrdinalMarginLoss{T,N}(loss)
end

#=
for fun in (:value, :deriv, :deriv2)
@eval @fastmath @generated function ($fun)(loss::OrdinalMarginLoss{T, N},
target::Number, output::Number) where {T <: MarginLoss, N}
quote
retval = zero(output)
@nexprs $N t -> begin
not_target = (t != target)
sgn = sign(target - t)
retval += not_target * ($($fun))(loss.loss, sgn, output - t)
end
retval
end
end
end =#

for fun in (:value, :deriv, :deriv2)
@eval @fastmath function ($fun)(loss::OrdinalMarginLoss{T, N},
target::Number, output::Number) where {T <: MarginLoss, N}
not_target = 1 != target
sgn = sign(target - 1)
retval = not_target * ($fun)(loss.loss, sgn, output - 1)
for t = 2:N
not_target = (t != target)
sgn = sign(target - t)
retval += not_target * ($fun)(loss.loss, sgn, output - t)
end
retval
end
end

for prop in [:isminimizable, :isdifferentiable,
:istwicedifferentiable,
:isconvex, :isstrictlyconvex,
:isstronglyconvex, :isnemitski,
:isunivfishercons, :isfishercons,
:islipschitzcont, :islocallylipschitzcont]
@eval ($prop)(l::OrdinalMarginLoss) = ($prop)(l.loss)
end

for fun in (:isdifferentiable, :istwicedifferentiable)
@eval function ($fun)(loss::OrdinalMarginLoss{T, N},
target::Number, output::Number) where {T, N}
for t = 1:target - 1
($fun)(loss.loss, output - t) || return false
end
for t = target + 1:N
($fun)(loss.loss, t - output) || return false
end
return true
end
end
1 change: 0 additions & 1 deletion test/tst_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,3 @@ end
end
end
end

17 changes: 16 additions & 1 deletion test/tst_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,22 @@ end
test_value(QuantileLoss(.7), _quantileloss, yr, tr)
end

@testset "Test ordinal losses against reference function" begin
function _ordinalhingeloss(y, t)
val = 0
for yp = 1:y - 1
val += max(0, 1 - t + yp)
end
for yp = y + 1:5
val += max(0, 1 + t - yp)
end
val
end
y = rand(1:5, 10); t = randn(10) .+ 3
test_value(OrdinalMarginLoss(HingeLoss(), Val{5}), _ordinalhingeloss, y, t)
end


@testset "Test other loss against reference function" begin
_crossentropyloss(y, t) = -y*log(t) - (1-y)*log(1-t)
test_value(CrossentropyLoss(), _crossentropyloss, 0:0.01:1, 0.01:0.01:0.99)
Expand Down Expand Up @@ -598,4 +614,3 @@ end
end
end
end

0 comments on commit 2a87c9e

Please sign in to comment.