Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry pick] fix quant scale name #44903

Merged
merged 2 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -962,10 +962,10 @@ def _update_program(self):
else:
scale_dict = self._quantized_threshold
for key, val in scale_dict.items():
utils.set_variable_data(self._scope, self._place, key + ".scale",
utils.set_variable_data(self._scope, self._place, key + "@scale",
np.array([val], dtype=np.float32))
utils.set_variable_data(self._scope, self._place,
key + ".quant_dequant.scale",
key + ".quant_dequant@scale",
np.array([val], dtype=np.float32))

if not self._onnx_format:
Expand Down
27 changes: 17 additions & 10 deletions python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def _quantized_scale_name(self, var_name):
"""
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)

def _is_skip_quant(self, graph, op_node):
"""
Expand Down Expand Up @@ -1246,8 +1246,8 @@ def _original_var_name(self, var_name):
return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[:-len('.scale')]
if var_name.endswith('@scale'):
return var_name[:-len('@scale')]
else:
return var_name

Expand Down Expand Up @@ -1440,11 +1440,18 @@ def apply(self, graph):
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue

scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())
try:
graph._find_node_by_name(
graph.all_var_nodes(),
self._scale_name(in_node.name()))
continue
except:
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())

data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
Expand Down Expand Up @@ -1705,7 +1712,7 @@ def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node(
name="{}.quant_dequant.scale".format(var_node.name()),
name="{}.quant_dequant@scale".format(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
Expand Down Expand Up @@ -1954,7 +1961,7 @@ def _quantized_scale_name(self, var_name):
"""
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)

def _zero_point_name(self, var_name):
"""
Expand Down