Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions include/behaviortree_cpp_v3/controls/parallel_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,35 @@ class ParallelNode : public ControlNode
{
public:

ParallelNode(const std::string& name, unsigned threshold);
ParallelNode(const std::string& name, unsigned success_threshold,
unsigned failure_threshold = 1);

ParallelNode(const std::string& name, const NodeConfiguration& config);

static PortsList providedPorts()
{
return { InputPort<unsigned>(THRESHOLD_KEY) };
return { InputPort<unsigned>(THRESHOLD_SUCCESS),
InputPort<unsigned>(THRESHOLD_FAILURE) };
}

~ParallelNode() = default;

virtual void halt() override;

unsigned int thresholdM();
unsigned int thresholdFM();
void setThresholdM(unsigned int threshold_M);
void setThresholdFM(unsigned int threshold_M);

private:
unsigned int threshold_;
unsigned int success_threshold_;
unsigned int failure_threshold_;

std::set<int> skip_list_;

bool read_parameter_from_ports_;
static constexpr const char* THRESHOLD_KEY = "threshold";
static constexpr const char* THRESHOLD_SUCCESS = "success_threshold";
static constexpr const char* THRESHOLD_FAILURE = "failure_threshold";

virtual BT::NodeStatus tick() override;
};
Expand Down
53 changes: 40 additions & 13 deletions src/controls/parallel_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
namespace BT
{

constexpr const char* ParallelNode::THRESHOLD_KEY;
constexpr const char* ParallelNode::THRESHOLD_FAILURE;
constexpr const char* ParallelNode::THRESHOLD_SUCCESS;

ParallelNode::ParallelNode(const std::string& name, unsigned threshold)
ParallelNode::ParallelNode(const std::string& name, unsigned success_threshold,
unsigned failure_threshold)
: ControlNode::ControlNode(name, {} ),
threshold_(threshold),
success_threshold_(success_threshold),
failure_threshold_(failure_threshold),
read_parameter_from_ports_(false)
{
setRegistrationID("Parallel");
Expand All @@ -29,7 +32,8 @@ ParallelNode::ParallelNode(const std::string& name, unsigned threshold)
ParallelNode::ParallelNode(const std::string &name,
const NodeConfiguration& config)
: ControlNode::ControlNode(name, config),
threshold_(0),
success_threshold_(1),
failure_threshold_(1),
read_parameter_from_ports_(true)
{
}
Expand All @@ -38,9 +42,14 @@ NodeStatus ParallelNode::tick()
{
if(read_parameter_from_ports_)
{
if( !getInput(THRESHOLD_KEY, threshold_) )
if( !getInput(THRESHOLD_SUCCESS, success_threshold_) )
{
throw RuntimeError("Missing parameter [", THRESHOLD_KEY, "] in ParallelNode");
throw RuntimeError("Missing parameter [", THRESHOLD_SUCCESS, "] in ParallelNode");
}

if( !getInput(THRESHOLD_FAILURE, failure_threshold_) )
{
throw RuntimeError("Missing parameter [", THRESHOLD_FAILURE, "] in ParallelNode");
}
}

Expand All @@ -49,9 +58,14 @@ NodeStatus ParallelNode::tick()

const size_t children_count = children_nodes_.size();

if( children_count < threshold_)
if( children_count < success_threshold_)
{
throw LogicError("Number of children is less than threshold. Can never suceed.");
throw LogicError("Number of children is less than threshold. Can never succeed.");
}

if( children_count < failure_threshold_)
{
throw LogicError("Number of children is less than threshold. Can never fail.");
}

// Routing the tree according to the sequence node's logic:
Expand Down Expand Up @@ -80,7 +94,7 @@ NodeStatus ParallelNode::tick()
}
success_childred_num++;

if (success_childred_num == threshold_)
if (success_childred_num == success_threshold_)
{
skip_list_.clear();
haltChildren();
Expand All @@ -95,8 +109,11 @@ NodeStatus ParallelNode::tick()
skip_list_.insert(i);
}
failure_childred_num++;

if (failure_childred_num > children_count - threshold_)

// It fails if it is not possible to succeed anymore or if
// number of failures are equal to failure_threshold_
if ((failure_childred_num > children_count - success_threshold_)
|| (failure_childred_num == failure_threshold_))
{
skip_list_.clear();
haltChildren();
Expand Down Expand Up @@ -127,12 +144,22 @@ void ParallelNode::halt()

unsigned int ParallelNode::thresholdM()
{
return threshold_;
return success_threshold_;
}

unsigned int ParallelNode::thresholdFM()
{
return failure_threshold_;
}

void ParallelNode::setThresholdM(unsigned int threshold_M)
{
threshold_ = threshold_M;
success_threshold_ = threshold_M;
}

void ParallelNode::setThresholdFM(unsigned int threshold_M)
{
failure_threshold_ = threshold_M;
}

}
49 changes: 48 additions & 1 deletion tests/gtest_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ TEST_F(SimpleParallelTest, Threshold_3)
ASSERT_EQ(NodeStatus::SUCCESS, state);
}

TEST_F(SimpleParallelTest, Threshold_1)
TEST_F(SimpleParallelTest, Threshold_2)
{
root.setThresholdM(2);
BT::NodeStatus state = root.executeTick();
Expand Down Expand Up @@ -189,11 +189,57 @@ TEST_F(ComplexParallelTest, ConditionsTrue)
ASSERT_EQ(NodeStatus::SUCCESS, state);
}

TEST_F(ComplexParallelTest, ConditionsLeftFalse)
{
parallel_left.setThresholdFM(3);
parallel_left.setThresholdM(3);
condition_L1.setExpectedResult(NodeStatus::FAILURE);
condition_L2.setExpectedResult(NodeStatus::FAILURE);
BT::NodeStatus state = parallel_root.executeTick();

// It fails because Parallel Left it will never succeed (two already fail)
// even though threshold_failure == 3

ASSERT_EQ(NodeStatus::IDLE, parallel_left.status());
ASSERT_EQ(NodeStatus::IDLE, condition_L1.status());
ASSERT_EQ(NodeStatus::IDLE, condition_L2.status());
ASSERT_EQ(NodeStatus::IDLE, action_L1.status());
ASSERT_EQ(NodeStatus::IDLE, action_L2.status());

ASSERT_EQ(NodeStatus::IDLE, parallel_right.status());
ASSERT_EQ(NodeStatus::IDLE, condition_R.status());
ASSERT_EQ(NodeStatus::IDLE, action_R.status());

ASSERT_EQ(NodeStatus::FAILURE, state);
}

TEST_F(ComplexParallelTest, ConditionRightFalse)
{
condition_R.setExpectedResult(NodeStatus::FAILURE);
BT::NodeStatus state = parallel_root.executeTick();

// It fails because threshold_failure is 1 for parallel right and
// condition_R fails

ASSERT_EQ(NodeStatus::IDLE, parallel_left.status());
ASSERT_EQ(NodeStatus::IDLE, condition_L1.status());
ASSERT_EQ(NodeStatus::IDLE, condition_L2.status());
ASSERT_EQ(NodeStatus::IDLE, action_L1.status());
ASSERT_EQ(NodeStatus::IDLE, action_L2.status());

ASSERT_EQ(NodeStatus::IDLE, parallel_right.status());
ASSERT_EQ(NodeStatus::IDLE, condition_R.status());
ASSERT_EQ(NodeStatus::IDLE, action_R.status());

ASSERT_EQ(NodeStatus::FAILURE, state);
}

TEST_F(ComplexParallelTest, ConditionRightFalse_thresholdF_2)
{
parallel_right.setThresholdFM(2);
condition_R.setExpectedResult(NodeStatus::FAILURE);
BT::NodeStatus state = parallel_root.executeTick();

// All the actions are running

ASSERT_EQ(NodeStatus::RUNNING, parallel_left.status());
Expand Down Expand Up @@ -229,6 +275,7 @@ TEST_F(ComplexParallelTest, ConditionRightFalseAction1Done)
{
condition_R.setExpectedResult(NodeStatus::FAILURE);

parallel_right.setThresholdFM(2);
parallel_left.setThresholdM(4);

BT::NodeStatus state = parallel_root.executeTick();
Expand Down