# Export Classification Models on Torchvision via PyTorch

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a href="https://drive.google.com/file/d/129iCu2nUBs-EwaaTsVT3HMSvukFpXkh8/view?usp=sharing">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>
</table>
<br>
</table>

**Import models from the Torchvision framework (PyTorch):** :

Load pre-trained ResNets models from the Torchvision library. This library is an open source development tool for Meta for computer vision models. You can select the currently validated Image Classification model in `model_name`.

In [None]:
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0
!pip install torchsummary



In [None]:
import torch, torchvision
from torchsummary import summary

model_name = "resnet34"   # @param ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "resnext101_64x4d", "wide_resnet50_2", "wide_resnet101_2"]
model = eval(f"torchvision.models.{model_name}")(True).eval()

summary(model, (3, 224, 224))




----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

# Delegation

## Export to ONNX
Usage: [[None]]()

**Step1.** Export the model to ONNX format.

In [None]:
!pip install onnx

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, f"{model_name}.onnx", verbose=False, do_constant_folding=False, dynamic_axes=None)  # delegating to hailo requred opset_version=11



**Step2.** Reproduce the output process manually.

In [None]:
!pip install onnxruntime
import numpy as np
import onnxruntime

ort_session = onnxruntime.InferenceSession(f'{model_name}.onnx')

OUTPUT_TORCHSCRIPT = model(dummy_input).detach().numpy()
OUTPUT_ONNX = ort_session.run([ort_session.get_outputs()[0].name], {ort_session.get_inputs()[0].name: dummy_input.numpy()})[0]

print('Model Consistancy Check Passed (ONNX):', np.allclose(OUTPUT_TORCHSCRIPT, OUTPUT_ONNX, rtol=1e-3, atol=1e-6))

Model Consistancy Check Passed (ONNX): True


## Export to TFLite
Usage: [[None]]()

**Step1.** Export to TFLite format.

In [None]:
!pip install keras tf_keras
!pip install "sng4onnx>=1.0.1
!pip install "onnx_graphsurgeon>=0.3.26"
!pip install "onnx2tf>1.17.5,<=1.22.3"
!pip install "onnxslim>=0.1.31"
!pip install sng4onnx

/bin/bash: -c: line 1: unexpected EOF while looking for matching `"'
/bin/bash: -c: line 2: syntax error: unexpected end of file


In [None]:
import onnx2tf, os
import tensorflow as tf

os.mkdir(model_name)
onnx2tf.convert(
    input_onnx_file_path=f"{model_name}.onnx", output_folder_path=model_name, not_use_onnxsim=True, non_verbose=False, verbosity=1,
    copy_onnx_input_output_names_to_tflite=True, output_integer_quantized_tflite=False, quant_type="per-tensor",  # "per-tensor" (faster) or "per-channel" (slower but more accurate)
)


[32mAutomatic generation of each OP name complete![0m


[32msaved_model output complete![0m
wget https://github.com/PINTO0309/onnx2tf/releases/download/1.16.31/flatc.tar.gz && tar -zxvf flatc.tar.gz && sudo chmod +x flatc && sudo mv flatc /usr/bin/
[32mFloat32 tflite output complete![0m
wget https://github.com/PINTO0309/onnx2tf/releases/download/1.16.31/flatc.tar.gz && tar -zxvf flatc.tar.gz && sudo chmod +x flatc && sudo mv flatc /usr/bin/
[32mFloat16 tflite output complete![0m


<tf_keras.src.engine.functional.Functional at 0x7b72142d0460>

**Step2.** Reproduce the output process manually.

In [None]:
interpreter = tf.lite.Interpreter(model_path=f'{model_name}/{model_name}_float32.tflite')
interpreter.allocate_tensors()

interpreter.set_tensor(interpreter.get_input_details()[0]['index'], dummy_input.permute(0, 2, 3, 1).numpy())
interpreter.invoke()
OUTPUT_TFLITE = interpreter.get_tensor(interpreter.get_output_details()[0]['index'])[0]

print('Model Conformance Check Passed (TFLite):', np.allclose(OUTPUT_TORCHSCRIPT, OUTPUT_ONNX, rtol=1e-3, atol=1e-6))

Model Conformance Check Passed (TFLite): True
