Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/ntops/kernels/rotary_position_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import functools

from ninetoothed import Tensor


def arrangement(input, sin_table, cos_table, output, interleaved=True):
emb_dim = input.shape[-1]
tile_shape = (1, 1, 1, emb_dim // 2)

if interleaved:
strides = (-1, -1, -1, 1)
dilation = (1, 1, 1, 2)
else:
strides = None
dilation = None

def _arrange_input_or_output(tensor):
tensor_arranged = tensor.tile(tile_shape, strides=strides, dilation=dilation)
tensor_arranged = tensor_arranged.tile((1, 1, 1, -1))
tensor_arranged.dtype = tensor_arranged.dtype.squeeze((0, 1, 2))
tensor_arranged.dtype.dtype = tensor_arranged.dtype.dtype.squeeze((0, 1, 2))

return tensor_arranged

def _arrange_table(table):
table_arranged = table.tile(tile_shape)
table_arranged.dtype = table_arranged.dtype.squeeze((0, 1, 2))

return table_arranged

input_arranged = _arrange_input_or_output(input)
sin_table_arranged = _arrange_table(sin_table)
cos_table_arranged = _arrange_table(cos_table)
output_arranged = _arrange_input_or_output(output)

return input_arranged, sin_table_arranged, cos_table_arranged, output_arranged


def application(input, sin_table, cos_table, output):
sin_table_loaded = sin_table
cos_table_loaded = cos_table

input_0 = input[0]
input_1 = input[1]

output[0] = input_0 * cos_table_loaded - input_1 * sin_table_loaded
output[1] = input_0 * sin_table_loaded + input_1 * cos_table_loaded


def premake(ndim, emb_dim=None, dtype=None, interleaved=True):
arrangement_ = functools.partial(arrangement, interleaved=interleaved)

shape_options = (None, None, None, {"constexpr": True, "upper_bound": 128})

tensors = (
Tensor(ndim, dtype=dtype, shape_options=shape_options),
Tensor(ndim, dtype=dtype, shape_options=shape_options),
Tensor(ndim, dtype=dtype, shape_options=shape_options),
Tensor(ndim, dtype=dtype, shape_options=shape_options),
)

if emb_dim is not None:
for tensor in tensors:
tensor.shape = tensor.shape[:-1] + (emb_dim,)

return arrangement_, application, tensors
37 changes: 35 additions & 2 deletions src/ntops/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import ntops.kernels.pow
import ntops.kernels.relu
import ntops.kernels.rms_norm
import ntops.kernels.rotary_position_embedding
import ntops.kernels.rsqrt
import ntops.kernels.scaled_dot_product_attention
import ntops.kernels.sigmoid
Expand Down Expand Up @@ -371,6 +372,31 @@ def rms_norm(input, normalized_shape, weight=None, eps=None):
return output


def rotary_position_embedding(
input, sin_table, cos_table, interleaved=True, inplace=False
):
if inplace:
output = input
else:
output = torch.empty_like(input)

batch_size, _, num_heads, _ = input.shape

sin_table = sin_table[None, :, None, :].expand(batch_size, -1, num_heads, -1)
cos_table = cos_table[None, :, None, :].expand(batch_size, -1, num_heads, -1)

kernel = _cached_make(
ntops.kernels.rotary_position_embedding.premake,
input.ndim,
interleaved=interleaved,
num_warps=1,
)

kernel(input, sin_table, cos_table, output)

return output


def rsqrt(input, *, out=None):
if out is None:
out = torch.empty_like(input)
Expand Down Expand Up @@ -538,5 +564,12 @@ def tanh(input, *, out=None):


@functools.cache
def _cached_make(premake, *args, **keywords):
return ninetoothed.make(*premake(*args, **keywords))
def _cached_make(
premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords
):
return ninetoothed.make(
*premake(*args, **keywords),
num_warps=num_warps,
num_stages=num_stages,
max_num_configs=max_num_configs,
)
81 changes: 81 additions & 0 deletions tests/test_rotary_position_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest
import torch

import ntops.torch
from tests.skippers import skip_if_cuda_not_available


def _torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True):
batch_size, seq_len, num_heads, emb_dim = input.shape

assert emb_dim % 2 == 0, "The embedding dimension must be even."

sin_table = sin_table[None, :, None, :]
cos_table = cos_table[None, :, None, :]

if interleaved:
pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2)
input_0, input_1 = pair_wise_input[..., 0], pair_wise_input[..., 1]
input_0_rotated = input_0 * cos_table - input_1 * sin_table
input_1_rotated = input_0 * sin_table + input_1 * cos_table

return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape)
else:
input_0 = input[..., : input.shape[-1] // 2]
input_1 = input[..., input.shape[-1] // 2 :]
input_0_rotated = input_0 * cos_table - input_1 * sin_table
input_1_rotated = input_0 * sin_table + input_1 * cos_table

return torch.cat((input_0_rotated, input_1_rotated), dim=-1)


def _generate_sin_and_cos_tables(
seq_len, emb_dim, base=10000, dtype=torch.float32, device="cuda"
):
assert emb_dim % 2 == 0, "The embedding dimension must be even."

theta = base ** (
-2 * (torch.arange(emb_dim // 2, dtype=dtype, device=device) / emb_dim)
)

positions = torch.arange(seq_len, dtype=dtype, device=device).unsqueeze(1)
sin_table = torch.sin(positions * theta)
cos_table = torch.cos(positions * theta)

return sin_table, cos_table


@skip_if_cuda_not_available
@pytest.mark.parametrize(
"dtype, atol, rtol", ((torch.float32, 0.001, 0), (torch.float16, 0.001, 0.001))
)
@pytest.mark.parametrize("inplace", (False, True))
@pytest.mark.parametrize("interleaved", (False, True))
@pytest.mark.parametrize("emb_dim", (32, 64))
@pytest.mark.parametrize("num_heads", (1, 8))
@pytest.mark.parametrize("seq_len", (1, 128))
@pytest.mark.parametrize("batch_size", (1, 4))
def test_cuda(
batch_size, seq_len, num_heads, emb_dim, interleaved, inplace, dtype, atol, rtol
):
device = "cuda"

input = torch.randn(
batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device
)
sin_table, cos_table = _generate_sin_and_cos_tables(
seq_len, emb_dim, dtype=dtype, device=device
)

ninetoothed_output = ntops.torch.rotary_position_embedding(
input.clone() if inplace else input,
sin_table,
cos_table,
interleaved=interleaved,
inplace=inplace,
)
reference_output = _torch_rotary_position_embedding(
input, sin_table, cos_table, interleaved=interleaved
)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)