In [7]:
import os, sys
sys.path.append('../')
import torch
import torch.nn as nn
from converter import Torch2TFLiteConverter

In [8]:
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1, groups=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size, stride, padding=1, groups=groups, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()   # tflite中会显示Logistic，其实就是Sigmoid，在tflite的源码中可以找到这个Logistic.h文件
    )

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = conv_nxn_bn(3, 16)

    def forward(self, x):
        return self.conv(x)

In [9]:
model = Model()
inputs = torch.rand(1, 3, 32, 32)
model(inputs)

torch.save(model, './model.pth')

converter = Torch2TFLiteConverter('./model.pth', tflite_model_save_path='model_float32.lite', target_shape=(32,32,3))
converter.convert()

converter = Torch2TFLiteConverter('./model.pth', tflite_model_save_path='./model_int8.lite', target_shape=(32,32,3),
                                representative_dataset=torch.randint(0, 255, (32,3,224,224)).float())

converter.convert()

INFO:root:Old temp directory removed
INFO:root:Temp directory created at /tmp/model_converter/
INFO:root:PyTorch model successfully loaded and mapped to CPU
INFO:root:Sample input file path not specified, random data will be generated
INFO:root:Sample input randomly generated
INFO:root:Onnx model is saved to /tmp/model_converter/model.onnx


INFO:tensorflow:Assets written to: /tmp/model_converter/tf_model/assets


INFO:tensorflow:Assets written to: /tmp/model_converter/tf_model/assets
INFO:root:Tflite model is saved to model_float32.lite
INFO:root:TFLite interpreter successfully loaded from, model_float32.lite
INFO:root:MSE (Mean-Square-Error): 1.1007556314945768e-15	MAE (Mean-Absolute-Error): 2.0235816577951482e-08
INFO:root:Old temp directory removed
INFO:root:Temp directory created at /tmp/model_converter/
INFO:root:PyTorch model successfully loaded and mapped to CPU
INFO:root:Sample input file path not specified, random data will be generated
INFO:root:Sample input randomly generated
INFO:root:Onnx model is saved to /tmp/model_converter/model.onnx


INFO:tensorflow:Assets written to: /tmp/model_converter/tf_model/assets


INFO:tensorflow:Assets written to: /tmp/model_converter/tf_model/assets
fully_quantize: 0, inference_type: 6, input_inference_type: 0, output_inference_type: 0
INFO:root:Tflite model is saved to ./model_int8.lite
INFO:root:TFLite interpreter successfully loaded from, ./model_int8.lite
INFO:root:MSE (Mean-Square-Error): 7.4524177762214094e-06	MAE (Mean-Absolute-Error): 0.0022050016559660435
