-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pallas+scan: NotImplementedError when num_extensive is True #21190
Comments
I also tried a simpler implementation without using scan but just a loop, which produces a different error: import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def cumsum_kernel(xref_Cxd, carryref_d, oref_Cxd):
carry_d = carryref_d[...]
x_Cxd = xref_Cxd[...]
C, d = x_Cxd.shape
for i in range(C):
carry_d += x_Cxd[i]
oref_Cxd[i, :] = carry_d
carryref_d[...] = carry_d
@jax.jit
def cumsum(x_Txd):
T, d = x_Txd.shape
C = 128
carry_d = jnp.zeros((d,), dtype=x_Txd.dtype)
return pl.pallas_call(
cumsum_kernel,
grid=T // C,
in_specs=[
pl.BlockSpec(lambda i: (i, 0), (C, d)),
pl.BlockSpec(lambda i: 0, (d,)),
],
out_specs=pl.BlockSpec(lambda i: (i, 0), (C, d)),
out_shape=jax.ShapeDtypeStruct(x_Txd.shape, x_Txd.dtype),
interpret=False,
)(x_Txd, carry_d)
key = jax.random.PRNGKey(0)
x_Txd = jax.random.normal(key, (1280, 64))
cs_Txd = cumsum(x_Txd)
realcs_Txd = jnp.cumsum(x_Txd, axis=0)
print(f"Error is {jnp.max(jnp.abs(cs_Txd - realcs_Txd))}") This produces a different error:
|
AllanYangZhou
changed the title
Pallas scan on TPU
Pallas+scan: NotImplementedError when num_extensive is True
May 16, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
I tried to write a simple TPU pallas kernel to implement cum_sum, where the input is being chunked along the summed dimension and the kernel is calculating one chunk at a time. I currently get the below error. It is being triggered because the
num_extensive
variable evaluates to True, though I don't understand what extensive means here.The code to reproduce is below. Note that if I turn
interpret=True
the code runs without error.System info (python version, jaxlib version, accelerator, etc.)
I am using a Cloud TPU v3-8 and using Jax Version: 0.4.23.
The text was updated successfully, but these errors were encountered: