diff --git a/test/infiniop/libinfiniop/devices.py b/test/infiniop/libinfiniop/devices.py
index af2d11a4f..e5b674e1a 100644
--- a/test/infiniop/libinfiniop/devices.py
+++ b/test/infiniop/libinfiniop/devices.py
@@ -8,3 +8,17 @@ class InfiniDeviceEnum:
ILUVATAR = 6,
KUNLUN = 7,
SUGON = 8,
+
+
+# Mapping that maps InfiniDeviceEnum to torch device string
+infiniDeviceEnum_str_map = {
+ InfiniDeviceEnum.CPU: "cpu",
+ InfiniDeviceEnum.NVIDIA: "cuda",
+ InfiniDeviceEnum.CAMBRICON: "mlu",
+ InfiniDeviceEnum.ASCEND: "npu",
+ InfiniDeviceEnum.METAX: "cuda",
+ InfiniDeviceEnum.MOORE: "musa",
+ InfiniDeviceEnum.ILUVATAR: "cuda",
+ InfiniDeviceEnum.KUNLUN: "cuda",
+ InfiniDeviceEnum.SUGON: "cuda",
+}
diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py
index 8c855761e..9cc863f79 100644
--- a/test/infiniop/libinfiniop/utils.py
+++ b/test/infiniop/libinfiniop/utils.py
@@ -1,5 +1,7 @@
import ctypes
from .datatypes import *
+from .devices import *
+from typing import Sequence
from .liboperators import infiniopTensorDescriptor_t, CTensor, infiniopHandle_t
@@ -46,12 +48,15 @@ def to_tensor(tensor, lib):
# Create Tensor
return CTensor(tensor_desc, data_ptr)
+
def create_workspace(size, torch_device):
+ print(f" - Workspace Size : {size}")
if size == 0:
return None
import torch
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)
+
def create_handle(lib, device, id=0):
handle = infiniopHandle_t()
check_error(lib.infiniopCreateHandle(ctypes.byref(handle), device, id))
@@ -106,3 +111,276 @@ def rearrange_tensor(tensor, new_strides):
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor
+
+
+def rearrange_if_needed(tensor, stride):
+ """
+ Rearrange a PyTorch tensor if the given stride is not None.
+ """
+ return rearrange_tensor(tensor, stride) if stride is not None else tensor
+
+
+def get_args():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Test Operator")
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ help="Whether profile tests",
+ )
+ parser.add_argument(
+ "--num_prerun",
+ type=lambda x: max(0, int(x)),
+ default=10,
+ help="Set the number of pre-runs before profiling. Default is 10. Must be a non-negative integer.",
+ )
+ parser.add_argument(
+ "--num_iterations",
+ type=lambda x: max(0, int(x)),
+ default=1000,
+ help="Set the number of iterations for profiling. Default is 1000. Must be a non-negative integer.",
+ )
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ help="Whether to turn on debug mode. If turned on, it will display detailed information about the tensors and discrepancies.",
+ )
+ parser.add_argument(
+ "--cpu",
+ action="store_true",
+ help="Run CPU test",
+ )
+ parser.add_argument(
+ "--nvidia",
+ action="store_true",
+ help="Run NVIDIA GPU test",
+ )
+ parser.add_argument(
+ "--cambricon",
+ action="store_true",
+ help="Run Cambricon MLU test",
+ )
+ parser.add_argument(
+ "--ascend",
+ action="store_true",
+ help="Run ASCEND NPU test",
+ )
+
+ return parser.parse_args()
+
+
+def synchronize_device(torch_device):
+ import torch
+ if torch_device == "cuda":
+ torch.cuda.synchronize()
+ elif torch_device == "npu":
+ torch.npu.synchronize()
+ elif torch_device == "mlu":
+ torch.mlu.synchronize()
+
+
+def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
+ """
+ Debugging function to compare two tensors (actual and desired) and print discrepancies.
+ Arguments:
+ ----------
+ - actual : The tensor containing the actual computed values.
+ - desired : The tensor containing the expected values that `actual` should be compared to.
+ - atol : optional (default=0)
+ The absolute tolerance for the comparison.
+ - rtol : optional (default=1e-2)
+ The relative tolerance for the comparison.
+ - equal_nan : bool, optional (default=False)
+ If True, `NaN` values in `actual` and `desired` will be considered equal.
+ - verbose : bool, optional (default=True)
+ If True, the function will print detailed information about any discrepancies between the tensors.
+ """
+ import numpy as np
+ print_discrepancy(actual, desired, atol, rtol, verbose)
+ np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True)
+
+
+def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
+ """
+ Debugging function to compare two sequences of values (actual and desired) pair by pair, results
+ are linked by the given logical condition, and prints discrepancies
+ Arguments:
+ ----------
+ - actual_vals (Sequence): A sequence (e.g., list or tuple) of actual computed values.
+ - desired_vals (Sequence): A sequence (e.g., list or tuple) of desired (expected) values to compare against.
+ - condition (str): A string specifying the condition for passing the test. It must be either:
+ - 'or': Test passes if any pair of actual and desired values satisfies the tolerance criteria.
+ - 'and': Test passes if all pairs of actual and desired values satisfy the tolerance criteria.
+ - atol (float, optional): Absolute tolerance. Default is 0.
+ - rtol (float, optional): Relative tolerance. Default is 1e-2.
+ - equal_nan (bool, optional): If True, NaN values in both actual and desired are considered equal. Default is False.
+ - verbose (bool, optional): If True, detailed output is printed for each comparison. Default is True.
+ Raises:
+ ----------
+ - AssertionError: If the condition is not satisfied based on the provided `condition`, `atol`, and `rtol`.
+ - ValueError: If the length of `actual_vals` and `desired_vals` do not match.
+ - AssertionError: If the specified `condition` is not 'or' or 'and'.
+ """
+ assert len(actual_vals) == len(desired_vals), "Invalid Length"
+ assert condition in {"or", "and"}, "Invalid condition: should be either 'or' or 'and'"
+ import numpy as np
+
+ passed = False if condition == "or" else True
+
+ for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)):
+ print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}")
+ indices = print_discrepancy(actual, desired, atol, rtol, verbose)
+ if condition == "or":
+ if not passed and len(indices) == 0:
+ passed = True
+ elif condition == "and":
+ if passed and len(indices) != 0:
+ passed = False
+ print(f"\033[31mThe condition has not been satisfied: Condition #{index + 1}\033[0m")
+ np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True)
+ assert passed, "\033[31mThe condition has not been satisfied\033[0m"
+
+
+def print_discrepancy(
+ actual, expected, atol=0, rtol=1e-3, verbose=True
+):
+ if actual.shape != expected.shape:
+ raise ValueError("Tensors must have the same shape to compare.")
+
+ import torch
+ import sys
+
+ is_terminal = sys.stdout.isatty()
+
+ # Calculate the difference mask based on atol and rtol
+ diff_mask = torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
+ diff_indices = torch.nonzero(diff_mask, as_tuple=False)
+ delta = actual - expected
+
+ # Display format: widths for columns
+ col_width = [18, 20, 20, 20]
+ decimal_places = [0, 12, 12, 12]
+ total_width = sum(col_width) + sum(decimal_places)
+
+ def add_color(text, color_code):
+ if is_terminal:
+ return f"\033[{color_code}m{text}\033[0m"
+ else:
+ return text
+
+ if verbose:
+ for idx in diff_indices:
+ index_tuple = tuple(idx.tolist())
+ actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
+ expected_str = f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
+ delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
+ print(
+ f" > Index: {str(index_tuple):<{col_width[0]}}"
+ f"actual: {add_color(actual_str, 31)}"
+ f"expect: {add_color(expected_str, 32)}"
+ f"delta: {add_color(delta_str, 33)}"
+ )
+
+ print(add_color(" INFO:", 35))
+ print(f" - Actual dtype: {actual.dtype}")
+ print(f" - Desired dtype: {expected.dtype}")
+ print(f" - Atol: {atol}")
+ print(f" - Rtol: {rtol}")
+ print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)")
+ print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}")
+ print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}")
+ print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}")
+ print("-" * total_width + "\n")
+
+ return diff_indices
+
+
+def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
+ """
+ Returns the atol and rtol for a given tensor data type in the tolerance_map.
+ If the given data type is not found, it returns the provided default tolerance values.
+ """
+ return tolerance_map.get(tensor_dtype, {'atol': default_atol, 'rtol': default_rtol}).values()
+
+
+def timed_op(func, num_iterations, device):
+ import time
+ """ Function for timing operations with synchronization. """
+ synchronize_device(device)
+ start = time.time()
+ for _ in range(num_iterations):
+ func()
+ synchronize_device(device)
+ return (time.time() - start) / num_iterations
+
+
+def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS):
+ """
+ Unified profiling workflow that is used to profile the execution time of a given function.
+ It first performs a number of warmup runs, then performs timed execution and
+ prints the average execution time.
+
+ Arguments:
+ ----------
+ - desc (str): Description of the operation, used for output display.
+ - func (callable): The operation function to be profiled.
+ - torch_device (str): The device on which the operation runs, provided for timed execution.
+ - NUM_PRERUN (int): The number of warmup runs.
+ - NUM_ITERATIONS (int): The number of timed execution iterations, used to calculate the average execution time.
+ """
+ # Warmup runs
+ for _ in range(NUM_PRERUN):
+ func()
+
+ # Timed execution
+ elapsed = timed_op(lambda: func(), NUM_ITERATIONS, torch_device)
+ print(f" {desc} time: {elapsed * 1000 :6f} ms")
+
+
+def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
+ """
+ Testing a specified operator on the given device with the given test function, test cases, and tensor data types.
+
+ Arguments:
+ ----------
+ - lib (ctypes.CDLL): The library object containing the operator implementations.
+ - device (InfiniDeviceEnum): The device on which the operator should be tested. See device.py.
+ - test_func (function): The test function to be executed for each test case.
+ - test_cases (list of tuples): A list of test cases, where each test case is a tuple of parameters
+ to be passed to `test_func`.
+ - tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test.
+ """
+ handle = create_handle(lib, device)
+ try:
+ for test_case in test_cases:
+ for tensor_dtype in tensor_dtypes:
+ test_func(lib, handle, infiniDeviceEnum_str_map[device], *test_case, tensor_dtype)
+ finally:
+ destroy_handle(lib, handle)
+
+
+def get_test_devices(args):
+ """
+ Using the given parsed Namespace to determine the devices to be tested.
+
+ Argument:
+ - args: the parsed Namespace object.
+
+ Return:
+ - devices_to_test: the devices that will be tested. Default is CPU.
+ """
+ devices_to_test = []
+
+ if args.cpu: devices_to_test.append(InfiniDeviceEnum.CPU)
+ if args.nvidia: devices_to_test.append(InfiniDeviceEnum.NVIDIA)
+ if args.cambricon:
+ import torch_mlu
+ devices_to_test.append(InfiniDeviceEnum.CAMBRICON)
+ if args.ascend:
+ import torch_npu
+ devices_to_test.append(InfiniDeviceEnum.ASCEND)
+ if not devices_to_test:
+ devices_to_test = [InfiniDeviceEnum.CPU]
+
+ return devices_to_test
diff --git a/test/infiniop/matmul.py b/test/infiniop/matmul.py
index 6fbe54684..bbe422743 100644
--- a/test/infiniop/matmul.py
+++ b/test/infiniop/matmul.py
@@ -1,50 +1,62 @@
-from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
+import torch
import ctypes
-import sys
-import os
-import time
-
-sys.path.append("..")
-
+from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
- open_lib,
- to_tensor,
- CTensor,
- InfiniDeviceEnum,
- infiniopHandle_t,
- infiniopTensorDescriptor_t,
- create_handle,
- destroy_handle,
- check_error,
- rearrange_tensor,
- create_workspace,
+ infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices,
+ check_error, rearrange_if_needed, create_workspace, test_operator, get_args,
+ debug, get_tolerance, profile_operation,
)
-from test_utils import get_args, synchronize_device
-import torch
-
+# ==============================================================================
+# Configuration (Internal Use Only)
+# ==============================================================================
+# These are not meant to be imported from other modules
+_TEST_CASES = [
+ # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
+ (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
+ (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
+ (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
+ (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
+ (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
+ (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
+ (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
+ (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
+ (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
+ (1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
+]
+
+# Data types used for testing
+_TENSOR_DTYPES = [torch.float16, torch.float32]
+
+# Tolerance map for different data types
+_TOLERANCE_MAP = {
+ torch.float16: {'atol': 0, 'rtol': 1e-2},
+ torch.float32: {'atol': 0, 'rtol': 1e-3},
+}
+
+DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
+# ==============================================================================
+# Definitions
+# ==============================================================================
class MatmulDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)
+# PyTorch implementation for matrix multiplication
def matmul(_c, beta, _a, _b, alpha):
- a = _a.clone()
- b = _b.clone()
- c = _c.clone()
- input_dtype = c.dtype
- ans = (
- alpha * torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(input_dtype)
- + beta * c
- )
- return ans
-
+ a, b, c = _a.clone(), _b.clone(), _c.clone()
+ result_dtype = c.dtype
+ fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
+ return alpha * fp32_result.to(result_dtype) + beta * c
+# The argument list should be (lib, handle, torch_device, , dtype)
+# The should keep the same order as the one specified in _TEST_CASES
def test(
lib,
handle,
@@ -60,26 +72,22 @@ def test(
dtype=torch.float16,
):
print(
- f"Testing Matmul on {torch_device} with a_shape:{a_shape} b_shape:{b_shape} c_shape:{c_shape}"
- f" a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}"
+ f"Testing Matmul on {torch_device} with alpha:{alpha}, beta:{beta},"
+ f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
+ f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{dtype}"
)
+ # Initialize tensors
a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device)
+ # Compute the PyTorch reference result
ans = matmul(c, beta, a, b, alpha)
- if a_stride is not None:
- a = rearrange_tensor(a, a_stride)
- if b_stride is not None:
- b = rearrange_tensor(b, b_stride)
- if c_stride is not None:
- c = rearrange_tensor(c, c_stride)
+ a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])]
+ a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
- a_tensor = to_tensor(a, lib)
- b_tensor = to_tensor(b, lib)
- c_tensor = to_tensor(c, lib)
descriptor = infiniopMatmulDescriptor_t()
check_error(
lib.infiniopCreateMatmulDescriptor(
@@ -92,20 +100,19 @@ def test(
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
- a_tensor.descriptor.contents.invalidate()
- b_tensor.descriptor.contents.invalidate()
- c_tensor.descriptor.contents.invalidate()
+ for tensor in [a_tensor, b_tensor, c_tensor]:
+ tensor.descriptor.contents.invalidate()
+ # Get workspace size and create workspace
workspace_size = c_uint64(0)
- check_error(
- lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))
- )
+ check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size)))
workspace = create_workspace(workspace_size.value, a.device)
- check_error(
- lib.infiniopMatmul(
- descriptor,
- workspace.data_ptr() if workspace is not None else None,
+ # Execute infiniop matmul operator
+ def lib_matmul():
+ check_error(lib.infiniopMatmul(
+ descriptor,
+ workspace.data_ptr() if workspace else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
@@ -113,201 +120,27 @@ def test(
alpha,
beta,
None,
- )
- )
+ ))
+ lib_matmul()
- assert torch.allclose(c, ans, atol=0, rtol=1e-2)
+ # Validate results
+ atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
+ if DEBUG:
+ debug(c, ans, atol=atol, rtol=rtol)
+ assert torch.allclose(c, ans, atol=atol, rtol=rtol)
+ # Profiling workflow
if PROFILE:
- for i in range(NUM_PRERUN):
- _ = matmul(c, beta, a, b, alpha)
- synchronize_device(torch_device)
- start_time = time.time()
- for i in range(NUM_ITERATIONS):
- _ = matmul(c, beta, a, b, alpha)
- synchronize_device(torch_device)
- elapsed = (time.time() - start_time) / NUM_ITERATIONS
- print(f" pytorch time: {elapsed * 1000 :6f} ms")
- for i in range(NUM_PRERUN):
- check_error(
- lib.infiniopMatmul(
- descriptor,
- workspace.data_ptr() if workspace is not None else None,
- workspace_size.value,
- c_tensor.data,
- a_tensor.data,
- b_tensor.data,
- None,
- )
- )
- synchronize_device(torch_device)
- start_time = time.time()
- for i in range(NUM_ITERATIONS):
- check_error(
- lib.infiniopMatmul(
- descriptor,
- workspace.data_ptr() if workspace is not None else None,
- workspace_size.value,
- c_tensor.data,
- a_tensor.data,
- b_tensor.data,
- None,
- )
- )
- synchronize_device(torch_device)
- elapsed = (time.time() - start_time) / NUM_ITERATIONS
- print(f" lib time: {elapsed * 1000 :6f} ms")
+ profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
+ profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))
-def test_cpu(lib, test_cases):
- device = InfiniDeviceEnum.CPU
- handle = create_handle(lib, device)
-
- for (
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- ) in test_cases:
- test(
- lib,
- handle,
- "cpu",
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- )
-
- destroy_handle(lib, handle)
-
-
-def test_nvidia(lib, test_cases):
- device = InfiniDeviceEnum.NVIDIA
- handle = create_handle(lib, device)
-
- for (
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- ) in test_cases:
- test(
- lib,
- handle,
- "cuda",
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- )
-
- destroy_handle(lib, handle)
-
-
-def test_cambricon(lib, test_cases):
- import torch_mlu
- device = InfiniDeviceEnum.CAMBRICON
- handle = create_handle(lib, device)
-
- for (
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- ) in test_cases:
- test(
- lib,
- handle,
- "mlu",
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- )
-
- destroy_handle(lib, handle)
-
-def test_ascend(lib, test_cases):
- import torch_npu
-
- device = InfiniDeviceEnum.ASCEND
- handle = create_handle(lib, device)
-
- for (
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- ) in test_cases:
- test(
- lib,
- handle,
- "npu",
- alpha,
- beta,
- a_shape,
- b_shape,
- c_shape,
- a_stride,
- b_stride,
- c_stride,
- dtype,
- )
-
- destroy_handle(lib, handle)
-
+# ==============================================================================
+# Main Execution
+# ==============================================================================
if __name__ == "__main__":
- test_cases = [
- # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype
- (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float16),
- (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float32),
- (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float16),
- (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float32),
- (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float16),
- (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float32),
- (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1), torch.float16),
- (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1), torch.float32),
- (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None, torch.float16),
- (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None, torch.float32),
- ]
args = get_args()
lib = open_lib()
@@ -344,16 +177,14 @@ def test_ascend(lib, test_cases):
infiniopMatmulDescriptor_t,
]
- if args.profile:
- PROFILE = True
- if args.cpu:
- test_cpu(lib, test_cases)
- if args.nvidia:
- test_nvidia(lib, test_cases)
- if args.cambricon:
- test_cambricon(lib, test_cases)
- if args.ascend:
- test_ascend(lib, test_cases)
- if not (args.cpu or args.nvidia or args.cambricon or args.ascend):
- test_cpu(lib, test_cases)
+ # Configure testing options
+ DEBUG = args.debug
+ PROFILE = args.profile
+ NUM_PRERUN = args.num_prerun
+ NUM_ITERATIONS = args.num_iterations
+
+ # Execute tests
+ for device in get_test_devices(args):
+ test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
+
print("\033[92mTest passed!\033[0m")
diff --git a/test/infiniop/test_utils.py b/test/infiniop/test_utils.py
deleted file mode 100644
index a4624cbeb..000000000
--- a/test/infiniop/test_utils.py
+++ /dev/null
@@ -1,41 +0,0 @@
-def get_args():
- import argparse
-
- parser = argparse.ArgumentParser(description="Test Operator")
- parser.add_argument(
- "--profile",
- action="store_true",
- help="Whether profile tests",
- )
- parser.add_argument(
- "--cpu",
- action="store_true",
- help="Run CPU test",
- )
- parser.add_argument(
- "--nvidia",
- action="store_true",
- help="Run NVIDIA GPU test",
- )
- parser.add_argument(
- "--cambricon",
- action="store_true",
- help="Run Cambricon MLU test",
- )
- parser.add_argument(
- "--ascend",
- action="store_true",
- help="Run ASCEND NPU test",
- )
-
- return parser.parse_args()
-
-
-def synchronize_device(torch_device):
- import torch
- if torch_device == "cuda":
- torch.cuda.synchronize()
- elif torch_device == "npu":
- torch.npu.synchronize()
- elif torch_device == "mlu":
- torch.mlu.synchronize()