Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for torch models tracing #2658

Merged
merged 5 commits into from
Feb 17, 2023
Merged

Conversation

lostella
Copy link
Contributor

@lostella lostella commented Feb 16, 2023

Description of changes: Add test to check that tracing a model results in the same computation, i.e., that the model's behaviour does not depend on input values.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@lostella lostella added tests This item concerns improving tests torch This concerns the PyTorch side of GluonTS labels Feb 16, 2023
Comment on lines 25 to 26
torch.zeros(shape, dtype=model.input_types()[name])
for (name, shape) in model.input_shapes().items()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to not have two different methods to get information on values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I guess that can be addressed separately?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Copy link
Contributor

@jaheba jaheba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There also is to_torchscript which lightning provides.

from gluonts.torch.model.tft import TemporalFusionTransformerModel


def get_model_and_input(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this, why not just use .example_input_array?

Also, why should this be part of the parameter and not just be invoked in the test function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that using these completely dummy arrays is the best approach. Maybe for tracing the model yes, but then we should probably test the model vs scripted model using some other input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's mostly a placeholder for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this, why not just use .example_input_array?

That's not a property of torch.nn.Module

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Still looks like something which could be consolidated.

@lostella lostella marked this pull request as ready for review February 16, 2023 11:20
@lostella lostella requested a review from jaheba February 16, 2023 11:20
@lostella lostella changed the title Add test for model tracing Add test for torch models tracing Feb 17, 2023
@lostella lostella enabled auto-merge (squash) February 17, 2023 11:14
@lostella lostella merged commit 485f8c4 into awslabs:dev Feb 17, 2023
@lostella lostella deleted the torchscript-test branch February 17, 2023 11:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests This item concerns improving tests torch This concerns the PyTorch side of GluonTS
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants