diff --git a/test/infiniop/libinfiniop/devices.py b/test/infiniop/libinfiniop/devices.py index af2d11a4f..e5b674e1a 100644 --- a/test/infiniop/libinfiniop/devices.py +++ b/test/infiniop/libinfiniop/devices.py @@ -8,3 +8,17 @@ class InfiniDeviceEnum: ILUVATAR = 6, KUNLUN = 7, SUGON = 8, + + +# Mapping that maps InfiniDeviceEnum to torch device string +infiniDeviceEnum_str_map = { + InfiniDeviceEnum.CPU: "cpu", + InfiniDeviceEnum.NVIDIA: "cuda", + InfiniDeviceEnum.CAMBRICON: "mlu", + InfiniDeviceEnum.ASCEND: "npu", + InfiniDeviceEnum.METAX: "cuda", + InfiniDeviceEnum.MOORE: "musa", + InfiniDeviceEnum.ILUVATAR: "cuda", + InfiniDeviceEnum.KUNLUN: "cuda", + InfiniDeviceEnum.SUGON: "cuda", +} diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index 8c855761e..9cc863f79 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -1,5 +1,7 @@ import ctypes from .datatypes import * +from .devices import * +from typing import Sequence from .liboperators import infiniopTensorDescriptor_t, CTensor, infiniopHandle_t @@ -46,12 +48,15 @@ def to_tensor(tensor, lib): # Create Tensor return CTensor(tensor_desc, data_ptr) + def create_workspace(size, torch_device): + print(f" - Workspace Size : {size}") if size == 0: return None import torch return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device) + def create_handle(lib, device, id=0): handle = infiniopHandle_t() check_error(lib.infiniopCreateHandle(ctypes.byref(handle), device, id)) @@ -106,3 +111,276 @@ def rearrange_tensor(tensor, new_strides): new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides)) return new_tensor + + +def rearrange_if_needed(tensor, stride): + """ + Rearrange a PyTorch tensor if the given stride is not None. + """ + return rearrange_tensor(tensor, stride) if stride is not None else tensor + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser(description="Test Operator") + parser.add_argument( + "--profile", + action="store_true", + help="Whether profile tests", + ) + parser.add_argument( + "--num_prerun", + type=lambda x: max(0, int(x)), + default=10, + help="Set the number of pre-runs before profiling. Default is 10. Must be a non-negative integer.", + ) + parser.add_argument( + "--num_iterations", + type=lambda x: max(0, int(x)), + default=1000, + help="Set the number of iterations for profiling. Default is 1000. Must be a non-negative integer.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Whether to turn on debug mode. If turned on, it will display detailed information about the tensors and discrepancies.", + ) + parser.add_argument( + "--cpu", + action="store_true", + help="Run CPU test", + ) + parser.add_argument( + "--nvidia", + action="store_true", + help="Run NVIDIA GPU test", + ) + parser.add_argument( + "--cambricon", + action="store_true", + help="Run Cambricon MLU test", + ) + parser.add_argument( + "--ascend", + action="store_true", + help="Run ASCEND NPU test", + ) + + return parser.parse_args() + + +def synchronize_device(torch_device): + import torch + if torch_device == "cuda": + torch.cuda.synchronize() + elif torch_device == "npu": + torch.npu.synchronize() + elif torch_device == "mlu": + torch.mlu.synchronize() + + +def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): + """ + Debugging function to compare two tensors (actual and desired) and print discrepancies. + Arguments: + ---------- + - actual : The tensor containing the actual computed values. + - desired : The tensor containing the expected values that `actual` should be compared to. + - atol : optional (default=0) + The absolute tolerance for the comparison. + - rtol : optional (default=1e-2) + The relative tolerance for the comparison. + - equal_nan : bool, optional (default=False) + If True, `NaN` values in `actual` and `desired` will be considered equal. + - verbose : bool, optional (default=True) + If True, the function will print detailed information about any discrepancies between the tensors. + """ + import numpy as np + print_discrepancy(actual, desired, atol, rtol, verbose) + np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True) + + +def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, atol=0, rtol=1e-2, equal_nan=False, verbose=True): + """ + Debugging function to compare two sequences of values (actual and desired) pair by pair, results + are linked by the given logical condition, and prints discrepancies + Arguments: + ---------- + - actual_vals (Sequence): A sequence (e.g., list or tuple) of actual computed values. + - desired_vals (Sequence): A sequence (e.g., list or tuple) of desired (expected) values to compare against. + - condition (str): A string specifying the condition for passing the test. It must be either: + - 'or': Test passes if any pair of actual and desired values satisfies the tolerance criteria. + - 'and': Test passes if all pairs of actual and desired values satisfy the tolerance criteria. + - atol (float, optional): Absolute tolerance. Default is 0. + - rtol (float, optional): Relative tolerance. Default is 1e-2. + - equal_nan (bool, optional): If True, NaN values in both actual and desired are considered equal. Default is False. + - verbose (bool, optional): If True, detailed output is printed for each comparison. Default is True. + Raises: + ---------- + - AssertionError: If the condition is not satisfied based on the provided `condition`, `atol`, and `rtol`. + - ValueError: If the length of `actual_vals` and `desired_vals` do not match. + - AssertionError: If the specified `condition` is not 'or' or 'and'. + """ + assert len(actual_vals) == len(desired_vals), "Invalid Length" + assert condition in {"or", "and"}, "Invalid condition: should be either 'or' or 'and'" + import numpy as np + + passed = False if condition == "or" else True + + for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)): + print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}") + indices = print_discrepancy(actual, desired, atol, rtol, verbose) + if condition == "or": + if not passed and len(indices) == 0: + passed = True + elif condition == "and": + if passed and len(indices) != 0: + passed = False + print(f"\033[31mThe condition has not been satisfied: Condition #{index + 1}\033[0m") + np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True) + assert passed, "\033[31mThe condition has not been satisfied\033[0m" + + +def print_discrepancy( + actual, expected, atol=0, rtol=1e-3, verbose=True +): + if actual.shape != expected.shape: + raise ValueError("Tensors must have the same shape to compare.") + + import torch + import sys + + is_terminal = sys.stdout.isatty() + + # Calculate the difference mask based on atol and rtol + diff_mask = torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)) + diff_indices = torch.nonzero(diff_mask, as_tuple=False) + delta = actual - expected + + # Display format: widths for columns + col_width = [18, 20, 20, 20] + decimal_places = [0, 12, 12, 12] + total_width = sum(col_width) + sum(decimal_places) + + def add_color(text, color_code): + if is_terminal: + return f"\033[{color_code}m{text}\033[0m" + else: + return text + + if verbose: + for idx in diff_indices: + index_tuple = tuple(idx.tolist()) + actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}" + expected_str = f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}" + delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}" + print( + f" > Index: {str(index_tuple):<{col_width[0]}}" + f"actual: {add_color(actual_str, 31)}" + f"expect: {add_color(expected_str, 32)}" + f"delta: {add_color(delta_str, 33)}" + ) + + print(add_color(" INFO:", 35)) + print(f" - Actual dtype: {actual.dtype}") + print(f" - Desired dtype: {expected.dtype}") + print(f" - Atol: {atol}") + print(f" - Rtol: {rtol}") + print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)") + print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}") + print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}") + print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}") + print("-" * total_width + "\n") + + return diff_indices + + +def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3): + """ + Returns the atol and rtol for a given tensor data type in the tolerance_map. + If the given data type is not found, it returns the provided default tolerance values. + """ + return tolerance_map.get(tensor_dtype, {'atol': default_atol, 'rtol': default_rtol}).values() + + +def timed_op(func, num_iterations, device): + import time + """ Function for timing operations with synchronization. """ + synchronize_device(device) + start = time.time() + for _ in range(num_iterations): + func() + synchronize_device(device) + return (time.time() - start) / num_iterations + + +def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS): + """ + Unified profiling workflow that is used to profile the execution time of a given function. + It first performs a number of warmup runs, then performs timed execution and + prints the average execution time. + + Arguments: + ---------- + - desc (str): Description of the operation, used for output display. + - func (callable): The operation function to be profiled. + - torch_device (str): The device on which the operation runs, provided for timed execution. + - NUM_PRERUN (int): The number of warmup runs. + - NUM_ITERATIONS (int): The number of timed execution iterations, used to calculate the average execution time. + """ + # Warmup runs + for _ in range(NUM_PRERUN): + func() + + # Timed execution + elapsed = timed_op(lambda: func(), NUM_ITERATIONS, torch_device) + print(f" {desc} time: {elapsed * 1000 :6f} ms") + + +def test_operator(lib, device, test_func, test_cases, tensor_dtypes): + """ + Testing a specified operator on the given device with the given test function, test cases, and tensor data types. + + Arguments: + ---------- + - lib (ctypes.CDLL): The library object containing the operator implementations. + - device (InfiniDeviceEnum): The device on which the operator should be tested. See device.py. + - test_func (function): The test function to be executed for each test case. + - test_cases (list of tuples): A list of test cases, where each test case is a tuple of parameters + to be passed to `test_func`. + - tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test. + """ + handle = create_handle(lib, device) + try: + for test_case in test_cases: + for tensor_dtype in tensor_dtypes: + test_func(lib, handle, infiniDeviceEnum_str_map[device], *test_case, tensor_dtype) + finally: + destroy_handle(lib, handle) + + +def get_test_devices(args): + """ + Using the given parsed Namespace to determine the devices to be tested. + + Argument: + - args: the parsed Namespace object. + + Return: + - devices_to_test: the devices that will be tested. Default is CPU. + """ + devices_to_test = [] + + if args.cpu: devices_to_test.append(InfiniDeviceEnum.CPU) + if args.nvidia: devices_to_test.append(InfiniDeviceEnum.NVIDIA) + if args.cambricon: + import torch_mlu + devices_to_test.append(InfiniDeviceEnum.CAMBRICON) + if args.ascend: + import torch_npu + devices_to_test.append(InfiniDeviceEnum.ASCEND) + if not devices_to_test: + devices_to_test = [InfiniDeviceEnum.CPU] + + return devices_to_test diff --git a/test/infiniop/matmul.py b/test/infiniop/matmul.py index 6fbe54684..bbe422743 100644 --- a/test/infiniop/matmul.py +++ b/test/infiniop/matmul.py @@ -1,50 +1,62 @@ -from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float +import torch import ctypes -import sys -import os -import time - -sys.path.append("..") - +from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float from libinfiniop import ( - open_lib, - to_tensor, - CTensor, - InfiniDeviceEnum, - infiniopHandle_t, - infiniopTensorDescriptor_t, - create_handle, - destroy_handle, - check_error, - rearrange_tensor, - create_workspace, + infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices, + check_error, rearrange_if_needed, create_workspace, test_operator, get_args, + debug, get_tolerance, profile_operation, ) -from test_utils import get_args, synchronize_device -import torch - +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None), + (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), + (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), + (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), + (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), + (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None), + (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None), +] + +# Data types used for testing +_TENSOR_DTYPES = [torch.float16, torch.float32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + torch.float16: {'atol': 0, 'rtol': 1e-2}, + torch.float32: {'atol': 0, 'rtol': 1e-3}, +} + +DEBUG = False PROFILE = False NUM_PRERUN = 10 NUM_ITERATIONS = 1000 +# ============================================================================== +# Definitions +# ============================================================================== class MatmulDescriptor(Structure): _fields_ = [("device", c_int32)] infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor) +# PyTorch implementation for matrix multiplication def matmul(_c, beta, _a, _b, alpha): - a = _a.clone() - b = _b.clone() - c = _c.clone() - input_dtype = c.dtype - ans = ( - alpha * torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(input_dtype) - + beta * c - ) - return ans - + a, b, c = _a.clone(), _b.clone(), _c.clone() + result_dtype = c.dtype + fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + return alpha * fp32_result.to(result_dtype) + beta * c +# The argument list should be (lib, handle, torch_device, , dtype) +# The should keep the same order as the one specified in _TEST_CASES def test( lib, handle, @@ -60,26 +72,22 @@ def test( dtype=torch.float16, ): print( - f"Testing Matmul on {torch_device} with a_shape:{a_shape} b_shape:{b_shape} c_shape:{c_shape}" - f" a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}" + f"Testing Matmul on {torch_device} with alpha:{alpha}, beta:{beta}," + f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape}," + f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{dtype}" ) + # Initialize tensors a = torch.rand(a_shape, dtype=dtype).to(torch_device) b = torch.rand(b_shape, dtype=dtype).to(torch_device) c = torch.ones(c_shape, dtype=dtype).to(torch_device) + # Compute the PyTorch reference result ans = matmul(c, beta, a, b, alpha) - if a_stride is not None: - a = rearrange_tensor(a, a_stride) - if b_stride is not None: - b = rearrange_tensor(b, b_stride) - if c_stride is not None: - c = rearrange_tensor(c, c_stride) + a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])] + a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]] - a_tensor = to_tensor(a, lib) - b_tensor = to_tensor(b, lib) - c_tensor = to_tensor(c, lib) descriptor = infiniopMatmulDescriptor_t() check_error( lib.infiniopCreateMatmulDescriptor( @@ -92,20 +100,19 @@ def test( ) # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel - a_tensor.descriptor.contents.invalidate() - b_tensor.descriptor.contents.invalidate() - c_tensor.descriptor.contents.invalidate() + for tensor in [a_tensor, b_tensor, c_tensor]: + tensor.descriptor.contents.invalidate() + # Get workspace size and create workspace workspace_size = c_uint64(0) - check_error( - lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size)) - ) + check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))) workspace = create_workspace(workspace_size.value, a.device) - check_error( - lib.infiniopMatmul( - descriptor, - workspace.data_ptr() if workspace is not None else None, + # Execute infiniop matmul operator + def lib_matmul(): + check_error(lib.infiniopMatmul( + descriptor, + workspace.data_ptr() if workspace else None, workspace_size.value, c_tensor.data, a_tensor.data, @@ -113,201 +120,27 @@ def test( alpha, beta, None, - ) - ) + )) + lib_matmul() - assert torch.allclose(c, ans, atol=0, rtol=1e-2) + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(c, ans, atol=atol, rtol=rtol) + assert torch.allclose(c, ans, atol=atol, rtol=rtol) + # Profiling workflow if PROFILE: - for i in range(NUM_PRERUN): - _ = matmul(c, beta, a, b, alpha) - synchronize_device(torch_device) - start_time = time.time() - for i in range(NUM_ITERATIONS): - _ = matmul(c, beta, a, b, alpha) - synchronize_device(torch_device) - elapsed = (time.time() - start_time) / NUM_ITERATIONS - print(f" pytorch time: {elapsed * 1000 :6f} ms") - for i in range(NUM_PRERUN): - check_error( - lib.infiniopMatmul( - descriptor, - workspace.data_ptr() if workspace is not None else None, - workspace_size.value, - c_tensor.data, - a_tensor.data, - b_tensor.data, - None, - ) - ) - synchronize_device(torch_device) - start_time = time.time() - for i in range(NUM_ITERATIONS): - check_error( - lib.infiniopMatmul( - descriptor, - workspace.data_ptr() if workspace is not None else None, - workspace_size.value, - c_tensor.data, - a_tensor.data, - b_tensor.data, - None, - ) - ) - synchronize_device(torch_device) - elapsed = (time.time() - start_time) / NUM_ITERATIONS - print(f" lib time: {elapsed * 1000 :6f} ms") + profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS) check_error(lib.infiniopDestroyMatmulDescriptor(descriptor)) -def test_cpu(lib, test_cases): - device = InfiniDeviceEnum.CPU - handle = create_handle(lib, device) - - for ( - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) in test_cases: - test( - lib, - handle, - "cpu", - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) - - destroy_handle(lib, handle) - - -def test_nvidia(lib, test_cases): - device = InfiniDeviceEnum.NVIDIA - handle = create_handle(lib, device) - - for ( - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) in test_cases: - test( - lib, - handle, - "cuda", - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) - - destroy_handle(lib, handle) - - -def test_cambricon(lib, test_cases): - import torch_mlu - device = InfiniDeviceEnum.CAMBRICON - handle = create_handle(lib, device) - - for ( - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) in test_cases: - test( - lib, - handle, - "mlu", - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) - - destroy_handle(lib, handle) - -def test_ascend(lib, test_cases): - import torch_npu - - device = InfiniDeviceEnum.ASCEND - handle = create_handle(lib, device) - - for ( - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) in test_cases: - test( - lib, - handle, - "npu", - alpha, - beta, - a_shape, - b_shape, - c_shape, - a_stride, - b_stride, - c_stride, - dtype, - ) - - destroy_handle(lib, handle) - +# ============================================================================== +# Main Execution +# ============================================================================== if __name__ == "__main__": - test_cases = [ - # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype - (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float16), - (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float32), - (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float16), - (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float32), - (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float16), - (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float32), - (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1), torch.float16), - (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1), torch.float32), - (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None, torch.float16), - (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None, torch.float32), - ] args = get_args() lib = open_lib() @@ -344,16 +177,14 @@ def test_ascend(lib, test_cases): infiniopMatmulDescriptor_t, ] - if args.profile: - PROFILE = True - if args.cpu: - test_cpu(lib, test_cases) - if args.nvidia: - test_nvidia(lib, test_cases) - if args.cambricon: - test_cambricon(lib, test_cases) - if args.ascend: - test_ascend(lib, test_cases) - if not (args.cpu or args.nvidia or args.cambricon or args.ascend): - test_cpu(lib, test_cases) + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/test_utils.py b/test/infiniop/test_utils.py deleted file mode 100644 index a4624cbeb..000000000 --- a/test/infiniop/test_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -def get_args(): - import argparse - - parser = argparse.ArgumentParser(description="Test Operator") - parser.add_argument( - "--profile", - action="store_true", - help="Whether profile tests", - ) - parser.add_argument( - "--cpu", - action="store_true", - help="Run CPU test", - ) - parser.add_argument( - "--nvidia", - action="store_true", - help="Run NVIDIA GPU test", - ) - parser.add_argument( - "--cambricon", - action="store_true", - help="Run Cambricon MLU test", - ) - parser.add_argument( - "--ascend", - action="store_true", - help="Run ASCEND NPU test", - ) - - return parser.parse_args() - - -def synchronize_device(torch_device): - import torch - if torch_device == "cuda": - torch.cuda.synchronize() - elif torch_device == "npu": - torch.npu.synchronize() - elif torch_device == "mlu": - torch.mlu.synchronize()