Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas+scan: NotImplementedError when num_extensive is True #21190

Open
AllanYangZhou opened this issue May 12, 2024 · 1 comment
Open

Pallas+scan: NotImplementedError when num_extensive is True #21190

AllanYangZhou opened this issue May 12, 2024 · 1 comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@AllanYangZhou
Copy link

Description

I tried to write a simple TPU pallas kernel to implement cum_sum, where the input is being chunked along the summed dimension and the kernel is calculating one chunk at a time. I currently get the below error. It is being triggered because the num_extensive variable evaluates to True, though I don't understand what extensive means here.

Traceback (most recent call last):
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 549, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1516, in _scan_lowering_rule
    if num_extensive: raise NotImplementedError
NotImplementedError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 41, in <module>
    carry_d, cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 24, in cumsum
    return pl.pallas_call(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 452, in wrapped
    grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 385, in _trace_to_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 15, in cumsum_kernel
    carry_d, cs_Cxd = jax.lax.scan(inner, carry_d, x_Cxd, length=x_Cxd.shape[0])
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax._src.pallas.mosaic.lowering.LoweringException: Exception while lowering eqn:
  a:f32[256] b:f32[128,256] = scan[
  jaxpr={ lambda ; c:f32[256] d:f32[256]. let e:f32[256] = add c d in (e, e) }
  length=128
  linear=(False, False)
  num_carry=1
  num_consts=0
  reverse=False
  unroll=1
] f g
With context:
  LoweringRuleContext(lowering_context=LoweringContext(ir_context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f971e902f70>, grid_indices=(<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7f971e8e1b70>,), block_shapes=[(128, 256), (256,), (128, 256)], name_stack=NameStack(stack=()), mesh_context=None), avals_in=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], avals_out=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], block_shapes=[None, None])
With inval shapes=[None, None]
With inval types=[VectorType(vector<256xf32>), VectorType(vector<128x256xf32>)]
In jaxpr:
{ lambda ; a:Ref{float32[128,256]} b:Ref{float32[256]} c:Ref{float32[128,256]}. let
    d:i32[] = program_id[axis=0] 
    e:bool[] = eq d 0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    cond[
      branches=(
        { lambda ; g_:Ref{float32[256]}. let  in () }
        { lambda ; h:Ref{float32[256]}. let
            i:f32[256] = broadcast_in_dim[broadcast_dimensions=() shape=(256,)] 0.0
            h[:] <- i
          in () }
      )
      linear=(False,)
    ] f b
    j:f32[256] <- b[:]
    k:f32[128,256] <- a[:,:]
    l:f32[256] m:f32[128,256] = scan[
      jaxpr={ lambda ; n:f32[256] o:f32[256]. let
          p:f32[256] = add n o
        in (p, p) }
      length=128
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] j k
    b[:] <- l
    c[:,:] <- m
  in () }

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 41, in <module>
    carry_d, cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 75, in pallas_call_tpu_lowering_rule
    mosaic_module, extra_args = lowering.lower_jaxpr_to_module(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 340, in lower_jaxpr_to_module
    func_op = lower_jaxpr_to_func(ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 478, in lower_jaxpr_to_func
    body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jaxlib/mlir/dialects/func.py", line 195, in decorator
    return_values = f(*func_args, **func_kwargs)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 474, in body_func
    return jaxpr_subcomp(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 555, in jaxpr_subcomp
    raise LoweringException(
jax._src.pallas.mosaic.lowering.LoweringException: Exception while lowering eqn:
  a:f32[256] b:f32[128,256] = scan[
  jaxpr={ lambda ; c:f32[256] d:f32[256]. let e:f32[256] = add c d in (e, e) }
  length=128
  linear=(False, False)
  num_carry=1
  num_consts=0
  reverse=False
  unroll=1
] f g
With context:
  LoweringRuleContext(lowering_context=LoweringContext(ir_context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f971e902f70>, grid_indices=(<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7f971e8e1b70>,), block_shapes=[(128, 256), (256,), (128, 256)], name_stack=NameStack(stack=()), mesh_context=None), avals_in=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], avals_out=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], block_shapes=[None, None])
With inval shapes=[None, None]
With inval types=[VectorType(vector<256xf32>), VectorType(vector<128x256xf32>)]
In jaxpr:
{ lambda ; a:Ref{float32[128,256]} b:Ref{float32[256]} c:Ref{float32[128,256]}. let
    d:i32[] = program_id[axis=0] 
    e:bool[] = eq d 0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    cond[
      branches=(
        { lambda ; g_:Ref{float32[256]}. let  in () }
        { lambda ; h:Ref{float32[256]}. let
            i:f32[256] = broadcast_in_dim[broadcast_dimensions=() shape=(256,)] 0.0
            h[:] <- i
          in () }
      )
      linear=(False,)
    ] f b
    j:f32[256] <- b[:]
    k:f32[128,256] <- a[:,:]
    l:f32[256] m:f32[128,256] = scan[
      jaxpr={ lambda ; n:f32[256] o:f32[256]. let
          p:f32[256] = add n o
        in (p, p) }
      length=128
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] j k
    b[:] <- l
    c[:,:] <- m
  in () }

The code to reproduce is below. Note that if I turn interpret=True the code runs without error.

from jax.experimental import pallas as pl
import jax.numpy as jnp


def cumsum_kernel(xref_Cxd, carryref_d, oref_Cxd):
    @pl.when(pl.program_id(axis=0) == 0)
    def _():
        carryref_d[...] = jnp.zeros_like(carryref_d)
    carry_d = carryref_d[...]
    x_Cxd = xref_Cxd[...]
    def inner(_carry_d, x_d):
        _carry_d = _carry_d + x_d
        return _carry_d, _carry_d
    carry_d, cs_Cxd = jax.lax.scan(inner, carry_d, x_Cxd, length=x_Cxd.shape[0])
    carryref_d[...] = carry_d
    oref_Cxd[...] = cs_Cxd


@jax.jit
def cumsum(x_Txd):
    T, d = x_Txd.shape
    C = 128
    return pl.pallas_call(
        cumsum_kernel,
        grid=T // C,
        in_specs=[pl.BlockSpec(lambda i: (i, 0), (C, d))],
        out_specs=[
            pl.BlockSpec(lambda i: 0, (d,)),  # carry
            pl.BlockSpec(lambda i: (i, 0), (C, d))  # out
        ],
        out_shape=[
            jax.ShapeDtypeStruct((d,), x_Txd.dtype),
            jax.ShapeDtypeStruct(x_Txd.shape, x_Txd.dtype),
        ],
        interpret=False
    )(x_Txd)

key = jax.random.PRNGKey(0)
x_Txd = jax.random.normal(key, (1280, 256))
carry_d, cs_Txd = cumsum(x_Txd)
realcs_Txd = jnp.cumsum(x_Txd, axis=0)
print(f"Error is {jnp.max(jnp.abs(cs_Txd - realcs_Txd))}")

System info (python version, jaxlib version, accelerator, etc.)

I am using a Cloud TPU v3-8 and using Jax Version: 0.4.23.

@AllanYangZhou AllanYangZhou added the bug Something isn't working label May 12, 2024
@AllanYangZhou
Copy link
Author

I also tried a simpler implementation without using scan but just a loop, which produces a different error:

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp


def cumsum_kernel(xref_Cxd, carryref_d, oref_Cxd):
    carry_d = carryref_d[...]
    x_Cxd = xref_Cxd[...]
    C, d = x_Cxd.shape
    for i in range(C):
        carry_d += x_Cxd[i]
        oref_Cxd[i, :] = carry_d
    carryref_d[...] = carry_d


@jax.jit
def cumsum(x_Txd):
    T, d = x_Txd.shape
    C = 128
    carry_d = jnp.zeros((d,), dtype=x_Txd.dtype)
    return pl.pallas_call(
        cumsum_kernel,
        grid=T // C,
        in_specs=[
            pl.BlockSpec(lambda i: (i, 0), (C, d)),
            pl.BlockSpec(lambda i: 0, (d,)),
        ],
        out_specs=pl.BlockSpec(lambda i: (i, 0), (C, d)),
        out_shape=jax.ShapeDtypeStruct(x_Txd.shape, x_Txd.dtype),
        interpret=False,
    )(x_Txd, carry_d)

key = jax.random.PRNGKey(0)
x_Txd = jax.random.normal(key, (1280, 64))
cs_Txd = cumsum(x_Txd)
realcs_Txd = jnp.cumsum(x_Txd, axis=0)
print(f"Error is {jnp.max(jnp.abs(cs_Txd - realcs_Txd))}")

This produces a different error:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 35, in <module>
    cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 21, in cumsum
    return pl.pallas_call(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 456, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Internal TPU kernel compiler error: Only zero-offset slices supported.

The MLIR operation involved:
  %137 = "vector.extract_strided_slice"(%132) <{offsets = [1, 0], sizes = [1, 64], strides = [1, 1]}> : (vector<128x64xf32>) -> vector<1x64xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 35, in <module>
    cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 95, in pallas_call_tpu_lowering_rule
    return mlir.lower_fun(_lower_fun, multiple_results=True)(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 82, in _lower_fun
    return mosaic.as_tpu_kernel(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 406, in as_tpu_kernel
    lowered_module_asm, constants = _lower_tpu_kernel(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 333, in _lower_tpu_kernel
    _run_pass_pipeline(pipeline, module, "post-infer-vector-layout")
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 260, in _run_pass_pipeline
    raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: Only zero-offset slices supported.

The MLIR operation involved:
  %137 = "vector.extract_strided_slice"(%132) <{offsets = [1, 0], sizes = [1, 64], strides = [1, 1]}> : (vector<128x64xf32>) -> vector<1x64xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

@AllanYangZhou AllanYangZhou changed the title Pallas scan on TPU Pallas+scan: NotImplementedError when num_extensive is True May 16, 2024
@superbobry superbobry added the pallas Issues pertaining to Pallas (GPU or TPU) label May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

2 participants