-
-
Notifications
You must be signed in to change notification settings - Fork 50
Simplied Recorder and Metric #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed all of the metrics here have a "Recorder" suffix. Would it help to add a new namespace and/or abstract type such that these names can be trimmed?
Agreed. I like that. There is also a lot of repeated code in recorder.jl. Is there a nice idiomatic way to improve that? |
The duplication doesn't look too bad, but one way would be to extract the {Train,Validation} part into a field or type parameter. I also noticed that fast.ai uses a single class for their training and validation loss recorder—this could work as well (moreso for the loss than for the training phase agnostic metrics). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think inheritance is appropriate here to avoid duplication. Firstly, defining
Base.getindex(lr::Recorder, i) = lr.log[i]
creates an implicit assumption about the fields of concrete subtypes of Recorder
. The correct thing to do would be to define an interface getlog(lr::Recorder)
that each subtype must extend, then use getlog(lr)[i]
instead of lr.log[i]
. But I am not sure such an interface is worth it. If we want to avoid maintaining a separate getindex
definition, then we should just use a function:
recorder_getindex(lr, i) = lr.log[i]
Base.getindex(lr::TrainLoss, i) = recorder_getindex(lr, i)
This has the advantage of only maintaining one function, but it does mean we have to define a separate Base.getindex
for TrainLoss
etc. But I think with the other suggestions below, the code bloat is minimal.
I also feel weird about having SmoothRecorder <: Recorder
. Multiple levels of inheritance is rare for Julia code in my experience. More abstractly, we should think about what a Recorder.SmoothTrainLoss
is — it needs to take a state from the Learner
(i.e. the loss
), applies a smoothing function to it, then records the value. What makes this different from Recorder.TrainLoss
? In the latter case, we take a state from the Learner
(again the loss
) and record the value. Could we generalize these two cases so that a single type can perform both duties? All that is missing from Recorder.TrainLoss
is the smoothing function; instead it implicitly applies a function that does nothing to the loss. We have a single type like so:
"""
Utility type for smoothing a series of values
"""
mutable struct Smoother
alpha::Real
val::Real
end
Smoother(alpha) = Smoother(alpha, 0.0)
function (asl::Smoother)(value)
asl.val = asl.alpha*asl.val+(1-asl.alpha)*value
return asl.val
end
struct TrainLoss{F, T} <: AbstractCallback where F, T<:Real
f::F
log::Array{T, 2}
end
TrainLoss() = TrainLoss(identity, Float32[])
SmoothTrainLoss(alpha = 0.98) = TrainLoss(Smoother(alpha), Float32[])
To me, this is more Julian and functional. Instead of using inheritance, we use high-order functions. The only issue with the above stuff is reset!(smoother::Smoother)
. This can be facilitated by dispatch:
function before_fit(lr::TrainLoss{<:Any, Smoother}, lrn::Learner, epoch_count, batch_size)
reset!(lr.f)
lr.log = zeros(epoch_count, batch_size)
end
We can take this even further and make Recorder
a concrete type...but that may go too far. Though it would make the whole getindex
code bloat issue simpler. In this scenario, we would define
struct Recorder{S, F, T} <: AbstractCallback where S, F, T<:Real
f::F
log::Array{T, 2}
end
ValidateLoss() = Recorder{:Validate}(identity, Float32[])
batch_validate_loss(lr::Recorder{:Validate}, ...) = # do stuff
What does everyone think?
src/recorder.jl
Outdated
|
||
https://github.com/fastai/fastai2/blob/master/fastai2/learner.py | ||
|
||
The documentation is copied from here | ||
|
||
https://dev.fast.ai/learner#Recorder | ||
=# | ||
module recorder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module names should be capitalized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like
struct Recorder{S, F, T} <: AbstractCallback where S, F, T<:Real
f::F
log::Array{T, 2}
end
ValidateLoss() = Recorder{:Validate}(identity, Float32[])
batch_validate_loss(lr::Recorder{:Validate}, ...) = # do stuff
Can we use the same trick with Smoothing, like this
ValidateLoss() = Recorder{:Nothing,:Validate}(identity, Float32[])
SmoothValidateLoss() = Recorder{:Smooth,:Validate}(identity, Float32[])
batch_validate_loss(lr::Recorder{:Smooth,:Validate}, ...) = # do stuff with Smoother
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but for smooth loss you don't need to specify :Smooth
. The type parameter F
in Recorder
is the type of f
which is the function applied to the loss
before recording it. This type will be Smoother
when you pass in Smoother
as the function. So you can specialize on that:
struct Recorder{S, F, T} <: AbstractCallback where S, F, T<:Real
f::F
log::Array{T, 2}
end
TrainLoss() = Recorder{:Train}(identity, Float32[])
SmoothTrainLoss(alpha = 0.98) = Recorder{:Train}(Smoother(alpha), Float32[])
ValidateLoss() = Recorder{:Validate}(identity, Float32[])
SmoothValidateLoss(alpha = 0.98) = Recorder{:Validate}(Smoother(alpha), Float32[])
function _record_state!(lr::Recorder, epoch, batch, state)
lr.log[epoch, batch] = lr.f(state)
end
# example specializing on :Train vs :Validate
batch_train_loss(lr::Recorder{:Train}, lrn::AbstractLearner, epoch, batch, loss) = _record_state!(lr, epoch, batch, loss)
batch_validate_loss(lr::Recorder{:Validate}, lrn::AbstractLearner, epoch, batch, loss) = _record_state!(lr, epoch, batch, loss)
# example specializing on Smoother vs identity
function before_fit(lr::Recorder, lrn::Learner, epoch_count, batch_size)
lr.log = zeros(epoch_count, batch_size)
end
function before_fit(lr::Recorder{<:Any, Smoother}, lrn::Learner, epoch_count, batch_size)
reset!(lr.f)
lr.log = zeros(epoch_count, batch_size)
end
You can do various combinations of the Recorder{S, F, T}
type parameters to specialize.
src/recorder.jl
Outdated
end | ||
return rec | ||
end | ||
abstract type Recorder <: AbstractCallback end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is now in the namespace of recorder
, you need using ..: AbstractCallback
after module recorder
to use AbstractCallback
here.
src/FastAI.jl
Outdated
export TrainLossRecorder | ||
export ValidateLossRecorder | ||
export TrainSmoothLossRecorder | ||
export ValidateSmoothLossRecorder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @ToucheSir's comment about type names means that here we should export the module Recorder
but not the types within it. Then in user code, a person needs to write Recorder.TrainLoss
to get the training loss recorder. I agree that this will look cleaner than having "Recorder" at the end of each type name.
src/recorder.jl
Outdated
log::Array{Real,2} | ||
end | ||
TrainLoss()=TrainLoss([]) | ||
batch_train_loss(lr::TraninLoss,lrn::AbstractLearner, epoch, batch, loss) = lr.log[epoch,batch] = loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this is a typo. Should be TrainLoss
not TraninLoss
.
Many thanks for all the suggestions here and in the "Common Code" thread. I spent the week learning those techniques, and rewriting this code several ways. In the end however, I went with my original design. Type arguments did not save many lines of code, limited a Recorder to a log and preprocessing function, and were harder to understand for a non-Julian. I thought the last was important if we want to attract Python developers and have them add their own Recorders. However, I will definitely use type arguments in the future. It is a really cool and elegant technique. |
Simplified design of Recorder and Metric concepts. Now Recorders are simple Callbacks that logs particular statistics. Metric and Recorder have gone away.