From f8a13eb2d18b0489cf7683813342e9018f7e22af Mon Sep 17 00:00:00 2001 From: Lingjie Date: Fri, 29 Mar 2024 17:32:52 +0800 Subject: [PATCH 1/2] fix(dipu,python): print device index for cuda tensor Example: ```python import torch import torch_dipu torch.cuda.set_device(0) a = torch.tensor([1, 2, 3]).cuda() print(a) ``` Expected output: ``` tensor([1, 2, 3], device='cuda:0') ``` Current output (before this commit): ``` tensor([1, 2, 3], device='cuda') ``` --- dipu/.clang-format | 2 +- .../python/unittests/test_python_device.py | 25 +++++++++++++++++++ .../csrc_dipu/binding/patchCsrcDevice.cpp | 20 +++++++-------- 3 files changed, 35 insertions(+), 12 deletions(-) create mode 100644 dipu/tests/python/unittests/test_python_device.py diff --git a/dipu/.clang-format b/dipu/.clang-format index 61244b861..6f52f1b75 100644 --- a/dipu/.clang-format +++ b/dipu/.clang-format @@ -13,7 +13,7 @@ IncludeCategories: - Regex: '^("|<)Python\.h' Priority: 50 CaseSensitive: false - - Regex: '^("|<)(frameobject|structmember)\.h' + - Regex: '^("|<)(descrobject|frameobject|object|structmember)\.h' Priority: 50 SortPriority: 51 CaseSensitive: false diff --git a/dipu/tests/python/unittests/test_python_device.py b/dipu/tests/python/unittests/test_python_device.py new file mode 100644 index 000000000..002e7f809 --- /dev/null +++ b/dipu/tests/python/unittests/test_python_device.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, DeepLink. +import torch +import torch_dipu +from torch_dipu.testing._internal.common_utils import TestCase, run_tests + + +class TestAbs(TestCase): + def test_cpu(self): + a = torch.tensor([1, 2, 3]) + self.assertEqual(str(a.device), "cpu") + self.assertEqual(repr(a.device), "device(type='cpu')") + self.assertEqual(str(a), "tensor([1, 2, 3])") + self.assertEqual(repr(a), "tensor([1, 2, 3])") + + def test_cuda(self): + torch.cuda.set_device(0) + a = torch.tensor([1, 2, 3]).cuda() + self.assertEqual(str(a.device), "cuda:0") + self.assertEqual(repr(a.device), "device(type='cuda', index=0)") + self.assertEqual(str(a), "tensor([1, 2, 3], device='cuda:0')") + self.assertEqual(repr(a), "tensor([1, 2, 3], device='cuda:0')") + + +if __name__ == "__main__": + run_tests() diff --git a/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp b/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp index 0cb5e088c..84fc68e6a 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp @@ -1,24 +1,22 @@ // Copyright (c) 2023, DeepLink. #include -#include -#include +#include #include -#include -#include +#include +#include #include #include #include -#include -#include -#include -#include -#include #include #include -#include +#include +#include +#include +#include +#include #include @@ -72,7 +70,7 @@ PyObject* DIPU_THPDevice_repr(THPDevice* self) { PyObject* DIPU_THPDevice_str(THPDevice* self) { std::ostringstream oss; - oss << _get_dipu_python_type(self->device); + oss << at::Device(_get_dipu_python_type(self->device), self->device.index()); return THPUtils_packString(oss.str().c_str()); } From 4bbe4fbac63ad45bce46e6de6fe1c60ede369591 Mon Sep 17 00:00:00 2001 From: Lingjie Date: Fri, 29 Mar 2024 18:12:41 +0800 Subject: [PATCH 2/2] test(dipu,python): enhance test_python_device.py --- dipu/tests/python/unittests/test_python_device.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dipu/tests/python/unittests/test_python_device.py b/dipu/tests/python/unittests/test_python_device.py index 002e7f809..dade7f231 100644 --- a/dipu/tests/python/unittests/test_python_device.py +++ b/dipu/tests/python/unittests/test_python_device.py @@ -4,7 +4,7 @@ from torch_dipu.testing._internal.common_utils import TestCase, run_tests -class TestAbs(TestCase): +class TestPythonDevice(TestCase): def test_cpu(self): a = torch.tensor([1, 2, 3]) self.assertEqual(str(a.device), "cpu") @@ -13,12 +13,13 @@ def test_cpu(self): self.assertEqual(repr(a), "tensor([1, 2, 3])") def test_cuda(self): - torch.cuda.set_device(0) + device_index = 0 # NOTE: maybe 0 is not available, fix me if this happens + torch.cuda.set_device(device_index) a = torch.tensor([1, 2, 3]).cuda() - self.assertEqual(str(a.device), "cuda:0") - self.assertEqual(repr(a.device), "device(type='cuda', index=0)") - self.assertEqual(str(a), "tensor([1, 2, 3], device='cuda:0')") - self.assertEqual(repr(a), "tensor([1, 2, 3], device='cuda:0')") + self.assertEqual(str(a.device), f"cuda:{device_index}") + self.assertEqual(repr(a.device), f"device(type='cuda', index={device_index})") + self.assertEqual(str(a), f"tensor([1, 2, 3], device='cuda:{device_index}')") + self.assertEqual(repr(a), f"tensor([1, 2, 3], device='cuda:{device_index}')") if __name__ == "__main__":