Below is the docker file. Install the requirements as per imports.

In [None]:
FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime

ARG model_conv_file='model_conv.py'

RUN apt-get update
RUN apt-get install -y libgl1-mesa-glx git protobuf-compiler
RUN apt update && apt install -y libsm6 libxext6
RUN apt-get install -y libxrender-dev
RUN apt-get install -y libglib2.0-0

RUN mkdir -p /root/Documents/

WORKDIR /root/Documents/
COPY requirements.txt /root/Documents/requirements.txt
RUN python -m pip install --upgrade pip
RUN pip install -r requirements.txt

COPY ${model_conv_file} /root/Documents/model_conv.py


ENV PYTHONPATH="/root/Documents/"
# CMD ["python3", "model_conv.py", "--saved-model", "mobilenetv2_model.pt", "--output", "mobilenetv2_model.onnx"]
ENTRYPOINT [ "python3", "model_conv.py" ]

In [None]:
from typing import Optional
import os

import sys
import shutil
import logging
import cv2
import numpy as np
import json

import tfcoreml
import argparse
import coremltools as ct
import onnxmltools
import onnx
from onnx2pytorch import ConvertModel
from onnx_tf.backend import prepare
import torch
# import keras_retinanet
# import keras_retinanet.layers, keras_retinanet.losses
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import load_model

from torch.utils.mobile_optimizer import optimize_for_mobile
import inspect
from onnx import numpy_helper
from onnx2keras.layers import AVAILABLE_CONVERTERS

from tf2onnx.convert import get_args
from tf2onnx.tf_loader import from_saved_model
from tf2onnx.utils import save_protobuf, save_onnx_zip, set_debug_mode
from tf2onnx.constants import ENV_TF2ONNX_CATCH_ERRORS
from tf2onnx.optimizer import optimize_graph
from tf2onnx.tf_utils import compress_graph_def
from tf2onnx.graph import ExternalTensorStorage
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx.verbose_logging import get_verbosity_level


######################################## Common Functions ##################################################

def set_logging(name=None, verbose=True):
    # Sets level and returns logger
    rank = int(os.getenv('RANK', -1))  # rank in world for Multi-GPU trainings
    logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
    return logging.getLogger(name)

def make_default_custom_op_handler(domain):
    def default_custom_op_handler(ctx, node, name, args):
        node.domain = domain
        return node
    return default_custom_op_handler

def onnx_node_attributes_to_dict(args):
    """
    Parse ONNX attributes to Python dictionary
    :param args: ONNX attributes object
    :return: Python dictionary
    """
    def onnx_attribute_to_dict(onnx_attr):
        """
        Parse ONNX attribute
        :param onnx_attr: ONNX attribute
        :return: Python data type
        """
        if onnx_attr.HasField('t'):
            return numpy_helper.to_array(getattr(onnx_attr, 't'))

        for attr_type in ['f', 'i', 's']:
            if onnx_attr.HasField(attr_type):
                return getattr(onnx_attr, attr_type)

        for attr_type in ['floats', 'ints', 'strings']:
            if getattr(onnx_attr, attr_type):
                return list(getattr(onnx_attr, attr_type))
    return {arg.name: onnx_attribute_to_dict(arg) for arg in args}

