Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/infiniop/libinfiniop/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,17 @@ class InfiniDeviceEnum:
ILUVATAR = 6,
KUNLUN = 7,
SUGON = 8,


# Mapping that maps InfiniDeviceEnum to torch device string
infiniDeviceEnum_str_map = {
InfiniDeviceEnum.CPU: "cpu",
InfiniDeviceEnum.NVIDIA: "cuda",
InfiniDeviceEnum.CAMBRICON: "mlu",
InfiniDeviceEnum.ASCEND: "npu",
InfiniDeviceEnum.METAX: "cuda",
InfiniDeviceEnum.MOORE: "musa",
InfiniDeviceEnum.ILUVATAR: "cuda",
InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.SUGON: "cuda",
}
278 changes: 278 additions & 0 deletions test/infiniop/libinfiniop/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ctypes
from .datatypes import *
from .devices import *
from typing import Sequence
from .liboperators import infiniopTensorDescriptor_t, CTensor, infiniopHandle_t


Expand Down Expand Up @@ -46,12 +48,15 @@ def to_tensor(tensor, lib):
# Create Tensor
return CTensor(tensor_desc, data_ptr)


def create_workspace(size, torch_device):
print(f" - Workspace Size : {size}")
if size == 0:
return None
import torch
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)


def create_handle(lib, device, id=0):
handle = infiniopHandle_t()
check_error(lib.infiniopCreateHandle(ctypes.byref(handle), device, id))
Expand Down Expand Up @@ -106,3 +111,276 @@ def rearrange_tensor(tensor, new_strides):
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))

return new_tensor


def rearrange_if_needed(tensor, stride):
"""
Rearrange a PyTorch tensor if the given stride is not None.
"""
return rearrange_tensor(tensor, stride) if stride is not None else tensor


def get_args():
import argparse

parser = argparse.ArgumentParser(description="Test Operator")
parser.add_argument(
"--profile",
action="store_true",
help="Whether profile tests",
)
parser.add_argument(
"--num_prerun",
type=lambda x: max(0, int(x)),
default=10,
help="Set the number of pre-runs before profiling. Default is 10. Must be a non-negative integer.",
)
parser.add_argument(
"--num_iterations",
type=lambda x: max(0, int(x)),
default=1000,
help="Set the number of iterations for profiling. Default is 1000. Must be a non-negative integer.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Whether to turn on debug mode. If turned on, it will display detailed information about the tensors and discrepancies.",
)
parser.add_argument(
"--cpu",
action="store_true",
help="Run CPU test",
)
parser.add_argument(
"--nvidia",
action="store_true",
help="Run NVIDIA GPU test",
)
parser.add_argument(
"--cambricon",
action="store_true",
help="Run Cambricon MLU test",
)
parser.add_argument(
"--ascend",
action="store_true",
help="Run ASCEND NPU test",
)

return parser.parse_args()


def synchronize_device(torch_device):
import torch
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()


def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debugging function to compare two tensors (actual and desired) and print discrepancies.
Arguments:
----------
- actual : The tensor containing the actual computed values.
- desired : The tensor containing the expected values that `actual` should be compared to.
- atol : optional (default=0)
The absolute tolerance for the comparison.
- rtol : optional (default=1e-2)
The relative tolerance for the comparison.
- equal_nan : bool, optional (default=False)
If True, `NaN` values in `actual` and `desired` will be considered equal.
- verbose : bool, optional (default=True)
If True, the function will print detailed information about any discrepancies between the tensors.
"""
import numpy as np
print_discrepancy(actual, desired, atol, rtol, verbose)
np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True)


def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debugging function to compare two sequences of values (actual and desired) pair by pair, results
are linked by the given logical condition, and prints discrepancies
Arguments:
----------
- actual_vals (Sequence): A sequence (e.g., list or tuple) of actual computed values.
- desired_vals (Sequence): A sequence (e.g., list or tuple) of desired (expected) values to compare against.
- condition (str): A string specifying the condition for passing the test. It must be either:
- 'or': Test passes if any pair of actual and desired values satisfies the tolerance criteria.
- 'and': Test passes if all pairs of actual and desired values satisfy the tolerance criteria.
- atol (float, optional): Absolute tolerance. Default is 0.
- rtol (float, optional): Relative tolerance. Default is 1e-2.
- equal_nan (bool, optional): If True, NaN values in both actual and desired are considered equal. Default is False.
- verbose (bool, optional): If True, detailed output is printed for each comparison. Default is True.
Raises:
----------
- AssertionError: If the condition is not satisfied based on the provided `condition`, `atol`, and `rtol`.
- ValueError: If the length of `actual_vals` and `desired_vals` do not match.
- AssertionError: If the specified `condition` is not 'or' or 'and'.
"""
assert len(actual_vals) == len(desired_vals), "Invalid Length"
assert condition in {"or", "and"}, "Invalid condition: should be either 'or' or 'and'"
import numpy as np

