-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
linear_operator.h
163 lines (140 loc) · 6.35 KB
/
linear_operator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#pragma once
#include <exception>
#include <string>
#include <Eigen/SparseCore>
#include "drake/common/default_scalars.h"
#include "drake/common/drake_copyable.h"
#include "drake/common/eigen_types.h"
#include "drake/common/nice_type_name.h"
namespace drake {
namespace multibody {
namespace contact_solvers {
namespace internal {
// This abstract class provides a generic interface for linear operators
// A ∈ ℝⁿˣᵐ defined by their application from ℝᵐ into ℝⁿ, y = A⋅x.
// Derived classes must provide an implementation for this application
// with specifics that exploit the operator's structure, e.g. sparsity, block
// diagonal, etc.
// Since most solvers will only need the multiplication operator, subclasses
// are required to implement this operation in DoMultiply() for both dense and
// sparse multiplies.
// Some operators, typically contact Jacobians for instance, do require
// additional operations such as multiplication by their transpose. This will
// generally be a requirement documented by specific solvers but not enforced
// by this class. Therefore, the implementation of that operation in
// DoMultiplyByTranspose() is optional, with the default implementation
// provided by this class throwing a runtime exception. Similarly for other
// operations.
//
// @tparam_nonsymbolic_scalar
template <typename T>
class LinearOperator {
public:
DRAKE_NO_COPY_NO_MOVE_NO_ASSIGN(LinearOperator)
// Creates an operator with a given `name`.
explicit LinearOperator(const std::string& name) : name_(name) {}
virtual ~LinearOperator() = default;
const std::string& name() const { return name_; }
virtual int rows() const = 0;
virtual int cols() const = 0;
// Performs y = A⋅x for `this` operator A.
// Derived classes must provide an implementation of the virtual interface
// DoMultiply().
// @throws if y is nullptr.
// @throws if x.size() does not equal this->cols() or if y->size() does not
// equal this->rows().
void Multiply(const Eigen::Ref<const Eigen::SparseVector<T>>& x,
Eigen::SparseVector<T>* y) const {
DRAKE_DEMAND(y != nullptr);
DRAKE_DEMAND(x.size() == cols());
DRAKE_DEMAND(y->size() == rows());
DoMultiply(x, y);
}
// Alternative signature that operates on dense vectors.
void Multiply(const Eigen::Ref<const VectorX<T>>& x, VectorX<T>* y) const {
DRAKE_DEMAND(y != nullptr);
DRAKE_DEMAND(x.size() == cols());
DRAKE_DEMAND(y->size() == rows());
DoMultiply(x, y);
}
// For `this` operator A, performs y = Aᵀ⋅x.
// The default implementation throws a std::runtime_error exception.
// Derived classes can provide an implementation through the virtual
// interface DoMultiplyByTranspose().
// @throws if y is nullptr.
// @throws if x.size() does not equal this->rows() or if y->size() does not
// equal this->cols().
void MultiplyByTranspose(const Eigen::Ref<const Eigen::SparseVector<T>>& x,
Eigen::SparseVector<T>* y) const {
DRAKE_DEMAND(y != nullptr);
DRAKE_DEMAND(x.size() == rows());
DRAKE_DEMAND(y->size() == cols());
DoMultiplyByTranspose(x, y);
}
// Alternative signature that operates on dense vectors.
void MultiplyByTranspose(const Eigen::Ref<const VectorX<T>>& x,
VectorX<T>* y) const {
DRAKE_DEMAND(y != nullptr);
DRAKE_DEMAND(x.size() == rows());
DRAKE_DEMAND(y->size() == cols());
DoMultiplyByTranspose(x, y);
}
// Assembles the explicit matrix form for `this` operator into matrix A.
// Some solvers might require this operation in order to use direct methods.
// Particularly useful for debugging sessions.
// Derived classes can provide an implementation through the virtual
// interface DoAssembleMatrix().
// @throws if A is nullptr.
// @throws if A->rows() does not equal this->rows() or if A->cols() does not
// equal this->cols().
void AssembleMatrix(Eigen::SparseMatrix<T>* A) const {
DRAKE_DEMAND(A != nullptr);
DRAKE_DEMAND(A->rows() == rows());
DRAKE_DEMAND(A->cols() == cols());
DoAssembleMatrix(A);
}
// TODO(amcastro-tri): expand operations as needed, e.g.:
// - MultiplyAndAdd(): z = y + A * x
// - AXPY(): Y = Y + a * X
// - Norm(): Matrix norm for a particular norm type.
// - GetDiagonal(): e.g. to be used in iterative solvers with preconditioning.
protected:
// Performs y = A⋅x for `this` operator A.
// Its NVI already performed checks for valid arguments.
virtual void DoMultiply(const Eigen::Ref<const Eigen::SparseVector<T>>& x,
Eigen::SparseVector<T>* y) const = 0;
// Alternate signature to operate on dense vectors.
// Its NVI already performed checks for valid arguments.
virtual void DoMultiply(const Eigen::Ref<const VectorX<T>>& x,
VectorX<T>* y) const = 0;
// For `this` operator A, performs y = Aᵀ⋅x.
// The default implementation throws a std::runtime_error exception.
// Its NVI already performed checks for valid arguments.
virtual void DoMultiplyByTranspose(const Eigen::SparseVector<T>& x,
Eigen::SparseVector<T>* y) const;
// Alternate signature to operate on dense vectors.
// Its NVI already performed checks for valid arguments.
virtual void DoMultiplyByTranspose(const VectorX<T>& x, VectorX<T>* y) const;
// Assembles `this` operator into a sparse matrix A.
// The default implementation throws a std::runtime_error exception.
// Its NVI already performed checks for a valid non-null pointer to a matrix
// of the proper size.
// TODO(amcastro-tri): A default implementation for this method based on
// repeated multiplies by unit vectors could be provided.
virtual void DoAssembleMatrix(Eigen::SparseMatrix<T>* A) const;
private:
std::string name_;
// Helper to throw a specific exception when a given override was not
// implemented.
void ThrowIfNotImplemented(const char* source_method) const {
throw std::runtime_error(std::string(source_method) + "(): Instance '" +
name_ + "' of type '" + NiceTypeName::Get(*this) +
"' must provide an implementation.");
}
};
} // namespace internal
} // namespace contact_solvers
} // namespace multibody
} // namespace drake
DRAKE_DECLARE_CLASS_TEMPLATE_INSTANTIATIONS_ON_DEFAULT_NONSYMBOLIC_SCALARS(
class ::drake::multibody::contact_solvers::internal::LinearOperator)