From 36243c164a75a1e7ba3c5eecf5f5ec995836eda1 Mon Sep 17 00:00:00 2001 From: Dmitriy Korchemkin Date: Thu, 19 Jan 2023 14:33:59 +0300 Subject: [PATCH] Fixes #6 Change-Id: I6a60915bf56c14cff2ed8e989b44a0d6a25cf8ff --- internal/ceres/parallel_for.h | 28 ++-- internal/ceres/parallel_for_synchronization.h | 2 +- internal/ceres/parallel_invoke.h | 128 ++++++++---------- .../ceres/partition_range_for_parallel_for.h | 22 +-- 4 files changed, 76 insertions(+), 104 deletions(-) diff --git a/internal/ceres/parallel_for.h b/internal/ceres/parallel_for.h index 0234835c0f..b4ce68c3aa 100644 --- a/internal/ceres/parallel_for.h +++ b/internal/ceres/parallel_for.h @@ -70,23 +70,20 @@ CERES_NO_EXPORT int MaxNumThreadsAvailable(); // When distributing workload between threads, it is assumed that each loop // iteration takes approximately equal time to complete. template -void ParallelFor(ContextImpl* context, - int start, - int end, - int num_threads, - const F& function) { +void ParallelFor( + ContextImpl* context, int start, int end, int num_threads, F&& function) { CHECK_GT(num_threads, 0); if (start >= end) { return; } if (num_threads == 1 || end - start == 1) { - InvokeOnSegment(0, std::make_tuple(start, end), function); + InvokeOnSegment(0, std::make_tuple(start, end), function); return; } CHECK(context != nullptr); - ParallelInvoke(context, start, end, num_threads, function); + ParallelInvoke(context, start, end, num_threads, function); } // Execute function for every element in the range [start, end) with at most @@ -99,7 +96,7 @@ void ParallelFor(ContextImpl* context, int start, int end, int num_threads, - const F& function, + F&& function, const std::vector& partitions) { CHECK_GT(num_threads, 0); if (start >= end) { @@ -119,11 +116,16 @@ void ParallelFor(ContextImpl* context, num_threads, [&function, &partitions](int thread_id, std::tuple partition_ids) { + // partition_ids is a range of partition indices const auto [partition_start, partition_end] = partition_ids; + // Execution over several adjacent segments is equivalent + // to execution over union of those segments (which is also a + // contiguous segment) const int range_start = partitions[partition_start]; const int range_end = partitions[partition_end]; + // Range of original loop indices const auto range = std::make_tuple(range_start, range_end); - InvokeOnSegment(thread_id, range, function); + InvokeOnSegment(thread_id, range, function); }); } @@ -148,9 +150,9 @@ void ParallelFor(ContextImpl* context, int start, int end, int num_threads, - const F& function, + F&& function, const CumulativeCostData* cumulative_cost_data, - const CumulativeCostFun& cumulative_cost_fun) { + CumulativeCostFun&& cumulative_cost_fun) { CHECK_GT(num_threads, 0); if (start >= end) { return; @@ -161,9 +163,9 @@ void ParallelFor(ContextImpl* context, } // Creating several partitions allows us to tolerate imperfections of // partitioning and user-supplied iteration costs up to a certain extent - const int kNumPartitionsPerThread = 4; + constexpr int kNumPartitionsPerThread = 4; const int kMaxPartitions = num_threads * kNumPartitionsPerThread; - const std::vector partitions = PartitionRangeForParallelFor( + const auto& partitions = PartitionRangeForParallelFor( start, end, kMaxPartitions, cumulative_cost_data, cumulative_cost_fun); CHECK_GT(partitions.size(), 1); ParallelFor(context, start, end, num_threads, function, partitions); diff --git a/internal/ceres/parallel_for_synchronization.h b/internal/ceres/parallel_for_synchronization.h index 9fadc54f28..4beafd71e0 100644 --- a/internal/ceres/parallel_for_synchronization.h +++ b/internal/ceres/parallel_for_synchronization.h @@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2022 Google Inc. All rights reserved. +// Copyright 2023 Google Inc. All rights reserved. // http://ceres-solver.org/ // // Redistribution and use in source and binary forms, with or without diff --git a/internal/ceres/parallel_invoke.h b/internal/ceres/parallel_invoke.h index 68390f1308..7f35173dd4 100644 --- a/internal/ceres/parallel_invoke.h +++ b/internal/ceres/parallel_invoke.h @@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2022 Google Inc. All rights reserved. +// Copyright 2023 Google Inc. All rights reserved. // http://ceres-solver.org/ // // Redistribution and use in source and binary forms, with or without @@ -40,81 +40,62 @@ namespace ceres::internal { -// Get arguments of callable as a tuple -template -std::tuple...> args_of(void (F::*)(Args...) const); - template -using args_of_t = decltype(args_of(&F::operator())); - -// Parallelizable functions might require passing thread_id as the first -// argument. This class supplies thread_id argument to functions that -// support it and ignores it otherwise. -template -struct InvokeImpl; - -// For parallel for iterations of type [](int i) -> void -template -struct InvokeImpl> { - static void InvokeOnSegment(int thread_id, - std::tuple range, - const F& function) { - (void)thread_id; - auto [start, end] = range; - for (int i = start; i < end; ++i) { - function(i); - } - } +struct InvokeOnSegmentTraits { + // Get arguments of callable as a tuple + template + std::tuple...> args_of(void (F::*)(Args...) const); + + template + using args_of_t = decltype(args_of(&F::operator())); + + using Args = args_of_t>; + static constexpr int NumArgs = std::tuple_size_v; + using FirstArg = typename std::tuple_element<0, Args>::type; + using LastArg = typename std::tuple_element::type; + + static constexpr bool FirstArgIsInt = std::is_same_v; + static constexpr bool LastArgIsInt = std::is_same_v; + + // Functions having at least 2 arguments with the first of them being int + // will be passed thread_id as the first argument. + static constexpr bool PassThreadId = NumArgs > 1 && FirstArgIsInt; + // Loop over indexes corresponding to segment is added for functions with 1 or + // 2 int arguments + static constexpr bool AddLoop = LastArgIsInt && FirstArgIsInt && NumArgs <= 2; }; -// For parallel for iterations of type [](int thread_id, int i) -> void -template -struct InvokeImpl> { - static void InvokeOnSegment(int thread_id, - std::tuple range, - const F& function) { - auto [start, end] = range; - for (int i = start; i < end; ++i) { - function(thread_id, i); - } - } -}; +// InvokeWithThreadId handles passing thread_id to the function +template ::PassThreadId, bool> = true> +void InvokeWithThreadId(int thread_id, F&& function, Args&&... args) { + function(thread_id, std::forward(args)...); +} -// For parallel for iterations of type [](tuple range) -> void -template -struct InvokeImpl>> { - static void InvokeOnSegment(int thread_id, - std::tuple range, - const F& function) { - (void)thread_id; - function(range); - } -}; +template < + typename F, + typename... Args, + std::enable_if_t::PassThreadId, bool> = true> +void InvokeWithThreadId(int /* thread_id */, F&& function, Args&&... args) { + function(std::forward(args)...); +} -// For parallel for iterations of type [](int thread_id, tuple range) -// -> void -template -struct InvokeImpl>> { - static void InvokeOnSegment(int thread_id, - std::tuple range, - const F& function) { - function(thread_id, range); +// InvokeOnSegment either runs a loop over segment indices or passes it to the +// function +template ::AddLoop, bool> = true> +void InvokeOnSegment(int thread_id, std::tuple range, F&& function) { + auto [start, end] = range; + for (int i = start; i != end; ++i) { + InvokeWithThreadId(thread_id, function, i); } -}; +} -// Invoke function on indices from contiguous range according to function -// signature. The following signatures are supported: -// - Functions processing single index per call: -// - [](int index) -> void -// - [](int thread_id, int index) -> void -// - Functions processing contiguous range [start, end) of indices per call: -// - [](std::tuple range) -> void -// Function arguments might have reference type and const qualifier -template -void InvokeOnSegment(int thread_id, - std::tuple range, - const F& function) { - InvokeImpl>::InvokeOnSegment(thread_id, range, function); +template ::AddLoop, bool> = true> +void InvokeOnSegment(int thread_id, std::tuple range, F&& function) { + InvokeWithThreadId(thread_id, function, range); } // This implementation uses a fixed size max worker pool with a shared task @@ -136,11 +117,8 @@ void InvokeOnSegment(int thread_id, // A performance analysis has shown this implementation is on par with OpenMP // and TBB. template -void ParallelInvoke(ContextImpl* context, - int start, - int end, - int num_threads, - const F& function) { +void ParallelInvoke( + ContextImpl* context, int start, int end, int num_threads, F&& function) { CHECK(context != nullptr); // Maximal number of work items scheduled for a single thread @@ -214,7 +192,7 @@ void ParallelInvoke(ContextImpl* context, (block_id < num_base_p1_sized_blocks ? 1 : 0); // Perform each task in current block const auto range = std::make_tuple(curr_start, curr_end); - InvokeOnSegment(thread_id, range, function); + InvokeOnSegment(thread_id, range, function); } shared_state->block_until_finished.Finished(num_jobs_finished); }; diff --git a/internal/ceres/partition_range_for_parallel_for.h b/internal/ceres/partition_range_for_parallel_for.h index 8d81fa2fd6..d1c0029503 100644 --- a/internal/ceres/partition_range_for_parallel_for.h +++ b/internal/ceres/partition_range_for_parallel_for.h @@ -1,5 +1,5 @@ // Ceres Solver - A fast non-linear least squares minimizer -// Copyright 2022 Google Inc. All rights reserved. +// Copyright 2023 Google Inc. All rights reserved. // http://ceres-solver.org/ // // Redistribution and use in source and binary forms, with or without @@ -115,7 +115,11 @@ std::vector PartitionRangeForParallelFor( // Upper bound is inclusive int partition_cost_upper_bound = total_cost; - std::vector partition, partition_upper_bound; + std::vector partition; + // Range partition corresponding to the latest evaluated upper bound. + // A single segment covering the whole input interval [start, end) corresponds + // to minimal maximal partition cost of total_cost. + std::vector partition_upper_bound = {start, end}; // Binary search over partition cost, returning the lowest admissible cost while (partition_cost_upper_bound - partition_cost_lower_bound > 1) { partition.reserve(max_num_partitions + 1); @@ -138,19 +142,7 @@ std::vector PartitionRangeForParallelFor( } } - // After binary search over partition cost, interval - // (partition_cost_lower_bound, partition_cost_upper_bound] contains the only - // admissible partition cost value - partition_cost_upper_bound - // - // Partition for this cost value might have been already computed - if (partition_upper_bound.empty() == false) { - return partition_upper_bound; - } - // Partition for upper bound is not computed if and only if upper bound was - // never updated This is a simple case of a single interval containing all - // values, which we were not able to break into pieces - partition = {start, end}; - return partition; + return partition_upper_bound; } } // namespace ceres::internal