Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TFLITE] get input tensor information from graph #7400

Merged
merged 13 commits into from Feb 15, 2021
Merged

[FRONTEND][TFLITE] get input tensor information from graph #7400

merged 13 commits into from Feb 15, 2021

Conversation

euntaik
Copy link
Contributor

@euntaik euntaik commented Feb 3, 2021

get input tensor information from graph

Copy link
Contributor

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I think it is an improvement to have a default function to come up with shapes in case we don't provide them, so thanks for the initiative @euntaik.

I'm pointing to some similar logic we have somewhere else in the code base that can be ported here. Please have a look.

I any case, it would be good to come up with some testing also to make sure this doesn't break in future.

@@ -3539,7 +3539,62 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def from_tflite(model, shape_dict, dtype_dict):
def get_tensor_shape(subgraph, tensor_idx):
"""Get the tensor shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • minor: the name of the argument doesn't match with the actual argument
  • the types are not specified

Please review all docstring being introduced here for those items above.

Comment on lines 3646 to 3653
if shape_dict:
shape = shape_dict[model_input_name] if model_input_name in shape_dict else None
else:
shape = get_tensor_shape(subgraph, model_input)
if dtype_dict:
dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32"
else:
dtype = get_tensor_type(subgraph, model_input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a similar function, that collect the same information being proposed here in TVMC. I agree we should move what is in there, to unify functionality here.

Can you have a look on the function I'm pointing here (below) and spot why are they so different, and in case you agree on what's the best approach, improve it here and remove it there?

@staticmethod
def _decode_type(n):
return TFLiteFrontend._tflite_m[n]
@staticmethod
def _input_type(model):
subgraph_count = model.SubgraphsLength()
assert subgraph_count > 0
shape_dict = {}
dtype_dict = {}
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
assert inputs_count >= 1
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
tensor = subgraph.Tensors(input_)
input_shape = tuple(tensor.ShapeAsNumpy())
tensor_type = tensor.Type()
input_name = tensor.Name().decode("utf8")
shape_dict[input_name] = input_shape
dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type)

Copy link
Contributor Author

@euntaik euntaik Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a similar function, that collect the same information being proposed here in TVMC. I agree we should move what is in there, to unify functionality here.

Oh, it was there all along. I think I missed your code since I was loading my models in a separate script to put the relay output into my compile passes.

Can you have a look on the function I'm pointing here (below) and spot why are they so different,

I don't see much difference except that your code accounts for models with more than one subgraph.

and in case you agree on what's the best approach, improve it here and remove it there?

My rationale behind making and putting this code in the tflite.py file was:

  1. use the data in the graph since it is already embedded in it.
  2. place the code inside the frontend code since it is dependent on the frontend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. I think we both agree that it is better to have the funcionality only in the tflite.py, and remove it from TVMC.

So I suggest we keep the one that accounts for many subgraphs, and move it from TVMC to the official frontend? If you agree, feel free to do it in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will update the PR.

Copy link
Contributor

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks better now. A few comments below, mostly on TVMC area.

@@ -3539,7 +3539,45 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def from_tflite(model, shape_dict, dtype_dict):
def _decode_type(n):
_tflite_m = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is duplicated in tvmc/frontends.py - is there any reason why we can't reuse this one there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it. Fixed.

mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict
)
mod, params = relay.frontend.from_tflite(tflite_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we merged #7366, users are able to provide shapes in tvmc from outside, can you have a look on that one and adjust?

Suggested change
mod, params = relay.frontend.from_tflite(tflite_model)
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict)

cc @CircleSpin @hogepodge to help

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the from_tflite() now duplicated? (I just looked quickly, might be wrong)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Sorry for that.

@euntaik
Copy link
Contributor Author

euntaik commented Feb 15, 2021

Any other comments?

@leandron
Copy link
Contributor

Any other comments?

Sorry I forgot to check this again after CI.

@leandron
Copy link
Contributor

leandron commented Feb 15, 2021

@mbaret @FrozenGene can you have a look on this one, and merge if you think it is ok?

Copy link
Contributor

@mbaret mbaret left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a good change to me.

@mbaret mbaret merged commit 6187e1c into apache:main Feb 15, 2021
Lokiiiiii pushed a commit to Lokiiiiii/tvm that referenced this pull request Mar 2, 2021
* [FRONTEND][TFLITE] get input tensor information from graph

* remove bare-except

* fix lint

* delete empty line

* comment change

* move some of the tflite frontend code from tvmc to tflite.py

* update shape and dtype when user provided them

* remove unused var. pass user provided shape_dict

* remove duplicate code
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Mar 2, 2021
* [FRONTEND][TFLITE] get input tensor information from graph

* remove bare-except

* fix lint

* delete empty line

* comment change

* move some of the tflite frontend code from tvmc to tflite.py

* update shape and dtype when user provided them

* remove unused var. pass user provided shape_dict

* remove duplicate code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants