Skip to content

Commit

Permalink
Merge pull request #89 from JuliaML/cs/ordinal
Browse files Browse the repository at this point in the history
streamline OrdinalMarginLoss construction
  • Loading branch information
Evizero committed Jul 13, 2017
2 parents 2a87c9e + b6b2519 commit 9f45628
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/LossFunctions.jl
Expand Up @@ -47,6 +47,7 @@ export
ZeroOneLoss,

OrdinalMarginLoss,
OrdinalHingeLoss,

weightedloss,

Expand Down
59 changes: 28 additions & 31 deletions src/supervised/ordinal.jl 100755 → 100644
Expand Up @@ -9,41 +9,33 @@ level thresholds crossed relative to target
Assumes target is encoded in an Index encoding scheme where levels are
numbered between 1 and `N`
"""
struct OrdinalMarginLoss{L<:MarginLoss, N} <: SupervisedLoss
struct OrdinalMarginLoss{L<:MarginLoss} <: SupervisedLoss
loss::L
N::Int
OrdinalMarginLoss{L}(loss::L, N::Int) where {L<:MarginLoss} = new{L}(loss, N)
end

function OrdinalMarginLoss(loss::T, ::Type{Val{N}}) where {T<:MarginLoss,N}
typeof(N) <: Number || _serror()
OrdinalMarginLoss{T,N}(loss)
function OrdinalMarginLoss(loss::L, N::Int) where {L<:MarginLoss}
OrdinalMarginLoss{L}(loss, N)
end

#=
for fun in (:value, :deriv, :deriv2)
@eval @fastmath @generated function ($fun)(loss::OrdinalMarginLoss{T, N},
target::Number, output::Number) where {T <: MarginLoss, N}
quote
retval = zero(output)
@nexprs $N t -> begin
not_target = (t != target)
sgn = sign(target - t)
retval += not_target * ($($fun))(loss.loss, sgn, output - t)
end
retval
end
end
end =#
@generated function (::Type{T})(N::Int, args...) where {T<:OrdinalMarginLoss}
L = typeof(T) == UnionAll ? T.var.ub : T.parameters[1]
:(OrdinalMarginLoss($L(args...), N))
end

for fun in (:value, :deriv, :deriv2)
@eval @fastmath function ($fun)(loss::OrdinalMarginLoss{T, N},
target::Number, output::Number) where {T <: MarginLoss, N}
not_target = 1 != target
@eval @fastmath function ($fun)(
l::OrdinalMarginLoss,
target::Number,
output::Number)
not_target = Int(1 != target)
sgn = sign(target - 1)
retval = not_target * ($fun)(loss.loss, sgn, output - 1)
for t = 2:N
not_target = (t != target)
retval = not_target * ($fun)(l.loss, sgn, output - 1)
for t = 2:l.N
not_target = Int(t != target)
sgn = sign(target - t)
retval += not_target * ($fun)(loss.loss, sgn, output - t)
retval += not_target * ($fun)(l.loss, sgn, output - t)
end
retval
end
Expand All @@ -59,14 +51,19 @@ for prop in [:isminimizable, :isdifferentiable,
end

for fun in (:isdifferentiable, :istwicedifferentiable)
@eval function ($fun)(loss::OrdinalMarginLoss{T, N},
target::Number, output::Number) where {T, N}
@eval function ($fun)(
l::OrdinalMarginLoss,
target::Number,
output::Number)
for t = 1:target - 1
($fun)(loss.loss, output - t) || return false
($fun)(l.loss, output - t) || return false
end
for t = target + 1:N
($fun)(loss.loss, t - output) || return false
for t = target + 1:l.N
($fun)(l.loss, t - output) || return false
end
return true
end
end

const OrdinalHingeLoss = OrdinalMarginLoss{HingeLoss}
# const OrdinalSmoothedHingeLoss = OrdinalMarginLoss{<:SmoothedL1HingeLoss}
7 changes: 6 additions & 1 deletion test/tst_loss.jl
Expand Up @@ -318,6 +318,7 @@ end
@test L2DistLoss === LPDistLoss{2}
@test HingeLoss === L1HingeLoss
@test EpsilonInsLoss === L1EpsilonInsLoss
@test OrdinalHingeLoss === OrdinalMarginLoss{HingeLoss}
end

@testset "Test typestable supervised loss for type stability" begin
Expand Down Expand Up @@ -476,6 +477,9 @@ end
test_value(QuantileLoss(.7), _quantileloss, yr, tr)
end

const OrdinalSmoothedHingeLoss = OrdinalMarginLoss{<:SmoothedL1HingeLoss}
@test OrdinalSmoothedHingeLoss(4, 2.1) === OrdinalMarginLoss(SmoothedL1HingeLoss(2.1), 4)

@testset "Test ordinal losses against reference function" begin
function _ordinalhingeloss(y, t)
val = 0
Expand All @@ -488,7 +492,8 @@ end
val
end
y = rand(1:5, 10); t = randn(10) .+ 3
test_value(OrdinalMarginLoss(HingeLoss(), Val{5}), _ordinalhingeloss, y, t)
@test OrdinalHingeLoss(5) === OrdinalMarginLoss(HingeLoss(), 5)
test_value(OrdinalMarginLoss(HingeLoss(), 5), _ordinalhingeloss, y, t)
end


Expand Down

0 comments on commit 9f45628

Please sign in to comment.