Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Change-Id: I6a60915bf56c14cff2ed8e989b44a0d6a25cf8ff
  • Loading branch information
DmitriyKorchemkin committed Jan 21, 2023
1 parent 390bf78 commit 2e2ff3e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 97 deletions.
28 changes: 15 additions & 13 deletions internal/ceres/parallel_for.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,20 @@ CERES_NO_EXPORT int MaxNumThreadsAvailable();
// When distributing workload between threads, it is assumed that each loop
// iteration takes approximately equal time to complete.
template <typename F>
void ParallelFor(ContextImpl* context,
int start,
int end,
int num_threads,
const F& function) {
void ParallelFor(
ContextImpl* context, int start, int end, int num_threads, F&& function) {
CHECK_GT(num_threads, 0);
if (start >= end) {
return;
}

if (num_threads == 1 || end - start == 1) {
InvokeOnSegment<F>(0, std::make_tuple(start, end), function);
InvokeOnSegment(0, std::make_tuple(start, end), function);
return;
}

CHECK(context != nullptr);
ParallelInvoke<F>(context, start, end, num_threads, function);
ParallelInvoke(context, start, end, num_threads, function);
}

// Execute function for every element in the range [start, end) with at most
Expand All @@ -99,7 +96,7 @@ void ParallelFor(ContextImpl* context,
int start,
int end,
int num_threads,
const F& function,
F&& function,
const std::vector<int>& partitions) {
CHECK_GT(num_threads, 0);
if (start >= end) {
Expand All @@ -119,11 +116,16 @@ void ParallelFor(ContextImpl* context,
num_threads,
[&function, &partitions](int thread_id,
std::tuple<int, int> partition_ids) {
// partition_ids is a range of partition indices
const auto [partition_start, partition_end] = partition_ids;
// Execution over several adjacent segments is equivalent
// to execution over union of those segments (which is also a
// contiguous segment)
const int range_start = partitions[partition_start];
const int range_end = partitions[partition_end];
// Range of original loop indices
const auto range = std::make_tuple(range_start, range_end);
InvokeOnSegment<F>(thread_id, range, function);
InvokeOnSegment(thread_id, range, function);
});
}

Expand All @@ -148,9 +150,9 @@ void ParallelFor(ContextImpl* context,
int start,
int end,
int num_threads,
const F& function,
F&& function,
const CumulativeCostData* cumulative_cost_data,
const CumulativeCostFun& cumulative_cost_fun) {
CumulativeCostFun&& cumulative_cost_fun) {
CHECK_GT(num_threads, 0);
if (start >= end) {
return;
Expand All @@ -161,9 +163,9 @@ void ParallelFor(ContextImpl* context,
}
// Creating several partitions allows us to tolerate imperfections of
// partitioning and user-supplied iteration costs up to a certain extent
const int kNumPartitionsPerThread = 4;
constexpr int kNumPartitionsPerThread = 4;
const int kMaxPartitions = num_threads * kNumPartitionsPerThread;
const std::vector<int> partitions = PartitionRangeForParallelFor(
const auto& partitions = PartitionRangeForParallelFor(
start, end, kMaxPartitions, cumulative_cost_data, cumulative_cost_fun);
CHECK_GT(partitions.size(), 1);
ParallelFor(context, start, end, num_threads, function, partitions);
Expand Down
2 changes: 1 addition & 1 deletion internal/ceres/parallel_for_synchronization.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
// Copyright 2022 Google Inc. All rights reserved.
// Copyright 2023 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
Expand Down
109 changes: 41 additions & 68 deletions internal/ceres/parallel_invoke.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
// Copyright 2022 Google Inc. All rights reserved.
// Copyright 2023 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -47,74 +47,50 @@ std::tuple<std::decay_t<Args>...> args_of(void (F::*)(Args...) const);
template <typename F>
using args_of_t = decltype(args_of(&F::operator()));

// Parallelizable functions might require passing thread_id as the first
// argument. This class supplies thread_id argument to functions that
// support it and ignores it otherwise.
template <typename F, typename Args>
struct InvokeImpl;

// For parallel for iterations of type [](int i) -> void
template <typename F>
struct InvokeImpl<F, std::tuple<int>> {
static void InvokeOnSegment(int thread_id,
std::tuple<int, int> range,
const F& function) {
(void)thread_id;
auto [start, end] = range;
for (int i = start; i < end; ++i) {
function(i);
}
}
};
struct FunctionTraits {
using Function = std::remove_reference_t<F>;

// For parallel for iterations of type [](int thread_id, int i) -> void
template <typename F>
struct InvokeImpl<F, std::tuple<int, int>> {
static void InvokeOnSegment(int thread_id,
std::tuple<int, int> range,
const F& function) {
auto [start, end] = range;
for (int i = start; i < end; ++i) {
function(thread_id, i);
}
}
};
using Args = args_of_t<Function>;
static constexpr int NumArgs = std::tuple_size_v<Args>;
using FirstArg = typename std::tuple_element<0, Args>::type;
using LastArg = typename std::tuple_element<NumArgs - 1, Args>::type;

// For parallel for iterations of type [](tuple<int, int> range) -> void
template <typename F>
struct InvokeImpl<F, std::tuple<std::tuple<int, int>>> {
static void InvokeOnSegment(int thread_id,
std::tuple<int, int> range,
const F& function) {
(void)thread_id;
function(range);
}
static constexpr bool FirstArgIsInt = std::is_same_v<FirstArg, int>;
static constexpr bool LastArgIsInt = std::is_same_v<LastArg, int>;

static constexpr bool PassThreadId = NumArgs > 1 && FirstArgIsInt;
static constexpr bool AddOuterLoop =
LastArgIsInt && FirstArgIsInt && NumArgs <= 2;
};

// For parallel for iterations of type [](int thread_id, tuple<int, int> range)
// -> void
template <typename F>
struct InvokeImpl<F, std::tuple<int, std::tuple<int, int>>> {
static void InvokeOnSegment(int thread_id,
std::tuple<int, int> range,
const F& function) {
function(thread_id, range);
template <typename F,
typename... Args,
std::enable_if_t<FunctionTraits<F>::PassThreadId, bool> = true>
void InvokeWithThreadId(int thread_id, F&& function, Args&&... args) {
function(thread_id, args...);
}

template <typename F,
typename... Args,
std::enable_if_t<!FunctionTraits<F>::PassThreadId, bool> = true>
void InvokeWithThreadId(int /* thread_id */, F&& function, Args&&... args) {
function(args...);
}

template <typename F,
std::enable_if_t<FunctionTraits<F>::AddOuterLoop, bool> = true>
void InvokeOnSegment(int thread_id, std::tuple<int, int> range, F&& function) {
auto [start, end] = range;
for (int i = start; i != end; ++i) {
InvokeWithThreadId(thread_id, function, i);
}
};
}

// Invoke function on indices from contiguous range according to function
// signature. The following signatures are supported:
// - Functions processing single index per call:
// - [](int index) -> void
// - [](int thread_id, int index) -> void
// - Functions processing contiguous range [start, end) of indices per call:
// - [](std::tuple<int, int> range) -> void
// Function arguments might have reference type and const qualifier
template <typename F>
void InvokeOnSegment(int thread_id,
std::tuple<int, int> range,
const F& function) {
InvokeImpl<F, args_of_t<F>>::InvokeOnSegment(thread_id, range, function);
template <typename F,
std::enable_if_t<!FunctionTraits<F>::AddOuterLoop, bool> = true>
void InvokeOnSegment(int thread_id, std::tuple<int, int> range, F&& function) {
InvokeWithThreadId(thread_id, function, range);
}

// This implementation uses a fixed size max worker pool with a shared task
Expand All @@ -136,11 +112,8 @@ void InvokeOnSegment(int thread_id,
// A performance analysis has shown this implementation is on par with OpenMP
// and TBB.
template <typename F>
void ParallelInvoke(ContextImpl* context,
int start,
int end,
int num_threads,
const F& function) {
void ParallelInvoke(
ContextImpl* context, int start, int end, int num_threads, F&& function) {
CHECK(context != nullptr);

// Maximal number of work items scheduled for a single thread
Expand Down Expand Up @@ -214,7 +187,7 @@ void ParallelInvoke(ContextImpl* context,
(block_id < num_base_p1_sized_blocks ? 1 : 0);
// Perform each task in current block
const auto range = std::make_tuple(curr_start, curr_end);
InvokeOnSegment<F>(thread_id, range, function);
InvokeOnSegment(thread_id, range, function);
}
shared_state->block_until_finished.Finished(num_jobs_finished);
};
Expand Down
22 changes: 7 additions & 15 deletions internal/ceres/partition_range_for_parallel_for.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Ceres Solver - A fast non-linear least squares minimizer
// Copyright 2022 Google Inc. All rights reserved.
// Copyright 2023 Google Inc. All rights reserved.
// http://ceres-solver.org/
//
// Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -115,7 +115,11 @@ std::vector<int> PartitionRangeForParallelFor(
// Upper bound is inclusive
int partition_cost_upper_bound = total_cost;

std::vector<int> partition, partition_upper_bound;
std::vector<int> partition;
// Range partition corresponding to the latest evaluated upper bound.
// A single segment covering the whole input interval [start, end) corresponds
// to minimal maximal partition cost of total_cost.
std::vector<int> partition_upper_bound = {start, end};
// Binary search over partition cost, returning the lowest admissible cost
while (partition_cost_upper_bound - partition_cost_lower_bound > 1) {
partition.reserve(max_num_partitions + 1);
Expand All @@ -138,19 +142,7 @@ std::vector<int> PartitionRangeForParallelFor(
}
}

// After binary search over partition cost, interval
// (partition_cost_lower_bound, partition_cost_upper_bound] contains the only
// admissible partition cost value - partition_cost_upper_bound
//
// Partition for this cost value might have been already computed
if (partition_upper_bound.empty() == false) {
return partition_upper_bound;
}
// Partition for upper bound is not computed if and only if upper bound was
// never updated This is a simple case of a single interval containing all
// values, which we were not able to break into pieces
partition = {start, end};
return partition;
return partition_upper_bound;
}
} // namespace ceres::internal

Expand Down

0 comments on commit 2e2ff3e

Please sign in to comment.