# Export Classification Models via Torchvision

> This script needs to be run on **Google Colab** or a **Custom Server**. If you are using a custom server, we also recommend that you set up a virtual environment via conda before running this script, as it requires the specified version of the framework to run properly.

In [1]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0



* Config the `model_name` and `torchvision.models.<backbone>` you want.

> **Note**: You can find all the available model options in the official  [Torchvision](https://pytorch.org/vision/stable/models.html) documentation.

In [2]:
import torch, torchvision

model_name = 'resnet18'
model = torchvision.models.resnet18(True).cpu()
model.eval()

# You can also use the PyTorch API here to train the model and then convert it to runtime format, but remember to switch to eval() mode before doing so.



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### Convert the Torch model into ONNX format

Before converting to TFLite, be sure to check which OPS version of ONNX is supported by the accelerator.

In [3]:
!pip install onnxruntime
!pip install onnx



In [4]:
import onnxruntime

# Please set the input dimensions of the model and turn off the gradient mode to adjust it to run mode.
inputs = torch.randn(1, 3, 320, 224, requires_grad=False).cpu()

torch.onnx.export(model, inputs, f"{model_name}.onnx", verbose=False, opset_version=16, do_constant_folding=False, dynamic_axes=None)
print(f"Input shape: {onnxruntime.InferenceSession(f'{model_name}.onnx').get_inputs()[0].shape}")

Input shape: [1, 3, 320, 224]


In [5]:
# Verify the accuracy of the model output
import numpy as np
ort_session = onnxruntime.InferenceSession(f'{model_name}.onnx')
np.unique(model(inputs).detach().numpy().astype(np.float16)==ort_session.run(None, {ort_session.get_inputs()[0].name: inputs.numpy()})[0].astype(np.float16), return_counts=True)

(array([False,  True]), array([  2, 998]))

### Convert the ONNX model into TFLite format

Before converting to TFLite, check which OPS version of TFLite is supported by the accelerator.


In [6]:
!pip install keras
!pip install tf_keras
!pip install "sng4onnx>=1.0.1
!pip install "onnx_graphsurgeon>=0.3.26"
!pip install "onnx2tf>1.17.5,<=1.22.3",
!pip install "onnxslim>=0.1.31",

/bin/bash: -c: line 1: unexpected EOF while looking for matching `"'
/bin/bash: -c: line 2: syntax error: unexpected end of file


In [7]:
import os, onnx2tf
import tensorflow as tf

!rm -rf {model_name}
os.mkdir(model_name)

onnx2tf.convert(
    input_onnx_file_path=f"{model_name}.onnx",
    output_folder_path=model_name,
    not_use_onnxsim=True, non_verbose=False, verbosity=1,
    copy_onnx_input_output_names_to_tflite=True,
    output_integer_quantized_tflite=False,
    quant_type="per-tensor",  # "per-tensor" (faster) or "per-channel" (slower but more accurate)
)

print(f"Input shape: {tf.lite.Interpreter(model_path=f'{model_name}/{model_name}_float32.tflite').get_input_details()[0]['shape']}")


[32mAutomatic generation of each OP name complete![0m


[32msaved_model output complete![0m
wget https://github.com/PINTO0309/onnx2tf/releases/download/1.16.31/flatc.tar.gz && tar -zxvf flatc.tar.gz && sudo chmod +x flatc && sudo mv flatc /usr/bin/
[32mFloat32 tflite output complete![0m
wget https://github.com/PINTO0309/onnx2tf/releases/download/1.16.31/flatc.tar.gz && tar -zxvf flatc.tar.gz && sudo chmod +x flatc && sudo mv flatc /usr/bin/
[32mFloat16 tflite output complete![0m
Input shape: [  1 320 224   3]


In [13]:
interpreter = tf.lite.Interpreter(model_path=f'{model_name}/{model_name}_float32.tflite')
interpreter.allocate_tensors()

interpreter.set_tensor(interpreter.get_input_details()[0]['index'], inputs.permute(0, 2, 3, 1).numpy())
interpreter.invoke()
print()

np.unique(model(inputs).detach().numpy()[0].astype(np.float16)==interpreter.get_tensor(interpreter.get_output_details()[0]['index'])[0].astype(np.float16), return_counts=True)




(array([False,  True]), array([999,   1]))