def _onnx_to_keras(onnx_model, input_names,
                  input_shapes=None, name_policy=None, verbose=True, change_ordering=False):
    """
    Convert ONNX graph to Keras model format
    :param onnx_model: loaded ONNX model
    :param input_names: list with input names
    :param input_shapes: override input shapes (experimental)
    :param name_policy: override layer names. None, "short" or "renumerate" (experimental)
    :param verbose: verbose output
    :param change_ordering: change ordering to HWC (experimental)
    :return: Keras model
    """
    # Use channels first format by default.
    keras_fmt = keras.backend.image_data_format()
    keras.backend.set_image_data_format('channels_first')

    if verbose:
        logging.basicConfig(level=logging.DEBUG)

    # logger = logging.getLogger('onnx2keras')

    onnx_weights = onnx_model.graph.initializer
    onnx_inputs = onnx_model.graph.input
    onnx_outputs = [i.name for i in onnx_model.graph.output]
    onnx_nodes = onnx_model.graph.node

    LOGGER.info('List input shapes:')
    LOGGER.info(input_shapes)

    LOGGER.debug('List inputs:')
    for i, input in enumerate(onnx_inputs):
        LOGGER.info('Input {0} -> {1}.'.format(i, input.name))

    LOGGER.info('List outputs:')
    for i, output in enumerate(onnx_outputs):
        LOGGER.info('Output {0} -> {1}.'.format(i, output))

    LOGGER.info('Gathering weights to dictionary.')
    weights = {}
    for onnx_w in onnx_weights:
        try:
            if len(onnx_w.ListFields()) < 4:
                onnx_extracted_weights_name = onnx_w.ListFields()[1][1]
            else:
                onnx_extracted_weights_name = onnx_w.ListFields()[2][1]
            weights[onnx_extracted_weights_name] = numpy_helper.to_array(onnx_w)
        except:
            onnx_extracted_weights_name = onnx_w.ListFields()[3][1]
            weights[onnx_extracted_weights_name] = numpy_helper.to_array(onnx_w)

        LOGGER.info('Found weight {0} with shape {1}.'.format(
                     onnx_extracted_weights_name,
                     weights[onnx_extracted_weights_name].shape))

    layers = dict()
    lambda_funcs = dict()
    keras_outputs = []
    keras_inputs = []

    for i, input_name in enumerate(input_names):
        for onnx_i in onnx_inputs:
            if onnx_i.name == input_name:
                if input_shapes:
                    input_shape = input_shapes[i]
                else:
                    input_shape = [i.dim_value for i in onnx_i.type.tensor_type.shape.dim][1:]

                layers[input_name] = keras.layers.InputLayer(
                    input_shape=input_shape, name=input_name
                ).output

                keras_inputs.append(layers[input_name])

                LOGGER.info('Found input {0} with shape {1}'.format(input_name, input_shape))

    # Convert every operation separable
    node_names = []
    for node_index, node in enumerate(onnx_nodes):
        node_type = node.op_type
        node_params = onnx_node_attributes_to_dict(node.attribute)

        # Add global converter info:
        node_params['change_ordering'] = change_ordering
        node_params['name_policy'] = name_policy

        node_name = str(node.output[0])
        keras_names = []
        for output_index, output in enumerate(node.output):
            if name_policy == 'short':
                keras_name = keras_name_i = str(output)[:8]
                suffix = 1
                while keras_name_i in node_names:
                    keras_name_i = keras_name + '_' + str(suffix)
                    suffix += 1
                keras_names.append(keras_name_i)
            elif name_policy == 'renumerate':
                postfix = node_index if len(node.output) == 1 else "%s_%s" % (node_index, output_index)
                keras_names.append('LAYER_%s' % postfix)
            else:
                keras_names.append(output)

        if len(node.output) != 1:
            LOGGER.warning('Trying to convert multi-output node')
            node_params['_outputs'] = list(node.output)
            node_names.extend(keras_names)
        else:
            keras_names = keras_names[0]
            node_names.append(keras_names)

        # If needed more conversion info, enable it
        # LOGGER.info('######')
        # LOGGER.info('...')
        # LOGGER.info('Converting ONNX operation')
        # LOGGER.info('type: %s', node_type)
        # LOGGER.info('node_name: %s', node_name)
        # LOGGER.info('node_params: %s', node_params)
        # LOGGER.info('...')

        LOGGER.info('Check if all inputs are available:')
        if len(node.input) == 0 and node_type != 'Constant':
            raise AttributeError('Operation doesn\'t have an input. Aborting.')

        for i, node_input in enumerate(node.input):
            LOGGER.info('Check input %i (name %s).', i, node_input)
            if node_input not in layers:
                LOGGER.info('The input not found in layers / model inputs.')

                if node_input in weights:
                    LOGGER.info('Found in weights, add as a numpy constant.')
                    layers[node_input] = weights[node_input]
                else:
                    raise AttributeError('Current node is not in weights / model inputs / layers.')
        else:
            LOGGER.info('... found all, continue')

        keras.backend.set_image_data_format('channels_first')
        if node_type == 'Clip': # Add more node types manually if it is failing; higher chances of failing while the conversion is being done
            node_params['min']=0
            node_params['max']=None
        AVAILABLE_CONVERTERS[node_type](
            node,
            node_params,
            layers,
            lambda_funcs,
            node_name,
            keras_names
        )
        if isinstance(keras_names, list):
            keras_names = keras_names[0]

        try:
            LOGGER.info('Output TF Layer -> ' + str(layers[keras_names]))
        except KeyError:
            pass

    # Check for terminal nodes
    for layer in onnx_outputs:
        if layer in layers:
            keras_outputs.append(layers[layer])

    # Create model
    model = keras.models.Model(inputs=keras_inputs, outputs=keras_outputs)

    if change_ordering:
        import numpy as np
        conf = model.get_config()

        for layer in conf['layers']:
            if layer['config'] and 'shared_axes' in layer['config']:
                # TODO: check axes first (if it's not 4D tensor)
                layer['config']['shared_axes'] = [1, 2]

            if layer['config'] and 'batch_input_shape' in layer['config']:
                layer['config']['batch_input_shape'] = \
                    tuple(np.reshape(np.array(
                        [
                            [None] +
                            list(layer['config']['batch_input_shape'][2:][:]) +
                            [layer['config']['batch_input_shape'][1]]
                        ]), -1
                    ))
            if layer['config'] and 'target_shape' in layer['config']:
                if len(list(layer['config']['target_shape'][1:][:])) > 0:
                    layer['config']['target_shape'] = \
                        tuple(np.reshape(np.array(
                                list(layer['config']['target_shape'][1:]) +
                                [layer['config']['target_shape'][0]]
                            ), -1),)

            if layer['config'] and 'data_format' in layer['config']:
                layer['config']['data_format'] = 'channels_last'
            if layer['config'] and 'axis' in layer['config']:
                if layer['config']['axis'] == 3:
                    layer['config']['axis'] = 1
                if layer['config']['axis'] == 1:
                    layer['config']['axis'] = 3

        for layer in conf['layers']:
            if 'function' in layer['config'] and layer['config']['function'][1] is not None:
                kerasf = list(layer['config']['function'])
                dargs = list(kerasf[1])
                func = lambda_funcs.get(layer['name'])

                if func:
                    if len(dargs) > 1:
                        params = inspect.signature(func).parameters
                        i = list(params.keys()).index('axes') if ('axes' in params) else -1

                        if i > 0:
                            i -= 1
                            axes = list(range(len(dargs[i].shape)))
                            axes = axes[0:1] + axes[2:] + axes[1:2]
                            dargs[i] = np.transpose(dargs[i], axes)

                        i = list(params.keys()).index('axis') if ('axis' in params) else -1

                        if i > 0:
                            i -= 1
                            axis = np.array(dargs[i])
                            axes_map = np.array([0, 3, 1, 2])
                            dargs[i] = axes_map[axis]
                    else:
                        if dargs[0] == -1:
                            dargs = [1]
                        elif dargs[0] == 3:
                            dargs = [1]

                kerasf[1] = tuple(dargs)
                layer['config']['function'] = tuple(kerasf)

        keras.backend.set_image_data_format('channels_last')
        model_tf_ordering = keras.models.Model.from_config(conf)

        for dst_layer, src_layer, conf in zip(model_tf_ordering.layers, model.layers, conf['layers']):
            W = src_layer.get_weights()
            # TODO: check axes first (if it's not 4D tensor)
            if conf['config'] and 'shared_axes' in conf['config']:
                W[0] = W[0].transpose(1, 2, 0)
            dst_layer.set_weights(W)

        model = model_tf_ordering

    keras.backend.set_image_data_format(keras_fmt)

    return model

