Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@ pip show libtpu-nightly && pip uninstall -y libtpu-nightly
pip show tensorflow && pip uninstall -y tensorflow
pip show ray && pip uninstall -y ray
pip show flax && pip uninstall -y flax
pip show keras && pip uninstall -y keras
pip show tensorboard && pip uninstall -y tensorboard
pip show tensorflow-text && pip uninstall -y tensorflow-text
pip show torch_xla2 && pip uninstall -y torch_xla2

pip install flax==0.8.3
pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install tensorflow-text
pip install tensorflow

pip install ray[default]==2.22.0
# torch cpu
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu
Expand Down
72 changes: 62 additions & 10 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import Any, Iterable, Optional, Union

import numpy as np
import jax
import ray
from ray.util.accelerators import tpu

Expand All @@ -11,6 +11,7 @@
Params = Any
Prefix = Any
DecodeState = Any
NpPrefix = Any


class PyTorchRayEngine(engine_api.Engine):
Expand All @@ -28,11 +29,15 @@ def __init__(
tokenizer_path: str,
context_length: int,
batch_size: int,
is_disaggregated: bool = False,
pod_slice_name: str = None,
):
self.engine_workers = engine_workers
self.tokenizer_path = tokenizer_path
self.context_length = context_length
self.batch_size = batch_size
self.is_disaggregated = is_disaggregated
self.pod_slice_name = pod_slice_name

# pylint: disable-next=all
def load_params(self) -> Params:
Expand Down Expand Up @@ -64,17 +69,33 @@ def prefill(
) -> Prefix:
all_outputs = []
for worker in self.engine_workers:
output = worker.prefill_ray.remote(
prefill_func = (
worker.prefill_ray_disaggregation
if self.is_disaggregated
else worker.prefill_ray
)
output = prefill_func.remote(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)
all_outputs.append(output)
_ = ray.get(all_outputs)
results = ray.get(all_outputs)
# The prefill function does not return any values;
# the worker itself manages and maintains the prefill states.
return None
return results[0]

def transfer(self, np_prefix: NpPrefix) -> Any:
"""Store prefill result into object store, then transfer to decode engine workers."""
all_outputs = []
np_prefix_ref = ray.put(np_prefix)
for worker in self.engine_workers:
output = worker.transfer.remote(np_prefix_ref)
all_outputs.append(output)
results = ray.get(all_outputs)

return results[0]

def insert(
self,
Expand Down Expand Up @@ -126,7 +147,8 @@ def max_prefill_length(self) -> int:

@property
def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]:
return jax.devices("cpu")[0]
# ray head doesn't load any parameters
return None

def get_prefix_destination_sharding(self) -> Prefix:
"No implementation"
Expand All @@ -153,16 +175,22 @@ def create_pytorch_ray_engine(
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
) -> PyTorchRayEngine:
is_disaggregated: bool = False,
num_hosts: int = 0,
decode_pod_slice_name: str = None,
) -> Any:

# Return tuple as reponse: issues/107
supported_models = ["llama-2", "llama-3", "gemma"]
if model_name not in supported_models:
raise NotImplementedError(
f"Model name should be one of{','.join(supported_models)}"
)
ray.init(ignore_reinit_error=True)
pod_name = tpu.get_current_pod_name()
num_hosts = tpu.get_current_pod_worker_count()
num_hosts = (
num_hosts if is_disaggregated else tpu.get_current_pod_worker_count()
)
print(f"pod_name:{pod_name}, number of host: {num_hosts}")
assert (
pod_name is not None
Expand Down Expand Up @@ -192,10 +220,34 @@ def create_pytorch_ray_engine(
sharding_config=sharding_config,
)
engine_workers.append(engine_worker)
engine_master = PyTorchRayEngine(
engine_workers=engine_workers,

if not is_disaggregated:
return PyTorchRayEngine(
engine_workers=engine_workers,
tokenizer_path=tokenizer_path,
context_length=context_length,
batch_size=batch_size,
)

workers_dict = defaultdict(list)
for worker in engine_workers:
pod_slice_name = ray.get(worker.pod_slice_name.remote())
workers_dict[pod_slice_name].append(worker)

prefill_engine = PyTorchRayEngine(
engine_workers=workers_dict[pod_name],
tokenizer_path=tokenizer_path,
context_length=context_length,
batch_size=batch_size,
is_disaggregated=is_disaggregated,
pod_slice_name=pod_name,
)
decode_engine = PyTorchRayEngine(
engine_workers=workers_dict[decode_pod_slice_name],
tokenizer_path=tokenizer_path,
context_length=context_length,
batch_size=batch_size,
is_disaggregated=is_disaggregated,
pod_slice_name=decode_pod_slice_name,
)
return engine_master
return (prefill_engine, decode_engine)
57 changes: 57 additions & 0 deletions jetstream_pt/ray_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax
import numpy as np
import ray
from ray.util.accelerators import tpu
import safetensors
import torch
import torch_xla2
Expand Down Expand Up @@ -57,6 +58,14 @@ class Prefix:
seq_len: int # true seqlen front pad


@struct.dataclass
# pylint: disable-next=all
class NpPrefix:
token: jax.Array # [1, seqlen]
caches: List[Tuple[jax.Array, jax.Array]]
seq_len: int # true seqlen front pad


@struct.dataclass
# pylint: disable-next=all
class DecodeState:
Expand Down Expand Up @@ -461,6 +470,50 @@ def prefill_ray(

return token

def _convert_to_np_caches(
self, caches: List[Tuple[jax.Array, jax.Array]]
) -> List[Tuple[np.ndarray, np.ndarray]]:
return [(np.asarray(tup[0]), np.asarray(tup[1])) for tup in caches]

def _convert_to_jax_caches(
self, np_caches: List[Tuple[np.ndarray, np.ndarray]]
) -> List[Tuple[jax.Array, jax.Array]]:
return [(jnp.asarray(tup[0]), jnp.asarray(tup[1])) for tup in np_caches]

def prefill_ray_disaggregation(
self,
*,
params: Any, # Weights
existing_prefix: Optional[Prefix] = None,
padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray],
true_length: int,
) -> Any:
"""Do prefill in ray worker"""
logits, updated_caches = self.prefill(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)
if len(logits.shape) == 3: # b, seqlen, num words
logits = logits[0]

token = np.argmax(logits[true_length - 1])
updated_caches = multihost_utils.process_allgather(
updated_caches, tiled=True
)
np_update_caches = self._convert_to_np_caches(updated_caches)
np_prefix = NpPrefix(token, np_update_caches, true_length)

return np_prefix

def transfer(self, np_prefix: NpPrefix) -> Any:
"""Transfer prefill result from object store to HBM"""
updated_caches = self._convert_to_jax_caches(np_prefix.caches)
prefix = Prefix(np_prefix.token, updated_caches, np_prefix.seq_len)
self.prefix_queue.put(prefix, block=False)
return True

def shrink_prefix(
self,
prefix: Prefix,
Expand Down Expand Up @@ -884,3 +937,7 @@ def max_decode_length(self) -> int:
def mesh(self):
"""return mesh"""
return None

def pod_slice_name(self):
"""pod slice name"""
return tpu.get_current_pod_name()
Loading