diff --git a/README.md b/README.md index 5292e0c7..ca6ec4ba 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,7 @@ Note: Get address ip and port information from ray head. Here is an example to run the server with ray for llama2 7B model: ```bash +export DISABLE_XLA2_PJRT_TEST="true" python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` diff --git a/deps/xla b/deps/xla index c2753715..fb2d4e14 160000 --- a/deps/xla +++ b/deps/xla @@ -1 +1 @@ -Subproject commit c27537153f3ea983a7ba9b0e1bfdae4b37ca5e9e +Subproject commit fb2d4e1464dfd96f38a343c0e6f512629e28b48c diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 24b27987..5d72768b 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -57,7 +57,7 @@ def create_engine(): sharding_config=FLAGS.sharding_config, num_hosts=_NUM_HOSTS.value, worker_chips=_WORKER_CHIPS.value, - tpu_chips=_TPU_CHIPS, + tpu_chips=_TPU_CHIPS.value, ) print("Initialize engine", time.perf_counter() - start)