Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pallas] align interpreter load/store with masked behaviour #21298

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

oliverdutton
Copy link
Contributor

@oliverdutton oliverdutton commented May 18, 2024

(and adds stride support)

Fixes #21143

Implements a jittable masked gather/scatter where for load/store/swap any masked indexing does not occur.

For load it sets any masked indexing to index to the first element in the array instead.

For swap(/store) it also sets masked indexing to the first element (and then deals with special rules to make sure the first element is dealt with correctly)

The currently used dynamic_slices are replaced with explicit index materialisation and gathers/scatters.
The advantage of doing it this way is that you can combine it with checkify(f, errors=checkify.index_checks) in interpreter mode to check for any unmasked OOB indexing which is (I think, and believe should be) undefined behaviour.

[apologies this is a reopening of a previous request I'd done badly having not checked contributing.md]

@oliverdutton oliverdutton force-pushed the pallas_interpreter_indexing_fix branch from 3750448 to 341e249 Compare May 18, 2024 21:26
@oliverdutton
Copy link
Contributor Author

oliverdutton commented May 24, 2024

This corrects similar issues to #21180 @justinjfu, though relating to indexing into MemRefs rather than non evenly-divisible block shapes for chunking arrays

@justinjfu justinjfu self-assigned this May 29, 2024
Copy link
Collaborator

@justinjfu justinjfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Really appreciate the fixes.

indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
slice_strides = tuple(s.stride if isinstance(s, Slice) else 1 for s in indices)
indices = [start + lax.iota(jnp.int32, size) * stride for (start, size, stride)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explicit indexing here could end up being very slow for large slices. I don't think this would be a big problem as kernels generally have small block sizes, but is the only advantage here (over padding) the checkify support as mentioned in the PR description?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it could be slow. The key advantage is that interpreter code (here) on host will produce the same behaviour as the compiled device code.

Dynamic slice (or dynamic_update_slice) leads to incorrect behaviour due to XLAs shifting rules. Noted as:

the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:
(https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html)

an example is given in #21143

The use of a gather instead of the dynamic slice fixes that incorrect behaviour

[As a side bonus, it aligns the slice and individual indexing code paths]

indices = tuple(jnp.meshgrid(*indices, indexing='ij'))
if mask is not None:
# masked loads set to index first element in array
indices = tuple(jnp.where(mask, indexs, 0) for indexs in indices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to set masked indices to the first element? Wouldn't these values be masked away anyways on line 328?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true these values will be masked away on 328. The masked indices being set to the first element is to match the described behaviour of masked loads more closely, the documentation says

Triton avoids the read from/write to memory if it’s masked

as part of

Masking is important when doing out-of-bounds loads/stores. The operational semantics of masking can be compiler-determined (if we understand the documentation properly, Triton avoids the read from/write to memory if it’s masked).
(https://jax.readthedocs.io/en/latest/pallas/design.html)

Those masked indices might be OOB values and should not be indexed at. By diverting them to the first index this is the best jittable approximation I could make.

The link to checkify is that this change makes it possible to distinguish between an unmasked OOB indexing that will actually happen on device and a a masked OOB indexing that will. So the masked load is compatible with debugging in JAX for NaNs by any variant of https://jax.readthedocs.io/en/latest/debugging/flags.html

e.g. for Ref of shape (4,) and indices of [1,2,3,7], mask of [True,True,True,False] is a completely valid load, however running checkify check_nan's on the interpretted code would throw an OOB error without this fix as the indexing at 7 would 'physically' occur, the divert to 0 avoids this.

Copy link
Collaborator

@justinjfu justinjfu May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - OOB indexes that are masked being legal does make the problem more complicated.

The cleaner way to handle this for checkify is probably to add a custom checkify rule for the load and swap primitives (for example, see https://github.com/google/jax/blob/main/jax/_src/checkify.py#L608). The interpret mode discharge rule is primarily responsible for returning the correct values.

What I propose is this change:

  1. In the discharge rule for load and swap, simply pad the inputs to the maximum shape with jax.numpy.pad at the beginning of the call, then you can leave the rest of the logic unchanged.
  2. Move the explicit indexing logic to a custom checkify rule for load & swap which will check if there are any unmasked indices that are OOB. If they exist, return the OOB error. I believe checkify only checks for NaNs for a subset of arithmetic ops, so you wouldn't need to do NaN checking here.

I think this would keep the discharge implementation simple and fast but also be compatible with checkify & jit. Does that sound reasonable or are there other use cases which wouldn't be covered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness

  1. In the discharge rule for load and swap, simply pad the inputs to the maximum shape with jax.numpy.pad at the beginning of the call, then you can leave the rest of the logic unchanged.
    This equates to padding by the slice size in each dimension. Agreed this will give correct output without performance hit. Will implement.

Debugging interaction with masked OOBs
2. Move the explicit indexing logic to a custom checkify rule for load & swap which will check if there are any unmasked indices that are OOB. If they exist, return the OOB error. I believe checkify only checks for NaNs for a subset of arithmetic ops, so you wouldn't need to do NaN checking here.
I think it makes sense for that to be dealt with by the NaN/OOB checkers. Will look into how to do that and push that into a separate PR to this one.

This PR will be just to fix #21143 and do 1.

Will put together on Monday evening


x = random.normal(random.key(0), (m, n))
y = random.normal(random.key(1), (m, n))
mask = jnp.zeros((m, n), dtype=bool)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the mask the only difference between this test and test_masked_swap? If so, it would be cleaner to merge them together with a parameterized test or subtests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, it's a special case of the mask which could cause problems.

Will merge together

y_new = y.at[unmasked_idx].set(x[mask])
np.testing.assert_array_equal(out[0], x_new)
np.testing.assert_array_equal(out[1], y_new)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for load (similar to the example you posted in #21143)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do


if (mask is not None):
# masked indexs do not occur, we simulate by using first element in array as dummy
is_first_element = (jnp.stack(indices) == 0).all(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here: why do we need to add special case handling to swap using the first element? Could we not solely rely on a jnp.where(mask, ...) to control which values get swapped as done in the original code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To mirror device behaviour, Ref is not accessed at masked indices. The best jittable approximation I found was to divert masked indices to index at the first element, this leads to the masked indexing not occurring (which is often OOB indexing, while the first element should never be OOB).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[pallas] Interpreter mismatch for masked OOB indexing
2 participants