Skip to content

Add use_linear option to replace Conv3d tokenizer with Linear layers#48

Merged
pzhanggit merged 1 commit intoORNL:mainfrom
nicholasmalaya:feature/use-linear-tokenizer
Apr 20, 2026
Merged

Add use_linear option to replace Conv3d tokenizer with Linear layers#48
pzhanggit merged 1 commit intoORNL:mainfrom
nicholasmalaya:feature/use-linear-tokenizer

Conversation

@nicholasmalaya
Copy link
Copy Markdown

When kernel_size == stride (non-overlapping patches), Conv3d is mathematically equivalent to reshape + nn.Linear. This avoids the im2col/col2im overhead and replaces MIOpen's implicit GEMM backward-weight path with standard rocBLAS matmul backward.

Profiling on MI355X (gfx950) shows the backward-weight GEMM (kernel_batched_gemm_xdlops_bwd_weight) consumed 79.3% of compute time. With use_linear=True, this kernel is eliminated entirely, yielding a 2.87x end-to-end training speedup with identical loss convergence.

Enabled via config: use_linear: !!bool True (default False, fully backward compatible).

When kernel_size == stride (non-overlapping patches), Conv3d is
mathematically equivalent to reshape + nn.Linear. This avoids the
im2col/col2im overhead and replaces MIOpen's implicit GEMM
backward-weight path with standard rocBLAS matmul backward.

Profiling on MI355X (gfx950) shows the backward-weight GEMM
(kernel_batched_gemm_xdlops_bwd_weight) consumed 79.3% of compute
time. With use_linear=True, this kernel is eliminated entirely,
yielding a 2.87x end-to-end training speedup with identical loss
convergence.

Enabled via config: use_linear: !!bool True (default False,
fully backward compatible).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@pzhanggit pzhanggit requested review from TsChala and pzhanggit April 7, 2026 18:36
Copy link
Copy Markdown
Collaborator

@pzhanggit pzhanggit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nicholasmalaya thank you very much for the optimization, Nick!
@TsChala the PR looks good to me. Could you do a test run on Frontier when it's back from maintenance? We should extend the changes to other models for better performance as well. Thanks

@TsChala
Copy link
Copy Markdown
Collaborator

TsChala commented Apr 8, 2026

Thanks for the edits @nicholasmalaya !

@pzhanggit I ran some test on the JHUTDB dataset today. Using the Turbulence Transformer I see around 2x speed-up! This is only from the hMLP_stem and hMLP_output. Probably further speed-up can be achieved if we replace the conv3D's in the upsampling parts as well.

For the vit_all2all model I see similar runtimes so far with and without the use_linear. I'm not exactly sure why, I can look more into it, but we mainly use the TurbT anyways.

Copy link
Copy Markdown
Collaborator

@TsChala TsChala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. After this PR is merged I can work on expanding it to the other models as well.

@pzhanggit pzhanggit merged commit 976e788 into ORNL:main Apr 20, 2026
mrowan137 added a commit to mrowan137/MATEY that referenced this pull request Apr 21, 2026
…RNL#48

This PR (follow-on to ORNL#48) addresses two potential issues:
1. smooth layer silently dropped when notransposed=True:
In hMLP_output.__init__ and forward, ORNL#48 added additional indentation to
'if self.smooth' blocks, resulting in the smooth layer getting skipped when
self.notransposed=True and self.smooth=True. This PR would revert to previous
behavior, by un-indenting.

2. expand_projections shape mismatch with use_linear=True:
Within BaseModel.expand_projections, ORNL#48 introduced additional freedom for
new_debed.out_head to be either nn.Linear or nn.ConvTranspose3d; however a
later copy still is not generalized for both possibilities, and can result in
shape mismatch when use_linear=True. This PR generalizes to account for both
possibilities.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
mrowan137 added a commit to mrowan137/MATEY that referenced this pull request Apr 21, 2026
This PR (follow-on to ORNL#48) addresses three potential issues:

1. smooth layer silently dropped when notransposed=True:
In hMLP_output.__init__ and forward, ORNL#48 added additional indentation to
'if self.smooth' blocks, resulting in the smooth layer getting skipped when
self.notransposed=True and self.smooth=True. This PR would revert to previous
behavior, by un-indenting.

2. expand_projections shape mismatch with use_linear=True:
Within BaseModel.expand_projections, ORNL#48 introduced additional freedom for
new_debed.out_head to be either nn.Linear or nn.ConvTranspose3d; however a
later copy still is not generalized for both possibilities, and can result in
shape mismatch when use_linear=True. This PR generalizes to account for both
possibilities.

3. Collapse redundant if/else in hMLP_output.forward:
After un-indenting the smooth block in (1), the remaining
'if self.notransposed / else' branches reduce to 'x = self.out_head(x)'. This
PR collapses them.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
pzhanggit added a commit that referenced this pull request Apr 21, 2026
Fix smooth layer and expand_projections regressions from #48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants