Add parameter to control flattening behavior of built-in MLP model. #78
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The built-in MLP model currently flattens input tensors from an assumed input dimension of
[batch_size, ...]to[batch_size, -1]before passing to the first linear layer. For example, if a user passes in an input tensor of dimension[8, 32, 32], that input is reshaped to[8, 1024], assuming that the first (or last dimension if in Fortran) is the batch_size and all other dimensions should be considered as features. This behavior is a bit over-prescriptive and not necessarily obvious, especially as it deviates from standard PyTorch broadcasting behavior.This PR better codifies this behavior by adding a new parameter to the MLP config (
flatten_non_batch_dims) that enables this automatic flattening. If set tofalse, no reshaping will take place and standard PyTorch broadcasting rules will apply to the inputs and the MLP dimensions. To maintain backwards compatibility, this parameter defaults totrue.