Skip to content

Commit

Permalink
[Ir][Primitives] fix __shfl_xor_sync (hidet-org#155)
Browse files Browse the repository at this point in the history
fix `__shfl_xor_sync`. I don't know why `__shfl_xor_sync` is an alias of
`__shfl_down_sync`. Is this intentional?

Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
  • Loading branch information
xiaocenxiaocen and xiaocenxiaocen committed Apr 19, 2024
1 parent 160b787 commit 4a09a38
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/hidet/ir/primitives/cuda/shfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def register_primitive_functions():
('cuda_shfl_sync', '__shfl_sync', FuncType(type_infer_func=type_infer)),
('cuda_shfl_up_sync', '__shfl_up_sync', FuncType(type_infer_func=type_infer)),
('cuda_shfl_down_sync', '__shfl_down_sync', FuncType(type_infer_func=type_infer)),
('cuda_shfl_xor_sync', '__shfl_xor_sync', FuncType(type_infer_func=type_infer)),
]
for name, codegen_name, func_type in functions:
register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name)
Expand All @@ -44,7 +45,7 @@ def shfl_down_sync(mask, var, delta, width=32):


def shfl_xor_sync(mask, var, lane_mask, width=32):
return call_primitive_func('cuda_shfl_down_sync', [mask, var, lane_mask, width])
return call_primitive_func('cuda_shfl_xor_sync', [mask, var, lane_mask, width])


def active_mask():
Expand Down

0 comments on commit 4a09a38

Please sign in to comment.