diff --git a/src/maxtext/inference/inference_microbenchmark.py b/src/maxtext/inference/inference_microbenchmark.py index c00d57a910..1adb0f36d5 100644 --- a/src/maxtext/inference/inference_microbenchmark.py +++ b/src/maxtext/inference/inference_microbenchmark.py @@ -332,8 +332,9 @@ def run_benchmarks(config): rng_shape = jax.ShapeDtypeStruct([4], jax.numpy.dtype("uint32")) for prefill_length in prefill_lengths: + is_bos = tokenizer_model.bos_id is not None prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = tokenizer_model.encode( - text, is_bos=True, prefill_lengths=[prefill_length] + text, is_bos=is_bos, prefill_lengths=[prefill_length] ) key_shape = jax.ShapeDtypeStruct([prefill_length], jax.numpy.dtype("int32"))