-
Notifications
You must be signed in to change notification settings - Fork 169
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f17131f
42519cc
264adbb
7cbe5b9
1f7d17e
71a9f7a
d02365c
5a572da
fc0bb88
95da832
34c11ef
10e3e2b
9f0691f
fa8f4c8
d1fac44
22b8b73
ca7c0e8
3f857a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis | |
return | ||
|
||
def sync_quantizer_amax_across_dp(quantizer, parallel_state): | ||
"""Synchronize the amax across all ranks in the data parallel group.""" | ||
if isinstance(quantizer, SequentialQuantizer): | ||
for _q in quantizer: | ||
sync_quantizer_amax_across_dp(_q, parallel_state) | ||
|
@@ -94,7 +95,6 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state): | |
for child in module.children(): | ||
if isinstance(child, (TensorQuantizer, SequentialQuantizer)): | ||
sync_quantizer_amax_across_dp(child, module.parallel_state) | ||
|
||
# TP sync: | ||
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same | ||
|
||
|
@@ -116,6 +116,7 @@ def sync_quantizer_amax_across_tp( | |
): | ||
if isinstance(quantizer, SequentialQuantizer): | ||
for _q in quantizer: | ||
"Syncing amax across TP for sequential quantizer" | ||
sync_quantizer_amax_across_tp( | ||
_q, linear_name, quantizer_type, axes_for_sync, parallel_state | ||
) | ||
|
@@ -598,19 +599,32 @@ def forward(self, input, *args, **kwargs): | |
# This will also perform distributed amax sync for input_quantizers | ||
max_calibrate(model, lambda model: None) | ||
|
||
def sync_act_scale_across_dp(module, data_parallel_group): | ||
"""Sync activation scale across Data Parallel (DP).""" | ||
if data_parallel_group.is_initialized(): | ||
dist.all_reduce( | ||
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group | ||
) | ||
|
||
for name, module in model.named_modules(): | ||
if ( | ||
is_quantized_linear(module) | ||
and hasattr(module, "awq_lite") | ||
and module.awq_lite.num_cache_steps > 0 | ||
): | ||
# Hack: MoEs forward all tokens through all experts if _if_calib is True | ||
module._if_calib = True | ||
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps | ||
|
||
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||
torch.isnan(module.awq_lite.weight_scale) | ||
): | ||
module.awq_lite.is_enabled = False | ||
# Hack: MoEs forward all tokens through all experts if _if_calib is True | ||
module._if_calib = True | ||
else: | ||
sync_act_scale_across_dp( | ||
module, | ||
module.parallel_state.data_parallel_group, | ||
) | ||
Comment on lines
+615
to
+627
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainVerify: NaN check placement may cause collective operation mismatch. The current implementation places the sync call in the
Since Past review comments suggest this should be resolved by either:
Consider adding a collective NaN check before the conditional: # Check if ANY rank has NaN
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
dist.all_reduce(has_nan, op=dist.ReduceOp.MAX,
group=module.parallel_state.data_parallel_group.group)
if has_nan.item() > 0:
module.awq_lite.is_enabled = False
else:
sync_act_scale_across_dp(module, module.parallel_state.data_parallel_group) Based on past review comments. 🏁 Script executed: #!/bin/bash
rg -n "sync_act_scale_across_dp" -C3 Length of output: 1217 🏁 Script executed: #!/bin/bash
sed -n '580,650p' modelopt/torch/quantization/model_calib.py Length of output: 3165 Collectively sync NaN detection before calling has_nan_local = (
torch.any(torch.isnan(module.awq_lite.act_scale))
or torch.any(torch.isnan(module.awq_lite.weight_scale))
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
dist.all_reduce(has_nan, op=dist.ReduceOp.MAX,
group=module.parallel_state.data_parallel_group.group)
if has_nan.item() > 0:
module.awq_lite.is_enabled = False
else:
sync_act_scale_across_dp(
module,
module.parallel_state.data_parallel_group,
) This ensures every rank participates in the collective operation and prevents deadlock. 🤖 Prompt for AI Agents
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jenchen13 this is a good point. We should make sure that module.awq_lite.is_enabled across the relevant parallel groups. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
AWQLiteHelper.cache_mode = False | ||
print_rank_0("awq_lite: Searching parameters...") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we dont need separate methods |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,6 +13,7 @@ | |||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||
import copy | ||||||||||||||||||||||||||
from unittest.mock import patch | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import pytest | ||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||
|
@@ -22,7 +23,9 @@ | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import modelopt.torch.opt as mto | ||||||||||||||||||||||||||
import modelopt.torch.quantization as mtq | ||||||||||||||||||||||||||
import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite | ||||||||||||||||||||||||||
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm | ||||||||||||||||||||||||||
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer | ||||||||||||||||||||||||||
from modelopt.torch.quantization.utils import is_quantized_linear | ||||||||||||||||||||||||||
from modelopt.torch.utils import torch_to | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -116,40 +119,160 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N | |||||||||||||||||||||||||
mto.restore_from_modelopt_state(model_ref, state_dict) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def tensor_parallel_test_helper(model, config, tp_group, dp_group): | ||||||||||||||||||||||||||
# The input to fist layer, the column parallel should be the same across all tp ranks | ||||||||||||||||||||||||||
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Can we rename this to _distributed_attr_check
Suggested change
|
||||||||||||||||||||||||||
quantizer_attr = getattr(quantizer, attr).clone() | ||||||||||||||||||||||||||
print("quantizer.attr before reduce", getattr(quantizer, attr)) | ||||||||||||||||||||||||||
dist.all_reduce(quantizer_attr, op=op, group=group) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we iterate through groups here:
Suggested change
|
||||||||||||||||||||||||||
print("quantizer.attr after reduce", getattr(quantizer, attr)) | ||||||||||||||||||||||||||
print("quantizer_attr after reduce", quantizer_attr) | ||||||||||||||||||||||||||
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
original_awq_lite = model_calib_module.awq_lite | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): | ||||||||||||||||||||||||||
"""Function to mock awq_lite function to always use debug=True for testing""" | ||||||||||||||||||||||||||
return original_awq_lite(model, forward_loop, alpha_step, debug=True) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Comment on lines
+134
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forward the AWQ-Lite kwargs in the patch The -def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
- """Function to mock awq_lite function to always use debug=True for testing"""
- return original_awq_lite(model, forward_loop, alpha_step, debug=True)
+def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
+ """Force awq_lite debug mode during tests without dropping optional args."""
+ return original_awq_lite(
+ model,
+ forward_loop,
+ alpha_step=alpha_step,
+ debug=True,
+ **kwargs,
+ ) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @realAsma Thank you for confirming! The fix will ensure that all optional parameters passed to |
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||||||||||||||||||||||||||
def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite): | ||||||||||||||||||||||||||
# The input to first layer, the column parallel should be the same across all tp ranks | ||||||||||||||||||||||||||
calib_data = model.get_dummy_input().cuda() | ||||||||||||||||||||||||||
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def forward_loop(model): | ||||||||||||||||||||||||||
model(calib_data) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Sanity check | ||||||||||||||||||||||||||
forward_loop(model) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: | ||||||||||||||||||||||||||
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks | ||||||||||||||||||||||||||
activation_amax = model.fc2.input_quantizer.amax.clone() | ||||||||||||||||||||||||||
dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
# Lets check the row parallel weight amax; it should be the same across all tp ranks | ||||||||||||||||||||||||||
weight_amax = model.fc2.weight_quantizer.amax.clone() | ||||||||||||||||||||||||||
dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks | ||||||||||||||||||||||||||
input_quantizer = model.fc1.input_quantizer | ||||||||||||||||||||||||||
pre_quant_scale = input_quantizer.pre_quant_scale.clone() | ||||||||||||||||||||||||||
dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
# Check activation scale for AWQ lite | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @realAsma For TP, I only test fc1 (column parallel) act scale during awq lite, because fc2 row parallel will fail. For DP/CP I can test both column + row parallel act scale. I'm assuming row parallel fails because it's split across the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||
model.fc1.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
group=tp_group, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
dist.destroy_process_group() | ||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||||||||||||||||||||||||||
def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): | ||||||||||||||||||||||||||
calib_data = model.get_dummy_input().cuda() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def forward_loop(model): | ||||||||||||||||||||||||||
model(calib_data) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Sanity check | ||||||||||||||||||||||||||
forward_loop(model) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Input quantizer amax | ||||||||||||||||||||||||||
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Weight quantizer amax | ||||||||||||||||||||||||||
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc1.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc2.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
# Check act scale | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc1.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
group=group, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc2.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
group=group, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||||||||||||||||||||||||||
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): | ||||||||||||||||||||||||||
# Calib data should be same across each DP rank | ||||||||||||||||||||||||||
dp_rank = dist.get_rank(group=dp_group) | ||||||||||||||||||||||||||
calib_data = model.get_dummy_input(seed=dp_rank).cuda() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def forward_loop(model): | ||||||||||||||||||||||||||
model(calib_data) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): | ||||||||||||||||||||||||||
quantizer_attr = getattr(quantizer, attr).clone() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Perform all-reduce operations | ||||||||||||||||||||||||||
dist.all_reduce(quantizer_attr, op=op, group=tp_group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
dist.all_reduce(quantizer_attr, op=op, group=dp_group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) | ||||||||||||||||||||||||||
Comment on lines
+235
to
+243
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to define it again here?
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Input quantizer amax | ||||||||||||||||||||||||||
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks | ||||||||||||||||||||||||||
# Channel-wise (INT8) only expects same amax across row parallel ranks | ||||||||||||||||||||||||||
# Block-wise quantization does not expect same amax across row and column parallel ranks | ||||||||||||||||||||||||||
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: | ||||||||||||||||||||||||||
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc1.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: | ||||||||||||||||||||||||||
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc2.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Check act scale | ||||||||||||||||||||||||||
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc1.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
jenchen13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def auto_quantize_helper(model): | ||||||||||||||||||||||||||
model, search_state = mtq.auto_quantize( | ||||||||||||||||||||||||||
model, | ||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.