Skip to content

Commit

Permalink
Pass use_custom_all_reduce in test_nemo_export.py
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
janekl committed Jun 6, 2024
1 parent c43c07b commit be70812
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/export/test_nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def run_trt_llm_inference(
max_batch_size=8,
max_input_token=128,
max_output_token=128,
use_custom_all_reduce=True,
ptuning=False,
p_tuning_checkpoint=None,
lora=False,
Expand Down Expand Up @@ -212,6 +213,7 @@ def run_trt_llm_inference(
max_output_token=max_output_token,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
use_custom_all_reduce=use_custom_all_reduce,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
save_nemo_model_config=True,
Expand Down Expand Up @@ -415,6 +417,12 @@ def get_args():
type=int,
default=128,
)
parser.add_argument(
"--use_custom_all_reduce",
type=str,
default="True",
choices=["True", "False"],
)
parser.add_argument(
"--p_tuning_checkpoint",
type=str,
Expand Down Expand Up @@ -507,6 +515,9 @@ def run_inference_tests(args):
else:
args.run_accuracy = False

if args.use_custom_all_reduce == "True":
args.use_custom_all_reduce = True

if args.run_accuracy:
if args.test_data_path is None:
raise Exception("test_data_path param cannot be None.")
Expand Down Expand Up @@ -551,6 +562,7 @@ def run_inference_tests(args):
max_batch_size=args.max_batch_size,
max_input_token=args.max_input_token,
max_output_token=args.max_output_token,
use_custom_all_reduce=args.use_custom_all_reduce,
ptuning=args.ptuning,
p_tuning_checkpoint=args.p_tuning_checkpoint,
lora=args.lora,
Expand Down

0 comments on commit be70812

Please sign in to comment.