Skip to content

Commit

Permalink
Keep track of DeathMonitor cookies
Browse files Browse the repository at this point in the history
This change keeps track of the objects that the cookies points to so the
serviceDied callback knows when it can use the cookie.

Test: atest neuralnetworks_utils_hal_aidl_test
Tets: atest NeuralNetworksTest_static
Bug: 319210610
(cherry picked from https://googleplex-android-review.googlesource.com/q/commit:def7a3cf59fa17ba7faa9af14a24f4161bc276bd)
(cherry picked from https://googleplex-android-review.googlesource.com/q/commit:49859a3b5542270363efe42a56b9145142bbfa60)
Merged-In: I418cbc6baa19aa702d9fd2e7d8096fe1a02b7794
Change-Id: I418cbc6baa19aa702d9fd2e7d8096fe1a02b7794
  • Loading branch information
devinmoore-goog authored and Android Build Coastguard Worker committed Jun 6, 2024
1 parent d69eb65 commit a987ebb
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 10 deletions.
11 changes: 11 additions & 0 deletions neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,27 @@ class IProtectedCallback {
// Thread safe class
class DeathMonitor final {
public:
explicit DeathMonitor(uintptr_t cookieKey) : kCookieKey(cookieKey) {}

static void serviceDied(void* cookie);
void serviceDied();
// Precondition: `killable` must be non-null.
void add(IProtectedCallback* killable) const;
// Precondition: `killable` must be non-null.
void remove(IProtectedCallback* killable) const;

uintptr_t getCookieKey() const { return kCookieKey; }

~DeathMonitor();
DeathMonitor(const DeathMonitor&) = delete;
DeathMonitor(DeathMonitor&&) noexcept = delete;
DeathMonitor& operator=(const DeathMonitor&) = delete;
DeathMonitor& operator=(DeathMonitor&&) noexcept = delete;

private:
mutable std::mutex mMutex;
mutable std::vector<IProtectedCallback*> mObjects GUARDED_BY(mMutex);
const uintptr_t kCookieKey;
};

class DeathHandler final {
Expand Down
56 changes: 49 additions & 7 deletions neuralnetworks/aidl/utils/src/ProtectCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <vector>
Expand All @@ -33,15 +34,41 @@

namespace aidl::android::hardware::neuralnetworks::utils {

namespace {

// Only dereference the cookie if it's valid (if it's in this set)
// Only used with ndk
std::mutex sCookiesMutex;
uintptr_t sCookieKeyCounter GUARDED_BY(sCookiesMutex) = 0;
std::map<uintptr_t, std::weak_ptr<DeathMonitor>> sCookies GUARDED_BY(sCookiesMutex);

} // namespace

void DeathMonitor::serviceDied() {
std::lock_guard guard(mMutex);
std::for_each(mObjects.begin(), mObjects.end(),
[](IProtectedCallback* killable) { killable->notifyAsDeadObject(); });
}

void DeathMonitor::serviceDied(void* cookie) {
auto deathMonitor = static_cast<DeathMonitor*>(cookie);
deathMonitor->serviceDied();
std::shared_ptr<DeathMonitor> monitor;
{
std::lock_guard<std::mutex> guard(sCookiesMutex);
if (auto it = sCookies.find(reinterpret_cast<uintptr_t>(cookie)); it != sCookies.end()) {
monitor = it->second.lock();
sCookies.erase(it);
} else {
LOG(INFO)
<< "Service died, but cookie is no longer valid so there is nothing to notify.";
return;
}
}
if (monitor) {
LOG(INFO) << "Notifying DeathMonitor from serviceDied.";
monitor->serviceDied();
} else {
LOG(INFO) << "Tried to notify DeathMonitor from serviceDied but could not promote.";
}
}

void DeathMonitor::add(IProtectedCallback* killable) const {
Expand All @@ -57,21 +84,35 @@ void DeathMonitor::remove(IProtectedCallback* killable) const {
mObjects.erase(removedIter);
}

DeathMonitor::~DeathMonitor() {
// lock must be taken so object is not used in OnBinderDied"
std::lock_guard<std::mutex> guard(sCookiesMutex);
sCookies.erase(kCookieKey);
}

nn::GeneralResult<DeathHandler> DeathHandler::create(std::shared_ptr<ndk::ICInterface> object) {
if (object == nullptr) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
<< "utils::DeathHandler::create must have non-null object";
}
auto deathMonitor = std::make_shared<DeathMonitor>();

std::shared_ptr<DeathMonitor> deathMonitor;
{
std::lock_guard<std::mutex> guard(sCookiesMutex);
deathMonitor = std::make_shared<DeathMonitor>(sCookieKeyCounter++);
sCookies[deathMonitor->getCookieKey()] = deathMonitor;
}

auto deathRecipient = ndk::ScopedAIBinder_DeathRecipient(
AIBinder_DeathRecipient_new(DeathMonitor::serviceDied));

// If passed a local binder, AIBinder_linkToDeath will do nothing and return
// STATUS_INVALID_OPERATION. We ignore this case because we only use local binders in tests
// where this is not an error.
if (object->isRemote()) {
const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_linkToDeath(
object->asBinder().get(), deathRecipient.get(), deathMonitor.get()));
const auto ret = ndk::ScopedAStatus::fromStatus(
AIBinder_linkToDeath(object->asBinder().get(), deathRecipient.get(),
reinterpret_cast<void*>(deathMonitor->getCookieKey())));
HANDLE_ASTATUS(ret) << "AIBinder_linkToDeath failed";
}

Expand All @@ -91,8 +132,9 @@ DeathHandler::DeathHandler(std::shared_ptr<ndk::ICInterface> object,

DeathHandler::~DeathHandler() {
if (kObject != nullptr && kDeathRecipient.get() != nullptr && kDeathMonitor != nullptr) {
const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_unlinkToDeath(
kObject->asBinder().get(), kDeathRecipient.get(), kDeathMonitor.get()));
const auto ret = ndk::ScopedAStatus::fromStatus(
AIBinder_unlinkToDeath(kObject->asBinder().get(), kDeathRecipient.get(),
reinterpret_cast<void*>(kDeathMonitor->getCookieKey())));
const auto maybeSuccess = handleTransportError(ret);
if (!maybeSuccess.ok()) {
LOG(ERROR) << maybeSuccess.error().message;
Expand Down
9 changes: 6 additions & 3 deletions neuralnetworks/aidl/utils/test/DeviceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,8 @@ TEST_P(DeviceTest, prepareModelAsyncCrash) {
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
const auto ret = [&device]() {
DeathMonitor::serviceDied(device->getDeathMonitor());
DeathMonitor::serviceDied(
reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));
return ndk::ScopedAStatus::ok();
};
EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
Expand Down Expand Up @@ -846,7 +847,8 @@ TEST_P(DeviceTest, prepareModelWithConfigAsyncCrash) {
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
const auto ret = [&device]() {
DeathMonitor::serviceDied(device->getDeathMonitor());
DeathMonitor::serviceDied(
reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));
return ndk::ScopedAStatus::ok();
};
EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
Expand Down Expand Up @@ -970,7 +972,8 @@ TEST_P(DeviceTest, prepareModelFromCacheAsyncCrash) {
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice, kVersion).value();
const auto ret = [&device]() {
DeathMonitor::serviceDied(device->getDeathMonitor());
DeathMonitor::serviceDied(
reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));
return ndk::ScopedAStatus::ok();
};
EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
Expand Down

0 comments on commit a987ebb

Please sign in to comment.