From 208425ab3c5e2f510975b6863183795c400bfb1d Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 28 May 2024 23:48:31 +0000 Subject: [PATCH 01/11] Ray disaggregated MVP support --- install_everything.sh | 8 ++ jetstream_pt/ray_engine.py | 105 +++++++++++++++-- jetstream_pt/ray_worker.py | 48 ++++++++ run_interactive_disaggregated.py | 195 +++++++++++++++++++++++++++++++ 4 files changed, 349 insertions(+), 7 deletions(-) create mode 100644 run_interactive_disaggregated.py diff --git a/install_everything.sh b/install_everything.sh index 0a2b21c0..35cff401 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -19,9 +19,17 @@ pip show libtpu-nightly && pip uninstall -y libtpu-nightly pip show tensorflow && pip uninstall -y tensorflow pip show ray && pip uninstall -y ray pip show flax && pip uninstall -y flax +pip show keras && pip uninstall -y keras +pip show tensorboard && pip uninstall -y tensorboard +pip show tensorflow-text && pip uninstall -y tensorflow-text +pip show torch_xla2 && pip uninstall -y torch_xla2 pip install flax==0.8.3 pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install tensorflow-text +pip install tensorflow +pip install flax==0.8.3 + pip install ray[default]==2.22.0 # torch cpu pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 1b394829..bdfb039a 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, Optional, Union +from collections import defaultdict +from typing import Any, Iterable, Optional, Union, Tuple import numpy as np import jax @@ -11,6 +12,7 @@ Params = Any Prefix = Any DecodeState = Any +NpPrefix = Any class PyTorchRayEngine(engine_api.Engine): @@ -28,11 +30,15 @@ def __init__( tokenizer_path: str, context_length: int, batch_size: int, + is_disaggregated: bool = False, + pod_slice_name: str = None, ): self.engine_workers = engine_workers self.tokenizer_path = tokenizer_path self.context_length = context_length self.batch_size = batch_size + self.is_disaggregated = is_disaggregated + self.pod_slice_name = pod_slice_name # pylint: disable-next=all def load_params(self) -> Params: @@ -54,7 +60,7 @@ def init_decode_state( _ = ray.get(all_outputs) return None - def prefill( + def interleave_prefill( self, *, params: Any, # Weights @@ -76,6 +82,63 @@ def prefill( # the worker itself manages and maintains the prefill states. return None + def disaggregated_prefill( + self, + *, + params: Any, # Weights + existing_prefix: Optional[Prefix] = None, + padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], + true_length: int, + ) -> Prefix: + all_outputs = [] + for worker in self.engine_workers: + output = worker.prefill_ray_disaggregation.remote( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) + all_outputs.append(output) + results = ray.get(all_outputs) + return results[0] + + def prefill( + self, + *, + params: Any, # Weights + existing_prefix: Optional[Prefix] = None, + padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], + true_length: int, + ) -> Prefix: + result = None + if self.is_disaggregated: + result = self.disaggregated_prefill( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) + else: + result= self.interleave_prefill( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) + + return result + + + def transfer(self, np_prefix: NpPrefix) -> Any: + all_outputs = [] + np_prefix_ref = ray.put(np_prefix) + for worker in self.engine_workers: + output = worker.transfer.remote(np_prefix_ref) + all_outputs.append(output) + results = ray.get(all_outputs) + + return results[0] + def insert( self, prefix: Prefix, @@ -126,7 +189,8 @@ def max_prefill_length(self) -> int: @property def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]: - return jax.devices("cpu")[0] + # return jax.devices("cpu")[0] + return None def get_prefix_destination_sharding(self) -> Prefix: "No implementation" @@ -153,7 +217,10 @@ def create_pytorch_ray_engine( quantize_kv=False, max_cache_length=1024, sharding_config=None, -) -> PyTorchRayEngine: + is_disaggregated: bool = False, + num_hosts: int = 0, + decode_pod_slice_name: str = None, +) -> Any: supported_models = ["llama-2", "llama-3", "gemma"] if model_name not in supported_models: @@ -162,7 +229,7 @@ def create_pytorch_ray_engine( ) ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() - num_hosts = tpu.get_current_pod_worker_count() + num_hosts = num_hosts if is_disaggregated else tpu.get_current_pod_worker_count() print(f"pod_name:{pod_name}, number of host: {num_hosts}") assert ( pod_name is not None @@ -192,10 +259,34 @@ def create_pytorch_ray_engine( sharding_config=sharding_config, ) engine_workers.append(engine_worker) - engine_master = PyTorchRayEngine( + + if not is_disaggregated: + return PyTorchRayEngine( engine_workers=engine_workers, tokenizer_path=tokenizer_path, context_length=context_length, batch_size=batch_size, + ) + + workers_dict = defaultdict(list) + for worker in engine_workers: + pod_slice_name = ray.get(worker.pod_slice_name.remote()) + workers_dict[pod_slice_name].append(worker) + + prefill_engine = PyTorchRayEngine( + engine_workers=workers_dict[pod_name], + tokenizer_path=tokenizer_path, + context_length=context_length, + batch_size=batch_size, + is_disaggregated=is_disaggregated, + pod_slice_name=pod_name, + ) + decode_engine = PyTorchRayEngine( + engine_workers=workers_dict[decode_pod_slice_name], + tokenizer_path=tokenizer_path, + context_length=context_length, + batch_size=batch_size, + is_disaggregated=is_disaggregated, + pod_slice_name=decode_pod_slice_name, ) - return engine_master + return (prefill_engine, decode_engine) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index d1a16b41..6da7b217 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -23,6 +23,7 @@ import jax import numpy as np import ray +from ray.util.accelerators import tpu import safetensors import torch import torch_xla2 @@ -56,6 +57,12 @@ class Prefix: caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad +@struct.dataclass +# pylint: disable-next=all +class NpPrefix: + token: jax.Array # [1, seqlen] + caches: List[Tuple[jax.Array, jax.Array]] + seq_len: int # true seqlen front pad @struct.dataclass # pylint: disable-next=all @@ -460,6 +467,44 @@ def prefill_ray( self.prefix_queue.put(prefix, block=False) return token + + def _convert_to_np_caches(self, caches: List[Tuple[jax.Array, jax.Array]]) -> List[Tuple[np.ndarray, np.ndarray]]: + return [(np.asarray(tup[0]), np.asarray(tup[1])) for tup in caches] + + def _convert_to_jax_caches(self, np_caches: List[Tuple[np.ndarray, np.ndarray]]) -> List[Tuple[jax.Array, jax.Array]]: + return [(jnp.asarray(tup[0]), jnp.asarray(tup[1])) for tup in np_caches] + + def prefill_ray_disaggregation( + self, + *, + params: Any, # Weights + existing_prefix: Optional[Prefix] = None, + padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray], + true_length: int, + ) -> Any: + """Do prefill in ray worker""" + logits, updated_caches = self.prefill( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) + if len(logits.shape) == 3: # b, seqlen, num words + logits = logits[0] + + token = np.argmax(logits[true_length - 1]) + updated_caches = multihost_utils.process_allgather(updated_caches, tiled=True) + np_update_caches = self._convert_to_np_caches(updated_caches) + np_prefix = NpPrefix(token, np_update_caches, true_length) + + return np_prefix + + + def transfer(self, np_prefix: NpPrefix) -> Any: + updated_caches = self._convert_to_jax_caches(np_prefix.caches) + prefix = Prefix(np_prefix.token, updated_caches, np_prefix.seq_len) + self.prefix_queue.put(prefix, block=False) + return True def shrink_prefix( self, @@ -884,3 +929,6 @@ def max_decode_length(self) -> int: def mesh(self): """return mesh""" return None + + def pod_slice_name(self): + return tpu.get_current_pod_name() diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py new file mode 100644 index 00000000..82414229 --- /dev/null +++ b/run_interactive_disaggregated.py @@ -0,0 +1,195 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import time + +from typing import List +from absl import app +from absl import flags +from colorama import Fore, Style + +import numpy as np +import jax + +from jetstream.engine import token_utils +from jetstream_pt import ray_engine + +FLAGS = flags.FLAGS + +_TOKENIZER_PATH = flags.DEFINE_string( + "tokenizer_path", + "tokenizer.model", + "The tokenizer model path", + required=False, +) +_CKPT_PATH = flags.DEFINE_string( + "checkpoint_path", None, "Directory for .pth checkpoints", required=False +) +_BF16_ENABLE = flags.DEFINE_bool( + "bf16_enable", False, "Whether to enable bf16", required=False +) +_CONTEXT_LENGTH = flags.DEFINE_integer( + "context_length", 1024, "The context length", required=False +) +_BATCH_SIZE = flags.DEFINE_integer( + "batch_size", 32, "The batch size", required=False +) +_PROFILING_OUTPUT = flags.DEFINE_string( + "profiling_output", + "", + "The profiling output", + required=False, +) + +_SIZE = flags.DEFINE_string("size", "tiny", "size of model") + +_QUANTIZE_WEIGHTS = flags.DEFINE_bool( + "quantize_weights", False, "weight quantization" +) +_QUANTIZE_KV_CACHE = flags.DEFINE_bool( + "quantize_kv_cache", False, "kv_cache_quantize" +) +_MAX_CACHE_LENGTH = flags.DEFINE_integer( + "max_cache_length", 1024, "kv_cache_quantize" +) + +_MODEL_NAME = flags.DEFINE_string( + "model_name", None, "model type", required=False +) + +_SHARDING_CONFIG = flags.DEFINE_string( + "sharding_config", "", "config file for sharding" +) + + +_IS_DISAGGREGATED = flags.DEFINE_bool( + "is_disaggregated", False, "Disaggregated serving if it's True" +) + +_NUM_HOSTS = flags.DEFINE_integer( + "num_hosts", 4, "Number of TPU host", required=False +) + +_DECODE_POD_SLICE_NAME = flags.DEFINE_string( + "decode_pod_slice_name", "", "Decode pod slice name" +) + + +def create_disaggregated_engines(): + """create a pytorch engine""" + # jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + start = time.perf_counter() + prefill_engine, decode_engine = ray_engine.create_pytorch_ray_engine( + model_name=_MODEL_NAME.value, + tokenizer_path=_TOKENIZER_PATH.value, + ckpt_path=_CKPT_PATH.value, + bf16_enable=True, + param_size=_SIZE.value, + context_length=_CONTEXT_LENGTH.value, + batch_size=_BATCH_SIZE.value, + quantize_weights=_QUANTIZE_WEIGHTS.value, + quantize_kv=_QUANTIZE_KV_CACHE.value, + max_cache_length=_MAX_CACHE_LENGTH.value, + sharding_config=_SHARDING_CONFIG.value, + is_disaggregated=_IS_DISAGGREGATED.value, + num_hosts=_NUM_HOSTS.value, + decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value, + ) + + print("Initialize engine", time.perf_counter() - start) + return (prefill_engine, decode_engine) + + +# pylint: disable-next=all +def main(argv): + + print("start the test") + prefill_engine, decode_engine = create_disaggregated_engines() + + start = time.perf_counter() + prefill_engine.load_params() + decode_engine.load_params() + print("Load params ", time.perf_counter() - start) + + metadata = prefill_engine.get_tokenizer() + tokenizer = prefill_engine.build_tokenizer(metadata) + vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) + stop_tokens = [vocab.eos_id, vocab.pad_id] + max_output_length = 1024 + + if _PROFILING_OUTPUT.value: + jax.profiler.start_trace(_PROFILING_OUTPUT.value) + + decode_engine.init_decode_state() + prompts: List[str] = [ + "I believe the meaning of life is", + # pylint: disable-next=all + "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", + ] + for prompt in prompts: + slot = random.randint(0, _BATCH_SIZE.value - 1) + tokens, true_length = token_utils.tokenize_and_pad( + prompt, vocab, is_bos=True, jax_padding=False + ) + print(f"---- Input prompts are: {prompt}") + print(f"---- Encoded tokens are: {tokens}") + + # pylint: disable-next=all + print(f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}") + prefill_result = prefill_engine.prefill( + params=None, padded_tokens=tokens, true_length=true_length + ) + print(f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}") + decode_engine.transfer(prefill_result) + # pylint: disable-next=all + print(f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}") + decode_state = decode_engine.insert(prefill_result, None, slot=slot) + sampled_tokens_list = [] + while True: + # pylint: disable-next=all + decode_state, result_tokens = decode_engine.generate(None, decode_state) + result_tokens = result_tokens.convert_to_numpy() + + slot_data = result_tokens.get_result_at_slot(slot) + slot_tokens = slot_data.tokens + slot_lengths = slot_data.lengths + + token_id = slot_tokens[slot, 0].item() + if slot_lengths > max_output_length or token_id in stop_tokens: + break + + sampled_tokens_list.append(token_id) + + print("---- All output tokens.") + print(sampled_tokens_list) + print("---- All output text.") + print(vocab.tokenizer.decode(sampled_tokens_list)) + + if _PROFILING_OUTPUT.value: + jax.profiler.stop_trace() + + +if __name__ == "__main__": + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + app.run(main) From d80e893f34d1935d0af6346c931ca0d4a7872829 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 00:43:19 +0000 Subject: [PATCH 02/11] add jax cpu --- jetstream_pt/ray_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index bdfb039a..f92b460f 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -189,8 +189,7 @@ def max_prefill_length(self) -> int: @property def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]: - # return jax.devices("cpu")[0] - return None + return jax.devices("cpu")[0] def get_prefix_destination_sharding(self) -> Prefix: "No implementation" From 32562b93791ef62849372b216972d28ba52e965b Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 00:44:50 +0000 Subject: [PATCH 03/11] add comments --- jetstream_pt/ray_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index f92b460f..91ef1a7a 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -189,7 +189,8 @@ def max_prefill_length(self) -> int: @property def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]: - return jax.devices("cpu")[0] + # ray head doesn't load any parameters + return None def get_prefix_destination_sharding(self) -> Prefix: "No implementation" From 0eb09068c342f7e7b015d12df3f2e0176ed3a327 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 01:01:02 +0000 Subject: [PATCH 04/11] format --- jetstream_pt/ray_engine.py | 54 ++++++++++++++++++++------------------ jetstream_pt/ray_worker.py | 27 ++++++++++++------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 91ef1a7a..4e2d9744 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -1,8 +1,7 @@ from collections import defaultdict -from typing import Any, Iterable, Optional, Union, Tuple +from typing import Any, Iterable, Optional, Union import numpy as np -import jax import ray from ray.util.accelerators import tpu @@ -60,6 +59,7 @@ def init_decode_state( _ = ray.get(all_outputs) return None + # pylint: disable-next=all def interleave_prefill( self, *, @@ -82,6 +82,7 @@ def interleave_prefill( # the worker itself manages and maintains the prefill states. return None + # pylint: disable-next=all def disaggregated_prefill( self, *, @@ -100,7 +101,7 @@ def disaggregated_prefill( ) all_outputs.append(output) results = ray.get(all_outputs) - return results[0] + return results[0] def prefill( self, @@ -113,23 +114,24 @@ def prefill( result = None if self.is_disaggregated: result = self.disaggregated_prefill( - params=params, - existing_prefix=existing_prefix, - padded_tokens=padded_tokens, - true_length=true_length, - ) + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) else: - result= self.interleave_prefill( - params=params, - existing_prefix=existing_prefix, - padded_tokens=padded_tokens, - true_length=true_length, + self.interleave_prefill( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, ) + # Return None if interleave return result - - + def transfer(self, np_prefix: NpPrefix) -> Any: + """Store prefill result into object store, then transfer to decode engine workers.""" all_outputs = [] np_prefix_ref = ray.put(np_prefix) for worker in self.engine_workers: @@ -137,8 +139,8 @@ def transfer(self, np_prefix: NpPrefix) -> Any: all_outputs.append(output) results = ray.get(all_outputs) - return results[0] - + return results[0] + def insert( self, prefix: Prefix, @@ -229,7 +231,9 @@ def create_pytorch_ray_engine( ) ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() - num_hosts = num_hosts if is_disaggregated else tpu.get_current_pod_worker_count() + num_hosts = ( + num_hosts if is_disaggregated else tpu.get_current_pod_worker_count() + ) print(f"pod_name:{pod_name}, number of host: {num_hosts}") assert ( pod_name is not None @@ -262,16 +266,16 @@ def create_pytorch_ray_engine( if not is_disaggregated: return PyTorchRayEngine( - engine_workers=engine_workers, - tokenizer_path=tokenizer_path, - context_length=context_length, - batch_size=batch_size, - ) + engine_workers=engine_workers, + tokenizer_path=tokenizer_path, + context_length=context_length, + batch_size=batch_size, + ) - workers_dict = defaultdict(list) + workers_dict = defaultdict(list) for worker in engine_workers: pod_slice_name = ray.get(worker.pod_slice_name.remote()) - workers_dict[pod_slice_name].append(worker) + workers_dict[pod_slice_name].append(worker) prefill_engine = PyTorchRayEngine( engine_workers=workers_dict[pod_name], diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 6da7b217..b386bb35 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -57,6 +57,7 @@ class Prefix: caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad + @struct.dataclass # pylint: disable-next=all class NpPrefix: @@ -64,6 +65,7 @@ class NpPrefix: caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad + @struct.dataclass # pylint: disable-next=all class DecodeState: @@ -467,12 +469,16 @@ def prefill_ray( self.prefix_queue.put(prefix, block=False) return token - - def _convert_to_np_caches(self, caches: List[Tuple[jax.Array, jax.Array]]) -> List[Tuple[np.ndarray, np.ndarray]]: + + def _convert_to_np_caches( + self, caches: List[Tuple[jax.Array, jax.Array]] + ) -> List[Tuple[np.ndarray, np.ndarray]]: return [(np.asarray(tup[0]), np.asarray(tup[1])) for tup in caches] - def _convert_to_jax_caches(self, np_caches: List[Tuple[np.ndarray, np.ndarray]]) -> List[Tuple[jax.Array, jax.Array]]: - return [(jnp.asarray(tup[0]), jnp.asarray(tup[1])) for tup in np_caches] + def _convert_to_jax_caches( + self, np_caches: List[Tuple[np.ndarray, np.ndarray]] + ) -> List[Tuple[jax.Array, jax.Array]]: + return [(jnp.asarray(tup[0]), jnp.asarray(tup[1])) for tup in np_caches] def prefill_ray_disaggregation( self, @@ -493,14 +499,16 @@ def prefill_ray_disaggregation( logits = logits[0] token = np.argmax(logits[true_length - 1]) - updated_caches = multihost_utils.process_allgather(updated_caches, tiled=True) + updated_caches = multihost_utils.process_allgather( + updated_caches, tiled=True + ) np_update_caches = self._convert_to_np_caches(updated_caches) np_prefix = NpPrefix(token, np_update_caches, true_length) - return np_prefix - + return np_prefix def transfer(self, np_prefix: NpPrefix) -> Any: + """Transfer prefill result from object store to HBM""" updated_caches = self._convert_to_jax_caches(np_prefix.caches) prefix = Prefix(np_prefix.token, updated_caches, np_prefix.seq_len) self.prefix_queue.put(prefix, block=False) @@ -929,6 +937,7 @@ def max_decode_length(self) -> int: def mesh(self): """return mesh""" return None - + def pod_slice_name(self): - return tpu.get_current_pod_name() + """pod slice name""" + return tpu.get_current_pod_name() From 0584f2160f05ac980a54e4d6e87a09e479649463 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 01:31:08 +0000 Subject: [PATCH 05/11] format --- run_interactive_disaggregated.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py index 82414229..b086d365 100644 --- a/run_interactive_disaggregated.py +++ b/run_interactive_disaggregated.py @@ -118,7 +118,7 @@ def create_disaggregated_engines(): # pylint: disable-next=all def main(argv): - print("start the test") + print("start the test") prefill_engine, decode_engine = create_disaggregated_engines() start = time.perf_counter() @@ -156,14 +156,20 @@ def main(argv): print(f"---- Encoded tokens are: {tokens}") # pylint: disable-next=all - print(f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}") + print( + f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}" + ) prefill_result = prefill_engine.prefill( params=None, padded_tokens=tokens, true_length=true_length ) - print(f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}") + print( + f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}" + ) decode_engine.transfer(prefill_result) # pylint: disable-next=all - print(f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}") + print( + f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}" + ) decode_state = decode_engine.insert(prefill_result, None, slot=slot) sampled_tokens_list = [] while True: From b5c1764fb9b5ecbc9ae1d270841129bd76166b82 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 17:24:25 +0000 Subject: [PATCH 06/11] assign call prefill in one line --- jetstream_pt/ray_engine.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 4e2d9744..fa2fc765 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -111,23 +111,18 @@ def prefill( padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], true_length: int, ) -> Prefix: - result = None - if self.is_disaggregated: - result = self.disaggregated_prefill( - params=params, - existing_prefix=existing_prefix, - padded_tokens=padded_tokens, - true_length=true_length, - ) - else: - self.interleave_prefill( - params=params, - existing_prefix=existing_prefix, - padded_tokens=padded_tokens, - true_length=true_length, - ) - - # Return None if interleave + call_prefill = ( + self.disaggregated_prefill + if self.is_disaggregated + else self.interleave_prefill + ) + # pylint: disable-next=all + result = call_prefill( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) return result def transfer(self, np_prefix: NpPrefix) -> Any: From 8e7d8b184aafffde2051c54c68f131ae1b9635d4 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 20:26:44 +0000 Subject: [PATCH 07/11] refactor prefill in ray engine --- jetstream_pt/ray_engine.py | 40 ++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index fa2fc765..6a9e907e 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -70,17 +70,22 @@ def interleave_prefill( ) -> Prefix: all_outputs = [] for worker in self.engine_workers: - output = worker.prefill_ray.remote( + prefill_func = ( + worker.prefill_ray + if self.is_disaggregated + else worker.prefill_ray_disaggregation + ) + output = prefill_func.remote( params=params, existing_prefix=existing_prefix, padded_tokens=padded_tokens, true_length=true_length, ) all_outputs.append(output) - _ = ray.get(all_outputs) + results = ray.get(all_outputs) # The prefill function does not return any values; # the worker itself manages and maintains the prefill states. - return None + return results[0] # pylint: disable-next=all def disaggregated_prefill( @@ -111,19 +116,24 @@ def prefill( padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], true_length: int, ) -> Prefix: - call_prefill = ( - self.disaggregated_prefill + all_outputs = [] + for worker in self.engine_workers: + prefill_func = ( + worker.prefill_ray if self.is_disaggregated - else self.interleave_prefill - ) - # pylint: disable-next=all - result = call_prefill( - params=params, - existing_prefix=existing_prefix, - padded_tokens=padded_tokens, - true_length=true_length, - ) - return result + else worker.prefill_ray_disaggregation + ) + output = prefill_func.remote( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) + all_outputs.append(output) + results = ray.get(all_outputs) + # The prefill function does not return any values; + # the worker itself manages and maintains the prefill states. + return results[0] def transfer(self, np_prefix: NpPrefix) -> Any: """Store prefill result into object store, then transfer to decode engine workers.""" From 4a0215afcaf90f412737c2cc43e3bd364ae0fb25 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 20:46:37 +0000 Subject: [PATCH 08/11] format --- jetstream_pt/ray_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 6a9e907e..5681c5e6 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -119,9 +119,9 @@ def prefill( all_outputs = [] for worker in self.engine_workers: prefill_func = ( - worker.prefill_ray + worker.prefill_ray_disaggregation if self.is_disaggregated - else worker.prefill_ray_disaggregation + else worker.prefill_ray ) output = prefill_func.remote( params=params, From 6e5abb7e79ade930ae0484000f7fac6cab0a4684 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 20:50:20 +0000 Subject: [PATCH 09/11] clean up ray prefill --- jetstream_pt/ray_engine.py | 55 +++----------------------------------- 1 file changed, 3 insertions(+), 52 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 5681c5e6..e9026896 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -59,55 +59,6 @@ def init_decode_state( _ = ray.get(all_outputs) return None - # pylint: disable-next=all - def interleave_prefill( - self, - *, - params: Any, # Weights - existing_prefix: Optional[Prefix] = None, - padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], - true_length: int, - ) -> Prefix: - all_outputs = [] - for worker in self.engine_workers: - prefill_func = ( - worker.prefill_ray - if self.is_disaggregated - else worker.prefill_ray_disaggregation - ) - output = prefill_func.remote( - params=params, - existing_prefix=existing_prefix, - padded_tokens=padded_tokens, - true_length=true_length, - ) - all_outputs.append(output) - results = ray.get(all_outputs) - # The prefill function does not return any values; - # the worker itself manages and maintains the prefill states. - return results[0] - - # pylint: disable-next=all - def disaggregated_prefill( - self, - *, - params: Any, # Weights - existing_prefix: Optional[Prefix] = None, - padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], - true_length: int, - ) -> Prefix: - all_outputs = [] - for worker in self.engine_workers: - output = worker.prefill_ray_disaggregation.remote( - params=params, - existing_prefix=existing_prefix, - padded_tokens=padded_tokens, - true_length=true_length, - ) - all_outputs.append(output) - results = ray.get(all_outputs) - return results[0] - def prefill( self, *, @@ -119,9 +70,9 @@ def prefill( all_outputs = [] for worker in self.engine_workers: prefill_func = ( - worker.prefill_ray_disaggregation - if self.is_disaggregated - else worker.prefill_ray + worker.prefill_ray_disaggregation + if self.is_disaggregated + else worker.prefill_ray ) output = prefill_func.remote( params=params, From a84cfbcfefc537a09997e4c9b3d3a32b0ab0112a Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 20:53:32 +0000 Subject: [PATCH 10/11] remove duplicated flax installation --- install_everything.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/install_everything.sh b/install_everything.sh index 35cff401..1a542efb 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -28,7 +28,6 @@ pip install flax==0.8.3 pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install tensorflow-text pip install tensorflow -pip install flax==0.8.3 pip install ray[default]==2.22.0 # torch cpu From 01b808fe0ee3f8b38f6dc1ac82fde5408cbe35fb Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 29 May 2024 22:09:50 +0000 Subject: [PATCH 11/11] add tuple as todo --- jetstream_pt/ray_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index e9026896..2d65ba15 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -180,6 +180,7 @@ def create_pytorch_ray_engine( decode_pod_slice_name: str = None, ) -> Any: + # Return tuple as reponse: issues/107 supported_models = ["llama-2", "llama-3", "gemma"] if model_name not in supported_models: raise NotImplementedError(