Skip to content

Wmma support for grouped convolution bwd weight#2947

Merged
illsilin merged 78 commits intodevelopfrom
streamhpc/conv_bwd_weight_wmma
Dec 17, 2025
Merged

Wmma support for grouped convolution bwd weight#2947
illsilin merged 78 commits intodevelopfrom
streamhpc/conv_bwd_weight_wmma

Conversation

@EnricoDeg
Copy link
Copy Markdown
Contributor

Proposed changes

Summary:

  • Modify gridwise implementation to work with convolution (grid descriptors are not created internally but passed from the device level)
  • Add device level implementation: DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 , DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 and DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
  • Add device implementation of batched gemm multiple Ds (needed for explicit gemm - conv bwd weight)
  • Adapt existing device implementation of explicit gemm to work for both xdl and wmma implementations of batched gemm multiple Ds
  • Add support for occupancy-based splitk for one stage and two stage implementations of grouped conv bwd weight
  • Create instances
  • Add examples
  • Remove old instances (they don't support splitk)
  • Add tests for bwd weight scale

The implementations are based on CShuffleV3 but the functionality is the same as xdl.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

…/conv_bwd_weight_wmma'

Convolution bwd weight device implementation

See merge request amd/ai/composable_kernel!38
 - rdna3 compilation error
 - gridwise layouts (need to be correct to ensure that CheckValidaity()
   works correctly)
…re/conv_bwd_weight_wmma'

Grouped conv: Instances and example bwd weight

See merge request amd/ai/composable_kernel!47
Device implementation of explicit gemm for grouped conv bwd weight

See merge request amd/ai/composable_kernel!52
krithalith and others added 17 commits December 15, 2025 08:56
…le V3 instances. CShuffleBlockTranserScalarPerVector adapted to 4, and mergegroups fixed to 1 for now. No more special instance lists.
…duplications. Also removing stride1pad0 support for NHWGC since we can use explicit for those cases.
… layout / datatype support as before the instance selection process.
… NHWGC. They are never faster and support is already carried by CShuffleV3 and Explicit.
…fwd declarations, cmakelists entries. Also merge the "wmma" and "wmma v3" instance list files, which are both v3.
…NHWGCxGKYXC and F16 or BF16 (no mixed in-out types).
…tance_selection

WIP: Grouped convolution bwd weight wmma v3 instance selection
@bartekxk bartekxk requested review from Copilot and removed request for aska-0096 December 16, 2025 08:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 82 out of 83 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

@bartekxk bartekxk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, minor comments


GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
p_shared, splitk_batch_offset, karg, epilogue_args, 0, k_id);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add in /**/ what 0 means

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

//################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| |
//################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1>
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHy the most of them are disabled?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also described in the instance selection document I sent but basically for NHWGCxGKCYX we also have explicit gemm, and from perf tests we saw that the TwoStage implementation was never faster or necessary for support for this layout. Therefore we only use a single generic instance for now. Two-stage becomes more relevant for other layouts but we are not adding these right now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I understand. For the future it is better to add some comment because no one will read some documents from private chat regarding public repo.

@EnricoDeg EnricoDeg marked this pull request as ready for review December 17, 2025 09:53
@illsilin illsilin merged commit 87dd073 into develop Dec 17, 2025
25 of 27 checks passed
@illsilin illsilin deleted the streamhpc/conv_bwd_weight_wmma branch December 17, 2025 23:59
SreecharanGundaboluAMD added a commit that referenced this pull request Dec 21, 2025
This reverts commit 87dd073.

Note: Resolved merge conflicts in a best effort.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants