From 26f8ae573a3f138cb826b593f27636109f9a9dd4 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Tue, 18 Feb 2025 16:11:23 +0800 Subject: [PATCH] =?UTF-8?q?issue/60:=20to=5Ftensor=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E5=8E=9Ftorch=E5=BC=A0=E9=87=8F=EF=BC=8C=E5=A2=9E=E5=8A=A0INFI?= =?UTF-8?q?NI=5FROOT=E9=BB=98=E8=AE=A4=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/infiniop/libinfiniop/liboperators.py | 8 +++++--- test/infiniop/libinfiniop/utils.py | 3 +-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/infiniop/libinfiniop/liboperators.py b/test/infiniop/libinfiniop/liboperators.py index 0fd1f0f9c..f753d9785 100644 --- a/test/infiniop/libinfiniop/liboperators.py +++ b/test/infiniop/libinfiniop/liboperators.py @@ -5,11 +5,12 @@ from ctypes import c_int, c_int64, c_uint64, Structure, POINTER, c_size_t from .datatypes import * from .devices import * +from pathlib import Path Device = c_int Optype = c_int -INFINI_ROOT = os.environ.get("INFINI_ROOT") +INFINI_ROOT = os.getenv("INFINI_ROOT") or str(Path.home() / ".infini") class TensorDescriptor(Structure): @@ -30,9 +31,10 @@ def invalidate(self): class CTensor: - def __init__(self, desc, data): + def __init__(self, desc, torch_tensor): self.descriptor = desc - self.data = data + self.torch_tensor_ = torch_tensor + self.data = torch_tensor.data_ptr() class Handle(Structure): diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index cc103093d..6fd8487d2 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -19,7 +19,6 @@ def to_tensor(tensor, lib): ndim = tensor.ndimension() shape = (ctypes.c_size_t * ndim)(*tensor.shape) strides = (ctypes.c_int64 * ndim)(*(tensor.stride())) - data_ptr = tensor.data_ptr() # fmt: off dt = ( InfiniDtype.I8 if tensor.dtype == torch.int8 else @@ -46,7 +45,7 @@ def to_tensor(tensor, lib): ctypes.byref(tensor_desc), ndim, shape, strides, dt ) # Create Tensor - return CTensor(tensor_desc, data_ptr) + return CTensor(tensor_desc, tensor) def create_workspace(size, torch_device):