From eb1ae65879bc63e20209ecf567473adf2d180235 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Tue, 18 Feb 2025 15:08:38 +0800 Subject: [PATCH] =?UTF-8?q?issue/48/test:=20=E9=87=8D=E6=9E=84rope?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/infiniop/rotary_embedding.py | 174 ++++++++++++++---------------- 1 file changed, 80 insertions(+), 94 deletions(-) diff --git a/test/infiniop/rotary_embedding.py b/test/infiniop/rotary_embedding.py index e4af9a57a..9e9a29866 100644 --- a/test/infiniop/rotary_embedding.py +++ b/test/infiniop/rotary_embedding.py @@ -1,27 +1,31 @@ import ctypes -from ctypes import c_float, POINTER, c_void_p, c_int32, c_uint64, Structure, byref +from ctypes import POINTER, c_void_p, c_int32, c_uint64, Structure, byref import sys import os - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) -from operatorspy import ( - open_lib, - to_tensor, - DeviceEnum, +from libinfiniop import ( infiniopHandle_t, infiniopTensorDescriptor_t, - create_handle, - destroy_handle, + open_lib, + to_tensor, + get_test_devices, check_error, - rearrange_tensor, + rearrange_if_needed, create_workspace, - U64, + test_operator, + get_args, + debug, + profile_operation, + InfiniDtype, ) - -from operatorspy.tests.test_utils import get_args import torch +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + class RoPEDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -40,15 +44,21 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): def rotary_embedding(t, pos, theta, torch_device): dh = t.shape[2] - freqs = (1.0 / (theta ** (torch.arange(0, dh, 2)[: (dh // 2)].float() / dh))).to( - torch_device - ) - freqs = torch.outer(pos, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + assert dh % 2 == 0, "Embedding dimension must be even." + t_even = t[..., 0::2] # [seq_len, n_head, dh // 2] + t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] + freqs = (1.0 / (theta ** (torch.arange(0, dh, 2).float() / dh))).to(torch_device) + freqs = torch.outer(pos, freqs) # [seq_len, dh // 2] + cos = torch.cos(freqs).unsqueeze(1) # [seq_len, 1, dh // 2] + sin = torch.sin(freqs).unsqueeze(1) # [seq_len, 1, dh // 2] + + t_out_even = t_even * cos - t_odd * sin + t_out_odd = t_even * sin + t_odd * cos + + t_out = torch.empty_like(t) + t_out[..., 0::2] = t_out_even + t_out[..., 1::2] = t_out_odd - t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, t_) - t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype) return t_out @@ -71,29 +81,23 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ) t = torch.rand(shape, dtype=dtype) - if strides is not None: - t = rearrange_tensor(t, strides) - posTmp = torch.arange(0, t.shape[0]) + t = rearrange_if_needed(t, strides).to(torch_device) + posTmp = torch.arange(0, t.shape[0]).to(torch_device) pos = torch.zeros(2 * posTmp.shape[0], dtype=torch.int32) for i in range(posTmp.shape[0]): pos[2 * i] = posTmp[i] pos[2 * i + 1] = 0 + pos = pos.to(torch_device) theta = 1e4 - if torch_device == "mlu" or torch_device == "npu": - ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device) - pos = pos.to(torch_device) - t = t.to(torch_device) - else: - t = t.to(torch_device) - pos = pos.to(torch_device) - ans = rotary_embedding(t, posTmp.to(torch_device), theta, torch_device) + + ans = rotary_embedding(t, posTmp, theta, torch_device) descriptor = infiniopRoPEDescriptor_t() # 2x table length for test sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) t_tensor = to_tensor(t, lib) pos_tensor = to_tensor(pos[: t.shape[0]], lib) - pos_tensor.descriptor.contents.dt = U64 + pos_tensor.descriptor.contents.dtype = InfiniDtype.U64 sin_table_tensor = to_tensor(sin_table, lib) cos_table_tensor = to_tensor(cos_table, lib) @@ -122,69 +126,52 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): lib.infiniopGetRoPEWorkspaceSize(descriptor, ctypes.byref(workspace_size)) ) workspace = create_workspace(workspace_size.value, t.device) - check_error( - lib.infiniopRoPE( - descriptor, - workspace.data_ptr() if workspace is not None else None, - workspace_size.value, - t_tensor.data, - pos_tensor.data, - sin_table_tensor.data, - cos_table_tensor.data, - None, + + def lib_rope(): + check_error( + lib.infiniopRoPE( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + t_tensor.data, + pos_tensor.data, + sin_table_tensor.data, + cos_table_tensor.data, + None, + ) ) - ) + lib_rope() + if DEBUG: + debug(t, ans, atol=1e-4, rtol=1e-2) assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2) - check_error(lib.infiniopDestroyRoPEDescriptor(descriptor)) - - -def test_cpu(lib, test_cases): - device = DeviceEnum.DEVICE_CPU - handle = create_handle(lib, device) - for shape, strides, dtype in test_cases: - test(lib, handle, "cpu", shape, strides, dtype) - destroy_handle(lib, handle) - - -def test_cuda(lib, test_cases): - device = DeviceEnum.DEVICE_CUDA - handle = create_handle(lib, device) - for shape, strides, dtype in test_cases: - test(lib, handle, "cuda", shape, strides, dtype) - destroy_handle(lib, handle) - - -def test_bang(lib, test_cases): - import torch_mlu - - device = DeviceEnum.DEVICE_BANG - handle = create_handle(lib, device) - for shape, strides, dtype in test_cases: - test(lib, handle, "mlu", shape, strides, dtype) - destroy_handle(lib, handle) - - -def test_ascend(lib, test_cases): - import torch_npu + if PROFILE: + profile_operation( + "PyTorch", + lambda: rotary_embedding(t, posTmp, theta, torch_device), + torch_device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_rope(), torch_device, NUM_PRERUN, NUM_ITERATIONS + ) - device = DeviceEnum.DEVICE_ASCEND - handle = create_handle(lib, device) - for shape, strides, dtype in test_cases: - test(lib, handle, "npu", shape, strides, dtype) - destroy_handle(lib, handle) + check_error(lib.infiniopDestroyRoPEDescriptor(descriptor)) if __name__ == "__main__": test_cases = [ - ((1, 32, 128), None, torch.float16), - ((1, 32, 64), None, torch.float16), + # (t_shape, t_strides) + ((1, 32, 128), None), + ((1, 32, 64), None), # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心 # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持 - ((4, 1, 32), None, torch.float16), - ((1, 32, 128), None, torch.float16), - ((3, 32, 128), (8000, 200, 1), torch.float16), + ((4, 1, 32), None), + ((1, 32, 128), None), + ((3, 32, 128), (8000, 200, 1)), ] + test_dtypes = [torch.float16] args = get_args() lib = open_lib() lib.infiniopCreateRoPEDescriptor.restype = c_int32 @@ -216,14 +203,13 @@ def test_ascend(lib, test_cases): lib.infiniopDestroyRoPEDescriptor.argtypes = [ infiniopRoPEDescriptor_t, ] - if args.cpu: - test_cpu(lib, test_cases) - if args.cuda: - test_cuda(lib, test_cases) - if args.bang: - test_bang(lib, test_cases) - if args.ascend: - test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): - test_cpu(lib, test_cases) + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(lib, device, test, test_cases, test_dtypes) print("\033[92mTest passed!\033[0m")