passed = False if condition == "or" else True

for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)):
print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}")
indices = print_discrepancy(actual, desired, atol, rtol, verbose)
if condition == "or":
if not passed and len(indices) == 0:
passed = True
elif condition == "and":
if passed and len(indices) != 0:
passed = False
print(f"\033[31mThe condition has not been satisfied: Condition #{index + 1}\033[0m")
np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True)
assert passed, "\033[31mThe condition has not been satisfied\033[0m"


def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, verbose=True
):
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")

import torch
import sys

is_terminal = sys.stdout.isatty()

# Calculate the difference mask based on atol and rtol
diff_mask = torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected

# Display format: widths for columns
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)

def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text

if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)

print(add_color(" INFO:", 35))
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)")
print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}")
print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}")
print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}")
print("-" * total_width + "\n")

return diff_indices


def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Returns the atol and rtol for a given tensor data type in the tolerance_map.
If the given data type is not found, it returns the provided default tolerance values.
"""
return tolerance_map.get(tensor_dtype, {'atol': default_atol, 'rtol': default_rtol}).values()


def timed_op(func, num_iterations, device):
import time
""" Function for timing operations with synchronization. """
synchronize_device(device)
start = time.time()
for _ in range(num_iterations):
func()
synchronize_device(device)
return (time.time() - start) / num_iterations


def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS):
"""
Unified profiling workflow that is used to profile the execution time of a given function.
It first performs a number of warmup runs, then performs timed execution and
prints the average execution time.

Arguments:
----------
- desc (str): Description of the operation, used for output display.
- func (callable): The operation function to be profiled.
- torch_device (str): The device on which the operation runs, provided for timed execution.
- NUM_PRERUN (int): The number of warmup runs.
- NUM_ITERATIONS (int): The number of timed execution iterations, used to calculate the average execution time.
"""
# Warmup runs
for _ in range(NUM_PRERUN):
func()

# Timed execution
elapsed = timed_op(lambda: func(), NUM_ITERATIONS, torch_device)
print(f" {desc} time: {elapsed * 1000 :6f} ms")


def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
"""
Testing a specified operator on the given device with the given test function, test cases, and tensor data types.

Arguments:
----------
- lib (ctypes.CDLL): The library object containing the operator implementations.
- device (InfiniDeviceEnum): The device on which the operator should be tested. See device.py.
- test_func (function): The test function to be executed for each test case.
- test_cases (list of tuples): A list of test cases, where each test case is a tuple of parameters
to be passed to `test_func`.
- tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test.
"""
handle = create_handle(lib, device)
try:
for test_case in test_cases:
for tensor_dtype in tensor_dtypes:
test_func(lib, handle, infiniDeviceEnum_str_map[device], *test_case, tensor_dtype)
finally:
destroy_handle(lib, handle)


def get_test_devices(args):
"""
Using the given parsed Namespace to determine the devices to be tested.

Argument:
- args: the parsed Namespace object.

Return:
- devices_to_test: the devices that will be tested. Default is CPU.
"""
devices_to_test = []

if args.cpu: devices_to_test.append(InfiniDeviceEnum.CPU)
if args.nvidia: devices_to_test.append(InfiniDeviceEnum.NVIDIA)
if args.cambricon:
import torch_mlu
devices_to_test.append(InfiniDeviceEnum.CAMBRICON)
if args.ascend:
import torch_npu
devices_to_test.append(InfiniDeviceEnum.ASCEND)
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]

return devices_to_test
Loading