Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved the BatchedPipeline #459

Merged
merged 3 commits into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 42 additions & 30 deletions src/util/BatchedPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <future>
#include <utility>

#include "./Log.h"
#include "./Timer.h"
#include "./TupleHelpers.h"

Expand All @@ -23,10 +24,10 @@ namespace detail {
*/
template <class T>
struct Batch {
bool m_isPipelineGood; // if set to false, this was the last (and possibly
// incomplete) batch, else there might be more content
// waiting in the pipeline.
std::vector<T> m_content; // the actual payload
bool m_isPipelineGood = true; // if set to false, this was the last (and
// possibly incomplete) batch, else there might
// be more content waiting in the pipeline.
std::vector<T> m_content; // the actual payload
};

/*
Expand Down Expand Up @@ -66,8 +67,10 @@ class Batcher {
*/
Batch<ValueT> pickupBatch() {
try {
_timer->cont();
auto res = _fut.get();
orderNextBatch();
_timer->stop();
return res;
} catch (const std::future_error& e) {
throw std::runtime_error(
Expand Down Expand Up @@ -95,34 +98,43 @@ class Batcher {
// since the unique_ptr _creator owns the creator,
// the captured pointer will stay valid even while this
// class is moved.
_fut =
std::async(std::launch::async,
[bs = _batchSize, ptr = _creator.get(), t = _timer.get()]() {
return produceBatchInternal(bs, t, ptr);
});
_fut = std::async(std::launch::async,
[bs = _batchSize, ptr = _creator.get()]() {
return produceBatchInternal(bs, ptr);
});
}

/* retrieve values from the creator and store them in the Batch result.
* Once we have reached batchSize Elements or the creator returns std::nullopt
* we return. In the latter case, result.first is false
*/
static detail::Batch<ValueT> produceBatchInternal(size_t batchSize,
ad_utility::Timer* timer,
Creator* creator) {
timer->cont();
detail::Batch<ValueT> res;
res.m_isPipelineGood = true;
res.m_content.reserve(batchSize);
for (size_t i = 0; i < batchSize; ++i) {
auto opt = (*creator)();
// If the Creator type has a method `getBatch`, use this method to produce
// the batch in one step, otherwise produce the batch value by value.
if constexpr (requires { creator->getBatch(); }) {
auto opt = creator->getBatch();
if (!opt) {
res.m_isPipelineGood = false;
return res;
}
res.m_content.push_back(std::move(opt.value()));
res.m_isPipelineGood = true;
res.m_content = std::move(*opt);
return res;
} else {
res.m_isPipelineGood = true;
res.m_content.reserve(batchSize);
for (size_t i = 0; i < batchSize; ++i) {
auto opt = (*creator)();
if (!opt) {
res.m_isPipelineGood = false;
return res;
}
res.m_content.push_back(std::move(opt.value()));
}
return res;
}
timer->stop();
return res;
}
};

Expand Down Expand Up @@ -176,8 +188,10 @@ class BatchedPipeline {
// _____________________________________________________________________
Batch<ResT> pickupBatch() {
try {
_timer->cont();
auto res = _fut.get();
orderNextBatch();
_timer->stop();
return res;
} catch (std::future_error& e) {
throw std::runtime_error(
Expand All @@ -188,14 +202,14 @@ class BatchedPipeline {

// asynchronously prepare the next Batch in a different thread
void orderNextBatch() {
auto lambda = [p = _previousStage.get(),
batchSize = _previousStage->getBatchSize(),
t = _timer.get()](auto... transformerPtrs) {
return std::async(
std::launch::async, [p, batchSize, t, transformerPtrs...]() {
return produceBatchInternal(p, batchSize, t, transformerPtrs...);
});
};
auto lambda =
[p = _previousStage.get(),
batchSize = _previousStage->getBatchSize()](auto... transformerPtrs) {
return std::async(
std::launch::async, [p, batchSize, transformerPtrs...]() {
return produceBatchInternal(p, batchSize, transformerPtrs...);
});
};
_fut = std::apply(lambda, _rawTransformers);
}

Expand Down Expand Up @@ -225,10 +239,8 @@ class BatchedPipeline {
template <typename... TransformerPtrs>
static Batch<ResT> produceBatchInternal(PreviousStage* previousStage,
size_t inBatchSize,
ad_utility::Timer* timer,
TransformerPtrs... transformers) {
auto inBatch = previousStage->pickupBatch();
timer->cont();
Batch<ResT> result;
result.m_isPipelineGood = inBatch.m_isPipelineGood;
// currently each of the <parallelism> threads first creates its own Batch
Expand All @@ -246,7 +258,6 @@ class BatchedPipeline {
std::make_move_iterator(vec.begin()),
std::make_move_iterator(vec.end()));
}
timer->stop();
return result;
}

Expand All @@ -256,8 +267,9 @@ class BatchedPipeline {
// size batchSize, then there will be an incomplete batch with the remaining
// range and all other ranges will be empty.
template <typename It>
static std::pair<It, It> getBatchRange(It beg, It end, const size_t batchSize,
static std::pair<It, It> getBatchRange(It beg, It end, size_t batchSize,
const size_t idx) {
batchSize = std::max(size_t(1), size_t((end - beg) / Parallelism));
std::pair<It, It> res;
res.first = std::min(beg + idx * batchSize, end);
res.second = idx < Parallelism - 1
Expand Down
6 changes: 1 addition & 5 deletions test/BatchedPipelineTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,7 @@ TEST(BatchedPipelineTest, BranchedParallelism) {

size_t j = 0;
while (auto opt = pipeline.getNextValue()) {
if (j % 20 < 10) {
ASSERT_EQ(opt.value(), j * 3);
} else {
ASSERT_EQ(opt.value(), j * 2);
}
ASSERT_TRUE(opt.value() == j * 3 || opt.value() == j * 2);
j++;
}
ASSERT_EQ(j, 67u);
Expand Down