From 22ca368994cc132c901ce30e555a978896e8676b Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 14 Jun 2024 04:32:23 +0000 Subject: [PATCH] Update submodules, prepare for leasing v0.2.4 --- benchmarks/prefill_offline.py | 3 +++ benchmarks/run_offline.py | 3 +++ deps/JetStream | 2 +- deps/xla | 2 +- install_everything.sh | 6 +++--- run_interactive.py | 2 ++ run_server_with_ray.py | 2 ++ 7 files changed, 15 insertions(+), 5 deletions(-) diff --git a/benchmarks/prefill_offline.py b/benchmarks/prefill_offline.py index 2d38b97c..8de5119d 100644 --- a/benchmarks/prefill_offline.py +++ b/benchmarks/prefill_offline.py @@ -16,6 +16,9 @@ import os import time +# import torch_xla2 first! +# pylint: disable-next=all +import torch_xla2 import humanize import jax import numpy as np diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 72a41bd6..ef83f9e9 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -16,6 +16,9 @@ import os import time +# import torch_xla2 first! +# pylint: disable-next=all +import torch_xla2 import jax import jax.numpy as jnp # pylint: disable-next=all diff --git a/deps/JetStream b/deps/JetStream index 8a1e3132..26872c3c 160000 --- a/deps/JetStream +++ b/deps/JetStream @@ -1 +1 @@ -Subproject commit 8a1e31322e8e953909482b71f2689f82dbf4572f +Subproject commit 26872c3c6e726f52f5bac1cb63e60a9a2a0bbe8a diff --git a/deps/xla b/deps/xla index 961c22ae..c216d26c 160000 --- a/deps/xla +++ b/deps/xla @@ -1 +1 @@ -Subproject commit 961c22ae03bbc3fc53641efd85427ed1f0f38be0 +Subproject commit c216d26c23a37eb85dd8f8152ffe1acdb6b484a0 diff --git a/install_everything.sh b/install_everything.sh index 1a542efb..220e6df2 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -24,14 +24,13 @@ pip show tensorboard && pip uninstall -y tensorboard pip show tensorflow-text && pip uninstall -y tensorflow-text pip show torch_xla2 && pip uninstall -y torch_xla2 -pip install flax==0.8.3 -pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install flax pip install tensorflow-text pip install tensorflow pip install ray[default]==2.22.0 # torch cpu -pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu +pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage pip install safetensors colorama coverage humanize @@ -39,3 +38,4 @@ git submodule update --init --recursive pip show google-jetstream && pip uninstall -y google-jetstream pip show torch_xla2 && pip uninstall -y torch_xla2 pip install -e . +pip install -U jax[tpu]==0.4.29 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html diff --git a/run_interactive.py b/run_interactive.py index ccddc9c3..1527e311 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -17,6 +17,8 @@ import time from typing import List +# import torch_xla2 first! +import torch_xla2 # pylint: disable import jax import numpy as np from absl import app, flags diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 75c41164..de3bdf21 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -18,6 +18,8 @@ from typing import Sequence from absl import app, flags +# import torch_xla2 first! +import torch_xla2 # pylint: disable import jax from jetstream.core import server_lib from jetstream.core.config_lib import ServerConfig