diff --git a/jetstream_pt/__init__.py b/jetstream_pt/__init__.py index e69de29b..77d2a3df 100644 --- a/jetstream_pt/__init__.py +++ b/jetstream_pt/__init__.py @@ -0,0 +1 @@ +from jetstream_pt.engine import create_pytorch_engine \ No newline at end of file diff --git a/run_server.py b/run_server.py index 05e7ed09..883c48f6 100644 --- a/run_server.py +++ b/run_server.py @@ -6,8 +6,9 @@ from absl import flags from jetstream.core import server_lib +import jetstream_pt from jetstream_pt import config -from jetstream_pt import engine as je +from jetstream.core.config_lib import ServerConfig _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') @@ -61,7 +62,6 @@ _QUANTIZE_KV_CACHE = flags.DEFINE_bool('quantize_kv_cache', False, 'kv_cache_quantize') _MAX_CACHE_LENGTH = flags.DEFINE_integer('max_cache_length', 1024, 'kv_cache_quantize') -from jetstream.core.config_lib import ServerConfig def main(argv: Sequence[str]): del argv @@ -69,7 +69,7 @@ def main(argv: Sequence[str]): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() print(f"devices: {devices}") - engine = je.create_pytorch_engine( + engine = jetstream_pt.create_pytorch_engine( devices=devices, tokenizer_path=_TOKENIZER_PATH.value, ckpt_path=_CKPT_PATH.value,