Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Split jax_extras into several modules #505

Merged
merged 13 commits into from
Feb 23, 2024
2 changes: 1 addition & 1 deletion frontend/catalyst/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@

import catalyst
from catalyst.ag_utils import AutoGraphError
from catalyst.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.utils.contexts import EvaluationContext
from catalyst.utils.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.utils.patching import Patcher

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/compilation_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import catalyst
from catalyst.ag_utils import run_autograph
from catalyst.compiler import CompileOptions, Compiler
from catalyst.jax_extras import get_aval2, get_implicit_and_explicit_flat_args
from catalyst.jax_tracer import trace_to_mlir
from catalyst.pennylane_extensions import QFunc
from catalyst.utils import wrapper # pylint: disable=no-name-in-module
Expand All @@ -52,7 +53,6 @@
from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import WorkspaceManager
from catalyst.utils.gen_mlir import inject_functions
from catalyst.utils.jax_extras import get_aval2, get_implicit_and_explicit_flat_args
from catalyst.utils.patching import Patcher

# Required for JAX tracer objects as PennyLane wires.
Expand Down
55 changes: 55 additions & 0 deletions frontend/catalyst/jax_extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Catalyst additions to the Jax library """

from catalyst.jax_extras.lowering import custom_lower_jaxpr_to_module, jaxpr_to_mlir
from catalyst.jax_extras.patches import (
_gather_shape_rule_dynamic,
_no_clean_up_dead_vars,
get_aval2,
)
from catalyst.jax_extras.tracing import (
ClosedJaxpr,
DynamicJaxprTrace,
DynamicJaxprTracer,
Jaxpr,
PyTreeDef,
PyTreeRegistry,
ShapedArray,
ShapeDtypeStruct,
_abstractify,
_extract_implicit_args,
_initial_style_jaxpr,
_input_type_to_tracers,
convert_constvars_jaxpr,
convert_element_type,
deduce_avals,
eval_jaxpr,
get_implicit_and_explicit_flat_args,
infer_lambda_input_type,
initial_style_jaxprs_with_common_consts1,
initial_style_jaxprs_with_common_consts2,
jaxpr_remove_implicit,
make_jaxpr2,
make_jaxpr_effects,
new_dynamic_main2,
new_inner_tracer,
sort_eqns,
tree_flatten,
tree_structure,
tree_unflatten,
treedef_is_leaf,
unzip2,
wrap_init,
)
144 changes: 144 additions & 0 deletions frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Jax extras module containing functions related to the StableHLO lowering """

from __future__ import annotations

import jax
from jax._src.dispatch import jaxpr_replicas
from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex
from jax._src.lax.lax import xla
from jax._src.sharding_impls import ReplicaAxisContext
from jax._src.source_info_util import new_name_stack
from jax._src.util import wrap_name
from jax.core import ClosedJaxpr
from jax.interpreters.mlir import (
AxisContext,
ModuleContext,
ir,
lower_jaxpr_to_fun,
lowerable_effects,
)

from catalyst.utils.patching import Patcher

# pylint: disable=protected-access

__all__ = ("jaxpr_to_mlir", "custom_lower_jaxpr_to_module")

from catalyst.jax_extras.patches import _no_clean_up_dead_vars, get_aval2


def jaxpr_to_mlir(func_name, jaxpr):
"""Lower a Jaxpr into an MLIR module.

Args:
func_name(str): function name
jaxpr(Jaxpr): Jaxpr code to lower

Returns:
module: the MLIR module corresponding to ``func``
context: the MLIR context corresponding
"""

with Patcher(
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
(jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars),
):
nrep = jaxpr_replicas(jaxpr)
effects = jax_ordered_effects.filter_in(jaxpr.effects)
axis_context = ReplicaAxisContext(xla.AxisEnv(nrep, (), ()))
name_stack = new_name_stack(wrap_name("ok", "jit"))
module, context = custom_lower_jaxpr_to_module(
func_name="jit_" + func_name,
module_name=func_name,
jaxpr=jaxpr,
effects=effects,
platform="cpu",
axis_context=axis_context,
name_stack=name_stack,
)

return module, context


# pylint: disable=too-many-arguments
def custom_lower_jaxpr_to_module(
func_name: str,
module_name: str,
jaxpr: ClosedJaxpr,
effects,
platform: str,
axis_context: AxisContext,
name_stack,
replicated_args=None,
arg_shardings=None,
result_shardings=None,
):
"""Lowers a top-level jaxpr to an MHLO module.

Handles the quirks of the argument/return value passing conventions of the
runtime.

This function has been modified from its original form in the JAX project at
https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version
released under the Apache License, Version 2.0, with the following copyright notice:

Copyright 2021 The JAX Authors.
"""

if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover
raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}")

assert platform == "cpu"
assert arg_shardings is None
assert result_shardings is None

# MHLO channels need to start at 1
channel_iter = 1
# Create a keepalives list that will be mutated during the lowering.
keepalives = []
host_callbacks = []
ctx = ModuleContext(
None, platform, axis_context, name_stack, keepalives, channel_iter, host_callbacks
)
ctx.context.allow_unregistered_dialects = True
with ctx.context, ir.Location.unknown(ctx.context):
# register_dialect()
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
lower_jaxpr_to_fun(
ctx,
func_name,
jaxpr,
effects,
public=True,
create_tokens=True,
replace_tokens_with_dummy=True,
replicated_args=replicated_args,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
)

for op in ctx.module.body.operations:
func_name = str(op.name)
is_entry_point = func_name.startswith('"jit_')
if is_entry_point:
continue
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")

return ctx.module, ctx.context
Loading
Loading