In [2]:
import torch
from models.csrnet_mbv3 import MobileCSRNet
import os

model = MobileCSRNet()
model.load_state_dict(torch.load("build/csrnet_mobile_B.pt", map_location='cpu'))
model.eval()

dummy_input = torch.randn(1, 3, 512, 512)
output = model(dummy_input)
print("🔍 Output shape:", output.shape)

# Optional: TorchScript
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save("build/csrnet_mobile_B_scripted.pt")

# ONNX Export
torch.onnx.export(
    model,
    dummy_input,
    "build/csrnet_mobile_B.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
    opset_version=11,
    export_params=True,
    do_constant_folding=True
)


🔍 Output shape: torch.Size([1, 1, 64, 64])


In [3]:
import onnx

onnx_model_path = "build/csrnet_mobile_B.onnx"
model = onnx.load(onnx_model_path)
onnx.checker.check_model(model)

print("✅ ONNX model is valid!")
print("Opset version:", [op.version for op in model.opset_import])
print("Inputs:", [i.name for i in model.graph.input])
print("Outputs:", [o.name for o in model.graph.output])

✅ ONNX model is valid!
Opset version: [11]
Inputs: ['input']
Outputs: ['output']


In [16]:
# Convert onnx -> pb -> tflite use this link https://colab.research.google.com/drive/1_fbN7caPwhKvxNDkljw-dHt_xtel7Eob#scrollTo=ld0uLw7QRBir