diff --git a/fairmq/shmem/Manager.h b/fairmq/shmem/Manager.h index 4eb86cfb1..9cd5c34dc 100644 --- a/fairmq/shmem/Manager.h +++ b/fairmq/shmem/Manager.h @@ -33,8 +33,11 @@ #include #include +#include // max #include +#include // max_align_t #include // getenv +#include // memcpy #include // make_unique #include #include @@ -55,6 +58,25 @@ namespace fair::mq::shmem { +struct ShmPtr +{ + explicit ShmPtr(char* rPtr) + : realPtr(rPtr) + {} + + char* RealPtr() + { + return realPtr; + } + + char* UserPtr() + { + return realPtr + sizeof(uint16_t) + *(reinterpret_cast(realPtr)); + } + + char* realPtr; +}; + class Manager { public: @@ -618,9 +640,13 @@ class Manager return boost::apply_visitor(SegmentAddressFromHandle(handle), fSegments.at(segmentId)); } - char* Allocate(const size_t size, size_t alignment = 0) + ShmPtr Allocate(size_t size, size_t alignment = 0) { + alignment = std::max(alignment, alignof(std::max_align_t)); + char* ptr = nullptr; + // [offset(uint16_t)][alignment][buffer] + size_t fullSize = sizeof(uint16_t) + alignment + size; // tools::RateLimiter rateLimiter(20); while (ptr == nullptr) { @@ -629,14 +655,15 @@ class Manager // char* hint = 0; // unused for boost::interprocess::allocate_new // ptr = fSegments.at(fSegmentId).allocation_command(boost::interprocess::allocate_new, size, actualSize, hint); size_t segmentSize = boost::apply_visitor(SegmentSize(), fSegments.at(fSegmentId)); - if (size > segmentSize) { - throw MessageBadAlloc(tools::ToString("Requested message size (", size, ") exceeds segment size (", segmentSize, ")")); - } - if (alignment == 0) { - ptr = reinterpret_cast(boost::apply_visitor(SegmentAllocate{size}, fSegments.at(fSegmentId))); - } else { - ptr = reinterpret_cast(boost::apply_visitor(SegmentAllocateAligned{size, alignment}, fSegments.at(fSegmentId))); + if (fullSize > segmentSize) { + throw MessageBadAlloc(tools::ToString("Requested message size (", fullSize, ") exceeds segment size (", segmentSize, ")")); } + + ptr = reinterpret_cast(boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId))); + assert(reinterpret_cast(ptr) % 2 == 0); + uint16_t offset = 0; + offset = alignment - ((reinterpret_cast(ptr) + sizeof(uint16_t)) % alignment); + std::memcpy(ptr, &offset, sizeof(offset)); } catch (boost::interprocess::bad_alloc& ba) { // LOG(warn) << "Shared memory full..."; if (ThrowingOnBadAlloc()) { @@ -645,7 +672,7 @@ class Manager // rateLimiter.maybe_sleep(); std::this_thread::sleep_for(std::chrono::milliseconds(50)); if (Interrupted()) { - return ptr; + return ShmPtr(ptr); } else { continue; } @@ -657,13 +684,13 @@ class Manager (*fMsgDebug).emplace(fSegmentId, fShmVoidAlloc); } (*fMsgDebug).at(fSegmentId).emplace( - static_cast(GetHandleFromAddress(ptr, fSegmentId)), + static_cast(GetHandleFromAddress(ShmPtr(ptr).UserPtr(), fSegmentId)), MsgDebug(getpid(), size, std::chrono::system_clock::now().time_since_epoch().count()) ); #endif } - return ptr; + return ShmPtr(ptr); } void Deallocate(boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index d4af19f09..ded96cde1 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -187,7 +187,8 @@ class Message final : public fair::mq::Message if (fMeta.fRegionId == 0) { if (fMeta.fSize > 0) { fManager.GetSegment(fMeta.fSegmentId); - fLocalPtr = reinterpret_cast(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); + ShmPtr shmPtr(reinterpret_cast(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId))); + fLocalPtr = shmPtr.UserPtr(); } else { fLocalPtr = nullptr; } @@ -202,7 +203,7 @@ class Message final : public fair::mq::Message } } - return fLocalPtr; + return static_cast(fLocalPtr); } size_t GetSize() const override { return fMeta.fSize; } @@ -217,7 +218,8 @@ class Message final : public fair::mq::Message } else if (newSize <= fMeta.fSize) { try { try { - fLocalPtr = fManager.ShrinkInPlace(newSize, fLocalPtr, fMeta.fSegmentId); + ShmPtr shmPtr(fManager.ShrinkInPlace(newSize, static_cast(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)), fMeta.fSegmentId)); + fLocalPtr = shmPtr.UserPtr(); fMeta.fSize = newSize; return true; } catch (boost::interprocess::bad_alloc& e) { @@ -225,12 +227,13 @@ class Message final : public fair::mq::Message // unused size >= 1000000 bytes: reallocate fully // unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction if (fMeta.fSize - newSize >= 1000000) { - char* newPtr = fManager.Allocate(newSize, fAlignment); - if (newPtr) { - std::memcpy(newPtr, fLocalPtr, newSize); + ShmPtr shmPtr = fManager.Allocate(newSize, fAlignment); + if (shmPtr.RealPtr()) { + char* userPtr = shmPtr.UserPtr(); + std::memcpy(userPtr, fLocalPtr, newSize); fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); - fLocalPtr = newPtr; - fMeta.fHandle = fManager.GetHandleFromAddress(fLocalPtr, fMeta.fSegmentId); + fLocalPtr = userPtr; + fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId); } else { LOG(debug) << "could not set used size: " << e.what(); return false; @@ -288,10 +291,11 @@ class Message final : public fair::mq::Message char* InitializeChunk(const size_t size, size_t alignment = 0) { - fLocalPtr = fManager.Allocate(size, alignment); - if (fLocalPtr) { - fMeta.fHandle = fManager.GetHandleFromAddress(fLocalPtr, fMeta.fSegmentId); + ShmPtr shmPtr = fManager.Allocate(size, alignment); + if (shmPtr.RealPtr()) { + fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId); fMeta.fSize = size; + fLocalPtr = shmPtr.UserPtr(); } return fLocalPtr; }