In [1]:
import onnx
from tvm import relay
import onnx
from tvm import relay, autotvm
from tvm.contrib import graph_executor
import tvm
import numpy as np

In [2]:
# Step 1: Load the ONNX model
onnx_model_path = r"/home/ragul-1819/tvm_assignments/mnist_model.onnx"
model = onnx.load(onnx_model_path)
print("Inputs:")
for input in model.graph.input:
    print(input.name, input.type.tensor_type.shape)

Inputs:
input dim {
  dim_param: "unk__102"
}
dim {
  dim_value: 28
}
dim {
  dim_value: 28
}
dim {
  dim_value: 1
}



In [3]:
# Step 2: Define the input shape and create a Relay module
input_name = "input"  # Adjust based on your ONNX model's input name
shape_dict = {input_name: (1, 1, 28, 28)}  # Update based on your input shape

# Convert ONNX model to Relay IR
mod, params = relay.frontend.from_onnx(model, shape_dict)

In [None]:
# Step 3: Set the optimization target and compile the model
# Choose "llvm" for CPU or "cuda" for GPU
target = "llvm"
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.


In [9]:
# Step 4: Save the compiled model
compiled_lib_path = "tvm_model.so"
lib.export_library(compiled_lib_path)
print(f"Optimized model saved to {compiled_lib_path}")


Optimized model saved to tvm_model.so


In [6]:
# Step 5: Load and run the model
# Create a runtime executor
dev = tvm.device(target ,0)  # Use tvm.gpu(0) for GPU
module = graph_executor.GraphModule(lib["default"](dev))

# Prepare input data
input_data = np.random.rand(1, 1, 28, 28).astype("float32")  # Random input data
module.set_input(input_name, input_data)

# Run inference
module.run()
output = module.get_output(0).numpy()
print("Inference output:", output)

Inference output: [[3.1590557e-09 3.8065000e-14 4.3464289e-07 2.1407073e-12 5.0909078e-04
  6.9542942e-15 5.4215809e-11 2.4733222e-09 9.9949050e-01 9.5523500e-09]]


In [10]:
with open("tvm_model_json","w") as f:
    f.write(lib.get_graph_json())