From 031563377ce47f894b6fb247342a7e810c0cf069 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 20:26:05 -0400 Subject: [PATCH 1/8] initial support Signed-off-by: Yaoyao Ding --- .../tilus/backends/emitters/cuda/mbarrier.py | 46 +++++++++++++ python/tilus/ir/builders/stmt_builder.py | 21 ++++++ python/tilus/ir/instructions/__init__.py | 4 ++ python/tilus/ir/instructions/cuda/__init__.py | 1 + python/tilus/ir/instructions/cuda/mbarrier.py | 60 +++++++++++++++++ python/tilus/ir/mfunction/ops.py | 2 +- python/tilus/lang/script.py | 66 ++++++++++++++++++- scripts/sign-commits.sh | 41 ++++++------ 8 files changed, 218 insertions(+), 23 deletions(-) create mode 100644 python/tilus/backends/emitters/cuda/mbarrier.py create mode 100644 python/tilus/ir/instructions/cuda/mbarrier.py diff --git a/python/tilus/backends/emitters/cuda/mbarrier.py b/python/tilus/backends/emitters/cuda/mbarrier.py new file mode 100644 index 00000000..784a5299 --- /dev/null +++ b/python/tilus/backends/emitters/cuda/mbarrier.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hidet.ir.dtypes import boolean +from hidet.ir.primitives.cuda.barrier import mbarrier_arrive, mbarrier_init, mbarrier_wait + +from tilus.backends.codegen import BaseInstEmitter, register_emitter +from tilus.ir.instructions import ArriveBarrierInst, ArriveRemoteBarrierInst, InitBarrierInst, WaitBarrierInst +from tilus.target import nvgpu_sm80 + + +@register_emitter(InitBarrierInst, target=nvgpu_sm80) +class InitBarrierInstEmitter(BaseInstEmitter): + def emit(self, inst: InitBarrierInst) -> None: + with self.if_then(self.current_worker == 0): + mbarrier_init(~inst.barrier, inst.count) + + +@register_emitter(ArriveBarrierInst, target=nvgpu_sm80) +class ArriveBarrierInstEmitter(BaseInstEmitter): + def emit(self, inst: ArriveBarrierInst) -> None: + mbarrier_arrive(~inst.barrier) + + +@register_emitter(ArriveRemoteBarrierInst, target=nvgpu_sm80) +class ArriveRemoteBarrierInstEmitter(BaseInstEmitter): + def emit(self, inst: ArriveRemoteBarrierInst) -> None: + mbarrier_arrive(~inst.barrier, inst.remote_block, pred=boolean.true) + + +@register_emitter(WaitBarrierInst, target=nvgpu_sm80) +class WaitBarrierInstEmitter(BaseInstEmitter): + def emit(self, inst: WaitBarrierInst) -> None: + mbarrier_wait(~inst.barrier, inst.phase) diff --git a/python/tilus/ir/builders/stmt_builder.py b/python/tilus/ir/builders/stmt_builder.py index dcfd8292..27c0dfa5 100644 --- a/python/tilus/ir/builders/stmt_builder.py +++ b/python/tilus/ir/builders/stmt_builder.py @@ -27,16 +27,20 @@ from tilus.ir.inst import Instruction, InstructionError from tilus.ir.instructions.annotation import AnnotateLayoutInst from tilus.ir.instructions.cuda import ( + ArriveBarrierInst, + ArriveRemoteBarrierInst, CopyAsyncCommitGroupInst, CopyAsyncGenericInst, CopyAsyncInst, CopyAsyncWaitAllInst, CopyAsyncWaitGroupInst, DotInst, + InitBarrierInst, LoadMatrixConfig, LoadMatrixInst, LockSemaphoreInst, ReleaseSemaphoreInst, + WaitBarrierInst, ) from tilus.ir.instructions.generic import ( AddInst, @@ -944,6 +948,23 @@ def exit(self) -> None: inst = ExitInst.create() self.append(inst) + # barrier + def init_barrier(self, barrier: Expr, count: Expr) -> None: + inst = InitBarrierInst.create(barrier=barrier, count=count) + self.append(inst) + + def arrive_barrier(self, barrier: Expr) -> None: + inst = ArriveBarrierInst.create(barrier=barrier) + self.append(inst) + + def arrive_remote_barrier(self, barrier: Expr, remote_block: Expr) -> None: + inst = ArriveRemoteBarrierInst.create(barrier=barrier, remote_block=remote_block) + self.append(inst) + + def wait_barrier(self, barrier: Expr, phase: Expr) -> None: + inst = WaitBarrierInst.create(barrier=barrier, phase=phase) + self.append(inst) + # annotations def annotate_layout(self, tensor: RegisterTensor, layout: RegisterLayout) -> None: inst = AnnotateLayoutInst.create(tensor=tensor, layout=layout) diff --git a/python/tilus/ir/instructions/__init__.py b/python/tilus/ir/instructions/__init__.py index 8322bc8a..27998c69 100644 --- a/python/tilus/ir/instructions/__init__.py +++ b/python/tilus/ir/instructions/__init__.py @@ -16,17 +16,21 @@ from .annotation import AnnotateLayoutInst from .cuda import ( + ArriveBarrierInst, + ArriveRemoteBarrierInst, CopyAsyncCommitGroupInst, CopyAsyncGenericInst, CopyAsyncInst, CopyAsyncWaitAllInst, CopyAsyncWaitGroupInst, DotInst, + InitBarrierInst, LoadMatrixConfig, LoadMatrixInst, LockSemaphoreInst, ReleaseSemaphoreInst, SimtDotInst, + WaitBarrierInst, ) from .generic import ( AddInst, diff --git a/python/tilus/ir/instructions/cuda/__init__.py b/python/tilus/ir/instructions/cuda/__init__.py index 20b13caf..f4b33b75 100644 --- a/python/tilus/ir/instructions/cuda/__init__.py +++ b/python/tilus/ir/instructions/cuda/__init__.py @@ -20,6 +20,7 @@ CopyAsyncWaitGroupInst, ) from .ldmatrix import LoadMatrixConfig, LoadMatrixInst +from .mbarrier import ArriveBarrierInst, ArriveRemoteBarrierInst, InitBarrierInst, WaitBarrierInst from .mma_dot import DotInst from .semaphore import LockSemaphoreInst, ReleaseSemaphoreInst from .simt_dot import SimtDotInst diff --git a/python/tilus/ir/instructions/cuda/mbarrier.py b/python/tilus/ir/instructions/cuda/mbarrier.py new file mode 100644 index 00000000..22689f39 --- /dev/null +++ b/python/tilus/ir/instructions/cuda/mbarrier.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass + +from hidet.ir.expr import Expr + +from tilus.ir.inst import Instruction + + +@dataclass(frozen=True, eq=False) +class InitBarrierInst(Instruction): + barrier: Expr + count: Expr + + @staticmethod + def create(barrier: Expr, count: Expr) -> InitBarrierInst: + return InitBarrierInst(output=None, inputs=(), barrier=barrier, count=count) + + +@dataclass(frozen=True, eq=False) +class ArriveBarrierInst(Instruction): + barrier: Expr + + @staticmethod + def create(barrier: Expr) -> ArriveBarrierInst: + return ArriveBarrierInst(output=None, inputs=(), barrier=barrier) + + +@dataclass(frozen=True, eq=False) +class ArriveRemoteBarrierInst(Instruction): + barrier: Expr + remote_block: Expr + + @staticmethod + def create(barrier: Expr, remote_block: Expr) -> ArriveRemoteBarrierInst: + return ArriveRemoteBarrierInst(output=None, inputs=(), barrier=barrier, remote_block=remote_block) + + +@dataclass(frozen=True, eq=False) +class WaitBarrierInst(Instruction): + barrier: Expr + phase: Expr + + @staticmethod + def create(barrier: Expr, phase: Expr) -> WaitBarrierInst: + return WaitBarrierInst(output=None, inputs=(), barrier=barrier, phase=phase) diff --git a/python/tilus/ir/mfunction/ops.py b/python/tilus/ir/mfunction/ops.py index b7b63ab9..edc60925 100644 --- a/python/tilus/ir/mfunction/ops.py +++ b/python/tilus/ir/mfunction/ops.py @@ -156,7 +156,7 @@ def cover( """Check whether the multi-function fa covers the multi-function fb. The size and image size of both multi-functions must be the same. For every x in the domain, if we have - fb(x) \subseteq fa(x), + fb(x) subset_eq fa(x), then we say that fa covers fb. In other words, for every x in the domain of fb, the image of fb at x is a subset of the image of fa at x. diff --git a/python/tilus/lang/script.py b/python/tilus/lang/script.py index cba5f931..20ecde21 100644 --- a/python/tilus/lang/script.py +++ b/python/tilus/lang/script.py @@ -17,7 +17,7 @@ import typing from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Type, Union -from hidet.ir.dtypes import boolean +from hidet.ir.dtypes import boolean, int32 from hidet.ir.expr import Constant, Equal, Expr, LogicalAnd, Mod, Var, as_expr from hidet.ir.primitives.cuda.vars import blockIdx, dim3, gridDim from hidet.ir.tools import infer_type @@ -1627,6 +1627,70 @@ def assign(self, dst: RegisterTensor, src: RegisterTensor) -> None: raise InstructionError("The dtypes of dst and src must match, got {} and {}".format(dst.dtype, src.dtype)) self._builder.assign_register(dst, src) + def init_barrier(self, barrier: Expr, count: Expr | int) -> None: + """Initialize a barrier. + + This instruction initializes a memory barrier in shared memory. The `barrier` parameter must be an addressable + expression whose address is in shared memory and aligned to 8 bytes. The barrier itself must be of type uint64. + + Parameters + ---------- + barrier: Expr + The expression with type uint64. It must be addressable in shared memory. + count: Expr | int + The number of threads that must arrive at the barrier before any of them can proceed. It must be evaluated + to a positive int32. + """ + self._builder.init_barrier(barrier, count if isinstance(count, Expr) else int32(count)) + + def arrive_barrier(self, barrier: Expr) -> None: + """Arrive at a barrier. + + This instruction indicates that the thread block (or the current partition of the thread block) has reached the + specified barrier. + + Parameters + ---------- + barrier: Expr + The expression with type uint64 that represents the barrier to arrive at. It must be addressable in shared + memory. + """ + self._builder.arrive_barrier(barrier) + + def arrive_remote_barrier(self, barrier: Expr, remote_block: Expr) -> None: + """Arrive at a remote barrier. + + This instruction indicates that a remote thread block has reached the specified barrier. It is used for + inter-block synchronization, allowing one thread block to signal another thread block that it has reached a + barrier. + + Parameters + ---------- + barrier: Expr + The expression with type uint64 that represents the barrier to arrive at. It must be addressable in shared + memory. + remote_block: Expr + The thread block index of the remote thread block that the current block is signaling the arrival to. It + should be an expression that evaluates to a non-negative int32. + """ + self._builder.arrive_remote_barrier(barrier, remote_block) + + def wait_barrier(self, barrier: Expr, phase: Expr | int) -> None: + """Wait at a barrier. + + This instruction makes the thread block (or the current partition of the thread block) wait at the specified + barrier until the entire thread block (or the current partition) has arrived at the barrier. + + Parameters + ---------- + barrier: Expr + The expression with type uint64 that represents the barrier to wait at. It must be addressable in shared + memory. + phase: Expr | int + The phase value to wait for. It must be evaluated to either 0 or 1. + """ + self._builder.wait_barrier(barrier, phase if isinstance(phase, Expr) else int32(phase)) + @staticmethod def static_assert(cond: bool | Expr, msg: str) -> None: if not isinstance(cond, Constant) and not isinstance(cond, bool): diff --git a/scripts/sign-commits.sh b/scripts/sign-commits.sh index de7fd239..31d14faf 100755 --- a/scripts/sign-commits.sh +++ b/scripts/sign-commits.sh @@ -1,31 +1,30 @@ #!/bin/bash -# This script signs off all unsigned commits in the current branch that are not in the main branch, -# and shows the newly signed commits. - set -e MAIN_BRANCH="main" git fetch origin - BASE=$(git merge-base HEAD origin/$MAIN_BRANCH) -## List commits after the common ancestor missing "Signed-off-by" -#UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do -# if ! git show --quiet --format=%B $commit | grep -q "Signed-off-by:"; then -# echo $commit -# fi -#done) -# -#if [ -z "$UNSIGNED_COMMITS" ]; then -# echo "No unsigned commits to sign off." -# exit 0 -#fi +# Find unsigned commits +UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do + if ! git verify-commit $commit &>/dev/null; then + echo $commit + fi +done) + +if [ -z "$UNSIGNED_COMMITS" ]; then + echo "No commits need GPG signing." + exit 0 +fi -# Rebase with signoff -git rebase --signoff $BASE +# Sign each unsigned commit with GPG +echo "Signing the following commits with GPG:" +for commit in $UNSIGNED_COMMITS; do + git rebase --onto $commit^ $commit --exec "git commit --amend -S --no-edit" +done -#echo "Newly signed commits:" -#for commit in $UNSIGNED_COMMITS; do -# git log --format="* %h %s" -n 1 $commit -#done \ No newline at end of file +# Print signed commits +for commit in $UNSIGNED_COMMITS; do + git log --format="* %h %s" -n 1 $commit +done From 2c4a9cb5c67e90c96ee8adeca34bc9102a02f96f Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 6 Sep 2025 20:34:33 -0400 Subject: [PATCH 2/8] update sm requirement for try_wait Signed-off-by: Yaoyao Ding --- python/tilus/backends/emitters/cuda/__init__.py | 2 +- python/tilus/backends/emitters/cuda/mbarrier.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tilus/backends/emitters/cuda/__init__.py b/python/tilus/backends/emitters/cuda/__init__.py index a79baa76..e9623d62 100644 --- a/python/tilus/backends/emitters/cuda/__init__.py +++ b/python/tilus/backends/emitters/cuda/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import cp_async, ldmatrix, mma_dot, semaphore, simt_dot +from . import cp_async, ldmatrix, mma_dot, semaphore, simt_dot, mbarrier diff --git a/python/tilus/backends/emitters/cuda/mbarrier.py b/python/tilus/backends/emitters/cuda/mbarrier.py index 784a5299..3512536c 100644 --- a/python/tilus/backends/emitters/cuda/mbarrier.py +++ b/python/tilus/backends/emitters/cuda/mbarrier.py @@ -18,29 +18,29 @@ from tilus.backends.codegen import BaseInstEmitter, register_emitter from tilus.ir.instructions import ArriveBarrierInst, ArriveRemoteBarrierInst, InitBarrierInst, WaitBarrierInst -from tilus.target import nvgpu_sm80 +from tilus.target import nvgpu_sm80, nvgpu_sm90 @register_emitter(InitBarrierInst, target=nvgpu_sm80) class InitBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: InitBarrierInst) -> None: with self.if_then(self.current_worker == 0): - mbarrier_init(~inst.barrier, inst.count) + self.append(mbarrier_init(~inst.barrier, inst.count)) @register_emitter(ArriveBarrierInst, target=nvgpu_sm80) class ArriveBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: ArriveBarrierInst) -> None: - mbarrier_arrive(~inst.barrier) + self.append(mbarrier_arrive(~inst.barrier)) @register_emitter(ArriveRemoteBarrierInst, target=nvgpu_sm80) class ArriveRemoteBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: ArriveRemoteBarrierInst) -> None: - mbarrier_arrive(~inst.barrier, inst.remote_block, pred=boolean.true) + self.append(mbarrier_arrive(~inst.barrier, inst.remote_block, pred=boolean.true)) -@register_emitter(WaitBarrierInst, target=nvgpu_sm80) +@register_emitter(WaitBarrierInst, target=nvgpu_sm90) class WaitBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: WaitBarrierInst) -> None: - mbarrier_wait(~inst.barrier, inst.phase) + self.append(mbarrier_wait(~inst.barrier, inst.phase)) From be0778b11b6a768a1c97353a4b4bee480c565b29 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sun, 7 Sep 2025 04:34:41 +0000 Subject: [PATCH 3/8] fix Signed-off-by: Yaoyao Ding --- .../tilus/backends/emitters/cuda/__init__.py | 2 +- .../tilus/backends/emitters/cuda/mbarrier.py | 8 +-- python/tilus/lang/script.py | 11 ++-- tests/lang/test_mbarrier.py | 54 +++++++++++++++++++ 4 files changed, 63 insertions(+), 12 deletions(-) create mode 100644 tests/lang/test_mbarrier.py diff --git a/python/tilus/backends/emitters/cuda/__init__.py b/python/tilus/backends/emitters/cuda/__init__.py index e9623d62..82a756d4 100644 --- a/python/tilus/backends/emitters/cuda/__init__.py +++ b/python/tilus/backends/emitters/cuda/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from . import cp_async, ldmatrix, mma_dot, semaphore, simt_dot, mbarrier +from . import cp_async, ldmatrix, mbarrier, mma_dot, semaphore, simt_dot diff --git a/python/tilus/backends/emitters/cuda/mbarrier.py b/python/tilus/backends/emitters/cuda/mbarrier.py index 3512536c..5530833a 100644 --- a/python/tilus/backends/emitters/cuda/mbarrier.py +++ b/python/tilus/backends/emitters/cuda/mbarrier.py @@ -25,22 +25,22 @@ class InitBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: InitBarrierInst) -> None: with self.if_then(self.current_worker == 0): - self.append(mbarrier_init(~inst.barrier, inst.count)) + self.append(mbarrier_init(inst.barrier, inst.count)) @register_emitter(ArriveBarrierInst, target=nvgpu_sm80) class ArriveBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: ArriveBarrierInst) -> None: - self.append(mbarrier_arrive(~inst.barrier)) + self.append(mbarrier_arrive(inst.barrier)) @register_emitter(ArriveRemoteBarrierInst, target=nvgpu_sm80) class ArriveRemoteBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: ArriveRemoteBarrierInst) -> None: - self.append(mbarrier_arrive(~inst.barrier, inst.remote_block, pred=boolean.true)) + self.append(mbarrier_arrive(inst.barrier, inst.remote_block, pred=boolean.true)) @register_emitter(WaitBarrierInst, target=nvgpu_sm90) class WaitBarrierInstEmitter(BaseInstEmitter): def emit(self, inst: WaitBarrierInst) -> None: - self.append(mbarrier_wait(~inst.barrier, inst.phase)) + self.append(mbarrier_wait(inst.barrier, inst.phase)) diff --git a/python/tilus/lang/script.py b/python/tilus/lang/script.py index 20ecde21..6a85b342 100644 --- a/python/tilus/lang/script.py +++ b/python/tilus/lang/script.py @@ -1636,7 +1636,7 @@ def init_barrier(self, barrier: Expr, count: Expr | int) -> None: Parameters ---------- barrier: Expr - The expression with type uint64. It must be addressable in shared memory. + The pointer to the barrier in shared memory. count: Expr | int The number of threads that must arrive at the barrier before any of them can proceed. It must be evaluated to a positive int32. @@ -1652,8 +1652,7 @@ def arrive_barrier(self, barrier: Expr) -> None: Parameters ---------- barrier: Expr - The expression with type uint64 that represents the barrier to arrive at. It must be addressable in shared - memory. + The pointer to the barrier in shared memory. """ self._builder.arrive_barrier(barrier) @@ -1667,8 +1666,7 @@ def arrive_remote_barrier(self, barrier: Expr, remote_block: Expr) -> None: Parameters ---------- barrier: Expr - The expression with type uint64 that represents the barrier to arrive at. It must be addressable in shared - memory. + The pointer to the barrier in shared memory. remote_block: Expr The thread block index of the remote thread block that the current block is signaling the arrival to. It should be an expression that evaluates to a non-negative int32. @@ -1684,8 +1682,7 @@ def wait_barrier(self, barrier: Expr, phase: Expr | int) -> None: Parameters ---------- barrier: Expr - The expression with type uint64 that represents the barrier to wait at. It must be addressable in shared - memory. + The pointer to the barrier in shared memory. phase: Expr | int The phase value to wait for. It must be evaluated to either 0 or 1. """ diff --git a/tests/lang/test_mbarrier.py b/tests/lang/test_mbarrier.py new file mode 100644 index 00000000..3ae2cffb --- /dev/null +++ b/tests/lang/test_mbarrier.py @@ -0,0 +1,54 @@ +import torch + +import tilus +from tilus import int32, uint64 + + +class DemoBarrier(tilus.Script): + def __init__(self): + super().__init__() + self.block_size = 128 + + def __call__(self, n: int32, x_ptr: ~int32, y_ptr: ~int32): + self.attrs.blocks = 1 + self.attrs.warps = 2 + + g_x = self.global_view(x_ptr, dtype=int32, shape=[n]) + g_y = self.global_view(y_ptr, dtype=int32, shape=[n]) + s_x = self.shared_tensor(dtype=int32, shape=[self.block_size]) + barriers = self.shared_tensor(dtype=uint64, shape=[1]) + + self.init_barrier(~barriers[0], count=self.attrs.warps * 32) + self.sync() + + phase: int32 = 0 + for bi in self.range(0, n, self.block_size): + self.store_shared( + dst=s_x, + src=self.load_global(g_x, offsets=[bi * self.block_size], shape=[self.block_size]) + ) + + self.arrive_barrier(~barriers[0]) + self.wait_barrier(~barriers[0], phase) + phase ^= 1 + + self.store_global( + dst=g_y, + src=self.load_shared(s_x) + 1, + offsets=[bi * self.block_size] + ) + + self.arrive_barrier(~barriers[0]) + self.wait_barrier(~barriers[0], phase) + phase ^= 1 + + +def test_mbarrier(): + n = 128 + x = torch.arange(n, dtype=torch.int32).cuda() + y = torch.zeros(n, dtype=torch.int32).cuda() + kernel = DemoBarrier() + kernel(n, x, y) + torch.cuda.synchronize() + torch.testing.assert_close(y, x + 1) + From 1643e6dfb463136052e60392098a51a50c981bf9 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sun, 7 Sep 2025 05:01:47 +0000 Subject: [PATCH 4/8] update Signed-off-by: Yaoyao Ding --- scripts/sign-commits.sh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/sign-commits.sh b/scripts/sign-commits.sh index 31d14faf..b81bb649 100755 --- a/scripts/sign-commits.sh +++ b/scripts/sign-commits.sh @@ -18,13 +18,12 @@ if [ -z "$UNSIGNED_COMMITS" ]; then exit 0 fi -# Sign each unsigned commit with GPG +# Sign all unsigned commits in one rebase echo "Signing the following commits with GPG:" -for commit in $UNSIGNED_COMMITS; do - git rebase --onto $commit^ $commit --exec "git commit --amend -S --no-edit" -done +echo "$UNSIGNED_COMMITS" +git rebase --exec "git commit --amend -S --no-edit" $BASE # Print signed commits for commit in $UNSIGNED_COMMITS; do git log --format="* %h %s" -n 1 $commit -done +done \ No newline at end of file From 0329c49d20afa4ec7c3e2c702495f960c028184c Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sun, 7 Sep 2025 05:04:17 +0000 Subject: [PATCH 5/8] update Signed-off-by: Yaoyao Ding --- scripts/sign-commits.sh | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/scripts/sign-commits.sh b/scripts/sign-commits.sh index b81bb649..89b94797 100755 --- a/scripts/sign-commits.sh +++ b/scripts/sign-commits.sh @@ -6,24 +6,5 @@ MAIN_BRANCH="main" git fetch origin BASE=$(git merge-base HEAD origin/$MAIN_BRANCH) -# Find unsigned commits -UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do - if ! git verify-commit $commit &>/dev/null; then - echo $commit - fi -done) - -if [ -z "$UNSIGNED_COMMITS" ]; then - echo "No commits need GPG signing." - exit 0 -fi - # Sign all unsigned commits in one rebase -echo "Signing the following commits with GPG:" -echo "$UNSIGNED_COMMITS" -git rebase --exec "git commit --amend -S --no-edit" $BASE - -# Print signed commits -for commit in $UNSIGNED_COMMITS; do - git log --format="* %h %s" -n 1 $commit -done \ No newline at end of file +git rebase $BASE --signoff From 255a5fdfc0c8302f2faafedca0e7715c7c257b5b Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sun, 7 Sep 2025 05:18:10 +0000 Subject: [PATCH 6/8] add barrier Signed-off-by: Yaoyao Ding --- tests/lang/test_mbarrier.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/lang/test_mbarrier.py b/tests/lang/test_mbarrier.py index 3ae2cffb..97df4aa2 100644 --- a/tests/lang/test_mbarrier.py +++ b/tests/lang/test_mbarrier.py @@ -1,6 +1,5 @@ -import torch - import tilus +import torch from tilus import int32, uint64 @@ -24,19 +23,14 @@ def __call__(self, n: int32, x_ptr: ~int32, y_ptr: ~int32): phase: int32 = 0 for bi in self.range(0, n, self.block_size): self.store_shared( - dst=s_x, - src=self.load_global(g_x, offsets=[bi * self.block_size], shape=[self.block_size]) + dst=s_x, src=self.load_global(g_x, offsets=[bi * self.block_size], shape=[self.block_size]) ) self.arrive_barrier(~barriers[0]) self.wait_barrier(~barriers[0], phase) phase ^= 1 - self.store_global( - dst=g_y, - src=self.load_shared(s_x) + 1, - offsets=[bi * self.block_size] - ) + self.store_global(dst=g_y, src=self.load_shared(s_x) + 1, offsets=[bi * self.block_size]) self.arrive_barrier(~barriers[0]) self.wait_barrier(~barriers[0], phase) @@ -51,4 +45,3 @@ def test_mbarrier(): kernel(n, x, y) torch.cuda.synchronize() torch.testing.assert_close(y, x + 1) - From 2ce56d0b31e3be6bcf8652dba705c813b8af508e Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sun, 7 Sep 2025 05:41:54 +0000 Subject: [PATCH 7/8] update tests Signed-off-by: Yaoyao Ding --- python/tilus/testing/__init__.py | 1 + python/tilus/testing/requires.py | 10 +++++++-- tests/conftest.py | 32 --------------------------- tests/lang/test_blocks_per_cluster.py | 5 ++--- tests/lang/test_mbarrier.py | 2 ++ 5 files changed, 13 insertions(+), 37 deletions(-) create mode 100644 python/tilus/testing/__init__.py diff --git a/python/tilus/testing/__init__.py b/python/tilus/testing/__init__.py new file mode 100644 index 00000000..04eb469b --- /dev/null +++ b/python/tilus/testing/__init__.py @@ -0,0 +1 @@ +from .requires import requires diff --git a/python/tilus/testing/requires.py b/python/tilus/testing/requires.py index 9445fb6c..e66fb736 100644 --- a/python/tilus/testing/requires.py +++ b/python/tilus/testing/requires.py @@ -1,10 +1,11 @@ from typing import Callable import pytest -from tilus.target import Target, get_current_target +from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90 -def requires(target: Target) -> Callable[[Callable], Callable]: + +def _requires(target: Target) -> Callable[[Callable], Callable]: """ Pytest fixture decorator that skips tests if the current GPU doesn't support the required architecture. @@ -33,3 +34,8 @@ def decorator(test_func): return pytest.mark.skip(f"Cannot determine current GPU capability: {e}")(test_func) return decorator + + +class requires: + nvgpu_sm90 = _requires(nvgpu_sm90) + nvgpu_sm80 = _requires(nvgpu_sm80) diff --git a/tests/conftest.py b/tests/conftest.py index b5bca303..3f3a0d5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,38 +29,6 @@ import pytest import tilus import tilus.utils -from tilus.target import Target, get_current_target - - -def requires(target: Target): - """ - Pytest fixture decorator that skips tests if the current GPU doesn't support the required architecture. - - Parameters - ---------- - target : Target - The required target architecture. Examples include 'sm_90a', 'sm_80', - """ - - def decorator(test_func): - try: - required_target = target - current_target = get_current_target() - current_capability = current_target.properties.compute_capability - - if not current_target.supports(required_target): - return pytest.mark.skip( - f"Test requires architecture {required_target}, but current GPU capability is {current_capability}" - )(test_func) - return test_func - except ValueError as e: - # If we can't parse the architecture string, skip the test - return pytest.mark.skip(f"Invalid architecture requirement: {e}")(test_func) - except Exception as e: - # If we can't determine current capability, skip the test - return pytest.mark.skip(f"Cannot determine current GPU capability: {e}")(test_func) - - return decorator def pytest_sessionstart(session): diff --git a/tests/lang/test_blocks_per_cluster.py b/tests/lang/test_blocks_per_cluster.py index 052bd8c4..f0d3ee96 100644 --- a/tests/lang/test_blocks_per_cluster.py +++ b/tests/lang/test_blocks_per_cluster.py @@ -1,5 +1,4 @@ import tilus -from tilus.target import nvgpu_sm80, nvgpu_sm90 from tilus.testing.requires import requires @@ -16,13 +15,13 @@ def __call__(self): self.printf("blockIdx: [%d, %d, %d]\n", self.blockIdx.x, self.blockIdx.y, self.blockIdx.z) -@requires(nvgpu_sm90) +@requires.nvgpu_sm90 def test_script_blocks_per_cluster_post_sm90(): kernel = DemoBlockCluster((2, 2, 1)) kernel() -@requires(nvgpu_sm80) +@requires.nvgpu_sm80 def test_script_blocks_per_cluster_pre_sm90(): kernel = DemoBlockCluster((1, 1, 1)) kernel() diff --git a/tests/lang/test_mbarrier.py b/tests/lang/test_mbarrier.py index 97df4aa2..92be1ad0 100644 --- a/tests/lang/test_mbarrier.py +++ b/tests/lang/test_mbarrier.py @@ -1,6 +1,7 @@ import tilus import torch from tilus import int32, uint64 +from tilus.testing import requires class DemoBarrier(tilus.Script): @@ -37,6 +38,7 @@ def __call__(self, n: int32, x_ptr: ~int32, y_ptr: ~int32): phase ^= 1 +@requires.nvgpu_sm80 def test_mbarrier(): n = 128 x = torch.arange(n, dtype=torch.int32).cuda() From 39786bd51f52a96a275121da65b8d9c8fe74afda Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sun, 7 Sep 2025 05:58:53 +0000 Subject: [PATCH 8/8] require sm90 to test barrier Signed-off-by: Yaoyao Ding --- tests/lang/test_mbarrier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lang/test_mbarrier.py b/tests/lang/test_mbarrier.py index 92be1ad0..18a872e7 100644 --- a/tests/lang/test_mbarrier.py +++ b/tests/lang/test_mbarrier.py @@ -38,7 +38,7 @@ def __call__(self, n: int32, x_ptr: ~int32, y_ptr: ~int32): phase ^= 1 -@requires.nvgpu_sm80 +@requires.nvgpu_sm90 def test_mbarrier(): n = 128 x = torch.arange(n, dtype=torch.int32).cuda()