def convert_common(frozen_graph, name="unknown", large_model=False, output_path=None,
                    output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, **kwargs):
    """Common processing for conversion."""
    
    model_proto = None
    external_tensor_storage = None
    const_node_values = None

    if custom_ops is not None:
        if custom_op_handlers is None:
            custom_op_handlers = {}
        custom_op_handlers.update(
            {op: (make_default_custom_op_handler(domain), []) for op, domain in custom_ops.items()})

    with tf.Graph().as_default() as tf_graph:
        if large_model:
            const_node_values = compress_graph_def(frozen_graph)
            external_tensor_storage = ExternalTensorStorage()
        if output_frozen_graph:
            save_protobuf(output_frozen_graph, frozen_graph)
        if not kwargs.get("tflite_path") and not kwargs.get("tfjs_path"):
            tf.import_graph_def(frozen_graph, name='')
        g = process_tf_graph(tf_graph, const_node_values=const_node_values,
                             custom_op_handlers=custom_op_handlers, **kwargs)
        if ENV_TF2ONNX_CATCH_ERRORS in os.environ:
            catch_errors = ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE"
        else:
            catch_errors = not large_model
        onnx_graph = optimize_graph(g, catch_errors)
        model_proto = onnx_graph.make_model("converted from {}".format(name),
                                            external_tensor_storage=external_tensor_storage)
    if output_path:
        if large_model:
            save_onnx_zip(output_path, model_proto, external_tensor_storage)
        else:
            save_protobuf(output_path, model_proto)

    return model_proto, external_tensor_storage

def load_sample_input(
            file_path: Optional[str] = None,
            target_shape: tuple = (224, 224, 3),
            seed: int = 10,
            normalize: bool = True
    ):
        if file_path is not None:
            # pass
            if (len(target_shape) == 3 and target_shape[-1] == 1) or len(target_shape) == 2:
                imread_flags = cv2.IMREAD_GRAYSCALE
            elif len(target_shape) == 3 and target_shape[-1] == 3:
                imread_flags = cv2.IMREAD_COLOR
            else:
                imread_flags = cv2.IMREAD_ANYCOLOR + cv2.IMREAD_ANYDEPTH
            try:
                img = cv2.resize(
                    src=cv2.imread(file_path, imread_flags),
                    dsize=target_shape[:2],
                    interpolation=cv2.INTER_LINEAR
                )
                if len(img.shape) == 3:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

                if normalize:
                    img = img * 1. / 255
                img = img.astype(np.float32)

                sample_data_np = np.transpose(img, (2, 0, 1))[np.newaxis, :, :, :]
                sample_data_torch = torch.from_numpy(sample_data_np)
                logging.info(f'Sample input successfully loaded from, {file_path}')

            except Exception:
                logging.error(f'Can not load sample input from, {file_path}')
                sys.exit(-1)

        else:
            logging.info(f'Sample input file path not specified, random data will be generated')
            np.random.seed(seed)
            data = np.random.random(target_shape).astype(np.float32)
            sample_data_np = np.transpose(data, (2, 0, 1))[np.newaxis, :, :, :]
            sample_data_torch = torch.from_numpy(sample_data_np)
            logging.info(f'Sample input randomly generated')

        return {'sample_data_np': sample_data_np, 'sample_data_torch': sample_data_torch}

########################################## Class for conversion ################################################

