From afd4af7de6b259cf927f16ba332618f6e74a1896 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 7 Jul 2025 13:13:54 +0800 Subject: [PATCH 1/2] Extend `_cached_make` to accept `num_warps`, `num_stages`, and `max_num_configs` --- src/ntops/torch.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ntops/torch.py b/src/ntops/torch.py index e99b28e..e180682 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -538,5 +538,12 @@ def tanh(input, *, out=None): @functools.cache -def _cached_make(premake, *args, **keywords): - return ninetoothed.make(*premake(*args, **keywords)) +def _cached_make( + premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords +): + return ninetoothed.make( + *premake(*args, **keywords), + num_warps=num_warps, + num_stages=num_stages, + max_num_configs=max_num_configs, + ) From 3346029cbd2b895ede9b617220bf86efbaac6ea2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 7 Jul 2025 13:28:41 +0800 Subject: [PATCH 2/2] Add `rotary_position_embedding` operator --- .../kernels/rotary_position_embedding.py | 66 +++++++++++++++ src/ntops/torch.py | 26 ++++++ tests/test_rotary_position_embedding.py | 81 +++++++++++++++++++ 3 files changed, 173 insertions(+) create mode 100644 src/ntops/kernels/rotary_position_embedding.py create mode 100644 tests/test_rotary_position_embedding.py diff --git a/src/ntops/kernels/rotary_position_embedding.py b/src/ntops/kernels/rotary_position_embedding.py new file mode 100644 index 0000000..072c2ff --- /dev/null +++ b/src/ntops/kernels/rotary_position_embedding.py @@ -0,0 +1,66 @@ +import functools + +from ninetoothed import Tensor + + +def arrangement(input, sin_table, cos_table, output, interleaved=True): + emb_dim = input.shape[-1] + tile_shape = (1, 1, 1, emb_dim // 2) + + if interleaved: + strides = (-1, -1, -1, 1) + dilation = (1, 1, 1, 2) + else: + strides = None + dilation = None + + def _arrange_input_or_output(tensor): + tensor_arranged = tensor.tile(tile_shape, strides=strides, dilation=dilation) + tensor_arranged = tensor_arranged.tile((1, 1, 1, -1)) + tensor_arranged.dtype = tensor_arranged.dtype.squeeze((0, 1, 2)) + tensor_arranged.dtype.dtype = tensor_arranged.dtype.dtype.squeeze((0, 1, 2)) + + return tensor_arranged + + def _arrange_table(table): + table_arranged = table.tile(tile_shape) + table_arranged.dtype = table_arranged.dtype.squeeze((0, 1, 2)) + + return table_arranged + + input_arranged = _arrange_input_or_output(input) + sin_table_arranged = _arrange_table(sin_table) + cos_table_arranged = _arrange_table(cos_table) + output_arranged = _arrange_input_or_output(output) + + return input_arranged, sin_table_arranged, cos_table_arranged, output_arranged + + +def application(input, sin_table, cos_table, output): + sin_table_loaded = sin_table + cos_table_loaded = cos_table + + input_0 = input[0] + input_1 = input[1] + + output[0] = input_0 * cos_table_loaded - input_1 * sin_table_loaded + output[1] = input_0 * sin_table_loaded + input_1 * cos_table_loaded + + +def premake(ndim, emb_dim=None, dtype=None, interleaved=True): + arrangement_ = functools.partial(arrangement, interleaved=interleaved) + + shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128}) + + tensors = ( + Tensor(ndim, dtype=dtype, shape_options=shape_options), + Tensor(ndim, dtype=dtype, shape_options=shape_options), + Tensor(ndim, dtype=dtype, shape_options=shape_options), + Tensor(ndim, dtype=dtype, shape_options=shape_options), + ) + + if emb_dim is not None: + for tensor in tensors: + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + return arrangement_, application, tensors diff --git a/src/ntops/torch.py b/src/ntops/torch.py index e180682..2a01960 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -32,6 +32,7 @@ import ntops.kernels.pow import ntops.kernels.relu import ntops.kernels.rms_norm +import ntops.kernels.rotary_position_embedding import ntops.kernels.rsqrt import ntops.kernels.scaled_dot_product_attention import ntops.kernels.sigmoid @@ -371,6 +372,31 @@ def rms_norm(input, normalized_shape, weight=None, eps=None): return output +def rotary_position_embedding( + input, sin_table, cos_table, interleaved=True, inplace=False +): + if inplace: + output = input + else: + output = torch.empty_like(input) + + batch_size, _, num_heads, _ = input.shape + + sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1) + + kernel = _cached_make( + ntops.kernels.rotary_position_embedding.premake, + input.ndim, + interleaved=interleaved, + num_warps=1, + ) + + kernel(input, sin_table, cos_table, output) + + return output + + def rsqrt(input, *, out=None): if out is None: out = torch.empty_like(input) diff --git a/tests/test_rotary_position_embedding.py b/tests/test_rotary_position_embedding.py new file mode 100644 index 0000000..b78ffbf --- /dev/null +++ b/tests/test_rotary_position_embedding.py @@ -0,0 +1,81 @@ +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available + + +def _torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True): + batch_size, seq_len, num_heads, emb_dim = input.shape + + assert emb_dim % 2 == 0, "The embedding dimension must be even." + + sin_table = sin_table[None, :, None, :] + cos_table = cos_table[None, :, None, :] + + if interleaved: + pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2) + input_0, input_1 = pair_wise_input[..., 0], pair_wise_input[..., 1] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table + + return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape) + else: + input_0 = input[..., : input.shape[-1] // 2] + input_1 = input[..., input.shape[-1] // 2 :] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table + + return torch.cat((input_0_rotated, input_1_rotated), dim=-1) + + +def _generate_sin_and_cos_tables( + seq_len, emb_dim, base=10000, dtype=torch.float32, device="cuda" +): + assert emb_dim % 2 == 0, "The embedding dimension must be even." + + theta = base ** ( + -2 * (torch.arange(emb_dim // 2, dtype=dtype, device=device) / emb_dim) + ) + + positions = torch.arange(seq_len, dtype=dtype, device=device).unsqueeze(1) + sin_table = torch.sin(positions * theta) + cos_table = torch.cos(positions * theta) + + return sin_table, cos_table + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "dtype, atol, rtol", ((torch.float32, 0.001, 0), (torch.float16, 0.001, 0.001)) +) +@pytest.mark.parametrize("inplace", (False, True)) +@pytest.mark.parametrize("interleaved", (False, True)) +@pytest.mark.parametrize("emb_dim", (32, 64)) +@pytest.mark.parametrize("num_heads", (1, 8)) +@pytest.mark.parametrize("seq_len", (1, 128)) +@pytest.mark.parametrize("batch_size", (1, 4)) +def test_cuda( + batch_size, seq_len, num_heads, emb_dim, interleaved, inplace, dtype, atol, rtol +): + device = "cuda" + + input = torch.randn( + batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device + ) + sin_table, cos_table = _generate_sin_and_cos_tables( + seq_len, emb_dim, dtype=dtype, device=device + ) + + ninetoothed_output = ntops.torch.rotary_position_embedding( + input.clone() if inplace else input, + sin_table, + cos_table, + interleaved=interleaved, + inplace=inplace, + ) + reference_output = _torch_rotary_position_embedding( + input, sin_table, cos_table, interleaved=interleaved + ) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)