From b6b2519b55e2fb371853d7a63f422f3070b34f80 Mon Sep 17 00:00:00 2001 From: Christof Stocker Date: Mon, 10 Jul 2017 16:00:12 +0200 Subject: [PATCH] streamline OrdinalMarginLoss construction --- src/LossFunctions.jl | 1 + src/supervised/ordinal.jl | 59 +++++++++++++++++++-------------------- test/tst_loss.jl | 7 ++++- 3 files changed, 35 insertions(+), 32 deletions(-) mode change 100755 => 100644 src/supervised/ordinal.jl diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index adbe96c..35ce9af 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -47,6 +47,7 @@ export ZeroOneLoss, OrdinalMarginLoss, + OrdinalHingeLoss, weightedloss, diff --git a/src/supervised/ordinal.jl b/src/supervised/ordinal.jl old mode 100755 new mode 100644 index edc867e..66d39e8 --- a/src/supervised/ordinal.jl +++ b/src/supervised/ordinal.jl @@ -9,41 +9,33 @@ level thresholds crossed relative to target Assumes target is encoded in an Index encoding scheme where levels are numbered between 1 and `N` """ -struct OrdinalMarginLoss{L<:MarginLoss, N} <: SupervisedLoss +struct OrdinalMarginLoss{L<:MarginLoss} <: SupervisedLoss loss::L + N::Int + OrdinalMarginLoss{L}(loss::L, N::Int) where {L<:MarginLoss} = new{L}(loss, N) end -function OrdinalMarginLoss(loss::T, ::Type{Val{N}}) where {T<:MarginLoss,N} - typeof(N) <: Number || _serror() - OrdinalMarginLoss{T,N}(loss) +function OrdinalMarginLoss(loss::L, N::Int) where {L<:MarginLoss} + OrdinalMarginLoss{L}(loss, N) end -#= -for fun in (:value, :deriv, :deriv2) - @eval @fastmath @generated function ($fun)(loss::OrdinalMarginLoss{T, N}, - target::Number, output::Number) where {T <: MarginLoss, N} - quote - retval = zero(output) - @nexprs $N t -> begin - not_target = (t != target) - sgn = sign(target - t) - retval += not_target * ($($fun))(loss.loss, sgn, output - t) - end - retval - end - end -end =# +@generated function (::Type{T})(N::Int, args...) where {T<:OrdinalMarginLoss} + L = typeof(T) == UnionAll ? T.var.ub : T.parameters[1] + :(OrdinalMarginLoss($L(args...), N)) +end for fun in (:value, :deriv, :deriv2) - @eval @fastmath function ($fun)(loss::OrdinalMarginLoss{T, N}, - target::Number, output::Number) where {T <: MarginLoss, N} - not_target = 1 != target + @eval @fastmath function ($fun)( + l::OrdinalMarginLoss, + target::Number, + output::Number) + not_target = Int(1 != target) sgn = sign(target - 1) - retval = not_target * ($fun)(loss.loss, sgn, output - 1) - for t = 2:N - not_target = (t != target) + retval = not_target * ($fun)(l.loss, sgn, output - 1) + for t = 2:l.N + not_target = Int(t != target) sgn = sign(target - t) - retval += not_target * ($fun)(loss.loss, sgn, output - t) + retval += not_target * ($fun)(l.loss, sgn, output - t) end retval end @@ -59,14 +51,19 @@ for prop in [:isminimizable, :isdifferentiable, end for fun in (:isdifferentiable, :istwicedifferentiable) - @eval function ($fun)(loss::OrdinalMarginLoss{T, N}, - target::Number, output::Number) where {T, N} + @eval function ($fun)( + l::OrdinalMarginLoss, + target::Number, + output::Number) for t = 1:target - 1 - ($fun)(loss.loss, output - t) || return false + ($fun)(l.loss, output - t) || return false end - for t = target + 1:N - ($fun)(loss.loss, t - output) || return false + for t = target + 1:l.N + ($fun)(l.loss, t - output) || return false end return true end end + +const OrdinalHingeLoss = OrdinalMarginLoss{HingeLoss} +# const OrdinalSmoothedHingeLoss = OrdinalMarginLoss{<:SmoothedL1HingeLoss} diff --git a/test/tst_loss.jl b/test/tst_loss.jl index b316012..13c30a7 100644 --- a/test/tst_loss.jl +++ b/test/tst_loss.jl @@ -318,6 +318,7 @@ end @test L2DistLoss === LPDistLoss{2} @test HingeLoss === L1HingeLoss @test EpsilonInsLoss === L1EpsilonInsLoss + @test OrdinalHingeLoss === OrdinalMarginLoss{HingeLoss} end @testset "Test typestable supervised loss for type stability" begin @@ -476,6 +477,9 @@ end test_value(QuantileLoss(.7), _quantileloss, yr, tr) end +const OrdinalSmoothedHingeLoss = OrdinalMarginLoss{<:SmoothedL1HingeLoss} +@test OrdinalSmoothedHingeLoss(4, 2.1) === OrdinalMarginLoss(SmoothedL1HingeLoss(2.1), 4) + @testset "Test ordinal losses against reference function" begin function _ordinalhingeloss(y, t) val = 0 @@ -488,7 +492,8 @@ end val end y = rand(1:5, 10); t = randn(10) .+ 3 - test_value(OrdinalMarginLoss(HingeLoss(), Val{5}), _ordinalhingeloss, y, t) + @test OrdinalHingeLoss(5) === OrdinalMarginLoss(HingeLoss(), 5) + test_value(OrdinalMarginLoss(HingeLoss(), 5), _ordinalhingeloss, y, t) end