diff --git a/convert_checkpoints.py b/convert_checkpoints.py index c3f83160..a06ec3d1 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -586,7 +586,7 @@ def main(argv) -> None: if FLAGS.quantize_weights: quantize_num_bits = 8 if "int8" in FLAGS.quantize_type else 4 is_blockwise = "blockwise" in FLAGS.quantize_type - weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else 1 + weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else -1 start = time.perf_counter() state_dict = _quantize_state_dict( state_dict,