# Export Classification Models via Torchvision

> This script needs to be run on **Google Colab** or a **Custom Server**. If you are using a custom server, we also recommend that you set up a virtual environment via python-venv or conda before running this script, as it requires the specified version of the framework to run properly.

In [1]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0

Collecting torch==2.4.0
  Downloading torch-2.4.0-cp38-cp38-manylinux1_x86_64.whl (797.2 MB)
[K     |████████████████████████████████| 797.2 MB 29 kB/s  eta 0:00:011   |██▏                             | 54.7 MB 3.4 MB/s eta 0:03:37     |█████▍                          | 134.4 MB 8.0 MB/s eta 0:01:23     |███████▋                        | 189.3 MB 5.7 MB/s eta 0:01:48     |███████████████▋                | 388.1 MB 10.1 MB/s eta 0:00:41     |█████████████████▏              | 427.6 MB 4.0 MB/s eta 0:01:33     |████████████████████▎           | 503.9 MB 524 kB/s eta 0:09:20MB 4.0 MB/s eta 0:00:571 MB 5.4 MB/s eta 0:00:08
[?25hCollecting torchvision==0.19.0
  Downloading torchvision-0.19.0-cp38-cp38-manylinux1_x86_64.whl (7.0 MB)
[K     |████████████████████████████████| 7.0 MB 1.5 MB/s eta 0:00:01
[?25hCollecting torchaudio==2.4.0
  Downloading torchaudio-2.4.0-cp38-cp38-manylinux1_x86_64.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 460 kB/s eta 0:00:01
[?25hCollect

* Config the `model_name` and `torchvision.models.<backbone>` you want.

> **Note**: You can find all the available model options in the official  [Torchvision](https://pytorch.org/vision/stable/models.html) documentation.

In [2]:
import torch, torchvision

model_name = 'resnet18'
model = torchvision.models.resnet18(True).cpu()
model.eval()

# You can also use the PyTorch API here to train the model and then convert it to runtime format, but remember to switch to eval() mode before doing so.

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/r300/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_11492/3926246098.py", line 1, in <module>
    import torch, torchvision
ModuleNotFoundError: No module named 'torch'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/r300/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/home/r300/.local/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1396, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/home/r300/.local/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1287, in structured_traceback
    return VerboseTB.structured_traceback(
  File "/home/r300/.local/lib/python3.8/site-packages/IPython/core/u

### Convert the Torch model into ONNX format

Before converting to TFLite, be sure to check which OPS version of ONNX is supported by the accelerator.

In [None]:
!pip install onnxruntime
!pip install onnx

In [None]:
import onnxruntime

# Please set the input dimensions of the model and turn off the gradient mode to adjust it to run mode.
inputs = torch.randn(1, 3, 320, 224, requires_grad=False).cpu()

torch.onnx.export(model, inputs, f"{model_name}.onnx", verbose=False, opset_version=16, do_constant_folding=False, dynamic_axes=None)
print(f"Input shape: {onnxruntime.InferenceSession(f'{model_name}.onnx').get_inputs()[0].shape}")

In [None]:
# Verify the accuracy of the model output
import numpy as np
ort_session = onnxruntime.InferenceSession(f'{model_name}.onnx')
np.unique(model(inputs).detach().numpy().astype(np.float16)==ort_session.run(None, {ort_session.get_inputs()[0].name: inputs.numpy()})[0].astype(np.float16), return_counts=True)

### Convert the ONNX model into TFLite format

Before converting to TFLite, check which OPS version of TFLite is supported by the accelerator.


In [None]:
!pip install keras
!pip install tf_keras
!pip install "sng4onnx>=1.0.1
!pip install "onnx_graphsurgeon>=0.3.26"
!pip install "onnx2tf>1.17.5,<=1.22.3",
!pip install "onnxslim>=0.1.31",

In [None]:
import os, onnx2tf
import tensorflow as tf

!rm -rf {model_name}
os.mkdir(model_name)

onnx2tf.convert(
    input_onnx_file_path=f"{model_name}.onnx",
    output_folder_path=model_name,
    not_use_onnxsim=True, non_verbose=False, verbosity=1,
    copy_onnx_input_output_names_to_tflite=True,
    output_integer_quantized_tflite=False,
    quant_type="per-tensor",  # "per-tensor" (faster) or "per-channel" (slower but more accurate)
)

print(f"Input shape: {tf.lite.Interpreter(model_path=f'{model_name}/{model_name}_float32.tflite').get_input_details()[0]['shape']}")

In [None]:
interpreter = tf.lite.Interpreter(model_path=f'{model_name}/{model_name}_float32.tflite')
interpreter.allocate_tensors()

interpreter.set_tensor(interpreter.get_input_details()[0]['index'], inputs.permute(0, 2, 3, 1).numpy())
interpreter.invoke()

np.unique(model(inputs).detach().numpy()[0].astype(np.float16)==interpreter.get_tensor(interpreter.get_output_details()[0]['index'])[0].astype(np.float16), return_counts=True)