-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathray_engine.py
308 lines (278 loc) · 9.22 KB
/
ray_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from collections import defaultdict
import threading
from typing import Any, Iterable, Optional, Union, Tuple, List
import numpy as np
import ray
from ray.runtime_env import RuntimeEnv
from ray.util.accelerators import tpu
from jetstream.engine import engine_api, tokenizer_pb2
from jetstream_pt.ray_worker import PyTorchRayWorker
Params = Any
Prefix = Any
DecodeState = Any
NpPrefix = Any
class PyTorchRayEngine(engine_api.Engine):
"""Ray PyTorch Engine Implementation for Multi-Host Inference Serving.
Key Features:
1. Manages all Ray workers.
2. Initializes model parameters for each Ray worker.
3. Routes incoming inference requests to Ray workers.
4. Collects token responses from the Ray workers.
"""
def __init__(
self,
engine_workers: Iterable[PyTorchRayWorker],
tokenizer_path: str,
context_length: int,
batch_size: int,
is_disaggregated: bool = False,
pod_slice_name: str = None,
):
self.engine_workers = engine_workers
self.tokenizer_path = tokenizer_path
self.context_length = context_length
self.batch_size = batch_size
self.is_disaggregated = is_disaggregated
self.pod_slice_name = pod_slice_name
if not self.is_disaggregated:
self._lock = threading.Lock()
# pylint: disable-next=all
def load_params(self) -> Params:
all_outputs = []
for worker in self.engine_workers:
output = worker.load_params_ray.remote()
all_outputs.append(output)
_ = ray.get(all_outputs)
return None
# pylint: disable-next=all
def init_decode_state(
self,
) -> DecodeState:
all_outputs = []
for worker in self.engine_workers:
output = worker.init_decode_state_ray.remote()
all_outputs.append(output)
_ = ray.get(all_outputs)
return None
def prefill(
self,
*,
params: Any, # Weights
existing_prefix: Optional[Prefix] = None,
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
true_length: int,
sampler=None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
if self.is_disaggregated:
return self.prefill_impl(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)
with self._lock:
return self.prefill_impl(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)
# pylint: disable-next=all
def prefill_impl(
self,
*,
params: Any, # Weights
existing_prefix: Optional[Prefix] = None,
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
true_length: int,
) -> Tuple[Prefix, engine_api.ResultTokens]:
all_outputs = []
for worker in self.engine_workers:
prefill_func = (
worker.prefill_ray_disaggregation
if self.is_disaggregated
else worker.prefill_ray
)
output = prefill_func.remote(
params=params,
existing_prefix=existing_prefix,
padded_tokens=padded_tokens,
true_length=true_length,
)
all_outputs.append(output)
results = ray.get(all_outputs)
# The prefill function does not return any values;
# the worker itself manages and maintains the prefill states.
return results[0]
def transfer(self, np_prefix: NpPrefix) -> Any:
"""Store prefill result into object store, then transfer to decode engine workers."""
all_outputs = []
np_prefix_ref = ray.put(np_prefix)
for worker in self.engine_workers:
output = worker.transfer.remote(np_prefix_ref)
all_outputs.append(output)
results = ray.get(all_outputs)
return results[0]
def insert(
self,
prefix: Prefix,
decode_state: DecodeState,
slot: int,
) -> DecodeState:
all_outputs = []
for worker in self.engine_workers:
output = worker.insert_ray.remote(
prefix=prefix, decode_state=decode_state, slot=slot
)
all_outputs.append(output)
_ = ray.get(all_outputs)
# The insert function does not return any values;
# the worker itself manages and maintains the DecodeState.
return None
def generate(
self,
params: Any,
decode_state: DecodeState,
sampler=None,
) -> tuple[None, engine_api.ResultTokens]:
if self.is_disaggregated:
return self.generate_impl(params=params, decode_state=decode_state)
with self._lock:
return self.generate_impl(params=params, decode_state=decode_state)
# pylint: disable-next=all
def generate_impl(
self, params: Any, decode_state: DecodeState
) -> tuple[None, engine_api.ResultTokens]:
all_outputs = []
for worker in self.engine_workers:
output = worker.generate_ray.remote(
params=params, decode_state=decode_state
)
all_outputs.append(output)
# All workers performed an all_gather operation. Since the results are
# identical across all workers, the result from worker 0 is returned.
state, result_tokens = ray.get(all_outputs)[0]
return state, result_tokens
# pylint: disable-next=all
def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
# pylint: disable-next=all
return tokenizer_pb2.TokenizerParameters(path=self.tokenizer_path)
@property
def max_concurrent_decodes(self) -> int:
return self.batch_size
@property
def samples_per_slot(self) -> int:
return 1
@property
def max_prefill_length(self) -> int:
return self.context_length
@property
def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]:
# ray head doesn't load any parameters
return None
def get_prefix_destination_sharding(self) -> Prefix:
"No implementation"
return None
@property
def mesh(self):
"No implementation"
return None
# pylint: disable-next=all
def create_pytorch_ray_engine(
tokenizer_path: str,
ckpt_path: Optional[str] = None,
samples_per_slot: int = 1,
bf16_enable: bool = False,
param_size: str = "7b",
context_length: int = 1024,
batch_size: int = 1,
max_decode_length: int = 4096,
model_name="llama",
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
is_disaggregated: bool = False,
num_hosts: int = 0,
worker_chips: int = 0,
tpu_chips: int = 0,
decode_pod_slice_name: str = None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
**kwargs,
) -> Union[
PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]]
]:
# Return tuple as reponse: issues/107
supported_models = ["llama-2", "llama-3", "gemma"]
if model_name not in supported_models:
raise NotImplementedError(
f"Model name should be one of{','.join(supported_models)}"
)
ray.init(ignore_reinit_error=True)
pod_name = tpu.get_current_pod_name()
num_hosts = num_hosts if num_hosts > 0 else tpu.get_current_pod_worker_count()
worker_chips = worker_chips if worker_chips > 0 else 4
print(f"pod_name:{pod_name}, number of host: {num_hosts}")
assert (
pod_name is not None
), f"TPU pod name (current value:{pod_name}) can not be None"
assert (
num_hosts > 0
), f"num_hosts (current value {num_hosts}) should be a positive number"
assert (
num_hosts * worker_chips == tpu_chips
), f"num_hosts:{num_hosts} * worker_chips: {worker_chips} not equal to tpu_chips: {tpu_chips}"
# pylint: disable-next=all
engine_worker_with_tpu_resource = PyTorchRayWorker.options(
resources={"TPU": worker_chips},
runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}),
)
engine_workers = []
for _ in range(num_hosts):
engine_worker = engine_worker_with_tpu_resource.remote(
tokenizer_path=tokenizer_path,
ckpt_path=ckpt_path,
samples_per_slot=samples_per_slot,
bf16_enable=bf16_enable,
param_size=param_size,
context_length=context_length,
batch_size=batch_size,
max_decode_length=max_decode_length,
model_name=model_name,
quantize_weights=quantize_weights,
quantize_kv=quantize_kv,
max_cache_length=max_cache_length,
sharding_config=sharding_config,
enable_jax_profiler=enable_jax_profiler,
jax_profiler_port=jax_profiler_port,
)
engine_workers.append(engine_worker)
if not is_disaggregated:
return PyTorchRayEngine(
engine_workers=engine_workers,
tokenizer_path=tokenizer_path,
context_length=context_length,
batch_size=batch_size,
)
workers_dict = defaultdict(list)
for worker in engine_workers:
pod_slice_name = ray.get(worker.pod_slice_name.remote())
workers_dict[pod_slice_name].append(worker)
prefill_engine = PyTorchRayEngine(
engine_workers=workers_dict[pod_name],
tokenizer_path=tokenizer_path,
context_length=context_length,
batch_size=batch_size,
is_disaggregated=is_disaggregated,
pod_slice_name=pod_name,
)
decode_engine = PyTorchRayEngine(
engine_workers=workers_dict[decode_pod_slice_name],
tokenizer_path=tokenizer_path,
context_length=context_length,
batch_size=batch_size,
is_disaggregated=is_disaggregated,
pod_slice_name=decode_pod_slice_name,
)
return ([prefill_engine], [decode_engine])