From cf8f7c8ecf3f40687498f674ff99f3736fc5d3c7 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 31 Jul 2024 16:21:56 +0000 Subject: [PATCH 1/3] add xla2 fix --- README.md | 1 + deps/xla | 2 +- run_interactive_multiple_host.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5292e0c7..ca6ec4ba 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,7 @@ Note: Get address ip and port information from ray head. Here is an example to run the server with ray for llama2 7B model: ```bash +export DISABLE_XLA2_PJRT_TEST="true" python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" ``` diff --git a/deps/xla b/deps/xla index c2753715..fb2d4e14 160000 --- a/deps/xla +++ b/deps/xla @@ -1 +1 @@ -Subproject commit c27537153f3ea983a7ba9b0e1bfdae4b37ca5e9e +Subproject commit fb2d4e1464dfd96f38a343c0e6f512629e28b48c diff --git a/run_interactive_multiple_host.py b/run_interactive_multiple_host.py index 24b27987..5d72768b 100644 --- a/run_interactive_multiple_host.py +++ b/run_interactive_multiple_host.py @@ -57,7 +57,7 @@ def create_engine(): sharding_config=FLAGS.sharding_config, num_hosts=_NUM_HOSTS.value, worker_chips=_WORKER_CHIPS.value, - tpu_chips=_TPU_CHIPS, + tpu_chips=_TPU_CHIPS.value, ) print("Initialize engine", time.perf_counter() - start) From 996dde51e8fbb2546ff81cf182855610280f7029 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 31 Jul 2024 18:30:56 +0000 Subject: [PATCH 2/3] update jax version --- install_everything.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install_everything.sh b/install_everything.sh index f404ef6b..57d21a92 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -39,5 +39,5 @@ git submodule update --init --recursive pip show google-jetstream && pip uninstall -y google-jetstream pip show torch_xla2 && pip uninstall -y torch_xla2 pip install -e . -pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U jax[tpu]==0.4.31 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu From 9d242794949bdfafa657ff14e646c3ec17fa41d0 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 31 Jul 2024 20:49:13 +0000 Subject: [PATCH 3/3] revert jax TPU version --- install_everything.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install_everything.sh b/install_everything.sh index 57d21a92..f404ef6b 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -39,5 +39,5 @@ git submodule update --init --recursive pip show google-jetstream && pip uninstall -y google-jetstream pip show torch_xla2 && pip uninstall -y torch_xla2 pip install -e . -pip install -U jax[tpu]==0.4.31 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U jax[tpu]==0.4.30 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install -U torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu