From dc1c1a63333ee8a7afc4bb705b26842d27b535e3 Mon Sep 17 00:00:00 2001 From: Xiang Si Date: Thu, 20 Jun 2024 17:59:43 +0000 Subject: [PATCH] fix mixtral quantization scaler axis when dimension > 2 --- convert_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index c3f83160..a06ec3d1 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -586,7 +586,7 @@ def main(argv) -> None: if FLAGS.quantize_weights: quantize_num_bits = 8 if "int8" in FLAGS.quantize_type else 4 is_blockwise = "blockwise" in FLAGS.quantize_type - weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else 1 + weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else -1 start = time.perf_counter() state_dict = _quantize_state_dict( state_dict,