[BugFix][Relax] Fix scatter_elements and scatter_nd CUDA compilation#19497
[BugFix][Relax] Fix scatter_elements and scatter_nd CUDA compilation#19497tlopex merged 1 commit intoapache:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces GPU-specific implementations for the scatter_nd and scatter_elements operators in TVM TOPI, utilizing explicit thread and block bindings for GPU targets while maintaining existing CPU paths. The changes include the addition of helper functions for reduction and index calculation. A review comment suggests refactoring scatter_elements.py to reduce logic duplication between the GPU and CPU paths by extracting common operations into helper functions.
75435f3 to
b6b11f9
Compare
|
@tlopex I considered going with a dispatch pass like DispatchSortScan but ended up fixing the lowering inline. The CPU and GPU paths here use the same algorithm so adding a separate topi/gpu/ implementation plus dispatch felt like mostly boilerplate and duplication. |
|
Thanks for working on this. I agree that fixing this in the scatter lowering is the right layer, and this is much closer to what I had in mind than a generic That said, I’d prefer the fix to restore a proper target-specific TOPI dispatch rather than make the generic The root issue in #19451 is that Relax legalization lost the old CUDA scatter lowering path: the op is lowered through Even if the CPU and GPU algorithms are mostly the same, the generated TIR is not target-neutral: GPU needs explicit So my preference would be:
I don’t think we need the broad pass from #19363, but I would like this PR to fix the issue through explicit GPU scatter lowering rather than inline target checks in the generic implementation. |
b6b11f9 to
5542f66
Compare
`topi.scatter_elements` and `topi.scatter_nd` emit bare `T.parallel` loops in
their `te.extern` IRBuilder bodies which trips `VerifyMemory` on CUDA targets:
RuntimeError: Memory verification failed
...
Did you forget to bind?
CPU (LLVM) is unaffected.
This fix makes the IRBuilder body in both `topi/scatter_elements.py` and
`topi/scatter.py` target-aware. When `Target.current()` is a GPU target it emits
thread bindings instead of `T.parallel`.
Fixes apache#19451.
5542f66 to
2628c11
Compare
|
@tlopex Thanks for the feedback and telling me your preferences. I reworked the PR. The generic topi.scatter_elements and topi.scatter_nd are untouched, GPU lowering lives in topi/gpu/scatter_elements.py and topi/gpu/scatter_nd.py with explicit blockIdx/threadIdx bindings, and dispatch happens in legalize_ops/manipulate.py. I also added a CUDA compile/build regression tests for both ops next to the existing scatter tests in test_transform_legalize_ops_manipulate.py. |
topi.scatter_elementsandtopi.scatter_ndemit bareT.parallelloops in their te.extern IRBuilder bodies which tripsVerifyMemoryon CUDA targets:CPU (LLVM) is unaffected.
This fix makes the IRBuilder body in both
topi/scatter_elements.pyandtopi/scatter.pytarget-aware. WhenTarget.current()is a GPU target it emits thread bindings instead ofT.parallel.Fixes #19451.