Skip to content

Commit

Permalink
Run a sample query for a quantized model conditionally
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
janekl committed Jul 12, 2024
1 parent 3e2bb21 commit f599645
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ jobs:
quantization.num_calib_size=8 \
inference.batch_size=2 \
export.inference_tensor_parallel=2 \
export.sample_output=False \
export.save_path=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo
AFTER_SCRIPT: |
rm -rf /home/TestData/nlp/megatron_llama/ci_fp8.qnemo
Expand All @@ -273,6 +274,7 @@ jobs:
quantization.algorithm=int8_sq \
quantization.num_calib_size=8 \
inference.batch_size=2 \
export.sample_output=False \
export.save_path=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo
AFTER_SCRIPT: |
rm -rf /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo
Expand Down
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ export:
inference_pipeline_parallel: 1 # Default using 1 PP for inference
dtype: ${trainer.precision} # Default precision data type
save_path: llama2-7b-${quantization.algorithm}.qnemo # Path where the quantized model will be saved
compress: false # Wheter save_path should be a tarball or a directory
compress: false # Whether save_path should be a tarball or a directory
sample_output: true # Whether to run a sample prompt before saving
3 changes: 2 additions & 1 deletion nemo/export/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def export(self, model: MegatronGPTModel):
assert self.export_config is not None, "Export config is not set"
torch_dtype = torch_dtype_from_precision(self.export_config.dtype)

self._sample_output(model)
if self.export_config.get("sample_output", True):
self._sample_output(model)

if model.cfg.megatron_amp_O2:
model.model = unwrap_model(model.model, Float16Module)
Expand Down

0 comments on commit f599645

Please sign in to comment.