Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use WIT for model trained in tfx #37

Open
orenkobo opened this issue Dec 31, 2019 · 7 comments
Open

Use WIT for model trained in tfx #37

orenkobo opened this issue Dec 31, 2019 · 7 comments

Comments

@orenkobo
Copy link

orenkobo commented Dec 31, 2019

Hi
I trained a model with tfx and it was exported as saved_model.pb.
Now, I want to reload it and visualize it using WIT.
How can I do this?

I couldn't find a way to do it since when reloading the model:
imported = tf.saved_model.load(export_dir=trained_model_path) I get object from the type :
<tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7f3d71e456a0>
instead of an estimator.

Thanks

@jameswex
Copy link
Collaborator

jameswex commented Jan 2, 2020

Looking at official documentation (https://www.tensorflow.org/guide/saved_model#savedmodels_from_estimators), it seems that when you load a saved model from disk, what you get back is not an estimator. But you should still be able to call predict on that object, by defining your own custom prediction function like is done in that documentation and then providing that custom predict function to the WitConfigBuilder.

Let me know if an approach similar to the predict(x) function in that link works for you.

@orenkobo
Copy link
Author

orenkobo commented Jan 5, 2020

@jameswex When using the predict function:

def predict(x):
    
    example = tf.train.Example()
    example.features.feature["x"].float_list.value.extend([x])
    return imported.signatures["predict"](examples=tf.constant([example.SerializeToString()]))

config_builder = WitConfigBuilder(test_examples, feats + ['level']).set_estimator_and_feature_spec(predict, feature_spec = [])
WitWidget(config_builder, height=1600)

(With imported being
imported = tf.saved_model.load(export_dir=trained_model_path)
from the type <tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7f3d71e456a0> )

I get the error:
"<_Rendezvous of RPC that terminated with: status = StatusCode.UNAVAILABLE details = "DNS resolution failed" debug_error_string = "{"created":"@1578211571.031196087","description":"Failed to pick subchannel","file":"src/core/ext/filters/client_channel/client_channel.cc","file_line":3818,"referenced_errors":[{"created":"@1578211571.031189371","description":"Resolver transient failure","file":"src/core/ext/filters/client_channel/resolving_lb_policy.cc","file_line":268,"referenced_errors":[{"created":"@1578211571.031187685","description":"DNS resolution failed","file":"src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc","file_line":357,"grpc_status":14,"referenced_errors":[{"created":"@1578211571.031167691","description":"C-ares status is not ARES_SUCCESS: Domain name not found","file":"src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc","file_line":244}]}]}]}" >"

@jameswex
Copy link
Collaborator

jameswex commented Jan 5, 2020

Since you have defined your own custom prediction function, instead of using a tf.Estimator, you want to change your code to something like:
config_builder = WitConfigBuilder(test_examples, feats + ['level']).set_custom_predict_fn(predict) WitWidget(config_builder, height=1600)

@orenkobo
Copy link
Author

orenkobo commented Jan 7, 2020

@jameswex OK this is better now but I have a problem - my features are from type list:

features {
  feature {
    key: "b_number"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "c_type"
    value {
      bytes_list {
        value: "motor"
      }
    }
  }

So I get the error:
[features { feature { key: "bearing_number" value { int64_list { value: 1 has type list, but expected one of: int, long, float

I have total of more 30 features and they are all from types float_list / int_list / bytes_list, what is the best way to convert them all to int / long / float?

@jameswex
Copy link
Collaborator

jameswex commented Jan 8, 2020

Are you able to share a colab notebook with your code that loads up your saved model so I could see the issue? I'm imagining that perhaps the saved model as reloaded wants the example in a very different format than the tf.Example format and so some conversion function will be necessary but its hard to know what that will need to be without playing with it myself.

@orenkobo
Copy link
Author

orenkobo commented Jan 8, 2020

@jameswex It's internal code so it will be problematic to share.. I'll try to play with it and make it work, Thanks!

@jameswex
Copy link
Collaborator

jameswex commented Jan 8, 2020

Looking at the example in the link I sent above, it seems your custom predict fn might need to take the provided tf.Examples, serialize them and wrap them in a tf.constant like:
def predict(examples): return imported.signatures["predict"]( examples=tf.constant([ex.SerializeToString() for ex in examples]))
That would be due to how the restored saved model accepts inputs. But I haven't directly worked with this type of restored model before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants