Add jasp -> QC conversion#18
Conversation
|
I believe I have adressed all issues. @burgholzer, I feel the same about the dependency. Unfortunately, it seems we will also need their passes and not just a dialect, so I see no way around the issue for now, unless we ask qrisp people nicely to get rid of it (they already lower some |
This is exactly my point of view. If catalyst really sees itself as an extension of StableHLO this might make sense for them. But for us it does not. This is also one of the reasons why I am not particularly happy with our dependency on MQT Core. The exact same problem as with StableHLO. In case of problems this will always tie us to a version "close enough" to the one MQT is relying on. The StableHLO dependency just makes the situation even worse, because now we have to live with the possibility of unsatisfiable constraints which might e.g. keep us from updating the one of our deps (e.g. MQT) if the other one (e.g. StableHLO) is incompatible (typical dependency hell). |
I agree with everything you said, but this one. |
burgholzer
left a comment
There was a problem hiding this comment.
LGTM except for one small typo that I will quickly apply.
Signed-off-by: Lukas Burgholzer <burgholzer@me.com>
Of course we appreciate that, but it is a bad example at the same time. If we would follow my advise and not depend on MQT Core we would not have had the problem in the first place. It was just a one line change after all. @julfarn could have done it with a single commit to the repo instead of reaching out and waiting for a day or two for this to be landed in some other repository. |
He could have also just contributed the one-line fix to a fully open-source repository, which would have also just been a single commit. 😉 |
@burgholzer, I think there’s a misunderstanding. My comments were strictly about the technical workflow and dependency structure, not a personal attack. I have no intention of 'bashing' or 'slandering'. I’m just sharing my perspective on what would be most efficient for the repo. It seems we have very different viewpoints on this, so I’ll leave it at that for now. |
I may just be putting a bit too much emphasis on some of the wording, I'll admit that. |
Qrisp dev here :) |
|
Thanks for chiming in Raphael!
It would be phenomenal, if the IR that we receive only contained
I am not yet fully sure if that would solve the problem, as it seems likely that qrisp, the compiler, or both, would then have to depend on that separate package; which carries a dependency on StableHLO. But that's just my 2ct's here. I do not have a clear picture of a solution myself (yet). |
|
The (transitive) depedency would still be there but it would be optional, which mitigates the compatibility headaches a bit. If someone wants to run Qrisp on their Smart-Fridge they most likely don't need StableHLO for their usecase. |
|
Hi, thanks for your input!
I don't quite understand. The |
|
Thank you for challenging this, @julfarn! I was about to dismiss it but I did a brief check-in with Claude and it indeed found a way to invoke the Jax-shipped binary. The following script removes all StableHLO operations without introducing StableHLO dependency (just Jax). (Also tagging @superlopuh since I think he also was looking for ways to circumvent StableHLO) """
Lower a Jaspr quantum program to MLIR with StableHLO ops lowered to
linalg / arith / tensor, while keeping the Jasp dialect ops intact.
Uses only JAX-bundled infrastructure — no external stablehlo or mlir packages.
Pipeline:
1. Trace a quantum function → Jaspr
2. Lower Jaspr → JAX MLIR module (StableHLO + Jasp dialect)
3. Run symbol-dce + stablehlo-legalize-to-linalg on the JAX module
4. Print to generic MLIR → xDSL → rewrite StableHLO control flow to SCF
5. Print final MLIR (linalg + arith + tensor + SCF + Jasp)
"""
from jaxlib.mlir import ir, passmanager
from jaxlib.mlir._mlir_libs import _stablehlo, _mlirHlo, _chlo, _jax_mlir_ext
from qrisp import QuantumFloat, h, x, cx, measure, control
from qrisp.jasp import make_jaspr
from qrisp.jasp.mlir.jaxpr_lowering import lower_jaxpr_to_MLIR
from qrisp.jasp.mlir.jasp_lowering_rules import jasp_lowering_rules
from qrisp.jasp.mlir.quantum_control_flow import fix_quantum_control_flow
def _linalg_mlir_to_xdsl(mlir_text: str):
"""Parse post-linalg generic MLIR into an xDSL module.
Like ``generic_mlir_to_xdsl`` but also registers linalg, arith, tensor,
scf, and math dialects that appear after the StableHLO → linalg pass.
"""
from xdsl.context import Context
from xdsl.dialects import builtin, func, linalg, arith, tensor, scf
from xdsl.parser import Parser
from qrisp.jasp.mlir.xdsl_dialect import JaspDialect
ctx = Context()
ctx.allow_unregistered = True
ctx.load_dialect(builtin.Builtin)
ctx.load_dialect(func.Func)
ctx.load_dialect(linalg.Linalg)
ctx.load_dialect(arith.Arith)
ctx.load_dialect(tensor.Tensor)
ctx.load_dialect(scf.Scf)
ctx.load_dialect(JaspDialect)
parser = Parser(ctx, mlir_text)
return parser.parse_module()
def jaspr_to_linalg_mlir(jaspr) -> str:
"""Lower a Jaspr to an MLIR string with StableHLO→linalg applied.
The pipeline:
1. Legalize StableHLO arithmetic → linalg (on the original JAX module)
2. Then rewrite remaining StableHLO control flow → SCF (via xDSL)
"""
# --- Step 1: Lower Jaspr to a JAX MLIR module ---------------------------
mlir_module = lower_jaxpr_to_MLIR(jaspr, lowering_rules=jasp_lowering_rules)
# --- Step 2: Register passes and run inside the module's context ---------
# lower_jaxpr_to_MLIR exits its `with ctx.context:` block, so we must
# re-enter the MLIR context to use the PassManager.
ctx = mlir_module.context
with ctx:
_stablehlo.register_dialect(ctx)
_stablehlo.register_stablehlo_passes()
_mlirHlo.register_mhlo_dialect(ctx)
_mlirHlo.register_mhlo_passes()
_chlo.register_dialect(ctx)
# symbol-dce removes unused private shadow functions that JAX emits.
# stablehlo-legalize-to-linalg converts arithmetic/data ops to linalg.
# stablehlo control-flow ops (case, while) are left untouched here.
pipeline = "builtin.module(" \
"symbol-dce," \
"stablehlo-convert-to-signless," \
"stablehlo-legalize-to-linalg" \
")"
pm = passmanager.PassManager.parse(pipeline)
# Disable verifier: stablehlo.case carries !jasp.QuantumState which
# fails StableHLO type constraints. The legalizer itself only touches
# arithmetic ops and leaves case/while alone, so skipping verification
# is safe here. The control-flow rewrite to SCF happens next via xDSL.
pm.enable_verifier(False)
pm.run(mlir_module.operation)
# --- Step 3: Print to generic MLIR text ------------------------------
generic_mlir = mlir_module.operation.get_asm(
print_generic_op_form=True
)
# --- Step 4: xDSL round-trip to rewrite remaining HLO control flow -------
# stablehlo.case/while carry Jasp quantum types which are fine in SCF
# but not in StableHLO. xDSL rewrites them to scf.if/while.
xdsl_module = _linalg_mlir_to_xdsl(generic_mlir)
fix_quantum_control_flow(xdsl_module)
# --- Step 5: Emit the final MLIR text ------------------------------------
return str(xdsl_module)
# ── Example usage ────────────────────────────────────────────────────────────
def adaptive_bell_state():
qf = QuantumFloat(3)
h(qf[0])
cx(qf[0], qf[1])
meas_res = measure(qf)
with control(meas_res == 0):
x(qf[2])
return measure(qf)
if __name__ == "__main__":
jaspr = make_jaspr(adaptive_bell_state)()
print("=" * 72)
print(" Jaspr → MLIR (linalg + arith + tensor + Jasp)")
print("=" * 72)
print(jaspr_to_linalg_mlir(jaspr)) |
|
We've been working on this with the Xanadu folks: https://github.com/xdslproject/xdsl-jax Am I right in understanding that you already have a custom frontend to StableHLO? If so, it might be worth your time to port the parts of lowering to linalg that you need for your use-case, which would let you drop the dependency. |
|
This is amazing, @positr0nium! Then we can work under the assumption that the IR reaching us has already been converted to @superlopuh: Not much of a frontend, we would simply call the stablehlo-to-linalg conversion pass on our side. As soon as we don't need to do that, we can drop the dependency. |
|
@julfarn yes, you can expect this to become part of Qrisp (possibly a kwarg of |
The second step towards #13.
Currently blocked by this PR in MQT core and subsequent update of the dependency version on our side.
Added a conversion pass from the
jasptoqcdialects. In addition, introduced a dependency onstablehloso that output fromqrispcan be handled.A few todos are still present in the code, but the most urgent features are implemented.
The remaining steps are now the assembly of a larger pipeline, lowering the classical side to MQT-compatible dialects as well, and thorough testing.