Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] Interpreter mismatch for masked OOB indexing #21143

Open
oliverdutton opened this issue May 9, 2024 · 0 comments · May be fixed by #21298
Open

[pallas] Interpreter mismatch for masked OOB indexing #21143

oliverdutton opened this issue May 9, 2024 · 0 comments · May be fixed by #21298
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@oliverdutton
Copy link
Contributor

oliverdutton commented May 9, 2024

Description

For triton (if I have read this correctly) masked load/stores do not occur. So you can request to load/store to an index OOB for ref if that is masked. The current interpreter uses dynamic_slices/dynamic_slice_updates where masked updates are applied. In line with the 'always be in bounds' design in JAX if you index a slice that overruns the edge of the array it will be shifted to be valid (if possible). This leads to a disconnect in interpreter and Pallas outputs.

I know Triton is not Pallas, have you changed the desired behaviour for these cases in Pallas? - in which case this isn't a bug but needs documenting.

I've added a pull request fixing this with some tests #21298

Here is a colab minimal reproduction with shifts in load indices.

import jax
from jax import numpy as jnp, jit
from jax.experimental import pallas as pl

def masked_load_pallas_kernel(x_ref, o_ref):
  i = jnp.array(3)
  mask = jnp.arange(x_ref.shape[0]) + i < x_ref.shape[0]
  x = pl.load(x_ref, pl.dslice(i, mask.shape[0]), mask=mask, other=-1)
  o_ref[:] = x

@partial(jit, static_argnames=('interpret',))
def masked_load(x: jax.Array, interpret: bool=True):
  return pl.pallas_call(masked_load_pallas_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
                        interpret = interpret,
                        )(x)

x = jnp.arange(16)
print(f'Input:\nx:\n{x}\n\nOutput:')
for interpret in (True, False):
  print(f'Interpret: {interpret}\n{masked_load(x, interpret=interpret)}')
Input:
x:
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

Output:
Interpret: True
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 -1 -1 -1]
Interpret: False
[ 3  4  5  6  7  8  9 10 11 12 13 14 15 -1 -1 -1]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='b70fe499e42d', release='6.1.58+', version='#1 SMP PREEMPT_DYNAMIC Sat Nov 18 15:31:17 UTC 2023', machine='x86_64')


$ nvidia-smi
Thu May  9 08:55:06 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   63C    P0              30W /  72W |  17235MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

(Problem persists in 0.4.28)

@oliverdutton oliverdutton added the bug Something isn't working label May 9, 2024
@superbobry superbobry added the pallas Issues pertaining to Pallas (GPU or TPU) label May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants