Skip to content

Commit

Permalink
[oop] Disable vmapping default (scalar) arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 5, 2023
1 parent 23b8899 commit 872d4e7
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/jaxsim/utils/oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs):
case _:
raise ValueError(in_axes)

# Disable mapping of non-vectorized default arguments
for arg, value in argname_to_mapped_axis.items():
if value == sig.parameters[arg].default:
logging.debug(f"Disabling vmapping of default argument '{arg}'")
argname_to_mapped_axis[arg] = None

# Build a dictionary (argument_name -> argument) for all mapped arguments.
# Note that a mapped argument is an argument whose axis is not None and
# is not a static jit argument.
Expand Down

0 comments on commit 872d4e7

Please sign in to comment.