Skip to content

Commit

Permalink
pydrake math: Ensure that .multiply preserves input shape (#13886)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricCousineau-TRI committed Aug 19, 2020
1 parent 30ebf3c commit 1ea9bbf
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 5 deletions.
1 change: 1 addition & 0 deletions bindings/pydrake/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ drake_pybind_library(
"//bindings/pydrake/common:cpp_template_pybind",
"//bindings/pydrake/common:default_scalars_pybind",
"//bindings/pydrake/common:deprecation_pybind",
"//bindings/pydrake/common:eigen_pybind",
"//bindings/pydrake/common:type_pack",
"//bindings/pydrake/common:value_pybind",
],
Expand Down
1 change: 1 addition & 0 deletions bindings/pydrake/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ drake_pybind_library(
":cpp_template_pybind",
":default_scalars_pybind",
":eigen_geometry_pybind",
":eigen_pybind",
":type_pack",
":value_pybind",
],
Expand Down
30 changes: 30 additions & 0 deletions bindings/pydrake/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,31 @@
import collections
import functools
import inspect

import numpy as np

from ._module_py import *


def _wrap_to_match_input_shape(f):
# See docstring for `WrapToMatchInputShape` in `eigen_pybind.h` for more
# details.
# N.B. We cannot use `inspect.Signature` due to the fact that pybind11's
# instance method is not inspectable for overloads.
assert callable(f), f

@functools.wraps(f)
def wrapper(self, *args, **kwargs):
# Call the function first to permit it to raise the appropriate
# TypeError from pybind11 if the inputs are not correctly formatted.
out = f(self, *args, **kwargs)
if isinstance(out, np.ndarray):
arg_list = tuple(args) + tuple(kwargs.values())
assert len(arg_list) == 1
(arg,) = arg_list
in_shape = np.asarray(arg).shape
return out.reshape(in_shape)
else:
return out

return wrapper
5 changes: 5 additions & 0 deletions bindings/pydrake/common/eigen_geometry_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "drake/bindings/pydrake/common/cpp_template_pybind.h"
#include "drake/bindings/pydrake/common/default_scalars_pybind.h"
#include "drake/bindings/pydrake/common/eigen_geometry_pybind.h"
#include "drake/bindings/pydrake/common/eigen_pybind.h"
#include "drake/bindings/pydrake/common/type_pack.h"
#include "drake/bindings/pydrake/common/value_pybind.h"
#include "drake/bindings/pydrake/pydrake_pybind.h"
Expand Down Expand Up @@ -185,6 +186,7 @@ void DoScalarDependentDefinitions(py::module m, T) {
.def("inverse", [](const Class* self) { return self->inverse(); })
.def(py::pickle([](const Class& self) { return self.matrix(); },
[](const Matrix4<T>& matrix) { return Class(matrix); }));
cls.attr("multiply") = WrapToMatchInputShape(cls.attr("multiply"));
cls.attr("__matmul__") = cls.attr("multiply");
py::implicitly_convertible<Matrix4<T>, Class>();
DefCopyAndDeepCopy(&cls);
Expand Down Expand Up @@ -308,6 +310,7 @@ void DoScalarDependentDefinitions(py::module m, T) {
[py_class_obj](py::object wxyz) -> Class {
return py_class_obj(wxyz).cast<Class>();
}));
cls.attr("multiply") = WrapToMatchInputShape(cls.attr("multiply"));
cls.attr("__matmul__") = cls.attr("multiply");
DefCopyAndDeepCopy(&cls);
DefCast<T>(&cls, kCastDoc);
Expand Down Expand Up @@ -403,6 +406,8 @@ void DoScalarDependentDefinitions(py::module m, T) {
DRAKE_THROW_UNLESS(t.size() == 2);
return Class(t[0].cast<T>(), t[1].cast<Vector3<T>>());
}));
// N.B. This class does not support multiplication with vectors, so we do
// not use `WrapToMatchInputShape` here.
cls.attr("__matmul__") = cls.attr("multiply");
DefCopyAndDeepCopy(&cls);
DefCast<T>(&cls, kCastDoc);
Expand Down
38 changes: 33 additions & 5 deletions bindings/pydrake/common/eigen_pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ namespace drake {
namespace pydrake {

// TODO(eric.cousineau): Ensure that all C++ mutator call sites use `EigenPtr`.
/// Provides a mutable Ref<> for a pointer.
/// Meant to be used for decorating methods passed to `pybind11` (e.g. virtual
/// function dispatch).
/**
Provides a mutable Ref<> for a pointer.
Meant to be used for decorating methods passed to `pybind11` (e.g. virtual
function dispatch).
*/
template <typename Derived>
auto ToEigenRef(Eigen::VectorBlock<Derived>* derived) {
return Eigen::Ref<Derived>(*derived);
}

/// Converts a raw array to a numpy array.
/** Converts a raw array to a numpy array. */
template <typename T>
py::object ToArray(T* ptr, int size, py::tuple shape) {
// Create flat array to be reshaped in numpy.
Expand All @@ -27,7 +29,7 @@ py::object ToArray(T* ptr, int size, py::tuple shape) {
.attr("reshape")(shape);
}

/// Converts a raw array to a numpy array (`const` variant).
/** Converts a raw array to a numpy array (`const` variant). */
template <typename T>
py::object ToArray(const T* ptr, int size, py::tuple shape) {
// Create flat array to be reshaped in numpy.
Expand All @@ -37,5 +39,31 @@ py::object ToArray(const T* ptr, int size, py::tuple shape) {
.attr("reshape")(shape);
}

/**
Wraps a overload instance method to reshape the output to be the same as a
given input argument. The input should be the first and only argument to
trigger reshaping.
This preserves the original docstrings so that they still indicate the shapes
of the input and output arrays.
Example:
@code
cls // BR
.def("multiply", [](const Class& self, const Class& other) { ... })
.def("multiply", [](const Class& self, const Vector3<T>& p) { ... })
.def("multiply", [](const Class& self, const Matrix3X<T>& plist) { ... });
cls.attr("multiply") = WrapToMatchInputShape(cls.attr("multiply"));
@endcode
@sa @ref PydrakeReturnVectorsOrMatrices
*/
inline py::object WrapToMatchInputShape(py::handle func) {
py::handle wrap =
py::module::import("pydrake.common").attr("_wrap_to_match_input_shape");
return wrap(func);
}

} // namespace pydrake
} // namespace drake
15 changes: 15 additions & 0 deletions bindings/pydrake/common/test/eigen_geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def test_quaternion(self, T):
numpy_compare.assert_float_equal(
q_I.slerp(t=0, other=q_I).wxyz(), [1., 0, 0, 0])

# - Test shaping (#13885).
v = np.array([0., 0., 0.])
vs = np.array([[1., 2., 3.], [4., 5., 6.]]).T
self.assertEqual((q_AB @ v).shape, (3,))
self.assertEqual((q_AB @ v.reshape((3, 1))).shape, (3, 1))
self.assertEqual((q_AB @ vs).shape, (3, 2))

# Test `type_caster`s.
if T == float:
value = test_util.create_quaternion()
Expand Down Expand Up @@ -203,6 +210,12 @@ def test_isometry3(self, T):
numpy_compare.assert_float_equal(
(X_AB.inverse() @ X_AB).matrix(), X_I_np)
numpy_compare.assert_float_equal(X_AB @ p_BQ, p_AQ)
# - Test shaping (#13885).
v = np.array([0., 0., 0.])
vs = np.array([[1., 2., 3.], [4., 5., 6.]]).T
self.assertEqual((X_AB @ v).shape, (3,))
self.assertEqual((X_AB @ v.reshape((3, 1))).shape, (3, 1))
self.assertEqual((X_AB @ vs).shape, (3, 2))

assert_pickle(self, X_AB, Isometry3.matrix, T=T)

Expand Down Expand Up @@ -271,6 +284,8 @@ def test_angle_axis(self, T):
numpy_compare.assert_equal(value.angle(), -value_sym.angle())
numpy_compare.assert_equal(value.axis(), -value_sym.axis())

# N.B. AngleAxis does not support multiplication with vectors, so we
# need not test it here.
def get_vector(value):
return np.hstack((value.angle(), value.axis()))

Expand Down
3 changes: 3 additions & 0 deletions bindings/pydrake/math_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "drake/bindings/pydrake/common/cpp_template_pybind.h"
#include "drake/bindings/pydrake/common/default_scalars_pybind.h"
#include "drake/bindings/pydrake/common/deprecation_pybind.h"
#include "drake/bindings/pydrake/common/eigen_pybind.h"
#include "drake/bindings/pydrake/common/type_pack.h"
#include "drake/bindings/pydrake/common/value_pybind.h"
#include "drake/bindings/pydrake/documentation_pybind.h"
Expand Down Expand Up @@ -133,6 +134,7 @@ void DoScalarDependentDefinitions(py::module m, T) {
[](const Eigen::Matrix<T, 3, 4>& matrix) {
return Class(matrix);
}));
cls.attr("multiply") = WrapToMatchInputShape(cls.attr("multiply"));
cls.attr("__matmul__") = cls.attr("multiply");
DefCopyAndDeepCopy(&cls);
DefCast<T>(&cls, cls_doc.cast.doc);
Expand Down Expand Up @@ -207,6 +209,7 @@ void DoScalarDependentDefinitions(py::module m, T) {
.def("ToAngleAxis", &Class::ToAngleAxis, cls_doc.ToAngleAxis.doc)
.def(py::pickle([](const Class& self) { return self.matrix(); },
[](const Matrix3<T>& matrix) { return Class(matrix); }));
cls.attr("multiply") = WrapToMatchInputShape(cls.attr("multiply"));
cls.attr("__matmul__") = cls.attr("multiply");
DefCopyAndDeepCopy(&cls);
DefCast<T>(&cls, cls_doc.cast.doc);
Expand Down
26 changes: 26 additions & 0 deletions bindings/pydrake/pydrake_doxygen.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,32 @@ py::handle abstract_value_cls =
...
```
### Matrix-multiplication-like Methods
For objects that may be represented by matrices or vectors (e.g.
RigidTransform, RotationMatrix), the `*` operator (via `__mul__`) should *not*
be bound because the `*` operator in NumPy implies elemnt-wise multiplication
for arrays.
For simplicity, we instead bind the explicitly named `.multiply()` method, and
alias the `__matmul__` operator `@` to this function.
@anchor PydrakeReturnVectorsOrMatrices
#### Returning Vectors or Matrices
Certain bound methods, like `RigidTransform.multiply()`, will have overloads
that can multiply and return (a) other `RigidTransform` instances, (b) vectors,
or (c) matrices (representing a list of vectors).
In the cases of (a) and (c), `pybind11` provides sufficient mechanisms to
provide an unambiguous output return type. However, for (b), `pybind11` will
return `ndarray` with shape `(3,)`. This can cause an issue when users pass
a vector of shape `(3, 1)` as input. Nominally, pybind11 will return a `(3,)`
array, but the user may expect `(3, 1)` as an output. To accommodate this, you
should use the drake::pydrake::WrapToMatchInputShape function.
@sa https://github.com/RobotLocomotion/drake/issues/13885
## Python Subclassing of C++ Classes
In general, minimize the amount in which users may subclass C++ classes in
Expand Down
13 changes: 13 additions & 0 deletions bindings/pydrake/test/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def check_equality(X_actual, X_expected_matrix):
X.multiply(other=RigidTransform()), RigidTransform)
self.assertIsInstance(X @ RigidTransform(), RigidTransform)
self.assertIsInstance(X @ [0, 0, 0], np.ndarray)
# - Test shaping (#13885).
v = np.array([0., 0., 0.])
vs = np.array([[1., 2., 3.], [4., 5., 6.]]).T
self.assertEqual((X @ v).shape, (3,))
self.assertEqual((X @ v.reshape((3, 1))).shape, (3, 1))
self.assertEqual((X @ vs).shape, (3, 2))
print(help(RigidTransform.multiply))
# - Test vector multiplication.
R_AB = RotationMatrix([
[0., 1, 0],
Expand Down Expand Up @@ -273,6 +280,12 @@ def test_rotation_matrix(self, T):
vlist_B = np.array([v_B, v_B]).T
vlist_A = np.array([v_A, v_A]).T
numpy_compare.assert_float_equal(R_AB.multiply(v_B=vlist_B), vlist_A)
# - Test shaping (#13885).
v = np.array([0., 0., 0.])
vs = np.array([[1., 2., 3.], [4., 5., 6.]]).T
self.assertEqual((R_AB @ v).shape, (3,))
self.assertEqual((R_AB @ v.reshape((3, 1))).shape, (3, 1))
self.assertEqual((R_AB @ vs).shape, (3, 2))
# Matrix checks
numpy_compare.assert_equal(R.IsValid(), True)
R = RotationMatrix()
Expand Down

0 comments on commit 1ea9bbf

Please sign in to comment.