Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support conv1d quant & skip calibrate zero-size tensor #48912

Merged
merged 4 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ def __init__(
self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}
# If the tensor is zero-size during any calibration step,
# it will be stored in self._zero_size_var_names
self._zero_size_var_names = set()
yghstill marked this conversation as resolved.
Show resolved Hide resolved
self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model
self._scale_dict = scale_dict
Expand Down Expand Up @@ -465,9 +468,12 @@ def quantize(self):

if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
if var_name not in self._quantized_var_avg:
continue
self._quantized_threshold[var_name] = np.array(
self._quantized_var_avg[var_name]
).mean()

if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()

Expand Down Expand Up @@ -741,6 +747,9 @@ def _sample_mse(self):
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
Expand Down Expand Up @@ -792,6 +801,9 @@ def _sample_emd(self):
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
Expand Down Expand Up @@ -845,6 +857,9 @@ def _sample_avg(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
if var_name not in self._quantized_var_avg:
self._quantized_var_avg[var_name] = []
Expand All @@ -857,7 +872,6 @@ def _sample_avg(self):
)
)
self._quantized_var_avg[var_name].append(abs_avg_value)
continue

def _sample_abs_max(self):
if self._quantized_threshold == {}:
Expand All @@ -884,6 +898,9 @@ def _sample_abs_max(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_threshold) or (
abs_max_value > self._quantized_threshold[var_name]
Expand Down Expand Up @@ -916,6 +933,9 @@ def _sample_min_max(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if (var_name not in self._quantized_var_min) or (
Expand All @@ -930,6 +950,11 @@ def _sample_min_max(self):
def _sample_histogram(self):
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if (not var_tensor.any()) or (
var_name not in self._sampling_act_histogram
):
self._zero_size_var_names.add(var_name)
continue
var_tensor_abs = np.abs(var_tensor)
bins = self._sampling_act_histogram[var_name][1]
hist, _ = np.histogram(var_tensor_abs, bins=bins)
Expand Down Expand Up @@ -964,6 +989,9 @@ def _sample_ptf(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
q_max = 2 ** (self._activation_bits - 1) - 1
scale8 = abs_max_value / q_max
Expand Down Expand Up @@ -1020,6 +1048,9 @@ def _collect_activation_abs_min_max(self):
'''
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = np.abs(var_tensor)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
Expand All @@ -1039,6 +1070,10 @@ def _init_sampling_act_histogram(self):
Based on the min/max value, init the sampling_act_histogram.
'''
for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_abs_min_max
):
continue
if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1]
Expand Down Expand Up @@ -1077,6 +1112,10 @@ def _calculate_kl_hist_threshold(self):
self._quantized_var_threshold[var_name] = weight_threshold

for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_histogram
):
continue
hist, hist_edeges = self._sampling_act_histogram[var_name]
if self._algo == "KL":
bin_width = hist_edeges[1] - hist_edeges[0]
Expand Down Expand Up @@ -1162,7 +1201,6 @@ def _update_program(self):
if self._same_scale_tensor_list is not None:
for tensor_list in self._same_scale_tensor_list:
max_scale = None
tmp_tensor_list = []
for tensor_name in tensor_list:
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
Expand Down Expand Up @@ -1261,21 +1299,40 @@ def _save_output_threshold(self):
self._calibration_scales = {}

def save_info(
op_node, out_var_name, threshold_map, out_info_name, quantized_type
op_node,
out_var_name,
threshold_map,
out_info_name,
argname_index,
quantized_type,
):
assert (
out_var_name in threshold_map
), "The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type
)
if (out_var_name in self._zero_size_var_names) and (
out_var_name not in threshold_map
):
_logger.warning(
"{} is zero-size tensor and unable to calibrate, so skip quant it.".format(
out_var_name
)
)
return
else:
assert (
out_var_name in threshold_map
), "The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type
)
if self._onnx_format:
# For easy extension, every var_node set a dict to save parameters of quant.
self._calibration_scales[var_name] = {}
self._calibration_scales[var_name]['scale'] = threshold_map[
var_name
self._calibration_scales[out_var_name] = {}
self._calibration_scales[out_var_name]['scale'] = threshold_map[
out_var_name
]
else:
op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr(out_info_name, threshold_map[out_var_name])
op_node._set_attr(
argname_index[0] + str(argname_index[1]) + "_threshold",
threshold_map[out_var_name],
)
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
Expand All @@ -1285,52 +1342,23 @@ def analysis_and_save_info(op_node, out_var_name):
assert argname_index is not None, (
out_var_name + " is not the output of the op"
)
if self._algo == "KL":
# For compatibility, we save output threshold by two methods.
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
"out_threshold",
"post_kl",
)
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl",
)
elif self._algo == "hist":
if self._algo in ["KL", "hist"]:
# For compatibility, we save output threshold by two methods.
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
"out_threshold",
"post_hist",
)
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_hist",
argname_index,
"post_" + str(self._algo).lower(),
)

elif self._algo in ["avg", "abs_max", "mse", "emd", "ptf"]:
save_info(
op_node,
out_var_name,
self._quantized_threshold,
"out_threshold",
"post_" + str(self._algo),
)
save_info(
op_node,
out_var_name,
self._quantized_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
argname_index,
"post_" + str(self._algo),
)
elif self._algo == "min_max":
Expand Down
Loading