diff --git a/install_everything.sh b/install_everything.sh index 0a2b21c0..1a542efb 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -19,9 +19,16 @@ pip show libtpu-nightly && pip uninstall -y libtpu-nightly pip show tensorflow && pip uninstall -y tensorflow pip show ray && pip uninstall -y ray pip show flax && pip uninstall -y flax +pip show keras && pip uninstall -y keras +pip show tensorboard && pip uninstall -y tensorboard +pip show tensorflow-text && pip uninstall -y tensorflow-text +pip show torch_xla2 && pip uninstall -y torch_xla2 pip install flax==0.8.3 pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install tensorflow-text +pip install tensorflow + pip install ray[default]==2.22.0 # torch cpu pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 1b394829..2d65ba15 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -1,7 +1,7 @@ +from collections import defaultdict from typing import Any, Iterable, Optional, Union import numpy as np -import jax import ray from ray.util.accelerators import tpu @@ -11,6 +11,7 @@ Params = Any Prefix = Any DecodeState = Any +NpPrefix = Any class PyTorchRayEngine(engine_api.Engine): @@ -28,11 +29,15 @@ def __init__( tokenizer_path: str, context_length: int, batch_size: int, + is_disaggregated: bool = False, + pod_slice_name: str = None, ): self.engine_workers = engine_workers self.tokenizer_path = tokenizer_path self.context_length = context_length self.batch_size = batch_size + self.is_disaggregated = is_disaggregated + self.pod_slice_name = pod_slice_name # pylint: disable-next=all def load_params(self) -> Params: @@ -64,17 +69,33 @@ def prefill( ) -> Prefix: all_outputs = [] for worker in self.engine_workers: - output = worker.prefill_ray.remote( + prefill_func = ( + worker.prefill_ray_disaggregation + if self.is_disaggregated + else worker.prefill_ray + ) + output = prefill_func.remote( params=params, existing_prefix=existing_prefix, padded_tokens=padded_tokens, true_length=true_length, ) all_outputs.append(output) - _ = ray.get(all_outputs) + results = ray.get(all_outputs) # The prefill function does not return any values; # the worker itself manages and maintains the prefill states. - return None + return results[0] + + def transfer(self, np_prefix: NpPrefix) -> Any: + """Store prefill result into object store, then transfer to decode engine workers.""" + all_outputs = [] + np_prefix_ref = ray.put(np_prefix) + for worker in self.engine_workers: + output = worker.transfer.remote(np_prefix_ref) + all_outputs.append(output) + results = ray.get(all_outputs) + + return results[0] def insert( self, @@ -126,7 +147,8 @@ def max_prefill_length(self) -> int: @property def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]: - return jax.devices("cpu")[0] + # ray head doesn't load any parameters + return None def get_prefix_destination_sharding(self) -> Prefix: "No implementation" @@ -153,8 +175,12 @@ def create_pytorch_ray_engine( quantize_kv=False, max_cache_length=1024, sharding_config=None, -) -> PyTorchRayEngine: + is_disaggregated: bool = False, + num_hosts: int = 0, + decode_pod_slice_name: str = None, +) -> Any: + # Return tuple as reponse: issues/107 supported_models = ["llama-2", "llama-3", "gemma"] if model_name not in supported_models: raise NotImplementedError( @@ -162,7 +188,9 @@ def create_pytorch_ray_engine( ) ray.init(ignore_reinit_error=True) pod_name = tpu.get_current_pod_name() - num_hosts = tpu.get_current_pod_worker_count() + num_hosts = ( + num_hosts if is_disaggregated else tpu.get_current_pod_worker_count() + ) print(f"pod_name:{pod_name}, number of host: {num_hosts}") assert ( pod_name is not None @@ -192,10 +220,34 @@ def create_pytorch_ray_engine( sharding_config=sharding_config, ) engine_workers.append(engine_worker) - engine_master = PyTorchRayEngine( - engine_workers=engine_workers, + + if not is_disaggregated: + return PyTorchRayEngine( + engine_workers=engine_workers, + tokenizer_path=tokenizer_path, + context_length=context_length, + batch_size=batch_size, + ) + + workers_dict = defaultdict(list) + for worker in engine_workers: + pod_slice_name = ray.get(worker.pod_slice_name.remote()) + workers_dict[pod_slice_name].append(worker) + + prefill_engine = PyTorchRayEngine( + engine_workers=workers_dict[pod_name], + tokenizer_path=tokenizer_path, + context_length=context_length, + batch_size=batch_size, + is_disaggregated=is_disaggregated, + pod_slice_name=pod_name, + ) + decode_engine = PyTorchRayEngine( + engine_workers=workers_dict[decode_pod_slice_name], tokenizer_path=tokenizer_path, context_length=context_length, batch_size=batch_size, + is_disaggregated=is_disaggregated, + pod_slice_name=decode_pod_slice_name, ) - return engine_master + return (prefill_engine, decode_engine) diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index d1a16b41..b386bb35 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -23,6 +23,7 @@ import jax import numpy as np import ray +from ray.util.accelerators import tpu import safetensors import torch import torch_xla2 @@ -57,6 +58,14 @@ class Prefix: seq_len: int # true seqlen front pad +@struct.dataclass +# pylint: disable-next=all +class NpPrefix: + token: jax.Array # [1, seqlen] + caches: List[Tuple[jax.Array, jax.Array]] + seq_len: int # true seqlen front pad + + @struct.dataclass # pylint: disable-next=all class DecodeState: @@ -461,6 +470,50 @@ def prefill_ray( return token + def _convert_to_np_caches( + self, caches: List[Tuple[jax.Array, jax.Array]] + ) -> List[Tuple[np.ndarray, np.ndarray]]: + return [(np.asarray(tup[0]), np.asarray(tup[1])) for tup in caches] + + def _convert_to_jax_caches( + self, np_caches: List[Tuple[np.ndarray, np.ndarray]] + ) -> List[Tuple[jax.Array, jax.Array]]: + return [(jnp.asarray(tup[0]), jnp.asarray(tup[1])) for tup in np_caches] + + def prefill_ray_disaggregation( + self, + *, + params: Any, # Weights + existing_prefix: Optional[Prefix] = None, + padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray], + true_length: int, + ) -> Any: + """Do prefill in ray worker""" + logits, updated_caches = self.prefill( + params=params, + existing_prefix=existing_prefix, + padded_tokens=padded_tokens, + true_length=true_length, + ) + if len(logits.shape) == 3: # b, seqlen, num words + logits = logits[0] + + token = np.argmax(logits[true_length - 1]) + updated_caches = multihost_utils.process_allgather( + updated_caches, tiled=True + ) + np_update_caches = self._convert_to_np_caches(updated_caches) + np_prefix = NpPrefix(token, np_update_caches, true_length) + + return np_prefix + + def transfer(self, np_prefix: NpPrefix) -> Any: + """Transfer prefill result from object store to HBM""" + updated_caches = self._convert_to_jax_caches(np_prefix.caches) + prefix = Prefix(np_prefix.token, updated_caches, np_prefix.seq_len) + self.prefix_queue.put(prefix, block=False) + return True + def shrink_prefix( self, prefix: Prefix, @@ -884,3 +937,7 @@ def max_decode_length(self) -> int: def mesh(self): """return mesh""" return None + + def pod_slice_name(self): + """pod slice name""" + return tpu.get_current_pod_name() diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py new file mode 100644 index 00000000..b086d365 --- /dev/null +++ b/run_interactive_disaggregated.py @@ -0,0 +1,201 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import time + +from typing import List +from absl import app +from absl import flags +from colorama import Fore, Style + +import numpy as np +import jax + +from jetstream.engine import token_utils +from jetstream_pt import ray_engine + +FLAGS = flags.FLAGS + +_TOKENIZER_PATH = flags.DEFINE_string( + "tokenizer_path", + "tokenizer.model", + "The tokenizer model path", + required=False, +) +_CKPT_PATH = flags.DEFINE_string( + "checkpoint_path", None, "Directory for .pth checkpoints", required=False +) +_BF16_ENABLE = flags.DEFINE_bool( + "bf16_enable", False, "Whether to enable bf16", required=False +) +_CONTEXT_LENGTH = flags.DEFINE_integer( + "context_length", 1024, "The context length", required=False +) +_BATCH_SIZE = flags.DEFINE_integer( + "batch_size", 32, "The batch size", required=False +) +_PROFILING_OUTPUT = flags.DEFINE_string( + "profiling_output", + "", + "The profiling output", + required=False, +) + +_SIZE = flags.DEFINE_string("size", "tiny", "size of model") + +_QUANTIZE_WEIGHTS = flags.DEFINE_bool( + "quantize_weights", False, "weight quantization" +) +_QUANTIZE_KV_CACHE = flags.DEFINE_bool( + "quantize_kv_cache", False, "kv_cache_quantize" +) +_MAX_CACHE_LENGTH = flags.DEFINE_integer( + "max_cache_length", 1024, "kv_cache_quantize" +) + +_MODEL_NAME = flags.DEFINE_string( + "model_name", None, "model type", required=False +) + +_SHARDING_CONFIG = flags.DEFINE_string( + "sharding_config", "", "config file for sharding" +) + + +_IS_DISAGGREGATED = flags.DEFINE_bool( + "is_disaggregated", False, "Disaggregated serving if it's True" +) + +_NUM_HOSTS = flags.DEFINE_integer( + "num_hosts", 4, "Number of TPU host", required=False +) + +_DECODE_POD_SLICE_NAME = flags.DEFINE_string( + "decode_pod_slice_name", "", "Decode pod slice name" +) + + +def create_disaggregated_engines(): + """create a pytorch engine""" + # jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + start = time.perf_counter() + prefill_engine, decode_engine = ray_engine.create_pytorch_ray_engine( + model_name=_MODEL_NAME.value, + tokenizer_path=_TOKENIZER_PATH.value, + ckpt_path=_CKPT_PATH.value, + bf16_enable=True, + param_size=_SIZE.value, + context_length=_CONTEXT_LENGTH.value, + batch_size=_BATCH_SIZE.value, + quantize_weights=_QUANTIZE_WEIGHTS.value, + quantize_kv=_QUANTIZE_KV_CACHE.value, + max_cache_length=_MAX_CACHE_LENGTH.value, + sharding_config=_SHARDING_CONFIG.value, + is_disaggregated=_IS_DISAGGREGATED.value, + num_hosts=_NUM_HOSTS.value, + decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value, + ) + + print("Initialize engine", time.perf_counter() - start) + return (prefill_engine, decode_engine) + + +# pylint: disable-next=all +def main(argv): + + print("start the test") + prefill_engine, decode_engine = create_disaggregated_engines() + + start = time.perf_counter() + prefill_engine.load_params() + decode_engine.load_params() + print("Load params ", time.perf_counter() - start) + + metadata = prefill_engine.get_tokenizer() + tokenizer = prefill_engine.build_tokenizer(metadata) + vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) + stop_tokens = [vocab.eos_id, vocab.pad_id] + max_output_length = 1024 + + if _PROFILING_OUTPUT.value: + jax.profiler.start_trace(_PROFILING_OUTPUT.value) + + decode_engine.init_decode_state() + prompts: List[str] = [ + "I believe the meaning of life is", + # pylint: disable-next=all + "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", + ] + for prompt in prompts: + slot = random.randint(0, _BATCH_SIZE.value - 1) + tokens, true_length = token_utils.tokenize_and_pad( + prompt, vocab, is_bos=True, jax_padding=False + ) + print(f"---- Input prompts are: {prompt}") + print(f"---- Encoded tokens are: {tokens}") + + # pylint: disable-next=all + print( + f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}" + ) + prefill_result = prefill_engine.prefill( + params=None, padded_tokens=tokens, true_length=true_length + ) + print( + f"---- Transfer prefill result to decode engine pod_slice_name: {decode_engine.pod_slice_name}" + ) + decode_engine.transfer(prefill_result) + # pylint: disable-next=all + print( + f"---- Do insert in decode engine pod_slice_name: {decode_engine.pod_slice_name}" + ) + decode_state = decode_engine.insert(prefill_result, None, slot=slot) + sampled_tokens_list = [] + while True: + # pylint: disable-next=all + decode_state, result_tokens = decode_engine.generate(None, decode_state) + result_tokens = result_tokens.convert_to_numpy() + + slot_data = result_tokens.get_result_at_slot(slot) + slot_tokens = slot_data.tokens + slot_lengths = slot_data.lengths + + token_id = slot_tokens[slot, 0].item() + if slot_lengths > max_output_length or token_id in stop_tokens: + break + + sampled_tokens_list.append(token_id) + + print("---- All output tokens.") + print(sampled_tokens_list) + print("---- All output text.") + print(vocab.tokenizer.decode(sampled_tokens_list)) + + if _PROFILING_OUTPUT.value: + jax.profiler.stop_trace() + + +if __name__ == "__main__": + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + app.run(main)