Skip to content

Commit

Permalink
[systems] Add DiagramBuilder::RemoveSystem
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri committed Apr 22, 2023
1 parent 585aba3 commit 31ffb92
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 1 deletion.
2 changes: 2 additions & 0 deletions bindings/pydrake/systems/framework_py_semantics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,8 @@ void DoScalarDependentDefinitions(py::module m) {
py::keep_alive<1, 0>(),
// Keep alive, ownership: `system` keeps `self` alive.
py::keep_alive<3, 1>(), doc.DiagramBuilder.AddNamedSystem.doc)
.def("RemoveSystem", &DiagramBuilder<T>::RemoveSystem, py::arg("system"),
doc.DiagramBuilder.RemoveSystem.doc)
.def("empty", &DiagramBuilder<T>::empty, doc.DiagramBuilder.empty.doc)
.def("already_built", &DiagramBuilder<T>::already_built,
doc.DiagramBuilder.already_built.doc)
Expand Down
10 changes: 10 additions & 0 deletions bindings/pydrake/systems/test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,16 @@ def test_generate_html(self):
html = GenerateHtml(system, initial_depth=2)
self.assertRegex(html, r'key: "zoh"')

def test_diagram_builder_remove(self):
builder = DiagramBuilder()
source = builder.AddSystem(ConstantVectorSource([0.0]))
adder = builder.AddSystem(Adder(1, 1))
builder.ExportOutput(source.get_output_port())
builder.RemoveSystem(adder) # N.B. Deletes 'adder'; don't use after!
diagram = builder.Build()
self.assertEqual(diagram.num_input_ports(), 0)
self.assertEqual(diagram.num_output_ports(), 1)

def test_diagram_fan_out(self):
builder = DiagramBuilder()
adder = builder.AddSystem(Adder(7, 1))
Expand Down
7 changes: 7 additions & 0 deletions common/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <map>
#include <optional>
Expand Down Expand Up @@ -80,6 +81,12 @@ hash_append(
hasher(std::addressof(item), sizeof(item));
}

/// Provides @ref hash_append for bare pointers.
template <class HashAlgorithm, class T>
void hash_append(HashAlgorithm& hasher, const T* item) noexcept {
hash_append(hasher, reinterpret_cast<std::uintptr_t>(item));
};

/// Provides @ref hash_append for enumerations.
template <class HashAlgorithm, class T>
std::enable_if_t<std::is_enum_v<T>>
Expand Down
12 changes: 12 additions & 0 deletions common/test/hash_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,17 @@ GTEST_TEST(HashTest, HashAppendOptional) {
EXPECT_NE(hash_empty1.record().size(), hash_nonempty1.record().size());
}

GTEST_TEST(HashTest, HashAppendPointer) {
const std::pair<int, const int*> foo{22, nullptr};
MockHasher foo_hash;
hash_append(foo_hash, foo);

MockHasher expected;
hash_append(expected, 22);
hash_append(expected, std::uintptr_t{});

EXPECT_EQ(foo_hash.record(), expected.record());
}

} // namespace
} // namespace drake
4 changes: 4 additions & 0 deletions systems/framework/diagram.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ class OwnedSystems {
decltype(auto) begin() const { return vec_.begin(); }
decltype(auto) end() const { return vec_.end(); }
decltype(auto) operator[](size_t i) const { return vec_[i]; }
decltype(auto) operator[](size_t i) { return vec_[i]; }
void push_back(std::unique_ptr<System<T>>&& sys) {
vec_.push_back(std::move(sys));
}
void pop_back() {
vec_.pop_back();
}

private:
std::vector<std::unique_ptr<System<T>>> vec_;
Expand Down
134 changes: 134 additions & 0 deletions systems/framework/diagram_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,107 @@

namespace drake {
namespace systems {
namespace {

/* Erases the i'th element of vec, shifting everything after it over by one. */
template <typename StdVector>
void VectorErase(StdVector* vec, size_t i) {
DRAKE_DEMAND(vec != nullptr);
const size_t size = vec->size();
DRAKE_DEMAND(i < size);
for (size_t hole = i; (hole + 1) < size; ++hole) {
(*vec)[hole] = std::move((*vec)[hole + 1]);
}
vec->pop_back();
}

} // namespace

template <typename T>
DiagramBuilder<T>::DiagramBuilder() {}

template <typename T>
DiagramBuilder<T>::~DiagramBuilder() {}

template <typename T>
void DiagramBuilder<T>::RemoveSystem(const System<T>& system) {
ThrowIfAlreadyBuilt();
if (systems_.count(&system) == 0) {
throw std::logic_error(fmt::format(
"Cannot RemoveSystem on {} because it has not been added to this "
"DiagramBuilder",
system.GetSystemPathname()));
}
const size_t system_index = std::distance(
registered_systems_.begin(),
std::find_if(registered_systems_.begin(), registered_systems_.end(),
[&system](const std::unique_ptr<System<T>>& item) {
return item.get() == &system;
}));
DRAKE_DEMAND(system_index < registered_systems_.size());

// Un-export any input ports associated with this system.
// First, undo the ConnectInput.
std::set<std::string> disconnected_diagram_input_port_names;
for (size_t i = 0; i < input_port_ids_.size();) {
const InputPortLocator& locator = input_port_ids_[i];
if (locator.first == &system) {
const size_t num_erased = diagram_input_set_.erase(locator);
DRAKE_DEMAND(num_erased == 1);
disconnected_diagram_input_port_names.insert(
std::move(input_port_names_[i]));
VectorErase(&input_port_ids_, i);
VectorErase(&input_port_names_, i);
} else {
++i;
}
}
// Second, undo the DeclareInput (iff it was the last connected system).
for (const auto& name : disconnected_diagram_input_port_names) {
const bool num_connections =
std::count(input_port_names_.begin(), input_port_names_.end(), name);
if (num_connections == 0) {
const auto iter = diagram_input_indices_.find(name);
DRAKE_DEMAND(iter != diagram_input_indices_.end());
const InputPortIndex removed_index = iter->second;
VectorErase(&diagram_input_data_, removed_index);
diagram_input_indices_.erase(iter);
for (auto& [_, index] : diagram_input_indices_) {
if (index > removed_index) {
--index;
}
}
}
}

// Un-export any output ports associated with this system.
for (OutputPortIndex i{0}; i < output_port_ids_.size();) {
const OutputPortLocator& locator = output_port_ids_[i];
if (locator.first == &system) {
VectorErase(&output_port_ids_, i);
VectorErase(&output_port_names_, i);
} else {
++i;
}
}

// Disconnect any internal connections associated with this system.
for (auto iter = connection_map_.begin(); iter != connection_map_.end();) {
const auto& [input_locator, output_locator] = *iter;
if ((input_locator.first == &system) || (output_locator.first == &system)) {
iter = connection_map_.erase(iter);
} else {
++iter;
}
}

// Delete the system.
systems_.erase(&system);
VectorErase(&registered_systems_, system_index);

DRAKE_ASSERT_VOID(CheckInvariants());
}

template <typename T>
std::vector<const System<T>*> DiagramBuilder<T>::GetSystems() const {
ThrowIfAlreadyBuilt();
Expand Down Expand Up @@ -518,11 +612,51 @@ void DiagramBuilder<T>::ThrowIfAlgebraicLoopsExist() const {
}
}

template <typename T>
void DiagramBuilder<T>::CheckInvariants() const {
auto has_system = [this](const System<T>* system) {
return std::count(systems_.begin(), systems_.end(), system) > 0;
};

// The systems_ and registered_systems_ are identical sets.
DRAKE_DEMAND(systems_.size() == registered_systems_.size());
for (const auto& item : registered_systems_) {
DRAKE_DEMAND(has_system(item.get()));
}

// The connection_map_ only refers to registered systems.
for (const auto& [input, output] : connection_map_) {
DRAKE_DEMAND(has_system(input.first));
DRAKE_DEMAND(has_system(output.first));
}

// The input_port_ids_ and output_port_ids_ only refer to registered systems.
for (const auto& [system, _] : input_port_ids_) {
DRAKE_DEMAND(has_system(system));
}
for (const auto& [system, _] : output_port_ids_) {
DRAKE_DEMAND(has_system(system));
}

// The input_port_ids_ and diagram_input_set_ are identical sets.
DRAKE_DEMAND(input_port_ids_.size() == diagram_input_set_.size());
for (const auto& item : input_port_ids_) {
DRAKE_DEMAND(diagram_input_set_.find(item) != diagram_input_set_.end());
}

// The diagram_input_indices_ is the inverse of diagram_input_data_.
DRAKE_DEMAND(diagram_input_data_.size() == diagram_input_indices_.size());
for (const auto& [name, index] : diagram_input_indices_) {
DRAKE_DEMAND(diagram_input_data_.at(index).name == name);
}
}

template <typename T>
std::unique_ptr<typename Diagram<T>::Blueprint> DiagramBuilder<T>::Compile() {
if (registered_systems_.size() == 0) {
throw std::logic_error("Cannot Compile an empty DiagramBuilder.");
}
DRAKE_ASSERT_VOID(CheckInvariants());
ThrowIfAlgebraicLoopsExist();

auto blueprint = std::make_unique<typename Diagram<T>::Blueprint>();
Expand Down
17 changes: 16 additions & 1 deletion systems/framework/diagram_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "drake/common/default_scalars.h"
#include "drake/common/drake_copyable.h"
#include "drake/common/hash.h"
#include "drake/common/pointer_cast.h"
#include "drake/systems/framework/diagram.h"
#include "drake/systems/framework/system.h"
Expand Down Expand Up @@ -243,6 +244,18 @@ class DiagramBuilder {
name, std::make_unique<S<T>>(std::forward<Args>(args)...));
}

/// Removes the given system from this builder and disconnects any connections
/// or exported ports associated with it.
///
/// Note that un-exporting this system's ports might have a ripple effect on
/// other exported input or output port index assignments. The relative order
/// will remain intact, but index "holes" created by this removal will be
/// filled in by subtracting from the indices of ports that remain.
///
/// @warning Because a DigramBuilder owns the objects it contains, the system
/// will be deleted.
void RemoveSystem(const System<T>& system);

/// Returns whether any Systems have been added yet.
bool empty() const {
ThrowIfAlreadyBuilt();
Expand Down Expand Up @@ -440,6 +453,8 @@ class DiagramBuilder {

void ThrowIfAlgebraicLoopsExist() const;

void CheckInvariants() const;

// Produces the Blueprint that has been described by the calls to
// Connect, ExportInput, and ExportOutput. Throws std::exception if the
// graph is empty or contains algebraic loops.
Expand All @@ -457,7 +472,7 @@ class DiagramBuilder {
std::vector<std::string> output_port_names_;

// For fast membership queries: has this input port already been wired?
std::set<InputPortLocator> diagram_input_set_;
std::unordered_set<InputPortLocator, DefaultHash> diagram_input_set_;

// A vector of data about exported input ports.
struct ExportedInputData {
Expand Down
16 changes: 16 additions & 0 deletions systems/framework/test/diagram_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ GTEST_TEST(DiagramBuilderTest, AddNamedSystem) {
EXPECT_EQ(c->get_name(), "c");
}

// Tests ::RemoveSystem.
GTEST_TEST(DiagramBuilderTest, Remove) {
DiagramBuilder<double> builder;
const auto& adder1 = *builder.AddSystem<Adder>(1 /* inputs */, 1 /* size */);
builder.ExportInput(adder1.get_input_port());
builder.ExportOutput(adder1.get_output_port());
const auto& adder2 = *builder.AddSystem<Adder>(1 /* inputs */, 2 /* size */);
builder.ExportInput(adder2.get_input_port());
builder.ExportOutput(adder2.get_output_port());
builder.RemoveSystem(adder1);
auto diagram = builder.Build();
ASSERT_EQ(diagram->num_input_ports(), 1);
ASSERT_EQ(diagram->num_output_ports(), 1);
EXPECT_EQ(diagram->get_output_port().size(), 2);
}

// Tests already_built() and one example of ThrowIfAlreadyBuilt().
GTEST_TEST(DiagramBuilderTest, AlreadyBuilt) {
DiagramBuilder<double> builder;
Expand Down

0 comments on commit 31ffb92

Please sign in to comment.