-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend] Split jax_extras into several modules (#505)
This is an utility PR split `./utils/jax_extras.py` into three sub-modules. The goal is to make a complex #370 PR simpler. --------- Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
- Loading branch information
1 parent
fd54689
commit 3b6f3ac
Showing
11 changed files
with
443 additions
and
333 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Catalyst additions to the Jax library """ | ||
|
||
from catalyst.jax_extras.lowering import custom_lower_jaxpr_to_module, jaxpr_to_mlir | ||
from catalyst.jax_extras.patches import ( | ||
_gather_shape_rule_dynamic, | ||
_no_clean_up_dead_vars, | ||
get_aval2, | ||
) | ||
from catalyst.jax_extras.tracing import ( | ||
ClosedJaxpr, | ||
DynamicJaxprTrace, | ||
DynamicJaxprTracer, | ||
Jaxpr, | ||
PyTreeDef, | ||
PyTreeRegistry, | ||
ShapedArray, | ||
ShapeDtypeStruct, | ||
_abstractify, | ||
_extract_implicit_args, | ||
_initial_style_jaxpr, | ||
_input_type_to_tracers, | ||
convert_constvars_jaxpr, | ||
convert_element_type, | ||
deduce_avals, | ||
eval_jaxpr, | ||
get_implicit_and_explicit_flat_args, | ||
infer_lambda_input_type, | ||
initial_style_jaxprs_with_common_consts1, | ||
initial_style_jaxprs_with_common_consts2, | ||
jaxpr_remove_implicit, | ||
make_jaxpr2, | ||
make_jaxpr_effects, | ||
new_dynamic_main2, | ||
new_inner_tracer, | ||
sort_eqns, | ||
transient_jax_config, | ||
tree_flatten, | ||
tree_structure, | ||
tree_unflatten, | ||
treedef_is_leaf, | ||
unzip2, | ||
wrap_init, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Jax extras module containing functions related to the StableHLO lowering """ | ||
|
||
from __future__ import annotations | ||
|
||
import jax | ||
from jax._src.dispatch import jaxpr_replicas | ||
from jax._src.effects import ordered_effects as jax_ordered_effects | ||
from jax._src.interpreters.mlir import _module_name_regex | ||
from jax._src.lax.lax import xla | ||
from jax._src.sharding_impls import ReplicaAxisContext | ||
from jax._src.source_info_util import new_name_stack | ||
from jax._src.util import wrap_name | ||
from jax.core import ClosedJaxpr | ||
from jax.interpreters.mlir import ( | ||
AxisContext, | ||
LoweringParameters, | ||
ModuleContext, | ||
ir, | ||
lower_jaxpr_to_fun, | ||
lowerable_effects, | ||
) | ||
|
||
from catalyst.utils.patching import Patcher | ||
|
||
# pylint: disable=protected-access | ||
|
||
__all__ = ("jaxpr_to_mlir", "custom_lower_jaxpr_to_module") | ||
|
||
from catalyst.jax_extras.patches import _no_clean_up_dead_vars, get_aval2 | ||
|
||
|
||
def jaxpr_to_mlir(func_name, jaxpr): | ||
"""Lower a Jaxpr into an MLIR module. | ||
Args: | ||
func_name(str): function name | ||
jaxpr(Jaxpr): Jaxpr code to lower | ||
Returns: | ||
module: the MLIR module corresponding to ``func`` | ||
context: the MLIR context corresponding | ||
""" | ||
|
||
with Patcher( | ||
(jax._src.interpreters.partial_eval, "get_aval", get_aval2), | ||
(jax._src.core, "clean_up_dead_vars", _no_clean_up_dead_vars), | ||
): | ||
nrep = jaxpr_replicas(jaxpr) | ||
effects = jax_ordered_effects.filter_in(jaxpr.effects) | ||
axis_context = ReplicaAxisContext(xla.AxisEnv(nrep, (), ())) | ||
name_stack = new_name_stack(wrap_name("ok", "jit")) | ||
module, context = custom_lower_jaxpr_to_module( | ||
func_name="jit_" + func_name, | ||
module_name=func_name, | ||
jaxpr=jaxpr, | ||
effects=effects, | ||
platform="cpu", | ||
axis_context=axis_context, | ||
name_stack=name_stack, | ||
) | ||
|
||
return module, context | ||
|
||
|
||
# pylint: disable=too-many-arguments | ||
def custom_lower_jaxpr_to_module( | ||
func_name: str, | ||
module_name: str, | ||
jaxpr: ClosedJaxpr, | ||
effects, | ||
platform: str, | ||
axis_context: AxisContext, | ||
name_stack, | ||
replicated_args=None, | ||
arg_shardings=None, | ||
result_shardings=None, | ||
): | ||
"""Lowers a top-level jaxpr to an MHLO module. | ||
Handles the quirks of the argument/return value passing conventions of the | ||
runtime. | ||
This function has been modified from its original form in the JAX project at | ||
https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version | ||
released under the Apache License, Version 2.0, with the following copyright notice: | ||
Copyright 2021 The JAX Authors. | ||
""" | ||
|
||
if any(lowerable_effects.filter_not_in(jaxpr.effects)): # pragma: no cover | ||
raise ValueError(f"Cannot lower jaxpr with effects: {jaxpr.effects}") | ||
|
||
assert platform == "cpu" | ||
assert arg_shardings is None | ||
assert result_shardings is None | ||
|
||
# MHLO channels need to start at 1 | ||
channel_iter = 1 | ||
# Create a keepalives list that will be mutated during the lowering. | ||
keepalives = [] | ||
host_callbacks = [] | ||
lowering_params = LoweringParameters() | ||
ctx = ModuleContext( | ||
backend_or_name=None, | ||
platforms=[platform], | ||
axis_context=axis_context, | ||
name_stack=name_stack, | ||
keepalives=keepalives, | ||
channel_iterator=channel_iter, | ||
host_callbacks=host_callbacks, | ||
lowering_parameters=lowering_params, | ||
) | ||
ctx.context.allow_unregistered_dialects = True | ||
with ctx.context, ir.Location.unknown(ctx.context): | ||
# register_dialect() | ||
# Remove module name characters that XLA would alter. This ensures that | ||
# XLA computation preserves the module name. | ||
module_name = _module_name_regex.sub("_", module_name) | ||
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) | ||
lower_jaxpr_to_fun( | ||
ctx, | ||
func_name, | ||
jaxpr, | ||
effects, | ||
public=True, | ||
create_tokens=True, | ||
replace_tokens_with_dummy=True, | ||
replicated_args=replicated_args, | ||
arg_shardings=arg_shardings, | ||
result_shardings=result_shardings, | ||
) | ||
|
||
for op in ctx.module.body.operations: | ||
func_name = str(op.name) | ||
is_entry_point = func_name.startswith('"jit_') | ||
if is_entry_point: | ||
continue | ||
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage<internal>") | ||
|
||
return ctx.module, ctx.context |
Oops, something went wrong.