diff --git a/README.md b/README.md index 1c72eca3..15bf4fd3 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,8 @@ Need to manually modify the `config.json` in the checkpoint folder to make it a export input_ckpt_dir=Original llama weights directory export output_ckpt_dir=The output directory export model_name="llama-3" # or "llama-2", "gemma" -export quantize_type="int8_per_channel" # Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, setting this will quantize the weights +export quantize_weights=True # Whether to quantize weights +export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified. python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type ``` diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 96501ec4..4f1ade16 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -391,7 +391,7 @@ def main(argv) -> None: llama_model.Transformer.get_quantized_embedding_weight_to_scaler_map() ) - if FLAGS.quantize_type: + if FLAGS.quantize_weights: quantize_num_bits = 8 if "int8" in FLAGS.quantize_type else 4 is_blockwise = "blockwise" in FLAGS.quantize_type weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else 1