Skip to content

Commit

Permalink
py math: Support pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
EricCousineau-TRI committed Aug 28, 2019
1 parent 9ffc3e9 commit 5f51039
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
9 changes: 9 additions & 0 deletions bindings/pydrake/math_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ void DoScalarDependentDefinitions(py::module m, T) {
doc_rigid_transform_linear_matrix_deprecation)
.def("linear", &RigidTransform<T>::linear, py_reference_internal,
doc_rigid_transform_linear_matrix_deprecation);
DefPickle(&cls,
[](const Class& self) { return self.GetAsMatrix34(); },
[](const Eigen::Matrix<T, 3, 4>& matrix) { return Class(matrix); });
cls.attr("__matmul__") = cls.attr("multiply");
DefCopyAndDeepCopy(&cls);
DefCast<T>(&cls, cls_doc.cast.doc);
Expand Down Expand Up @@ -202,6 +205,9 @@ void DoScalarDependentDefinitions(py::module m, T) {
.def("ToQuaternion",
overload_cast_explicit<Eigen::Quaternion<T>>(&Class::ToQuaternion),
cls_doc.ToQuaternion.doc_0args);
DefPickle(&cls,
[](const Class& self) { return self.matrix(); },
[](const Matrix3<T>& matrix) { return Class(matrix); });
cls.attr("__matmul__") = cls.attr("multiply");
DefCopyAndDeepCopy(&cls);
DefCast<T>(&cls, cls_doc.cast.doc);
Expand Down Expand Up @@ -255,6 +261,9 @@ void DoScalarDependentDefinitions(py::module m, T) {
&Class::CalcRpyDDtFromAngularAccelInChild, py::arg("rpyDt"),
py::arg("alpha_AD_D"),
cls_doc.CalcRpyDDtFromAngularAccelInChild.doc);
DefPickle(&cls,
[](const Class& self) { return self.vector(); },
[](const Vector3<T>& rpy) { return Class(rpy); });
DefCopyAndDeepCopy(&cls);
// N.B. `RollPitchYaw::cast` is not defined in C++.
}
Expand Down
21 changes: 21 additions & 0 deletions bindings/pydrake/test/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pydrake.common.test_utilities.numpy_compare as numpy_compare

import copy
import pickle
from io import BytesIO
import math
import unittest

Expand Down Expand Up @@ -107,6 +109,23 @@ def check_cast(self, template, T):
for U in U_list:
self.assertIsInstance(value.cast[U](), template[U], U)

def check_pickle(self, T, input, value_to_compare):
if six.PY2:
# Pickling explicitly disabled in Python 2.
with self.assertRaises(RuntimeError) as cm:
pickle.dump(input, BytesIO())
return
if T == Expression:
# Pickling not enabled for Expression.
return
f = BytesIO()
pickle.dump(input, f)
f.seek(0)
output = pickle.load(f)
input_value = value_to_compare(input)
output_value = value_to_compare(output)
numpy_compare.assert_equal(input_value, output_value)

@numpy_compare.check_all_types
def test_rigid_transform(self, T):
RigidTransform = mut.RigidTransform_[T]
Expand Down Expand Up @@ -184,6 +203,8 @@ def check_equality(X_actual, X_expected_matrix):
p_AQlist = np.array([p_AQ, p_AQ]).T
numpy_compare.assert_float_equal(
X_AB.multiply(p_BoQ_B=p_BQlist), p_AQlist)
# Test pickling.
self.check_pickle(T, X_AB, RigidTransform.GetAsMatrix4)

@numpy_compare.check_all_types
def test_isometry_implicit(self, T):
Expand Down

0 comments on commit 5f51039

Please sign in to comment.