diff --git a/libs/core/synchronization/include/hpx/synchronization/range_lock.hpp b/libs/core/synchronization/include/hpx/synchronization/range_lock.hpp index c09b536fd58..34b0af7c4d4 100644 --- a/libs/core/synchronization/include/hpx/synchronization/range_lock.hpp +++ b/libs/core/synchronization/include/hpx/synchronization/range_lock.hpp @@ -1,104 +1,109 @@ #pragma once -#include -#include -#include -#include #include -#include -#include + +#include +#include namespace hpx::synchronization { + using range_lock = + hpx::synchronization::detail::RangeLock; +} - template typename Guard> - class RangeLock - { - template - using MapTy = boost::container::flat_map; +// Lock guards for range_lock +namespace hpx::synchronization { - Lock mtx; - std::size_t counter = 0; - MapTy> rangeMap; - MapTy> waiting; + template + class range_guard + { + std::reference_wrapper lockRef; + std::size_t lockId = 0; public: - std::size_t lock(std::size_t begin, std::size_t end); - std::size_t try_lock(std::size_t begin, std::size_t end); - void unlock(std::size_t lockId); + range_guard(RangeLock& lock, std::size_t begin, std::size_t end) + : lockRef(lock) + { + lockId = lockRef.get().lock(begin, end); + } + ~range_guard() + { + lockRef.get().unlock(lockId); + } }; - template class Guard> - std::size_t RangeLock::lock(std::size_t begin, std::size_t end) +} // namespace hpx::synchronization + +namespace hpx::synchronization { + + template + class range_unique_lock { + std::reference_wrapper lockRef; std::size_t lockId = 0; - bool localFlag = false; - std::size_t blockIdx; - std::shared_ptr waitingFlag; + public: + range_unique_lock(RangeLock& lock, std::size_t begin, std::size_t end) + : lockRef(lock) + { + lockId = lockRef.get().lock(begin, end); + } - while (lockId == 0) + ~range_unique_lock() { - { - const Guard lock_guard(mtx); - for (auto const& it : rangeMap) - { - std::size_t b = it.second.first; - std::size_t e = it.second.second; - - if (!(e < begin) & !(end < b)) - { - blockIdx = it.first; - localFlag = true; - waitingFlag = waiting[blockIdx]; - break; - } - } - if (localFlag == false) - { - ++counter; - rangeMap[counter] = {begin, end}; - waiting[counter] = std::shared_ptr( - new std::atomic_bool(false)); - return counter; - } - localFlag = false; - } - while (waitingFlag->load() == false) - { - } + lockRef.get().unlock(lockId); } - return lockId; - } - template class Guard> - void RangeLock::unlock(std::size_t lockId) - { - const Guard lock_guard(mtx); + void operator=(range_unique_lock&& lock) + { + lockRef.get().unlock(lockId); + lockRef = lock.lockRef; + lockId = lock.lockRef.get().lock(); + } + + void lock(std::size_t begin, std::size_t end) + { + lockId = lockRef.get().lock(begin, end); + } - rangeMap.erase(lockId); + void try_lock(std::size_t begin, std::size_t end) + { + lockId = lockRef.get().try_lock(begin, end); + } - waiting[lockId]->store(true); + void unlock() + { + lockRef.get().unlock(lockId); + lockId = 0; + } - waiting.erase(lockId); - return; - } + void swap(std::unique_lock& uLock) + { + std::swap(lockRef, uLock.lockRef); + std::swap(lockId, uLock.lockId); + } - template class Guard> - std::size_t RangeLock::try_lock( - std::size_t begin, std::size_t end) - { - const Guard lock_guard(mtx); - for (auto const& it : rangeMap) + RangeLock* release() { - std::size_t b = it.second.first; - std::size_t e = it.second.second; + RangeLock* mtx = lockRef.get(); + lockRef = nullptr; + lockId = 0; + return mtx; + } - if (!(e < begin) && !(end < b)) - { - return 0; - } + operator bool() const + { + return lockId != 0; } - rangeMap[++counter] = {begin, end}; - return counter; - } -} // namespace hpx::synchronization \ No newline at end of file + + bool owns_lock() const + { + return lockId != 0; + } + + RangeLock* mutex() const + { + return lockRef.get(); + } + }; + +} // namespace hpx::synchronization diff --git a/libs/core/synchronization/tests/unit/range_lock.cpp b/libs/core/synchronization/tests/unit/range_lock.cpp index a83162023f5..7c5cf24959d 100644 --- a/libs/core/synchronization/tests/unit/range_lock.cpp +++ b/libs/core/synchronization/tests/unit/range_lock.cpp @@ -4,7 +4,25 @@ int main() { - hpx::synchronization::RangeLock rl; - std::size_t x = rl.lock(0, 10); - rl.unlock(x); + { + hpx::synchronization::range_lock rl; + std::size_t x = rl.lock(0, 10); + rl.unlock(x); + return 0; + } + + { + hpx::synchronization::range_lock rl; + + hpx::synchronization::range_guard rg( + rl, 0, 10); + } + + { + hpx::synchronization::range_lock rl; + + hpx::synchronization::range_unique_lock< + hpx::synchronization::range_lock> + rg(rl, 0, 10); + } }