# Export Classification Models via Torchvision

This script needs to be run on Google Colab or a custom server. If you are using a custom server, we also recommend that you set up a virtual environment via conda before running this script, as it requires the specified version of the framework to run properly.


In [None]:
#!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0

* Config the `model_name` and `torchvision.models.<backbone>` you want.

> **Note**: You can find all the available model options in the official  [Torchvision](https://pytorch.org/vision/stable/models.html) documentation.

In [None]:
#import torch, torchvision

#model_name = 'resnet18'
#model = torchvision.models.resnet18(True)

# you can also training the model here using PyTorch APIs before convert to runtime format.

### Convert the Torch model into ONNX format

Before converting to TFLite, be sure to check which OPS version of ONNX is supported by the accelerator.

In [None]:
#!pip install onnxruntime
#!pip install onnx

In [None]:
#import onnxruntime

# Please set the input dimensions of the model and turn off the gradient mode to adjust it to run mode.
#torch.onnx.export(model, torch.randn(1, 3, 224, 224, requires_grad=False), f"{model_name}.onnx", opset_version=16)
#print(f"Input shape: {onnxruntime.InferenceSession(f'{model_name}.onnx').get_inputs()[0].shape}")

### Convert the ONNX model into TFLite format

Before converting to TFLite, check which OPS version of TFLite is supported by the accelerator.


In [None]:
# onnx-tf was designed for TensorFlow 1.X, so force this version.
%tensorflow_version 1.x
import tensorflow as tf

!pip install onnx-tf

ValueError: Tensorflow 1 is unsupported in Colab.

Your notebook should be updated to use Tensorflow 2.
See the guide at https://www.tensorflow.org/guide/migrate#migrate-from-tensorflow-1x-to-tensorflow-2.

In [None]:
from onnx_tf.backend import prepare

tf_rep = prepare(onnx.load(f'{model_name}.onnx'))
tf_rep.export_graph('./onnx2tf')
tf_model = tf.saved_model.load('./onnx2tf')
converter = tf.lite.TFLiteConverter.from_saved_model('./onnx2tf')

# converter.allow_custom_ops = True
# converter.experimental_new_converter = True
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]

tflite_model = converter.convert()

with open(f'{model_name}.tflite', 'wb') as f:
    f.write(tflite_model)