Skip to content

Conversation

@romerojosh
Copy link
Collaborator

The built-in MLP model currently flattens input tensors from an assumed input dimension of [batch_size, ...] to [batch_size, -1] before passing to the first linear layer. For example, if a user passes in an input tensor of dimension [8, 32, 32], that input is reshaped to [8, 1024], assuming that the first (or last dimension if in Fortran) is the batch_size and all other dimensions should be considered as features. This behavior is a bit over-prescriptive and not necessarily obvious, especially as it deviates from standard PyTorch broadcasting behavior.

This PR better codifies this behavior by adding a new parameter to the MLP config (flatten_non_batch_dims) that enables this automatic flattening. If set to false, no reshaping will take place and standard PyTorch broadcasting rules will apply to the inputs and the MLP dimensions. To maintain backwards compatibility, this parameter defaults to true.

@romerojosh romerojosh changed the title Add parameter to control flattening behavior of built-in MLP models. Add parameter to control flattening behavior of built-in MLP model. Sep 18, 2025
@romerojosh
Copy link
Collaborator Author

/build_and_test

@github-actions
Copy link

🚀 Build workflow triggered! View run

@github-actions
Copy link

✅ Build workflow passed! View run

@romerojosh
Copy link
Collaborator Author

/build_and_test

@github-actions
Copy link

🚀 Build workflow triggered! View run

@github-actions
Copy link

✅ Build workflow passed! View run

Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
@azrael417 azrael417 merged commit 90c385d into master Sep 30, 2025
4 checks passed
@romerojosh romerojosh deleted the mlp_flatten_option branch October 2, 2025 16:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants