From 22d79a84b549e7da52132da5b79d350221a892e0 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 4 Apr 2024 17:31:46 +0000 Subject: [PATCH] Move create_pytorch_engine to init. Makes the UX a little bit cleaner --- jetstream_pt/__init__.py | 1 + run_server.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jetstream_pt/__init__.py b/jetstream_pt/__init__.py index e69de29b..77d2a3df 100644 --- a/jetstream_pt/__init__.py +++ b/jetstream_pt/__init__.py @@ -0,0 +1 @@ +from jetstream_pt.engine import create_pytorch_engine \ No newline at end of file diff --git a/run_server.py b/run_server.py index 05e7ed09..883c48f6 100644 --- a/run_server.py +++ b/run_server.py @@ -6,8 +6,9 @@ from absl import flags from jetstream.core import server_lib +import jetstream_pt from jetstream_pt import config -from jetstream_pt import engine as je +from jetstream.core.config_lib import ServerConfig _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') @@ -61,7 +62,6 @@ _QUANTIZE_KV_CACHE = flags.DEFINE_bool('quantize_kv_cache', False, 'kv_cache_quantize') _MAX_CACHE_LENGTH = flags.DEFINE_integer('max_cache_length', 1024, 'kv_cache_quantize') -from jetstream.core.config_lib import ServerConfig def main(argv: Sequence[str]): del argv @@ -69,7 +69,7 @@ def main(argv: Sequence[str]): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() print(f"devices: {devices}") - engine = je.create_pytorch_engine( + engine = jetstream_pt.create_pytorch_engine( devices=devices, tokenizer_path=_TOKENIZER_PATH.value, ckpt_path=_CKPT_PATH.value,