-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expand to_torchscript to support also TorchScript's trace method #4140
Comments
There is a problem with your PR regarding the inputs. The example inputs passed to the trace function is not the same as the LightningModule.example_input_array.
|
@awaelchli I addressed this in your comment on the pull request revies: #4142 (comment) |
To move the discussion from #4142 (comment) to here @ananthsub 's
Could you elaborate on why scripting is strongly recommended? From the TorchScript documentation, scripting and tracing just focuses on different use cases, without a preference for either one. Scripting is positioned as useful when you need control-flow, while trace is simpler to use (no code changes).
What flexibility would tracing need? Unlike scripting where you build in some logic, tracing is just throwing in an example batch, and get the resulting TorchScript module. About "determine how you want to export?", could you elaborate? The I would like to ask why
Imagine a team with a data scientist good at training a model (Python), but knowing little of engineering (in this example C++), and an engineer good at C++, but has no idea what the model does internally. The engineer just wants to TorchScript model and the Data Scientist has never brought a model to production, but heard that TorchScript is useful to transfer this model to production. In case of scripting, likely the data scientist is the one who needs to figure out how TorchScript works and how to decorate his/her model to do proper scripting. Tracing on the other hand requires no extra knowledge, just call |
@suo is there guidance that PyTorch/JIT can share for tracing vs scripting? My understanding is that more use cases should adopt scripting, but maybe that's too naive on my part |
With the PR: #4360 being merged, While according to the [TorchScript docs (under Automatic Trace Checking)], AttributeError Traceback (most recent call last)
<ipython-input-11-1f9f6fbe4f6c> in <module>()
----> 1 test_x(tmpdir)
10 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
1670 if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
1671 return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
-> 1672 if input.dim() == 2 and bias is not None:
1673 # fused op is marginally faster
1674 ret = torch.addmm(bias, input, weight.t())
AttributeError: 'dict' object has no attribute 'dim' Who (Lighting/PyTorch/TorchScript) is the cause for this error still needs to be determined. Therefore, it's better to keep this issue open until Dicts are supported (or determined that Dicts cannot be used). Related issue: #4378 |
The forward function above doesn't accept dict that's why that error is coming up. def forward(self, batch):
x = batch['x']
return self.model(x) and you would run |
@ananthsub Scripting is generally recommended. Since the tracer can only record observed tensor operations, there are a number of corner cases where the resulting graph may not generalize in surprising ways. Common pitfalls include: control flow, device-specific code, different sizes, etc. That said, if tracing works (and tracing generally works well for a large class of models that don't have any control flow), then it is perfectly fine to use. |
🚀 Feature
Allow for the user to easily choose between TorchScript's script or trace method to create a module.
Motivation
While TorchScript's
script
method will work for simple models, it will not always work out of the box when models rely on Python variables to be set. This requires the user to manually annotate the model to not run into issues withscript()
.TorchScript's
trace
method on the other hand creates a traced module that is determined by running a Tensor through the network and tracks what happens during this process.This always works, but loses design choices if present in the model.
Both
script
andtrace
have their use cases, and with a minimal extension of this function, both methods can be used.Pitch
method
argument that can be set to eitherscript
ortrace
(default toscript
, which results in the current behaviour).example_inputs
argument that defaults to None and can be set to any Tensor. If None is provided, this function will automatically try to useself.example_input_array
. The example input is automatically send to the correct device.Note:
example_inputs
's name cannot be changed as this is the name of the argumenttrace()
expects. If named otherwise, there can be a conflict withkwargs
.This change should not break any older scripts, as it by defaults uses
script
.Alternatives
Make no change and require the user to overwrite this function to use
trace
.Additional context
Please assign me for this request.
The text was updated successfully, but these errors were encountered: