Skip to content

Commit

Permalink
[Frontend] Split jax_extras into several modules (#505)
Browse files Browse the repository at this point in the history
This is an utility PR split `./utils/jax_extras.py` into three
sub-modules. The goal is to make a complex
#370 PR simpler.

---------

Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
  • Loading branch information
2 people authored and rauletorresc committed Feb 26, 2024
1 parent fd54689 commit 3b6f3ac
Show file tree
Hide file tree
Showing 11 changed files with 443 additions and 333 deletions.
2 changes: 1 addition & 1 deletion frontend/catalyst/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@

import catalyst
from catalyst.ag_utils import AutoGraphError
from catalyst.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.tracing.contexts import EvaluationContext
from catalyst.utils.jax_extras import DynamicJaxprTracer, ShapedArray
from catalyst.utils.patching import Patcher

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/compiled_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
make_zero_d_memref_descriptor,
)

from catalyst.jax_extras import get_implicit_and_explicit_flat_args
from catalyst.tracing.type_signatures import filter_static_args
from catalyst.utils import wrapper # pylint: disable=no-name-in-module
from catalyst.utils.c_template import get_template, mlir_type_to_numpy_type
from catalyst.utils.jax_extras import get_implicit_and_explicit_flat_args


class SharedObjectManager:
Expand Down
56 changes: 56 additions & 0 deletions frontend/catalyst/jax_extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Catalyst additions to the Jax library """

from catalyst.jax_extras.lowering import custom_lower_jaxpr_to_module, jaxpr_to_mlir
from catalyst.jax_extras.patches import (
_gather_shape_rule_dynamic,
_no_clean_up_dead_vars,
get_aval2,
)
from catalyst.jax_extras.tracing import (
ClosedJaxpr,
DynamicJaxprTrace,
DynamicJaxprTracer,
Jaxpr,
PyTreeDef,
PyTreeRegistry,
ShapedArray,
ShapeDtypeStruct,
_abstractify,
_extract_implicit_args,
_initial_style_jaxpr,
_input_type_to_tracers,
convert_constvars_jaxpr,
convert_element_type,
deduce_avals,
eval_jaxpr,
get_implicit_and_explicit_flat_args,
infer_lambda_input_type,
initial_style_jaxprs_with_common_consts1,
initial_style_jaxprs_with_common_consts2,
jaxpr_remove_implicit,
make_jaxpr2,
make_jaxpr_effects,
new_dynamic_main2,
new_inner_tracer,
sort_eqns,
transient_jax_config,
tree_flatten,
tree_structure,
tree_unflatten,
treedef_is_leaf,
unzip2,
wrap_init,
)
153 changes: 153 additions & 0 deletions frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Jax extras module containing functions related to the StableHLO lowering """

from __future__ import annotations

import jax
from jax._src.dispatch import jaxpr_replicas
from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex
from jax._src.lax.lax import xla
from jax._src.sharding_impls import ReplicaAxisContext
from jax._src.source_info_util import new_name_stack
from jax._src.util import wrap_name
from jax.core import ClosedJaxpr
from jax.interpreters.mlir import (
AxisContext,
LoweringParameters,
ModuleContext,
ir,
lower_jaxpr_to_fun,
lowerable_effects,
)

from catalyst.utils.patching import Patcher

# pylint: disable=protected-access

__all__ = ("jaxpr_to_mlir", "custom_lower_jaxpr_to_module")

from catalyst.jax_extras.patches import _no_clean_up_dead_vars, get_aval2


def jaxpr_to_mlir(func_name, jaxpr):
"""Lower a Jaxpr into an MLIR module.
Args:
func_name(str): function name
jaxpr(Jaxpr): Jaxpr code to lower
Returns:
module: the MLIR module corresponding to ``func``
context: the MLIR context corresponding
"""

with Patcher(
(jax._src.interpreters.partial_eval, "get_aval", get_aval2),
(jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars),
):
nrep = jaxpr_replicas(jaxpr)
effects = jax_ordered_effects.filter_in(jaxpr.effects)
axis_context = ReplicaAxisContext(xla.AxisEnv(nrep, (), ()))
name_stack = new_name_stack(wrap_name("ok", "jit"))
module, context = custom_lower_jaxpr_to_module(
func_name="jit_" + func_name,
module_name=func_name,
jaxpr=jaxpr,
effects=effects,
platform="cpu",
axis_context=axis_context,
name_stack=name_stack,
)

return module, context


# pylint: disable=too-many-arguments
def custom_lower_jaxpr_to_module(
func_name: str,
module_name: str,
jaxpr: ClosedJaxpr,
effects,
platform: str,
axis_context: AxisContext,
name_stack,
replicated_args=None,
arg_shardings=None,
result_shardings=None,
):
"""Lowers a top-level jaxpr to an MHLO module.
Handles the quirks of the argument/return value passing conventions of the
runtime.
This function has been modified from its original form in the JAX project at
https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version
released under the Apache License, Version 2.0, with the following copyright notice:
Copyright 2021 The JAX Authors.
"""

if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover
raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}")

assert platform == "cpu"
assert arg_shardings is None
assert result_shardings is None

# MHLO channels need to start at 1
channel_iter = 1
# Create a keepalives list that will be mutated during the lowering.
keepalives = []
host_callbacks = []
lowering_params = LoweringParameters()
ctx = ModuleContext(
backend_or_name=None,
platforms=[platform],
axis_context=axis_context,
name_stack=name_stack,
keepalives=keepalives,
channel_iterator=channel_iter,
host_callbacks=host_callbacks,
lowering_parameters=lowering_params,
)
ctx.context.allow_unregistered_dialects = True
with ctx.context, ir.Location.unknown(ctx.context):
# register_dialect()
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)
lower_jaxpr_to_fun(
ctx,
func_name,
jaxpr,
effects,
public=True,
create_tokens=True,
replace_tokens_with_dummy=True,
replicated_args=replicated_args,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
)

for op in ctx.module.body.operations:
func_name = str(op.name)
is_entry_point = func_name.startswith('"jit_')
if is_entry_point:
continue
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")

return ctx.module, ctx.context
Loading

0 comments on commit 3b6f3ac

Please sign in to comment.