-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Actual behavior
Mismatched type on argument #0 when calling: vm.builtin.reshape(0: ffi.Tensor, 1: ffi.Shape) -> ffi.Tensor. Expected ffi.Tensor but got DLTensor*
Environment
Ubuntu 22.04 LTS
TVM version: 0.23.0 built from source code
Steps to reproduce
run the code I paste
Hi, I am new to tvm, and follow the "Cross Compilation and RPC" tutorial, just copy the code on the page, setup a rpc server in one terminal, and run the code in another terminal. please help out, thank you.
The code is below
import tvm
from tvm import relax
from tvm.contrib import utils
from tvm.relax.frontend.torch import from_exported_program
try:
import torch
from torch.export import export
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
local_demo = True
def run_pytorch_model_via_rpc():
if not HAS_TORCH:
print("PyTorch is not installed. Skipping the test.")
return
# step 1: define and export torch model
class TorchMLP(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(28 * 28, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10),
)
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.net(data)
torch_model = TorchMLP().eval()
example_args = (torch.randn(1, 1, 28, 28, dtype=torch.float32), )
with torch.no_grad():
exported_model = export(torch_model, example_args)
# step 2: convert to relax and prepare for compilation
mod = from_exported_program(exported_model, keep_params_as_input=True)
mod, params = relax.frontend.detach_params(mod)
print("Converted torch model to relax:")
print(f" - Number of parameters: {len(params['main'])}")
#step 3: cross compile for target device
if local_demo:
target = tvm.target.Target("llvm")
print("Using local target for demonstration.")
else:
target = tvm.target.Target("llvm -mtriple=aarch64-linux-gnu")
print(f"Using remote target: {target}")
# apply optimization pipeline
pipeline = relax.get_pipeline()
with target:
built_mod = pipeline(mod)
# built_mod.show()
# compile to executable
executable = tvm.compile(built_mod, target=target)
# export to shared library
temp = utils.tempdir()
lib_path = temp.relpath("model_deployed.so")
executable.export_library(lib_path)
print(f"Exported library to {lib_path}")
# save parameters separately
import numpy as np
params_path = temp.relpath("model_params.npz")
param_arrays = {f"p_{i}": p.numpy() for i, p in enumerate(params["main"])}
np.savez(params_path, **param_arrays)
print(f"Saved parameters to {params_path}")
# step 4: deploy to remote device and run inference
device_host = "127.0.0.1"
device_port = 9090
remote = tvm.rpc.connect(device_host, device_port)
print(f"Connected to remote device at {device_host}:{device_port}")
# upload library and parameters
remote.upload(lib_path)
remote.upload(params_path)
print("Uploaded library and parameters to remote device.")
# load library on remote device
lib = remote.load_module("model_deployed.so")
# choose device on remote machine
dev = remote.cpu()
# create VM and load parameters
vm = relax.VirtualMachine(lib, dev)
params_npz = np.load(params_path)
remote_params = [
tvm.runtime.tensor(params_npz[f"p_{i}"], device = dev) for i in range(len(params_npz))
]
# step 5: run inference on remote device
input_data = np.random.randn(1, 1, 28, 28).astype("float32")
# remote_input = tvm.runtime.tensor(input_data, device=dev)
remote_input = tvm.runtime.tensor(input_data, device=dev)
output = vm["main"](remote_input, *remote_params)
if isinstance(output, tvm.ir.Array) and len(output) > 0:
result = output[0]
else:
result = output
result_np = result.numpy()
print("Inference result on remote device:")
print(" Output shape:", result_np.shape)
print(f" Predicted class: {np.argmax(result_np)}")
# step 6: measure latency
time_f = vm.time_evaluator("main", dev, number=10, repeat=3)
time_cost = time_f(remote_input, *remote_params)
print(f"Average inference time: {time_cost.mean * 1000:.2f} ms")
if HAS_TORCH and local_demo:
try:
run_pytorch_model_via_rpc()
except Exception as e:
print(f"Error during RPC demo: {e}")
my output is
python deploy_remote.py
/data/miniconda3/envs/tvm_base/lib/python3.11/copyreg.py:105: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
Converted torch model to relax:
- Number of parameters: 4
Using local target for demonstration.
Exported library to /tmp/tmpbrmko0n3/model_deployed.so
Saved parameters to /tmp/tmpbrmko0n3/model_params.npz
Connected to remote device at 127.0.0.1:9090
Uploaded library and parameters to remote device.
Error during RPC demo: Error caught from RPC call:
Mismatched type on argument #0 when calling: `vm.builtin.reshape(0: ffi.Tensor, 1: ffi.Shape) -> ffi.Tensor`. Expected `ffi.Tensor` but got `DLTensor*`
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug