In [None]:
!pip3 install tensorflow

In [3]:
import tensorflow as tf

In [None]:
tf.__version__
# !protoc --version

In [None]:
# Download the Tensorflow model
!wget 'http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz'
!tar -xzf ssd_mobilenet_v1_coco_2018_01_28.tar.gz
!rm ssd_mobilenet_v1_coco_2018_01_28.tar.gz
%cd ssd_mobilenet_v1_coco_2018_01_28/

In [6]:
from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.platform import gfile
from tensorflow.python.framework import ops
from tensorflow.python.framework import importer

graph_def_proto = ""

# Import the protobuf graph definition from the frozen graph and load it in Tensorflow
with session.Session(graph=ops.Graph()) as sess:
  with gfile.GFile("frozen_inference_graph.pb", "rb") as f:
    graph_def_proto = graph_pb2.GraphDef()
    graph_def_proto.ParseFromString(f.read())
    importer.import_graph_def(graph_def_proto)

In [7]:
import numpy as np

# Byte swap the
def swap16(x):
  return np.frombuffer(x, dtype=np.uint16).byteswap().tobytes()

def swap32(x):
  return np.frombuffer(x, dtype=np.uint32).byteswap().tobytes()

def swap64(x):
  return np.frombuffer(x, dtype=np.uint64).byteswap().tobytes()

def byte_swap_array(content, bytes_per_elem):
    if bytes_per_elem == 1:
      return content

    elif bytes_per_elem == 2:
      return swap16(content)

    elif bytes_per_elem == 4:
      return swap32(content)

    elif bytes_per_elem == 8:
      return swap64(content)

    else:
      return content

In [8]:
set_allowed = set([tf.string, tf.qint8, tf.quint8, tf.bool, tf.int8, tf.uint8])
set_16bits = set([tf.bfloat16, tf.half, tf.qint16, tf.quint16, tf.uint16, tf.int16])
set_32bits = set([tf.float32, tf.int32, tf.qint32, tf.uint32])
set_64bits = set([tf.int64, tf.uint64, tf.double])

# Separate the tensors according to tf.dtype of their tensor_content
def byte_swap_tensor(orig_tensor) -> bytes:

    datatype = orig_tensor.dtype
    content = orig_tensor.tensor_content
    bytes_per_elem = 0

    if datatype in set_allowed:
      bytes_per_elem = 1
    elif datatype in set_16bits:
      bytes_per_elem = 2
    elif datatype in set_32bits:
      bytes_per_elem = 4
    elif datatype in set_64bits:
      bytes_per_elem = 8
    elif datatype in set_complex64:
      bytes_per_elem = 4
    elif datatype in set_complex128:
      bytes_per_elem = 8
    else:
      print("Byteswapping not supported for datatype")
      return content
  
    return byte_swap_array(content, bytes_per_elem)

In [None]:
# Iterate over the nodes of graphs in the top-level graph to identify all the tensors. The `tensor_content` field has the same byte order 
# of the machine it was created on (x86 machine with little-endian byte order).
# So, we convert the byte order of 'tensor_content' from little-endian to big-endian which makes it possible for big-endian archs to load TF frozen models .
for n in graph_def_proto.node:
    # print(n.name)
    for attr_name in n.attr:
        attr_value = n.attr[attr_name]
        if attr_value.HasField("tensor"):   
            tensor_value = attr_value.tensor
            if len(tensor_value.tensor_content) > 0:
              print(f"       Before: {attr_name} => {tensor_value.tensor_content}")
              swapped_tensor_content = byte_swap_tensor(tensor_value)
              tensor_value.tensor_content = swapped_tensor_content
              print(f"       After: {attr_name} => {tensor_value.tensor_content}")


In [None]:
binary_saved_model_proto = graph_def_proto.SerializeToString()
len(binary_saved_model_proto)

29110448

In [None]:
# Saving the rewitten frozen graph
with open("swapped_frozen_inference_graph.pb", "wb") as f:
    f.write(binary_saved_model_proto)