From e4e309109e0207995c1bf3fde96e6b7e609c3411 Mon Sep 17 00:00:00 2001 From: WintersMontagne Date: Fri, 14 Nov 2025 17:12:22 +0000 Subject: [PATCH 1/5] Add unit tests for triton_utils.py --- .../ops/triton_ops/triton_utils.py | 314 ++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 tests/model_executor/ops/triton_ops/triton_utils.py diff --git a/tests/model_executor/ops/triton_ops/triton_utils.py b/tests/model_executor/ops/triton_ops/triton_utils.py new file mode 100644 index 00000000000..f9fd41579de --- /dev/null +++ b/tests/model_executor/ops/triton_ops/triton_utils.py @@ -0,0 +1,314 @@ +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +import paddle +import triton + +from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( + KernelInterface, + SubstituteTemplate, + build_package, + extract_triton_kernel, + find_so_path, + get_dtype_str, + get_op_name_with_suffix, + get_pointer_hint, + get_value_hint, + multi_process_do, + paddle_use_triton, + rename_c_to_cu, + rendering_common_template, +) + + +class TestTritonUtils(unittest.TestCase): + + @patch("triton.runtime.jit.JITFunction") + @patch("os.system") + @patch("multiprocessing.Process") + def test_kernel_interface_initialization(self, mock_process, mock_system, mock_jit): + def mock_func(a, b): + return a + b + + mock_func.__annotations__ = {"a": int, "b": int} + + kernel_interface = KernelInterface(mock_func, other_config={}) + + self.assertIsNotNone(kernel_interface.func) + self.assertEqual(kernel_interface.key_args, ["1"]) + self.assertIn("a", kernel_interface.arg_names) + self.assertIn("b", kernel_interface.arg_names) + + @patch("triton.runtime.jit.JITFunction") + @patch("os.system") + def test_paddle_use_triton_decorator(self, mock_system, mock_jit): + mock_jit.return_value.fn = MagicMock() + + @paddle_use_triton() + def mock_kernel(a, b): + return a + b + + self.assertIsInstance(mock_kernel, KernelInterface) + + @patch("os.system") + def test_build_package(self, mock_system): + generated_dir = "/tmp/generated" + python_package_name = "test_package" + + mock_system.return_value = 0 + build_package(generated_dir, python_package_name) + + mock_system.assert_called_with(f"cd {generated_dir} && {sys.executable} setup_cuda.py build") + + @triton.jit + def simple_kernel(x, y): + return x + y + + @patch("builtins.open", new_callable=MagicMock) + def test_extract_triton_kernel_with_real_kernel(self, mock_open): + mock_file = MagicMock() + mock_file.write = MagicMock() + mock_open.return_value = mock_file + file_name = "kernel.py" + extract_triton_kernel(self.simple_kernel, file_name) + mock_open.assert_called_with(file_name, "w") + + @patch("os.system") + @patch("multiprocessing.Process") + def test_multi_process_do(self, mock_process, mock_system): + commands = ["echo 'hello'"] * 5 + + mock_system.return_value = 0 + + mock_process_instance = MagicMock() + mock_process.return_value = mock_process_instance + + multi_process_do(commands) + + self.assertEqual(mock_process.call_count, 40) + mock_process_instance.start.assert_called() + mock_process_instance.join.assert_called() + + @patch("os.rename") + def test_rename_c_to_cu(self, mock_rename): + generated_dir = "/tmp/generated" + os.makedirs(generated_dir, exist_ok=True) + + with open(os.path.join(generated_dir, "file1.c"), "w") as f: + f.write("content") + + rename_c_to_cu(generated_dir) + + mock_rename.assert_called_with(os.path.join(generated_dir, "file1.c"), os.path.join(generated_dir, "file1.cu")) + + def test_substitute_template(self): + template = "Hello, ${name}! Welcome to ${place}." + values = {"name": "Alice", "place": "Wonderland"} + result = SubstituteTemplate(template, values) + self.assertEqual(result, "Hello, Alice! Welcome to Wonderland.") + + @patch("os.walk") + def test_find_so_path_found(self, mock_os_walk): + mock_os_walk.return_value = [("/path/to/dir", [], ["file1.so", "file2.so"])] + so_path = find_so_path("/path/to/dir", "file1") + self.assertEqual(so_path, "/path/to/dir/file1.so") + + @patch("os.walk") + def test_find_so_path_not_found(self, mock_os_walk): + mock_os_walk.return_value = [("/path/to/dir", [], ["file1.txt", "file2.txt"])] + so_path = find_so_path("/path/to/dir", "file") + self.assertIsNone(so_path) + + def test_get_op_name_with_suffix(self): + result = get_op_name_with_suffix("op_name", [16, 1, 32]) + self.assertEqual(result, "op_name16_1_16") + + def test_get_value_hint(self): + result = get_value_hint([16, 1, 32]) + self.assertEqual(result, "i64:16,i64:1,i64:16,") + + def test_get_dtype_str(self): + result = get_dtype_str(paddle.float32) + self.assertEqual(result, "_fp32") + + with self.assertRaises(ValueError): + get_dtype_str(paddle.bool) + + def test_get_pointer_hint(self): + result = get_pointer_hint([paddle.float16, paddle.int32, paddle.uint8]) + self.assertEqual(result, "*fp16:16,*i32:16,*u8:16,") + + +class TestRenderingCommonTemplate(unittest.TestCase): + + def mock_function(self): + def func(a: int, b: float = 2.0, c: bool = True, d: str = "test"): + pass + + return func + + def test_rendering_with_no_return_tensor(self): + func = self.mock_function() + prepare_attr_for_triton_kernel = "prepare_attr_code" + prepare_ptr_for_triton_kernel = "prepare_ptr_code" + + result = rendering_common_template(func, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel) + + self.assertIn('Outputs({"useless"}', result) + + def test_rendering_with_return_tensor(self): + func = self.mock_function() + prepare_attr_for_triton_kernel = "prepare_attr_code" + prepare_ptr_for_triton_kernel = "prepare_ptr_code" + return_tensor_names = "out_tensor" + + result = rendering_common_template( + func, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names=return_tensor_names, + ) + + self.assertIn('Outputs({"out_tensor"})', result) + self.assertIn("std::vector> ${op_name}_InferShape", result) + self.assertIn("std::vector ${op_name}_InferDtype", result) + + def test_rendering_with_d2s_infer_code(self): + func = self.mock_function() + prepare_attr_for_triton_kernel = "prepare_attr_code" + prepare_ptr_for_triton_kernel = "prepare_ptr_code" + return_tensor_names = "out_tensor" + d2s_infer_code = "existing_infer_code" + + result = rendering_common_template( + func, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names=return_tensor_names, + d2s_infer_code=d2s_infer_code, + ) + + self.assertIn("existing_infer_code", result) + + def test_rendering_with_default_parameters(self): + func = self.mock_function() + prepare_attr_for_triton_kernel = "prepare_attr_code" + prepare_ptr_for_triton_kernel = "prepare_ptr_code" + + result = rendering_common_template(func, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel) + + self.assertIn("float b", result) + self.assertIn("bool c", result) + self.assertIn("std::string d", result) + + def test_rendering_with_invalid_function(self): + def invalid_func(): + pass + + prepare_attr_for_triton_kernel = "prepare_attr_code" + prepare_ptr_for_triton_kernel = "prepare_ptr_code" + + result = rendering_common_template(invalid_func, prepare_attr_for_triton_kernel, prepare_ptr_for_triton_kernel) + + self.assertIn("useless", result) + + def test_rendering_with_multiple_return_tensors(self): + func = self.mock_function() + prepare_attr_for_triton_kernel = "prepare_attr_code" + prepare_ptr_for_triton_kernel = "prepare_ptr_code" + return_tensor_names = "out_tensor, aux_tensor" + + result = rendering_common_template( + func, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names=return_tensor_names, + ) + + self.assertIn('Outputs({"out_tensor","aux_tensor"})', result) + + def test_rendering_with_edge_case_return_tensor_names(self): + func = self.mock_function() + prepare_attr_for_triton_kernel = "prepare_attr_code" + prepare_ptr_for_triton_kernel = "prepare_ptr_code" + return_tensor_names = "" + + result = rendering_common_template( + func, + prepare_attr_for_triton_kernel, + prepare_ptr_for_triton_kernel, + return_tensor_names=return_tensor_names, + ) + + self.assertIn('Outputs({""}', result) + + +class TestKernelInterface(unittest.TestCase): + + @patch( + "fastdeploy.model_executor.ops.triton_ops.triton_utils.paddle.utils.cpp_extension.load_op_meta_info_and_register_op" + ) + @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.OpProtoHolder.instance") + @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.multi_process_do") + @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.build_package") + @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.find_so_path") + @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.extract_triton_kernel") + @patch("paddle.distributed.get_rank") + @patch("os.path") + @patch("os.makedirs") + @patch("builtins.open", new_callable=MagicMock) + @patch("os.system") + @patch("os.rename") + @patch("os.listdir") + def test_kernel_interface_initialization( + self, + mock_listdir, + mock_rename, + mock_system, + mock_open, + mock_makedirs, + mock_os_path, + mock_get_rank, + mock_extract_triton_kernel, + mock_find_so_path, + mock_build_package, + mock_multi_process_do, + mock_op_proto_instance, + mock_register_op, + ): + mock_system.return_value = 0 + mock_get_rank.return_value = 0 + mock_extract_triton_kernel.return_value = None + mock_find_so_path.return_value = None + mock_build_package.return_value = None + mock_multi_process_do.return_value = None + mock_op_proto_map = {"simple_op": "some_value"} + mock_op_proto_instance_return_value = MagicMock() + mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map + mock_op_proto_instance.return_value = mock_op_proto_instance_return_value + + mock_register_op.return_value = None + + def mock_kernel_func(a, b: int, c: str): + return a + b + + kernel_interface = KernelInterface(mock_kernel_func, other_config={}) + + kernel_interface.op_name = "simple_op" + kernel_interface.custom_op_template = "custom_template" + kernel_interface.grid = [1, 1, 1] + kernel_interface.tune_config = {} + + self.assertIsNotNone(kernel_interface.func) + self.assertEqual(kernel_interface.key_args, ["1"]) + self.assertIn("a", kernel_interface.arg_names) + self.assertIn("b", kernel_interface.arg_names) + self.assertIn("c", kernel_interface.arg_names) + + kernel_interface.decorator("simple_op", "custom_template", [1, 1, 1]) + + +if __name__ == "__main__": + unittest.main() From 677e9279ca2e683056a9bbeea85faa77d24d3917 Mon Sep 17 00:00:00 2001 From: WintersMontagne Date: Sat, 15 Nov 2025 01:24:44 +0000 Subject: [PATCH 2/5] update name --- .../ops/triton_ops/{triton_utils.py => test_triton_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/model_executor/ops/triton_ops/{triton_utils.py => test_triton_utils.py} (100%) diff --git a/tests/model_executor/ops/triton_ops/triton_utils.py b/tests/model_executor/ops/triton_ops/test_triton_utils.py similarity index 100% rename from tests/model_executor/ops/triton_ops/triton_utils.py rename to tests/model_executor/ops/triton_ops/test_triton_utils.py From 741767e69573236e7db32bfc8039cad7ea37ea23 Mon Sep 17 00:00:00 2001 From: WintersMontagne10335 Date: Sat, 15 Nov 2025 01:43:30 +0000 Subject: [PATCH 3/5] update --- tests/model_executor/ops/triton_ops/test_triton_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/model_executor/ops/triton_ops/test_triton_utils.py b/tests/model_executor/ops/triton_ops/test_triton_utils.py index f9fd41579de..dcbde2d7f67 100644 --- a/tests/model_executor/ops/triton_ops/test_triton_utils.py +++ b/tests/model_executor/ops/triton_ops/test_triton_utils.py @@ -52,8 +52,9 @@ def mock_kernel(a, b): self.assertIsInstance(mock_kernel, KernelInterface) + @patch("builtins.open", new_callable=MagicMock) @patch("os.system") - def test_build_package(self, mock_system): + def test_build_package(self, mock_system, mock_open): generated_dir = "/tmp/generated" python_package_name = "test_package" From 7b40f02c9e381345b2d839c4c98bbce58e0cd0a4 Mon Sep 17 00:00:00 2001 From: WintersMontagne10335 Date: Sat, 15 Nov 2025 04:08:40 +0000 Subject: [PATCH 4/5] update --- .../ops/triton_ops/test_triton_utils.py | 100 ++++++++++++------ 1 file changed, 66 insertions(+), 34 deletions(-) diff --git a/tests/model_executor/ops/triton_ops/test_triton_utils.py b/tests/model_executor/ops/triton_ops/test_triton_utils.py index dcbde2d7f67..5a90cc3d0a7 100644 --- a/tests/model_executor/ops/triton_ops/test_triton_utils.py +++ b/tests/model_executor/ops/triton_ops/test_triton_utils.py @@ -22,6 +22,9 @@ rendering_common_template, ) +TRITON_UTILS_PATH = "fastdeploy.model_executor.ops.triton_ops.triton_utils" +MOCK_GENERATED_DIR = "/tmp/generated" + class TestTritonUtils(unittest.TestCase): @@ -29,12 +32,12 @@ class TestTritonUtils(unittest.TestCase): @patch("os.system") @patch("multiprocessing.Process") def test_kernel_interface_initialization(self, mock_process, mock_system, mock_jit): - def mock_func(a, b): + def mock_function(a, b): return a + b - mock_func.__annotations__ = {"a": int, "b": int} + mock_function.__annotations__ = {"a": int, "b": int} - kernel_interface = KernelInterface(mock_func, other_config={}) + kernel_interface = KernelInterface(mock_function, other_config={}) self.assertIsNotNone(kernel_interface.func) self.assertEqual(kernel_interface.key_args, ["1"]) @@ -55,34 +58,33 @@ def mock_kernel(a, b): @patch("builtins.open", new_callable=MagicMock) @patch("os.system") def test_build_package(self, mock_system, mock_open): - generated_dir = "/tmp/generated" - python_package_name = "test_package" - + MOCK_NAME = "test_package" mock_system.return_value = 0 - build_package(generated_dir, python_package_name) - mock_system.assert_called_with(f"cd {generated_dir} && {sys.executable} setup_cuda.py build") + build_package(MOCK_GENERATED_DIR, MOCK_NAME) - @triton.jit - def simple_kernel(x, y): - return x + y + mock_system.assert_called_with(f"cd {MOCK_GENERATED_DIR} && {sys.executable} setup_cuda.py build") @patch("builtins.open", new_callable=MagicMock) def test_extract_triton_kernel_with_real_kernel(self, mock_open): + @triton.jit + def mock_kernel(x, y): + return x + y + mock_file = MagicMock() mock_file.write = MagicMock() mock_open.return_value = mock_file file_name = "kernel.py" - extract_triton_kernel(self.simple_kernel, file_name) + + extract_triton_kernel(mock_kernel, file_name) + mock_open.assert_called_with(file_name, "w") @patch("os.system") @patch("multiprocessing.Process") def test_multi_process_do(self, mock_process, mock_system): commands = ["echo 'hello'"] * 5 - mock_system.return_value = 0 - mock_process_instance = MagicMock() mock_process.return_value = mock_process_instance @@ -94,44 +96,54 @@ def test_multi_process_do(self, mock_process, mock_system): @patch("os.rename") def test_rename_c_to_cu(self, mock_rename): - generated_dir = "/tmp/generated" - os.makedirs(generated_dir, exist_ok=True) + os.makedirs(MOCK_GENERATED_DIR, exist_ok=True) - with open(os.path.join(generated_dir, "file1.c"), "w") as f: + with open(os.path.join(MOCK_GENERATED_DIR, "file1.c"), "w") as f: f.write("content") - rename_c_to_cu(generated_dir) + rename_c_to_cu(MOCK_GENERATED_DIR) - mock_rename.assert_called_with(os.path.join(generated_dir, "file1.c"), os.path.join(generated_dir, "file1.cu")) + mock_rename.assert_called_with( + os.path.join(MOCK_GENERATED_DIR, "file1.c"), os.path.join(MOCK_GENERATED_DIR, "file1.cu") + ) def test_substitute_template(self): template = "Hello, ${name}! Welcome to ${place}." values = {"name": "Alice", "place": "Wonderland"} + result = SubstituteTemplate(template, values) + self.assertEqual(result, "Hello, Alice! Welcome to Wonderland.") @patch("os.walk") def test_find_so_path_found(self, mock_os_walk): mock_os_walk.return_value = [("/path/to/dir", [], ["file1.so", "file2.so"])] + so_path = find_so_path("/path/to/dir", "file1") + self.assertEqual(so_path, "/path/to/dir/file1.so") @patch("os.walk") def test_find_so_path_not_found(self, mock_os_walk): mock_os_walk.return_value = [("/path/to/dir", [], ["file1.txt", "file2.txt"])] + so_path = find_so_path("/path/to/dir", "file") + self.assertIsNone(so_path) def test_get_op_name_with_suffix(self): result = get_op_name_with_suffix("op_name", [16, 1, 32]) + self.assertEqual(result, "op_name16_1_16") def test_get_value_hint(self): result = get_value_hint([16, 1, 32]) + self.assertEqual(result, "i64:16,i64:1,i64:16,") def test_get_dtype_str(self): result = get_dtype_str(paddle.float32) + self.assertEqual(result, "_fp32") with self.assertRaises(ValueError): @@ -139,6 +151,7 @@ def test_get_dtype_str(self): def test_get_pointer_hint(self): result = get_pointer_hint([paddle.float16, paddle.int32, paddle.uint8]) + self.assertEqual(result, "*fp16:16,*i32:16,*u8:16,") @@ -248,14 +261,11 @@ def test_rendering_with_edge_case_return_tensor_names(self): class TestKernelInterface(unittest.TestCase): - @patch( - "fastdeploy.model_executor.ops.triton_ops.triton_utils.paddle.utils.cpp_extension.load_op_meta_info_and_register_op" - ) - @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.OpProtoHolder.instance") - @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.multi_process_do") - @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.build_package") - @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.find_so_path") - @patch("fastdeploy.model_executor.ops.triton_ops.triton_utils.extract_triton_kernel") + @patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance") + @patch(f"{TRITON_UTILS_PATH}.multi_process_do") + @patch(f"{TRITON_UTILS_PATH}.build_package") + @patch(f"{TRITON_UTILS_PATH}.find_so_path") + @patch(f"{TRITON_UTILS_PATH}.extract_triton_kernel") @patch("paddle.distributed.get_rank") @patch("os.path") @patch("os.makedirs") @@ -277,7 +287,6 @@ def test_kernel_interface_initialization( mock_build_package, mock_multi_process_do, mock_op_proto_instance, - mock_register_op, ): mock_system.return_value = 0 mock_get_rank.return_value = 0 @@ -285,19 +294,16 @@ def test_kernel_interface_initialization( mock_find_so_path.return_value = None mock_build_package.return_value = None mock_multi_process_do.return_value = None - mock_op_proto_map = {"simple_op": "some_value"} + mock_op_proto_map = {"mock_op": "some_value"} mock_op_proto_instance_return_value = MagicMock() mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map mock_op_proto_instance.return_value = mock_op_proto_instance_return_value - mock_register_op.return_value = None - def mock_kernel_func(a, b: int, c: str): return a + b kernel_interface = KernelInterface(mock_kernel_func, other_config={}) - - kernel_interface.op_name = "simple_op" + kernel_interface.op_name = "mock_op" kernel_interface.custom_op_template = "custom_template" kernel_interface.grid = [1, 1, 1] kernel_interface.tune_config = {} @@ -308,7 +314,33 @@ def mock_kernel_func(a, b: int, c: str): self.assertIn("b", kernel_interface.arg_names) self.assertIn("c", kernel_interface.arg_names) - kernel_interface.decorator("simple_op", "custom_template", [1, 1, 1]) + kernel_interface.decorator("mock_op", "custom_template", [1, 1, 1]) + + mock_op_proto_instance.assert_called_once_with() + mock_extract_triton_kernel.assert_called_once_with( + mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py" + ) + mock_open.assert_called_once_with("/tmp/triton_cache/rank0/mock_op/mock_op.cu", "w") + mock_system.assert_called() + mock_build_package.assert_called_once_with("/tmp/triton_cache/rank0/mock_op", "mock_op_package") + + @patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance") + def test_getitem(self, mock_op_proto_instance): + def mock_kernel_func(a, b): + return a + b + + mock_op_proto_map = {"mock_op": "some_value"} + mock_op_proto_instance_return_value = MagicMock() + mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map + mock_op_proto_instance.return_value = mock_op_proto_instance_return_value + kernel_interface = KernelInterface(mock_kernel_func, other_config={}) + op_name_and_grid = ["mock_op", "custom_template", [1, 1, 1]] + + kernel_interface[op_name_and_grid] + + self.assertEqual(kernel_interface.op_name, "mock_op") + self.assertEqual(kernel_interface.custom_op_template, "custom_template") + self.assertEqual(kernel_interface.grid, [1, 1, 1]) if __name__ == "__main__": From 05cc79b0b50be9d022d36bac98fb52d2e804c9bd Mon Sep 17 00:00:00 2001 From: WintersMontagne10335 Date: Sat, 15 Nov 2025 12:46:19 +0000 Subject: [PATCH 5/5] update --- .../ops/triton_ops/triton_utils.py | 2 +- .../ops/triton_ops/test_triton_utils.py | 224 +++++++++++++++--- 2 files changed, 198 insertions(+), 28 deletions(-) diff --git a/fastdeploy/model_executor/ops/triton_ops/triton_utils.py b/fastdeploy/model_executor/ops/triton_ops/triton_utils.py index 2a2a00d0d09..a61268044bd 100644 --- a/fastdeploy/model_executor/ops/triton_ops/triton_utils.py +++ b/fastdeploy/model_executor/ops/triton_ops/triton_utils.py @@ -710,7 +710,7 @@ def decorator(*args, **kwargs): + f""" -s"{address_hint} {value_hint} {const_args}" """ + f""" -g "{lanuch_grid}" """ ) - all_tune_config = list(self.tune_config) + all_tune_config = [{key: value} for key, value in self.tune_config.items()] if len(all_tune_config) == 0: # when user do not specify config, we use const_hint_dict as config. all_tune_config = [const_hint_dict] diff --git a/tests/model_executor/ops/triton_ops/test_triton_utils.py b/tests/model_executor/ops/triton_ops/test_triton_utils.py index 5a90cc3d0a7..aecd2ae5733 100644 --- a/tests/model_executor/ops/triton_ops/test_triton_utils.py +++ b/tests/model_executor/ops/triton_ops/test_triton_utils.py @@ -28,45 +28,57 @@ class TestTritonUtils(unittest.TestCase): + # Test case to validate KernelInterface initialization @patch("triton.runtime.jit.JITFunction") @patch("os.system") @patch("multiprocessing.Process") def test_kernel_interface_initialization(self, mock_process, mock_system, mock_jit): + # Mock function for testing def mock_function(a, b): return a + b + # Add type annotations to the mock function mock_function.__annotations__ = {"a": int, "b": int} + # Initialize KernelInterface with the mock function kernel_interface = KernelInterface(mock_function, other_config={}) + # Validate that the function and argument names are correctly initialized self.assertIsNotNone(kernel_interface.func) self.assertEqual(kernel_interface.key_args, ["1"]) self.assertIn("a", kernel_interface.arg_names) self.assertIn("b", kernel_interface.arg_names) + # Test case for validating the paddle_use_triton decorator @patch("triton.runtime.jit.JITFunction") @patch("os.system") def test_paddle_use_triton_decorator(self, mock_system, mock_jit): mock_jit.return_value.fn = MagicMock() + # Apply the paddle_use_triton decorator to a mock function @paddle_use_triton() def mock_kernel(a, b): return a + b + # Validate the result of the decorator self.assertIsInstance(mock_kernel, KernelInterface) + # Test case for validating the build_package function @patch("builtins.open", new_callable=MagicMock) @patch("os.system") def test_build_package(self, mock_system, mock_open): MOCK_NAME = "test_package" mock_system.return_value = 0 + # Call build_package with mocked directory and package name build_package(MOCK_GENERATED_DIR, MOCK_NAME) + # Assert that system command was called correctly mock_system.assert_called_with(f"cd {MOCK_GENERATED_DIR} && {sys.executable} setup_cuda.py build") + # Test case for extracting Triton kernel with a JIT kernel @patch("builtins.open", new_callable=MagicMock) - def test_extract_triton_kernel_with_real_kernel(self, mock_open): + def test_extract_triton_kernel_with_jit_kernel(self, mock_open): @triton.jit def mock_kernel(x, y): return x + y @@ -76,10 +88,30 @@ def mock_kernel(x, y): mock_open.return_value = mock_file file_name = "kernel.py" + # Extract the Triton kernel and write to the specified file extract_triton_kernel(mock_kernel, file_name) + # Assert that file write was performed as expected mock_open.assert_called_with(file_name, "w") + # Test case for extracting Triton kernel with a Python kernel + @patch("builtins.open", new_callable=MagicMock) + def test_extract_triton_kernel_with_python_kernel(self, mock_open): + def mock_kernel(x, y): + return x + y + + mock_file = MagicMock() + mock_file.write = MagicMock() + mock_open.return_value = mock_file + file_name = "kernel.py" + + # Extract the Triton kernel and write to the specified file + extract_triton_kernel(mock_kernel, file_name) + + # Assert that file write was performed as expected + mock_open.assert_called_with(file_name, "w") + + # Test case for validating multi-process execution of commands @patch("os.system") @patch("multiprocessing.Process") def test_multi_process_do(self, mock_process, mock_system): @@ -88,70 +120,90 @@ def test_multi_process_do(self, mock_process, mock_system): mock_process_instance = MagicMock() mock_process.return_value = mock_process_instance + # Call multi_process_do with the list of commands multi_process_do(commands) + # Assert the expected behavior of mock_process self.assertEqual(mock_process.call_count, 40) mock_process_instance.start.assert_called() mock_process_instance.join.assert_called() + # Test case for renaming .c files to .cu in the specified directory + @patch("os.listdir") @patch("os.rename") - def test_rename_c_to_cu(self, mock_rename): - os.makedirs(MOCK_GENERATED_DIR, exist_ok=True) - - with open(os.path.join(MOCK_GENERATED_DIR, "file1.c"), "w") as f: - f.write("content") + def test_rename_c_to_cu(self, mock_rename, mock_listdir): + mock_listdir.return_value = ["file1.c"] + # Call rename_c_to_cu to rename the files in the directory rename_c_to_cu(MOCK_GENERATED_DIR) + # Assert that the rename operation was performed correctly mock_rename.assert_called_with( os.path.join(MOCK_GENERATED_DIR, "file1.c"), os.path.join(MOCK_GENERATED_DIR, "file1.cu") ) + # Test case for validating template substitution def test_substitute_template(self): template = "Hello, ${name}! Welcome to ${place}." values = {"name": "Alice", "place": "Wonderland"} + # Call SubstituteTemplate to replace placeholders in the template result = SubstituteTemplate(template, values) + # Assert that the substitution worked as expected self.assertEqual(result, "Hello, Alice! Welcome to Wonderland.") + # Test case for finding shared object (.so) file in a directory @patch("os.walk") def test_find_so_path_found(self, mock_os_walk): mock_os_walk.return_value = [("/path/to/dir", [], ["file1.so", "file2.so"])] + # Call find_so_path to locate the .so file so_path = find_so_path("/path/to/dir", "file1") + # Assert that the correct path is returned self.assertEqual(so_path, "/path/to/dir/file1.so") + # Test case for handling the scenario when the .so file is not found @patch("os.walk") def test_find_so_path_not_found(self, mock_os_walk): mock_os_walk.return_value = [("/path/to/dir", [], ["file1.txt", "file2.txt"])] + # Call find_so_path when the .so file is not present so_path = find_so_path("/path/to/dir", "file") + # Assert that None is returned when the .so file is not found self.assertIsNone(so_path) + # Test case for getting the operator name with suffix def test_get_op_name_with_suffix(self): result = get_op_name_with_suffix("op_name", [16, 1, 32]) + # Assert the correct suffix is added to the operator name self.assertEqual(result, "op_name16_1_16") + # Test case for getting value hints from a list def test_get_value_hint(self): result = get_value_hint([16, 1, 32]) + # Assert that the correct value hint string is generated self.assertEqual(result, "i64:16,i64:1,i64:16,") + # Test case for getting the string representation of a data type def test_get_dtype_str(self): result = get_dtype_str(paddle.float32) + # Assert the correct dtype string is returned self.assertEqual(result, "_fp32") with self.assertRaises(ValueError): get_dtype_str(paddle.bool) + # Test case for getting pointer hints for different data types def test_get_pointer_hint(self): result = get_pointer_hint([paddle.float16, paddle.int32, paddle.uint8]) + # Assert the correct pointer hint string is generated self.assertEqual(result, "*fp16:16,*i32:16,*u8:16,") @@ -163,6 +215,7 @@ def func(a: int, b: float = 2.0, c: bool = True, d: str = "test"): return func + # Test case for rendering a function template without return tensor def test_rendering_with_no_return_tensor(self): func = self.mock_function() prepare_attr_for_triton_kernel = "prepare_attr_code" @@ -172,6 +225,7 @@ def test_rendering_with_no_return_tensor(self): self.assertIn('Outputs({"useless"}', result) + # Test case for rendering a function template with return tensor def test_rendering_with_return_tensor(self): func = self.mock_function() prepare_attr_for_triton_kernel = "prepare_attr_code" @@ -189,6 +243,7 @@ def test_rendering_with_return_tensor(self): self.assertIn("std::vector> ${op_name}_InferShape", result) self.assertIn("std::vector ${op_name}_InferDtype", result) + # Test case for rendering a function template with d2s inference code def test_rendering_with_d2s_infer_code(self): func = self.mock_function() prepare_attr_for_triton_kernel = "prepare_attr_code" @@ -206,6 +261,7 @@ def test_rendering_with_d2s_infer_code(self): self.assertIn("existing_infer_code", result) + # Test case for rendering a function template with default parameters def test_rendering_with_default_parameters(self): func = self.mock_function() prepare_attr_for_triton_kernel = "prepare_attr_code" @@ -217,6 +273,7 @@ def test_rendering_with_default_parameters(self): self.assertIn("bool c", result) self.assertIn("std::string d", result) + # Test case for rendering a function template with an invalid function def test_rendering_with_invalid_function(self): def invalid_func(): pass @@ -228,6 +285,7 @@ def invalid_func(): self.assertIn("useless", result) + # Test case for rendering a function template with multiple return tensors def test_rendering_with_multiple_return_tensors(self): func = self.mock_function() prepare_attr_for_triton_kernel = "prepare_attr_code" @@ -243,6 +301,7 @@ def test_rendering_with_multiple_return_tensors(self): self.assertIn('Outputs({"out_tensor","aux_tensor"})', result) + # Test case for rendering a function template with edge case return tensor names def test_rendering_with_edge_case_return_tensor_names(self): func = self.mock_function() prepare_attr_for_triton_kernel = "prepare_attr_code" @@ -260,7 +319,18 @@ def test_rendering_with_edge_case_return_tensor_names(self): class TestKernelInterface(unittest.TestCase): + # A mock kernel function with constant arguments + def mock_kernel_func( + self, + a, + b: int, + config_key0: triton.language.core.constexpr, + config_key1: triton.language.core.constexpr, + config_key2: triton.language.core.constexpr, + ): + return a + b + # Test case for when values do not match in decorator @patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance") @patch(f"{TRITON_UTILS_PATH}.multi_process_do") @patch(f"{TRITON_UTILS_PATH}.build_package") @@ -273,7 +343,7 @@ class TestKernelInterface(unittest.TestCase): @patch("os.system") @patch("os.rename") @patch("os.listdir") - def test_kernel_interface_initialization( + def test_with_values_do_not_match( self, mock_listdir, mock_rename, @@ -299,48 +369,148 @@ def test_kernel_interface_initialization( mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map mock_op_proto_instance.return_value = mock_op_proto_instance_return_value - def mock_kernel_func(a, b: int, c: str): - return a + b + kernel_interface = KernelInterface(self.mock_kernel_func, other_config={}) + op_name_and_grid = [ + "mock_op", + "custom_template", + [1, "1", 1], + {"config_key0": "config_value0", "config_key1": "config_value1", "config_key2": "config_value2"}, + ] + kernel_interface[op_name_and_grid] + + self.assertIsNotNone(kernel_interface.func) + self.assertIn("a", kernel_interface.arg_names) + self.assertIn("b", kernel_interface.arg_names) + self.assertIn("config_key0", kernel_interface.arg_names) + + # Test if ValueError is raised when values do not match + with self.assertRaises(ValueError): + kernel_interface.decorator("mock_op", "custom_template", "config_value", "config_value1", "config_value2") - kernel_interface = KernelInterface(mock_kernel_func, other_config={}) - kernel_interface.op_name = "mock_op" - kernel_interface.custom_op_template = "custom_template" - kernel_interface.grid = [1, 1, 1] - kernel_interface.tune_config = {} + mock_extract_triton_kernel.assert_called_once_with( + self.mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py" + ) + mock_open.assert_called_once_with("/tmp/triton_cache/rank0/mock_op/mock_op.cu", "w") + + # Test case for when parameter values match in decorator + @patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance") + @patch(f"{TRITON_UTILS_PATH}.multi_process_do") + @patch(f"{TRITON_UTILS_PATH}.build_package") + @patch(f"{TRITON_UTILS_PATH}.find_so_path") + @patch(f"{TRITON_UTILS_PATH}.extract_triton_kernel") + @patch("paddle.distributed.get_rank") + @patch("os.path") + @patch("os.makedirs") + @patch("builtins.open", new_callable=MagicMock) + @patch("os.system") + @patch("os.rename") + @patch("os.listdir") + def test_with_values_match( + self, + mock_listdir, + mock_rename, + mock_system, + mock_open, + mock_makedirs, + mock_os_path, + mock_get_rank, + mock_extract_triton_kernel, + mock_find_so_path, + mock_build_package, + mock_multi_process_do, + mock_op_proto_instance, + ): + mock_system.return_value = 0 + mock_get_rank.return_value = 0 + mock_extract_triton_kernel.return_value = None + mock_find_so_path.return_value = None + mock_build_package.return_value = None + mock_multi_process_do.return_value = None + mock_op_proto_map = {"mock_op": "some_value"} + mock_op_proto_instance_return_value = MagicMock() + mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map + mock_op_proto_instance.return_value = mock_op_proto_instance_return_value + + kernel_interface = KernelInterface(self.mock_kernel_func, other_config={}) + op_name_and_grid = [ + "mock_op", + "custom_template", + [1, "1", 1], + {"config_key0": "config_value0", "config_key1": "config_value1"}, + ] + kernel_interface[op_name_and_grid] self.assertIsNotNone(kernel_interface.func) - self.assertEqual(kernel_interface.key_args, ["1"]) self.assertIn("a", kernel_interface.arg_names) self.assertIn("b", kernel_interface.arg_names) - self.assertIn("c", kernel_interface.arg_names) + self.assertIn("config_key0", kernel_interface.arg_names) - kernel_interface.decorator("mock_op", "custom_template", [1, 1, 1]) + # Validate if the decorator works correctly when values match + kernel_interface.decorator("mock_op", "custom_template", "config_value0", "config_value1", "config_value2") mock_op_proto_instance.assert_called_once_with() mock_extract_triton_kernel.assert_called_once_with( - mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py" + self.mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py" ) mock_open.assert_called_once_with("/tmp/triton_cache/rank0/mock_op/mock_op.cu", "w") mock_system.assert_called() mock_build_package.assert_called_once_with("/tmp/triton_cache/rank0/mock_op", "mock_op_package") + # Test case for when parameter values match in decorator @patch(f"{TRITON_UTILS_PATH}.OpProtoHolder.instance") - def test_getitem(self, mock_op_proto_instance): - def mock_kernel_func(a, b): - return a + b - + @patch(f"{TRITON_UTILS_PATH}.multi_process_do") + @patch(f"{TRITON_UTILS_PATH}.build_package") + @patch(f"{TRITON_UTILS_PATH}.find_so_path") + @patch(f"{TRITON_UTILS_PATH}.extract_triton_kernel") + @patch("paddle.distributed.get_rank") + @patch("os.path") + @patch("os.makedirs") + @patch("builtins.open", new_callable=MagicMock) + @patch("os.system") + @patch("os.rename") + @patch("os.listdir") + def test_with_missing_parameter( + self, + mock_listdir, + mock_rename, + mock_system, + mock_open, + mock_makedirs, + mock_os_path, + mock_get_rank, + mock_extract_triton_kernel, + mock_find_so_path, + mock_build_package, + mock_multi_process_do, + mock_op_proto_instance, + ): + mock_system.return_value = 0 + mock_get_rank.return_value = 0 + mock_extract_triton_kernel.return_value = None + mock_find_so_path.return_value = None + mock_build_package.return_value = None + mock_multi_process_do.return_value = None mock_op_proto_map = {"mock_op": "some_value"} mock_op_proto_instance_return_value = MagicMock() mock_op_proto_instance_return_value.op_proto_map = mock_op_proto_map mock_op_proto_instance.return_value = mock_op_proto_instance_return_value - kernel_interface = KernelInterface(mock_kernel_func, other_config={}) - op_name_and_grid = ["mock_op", "custom_template", [1, 1, 1]] + kernel_interface = KernelInterface(self.mock_kernel_func, other_config={}) + op_name_and_grid = ["mock_op", "custom_template", [1, "1", 1]] kernel_interface[op_name_and_grid] - self.assertEqual(kernel_interface.op_name, "mock_op") - self.assertEqual(kernel_interface.custom_op_template, "custom_template") - self.assertEqual(kernel_interface.grid, [1, 1, 1]) + self.assertIsNotNone(kernel_interface.func) + self.assertIn("a", kernel_interface.arg_names) + self.assertIn("b", kernel_interface.arg_names) + self.assertIn("config_key0", kernel_interface.arg_names) + + with self.assertRaises(AssertionError): + kernel_interface.decorator("mock_op", "custom_template") + + mock_extract_triton_kernel.assert_called_once_with( + self.mock_kernel_func, "/tmp/triton_cache/rank0/mock_op/triton_kernels.py" + ) + mock_open.assert_called_once_with("/tmp/triton_cache/rank0/mock_op/mock_op.cu", "w") if __name__ == "__main__":