Skip to content
This repository has been archived by the owner on May 12, 2024. It is now read-only.

Commit

Permalink
Densify
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Sep 23, 2021
1 parent ce307d3 commit f032b31
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ARG CPVER=cp38
ARG OPENVINOVER=2021.4.582
ARG OPENVINOROOTDIR=/opt/intel/openvino_2021
ARG TENSORRTVER=cuda11.3-trt8.0.1.6-ga-20210626
ARG APPVER=v1.11.6
ARG APPVER=v1.11.7
ARG wkdir=/home/user

# dash -> bash
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite, ONNX, Ope
|122|FlexRoll|tf.roll|Flex OP|
|123|CONV_3D|tf.keras.layers.Conv3D||
|124|CONV_3D_TRANSPOSE|tf.nn.conv3d_transpose||
|125|Densify|Const||

## 2. Environment
- Python3.6+
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name="tflite2tensorflow",
scripts=scripts,
version="1.11.6",
version="1.11.7",
description="Generate saved_model, tfjs, tf-trt, EdgeTPU, CoreML, quantized tflite, ONNX, OpenVINO, Myriad Inference Engine blob and .pb from .tflite.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
83 changes: 78 additions & 5 deletions tflite2tensorflow/tflite2tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def parse_json(jsonfile_path):
json_tensor_details = j['subgraphs'][0]['tensors']
print('num of ops:', len(ops))
pprint.pprint(ops)
return ops, json_tensor_details, op_types
return ops, json_tensor_details, op_types, j


def read_int(buffer, offset, bit_size):
Expand Down Expand Up @@ -413,6 +413,7 @@ def read_flexbuffer(buffer, decode_strings=True):
def make_graph(
ops,
json_tensor_details,
full_json,
op_types,
interpreter,
replace_swish_and_hardswish,
Expand Down Expand Up @@ -443,6 +444,19 @@ def make_graph(
'BFLOAT16': tf.bfloat16
}

cast_type_np = {
'UINT8' : np.uint8,
'UINT16' : np.uint16,
'UINT32' : np.uint32,
'UINT64' : np.uint64,
'INT8' : np.int8,
'INT16' : np.int16,
'INT32' : np.int32,
'INT64' : np.int64,
'FLOAT16' : np.float16,
'FLOAT32' : np.float32
}

class MaxUnpooling2D(Layer):
def __init__(self):
super(MaxUnpooling2D,self).__init__()
Expand Down Expand Up @@ -1552,9 +1566,15 @@ def upsampling2d_nearrest(x, size_height, size_width, align_corners, half_pixel_
tensors[output_detail['index']] = output_tensor

elif op_type == 'DEQUANTIZE':
weights_detail = interpreter._get_tensor_details(op['inputs'][0])
weights = interpreter.get_tensor(weights_detail['index'])
output_tensor = weights.astype(np.float32)
input_tensor1 = None
input_detail = interpreter._get_tensor_details(op['inputs'][0])
try:
input_tensor1 = tensors[op['inputs'][0]]
except:
input_tensor1 = interpreter.get_tensor(input_detail['index'])
input_tensor1 = backward_quantization(input_detail, input_tensor1)

output_tensor = input_tensor1.astype(np.float32)
output_detail = interpreter._get_tensor_details(op['outputs'][0])
tensors[output_detail['index']] = output_tensor

Expand Down Expand Up @@ -4122,6 +4142,58 @@ def complexabs_(x, tout):

tensors[output_detail['index']] = output_tensor

elif op_type == 'DENSIFY':
json_tensor_info = searh_json_tensor_detail(interpreter._get_tensor_details(op['outputs'][0])['name'][:-1])
output_detail = interpreter._get_tensor_details(op['outputs'][0])

shape = json_tensor_info['shape']
dtype = cast_type_np[json_tensor_info['type']]
dim_metadata = json_tensor_info['sparsity']['dim_metadata']
traversal_order = json_tensor_info['sparsity']['traversal_order']
array_segments = None
array_indices = None
for dict_data in dim_metadata:
if 'format' in dict_data:
if dict_data['format'] == 'SPARSE_CSR':
array_segments = dict_data['array_segments']['values']
array_indices = dict_data['array_indices']['values']

denj = full_json['buffers'][output_detail['index']]['data']
b = np.array(denj).astype(np.uint8).tobytes()

dense_list = []
for i in range(len(b))[::2]:
dense_list.append(struct.unpack_from('<e', bytes(b[i:i+2]))[0])

starting_point = 0
dense_shape = np.asarray(shape)
groups = dense_shape[-1]
total_number_of_elements = dense_shape[0]
for i in dense_shape[1:]:
total_number_of_elements *= i

densify_values = np.zeros((total_number_of_elements))
sidx = 0
aidx = 0
didx = 0
addition = 0

for idx in range(total_number_of_elements):
if array_segments[sidx] == aidx:
addition = sidx * groups
sidx += 1
if array_indices[aidx] + addition == idx:
densify_values[idx] = dense_list[didx]
didx += 1
if aidx < len(array_indices) - 1:
aidx += 1

densify_values = densify_values.reshape(dense_shape)
output_tensor = densify_values.astype(dtype)

tensors[output_detail['index']] = output_tensor


elif op_type == 'CUSTOM':
"""
Convolution2DTransposeBias
Expand Down Expand Up @@ -4881,7 +4953,7 @@ def main():

jsonfile_path = f'./{model}.json'
gen_model_json(flatc_path, model_output_path, jsonfile_path, schema_path, model_path)
ops, json_tensor_details, op_types = parse_json(jsonfile_path)
ops, json_tensor_details, op_types, full_json = parse_json(jsonfile_path)

interpreter = tflite_interpreter(model_path)
interpreter.allocate_tensors()
Expand All @@ -4901,6 +4973,7 @@ def main():
TFLite_Detection_PostProcess_flg = make_graph(
ops,
json_tensor_details,
full_json,
op_types,
interpreter,
replace_swish_and_hardswish,
Expand Down

0 comments on commit f032b31

Please sign in to comment.