diff --git a/dipu/.clang-format b/dipu/.clang-format index 61244b861..6f52f1b75 100644 --- a/dipu/.clang-format +++ b/dipu/.clang-format @@ -13,7 +13,7 @@ IncludeCategories: - Regex: '^("|<)Python\.h' Priority: 50 CaseSensitive: false - - Regex: '^("|<)(frameobject|structmember)\.h' + - Regex: '^("|<)(descrobject|frameobject|object|structmember)\.h' Priority: 50 SortPriority: 51 CaseSensitive: false diff --git a/dipu/tests/python/unittests/test_python_device.py b/dipu/tests/python/unittests/test_python_device.py new file mode 100644 index 000000000..dade7f231 --- /dev/null +++ b/dipu/tests/python/unittests/test_python_device.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024, DeepLink. +import torch +import torch_dipu +from torch_dipu.testing._internal.common_utils import TestCase, run_tests + + +class TestPythonDevice(TestCase): + def test_cpu(self): + a = torch.tensor([1, 2, 3]) + self.assertEqual(str(a.device), "cpu") + self.assertEqual(repr(a.device), "device(type='cpu')") + self.assertEqual(str(a), "tensor([1, 2, 3])") + self.assertEqual(repr(a), "tensor([1, 2, 3])") + + def test_cuda(self): + device_index = 0 # NOTE: maybe 0 is not available, fix me if this happens + torch.cuda.set_device(device_index) + a = torch.tensor([1, 2, 3]).cuda() + self.assertEqual(str(a.device), f"cuda:{device_index}") + self.assertEqual(repr(a.device), f"device(type='cuda', index={device_index})") + self.assertEqual(str(a), f"tensor([1, 2, 3], device='cuda:{device_index}')") + self.assertEqual(repr(a), f"tensor([1, 2, 3], device='cuda:{device_index}')") + + +if __name__ == "__main__": + run_tests() diff --git a/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp b/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp index 0cb5e088c..84fc68e6a 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/patchCsrcDevice.cpp @@ -1,24 +1,22 @@ // Copyright (c) 2023, DeepLink. #include -#include -#include +#include #include -#include -#include +#include +#include #include #include #include -#include -#include -#include -#include -#include #include #include -#include +#include +#include +#include +#include +#include #include @@ -72,7 +70,7 @@ PyObject* DIPU_THPDevice_repr(THPDevice* self) { PyObject* DIPU_THPDevice_str(THPDevice* self) { std::ostringstream oss; - oss << _get_dipu_python_type(self->device); + oss << at::Device(_get_dipu_python_type(self->device), self->device.index()); return THPUtils_packString(oss.str().c_str()); }