Skip to content

Commit

Permalink
[oop] Fix application of wrappers cascade in OOP decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 5, 2023
1 parent bf00399 commit 23b8899
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/jaxsim/utils/oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def wrapper(*args, **kwargs):
# computed by the functional transformation.
# We do so by iterating over the fields of the jax_dataclasses and ignoring
# all the fields that are marked as static.
# Caveats: https://github.com/ami-iit/jaxsim/pull/48#issuecomment-1746635121.
with instance.mutable_context(
mutability=mutability_dict[instance._mutability()]
):
Expand Down Expand Up @@ -300,7 +301,7 @@ def wrap_fn(
f"Static argument '{arg_name}' cannot be mapped with vmap"
)

def fn_tf_vmap(function_to_vmap: Callable, *args, **kwargs):
def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs):
"""Wrapper applying the vmap transformation"""

# Canonicalize the arguments so that all of them are kwargs
Expand Down Expand Up @@ -340,7 +341,9 @@ def fn_tf_vmap(function_to_vmap: Callable, *args, **kwargs):
}

# Close the function over the unmapped arguments of vmap
fn_closed = functools.partial(function_to_vmap, **vmap_unmapped_args)
fn_closed = lambda *mapped_args: function_to_vmap(
**vmap_unmapped_args, **dict(zip(vmap_mapped_args.keys(), mapped_args))
)

# Create the in_axes tuple of only the mapped arguments
in_axes_mapped = tuple(
Expand All @@ -365,17 +368,17 @@ def fn_tf_vmap(function_to_vmap: Callable, *args, **kwargs):

# Apply the vmap transformation and call the function passing only the
# mapped arguments. The unmapped arguments have been closed over.
# Note: that we altered the "in_axes" tuple so that it does not have any
# Note: we altered the "in_axes" tuple so that it does not have any
# None elements.
# Note: if in_axes_mapped is a tuple, the following fails if we pass kwargs,
# Note: if "in_axes_mapped" is a tuple, the following fails if we pass kwargs,
# we need to pass the unpacked args tuple instead.
return jax.vmap(
fn_closed,
in_axes=in_axes_mapped,
**dict(out_axes=out_axes) if out_axes is not None else {},
)(*list(vmap_mapped_args.values()))

def fn_tf_jit(function_to_jit: Callable, *args, **kwargs):
def fn_tf_jit(*args, function_to_jit: Callable, **kwargs):
"""Wrapper applying the jit transformation"""

# Canonicalize the arguments so that all of them are kwargs
Expand All @@ -390,20 +393,31 @@ def fn_tf_jit(function_to_jit: Callable, *args, **kwargs):

# First applied wrapper that executes fn in a mutable context
fn_mutable = functools.partial(
jax_tf.call_class_method_in_mutable_context, fn, jit, mutability
jax_tf.call_class_method_in_mutable_context,
fn=fn,
jit=jit,
mutability=mutability,
)

# Second applied wrapper that transforms fn with vmap
fn_vmap = fn_mutable if not vmap else functools.partial(fn_tf_vmap, fn_mutable)
fn_vmap = (
fn_mutable
if not vmap
else functools.partial(fn_tf_vmap, function_to_vmap=fn_mutable)
)

# Third applied wrapper that transforms fn with jit
fn_jit_vmap = fn_vmap if not jit else functools.partial(fn_tf_jit, fn_vmap)
fn_jit_vmap = (
fn_vmap
if not jit
else functools.partial(fn_tf_jit, function_to_jit=fn_vmap)
)

return fn_jit_vmap

@staticmethod
def call_class_method_in_mutable_context(
fn: Callable, jit: bool, mutability: Mutability, *args, **kwargs
*args, fn: Callable, jit: bool, mutability: Mutability, **kwargs
) -> tuple[Any, Vmappable]:
"""
Wrapper to call a method on an object with the desired mutable context.
Expand Down

0 comments on commit 23b8899

Please sign in to comment.