Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 80 additions & 94 deletions test/infiniop/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
import ctypes
from ctypes import c_float, POINTER, c_void_p, c_int32, c_uint64, Structure, byref
from ctypes import POINTER, c_void_p, c_int32, c_uint64, Structure, byref
import sys
import os


sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_tensor,
rearrange_if_needed,
create_workspace,
U64,
test_operator,
get_args,
debug,
profile_operation,
InfiniDtype,
)

from operatorspy.tests.test_utils import get_args
import torch

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


class RoPEDescriptor(Structure):
_fields_ = [("device", c_int32)]
Expand All @@ -40,15 +44,21 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):

def rotary_embedding(t, pos, theta, torch_device):
dh = t.shape[2]
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2)[: (dh // 2)].float() / dh))).to(
torch_device
)
freqs = torch.outer(pos, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device)
freqs = torch.outer(pos, freqs) # [seq_len, dh // 2]
cos = torch.cos(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]
sin = torch.sin(freqs).unsqueeze(1) # [seq_len, 1, dh // 2]

t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos

t_out = torch.empty_like(t)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd

t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
return t_out


Expand All @@ -71,29 +81,23 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
)

t = torch.rand(shape, dtype=dtype)
if strides is not None:
t = rearrange_tensor(t, strides)
posTmp = torch.arange(0, t.shape[0])
t = rearrange_if_needed(t, strides).to(torch_device)
posTmp = torch.arange(0, t.shape[0]).to(torch_device)
pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32)
for i in range(posTmp.shape[0]):
pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0
pos = pos.to(torch_device)
theta = 1e4
if torch_device == "mlu" or torch_device == "npu":
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
pos = pos.to(torch_device)
t = t.to(torch_device)
else:
t = t.to(torch_device)
pos = pos.to(torch_device)
ans = rotary_embedding(t, posTmp.to(torch_device), theta, torch_device)

ans = rotary_embedding(t, posTmp, theta, torch_device)

descriptor = infiniopRoPEDescriptor_t()
# 2x table length for test
sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta)
t_tensor = to_tensor(t, lib)
pos_tensor = to_tensor(pos[: t.shape[0]], lib)
pos_tensor.descriptor.contents.dt = U64
pos_tensor.descriptor.contents.dtype = InfiniDtype.U64
sin_table_tensor = to_tensor(sin_table, lib)
cos_table_tensor = to_tensor(cos_table, lib)

Expand Down Expand Up @@ -122,69 +126,52 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, t.device)
check_error(
lib.infiniopRoPE(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
t_tensor.data,
pos_tensor.data,
sin_table_tensor.data,
cos_table_tensor.data,
None,

def lib_rope():
check_error(
lib.infiniopRoPE(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
t_tensor.data,
pos_tensor.data,
sin_table_tensor.data,
cos_table_tensor.data,
None,
)
)
)

lib_rope()
if DEBUG:
debug(t, ans, atol=1e-4, rtol=1e-2)
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2)
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "cpu", shape, strides, dtype)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "cuda", shape, strides, dtype)
destroy_handle(lib, handle)


def test_bang(lib, test_cases):
import torch_mlu

device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "mlu", shape, strides, dtype)
destroy_handle(lib, handle)


def test_ascend(lib, test_cases):
import torch_npu
if PROFILE:
profile_operation(
"PyTorch",
lambda: rotary_embedding(t, posTmp, theta, torch_device),
torch_device,
NUM_PRERUN,
NUM_ITERATIONS,
)
profile_operation(
" lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS
)

device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle)
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))


if __name__ == "__main__":
test_cases = [
((1, 32, 128), None, torch.float16),
((1, 32, 64), None, torch.float16),
# (t_shape, t_strides)
((1, 32, 128), None),
((1, 32, 64), None),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), None, torch.float16),
((1, 32, 128), None, torch.float16),
((3, 32, 128), (8000, 200, 1), torch.float16),
((4, 1, 32), None),
((1, 32, 128), None),
((3, 32, 128), (8000, 200, 1)),
]
test_dtypes = [torch.float16]
args = get_args()
lib = open_lib()
lib.infiniopCreateRoPEDescriptor.restype = c_int32
Expand Down Expand Up @@ -216,14 +203,13 @@ def test_ascend(lib, test_cases):
lib.infiniopDestroyRoPEDescriptor.argtypes = [
infiniopRoPEDescriptor_t,
]
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations

# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, test_cases, test_dtypes)
print("\033[92mTest passed!\033[0m")