We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 60c2fa5 commit 0572532Copy full SHA for 0572532
jetstream_pt/ray_worker.py
@@ -469,12 +469,12 @@ def prefill_ray(
469
prefix = Prefix(token, updated_caches, true_length)
470
self.prefix_queue.put(prefix, block=False)
471
472
- token_out = jnp.reshape(token, (1, 1))
473
- data = jnp.concatenate(
+ token_out = np.reshape(token, (1, 1))
+ data = np.concatenate(
474
[
475
token_out, # First token
476
- jnp.ones_like(token_out), # validity of first token
477
- jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
+ np.ones_like(token_out), # validity of first token
+ np.zeros((1, 1), dtype=np.int32), # length = 0
478
],
479
axis=-1,
480
)
0 commit comments