diff --git a/experimental/jax/inference/config/config.py b/experimental/jax/inference/config/config.py index 178e381a..cafe36c8 100644 --- a/experimental/jax/inference/config/config.py +++ b/experimental/jax/inference/config/config.py @@ -19,6 +19,7 @@ class ModelId: llama_2_7b_chat_hf = "meta-llama/Llama-2-7b-chat-hf" + llama_2_70b_chat_hf = "meta-llama/Llama-2-70b-chat-hf" @dataclasses.dataclass @@ -43,6 +44,15 @@ class Config: page_size=128, hbm_utilization=0.875, ), + ModelId.llama_2_70b_chat_hf: InferenceParams( + model_id=ModelId.llama_2_70b_chat_hf, + batch_size=100, + max_seq_length=2048, + max_input_length=1024, + prefill_chunk_sizes=[128, 256, 512, 1024], + page_size=128, + hbm_utilization=0.875, + ), } @classmethod