Skip to content

Commit

Permalink
[pytorch][ao] force weight observer/fake_quant to be on the same devi…
Browse files Browse the repository at this point in the history
…ce as the weight tensor (pytorch#106755)

Summary:
As title.
There's a corner case where both cpu and gpu are avaiable, although the model is moved to cpu, the newly created PTQ weight observer is still on gpu. Therefore, during the convert, this line will fail https://fburl.com/4rhipfvb

Test Plan: CI

Differential Revision: D48141494

Pull Request resolved: pytorch#106755
Approved by: https://github.com/jerryzh168
  • Loading branch information
jiaxuzhu92 authored and Cyril-Anto committed Aug 17, 2023
1 parent 6583bd6 commit bf3d98b
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_get_module,
_is_custom_module_lstm,
_is_custom_module_mha,
assert_and_get_unique_device,
get_custom_module_class_keys,
create_getattr_from_value,
collect_producer_nodes,
Expand Down Expand Up @@ -733,6 +734,9 @@ def convert_weighted_module(
is_ptq = weight_post_process is None
if is_ptq:
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
device = assert_and_get_unique_device(float_module)
if device:
weight_post_process.to(device)

# Call weight observer/fake_quant at least once to ensure the scales and zero points
# have the right shapes. Note: there are two cases where we don't have to do this:
Expand Down

0 comments on commit bf3d98b

Please sign in to comment.