From bce380d87166530c354b0a376481fcc5e16bdd10 Mon Sep 17 00:00:00 2001 From: Alexey Rybalchenko Date: Wed, 14 Jul 2021 10:46:12 +0200 Subject: [PATCH] Implement shmem msg zero-copy --- fairmq/Message.h | 4 + fairmq/shmem/Common.h | 17 ++-- fairmq/shmem/Manager.h | 97 ++++++++++++++++------- fairmq/shmem/Message.h | 162 +++++++++++++++++++++++++------------- test/CMakeLists.txt | 2 +- test/message/_message.cxx | 148 ++++++++++++++++++++++++++++++++-- test/region/_region.cxx | 1 - 7 files changed, 332 insertions(+), 99 deletions(-) diff --git a/fairmq/Message.h b/fairmq/Message.h index 686314cf3..5f862e31a 100644 --- a/fairmq/Message.h +++ b/fairmq/Message.h @@ -47,6 +47,10 @@ struct Message TransportFactory* GetTransport() { return fTransport; } void SetTransport(TransportFactory* transport) { fTransport = transport; } + /// Copy the message buffer from another message + /// Transport may choose not to physically copy the buffer, but to share across the messages. + /// Modifying the buffer after a call to Copy() is undefined behaviour. + /// @param msg message to copy the buffer from. virtual void Copy(const Message& msg) = 0; virtual ~Message() = default; diff --git a/fairmq/shmem/Common.h b/fairmq/shmem/Common.h index e6668c799..b9fabe66a 100644 --- a/fairmq/shmem/Common.h +++ b/fairmq/shmem/Common.h @@ -146,9 +146,10 @@ struct MetaHeader { size_t fSize; size_t fHint; - uint16_t fRegionId; - uint16_t fSegmentId; boost::interprocess::managed_shared_memory::handle_t fHandle; + mutable boost::interprocess::managed_shared_memory::handle_t fShared; + uint16_t fRegionId; + mutable uint16_t fSegmentId; }; #ifdef FAIRMQ_DEBUG_MODE @@ -271,22 +272,22 @@ struct SegmentHandleFromAddress : public boost::static_visitor +struct SegmentAddressFromHandle : public boost::static_visitor { SegmentAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t _handle) : handle(_handle) {} template - void* operator()(S& s) const { return s.get_address_from_handle(handle); } + char* operator()(S& s) const { return reinterpret_cast(s.get_address_from_handle(handle)); } const boost::interprocess::managed_shared_memory::handle_t handle; }; -struct SegmentAllocate : public boost::static_visitor +struct SegmentAllocate : public boost::static_visitor { SegmentAllocate(const size_t _size) : size(_size) {} template - void* operator()(S& s) const { return s.allocate(size); } + char* operator()(S& s) const { return reinterpret_cast(s.allocate(size)); } const size_t size; }; @@ -322,12 +323,12 @@ struct SegmentBufferShrink : public boost::static_visitor struct SegmentDeallocate : public boost::static_visitor<> { - SegmentDeallocate(void* _ptr) : ptr(_ptr) {} + SegmentDeallocate(char* _ptr) : ptr(_ptr) {} template void operator()(S& s) const { return s.deallocate(ptr); } - void* ptr; + char* ptr; }; } // namespace fair::mq::shmem diff --git a/fairmq/shmem/Manager.h b/fairmq/shmem/Manager.h index 9cd5c34dc..5ace4c0f7 100644 --- a/fairmq/shmem/Manager.h +++ b/fairmq/shmem/Manager.h @@ -52,29 +52,77 @@ #include // getuid #include // getuid - #include // mlock namespace fair::mq::shmem { -struct ShmPtr +// ShmHeader stores user buffer alignment and the reference count in the following structure: +// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer] +// The alignment of Hdr depends on the alignment of std::atomic and is stored in the first entry +struct ShmHeader { - explicit ShmPtr(char* rPtr) - : realPtr(rPtr) - {} + struct Hdr + { + uint16_t userOffset; + std::atomic refCount; + }; + + static Hdr* HdrPtr(char* ptr) + { + // [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer] + // ^ + return reinterpret_cast(ptr + sizeof(uint16_t) + *(reinterpret_cast(ptr))); + } - char* RealPtr() + static uint16_t HdrPartSize() // [HdrOffset(uint16_t)][Hdr alignment][Hdr] { - return realPtr; + // [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer] + // <---------------------------------------> + return sizeof(uint16_t) + alignof(Hdr) + sizeof(Hdr); } - char* UserPtr() + static std::atomic& RefCountPtr(char* ptr) { - return realPtr + sizeof(uint16_t) + *(reinterpret_cast(realPtr)); + // get the ref count ptr from the Hdr + return HdrPtr(ptr)->refCount; } - char* realPtr; + static char* UserPtr(char* ptr) + { + // [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer] + // ^ + return ptr + HdrPartSize() + HdrPtr(ptr)->userOffset; + } + + static uint16_t RefCount(char* ptr) { return RefCountPtr(ptr).load(); } + static uint16_t IncrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_add(1); } + static uint16_t DecrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_sub(1); } + + static size_t FullSize(size_t size, size_t alignment) + { + // [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer] + // <---------------------------------------------------------------------------> + return HdrPartSize() + alignment + size; + } + + static void Construct(char* ptr, size_t alignment) + { + // place the Hdr in the aligned location, fill it and store its offset to HdrOffset + + // the address alignment should be at least 2 + assert(reinterpret_cast(ptr) % 2 == 0); + + // offset to the beginning of the Hdr. store it in the beginning + uint16_t hdrOffset = alignof(Hdr) - ((reinterpret_cast(ptr) + sizeof(uint16_t)) % alignof(Hdr)); + memcpy(ptr, &hdrOffset, sizeof(hdrOffset)); + + // offset to the beginning of the user buffer, store in Hdr together with the ref count + uint16_t userOffset = alignment - ((reinterpret_cast(ptr) + HdrPartSize()) % alignment); + new(ptr + sizeof(uint16_t) + hdrOffset) Hdr{ userOffset, std::atomic(1) }; + } + + static void Destruct(char* ptr) { RefCountPtr(ptr).~atomic(); } }; class Manager @@ -635,44 +683,35 @@ class Manager { return boost::apply_visitor(SegmentHandleFromAddress(ptr), fSegments.at(segmentId)); } - void* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const + char* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const { return boost::apply_visitor(SegmentAddressFromHandle(handle), fSegments.at(segmentId)); } - ShmPtr Allocate(size_t size, size_t alignment = 0) + char* Allocate(size_t size, size_t alignment = 0) { alignment = std::max(alignment, alignof(std::max_align_t)); char* ptr = nullptr; - // [offset(uint16_t)][alignment][buffer] - size_t fullSize = sizeof(uint16_t) + alignment + size; - // tools::RateLimiter rateLimiter(20); + size_t fullSize = ShmHeader::FullSize(size, alignment); while (ptr == nullptr) { try { - // boost::interprocess::managed_shared_memory::size_type actualSize = size; - // char* hint = 0; // unused for boost::interprocess::allocate_new - // ptr = fSegments.at(fSegmentId).allocation_command(boost::interprocess::allocate_new, size, actualSize, hint); size_t segmentSize = boost::apply_visitor(SegmentSize(), fSegments.at(fSegmentId)); if (fullSize > segmentSize) { throw MessageBadAlloc(tools::ToString("Requested message size (", fullSize, ") exceeds segment size (", segmentSize, ")")); } - ptr = reinterpret_cast(boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId))); - assert(reinterpret_cast(ptr) % 2 == 0); - uint16_t offset = 0; - offset = alignment - ((reinterpret_cast(ptr) + sizeof(uint16_t)) % alignment); - std::memcpy(ptr, &offset, sizeof(offset)); + ptr = boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId)); + ShmHeader::Construct(ptr, alignment); } catch (boost::interprocess::bad_alloc& ba) { // LOG(warn) << "Shared memory full..."; if (ThrowingOnBadAlloc()) { throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId)))); } - // rateLimiter.maybe_sleep(); std::this_thread::sleep_for(std::chrono::milliseconds(50)); if (Interrupted()) { - return ShmPtr(ptr); + throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId)))); } else { continue; } @@ -684,18 +723,20 @@ class Manager (*fMsgDebug).emplace(fSegmentId, fShmVoidAlloc); } (*fMsgDebug).at(fSegmentId).emplace( - static_cast(GetHandleFromAddress(ShmPtr(ptr).UserPtr(), fSegmentId)), + static_cast(GetHandleFromAddress(ShmHeader::UserPtr(ptr), fSegmentId)), MsgDebug(getpid(), size, std::chrono::system_clock::now().time_since_epoch().count()) ); #endif } - return ShmPtr(ptr); + return ptr; } void Deallocate(boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) { - boost::apply_visitor(SegmentDeallocate(GetAddressFromHandle(handle, segmentId)), fSegments.at(segmentId)); + char* ptr = GetAddressFromHandle(handle, segmentId); + ShmHeader::Destruct(ptr); + boost::apply_visitor(SegmentDeallocate(ptr), fSegments.at(segmentId)); #ifdef FAIRMQ_DEBUG_MODE boost::interprocess::scoped_lock lock(fShmMtx); DecrementShmMsgCounter(segmentId); diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index ded96cde1..0e674e005 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -38,7 +38,7 @@ class Message final : public fair::mq::Message : fair::mq::Message(factory) , fManager(manager) , fQueued(false) - , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()} , fRegionPtr(nullptr) , fLocalPtr(nullptr) { @@ -49,7 +49,7 @@ class Message final : public fair::mq::Message : fair::mq::Message(factory) , fManager(manager) , fQueued(false) - , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()} , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) @@ -61,7 +61,7 @@ class Message final : public fair::mq::Message : fair::mq::Message(factory) , fManager(manager) , fQueued(false) - , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()} , fRegionPtr(nullptr) , fLocalPtr(nullptr) { @@ -73,7 +73,7 @@ class Message final : public fair::mq::Message : fair::mq::Message(factory) , fManager(manager) , fQueued(false) - , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()} , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) @@ -86,7 +86,7 @@ class Message final : public fair::mq::Message : fair::mq::Message(factory) , fManager(manager) , fQueued(false) - , fMeta{0, 0, 0, fManager.GetSegmentId(), -1} + , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()} , fRegionPtr(nullptr) , fLocalPtr(nullptr) { @@ -105,7 +105,7 @@ class Message final : public fair::mq::Message : fair::mq::Message(factory) , fManager(manager) , fQueued(false) - , fMeta{size, reinterpret_cast(hint), static_cast(region.get())->fRegionId, fManager.GetSegmentId(), -1} + , fMeta{size, reinterpret_cast(hint), -1, -1, static_cast(region.get())->fRegionId, fManager.GetSegmentId()} , fRegionPtr(nullptr) , fLocalPtr(static_cast(data)) { @@ -187,8 +187,7 @@ class Message final : public fair::mq::Message if (fMeta.fRegionId == 0) { if (fMeta.fSize > 0) { fManager.GetSegment(fMeta.fSegmentId); - ShmPtr shmPtr(reinterpret_cast(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId))); - fLocalPtr = shmPtr.UserPtr(); + fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); } else { fLocalPtr = nullptr; } @@ -218,8 +217,8 @@ class Message final : public fair::mq::Message } else if (newSize <= fMeta.fSize) { try { try { - ShmPtr shmPtr(fManager.ShrinkInPlace(newSize, static_cast(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)), fMeta.fSegmentId)); - fLocalPtr = shmPtr.UserPtr(); + char* ptr = fManager.ShrinkInPlace(newSize, fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId), fMeta.fSegmentId); + fLocalPtr = ShmHeader::UserPtr(ptr); fMeta.fSize = newSize; return true; } catch (boost::interprocess::bad_alloc& e) { @@ -227,17 +226,12 @@ class Message final : public fair::mq::Message // unused size >= 1000000 bytes: reallocate fully // unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction if (fMeta.fSize - newSize >= 1000000) { - ShmPtr shmPtr = fManager.Allocate(newSize, fAlignment); - if (shmPtr.RealPtr()) { - char* userPtr = shmPtr.UserPtr(); - std::memcpy(userPtr, fLocalPtr, newSize); - fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); - fLocalPtr = userPtr; - fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId); - } else { - LOG(debug) << "could not set used size: " << e.what(); - return false; - } + char* ptr = fManager.Allocate(newSize, fAlignment); + char* userPtr = ShmHeader::UserPtr(ptr); + std::memcpy(userPtr, fLocalPtr, newSize); + fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); + fLocalPtr = userPtr; + fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId); } fMeta.fSize = newSize; return true; @@ -254,33 +248,65 @@ class Message final : public fair::mq::Message Transport GetType() const override { return fair::mq::Transport::SHM; } - void Copy(const fair::mq::Message& msg) override + uint16_t GetRefCount() const { if (fMeta.fHandle < 0) { - boost::interprocess::managed_shared_memory::handle_t otherHandle = static_cast(msg).fMeta.fHandle; - if (otherHandle) { - if (InitializeChunk(msg.GetSize())) { - std::memcpy(GetData(), msg.GetData(), msg.GetSize()); - } + return 1; + } + + if (fMeta.fRegionId == 0) { // managed segment + fManager.GetSegment(fMeta.fSegmentId); + return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); + } else { // unmanaged region + if (fMeta.fShared < 0) { // UR msg is not yet shared + return 1; } else { - LOG(error) << "copy fail: source message not initialized!"; + fManager.GetSegment(fMeta.fSegmentId); + return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId)); } - } else { - LOG(error) << "copy fail: target message already initialized!"; } } - ~Message() override + void Copy(const fair::mq::Message& other) override { - try { + const Message& otherMsg = static_cast(other); + if (otherMsg.fMeta.fHandle < 0) { + // if the other message is not initialized, close this one too and return CloseMessage(); - } catch(SharedMemoryError& sme) { - LOG(error) << "error closing message: " << sme.what(); - } catch(boost::interprocess::lock_exception& le) { - LOG(error) << "error closing message: " << le.what(); + return; + } + + if (fMeta.fHandle >= 0) { + // if this msg is already initialized, close it first + CloseMessage(); + } + + if (otherMsg.fMeta.fRegionId == 0) { // managed segment + fMeta = otherMsg.fMeta; + fManager.GetSegment(fMeta.fSegmentId); + ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); + } else { // unmanaged region + if (otherMsg.fMeta.fShared < 0) { // if UR msg is not yet shared + // TODO: minimize the size to 0 and don't create extra space for user buffer alignment + char* ptr = fManager.Allocate(2, 0); + // point the fShared in the unmanaged region message to the refCount holder + otherMsg.fMeta.fShared = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId); + // the message needs to be able to locate in which segment the refCount is stored + otherMsg.fMeta.fSegmentId = fMeta.fSegmentId; + // point this message to the same content as the unmanaged region message + fMeta = otherMsg.fMeta; + // increment the refCount + ShmHeader::IncrementRefCount(ptr); + } else { // if the UR msg is already shared + fMeta = otherMsg.fMeta; + fManager.GetSegment(fMeta.fSegmentId); + ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId)); + } } } + ~Message() override { CloseMessage(); } + private: Manager& fManager; bool fQueued; @@ -291,44 +317,70 @@ class Message final : public fair::mq::Message char* InitializeChunk(const size_t size, size_t alignment = 0) { - ShmPtr shmPtr = fManager.Allocate(size, alignment); - if (shmPtr.RealPtr()) { - fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId); - fMeta.fSize = size; - fLocalPtr = shmPtr.UserPtr(); + if (size == 0) { + fMeta.fSize = 0; + return fLocalPtr; } + char* ptr = fManager.Allocate(size, alignment); + fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId); + fMeta.fSize = size; + fLocalPtr = ShmHeader::UserPtr(ptr); return fLocalPtr; } void Deallocate() { if (fMeta.fHandle >= 0 && !fQueued) { - if (fMeta.fRegionId == 0) { + if (fMeta.fRegionId == 0) { // managed segment fManager.GetSegment(fMeta.fSegmentId); - fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); - fMeta.fHandle = -1; - } else { - if (!fRegionPtr) { - fRegionPtr = fManager.GetRegion(fMeta.fRegionId); + uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); + if (refCount == 1) { + fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); } - - if (fRegionPtr) { - fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint}); + } else { // unmanaged region + if (fMeta.fShared >= 0) { + // make sure segment is initialized in this transport + fManager.GetSegment(fMeta.fSegmentId); + // release unmanaged region block if ref count is one + uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId)); + if (refCount == 1) { + fManager.Deallocate(fMeta.fShared, fMeta.fSegmentId); + ReleaseUnmanagedRegionBlock(); + } } else { - LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack"; + ReleaseUnmanagedRegionBlock(); } } } + fMeta.fHandle = -1; fLocalPtr = nullptr; fMeta.fSize = 0; } - void CloseMessage() + void ReleaseUnmanagedRegionBlock() { - Deallocate(); - fAlignment = 0; + if (!fRegionPtr) { + fRegionPtr = fManager.GetRegion(fMeta.fRegionId); + } + + if (fRegionPtr) { + fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint}); + } else { + LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack"; + } + } - fManager.DecrementMsgCounter(); + void CloseMessage() + { + try { + Deallocate(); + fAlignment = 0; + fManager.DecrementMsgCounter(); + } catch(SharedMemoryError& sme) { + LOG(error) << "error closing message: " << sme.what(); + } catch(boost::interprocess::lock_exception& le) { + LOG(error) << "error closing message: " << le.what(); + } } }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d22dbb42e..530e48472 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -89,7 +89,7 @@ add_testsuite(Message ${CMAKE_CURRENT_BINARY_DIR}/runner.cxx message/_message.cxx - LINKS FairMQ + LINKS FairMQ PicoSHA2 INCLUDES ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/message ${CMAKE_CURRENT_BINARY_DIR} diff --git a/test/message/_message.cxx b/test/message/_message.cxx index 672cba29c..08684a379 100644 --- a/test/message/_message.cxx +++ b/test/message/_message.cxx @@ -6,19 +6,23 @@ * copied verbatim in the file "LICENSE" * ********************************************************************************/ -#include -#include -#include #include #include #include -#include +#include #include #include +#include +#include + #include + +#include +#include +#include #include -#include #include +#include #include namespace @@ -190,7 +194,6 @@ auto EmptyMessage(string const& transport, string const& _address) -> void push.Bind(address); pull.Connect(address); - { auto outMsg(push.NewMessage()); ASSERT_EQ(outMsg->GetData(), nullptr); @@ -227,6 +230,129 @@ auto EmptyMessage(string const& transport, string const& _address) -> void } } +// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed. +// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports. +auto ZeroCopy() -> void +{ + ProgOptions config; + config.SetProperty("session", tools::Uuid()); + auto factory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config)); + + unique_ptr str(make_unique("asdf")); + const size_t size = 2; + MessagePtr original(factory->CreateMessage(size)); + memcpy(original->GetData(), "AB", size); + { + MessagePtr copy(factory->CreateMessage()); + copy->Copy(*original); + EXPECT_EQ(original->GetSize(), copy->GetSize()); + EXPECT_EQ(original->GetData(), copy->GetData()); + EXPECT_EQ(static_cast(*original).GetRefCount(), 2); + EXPECT_EQ(static_cast(*copy).GetRefCount(), 2); + + // buffer must be still intact + ASSERT_EQ(AsStringView(*original)[0], 'A'); + ASSERT_EQ(AsStringView(*original)[1], 'B'); + ASSERT_EQ(AsStringView(*copy)[0], 'A'); + ASSERT_EQ(AsStringView(*copy)[1], 'B'); + } + EXPECT_EQ(static_cast(*original).GetRefCount(), 1); +} + +// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed. +// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports. +auto ZeroCopyFromUnmanaged(string const& address) -> void +{ + ProgOptions config1; + ProgOptions config2; + string session(tools::Uuid()); + config1.SetProperty("session", session); + config2.SetProperty("session", session); + // ref counts should be accessible accross different segments + config2.SetProperty("shm-segment-id", 2); + auto factory1(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config1)); + auto factory2(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config2)); + + const size_t msgSize{100}; + const size_t regionSize{1000000}; + tools::Semaphore blocker; + + auto region = factory1->CreateUnmanagedRegion(regionSize, [&blocker](void*, size_t, void*) { + blocker.Signal(); + }); + + { + FairMQChannel push("Push", "push", factory1); + FairMQChannel pull("Pull", "pull", factory2); + + push.Bind(address); + pull.Connect(address); + + const size_t offset = 100; + auto msg1(push.NewMessage(region, static_cast(region->GetData()), msgSize, nullptr)); + auto msg2(push.NewMessage(region, static_cast(region->GetData()) + offset, msgSize, nullptr)); + const size_t contentSize = 2; + memcpy(msg1->GetData(), "AB", contentSize); + memcpy(msg2->GetData(), "CD", contentSize); + EXPECT_EQ(static_cast(*msg1).GetRefCount(), 1); + + { + auto copyFromOriginal(push.NewMessage()); + copyFromOriginal->Copy(*msg1); + EXPECT_EQ(static_cast(*msg1).GetRefCount(), 2); + EXPECT_EQ(static_cast(*msg1).GetRefCount(), static_cast(*copyFromOriginal).GetRefCount()); + { + auto copyFromCopy(push.NewMessage()); + copyFromCopy->Copy(*copyFromOriginal); + EXPECT_EQ(static_cast(*msg1).GetRefCount(), 3); + EXPECT_EQ(static_cast(*msg1).GetRefCount(), static_cast(*copyFromCopy).GetRefCount()); + + EXPECT_EQ(msg1->GetSize(), copyFromOriginal->GetSize()); + EXPECT_EQ(msg1->GetData(), copyFromOriginal->GetData()); + EXPECT_EQ(msg1->GetSize(), copyFromCopy->GetSize()); + EXPECT_EQ(msg1->GetData(), copyFromCopy->GetData()); + EXPECT_EQ(copyFromOriginal->GetSize(), copyFromCopy->GetSize()); + EXPECT_EQ(copyFromOriginal->GetData(), copyFromCopy->GetData()); + + // messing with the ref count should not have affected the user buffer + ASSERT_EQ(AsStringView(*msg1)[0], 'A'); + ASSERT_EQ(AsStringView(*msg1)[1], 'B'); + + push.Send(copyFromCopy); + push.Send(msg2); + + auto incomingCopiedMsg(pull.NewMessage()); + auto incomingOriginalMsg(pull.NewMessage()); + pull.Receive(incomingCopiedMsg); + pull.Receive(incomingOriginalMsg); + + EXPECT_EQ(static_cast(*incomingCopiedMsg).GetRefCount(), 3); + EXPECT_EQ(static_cast(*incomingOriginalMsg).GetRefCount(), 1); + + ASSERT_EQ(AsStringView(*incomingCopiedMsg)[0], 'A'); + ASSERT_EQ(AsStringView(*incomingCopiedMsg)[1], 'B'); + + { + // copying on a different segment should work + auto copyFromIncoming(pull.NewMessage()); + copyFromIncoming->Copy(*incomingOriginalMsg); + EXPECT_EQ(static_cast(*copyFromIncoming).GetRefCount(), 2); + + ASSERT_EQ(AsStringView(*incomingOriginalMsg)[0], 'C'); + ASSERT_EQ(AsStringView(*incomingOriginalMsg)[1], 'D'); + } + + EXPECT_EQ(static_cast(*incomingOriginalMsg).GetRefCount(), 1); + } + EXPECT_EQ(static_cast(*msg1).GetRefCount(), 2); + } + EXPECT_EQ(static_cast(*msg1).GetRefCount(), 1); + } + + blocker.Wait(); + blocker.Wait(); +} + TEST(Resize, zeromq) // NOLINT { RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize"); @@ -267,4 +393,14 @@ TEST(EmptyMessage, shmem) // NOLINT EmptyMessage("shmem", "ipc://test_empty_message"); } +TEST(ZeroCopy, shmem) // NOLINT +{ + ZeroCopy(); +} + +TEST(ZeroCopyFromUnmanaged, shmem) // NOLINT +{ + ZeroCopyFromUnmanaged("ipc://test_zerocopy_unmanaged"); +} + } // namespace diff --git a/test/region/_region.cxx b/test/region/_region.cxx index d31de36da..b2a63ee36 100644 --- a/test/region/_region.cxx +++ b/test/region/_region.cxx @@ -199,7 +199,6 @@ void RegionCallbacks(const string& transport, const string& _address) }); ptr2 = region2->GetData(); - { FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get())); FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));