Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tilus/backends/emitters/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import cp_async, ldmatrix, mma_dot, semaphore, simt_dot
from . import cp_async, ldmatrix, mbarrier, mma_dot, semaphore, simt_dot
46 changes: 46 additions & 0 deletions python/tilus/backends/emitters/cuda/mbarrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from hidet.ir.dtypes import boolean
from hidet.ir.primitives.cuda.barrier import mbarrier_arrive, mbarrier_init, mbarrier_wait

from tilus.backends.codegen import BaseInstEmitter, register_emitter
from tilus.ir.instructions import ArriveBarrierInst, ArriveRemoteBarrierInst, InitBarrierInst, WaitBarrierInst
from tilus.target import nvgpu_sm80, nvgpu_sm90


@register_emitter(InitBarrierInst, target=nvgpu_sm80)
class InitBarrierInstEmitter(BaseInstEmitter):
def emit(self, inst: InitBarrierInst) -> None:
with self.if_then(self.current_worker == 0):
self.append(mbarrier_init(inst.barrier, inst.count))


@register_emitter(ArriveBarrierInst, target=nvgpu_sm80)
class ArriveBarrierInstEmitter(BaseInstEmitter):
def emit(self, inst: ArriveBarrierInst) -> None:
self.append(mbarrier_arrive(inst.barrier))


@register_emitter(ArriveRemoteBarrierInst, target=nvgpu_sm80)
class ArriveRemoteBarrierInstEmitter(BaseInstEmitter):
def emit(self, inst: ArriveRemoteBarrierInst) -> None:
self.append(mbarrier_arrive(inst.barrier, inst.remote_block, pred=boolean.true))