class ModelConversion:
    def __init__(
            self,
            saved_model: str,
            output: str,
            sample_file_path: Optional[str] = None,
            target_shape: tuple = (224, 224, 3),
            seed: int = 10,
            normalize: bool = True,
            opset: int = 13
    ):
        self.saved_model = saved_model  # initial model path
        self.output = output # final model path

        self.sample_file_path = sample_file_path
        self.target_shape = target_shape
        self.seed = seed
        self.normalize = normalize
        self.device = 'cpu'
        self.tmpdir = os.path.join(os.getcwd(),'model_conv')
        self.__check_tmpdir()
        self.temp_torch_path = os.path.join(self.tmpdir, 'temp_torch.pt')
        self.temp_onnx_path = os.path.join(self.tmpdir, 'temp_onnx.onnx')
        self.temp_coreml_path = os.path.join(self.tmpdir, 'temp_coreml.mlmodel')
        self.temp_keras_path = os.path.join(self.tmpdir, 'temp_keras.h5')
        self.temp_tf_path = os.path.join(self.tmpdir, 'tf_temp_model')
        self.func = ''
        self.final = ''
        self.infunc = ''
        self.opset = opset

        if os.path.dirname(self.output) == '':
            self.output = os.path.join(self.tmpdir, self.output)

    def log_info(self):
        LOGGER.info(f' -------- Conversion "{self.infunc}" is done and the model was saved at: {self.final} --------')
        
    def __check_tmpdir(self):
        try:
            if os.path.exists(self.tmpdir) and os.path.isdir(self.tmpdir):
                shutil.rmtree(self.tmpdir)
                logging.info(f'Old temp directory removed')
            os.makedirs(self.tmpdir, exist_ok=True)
            logging.info(f'Temp directory created at {self.tmpdir}')
        except Exception:
            logging.error('Can not create temporary directory, exiting!')
            sys.exit(-1)
    
    def load_torch_model(self, torch_model_path) -> torch.nn.Module:
        try:
            if torch_model_path.endswith('.pth') or torch_model_path.endswith('.pt'):
                model = torch.load(torch_model_path, map_location=self.device)
                # model = model.eval()
                logging.info('PyTorch model successfully loaded and mapped to CPU')
                return model
            else:
                logging.error('Specified file path not compatible with torch2tflite, exiting!')
                sys.exit(-1)
        except Exception:
            logging.error('Can not load PyTorch model. Please make sure'
                          'that model saved like `torch.save(model, PATH)`')
            sys.exit(-1)

    def torch_to_torchscript(self, model, output, optimize=False) -> None :
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        model = self.load_torch_model(model)
        im = torch.zeros(1, 3, *self.target_shape[:-1]).to(self.device)  #
        LOGGER.info(f'\n starting export with torch {torch.__version__}...')
        ts = torch.jit.trace(model, im, strict=False)
        d = {"shape": im.shape}
        extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()
        (optimize_for_mobile(ts) if optimize else ts).save(str(output), _extra_files=extra_files)
        self.log_info()
        
    def torch_to_onnx(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        # args.saved_model=torchvision.models.mobilenet_v2(pretrained = True, progress = True)
        # input=args.saved_model
        model = self.load_torch_model(input)
        sample_data = load_sample_input(self.sample_file_path, self.target_shape, self.seed, self.normalize)
        torch.onnx.export(
            model=model,
            args=sample_data['sample_data_torch'],
            f=output,
            verbose=False,
            export_params=True,
            do_constant_folding=False,
            input_names=['input'],
            opset_version=self.opset,
            output_names=['output'],
            )
        self.log_info()

    def tf_to_onnx(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        args = get_args()
        logging.basicConfig(level=get_verbosity_level(args.verbose))
        if args.debug:
            set_debug_mode(True)

        # logger = logging.getLogger(TF2ONNX_PACKAGE_NAME)

        args.saved_model = os.path.dirname(input)
        args.output = output
        args.opset = self.opset
        tensors_to_rename = {}
        graph_def = None
        inputs = None
        outputs = None
        model_path = None

        if args.rename_inputs:
            tensors_to_rename.update(zip(inputs, args.rename_inputs))
        if args.rename_outputs:
            tensors_to_rename.update(zip(outputs, args.rename_outputs))
        graph_def, inputs, outputs, initialized_tables, tensors_to_rename = from_saved_model(
                args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function,
                args.large_model, return_initialized_tables=True, return_tensors_to_rename=True,
                use_graph_names=args.use_graph_names)
        model_path = args.saved_model

        with tf.device("/cpu:0"):
            model_proto, _ = convert_common(
            graph_def,
            name=model_path,
            continue_on_error=args.continue_on_error,
            target=args.target,
            opset=args.opset,
            shape_override=args.shape_override,
            input_names=inputs,
            output_names=outputs,
            inputs_as_nchw=args.inputs_as_nchw,
            large_model=args.large_model,
            tensors_to_rename=tensors_to_rename,
            ignore_default=args.ignore_default,
            use_default=args.use_default,
            dequantize=args.dequantize,
            initialized_tables=initialized_tables,
            output_frozen_graph=args.output_frozen_graph,
            output_path=args.output)

        # LOGGER.info("")
        # LOGGER.info("Successfully converted TensorFlow model %s to ONNX", model_path)
        self.log_info()

        LOGGER.info("Model inputs: %s", [n.name for n in model_proto.graph.input])
        LOGGER.info("Model outputs: %s", [n.name for n in model_proto.graph.output])
        if args.output:
            if args.large_model:
                LOGGER.info("Zipped ONNX model is saved at %s. Unzip before opening in onnxruntime.", output)
            else:
                LOGGER.info("ONNX model is saved at %s", args.output)
        else:
            LOGGER.info("To export ONNX model to file, please run with `--output` option")

    def onnx_to_tf(self, input, output) -> None:
        if os.path.dirname(output) == '':
            output = os.path.join(self.tmpdir, output)
        if output.endswith('.pb'):
            output = os.path.splitext(output)[0]
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        onnx_model = onnx.load(input)
        onnx.checker.check_model(onnx_model)
        tf_rep = prepare(onnx_model)
        tf_rep.export_graph(output)
        self.log_info()

    def tf_to_tflite(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        input = os.path.dirname(input) if not os.path.isdir(input) else input # in intermediate calls directory is being passed and .pb path
        converter = tf.lite.TFLiteConverter.from_saved_model(input)
        tflite_model = converter.convert()
        with open(output, 'wb') as f:
            f.write(tflite_model)
        self.log_info()

    def keras_to_onnx(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        keras_model = load_model(input)

        ## Enable this if the model contains custome objects like keras retinanet
        # keras_model = load_model(input, custom_objects={'UpsampleLike':keras_retinanet.layers.UpsampleLike,
        # '_smooth_l1':keras_retinanet.losses.smooth_l1(), '_focal': keras_retinanet.losses.focal()})

        onnx_model = onnxmltools.convert_keras(keras_model)
        onnxmltools.utils.save_model(onnx_model, output)
        self.log_info()

    def onnx_to_torch(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        onnx_model = onnx.load(input)
        pytorch_model = ConvertModel(onnx_model)
        torch.save(pytorch_model.state_dict(), output)
        self.log_info()

    def coreml_to_onnx(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        coreml_model = ct.utils.load_spec(input)
        onnx_model = onnxmltools.convert_coreml(coreml_model)
        onnxmltools.utils.save_model(onnx_model, output)
        self.log_info()

    def onnx_to_coreml(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        model = ct.converters.onnx.convert(model=input)
        model.save(output)
        self.log_info()

    def tf_to_coreml(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        # input = os.path.dirname(input) if not os.path.isdir(input) else input # in intermediate calls directory is being passed and .pb path
        tfcoreml.convert(tf_model_path=input, ## check here if dir or .pb path
                    mlmodel_path=output,
                    output_feature_names=['softmax:0'],  # name of the output tensor (appended by ":0")
                    input_name_shape_dict={'input:0': [1, *self.target_shape]},  # map from input tensor name (placeholder op in the graph) to shape
                   ) # one of ['12', '11.2']
        self.log_info()
    
    def onnx_to_keras(self, input, output) -> None:
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        onnx_model = onnx.load(input)
        input_shapes = self.target_shape[::-1]
        k_model = _onnx_to_keras(onnx_model=onnx_model, input_names=['input'],input_shapes=[input_shapes])
        keras.models.save_model(k_model, output, overwrite=True,include_optimizer=True)
        self.log_info()

    def keras_to_tf(self, input, output) -> None:
        if os.path.dirname(output) == '':
            output = os.path.join(self.tmpdir, output)
        if output.endswith('.pb'):
            output = os.path.splitext(output)[0]
        self.final = output
        self.infunc = inspect.currentframe().f_code.co_name
        model = load_model(input)
        model.save(output)
        self.log_info()

    ########### Combinations ###########
    def torch_to_coreml(self, input, output) -> None:
        self.final = output
        self.torch_to_onnx(input, self.temp_onnx_path)
        self.onnx_to_coreml(self.temp_onnx_path, output)        

    def torch_to_tflite(self, input, output) -> None:
        self.final = output
        self.torch_to_tf(input, self.temp_tf_path)
        self.tf_to_tflite(self.temp_tf_path, output)

    def torch_to_keras(self, input, output) -> None:
        self.final = output
        self.torch_to_onnx(input, self.temp_onnx_path)
        self.onnx_to_keras(self.temp_onnx_path, output)

    def keras_to_torch(self, input, output) -> None:
        self.final = output
        self.keras_to_onnx(input, self.temp_onnx_path)
        self.onnx_to_torch(self.temp_onnx_path, output)
 
    def torch_to_tf(self, input, output) -> None:
        self.final = output
        self.torch_to_onnx(input, self.temp_onnx_path)
        self.onnx_to_tf(self.temp_onnx_path, output)

    def coreml_to_tf(self, input, output) -> None:
        self.final = output
        self.coreml_to_onnx(input, self.temp_onnx_path)
        self.onnx_to_tf(self.temp_onnx_path, output)
    ########### Combinations ###########

    function_mappings = {   # You can always add here new function and its body as above
            'torch_to_onnx':torch_to_onnx,  # working
            'tf_to_onnx':tf_to_onnx,    # working
            'keras_to_onnx':keras_to_onnx,  # working
            'coreml_to_onnx':coreml_to_onnx, # working
            'onnx_to_tf':onnx_to_tf,   # working
            'onnx_to_keras':onnx_to_keras,  # working
            'onnx_to_torch':onnx_to_torch,   # working
            'coreml_to_tf':coreml_to_tf,    #'''model type issue, float 32 and 64 '''
            'onnx_to_coreml':onnx_to_coreml,    # working
            'tf_to_tflite':tf_to_tflite,    # working
            'tf_to_coreml':tf_to_coreml,    #'''DOABLE, lib is old, it calls tf.graph() directly, so tf 2 doesnt work '''
            'torch_to_tflite':torch_to_tflite,  # working
            'torch_to_tf':torch_to_tf,  # working
            'torch_to_torchscript':torch_to_torchscript,  # working
            'torch_to_coreml':torch_to_coreml,  # working
            'keras_to_torch':keras_to_torch, # working
            'torch_to_keras':torch_to_keras, # working
            'keras_to_tf':keras_to_tf,  # working
            } 

    ext_mappings = {    # Add new extensions below if not present
            'pb':'tf',
            'h5':'keras',
            'pt':'torch',
            'pth':'torch', 
            'tflite':'tflite',
            'mlmodel':'coreml',
            'onnx':'onnx',
            'torchscript':'torchscript',
    }

    def convert(self) -> None:
        model_ext = os.path.splitext(self.saved_model)[1][1:]
        output_ext = os.path.splitext(self.output)[1][1:]
        
        assert len(model_ext) > 0, f'ERROR: The given "saved-model" is not a file. It doesnt have an extension at the end: {self.saved_model}.'
        assert len(output_ext) > 0, f'ERROR: The requested "output" is not a file. It doesnt have an extension at the end: {self.output}.'
        try:
            model_ext = self.ext_mappings[model_ext]
            output_ext = self.ext_mappings[output_ext]
        except KeyError:
            LOGGER.error(f'ERROR: Invalid extension key given either for the saved-model or for the output "model_extension: {model_ext}, output_extension: {output_ext}". Valid extensions are: {self.ext_mappings.keys()}')
            sys.exit(1)
        func = model_ext + '_to_' + output_ext
        try:
            self.func = self.function_mappings[func]
        except KeyError:
                LOGGER.error(f'ERROR: Invalid conversion requested, try again. The valid conversions are: {self.function_mappings.keys()}')
                sys.exit(1)
        LOGGER.info(f'          <--------> RUNNING THE CONVERSION: {self.func.__name__} <-------->          ')
        self.func(self, self.saved_model, self.output)


############################################  MAIN  ##############################################
# Both --saved-model and --output should be provided with extensions just like source and destination; It can be along with directories or just the file names
# for e.g  'python model_conv.py --saved-model /path/to/model.pb --output model.onnx'  --> This will convert the tf saved model to onnx format

if __name__ == '__main__':
    LOGGER = set_logging(__name__)
     
    parser = argparse.ArgumentParser()
    parser.add_argument('--saved-model', type=str, required=True, help='available formats are (.onnx, .mlmodel(coreml), .pt(torch), .pth(torch), .pb(tf), .tflite, .h5(keras) )')
    parser.add_argument('--output', type=str, required=True)
    parser.add_argument('--target-shape', type=tuple, nargs=3, default=(224, 224, 3))
    parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
    parser.add_argument('--sample-file', type=str)
    
    args = parser.parse_args()
    conv_model = ModelConversion(args.saved_model, args.output, opset=args.opset)
    conv_model.convert()
    final_model = conv_model.final
    
    if os.path.isdir(final_model):
        flag = False
        for fname in os.listdir(final_model):
            if fname.endswith('.pb'):
                flag = True
                LOGGER.info(f'Final model is saved at: {os.path.join(final_model, fname)}')
                break   
        if not flag:
            raise Exception(f'ERROR: No model with required extension is found at: {final_model}')   
    elif os.path.isfile(final_model) and (os.path.splitext(args.output)[1][1:] == os.path.splitext(final_model)[1][1:] ):
        LOGGER.info(f'Final model is saved at: {os.path.join(final_model)}')   
    else:
        raise Exception(f'ERROR: No model with required extension is found at: {final_model}')   
    
