In [1]:
import onnx
import tvm
from tvm import relay
from tvm.contrib import graph_executor

In [2]:
%pip install onnx

Note: you may need to restart the kernel to use updated packages.


In [13]:
import onnx

# Path to your ONNX model
onnx_model_path = "/home/rakshan/TVM_assignment/onnx_lightweight_resnet_mnist.onnx"

# Load ONNX model
model = onnx.load(onnx_model_path)

# Check the input information
print("Inputs:")
for input_tensor in model.graph.input:
    name = input_tensor.name
    type_proto = input_tensor.type
    shape = None
    
    # Try to extract the shape if available
    if type_proto.HasField("tensor_type"):
        tensor_type = type_proto.tensor_type
        if tensor_type.HasField("shape"):
            dims = tensor_type.shape.dim
            shape = [dim.dim_value if dim.HasField("dim_value") else None for dim in dims]
    
    print(f"  Name: {name}")
    print(f"  Shape: {shape}")


Inputs:
  Name: input
  Shape: [None, 112, 112, 3]


In [17]:

# Path to your ONNX model
onnx_model_path = "/home/rakshan/TVM_assignment/onnx_lightweight_resnet_mnist.onnx"

# Load ONNX model
onnx_model = onnx.load(onnx_model_path)

# Define the target device
target = "llvm"  # Change to "cuda" for GPU

# Correct input name and shape
input_name = "input"  # Update with the correct input name
input_shape = (1, 112, 112, 3)  # Define a fixed batch size, replacing None with 1

# Convert ONNX model to Relay IR
shape_dict = {input_name: input_shape}
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

# Optimize the model using TVM
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

# Export the compiled library
lib.export_library("optimized_model.so")

# Load the optimized module for testing or deployment
dev = tvm.device(target, 0)
module = graph_executor.GraphModule(lib["default"](dev))

# Example of running inference
import numpy as np

# Create a sample input array with the appropriate shape
input_data = np.random.rand(*input_shape).astype("float32")
module.set_input(input_name, input_data)
module.run()

# Get output and print results
output_data = module.get_output(0).asnumpy()
print("Output shape:", output_data.shape)

Output shape: (1, 10)


In [18]:
with open("model_json","w") as f:
    f.write(lib.get_graph_json())