@register_emitter(WaitBarrierInst, target=nvgpu_sm90)
class WaitBarrierInstEmitter(BaseInstEmitter):
def emit(self, inst: WaitBarrierInst) -> None:
self.append(mbarrier_wait(inst.barrier, inst.phase))
21 changes: 21 additions & 0 deletions python/tilus/ir/builders/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@
from tilus.ir.inst import Instruction, InstructionError
from tilus.ir.instructions.annotation import AnnotateLayoutInst
from tilus.ir.instructions.cuda import (
ArriveBarrierInst,
ArriveRemoteBarrierInst,
CopyAsyncCommitGroupInst,
CopyAsyncGenericInst,
CopyAsyncInst,
CopyAsyncWaitAllInst,
CopyAsyncWaitGroupInst,
DotInst,
InitBarrierInst,
LoadMatrixConfig,
LoadMatrixInst,
LockSemaphoreInst,
ReleaseSemaphoreInst,
WaitBarrierInst,
)
from tilus.ir.instructions.generic import (
AddInst,
Expand Down Expand Up @@ -944,6 +948,23 @@ def exit(self) -> None:
inst = ExitInst.create()
self.append(inst)

# barrier
def init_barrier(self, barrier: Expr, count: Expr) -> None:
inst = InitBarrierInst.create(barrier=barrier, count=count)
self.append(inst)

def arrive_barrier(self, barrier: Expr) -> None:
inst = ArriveBarrierInst.create(barrier=barrier)
self.append(inst)

def arrive_remote_barrier(self, barrier: Expr, remote_block: Expr) -> None:
inst = ArriveRemoteBarrierInst.create(barrier=barrier, remote_block=remote_block)
self.append(inst)

def wait_barrier(self, barrier: Expr, phase: Expr) -> None:
inst = WaitBarrierInst.create(barrier=barrier, phase=phase)
self.append(inst)

# annotations
def annotate_layout(self, tensor: RegisterTensor, layout: RegisterLayout) -> None:
inst = AnnotateLayoutInst.create(tensor=tensor, layout=layout)
Expand Down
4 changes: 4 additions & 0 deletions python/tilus/ir/instructions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@

from .annotation import AnnotateLayoutInst
from .cuda import (
ArriveBarrierInst,
ArriveRemoteBarrierInst,
CopyAsyncCommitGroupInst,
CopyAsyncGenericInst,
CopyAsyncInst,
CopyAsyncWaitAllInst,
CopyAsyncWaitGroupInst,
DotInst,
InitBarrierInst,
LoadMatrixConfig,
LoadMatrixInst,
LockSemaphoreInst,
ReleaseSemaphoreInst,
SimtDotInst,
WaitBarrierInst,
)
from .generic import (
AddInst,
Expand Down
1 change: 1 addition & 0 deletions python/tilus/ir/instructions/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CopyAsyncWaitGroupInst,
)
from .ldmatrix import LoadMatrixConfig, LoadMatrixInst
from .mbarrier import ArriveBarrierInst, ArriveRemoteBarrierInst, InitBarrierInst, WaitBarrierInst
from .mma_dot import DotInst
from .semaphore import LockSemaphoreInst, ReleaseSemaphoreInst
from .simt_dot import SimtDotInst
60 changes: 60 additions & 0 deletions python/tilus/ir/instructions/cuda/mbarrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from dataclasses import dataclass

from hidet.ir.expr import Expr

from tilus.ir.inst import Instruction


@dataclass(frozen=True, eq=False)
class InitBarrierInst(Instruction):
barrier: Expr
count: Expr

@staticmethod
def create(barrier: Expr, count: Expr) -> InitBarrierInst:
return InitBarrierInst(output=None, inputs=(), barrier=barrier, count=count)


@dataclass(frozen=True, eq=False)
class ArriveBarrierInst(Instruction):
barrier: Expr

@staticmethod
def create(barrier: Expr) -> ArriveBarrierInst:
return ArriveBarrierInst(output=None, inputs=(), barrier=barrier)


@dataclass(frozen=True, eq=False)
class ArriveRemoteBarrierInst(Instruction):
barrier: Expr
remote_block: Expr

@staticmethod
def create(barrier: Expr, remote_block: Expr) -> ArriveRemoteBarrierInst:
return ArriveRemoteBarrierInst(output=None, inputs=(), barrier=barrier, remote_block=remote_block)


@dataclass(frozen=True, eq=False)
class WaitBarrierInst(Instruction):
barrier: Expr
phase: Expr

@staticmethod
def create(barrier: Expr, phase: Expr) -> WaitBarrierInst:
return WaitBarrierInst(output=None, inputs=(), barrier=barrier, phase=phase)
2 changes: 1 addition & 1 deletion python/tilus/ir/mfunction/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def cover(
"""Check whether the multi-function fa covers the multi-function fb.

The size and image size of both multi-functions must be the same. For every x in the domain, if we have
fb(x) \subseteq fa(x),
fb(x) subset_eq fa(x),
then we say that fa covers fb. In other words, for every x in the domain of fb, the image of fb at x is a subset
of the image of fa at x.

Expand Down
63 changes: 62 additions & 1 deletion python/tilus/lang/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import typing
from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Type, Union

from hidet.ir.dtypes import boolean
from hidet.ir.dtypes import boolean, int32
from hidet.ir.expr import Constant, Equal, Expr, LogicalAnd, Mod, Var, as_expr
from hidet.ir.primitives.cuda.vars import blockIdx, dim3, gridDim
from hidet.ir.tools import infer_type
Expand Down Expand Up @@ -1627,6 +1627,67 @@ def assign(self, dst: RegisterTensor, src: RegisterTensor) -> None:
raise InstructionError("The dtypes of dst and src must match, got {} and {}".format(dst.dtype, src.dtype))
self._builder.assign_register(dst, src)

def init_barrier(self, barrier: Expr, count: Expr | int) -> None:
"""Initialize a barrier.

This instruction initializes a memory barrier in shared memory. The `barrier` parameter must be an addressable
expression whose address is in shared memory and aligned to 8 bytes. The barrier itself must be of type uint64.

Parameters
----------
barrier: Expr
The pointer to the barrier in shared memory.
count: Expr | int
The number of threads that must arrive at the barrier before any of them can proceed. It must be evaluated
to a positive int32.
"""
self._builder.init_barrier(barrier, count if isinstance(count, Expr) else int32(count))

def arrive_barrier(self, barrier: Expr) -> None:
"""Arrive at a barrier.

This instruction indicates that the thread block (or the current partition of the thread block) has reached the
specified barrier.

Parameters
----------
barrier: Expr
The pointer to the barrier in shared memory.
"""
self._builder.arrive_barrier(barrier)

def arrive_remote_barrier(self, barrier: Expr, remote_block: Expr) -> None:
"""Arrive at a remote barrier.

This instruction indicates that a remote thread block has reached the specified barrier. It is used for
inter-block synchronization, allowing one thread block to signal another thread block that it has reached a
barrier.

Parameters
----------
barrier: Expr
The pointer to the barrier in shared memory.
remote_block: Expr
The thread block index of the remote thread block that the current block is signaling the arrival to. It
should be an expression that evaluates to a non-negative int32.
"""
self._builder.arrive_remote_barrier(barrier, remote_block)

def wait_barrier(self, barrier: Expr, phase: Expr | int) -> None:
"""Wait at a barrier.

This instruction makes the thread block (or the current partition of the thread block) wait at the specified
barrier until the entire thread block (or the current partition) has arrived at the barrier.

Parameters
----------
barrier: Expr
The pointer to the barrier in shared memory.
phase: Expr | int
The phase value to wait for. It must be evaluated to either 0 or 1.
"""
self._builder.wait_barrier(barrier, phase if isinstance(phase, Expr) else int32(phase))

@staticmethod
def static_assert(cond: bool | Expr, msg: str) -> None:
if not isinstance(cond, Constant) and not isinstance(cond, bool):
Expand Down
1 change: 1 addition & 0 deletions python/tilus/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .requires import requires
10 changes: 8 additions & 2 deletions python/tilus/testing/requires.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Callable

import pytest
from tilus.target import Target, get_current_target

from tilus.target import Target, get_current_target, nvgpu_sm80, nvgpu_sm90

def requires(target: Target) -> Callable[[Callable], Callable]:

def _requires(target: Target) -> Callable[[Callable], Callable]:
"""
Pytest fixture decorator that skips tests if the current GPU doesn't support the required architecture.

Expand Down Expand Up @@ -33,3 +34,8 @@ def decorator(test_func):
return pytest.mark.skip(f"Cannot determine current GPU capability: {e}")(test_func)

return decorator


class requires:
nvgpu_sm90 = _requires(nvgpu_sm90)
nvgpu_sm80 = _requires(nvgpu_sm80)
25 changes: 2 additions & 23 deletions scripts/sign-commits.sh
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
#!/bin/bash
# This script signs off all unsigned commits in the current branch that are not in the main branch,
# and shows the newly signed commits.

set -e

MAIN_BRANCH="main"

git fetch origin

BASE=$(git merge-base HEAD origin/$MAIN_BRANCH)

## List commits after the common ancestor missing "Signed-off-by"
#UNSIGNED_COMMITS=$(git rev-list $BASE..HEAD | while read commit; do
# if ! git show --quiet --format=%B $commit | grep -q "Signed-off-by:"; then
# echo $commit
# fi
#done)
#
#if [ -z "$UNSIGNED_COMMITS" ]; then
# echo "No unsigned commits to sign off."
# exit 0
#fi

# Rebase with signoff
git rebase --signoff $BASE

#echo "Newly signed commits:"
#for commit in $UNSIGNED_COMMITS; do
# git log --format="* %h %s" -n 1 $commit
#done
# Sign all unsigned commits in one rebase
git rebase $BASE --signoff
32 changes: 0 additions & 32 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,6 @@
import pytest
import tilus
import tilus.utils
from tilus.target import Target, get_current_target


def requires(target: Target):
"""
Pytest fixture decorator that skips tests if the current GPU doesn't support the required architecture.

Parameters
----------
target : Target
The required target architecture. Examples include 'sm_90a', 'sm_80',
"""

def decorator(test_func):
try:
required_target = target
current_target = get_current_target()
current_capability = current_target.properties.compute_capability

if not current_target.supports(required_target):
return pytest.mark.skip(
f"Test requires architecture {required_target}, but current GPU capability is {current_capability}"
)(test_func)
return test_func
except ValueError as e:
# If we can't parse the architecture string, skip the test
return pytest.mark.skip(f"Invalid architecture requirement: {e}")(test_func)
except Exception as e:
# If we can't determine current capability, skip the test
return pytest.mark.skip(f"Cannot determine current GPU capability: {e}")(test_func)

return decorator


def pytest_sessionstart(session):
Expand Down
5 changes: 2 additions & 3 deletions tests/lang/test_blocks_per_cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tilus
from tilus.target import nvgpu_sm80, nvgpu_sm90
from tilus.testing.requires import requires


Expand All @@ -16,13 +15,13 @@ def __call__(self):
self.printf("blockIdx: [%d, %d, %d]\n", self.blockIdx.x, self.blockIdx.y, self.blockIdx.z)


@requires(nvgpu_sm90)
@requires.nvgpu_sm90
def test_script_blocks_per_cluster_post_sm90():
kernel = DemoBlockCluster((2, 2, 1))
kernel()


@requires(nvgpu_sm80)
@requires.nvgpu_sm80
def test_script_blocks_per_cluster_pre_sm90():
kernel = DemoBlockCluster((1, 1, 1))
kernel()
Loading
Loading