-
Notifications
You must be signed in to change notification settings - Fork 16
Description
There is a fundamental problem with the way we handle batches, at least as far applications where the extra GPU speed gained with batching is important. The issue is the incompatibility of batching with observation resampling, as conventionally understood.
So, for example, if we are wrapping a model for cross-validation, then the observations get split up multiple times into different test/train sets. At present, a "batch" is understood to consist of multiple "observations", which means that resampling a MLJFlux model breaks the batches, an expensive operation for large data.
I'm guessing this is a very familiar problem to people in deep learning and so am copying some of them in for comment and will post a link on slack. The solution I am considering for MLJ is to regard a "batch" of images as an unbreakable object that we consequently view as an observation, by definition. It would be natural do introduce a new parametric scientific type Batch{SomeAtomicScitype} to articulate a model's participation in this convention.
Thoughts anyone?
Some consequences of this breaking change would be:
-
batch_sizedisappears as a hyper-parameter of MLJFlux models, at least forImageClassifier, but probably for all the models, for simplicity. So changing the batch size becomes the responsibility of a pre-processing transformer external to the model. I need to give some thought to transformers that reduce the number of observations, when inserted into MLJ pipelines (and learning networks, more generally). If that works, "smart" training of MLJ pipelines would mean no "re-batching" when retraining the composite model, unless the batch size changes, which is good. -
with this change one could implement the
reformatandselectrows(now same as "select batches") functions that constitute buy-in for MLJ's new data front-end.