diff --git a/test/infiniop/libinfiniop/liboperators.py b/test/infiniop/libinfiniop/liboperators.py index 0fd1f0f9c..f753d9785 100644 --- a/test/infiniop/libinfiniop/liboperators.py +++ b/test/infiniop/libinfiniop/liboperators.py @@ -5,11 +5,12 @@ from ctypes import c_int, c_int64, c_uint64, Structure, POINTER, c_size_t from .datatypes import * from .devices import * +from pathlib import Path Device = c_int Optype = c_int -INFINI_ROOT = os.environ.get("INFINI_ROOT") +INFINI_ROOT = os.getenv("INFINI_ROOT") or str(Path.home() / ".infini") class TensorDescriptor(Structure): @@ -30,9 +31,10 @@ def invalidate(self): class CTensor: - def __init__(self, desc, data): + def __init__(self, desc, torch_tensor): self.descriptor = desc - self.data = data + self.torch_tensor_ = torch_tensor + self.data = torch_tensor.data_ptr() class Handle(Structure): diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index cc103093d..6fd8487d2 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -19,7 +19,6 @@ def to_tensor(tensor, lib): ndim = tensor.ndimension() shape = (ctypes.c_size_t * ndim)(*tensor.shape) strides = (ctypes.c_int64 * ndim)(*(tensor.stride())) - data_ptr = tensor.data_ptr() # fmt: off dt = ( InfiniDtype.I8 if tensor.dtype == torch.int8 else @@ -46,7 +45,7 @@ def to_tensor(tensor, lib): ctypes.byref(tensor_desc), ndim, shape, strides, dt ) # Create Tensor - return CTensor(tensor_desc, data_ptr) + return CTensor(tensor_desc, tensor) def create_workspace(size, torch_device):