Skip to content

Commit

Permalink
Fix ServiceTracker race (#558)
Browse files Browse the repository at this point in the history
* Fix ServiceTracker race

The race occurs when a customizer is still running and the ServiceTracker object is in the process of being destroyed.

Signed-off-by: The MathWorks, Inc. <jdicleme@mathworks.com>

* Fix trackingCount member initialization

Signed-off-by: The MathWorks, Inc. <jdicleme@mathworks.com>

* Fix race

Signed-off-by: The MathWorks, Inc. <jdicleme@mathworks.com>

* Fix broken build

Signed-off-by: The MathWorks, Inc. <jdicleme@mathworks.com>
  • Loading branch information
jeffdiclemente committed May 6, 2021
1 parent a5a6be0 commit 6337ceb
Show file tree
Hide file tree
Showing 18 changed files with 406 additions and 162 deletions.
4 changes: 2 additions & 2 deletions compendium/DeclarativeServices/src/SCRActivator.cpp
Expand Up @@ -38,7 +38,7 @@
#include "cppmicroservices/servicecomponent/runtime/dto/ComponentDescriptionDTO.hpp"
#include "cppmicroservices/servicecomponent/runtime/dto/ReferenceDTO.hpp"

#include "cppmicroservices/util/ScopeGuard.h"
#include "cppmicroservices/detail/ScopeGuard.h"

using cppmicroservices::logservice::SeverityLevel;
using cppmicroservices::service::component::ComponentConstants::SERVICE_COMPONENT;
Expand Down Expand Up @@ -79,7 +79,7 @@ void SCRActivator::Stop(cppmicroservices::BundleContext context)
{
try
{
cppmicroservices::util::ScopeGuard joinThreadPool{ [this]() {
cppmicroservices::detail::ScopeGuard joinThreadPool{ [this]() {
if (threadpool) {
try {
threadpool->join();
Expand Down
99 changes: 0 additions & 99 deletions compendium/DeclarativeServices/src/manager/ConcurrencyUtil.hpp
Expand Up @@ -93,105 +93,6 @@ class Guarded
}
};

/**
* A utility class similar to std::latch except it allows for incrementing the count as well.
* It is useful when the number of threads entering a block may vary dynamically. This class
* can be used to synchronize two blocks of code where the first block can be executed by
* 'n' number of threads simultaneously and the second block needs to wait for all the threads
* to exit the first block.
*
* Example code:
* function1()
* {
* if(latch.CountUp())
* {
* // do something critical here
* latch.CountDown();
* }
* }
* function2()
* {
* latch.Wait(); // this blocks the thread until all threads that have exited function1
* // do something here.
* }
*/
class CounterLatch
{
public:
CounterLatch() : count(0) {}
CounterLatch(const CounterLatch&) = delete;
CounterLatch(CounterLatch&&) = delete;
CounterLatch& operator=(const CounterLatch&) = delete;
CounterLatch& operator=(CounterLatch&&) = delete;
~CounterLatch() = default;

/**
* Increments the count of the latch, if current count is not negative
* \return \c true if the count was incremented, \c false otherwise
*/
bool CountUp()
{
std::lock_guard<std::mutex> lock{mtx};
if(count >= 0)
{
++count;
return true;
}
return false;
}

/**
* Decrements the count of the latch, releasing all waiting threads if the
* count reaches zero.
* If the current count is greater than zero then it is decremented. If
* the new count is zero then all waiting threads are notified.
* If the current count equals zero then nothing happens.
*/
void CountDown()
{
std::lock_guard<std::mutex> lock{mtx};
if(count > 0)
{
if(--count == 0)
{
// notify waiting threads
cond.notify_all();
}
}
}

/**
* Waits until the counter reaches 0. The value of the counter after this
* method returns is invalid (negative). This method is designed for a
* one-time use only.
*
* \throws std::runtime_error if the current count is negative.
*/
void Wait()
{
std::unique_lock<std::mutex> lock{mtx};
if(count < 0)
{
throw std::runtime_error("CounterLatch is in invalid state.");
}
cond.wait(lock, [&]() { return count == 0; });
count = std::numeric_limits<long>::min(); // makes the latch unusable for other threads
}

/**
* Returns the current count of the latch
*/
long GetCount()
{
std::unique_lock<std::mutex> lock{mtx};
return count;
}

private:
std::mutex mtx; ///< mutex to protect access to the counter
std::condition_variable cond; ///< used to notify the waiting thread
long count{0}; ///< latch counter
};
}
}

Expand Down
Expand Up @@ -25,21 +25,11 @@
#include "CCUnsatisfiedReferenceState.hpp"
#include "cppmicroservices/SharedLibraryException.h"

#include "cppmicroservices/detail/ScopeGuard.h"

namespace cppmicroservices {
namespace scrimpl {

class LatchScopeGuard
{
public:
LatchScopeGuard(std::function<void()> cleanupFcn)
: _cleanupFcn(std::move(cleanupFcn))
{}
~LatchScopeGuard() { _cleanupFcn(); }

private:
std::function<void()> _cleanupFcn;
};

CCActiveState::CCActiveState() = default;

std::shared_ptr<ComponentInstance> CCActiveState::Activate(
Expand All @@ -51,7 +41,7 @@ std::shared_ptr<ComponentInstance> CCActiveState::Activate(
auto logger = mgr.GetLogger();
if (latch.CountUp()) {
{
LatchScopeGuard sg([this, logger]() {
detail::ScopeGuard sg([this, logger]() {
// By using try/catch here, we ensure that this lambda function doesn't
// throw inside LatchScopeGuard's dtor.
try {
Expand Down
Expand Up @@ -26,6 +26,8 @@
#include "CCSatisfiedState.hpp"
#include "../ConcurrencyUtil.hpp"

#include "cppmicroservices/detail/CounterLatch.h"

using cppmicroservices::service::component::runtime::dto::ComponentState;

namespace cppmicroservices {
Expand Down Expand Up @@ -77,7 +79,7 @@ class CCActiveState final
latch.Wait();
}
private:
CounterLatch latch;
detail::CounterLatch latch;
};
}
}
Expand Down
1 change: 0 additions & 1 deletion compendium/DeclarativeServices/test/CMakeLists.txt
Expand Up @@ -63,7 +63,6 @@ set(_declarativeservices_tests
TestComponentManagerEnabledState.cpp
TestComponentManagerImpl.cpp
TestComponentRegistry.cpp
TestCounterLatch.cpp
TestMetadataParserFactory.cpp
TestMetadataParserImplV1.cpp
TestReferenceManagerImpl.cpp
Expand Down
2 changes: 2 additions & 0 deletions framework/include/CMakeLists.txt
Expand Up @@ -59,4 +59,6 @@ set(_public_headers
cppmicroservices/detail/BundleAbstractTracked.h
cppmicroservices/detail/BundleAbstractTracked.tpp
cppmicroservices/detail/BundleResourceBuffer.h
cppmicroservices/detail/ScopeGuard.h
cppmicroservices/detail/CounterLatch.h
)
Expand Up @@ -64,7 +64,7 @@ class BundleAbstractTracked
/**
* BundleAbstractTracked constructor.
*/
BundleAbstractTracked(BundleContext* bc);
BundleAbstractTracked(BundleContext bc);

virtual ~BundleAbstractTracked();

Expand Down Expand Up @@ -284,7 +284,7 @@ class BundleAbstractTracked
*/
std::atomic<int> trackingCount;

BundleContext* const bc;
BundleContext bc;

bool CustomizerAddingFinal(S item,
const std::shared_ptr<TrackedParamType>& custom);
Expand Down
28 changes: 14 additions & 14 deletions framework/include/cppmicroservices/detail/BundleAbstractTracked.tpp
Expand Up @@ -30,8 +30,8 @@ namespace cppmicroservices {
namespace detail {

template<class S, class TTT, class R>
BundleAbstractTracked<S,TTT,R>::BundleAbstractTracked(BundleContext* bc)
: closed(false), bc(bc)
BundleAbstractTracked<S,TTT,R>::BundleAbstractTracked(BundleContext bc)
: closed(false), trackingCount(0), bc(bc)
{
}

Expand All @@ -44,12 +44,12 @@ void BundleAbstractTracked<S,TTT,R>::SetInitial(const std::vector<S>& initiallis
{
std::copy(initiallist.begin(), initiallist.end(), std::back_inserter(initial));

if (bc->GetLogSink()->Enabled())
if (bc.GetLogSink()->Enabled())
{
for(typename std::list<S>::const_iterator item = initial.begin();
item != initial.end(); ++item)
{
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::setInitial: " << (*item);
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::setInitial: " << (*item);
}
}
}
Expand Down Expand Up @@ -78,20 +78,20 @@ void BundleAbstractTracked<S,TTT,R>::TrackInitial()
if (tracked[item])
{
/* if we are already tracking this item */
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::trackInitial[already tracked]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::trackInitial[already tracked]: " << item;
continue; /* skip this item */
}
if (std::find(adding.begin(), adding.end(), item) != adding.end())
{
/*
* if this item is already in the process of being added.
*/
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::trackInitial[already adding]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::trackInitial[already adding]: " << item;
continue; /* skip this item */
}
adding.push_back(item);
}
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::trackInitial: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::trackInitial: " << item;
TrackAdding(item, R());
/*
* Begin tracking it. We call trackAdding
Expand Down Expand Up @@ -123,14 +123,14 @@ void BundleAbstractTracked<S,TTT,R>::Track(S item, R related)
if (std::find(adding.begin(), adding.end(),item) != adding.end())
{
/* if this item is already in the process of being added. */
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::track[already adding]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::track[already adding]: " << item;
return;
}
adding.push_back(item); /* mark this item is being added */
}
else
{ /* we are currently tracking this item */
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::track[modified]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::track[modified]: " << item;
Modified(); /* increment modification count */
}
}
Expand Down Expand Up @@ -162,7 +162,7 @@ void BundleAbstractTracked<S,TTT,R>::Untrack(S item, R related)
{ /* if this item is already in the list
* of initial references to process
*/
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::untrack[removed from initial]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::untrack[removed from initial]: " << item;
return; /* we have removed it from the list and it will not be
* processed
*/
Expand All @@ -174,7 +174,7 @@ void BundleAbstractTracked<S,TTT,R>::Untrack(S item, R related)
{ /* if the item is in the process of
* being added
*/
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::untrack[being added]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::untrack[being added]: " << item;
return; /*
* in case the item is untracked while in the process of
* adding
Expand All @@ -192,7 +192,7 @@ void BundleAbstractTracked<S,TTT,R>::Untrack(S item, R related)
}
Modified(); /* increment modification count */
}
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::untrack[removed]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::untrack[removed]: " << item;
/* Call customizer outside of synchronized region */
CustomizerRemoved(item, related, object);
/*
Expand Down Expand Up @@ -280,7 +280,7 @@ bool BundleAbstractTracked<S,TTT,R>::CustomizerAddingFinal(S item, const std::sh
template<class S, class TTT, class R>
void BundleAbstractTracked<S,TTT,R>::TrackAdding(S item, R related)
{
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::trackAdding:" << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::trackAdding:" << item;
std::shared_ptr<TrackedParamType> object;
bool becameUntracked = false;
/* Call customizer outside of synchronized region */
Expand All @@ -304,7 +304,7 @@ void BundleAbstractTracked<S,TTT,R>::TrackAdding(S item, R related)
*/
if (becameUntracked && object)
{
DIAG_LOG(*bc->GetLogSink()) << "BundleAbstractTracked::trackAdding[removed]: " << item;
DIAG_LOG(*bc.GetLogSink()) << "BundleAbstractTracked::trackAdding[removed]: " << item;
/* Call customizer outside of synchronized region */
CustomizerRemoved(item, related, object);
/*
Expand Down

0 comments on commit 6337ceb

Please sign in to comment.