In [None]:
# default_exp converter.core

# Model Interconversion

> API details.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from chitra.utility.import_utils import INSTALLED_MODULES, is_installed

In [None]:
# export
import torch.onnx


def pytorch_to_onnx(model, tensor, export_path="temp.onnx"):
    # Input to the model
    torch_out = model(tensor)

    # Export the model
    torch.onnx.export(
        model,  # model being run
        tensor,  # model input (or a tuple for multiple inputs)
        export_path,  # where to save the model (can be a file or file-like object)
        export_params=True,  # store the trained parameter weights inside the model file
        opset_version=10,  # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names=["input"],  # the model's input names
        output_names=["output"],  # the model's output names
        dynamic_axes={
            "input": {0: "batch_size"},  # variable length axes
            "output": {0: "batch_size"},
        },
    )
    return export_path

In [None]:
# export
import onnx
import tf2onnx
from onnx2pytorch import ConvertModel


def onnx_to_pytorch(onnx_model):
    if isinstance(onnx_model, str):
        onnx_model = onnx.load(onnx_model)
    onnx.checker.check_model(onnx_model)
    pytorch_model = ConvertModel(onnx_model)
    return pytorch_model


def tf2_to_onnx(model, opset=None, output_path=None, **kwargs):
    inputs_as_nchw = kwargs.get("inputs_as_nchw", "input0:0")
    onnx_model = tf2onnx.convert.from_keras(
        model, opset=opset, output_path=output_path, inputs_as_nchw=inputs_as_nchw
    )
    return onnx_model


def tf2_to_pytorch(model, opset=None, **kwargs):
    with tempfile.NamedTemporaryFile(mode='w') as fw:
        filename = fw.name
        onnx_model = tf2_to_onnx(tf_model, opset, output_path=filename, **kwargs)
        fw.seek(0)
        torch_model = onnx_to_pytorch(filename)
    return torch_model

## example

In [None]:

import numpy as np
import timm

model1 = timm.create_model("resnet18")
model1.eval()

model_inter_path = pytorch_to_onnx(model1, torch.randn(1, 3, 224, 224))
model2 = onnx_to_pytorch(model_inter_path)

x = torch.randn(1, 3, 224, 224)
np.allclose(model1(x).detach().numpy(), model2(x).detach().numpy(), 1e-4)

In [None]:
import tensorflow as tf
import torch

In [None]:
tf.__version__

In [None]:
# tf_model = tf.keras.applications.MobileNetV2()
# model_test = tf2_to_pytorch(tf_model, inputs_as_nchw=None, opset=13).eval()

In [None]:
import numpy as np
from chitra.image import Chitra

image = Chitra("https://c.files.bbci.co.uk/957C/production/_111686283_pic1.png")
image.image = image.image.resize((224, 224)).convert("RGB")
image.imshow()

In [None]:
x1 = tf.cast(image.to_tensor("tf"), tf.float32) / 127.5 - 1.0
x1 = tf.expand_dims(x1, 0)

x2 = image.numpy()[:].astype(np.float32) / 255
x2 = np.expand_dims(x2, 0)
x2 = torch.from_numpy(x2)
x2 = x2.permute(0, 3, 1, 2)

In [None]:
x2.shape

In [None]:
Chitra(((x1[0] + 1) * 127.5).numpy().astype("uint8")).imshow()

In [None]:
from chitra.core import IMAGENET_LABELS

res1 = tf.math.softmax(tf_model.predict(x1), 1)
IMAGENET_LABELS[tf.argmax(res1, 1).numpy()[0]]

In [None]:
res2 = my_model(x2)
# IMAGENET_LABELS[torch.argmax(res2).item()]

In [None]:
my_model

In [None]:
x2.shape, res2.shape