diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 5e64b8f5..bd68257e 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -187,6 +187,7 @@ def run( jax_profiler_port: int = 9999, enable_model_warmup: bool = False, multi_sampling: bool = False, + lora_input_adapters_path: str | None = None, ) -> JetStreamServer: """Runs a server with a specified config. @@ -203,10 +204,16 @@ def run( jax_profiler_port: The port JAX profiler server (default to 9999). enable_model_warmup: The flag to enable model server warmup. multi_sampling: The flag to enable multi-sampling. + lora_input_adapters_path: Path to define the location of all lora adapters. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. """ + # TODO: Deleting the lora_input_adapters_path for now. + # Planning to use it in next big PR. Currently accomodating it + # to fix the params mismatch between maxText and JetStream + del lora_input_adapters_path + server_start_time = time.time() logging.info("Kicking off gRPC server.") # Setup Prometheus server