Skip to content

Commit

Permalink
[Cherry pick] fix quant scale name (#44903)
Browse files Browse the repository at this point in the history
* fix quant scale name (#44116)

* fix acc diff problem caused by pr #44116 (#44311)

Co-authored-by: handiz <35895648+ZhangHandi@users.noreply.github.com>
  • Loading branch information
ceci3 and ZhangHandi committed Aug 10, 2022
1 parent 2676281 commit cbab018
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -962,10 +962,10 @@ def _update_program(self):
else:
scale_dict = self._quantized_threshold
for key, val in scale_dict.items():
utils.set_variable_data(self._scope, self._place, key + ".scale",
utils.set_variable_data(self._scope, self._place, key + "@scale",
np.array([val], dtype=np.float32))
utils.set_variable_data(self._scope, self._place,
key + ".quant_dequant.scale",
key + ".quant_dequant@scale",
np.array([val], dtype=np.float32))

if not self._onnx_format:
Expand Down
27 changes: 17 additions & 10 deletions python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def _quantized_scale_name(self, var_name):
"""
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)

def _is_skip_quant(self, graph, op_node):
"""
Expand Down Expand Up @@ -1246,8 +1246,8 @@ def _original_var_name(self, var_name):
return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[:-len('.scale')]
if var_name.endswith('@scale'):
return var_name[:-len('@scale')]
else:
return var_name

Expand Down Expand Up @@ -1440,11 +1440,18 @@ def apply(self, graph):
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue

scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())
try:
graph._find_node_by_name(
graph.all_var_nodes(),
self._scale_name(in_node.name()))
continue
except:
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())

data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
Expand Down Expand Up @@ -1705,7 +1712,7 @@ def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node(
name="{}.quant_dequant.scale".format(var_node.name()),
name="{}.quant_dequant@scale".format(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
Expand Down Expand Up @@ -1954,7 +1961,7 @@ def _quantized_scale_name(self, var_name):
"""
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)

def _zero_point_name(self, var_name):
"""
Expand Down

0 comments on commit cbab018

Please sign in to comment.