Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scatter lowering #273

Merged
merged 45 commits into from
Sep 22, 2023
Merged

Scatter lowering #273

merged 45 commits into from
Sep 22, 2023

Conversation

rmoyard
Copy link
Contributor

@rmoyard rmoyard commented Sep 6, 2023

Description of the Change:

In the quantum dialect, I add a lowering for MHLO scatter to the MLIR tensor dialect. The reference for the implementation can be found here https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter

Benefits:

This unlocks Jax array index with at and apply functions on it, for example:

    @qjit
    def add_multiply(l: jax.core.ShapedArray((3,), dtype=float), idx: int):
        res = l.at[idx].multiply(3)
        res2 = l.at[idx].add(2)
        return res + res2

    res = add_multiply(jnp.array([0, 1, 2]), 2)
    >>> [0, 2, 10]

Possible Drawbacks:

Some edge cases not properly covered.

@dime10
Copy link
Collaborator

dime10 commented Sep 19, 2023

We might want to wait until #216 is merged, since that PR is fairly heavy and close to being merged. It shouldn't be too much work to adapt this PR to the new compiler driver.

@rmoyard rmoyard requested a review from dime10 September 20, 2023 23:46
@dime10
Copy link
Collaborator

dime10 commented Sep 21, 2023

[sc-41325]

Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still looking at it. Some small comments first.

.github/workflows/check-catalyst.yaml Outdated Show resolved Hide resolved
mlir/test/Catalyst/ScatterTest.mlir Outdated Show resolved Hide resolved
frontend/test/pytest/test_scatter.py Show resolved Hide resolved
Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Great work @rmoyard!

@rmoyard rmoyard merged commit 7ba2158 into main Sep 22, 2023
19 checks passed
@rmoyard rmoyard deleted the scatter_lowering branch September 22, 2023 19:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants