From 01eb67f3761c6f5b1903b35c7ce1cfab6c78755c Mon Sep 17 00:00:00 2001 From: renan028 Date: Thu, 6 Aug 2020 16:02:41 -0300 Subject: [PATCH] add failure threshold to parallel node with tests --- .../controls/parallel_node.h | 14 +++-- src/controls/parallel_node.cpp | 53 ++++++++++++++----- tests/gtest_parallel.cpp | 49 ++++++++++++++++- 3 files changed, 98 insertions(+), 18 deletions(-) diff --git a/include/behaviortree_cpp_v3/controls/parallel_node.h b/include/behaviortree_cpp_v3/controls/parallel_node.h index 0c6c959fe..e49d176ec 100644 --- a/include/behaviortree_cpp_v3/controls/parallel_node.h +++ b/include/behaviortree_cpp_v3/controls/parallel_node.h @@ -24,13 +24,15 @@ class ParallelNode : public ControlNode { public: - ParallelNode(const std::string& name, unsigned threshold); + ParallelNode(const std::string& name, unsigned success_threshold, + unsigned failure_threshold = 1); ParallelNode(const std::string& name, const NodeConfiguration& config); static PortsList providedPorts() { - return { InputPort(THRESHOLD_KEY) }; + return { InputPort(THRESHOLD_SUCCESS), + InputPort(THRESHOLD_FAILURE) }; } ~ParallelNode() = default; @@ -38,15 +40,19 @@ class ParallelNode : public ControlNode virtual void halt() override; unsigned int thresholdM(); + unsigned int thresholdFM(); void setThresholdM(unsigned int threshold_M); + void setThresholdFM(unsigned int threshold_M); private: - unsigned int threshold_; + unsigned int success_threshold_; + unsigned int failure_threshold_; std::set skip_list_; bool read_parameter_from_ports_; - static constexpr const char* THRESHOLD_KEY = "threshold"; + static constexpr const char* THRESHOLD_SUCCESS = "success_threshold"; + static constexpr const char* THRESHOLD_FAILURE = "failure_threshold"; virtual BT::NodeStatus tick() override; }; diff --git a/src/controls/parallel_node.cpp b/src/controls/parallel_node.cpp index 0ab8cb29a..18ff9e7f6 100644 --- a/src/controls/parallel_node.cpp +++ b/src/controls/parallel_node.cpp @@ -16,11 +16,14 @@ namespace BT { -constexpr const char* ParallelNode::THRESHOLD_KEY; +constexpr const char* ParallelNode::THRESHOLD_FAILURE; +constexpr const char* ParallelNode::THRESHOLD_SUCCESS; -ParallelNode::ParallelNode(const std::string& name, unsigned threshold) +ParallelNode::ParallelNode(const std::string& name, unsigned success_threshold, + unsigned failure_threshold) : ControlNode::ControlNode(name, {} ), - threshold_(threshold), + success_threshold_(success_threshold), + failure_threshold_(failure_threshold), read_parameter_from_ports_(false) { setRegistrationID("Parallel"); @@ -29,7 +32,8 @@ ParallelNode::ParallelNode(const std::string& name, unsigned threshold) ParallelNode::ParallelNode(const std::string &name, const NodeConfiguration& config) : ControlNode::ControlNode(name, config), - threshold_(0), + success_threshold_(1), + failure_threshold_(1), read_parameter_from_ports_(true) { } @@ -38,9 +42,14 @@ NodeStatus ParallelNode::tick() { if(read_parameter_from_ports_) { - if( !getInput(THRESHOLD_KEY, threshold_) ) + if( !getInput(THRESHOLD_SUCCESS, success_threshold_) ) { - throw RuntimeError("Missing parameter [", THRESHOLD_KEY, "] in ParallelNode"); + throw RuntimeError("Missing parameter [", THRESHOLD_SUCCESS, "] in ParallelNode"); + } + + if( !getInput(THRESHOLD_FAILURE, failure_threshold_) ) + { + throw RuntimeError("Missing parameter [", THRESHOLD_FAILURE, "] in ParallelNode"); } } @@ -49,9 +58,14 @@ NodeStatus ParallelNode::tick() const size_t children_count = children_nodes_.size(); - if( children_count < threshold_) + if( children_count < success_threshold_) { - throw LogicError("Number of children is less than threshold. Can never suceed."); + throw LogicError("Number of children is less than threshold. Can never succeed."); + } + + if( children_count < failure_threshold_) + { + throw LogicError("Number of children is less than threshold. Can never fail."); } // Routing the tree according to the sequence node's logic: @@ -80,7 +94,7 @@ NodeStatus ParallelNode::tick() } success_childred_num++; - if (success_childred_num == threshold_) + if (success_childred_num == success_threshold_) { skip_list_.clear(); haltChildren(); @@ -95,8 +109,11 @@ NodeStatus ParallelNode::tick() skip_list_.insert(i); } failure_childred_num++; - - if (failure_childred_num > children_count - threshold_) + + // It fails if it is not possible to succeed anymore or if + // number of failures are equal to failure_threshold_ + if ((failure_childred_num > children_count - success_threshold_) + || (failure_childred_num == failure_threshold_)) { skip_list_.clear(); haltChildren(); @@ -127,12 +144,22 @@ void ParallelNode::halt() unsigned int ParallelNode::thresholdM() { - return threshold_; + return success_threshold_; +} + +unsigned int ParallelNode::thresholdFM() +{ + return failure_threshold_; } void ParallelNode::setThresholdM(unsigned int threshold_M) { - threshold_ = threshold_M; + success_threshold_ = threshold_M; +} + +void ParallelNode::setThresholdFM(unsigned int threshold_M) +{ + failure_threshold_ = threshold_M; } } diff --git a/tests/gtest_parallel.cpp b/tests/gtest_parallel.cpp index 16d2f5140..08f79f8fe 100644 --- a/tests/gtest_parallel.cpp +++ b/tests/gtest_parallel.cpp @@ -145,7 +145,7 @@ TEST_F(SimpleParallelTest, Threshold_3) ASSERT_EQ(NodeStatus::SUCCESS, state); } -TEST_F(SimpleParallelTest, Threshold_1) +TEST_F(SimpleParallelTest, Threshold_2) { root.setThresholdM(2); BT::NodeStatus state = root.executeTick(); @@ -189,11 +189,57 @@ TEST_F(ComplexParallelTest, ConditionsTrue) ASSERT_EQ(NodeStatus::SUCCESS, state); } +TEST_F(ComplexParallelTest, ConditionsLeftFalse) +{ + parallel_left.setThresholdFM(3); + parallel_left.setThresholdM(3); + condition_L1.setExpectedResult(NodeStatus::FAILURE); + condition_L2.setExpectedResult(NodeStatus::FAILURE); + BT::NodeStatus state = parallel_root.executeTick(); + + // It fails because Parallel Left it will never succeed (two already fail) + // even though threshold_failure == 3 + + ASSERT_EQ(NodeStatus::IDLE, parallel_left.status()); + ASSERT_EQ(NodeStatus::IDLE, condition_L1.status()); + ASSERT_EQ(NodeStatus::IDLE, condition_L2.status()); + ASSERT_EQ(NodeStatus::IDLE, action_L1.status()); + ASSERT_EQ(NodeStatus::IDLE, action_L2.status()); + + ASSERT_EQ(NodeStatus::IDLE, parallel_right.status()); + ASSERT_EQ(NodeStatus::IDLE, condition_R.status()); + ASSERT_EQ(NodeStatus::IDLE, action_R.status()); + + ASSERT_EQ(NodeStatus::FAILURE, state); +} + TEST_F(ComplexParallelTest, ConditionRightFalse) { condition_R.setExpectedResult(NodeStatus::FAILURE); BT::NodeStatus state = parallel_root.executeTick(); + // It fails because threshold_failure is 1 for parallel right and + // condition_R fails + + ASSERT_EQ(NodeStatus::IDLE, parallel_left.status()); + ASSERT_EQ(NodeStatus::IDLE, condition_L1.status()); + ASSERT_EQ(NodeStatus::IDLE, condition_L2.status()); + ASSERT_EQ(NodeStatus::IDLE, action_L1.status()); + ASSERT_EQ(NodeStatus::IDLE, action_L2.status()); + + ASSERT_EQ(NodeStatus::IDLE, parallel_right.status()); + ASSERT_EQ(NodeStatus::IDLE, condition_R.status()); + ASSERT_EQ(NodeStatus::IDLE, action_R.status()); + + ASSERT_EQ(NodeStatus::FAILURE, state); +} + +TEST_F(ComplexParallelTest, ConditionRightFalse_thresholdF_2) +{ + parallel_right.setThresholdFM(2); + condition_R.setExpectedResult(NodeStatus::FAILURE); + BT::NodeStatus state = parallel_root.executeTick(); + // All the actions are running ASSERT_EQ(NodeStatus::RUNNING, parallel_left.status()); @@ -229,6 +275,7 @@ TEST_F(ComplexParallelTest, ConditionRightFalseAction1Done) { condition_R.setExpectedResult(NodeStatus::FAILURE); + parallel_right.setThresholdFM(2); parallel_left.setThresholdM(4); BT::NodeStatus state = parallel_root.executeTick();