-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ordinals and minor changes to LogitDistLoss #88
Conversation
To address the question you wrote me on Discourse.
That is a very fair question, since that file uses a couple of advanced concepts. I have not yet looked into ordinal losses, so I don't have any comments on the concrete implementation yet, but I am happy to walk through how to define a "decorator" type similar to First we need to define the type itself which needs two type parameters as you correctly assumed. This looks like the following. We bind a type variable struct OrdinalMarginLoss{L<:MarginLoss,N} <: SupervisedLoss
loss::L
end With this type defined we have in principle everything we need to work with it, but it would be unnecessarily verbose julia> OrdinalMarginLoss{HingeLoss,5}(HingeLoss())
OrdinalMarginLoss{LossFunctions.L1HingeLoss,5}(LossFunctions.L1HingeLoss()) To make this type more convenient to work with we should also define an outer constructor. We will use the type OrdinalMarginLoss(loss::T, ::Type{Val{N}}) where {T,N} = OrdinalMarginLoss{T,N}(loss) This will allow a shorter contructor where we don't have to repeat typing the loss julia> OrdinalMarginLoss(HingeLoss(), Val{5})
OrdinalMarginLoss{LossFunctions.L1HingeLoss,5}(LossFunctions.L1HingeLoss()) |
src/supervised/ordinal.jl
Outdated
|
||
for fun in (:value, :deriv, :deriv2) | ||
@eval @fastmath function ($fun)(loss::OrdinalMarginLoss, target::Number, output::Number) | ||
retval = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a problem you may have is that initializing this variable with an Int
may cause type instability since it is likely that value
or deriv
return some float
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right, I just checked it with @code_warntype
and it shows that retval is of type Any. I'm thinking that I should initialize it with the type of the output, since the type instability is making the function quite slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are only using +
to accumulate items it should work if you initialize with the first result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. set it to retval = ($fun)(loss.loss, 1, output - 1)
and change the first loop to for t = 2:target-1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it seems like even if I specify
retval = zero(output)
that I get type instability because the type of retval is inferred to be Any.
How would you enforce type stability in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats because in
struct OrdinalMarginLoss <: SupervisedLoss
loss::MarginLoss
nlevels::Int
end
the member loss
is weakly typed. This means no matter how beautifully clean the function value
, the call value(loss.loss,...)
will always be a run-time look up where the compiler can't know which method will be called . Try defining the type as I outlined in an earlier post using type variables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I forgot to provide some more context.
I've redefined
struct OrdinalMarginLoss{L<:MarginLoss, N} <: SupervisedLoss
loss::MarginLoss
end
function OrdinalMarginLoss(loss::T, ::Type{Val{N}}) where {T<:MarginLoss,N}
typeof(N) <: Number || _serror()
OrdinalMarginLoss{T,N}(loss)
end
As you outlined
I'm trying to define the value, deriv, and deriv2 functions as follows:
for fun in (:value, :deriv, :deriv2)
@eval @fastmath @generated function ($fun)(loss::OrdinalMarginLoss{T, N},
target::Number, output::Number) where {T <: MarginLoss, N}
quote
retval = zero(output)
@nexprs $N t -> begin
not_target = (t != target)
sgn = sign(target - t)
retval += not_target * ($($fun))(loss.loss, sgn, output - t)
end
retval
end
end
end
It gives the correct answer, but @code_warntype
still flags retval as being of type ::Any. Is there a different reason for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct OrdinalMarginLoss{L<:MarginLoss,N} <: SupervisedLoss
loss::L
end
note the loss::L
. This is the important part and the sole reason we even want the type of loss as the type parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't bother with the generated function and fastmath at first. Try to make the initial implementation that is currently submitted type stable first. Otherwise you are dealing with many complex factors at the same time, which can be quite a hassle when trying to deduce issues
plus I think that almost all the performance benefits will come from type stability with the unrolling probably being a small micro optimization |
I missed the |
Thanks for investing time in this |
After some testing, it seems that the loop version and unrolled version are almost equally fast, but I think the loop is more readable. I just pushed the update to this PR. |
src/supervised/ordinal.jl
Outdated
for fun in (:value, :deriv, :deriv2) | ||
@eval @fastmath function ($fun)(loss::OrdinalMarginLoss{T, N}, | ||
target::Number, output::Number) where {T <: MarginLoss, N} | ||
retval = zero(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would still suggest using the first loop element as initialization to catch all kinds of sneaky type instability edge cases. For example this code is unstable in the following case
julia> using LossFunctions
julia> output = 1 # Int
1
julia> value(L2HingeLoss(), -1., output)
4.0
julia> value(LogitMarginLoss(), -1, output)
1.3132616875182228
This is because the type of the return value depends on the type of target
, output
, and on what the loss
does. Some losses such as LogitMarginLoss
will always result in a float no matter the type of output
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That does make sense, and it's easier to reason about types if the underlying MarginLoss was responsible for determining the output type.
Changes made.
src/supervised/ordinal.jl
Outdated
retval = zero(output) | ||
for t = 1:N | ||
not_target = (t != target) | ||
sgn = sign(target - t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this always work out? I remember sign
having some inconvenient behaviour. for example
julia> sign(0)
0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only time that sign(x) ∉{-1,1}
is indeed when x == 0
. However, in this case, target == t
so the whole thing gets multiplied by 0 anyway, which is what should happen.
sorry for the delay. This PR is on my agenda today! |
very nice thanks |
LogitDistLoss incurs numerical problems for large values of diff, since there is exp(abs2(...)) which blows up to Inf very quickly. These problems are improved somewhat by simplifying the equation.
I also added an 'ordinalization' of margin losses that allows them to predict over ordinal targets, so that an output greater than the maximum or less than the minimum is not penalized strongly as in a DiffLoss