From 42f3904ed63ac08ce76285e35ab836c5d4e333d4 Mon Sep 17 00:00:00 2001 From: vipandya Date: Tue, 28 Apr 2026 08:18:42 +0000 Subject: [PATCH 1/2] onnx int4 - remove temp files even on exception Signed-off-by: vipandya --- modelopt/onnx/quantization/int4.py | 831 +++++++++++++++-------------- 1 file changed, 423 insertions(+), 408 deletions(-) diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index b17431fb9b..f5a33ba479 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -533,147 +533,153 @@ def _quantize_awq_clip( augmented_onnx_file, augmented_onnx_path = tempfile.mkstemp(suffix=".onnx") os.close(augmented_onnx_file) - save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) - logger.info(f"Saving the model took {time.time() - t} seconds") - - # Creating inference session and preparing inputs for calibration - session = create_inference_session(augmented_onnx_path, calibration_eps, input_shapes_profile) - inputs = [] - for inp_d in data_reader: - inputs.append(inp_d) - assert isinstance(inp_d, dict) - layer_info = get_layer_info(onnx_model, nodes_to_exclude, block_size, **kwargs) - # Apply AWQ clip on selected weights - t = time.time() - alphas = {} - for i in tqdm(range(len(wa_pack)), desc="Running clip search..."): - act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] - - # First capture all the activation values after calibration data sweep - output_dicts = {} - for inp_d in inputs: - np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()} - output = session.run([act_tensor.name], np_inp_d) - out = np.asarray(output[0]) - output_dicts.setdefault(act_tensor.name, []).append(out) - - # Concatenating the activation tensors over all calib data - x = np.concatenate(output_dicts[act_tensor.name], axis=0) # n_token, ci - w = numpy_helper.to_array( - weight_tensor, base_dir=os.path.dirname(augmented_onnx_path) - ).copy() - if do_transpose: - w = w.T - w = np.asarray(w) - num_bits = get_num_bits(layer_info, weight_tensor.name) - # Updating the block size as for 8bit quantization, per-channel quantization is used. - block_size_updated = update_block_size(block_size, layer_info, weight_tensor.name, w=w) - awq_clip = AWQClipHelper(w, block_size_updated, **kwargs) - _clip_search(x, w, awq_clip, num_bits=num_bits, **kwargs) - alphas[weight_tensor.name] = awq_clip.best_alpha - - logger.info(f"Clip search for all weights took {time.time() - t} seconds") + try: + save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) + logger.info(f"Saving the model took {time.time() - t} seconds") - del session + # Creating inference session and preparing inputs for calibration + session = create_inference_session( + augmented_onnx_path, calibration_eps, input_shapes_profile + ) + inputs = [] + for inp_d in data_reader: + inputs.append(inp_d) + assert isinstance(inp_d, dict) + layer_info = get_layer_info(onnx_model, nodes_to_exclude, block_size, **kwargs) + # Apply AWQ clip on selected weights + t = time.time() + alphas = {} + for i in tqdm(range(len(wa_pack)), desc="Running clip search..."): + act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] - # Compute quantized weights and scales which are needed for DQ nodes - t = time.time() - for i in tqdm(range(len(wa_pack)), desc="Quantizing the weights..."): - act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] - gemm_io_type = cast("onnx.TensorProto.DataType", gemm_io_type) + # First capture all the activation values after calibration data sweep + output_dicts = {} + for inp_d in inputs: + np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()} + output = session.run([act_tensor.name], np_inp_d) + out = np.asarray(output[0]) + output_dicts.setdefault(act_tensor.name, []).append(out) - if force_fp16: - gemm_io_type = onnx.TensorProto.FLOAT16 + # Concatenating the activation tensors over all calib data + x = np.concatenate(output_dicts[act_tensor.name], axis=0) # n_token, ci + w = numpy_helper.to_array( + weight_tensor, base_dir=os.path.dirname(augmented_onnx_path) + ).copy() + if do_transpose: + w = w.T + w = np.asarray(w) + num_bits = get_num_bits(layer_info, weight_tensor.name) + # Updating the block size as for 8bit quantization, per-channel quantization is used. + block_size_updated = update_block_size(block_size, layer_info, weight_tensor.name, w=w) + awq_clip = AWQClipHelper(w, block_size_updated, **kwargs) + _clip_search(x, w, awq_clip, num_bits=num_bits, **kwargs) + alphas[weight_tensor.name] = awq_clip.best_alpha - w = numpy_helper.to_array( - weight_tensor, base_dir=os.path.dirname(augmented_onnx_path) - ).copy() - if do_transpose: - w = w.T - w = np.asarray(w) + logger.info(f"Clip search for all weights took {time.time() - t} seconds") - alpha = alphas.get(weight_tensor.name, 1) - num_bits = get_num_bits(layer_info, weight_tensor.name) - # Updating the block size as for 8bit quantization, per-channel quantization is used. - block_size_updated = update_block_size(block_size, layer_info, weight_tensor.name, w=w) - qw, scale, _ = quant_tensor(w, block_size_updated, alpha=alpha, num_bits=num_bits) - if has_cupy: - qw = np.asnumpy(qw) - scale = np.asnumpy(scale) - if do_transpose: - qw = qw.T - scale = scale.T - scales[weight_tensor.name] = scale.astype( - onnx.helper.tensor_dtype_to_np_dtype(gemm_io_type) - ) - gemm_weights_quantized[weight_tensor.name] = numpy.asarray(qw).astype(numpy.int8) + del session - # Change the input activation type to the expected type, fp16 by default - # TODO: cast input C for Gemm - _change_input_type(onnx_model.graph, act_tensor.name, gemm_io_type) + # Compute quantized weights and scales which are needed for DQ nodes + t = time.time() + for i in tqdm(range(len(wa_pack)), desc="Quantizing the weights..."): + act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] + gemm_io_type = cast("onnx.TensorProto.DataType", gemm_io_type) - logger.info(f"Quantizing actual weights took {time.time() - t} seconds") + if force_fp16: + gemm_io_type = onnx.TensorProto.FLOAT16 - graph_gs = gs.import_onnx(onnx_model) + w = numpy_helper.to_array( + weight_tensor, base_dir=os.path.dirname(augmented_onnx_path) + ).copy() + if do_transpose: + w = w.T + w = np.asarray(w) - gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) - gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) - gather_w_map = None - gather_s_map = None - if gather_quantize_axis is not None: - gather_w_map, gather_s_map, _ = _quantize_gather_nodes( - graph_gs, - nodes_to_exclude, - use_zero_point=False, - dq_only=True, - layer_info=layer_info, - ) + alpha = alphas.get(weight_tensor.name, 1) + num_bits = get_num_bits(layer_info, weight_tensor.name) + # Updating the block size as for 8bit quantization, per-channel quantization is used. + block_size_updated = update_block_size(block_size, layer_info, weight_tensor.name, w=w) + qw, scale, _ = quant_tensor(w, block_size_updated, alpha=alpha, num_bits=num_bits) + if has_cupy: + qw = np.asnumpy(qw) + scale = np.asnumpy(scale) + if do_transpose: + qw = qw.T + scale = scale.T + scales[weight_tensor.name] = scale.astype( + onnx.helper.tensor_dtype_to_np_dtype(gemm_io_type) + ) + gemm_weights_quantized[weight_tensor.name] = numpy.asarray(qw).astype(numpy.int8) + + # Change the input activation type to the expected type, fp16 by default + # TODO: cast input C for Gemm + _change_input_type(onnx_model.graph, act_tensor.name, gemm_io_type) + + logger.info(f"Quantizing actual weights took {time.time() - t} seconds") + + graph_gs = gs.import_onnx(onnx_model) + + gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) + gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) + gather_w_map = None + gather_s_map = None + if gather_quantize_axis is not None: + gather_w_map, gather_s_map, _ = _quantize_gather_nodes( + graph_gs, + nodes_to_exclude, + use_zero_point=False, + dq_only=True, + layer_info=layer_info, + ) - t = time.time() - # Apply column-major optimization if flag is set - # Transposes the weights and scales in-place - use_column_major = kwargs.get("use_column_major", False) - if use_column_major: - qdq.apply_column_major_transformation(gemm_weights_quantized, scales) - dq_node_attributes = {"axis": 1, "block_size": block_size} - else: - dq_node_attributes = {"axis": 0, "block_size": block_size} - scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) - qdq.insert_dq_nodes( - graph_gs, - scales, - quantized_weights=gemm_weights_quantized, - attributes=dq_node_attributes, - layer_info=layer_info, - ) - # Add transpose nodes for column-major if needed - if use_column_major: - qdq.insert_transpose_nodes_for_column_major(graph_gs) - if gather_w_map is not None: - assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" - gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} + t = time.time() + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + use_column_major = kwargs.get("use_column_major", False) + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, - gather_s_map, - quantized_weights=gather_w_map, - attributes=gather_dq_node_attributes, + scales, + quantized_weights=gemm_weights_quantized, + attributes=dq_node_attributes, layer_info=layer_info, ) - logger.info(f"Inserting DQ nodes took {time.time() - t} seconds") - - logger.info("Exporting the quantized graph") - t = time.time() - model = gs.export_onnx(graph_gs) - # Set ir_version to 10, remove it once ORT supports ir_version 11 - model.ir_version = 10 - logger.info(f"Exporting took {time.time() - t} seconds") + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph_gs) + if gather_w_map is not None: + assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" + gather_dq_node_attributes = { + "axis": gather_quantize_axis, + "block_size": gather_block_size, + } + qdq.insert_dq_nodes( + graph_gs, + gather_s_map, + quantized_weights=gather_w_map, + attributes=gather_dq_node_attributes, + layer_info=layer_info, + ) + logger.info(f"Inserting DQ nodes took {time.time() - t} seconds") - try: - os.remove(augmented_onnx_path) - if use_external_data_format: - os.remove(augmented_onnx_path + "_data") - except OSError: - logger.warn("Augmented ONNX model or external data file was not found") + logger.info("Exporting the quantized graph") + t = time.time() + model = gs.export_onnx(graph_gs) + # Set ir_version to 10, remove it once ORT supports ir_version 11 + model.ir_version = 10 + logger.info(f"Exporting took {time.time() - t} seconds") + finally: + try: + os.remove(augmented_onnx_path) + if use_external_data_format: + os.remove(augmented_onnx_path + "_data") + except OSError: + logger.warn("Augmented ONNX model or external data file was not found") return model @@ -1085,316 +1091,325 @@ def _quantize_awq_lite( augmented_onnx_file, augmented_onnx_path = tempfile.mkstemp(suffix=".onnx") os.close(augmented_onnx_file) - save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) - logger.info(f"Saving the model took {time.time() - t} seconds") - - # Creating inference session and preparing inputs for calibration - session = create_inference_session(augmented_onnx_path, calibration_eps, input_shapes_profile) - inputs = [] - for inp_d in data_reader: - inputs.append(inp_d) - assert isinstance(inp_d, dict) - - gc.collect() - - output_data = [] - - if enable_fast_path_using_high_sysram: - logger.info("Fast-path-using-high-sysram is enabled\n") - - tensor_names_list = [] - for i in tqdm(range(len(wa_pack)), desc="Getting tensor names..."): - act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] - tensor_names_list.append(act_tensor.name) + try: + save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) + logger.info(f"Saving the model took {time.time() - t} seconds") - for i in tqdm(range(len(inputs)), desc="Caching activations..."): - inp_d = inputs[i] - np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()} - output = session.run(tensor_names_list, np_inp_d) - output_data.append(output) + # Creating inference session and preparing inputs for calibration + session = create_inference_session( + augmented_onnx_path, calibration_eps, input_shapes_profile + ) + inputs = [] + for inp_d in data_reader: + inputs.append(inp_d) + assert isinstance(inp_d, dict) - del session - session = None gc.collect() - # Apply AWQ lite on selected weights - t = time.time() - awq_lite = [None] * len(wa_pack) - clip_alphas = {} - - msg = "..." - if enable_weight_clipping: - msg = " and clip-range search..." - - act_to_wa_pack_map, act_to_quant_nodes_weight_shape_map = ( - get_act_to_weight_map_and_act_to_wa_pack_map(wa_pack) - ) - if run_per_subgraph: - # TODO - add support for handling awq_lite mixed precision for per-subgraph implementation - awq_lite = run_awq_scale_search_per_subgraph( - wa_pack, - act_to_wa_pack_map, - act_to_quant_nodes_weight_shape_map, - augmented_onnx_path, - block_size, - use_zero_point, - session, - awq_lite, - inputs, - msg, - **kwargs, - ) - else: - awq_lite, clip_alphas = run_awq_scale_search_per_node( - wa_pack, - augmented_onnx_path, - block_size, - use_zero_point, - session, - awq_lite, - inputs, - msg, - enable_weight_clipping, - enable_fast_path_using_high_sysram, - output_data, - clip_alphas, - layer_info, - **kwargs, - ) - assert len(awq_lite) == len(wa_pack) - for i in range(len(awq_lite)): - assert awq_lite[i] is not None - - if enable_weight_clipping: - assert len(clip_alphas.keys()) == len(wa_pack) + output_data = [] - logger.info("AWQ scale search" + msg.strip(".") + f" took {time.time() - t} seconds") + if enable_fast_path_using_high_sysram: + logger.info("Fast-path-using-high-sysram is enabled\n") - if session is not None: - del session - session = None - if has_cupy: - np.get_default_memory_pool().free_all_blocks() - del output_data - gc.collect() + tensor_names_list = [] + for i in tqdm(range(len(wa_pack)), desc="Getting tensor names..."): + act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] + tensor_names_list.append(act_tensor.name) - # Compute quantized weights and scales which are needed for DQ nodes - t = time.time() - # Use a common mean scale for weights within a sub-graph - if fuse_nodes and not run_per_subgraph: - for wa_pack_idx_list in act_to_wa_pack_map.values(): - group_awq_scale = [ - awq_lite[wa_pack_idx].best_scale[:, np.newaxis] for wa_pack_idx in wa_pack_idx_list - ] - mean_awq_scale = np.concatenate(group_awq_scale, axis=1) - mean_awq_scale = mean_awq_scale.mean(axis=1) - for wa_pack_idx in wa_pack_idx_list: - awq_lite[wa_pack_idx].best_scale = mean_awq_scale + for i in tqdm(range(len(inputs)), desc="Caching activations..."): + inp_d = inputs[i] + np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()} + output = session.run(tensor_names_list, np_inp_d) + output_data.append(output) - for i in tqdm(range(len(wa_pack)), desc="Quantizing the weights..."): - act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] - gemm_io_type = cast("onnx.TensorProto.DataType", gemm_io_type) + del session + session = None + gc.collect() - if force_fp16: - gemm_io_type = onnx.TensorProto.FLOAT16 + # Apply AWQ lite on selected weights + t = time.time() + awq_lite = [None] * len(wa_pack) + clip_alphas = {} - w = numpy_helper.to_array( - weight_tensor, base_dir=os.path.dirname(augmented_onnx_path) - ).copy() - if do_transpose: - w = w.T - w = np.asarray(w) + msg = "..." + if enable_weight_clipping: + msg = " and clip-range search..." - w_scaled = w * awq_lite[i].best_scale[:, np.newaxis] - alpha = clip_alphas.get(weight_tensor.name, 1) - assert enable_weight_clipping or (alpha == 1), ( - "clip range enabled without enabling weight-clipping param" - ) - # Updating the block size as for 8bit quantization, per-channel quantization is used. - num_bits = get_num_bits(layer_info, weight_tensor.name) - block_size_updated = update_block_size( - block_size, layer_info, weight_tensor.name, w=w_scaled - ) - qw, scale, zp = quant_tensor( - w_scaled, - block_size_updated, - alpha=alpha, - use_zero_point=use_zero_point, - num_bits=num_bits, + act_to_wa_pack_map, act_to_quant_nodes_weight_shape_map = ( + get_act_to_weight_map_and_act_to_wa_pack_map(wa_pack) ) + if run_per_subgraph: + # TODO - add support for handling awq_lite mixed precision for per-subgraph implementation + awq_lite = run_awq_scale_search_per_subgraph( + wa_pack, + act_to_wa_pack_map, + act_to_quant_nodes_weight_shape_map, + augmented_onnx_path, + block_size, + use_zero_point, + session, + awq_lite, + inputs, + msg, + **kwargs, + ) + else: + awq_lite, clip_alphas = run_awq_scale_search_per_node( + wa_pack, + augmented_onnx_path, + block_size, + use_zero_point, + session, + awq_lite, + inputs, + msg, + enable_weight_clipping, + enable_fast_path_using_high_sysram, + output_data, + clip_alphas, + layer_info, + **kwargs, + ) + assert len(awq_lite) == len(wa_pack) + for i in range(len(awq_lite)): + assert awq_lite[i] is not None - assert use_zero_point is True or zp is None, "zp is not according to use-zero-point setting" - if do_transpose: - qw = qw.T - scale = scale.T - if zp is not None: - zp = zp.T - if has_cupy: - qw = np.asnumpy(qw) - scale = np.asnumpy(scale) - if zp is not None: - zp = np.asnumpy(zp) - scales[weight_tensor.name] = scale.astype( - onnx.helper.tensor_dtype_to_np_dtype(gemm_io_type) - ) - weight_dtype = numpy.int8 - if zp is not None: - zero_points[weight_tensor.name] = numpy.asarray(zp).astype(numpy.uint8) - weight_dtype = numpy.uint8 - gemm_weights_quantized[weight_tensor.name] = numpy.asarray(qw).astype(weight_dtype) - input_tensors[weight_tensor.name] = act_tensor.name - pqs_value = ( - awq_lite[i] - .best_scale[:, np.newaxis] - .astype(onnx.helper.tensor_dtype_to_np_dtype(gemm_io_type)) - ).T - if has_cupy: - pqs_value = np.asnumpy(pqs_value) - pre_quant_scale[weight_tensor.name] = pqs_value + if enable_weight_clipping: + assert len(clip_alphas.keys()) == len(wa_pack) - # Change the input activation type to the expected type, fp16 by default - # TODO: cast input C for Gemm - _change_input_type(onnx_model.graph, act_tensor.name, gemm_io_type) + logger.info("AWQ scale search" + msg.strip(".") + f" took {time.time() - t} seconds") - logger.info(f"Quantizing actual weights took {time.time() - t} seconds") + if session is not None: + del session + session = None + if has_cupy: + np.get_default_memory_pool().free_all_blocks() + del output_data + gc.collect() - # Fuse Mul nodes with parent node if possible - if fuse_nodes: - logger.info("Fusing pre-quant scale Mul nodes with parent node") + # Compute quantized weights and scales which are needed for DQ nodes t = time.time() - updated_nodes = set() - name_to_node_map = {node.name: node for node in onnx_model.graph.node} - initializer_map = { - initializer.name: initializer for initializer in onnx_model.graph.initializer - } - for parent, child_nodes in parent_child_nodes_map.items(): - if parent == "root_0": - continue - parent = name_to_node_map[parent] - if parent.name in updated_nodes: - continue - # When fuse_nodes or run_per_subgraph is True, - # scales computed for each child_nodes will be same. - # Hence, picking pre_quant_scale corresponding to any child_nodes is acceptable - input_scale = np.asarray(pre_quant_scale[child_nodes[0].input[1]]) - weight_tensor_names = [node.input[1] for node in child_nodes] - if ( - is_fusible_scaling_op(parent.op_type) - and not all(initializer_map.get(inp) is None for inp in parent.input) - and len(input_name_to_nodes[child_nodes[0].input[0]]) == len(child_nodes) - ): - for inp in parent.input: - if initializer_map.get(inp) is not None: - tensor = initializer_map[inp] - old_dim = tensor.dims - tensor_array = numpy_helper.to_array( - tensor, - base_dir=os.path.dirname(augmented_onnx_path), - ) - new_tensor = np.asarray(tensor_array) / input_scale - new_tensor = new_tensor.reshape(old_dim) - new_tensor = numpy_helper.from_array(new_tensor.get(), tensor.name) - # replace initializer with new scaled array - tensor.CopyFrom(new_tensor) - for w_name in weight_tensor_names: - del pre_quant_scale[w_name] - updated_nodes.add(parent.name) - else: - scale_tensor = onnx.helper.make_tensor( - name=parent.output[0] + "_pre_quant_scale", - data_type=onnx.helper.np_dtype_to_tensor_dtype(input_scale.dtype), - dims=input_scale.shape, - vals=(1.0 / input_scale).flatten().tolist(), - ) - mul_op_name = parent.output[0] + "_pre_quant_scale_out" - mul_node = onnx.helper.make_node( - "Mul", - inputs=[child_nodes[0].input[0], scale_tensor.name], - outputs=[mul_op_name], - name=child_nodes[0].input[0] + "_pre_quant_scale_mul", - ) - for node in child_nodes: - node.input[0] = mul_node.output[0] - for w_name in weight_tensor_names: - del pre_quant_scale[w_name] - onnx_model.graph.initializer.append(scale_tensor) - onnx_model.graph.node.append(mul_node) - - logger.info(f"Fusing pre-quant scale Mul nodes took {time.time() - t} seconds") + # Use a common mean scale for weights within a sub-graph + if fuse_nodes and not run_per_subgraph: + for wa_pack_idx_list in act_to_wa_pack_map.values(): + group_awq_scale = [ + awq_lite[wa_pack_idx].best_scale[:, np.newaxis] + for wa_pack_idx in wa_pack_idx_list + ] + mean_awq_scale = np.concatenate(group_awq_scale, axis=1) + mean_awq_scale = mean_awq_scale.mean(axis=1) + for wa_pack_idx in wa_pack_idx_list: + awq_lite[wa_pack_idx].best_scale = mean_awq_scale + + for i in tqdm(range(len(wa_pack)), desc="Quantizing the weights..."): + act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i] + gemm_io_type = cast("onnx.TensorProto.DataType", gemm_io_type) + + if force_fp16: + gemm_io_type = onnx.TensorProto.FLOAT16 + + w = numpy_helper.to_array( + weight_tensor, base_dir=os.path.dirname(augmented_onnx_path) + ).copy() + if do_transpose: + w = w.T + w = np.asarray(w) + + w_scaled = w * awq_lite[i].best_scale[:, np.newaxis] + alpha = clip_alphas.get(weight_tensor.name, 1) + assert enable_weight_clipping or (alpha == 1), ( + "clip range enabled without enabling weight-clipping param" + ) + # Updating the block size as for 8bit quantization, per-channel quantization is used. + num_bits = get_num_bits(layer_info, weight_tensor.name) + block_size_updated = update_block_size( + block_size, layer_info, weight_tensor.name, w=w_scaled + ) + qw, scale, zp = quant_tensor( + w_scaled, + block_size_updated, + alpha=alpha, + use_zero_point=use_zero_point, + num_bits=num_bits, + ) - logger.info( - "Inserting DQ nodes and input_pre_quant_scale node using quantized weights and scales" - ) + assert use_zero_point is True or zp is None, ( + "zp is not according to use-zero-point setting" + ) + if do_transpose: + qw = qw.T + scale = scale.T + if zp is not None: + zp = zp.T + if has_cupy: + qw = np.asnumpy(qw) + scale = np.asnumpy(scale) + if zp is not None: + zp = np.asnumpy(zp) + scales[weight_tensor.name] = scale.astype( + onnx.helper.tensor_dtype_to_np_dtype(gemm_io_type) + ) + weight_dtype = numpy.int8 + if zp is not None: + zero_points[weight_tensor.name] = numpy.asarray(zp).astype(numpy.uint8) + weight_dtype = numpy.uint8 + gemm_weights_quantized[weight_tensor.name] = numpy.asarray(qw).astype(weight_dtype) + input_tensors[weight_tensor.name] = act_tensor.name + pqs_value = ( + awq_lite[i] + .best_scale[:, np.newaxis] + .astype(onnx.helper.tensor_dtype_to_np_dtype(gemm_io_type)) + ).T + if has_cupy: + pqs_value = np.asnumpy(pqs_value) + pre_quant_scale[weight_tensor.name] = pqs_value + + # Change the input activation type to the expected type, fp16 by default + # TODO: cast input C for Gemm + _change_input_type(onnx_model.graph, act_tensor.name, gemm_io_type) + + logger.info(f"Quantizing actual weights took {time.time() - t} seconds") + + # Fuse Mul nodes with parent node if possible + if fuse_nodes: + logger.info("Fusing pre-quant scale Mul nodes with parent node") + t = time.time() + updated_nodes = set() + name_to_node_map = {node.name: node for node in onnx_model.graph.node} + initializer_map = { + initializer.name: initializer for initializer in onnx_model.graph.initializer + } + for parent, child_nodes in parent_child_nodes_map.items(): + if parent == "root_0": + continue + parent = name_to_node_map[parent] + if parent.name in updated_nodes: + continue + # When fuse_nodes or run_per_subgraph is True, + # scales computed for each child_nodes will be same. + # Hence, picking pre_quant_scale corresponding to any child_nodes is acceptable + input_scale = np.asarray(pre_quant_scale[child_nodes[0].input[1]]) + weight_tensor_names = [node.input[1] for node in child_nodes] + if ( + is_fusible_scaling_op(parent.op_type) + and not all(initializer_map.get(inp) is None for inp in parent.input) + and len(input_name_to_nodes[child_nodes[0].input[0]]) == len(child_nodes) + ): + for inp in parent.input: + if initializer_map.get(inp) is not None: + tensor = initializer_map[inp] + old_dim = tensor.dims + tensor_array = numpy_helper.to_array( + tensor, + base_dir=os.path.dirname(augmented_onnx_path), + ) + new_tensor = np.asarray(tensor_array) / input_scale + new_tensor = new_tensor.reshape(old_dim) + new_tensor = numpy_helper.from_array(new_tensor.get(), tensor.name) + # replace initializer with new scaled array + tensor.CopyFrom(new_tensor) + for w_name in weight_tensor_names: + del pre_quant_scale[w_name] + updated_nodes.add(parent.name) + else: + scale_tensor = onnx.helper.make_tensor( + name=parent.output[0] + "_pre_quant_scale", + data_type=onnx.helper.np_dtype_to_tensor_dtype(input_scale.dtype), + dims=input_scale.shape, + vals=(1.0 / input_scale).flatten().tolist(), + ) + mul_op_name = parent.output[0] + "_pre_quant_scale_out" + mul_node = onnx.helper.make_node( + "Mul", + inputs=[child_nodes[0].input[0], scale_tensor.name], + outputs=[mul_op_name], + name=child_nodes[0].input[0] + "_pre_quant_scale_mul", + ) + for node in child_nodes: + node.input[0] = mul_node.output[0] + for w_name in weight_tensor_names: + del pre_quant_scale[w_name] + onnx_model.graph.initializer.append(scale_tensor) + onnx_model.graph.node.append(mul_node) - graph_gs = gs.import_onnx(onnx_model) + logger.info(f"Fusing pre-quant scale Mul nodes took {time.time() - t} seconds") - gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) - gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) - gather_w_map = None - gather_s_map = None - gather_zp_map = None - if gather_quantize_axis is not None: - gather_w_map, gather_s_map, gather_zp_map = _quantize_gather_nodes( - graph_gs, - nodes_to_exclude, - use_zero_point=use_zero_point, - dq_only=True, - layer_info=layer_info, + logger.info( + "Inserting DQ nodes and input_pre_quant_scale node using quantized weights and scales" ) - t = time.time() - # Apply column-major optimization if flag is set - # Transposes the weights and scales in-place - use_column_major = kwargs.get("use_column_major", False) - if use_column_major: - qdq.apply_column_major_transformation(gemm_weights_quantized, scales) - dq_node_attributes = {"axis": 1, "block_size": block_size} - else: - dq_node_attributes = {"axis": 0, "block_size": block_size} - scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) - qdq.insert_dq_nodes( - graph_gs, - scales, - quantized_weights=gemm_weights_quantized, - attributes=dq_node_attributes, - zero_points=zero_points if use_zero_point else None, - layer_info=layer_info, - ) - # Add transpose nodes for column-major if needed - if use_column_major: - qdq.insert_transpose_nodes_for_column_major(graph_gs) - if gather_w_map is not None: - assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" - assert not use_zero_point or gather_zp_map, ( - "zero-point setting and zero-point map not in sync for quantizable gather nodes" - ) - gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} + graph_gs = gs.import_onnx(onnx_model) + + gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) + gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) + gather_w_map = None + gather_s_map = None + gather_zp_map = None + if gather_quantize_axis is not None: + gather_w_map, gather_s_map, gather_zp_map = _quantize_gather_nodes( + graph_gs, + nodes_to_exclude, + use_zero_point=use_zero_point, + dq_only=True, + layer_info=layer_info, + ) + + t = time.time() + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + use_column_major = kwargs.get("use_column_major", False) + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, - gather_s_map, - quantized_weights=gather_w_map, - attributes=gather_dq_node_attributes, - zero_points=gather_zp_map if use_zero_point else None, + scales, + quantized_weights=gemm_weights_quantized, + attributes=dq_node_attributes, + zero_points=zero_points if use_zero_point else None, layer_info=layer_info, ) - if pre_quant_scale: - qdq.insert_pre_quant_scale_nodes(graph_gs, input_tensors, pre_quant_scale) - - logger.info(f"Inserting nodes took {time.time() - t} seconds") + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph_gs) + if gather_w_map is not None: + assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" + assert not use_zero_point or gather_zp_map, ( + "zero-point setting and zero-point map not in sync for quantizable gather nodes" + ) + gather_dq_node_attributes = { + "axis": gather_quantize_axis, + "block_size": gather_block_size, + } + qdq.insert_dq_nodes( + graph_gs, + gather_s_map, + quantized_weights=gather_w_map, + attributes=gather_dq_node_attributes, + zero_points=gather_zp_map if use_zero_point else None, + layer_info=layer_info, + ) + if pre_quant_scale: + qdq.insert_pre_quant_scale_nodes(graph_gs, input_tensors, pre_quant_scale) - logger.info("Exporting the quantized graph") - t = time.time() - model = gs.export_onnx(graph_gs) - # Set ir_version to 10, remove it once ORT supports ir_version 11 - model.ir_version = 10 - logger.info(f"Exporting took {time.time() - t} seconds") + logger.info(f"Inserting nodes took {time.time() - t} seconds") - try: - os.remove(augmented_onnx_path) - if use_external_data_format: - os.remove(augmented_onnx_path + "_data") - except OSError: - logger.error("Augmented ONNX model or external data file was not found") + logger.info("Exporting the quantized graph") + t = time.time() + model = gs.export_onnx(graph_gs) + # Set ir_version to 10, remove it once ORT supports ir_version 11 + model.ir_version = 10 + logger.info(f"Exporting took {time.time() - t} seconds") + finally: + try: + os.remove(augmented_onnx_path) + if use_external_data_format: + os.remove(augmented_onnx_path + "_data") + except OSError: + logger.error("Augmented ONNX model or external data file was not found") return model From ce5214b5d29d4ca5b6485ff78a2e43e9e44ae7b4 Mon Sep 17 00:00:00 2001 From: vipandya Date: Tue, 28 Apr 2026 11:41:54 +0000 Subject: [PATCH 2/2] clear session before onnx file removal, handle removal of both onnx and data file Signed-off-by: vipandya --- modelopt/onnx/quantization/int4.py | 42 ++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index f5a33ba479..d680b47cfc 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -480,6 +480,23 @@ def _augment_graph( augmented_outputs.add(act_tensor.name) +def _remove_augmented_onnx(onnx_path: str, use_external_data_format: bool) -> None: + """Remove the augmented ONNX temp file and its external data companion (if any).""" + try: + os.remove(onnx_path) + except FileNotFoundError: + pass + except OSError as e: + logger.warning("Failed to remove augmented ONNX file: %s", e) + if use_external_data_format: + try: + os.remove(onnx_path + "_data") + except FileNotFoundError: + pass + except OSError as e: + logger.warning("Failed to remove augmented ONNX data file: %s", e) + + def _change_input_type( graph: onnx.GraphProto, input_name: str, gemm_io_type: onnx.TensorProto.DataType ): @@ -533,6 +550,7 @@ def _quantize_awq_clip( augmented_onnx_file, augmented_onnx_path = tempfile.mkstemp(suffix=".onnx") os.close(augmented_onnx_file) + session = None try: save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) logger.info(f"Saving the model took {time.time() - t} seconds") @@ -577,7 +595,7 @@ def _quantize_awq_clip( logger.info(f"Clip search for all weights took {time.time() - t} seconds") - del session + session = None # Compute quantized weights and scales which are needed for DQ nodes t = time.time() @@ -674,12 +692,10 @@ def _quantize_awq_clip( model.ir_version = 10 logger.info(f"Exporting took {time.time() - t} seconds") finally: - try: - os.remove(augmented_onnx_path) - if use_external_data_format: - os.remove(augmented_onnx_path + "_data") - except OSError: - logger.warn("Augmented ONNX model or external data file was not found") + if session is not None: + session = None + gc.collect() + _remove_augmented_onnx(augmented_onnx_path, use_external_data_format) return model @@ -1091,6 +1107,7 @@ def _quantize_awq_lite( augmented_onnx_file, augmented_onnx_path = tempfile.mkstemp(suffix=".onnx") os.close(augmented_onnx_file) + session = None try: save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) logger.info(f"Saving the model took {time.time() - t} seconds") @@ -1180,7 +1197,6 @@ def _quantize_awq_lite( logger.info("AWQ scale search" + msg.strip(".") + f" took {time.time() - t} seconds") if session is not None: - del session session = None if has_cupy: np.get_default_memory_pool().free_all_blocks() @@ -1404,12 +1420,10 @@ def _quantize_awq_lite( model.ir_version = 10 logger.info(f"Exporting took {time.time() - t} seconds") finally: - try: - os.remove(augmented_onnx_path) - if use_external_data_format: - os.remove(augmented_onnx_path + "_data") - except OSError: - logger.error("Augmented ONNX model or external data file was not found") + if session is not None: + session = None + gc.collect() + _remove_augmented_onnx(augmented_onnx_path, use_external_data_format) return model