-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test for torch models tracing #2658
Conversation
test/torch/model/test_jit.py
Outdated
torch.zeros(shape, dtype=model.input_types()[name]) | ||
for (name, shape) in model.input_shapes().items() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to not have two different methods to get information on values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I guess that can be addressed separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There also is to_torchscript
which lightning provides.
test/torch/model/test_jit.py
Outdated
from gluonts.torch.model.tft import TemporalFusionTransformerModel | ||
|
||
|
||
def get_model_and_input(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this, why not just use .example_input_array
?
Also, why should this be part of the parameter and not just be invoked in the test function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that using these completely dummy arrays is the best approach. Maybe for tracing the model yes, but then we should probably test the model vs scripted model using some other input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it's mostly a placeholder for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this, why not just use .example_input_array?
That's not a property of torch.nn.Module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Still looks like something which could be consolidated.
Description of changes: Add test to check that tracing a model results in the same computation, i.e., that the model's behaviour does not depend on input values.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup