-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Realize performance improvements for models implementing new data front-end #501
Conversation
add fit_only! logic that caches reformatted/resampled user-data add test minor comment change "Bump version to 0.16.3" resolve #499 add test add tests for reformat front-end oops add model api tests for reformat front-end tidy up add test for reformat logic in learning network context do not update state of a Machine{<:Static} on fit! unless hyperparams change add tests allow speedup buy-out with machine(model, args..., cache=false) oops have KNN models buy into reformat data front-end for better testing introduce "back-end" resampling in evaluate! implement data front-end on prediction/transformation side update machine show more tests; add cache=... to evaluate(model, ...) more tests make `cache` hyperparam of Resampler for passing onto wrapped machines more tests Composite models do not cache data by default (no benefit) correct comment bump [compat] MLJModelInterface = "^0.3.7" (essential)
Codecov Report
@@ Coverage Diff @@
## dev #501 +/- ##
==========================================
- Coverage 82.72% 82.57% -0.15%
==========================================
Files 39 39
Lines 2935 2962 +27
==========================================
+ Hits 2428 2446 +18
- Misses 507 516 +9
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Continuation of #492 with squashed commits.
Addresses key parts of #309.
Builds on the MLJModelInterface PR 76 (now part of MLJModelInterface 0.3.7) which provides a way for the implementer of MLJ's model interface to articulate: (i) how user-data (eg, a table) is transformed into some model-specific representation needed by the core algorithm (eg, a matrix) - the implementer overloads
reformat
; and (ii) how one can construct samples of the model-specific representation given a vector of observation indices - the implementer overloadsselectrows
. Detailed responsibilities of this optional "data front-end" implementation are given in this proposed doc PR.This PR introduces optional caching of the model-specific data representations for machines and for subsamples of these representations, which could speed up computations (at the cost of larger memory allocations). Machines have a new type parameter
C
which determines whether or not they buy into data caching, withC=true
the default. A user opts-out during machine construction, as inmach = machine(model, X, y, cache=false)
.Some key implementation details to help with a review:
C
(see above) defaulting totrue
, except for composite models, where caching provides no obvious advantage. (Machines within in a learning network defining the composite may still cache with some benefit.)data
andresampled_data
for caching model-specific representations for the what is returned by[arg() for arg in args()]
(recallargs
is a tuple ofSource
orNode
objects). We will refer to the latter as user-supplied data.fit_only!
at https://github.com/alan-turing-institute/MLJBase.jl/blob/2eda522877451fdd26838faf279ac74c5a5a837c/src/machines.jl#L417 is adapted to applyreformat
to the user-supplied data whenfit!
is first called, or the user-supplied data has changed because of upstream changes in a learning network. This data is cached for future calls tofit!
;data
is resampled according to therows
specified in thefit!
call and also cached. Future calls tofit!
only performreformat
andselectrows
as needed. All this assumesC=true
and otherwisereformat
andselectrows
are called every time and nothing is cached.predict
, say, is called on a machine with therows=....
syntax, then, assuming the machine has caching enabled,selectrows
is applied to thedata
cached, avoiding a call toreformat
. If new dataXnew
is supplied instead, as inpredict(mach, Xnew)
, then naturallyreformat
is then necessary. The relevant changes are in src/operations.jlpredict(mach, Xtest)
as currently, we callpredict(mach, rows=test)
. This avoids one call toreformat
. However, there is one case distinction: the uncommon case where one of the measures hasis_feature_dependent=true
(for measures withX
in the signature), in which caseXtest
must be still be created, using the slowselectrows(X, test)
fallback.cache=...
in machine constructors, and theResampled
objects. An interface point forcache
inTunedModels
is also needed and addressed in this MLJTuning PR which passes locally (Julia 1.5).