From 6bac4935eaec2a54aa76c115387d93816793255a Mon Sep 17 00:00:00 2001 From: Chris Hogan Date: Tue, 6 Apr 2021 15:09:21 -0500 Subject: [PATCH] Add size and pointer version of Bucket::Get --- src/api/bucket.cc | 20 +++++++++++++++----- src/api/bucket.h | 2 +- src/buffer_pool.cc | 17 ++++++++++++----- src/buffer_pool.h | 3 +++ test/bucket_test.cc | 10 ++++++++++ 5 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/api/bucket.cc b/src/api/bucket.cc index 27cc78be4..03addea3f 100644 --- a/src/api/bucket.cc +++ b/src/api/bucket.cc @@ -95,6 +95,13 @@ size_t Bucket::GetBlobSize(Arena *arena, const std::string &name, } size_t Bucket::Get(const std::string &name, Blob &user_blob, Context &ctx) { + size_t ret = Get(name, user_blob.data(), user_blob.size(), ctx); + + return ret; +} + +size_t Bucket::Get(const std::string &name, void *user_blob, size_t blob_size, + Context &ctx) { (void)ctx; size_t ret = 0; @@ -103,14 +110,17 @@ size_t Bucket::Get(const std::string &name, Blob &user_blob, Context &ctx) { // TODO(chogan): Assumes scratch is big enough to hold buffer_ids ScopedTemporaryMemory scratch(&hermes_->trans_arena_); - if (user_blob.size() == 0) { - ret = GetBlobSize(scratch, name, ctx); - } else { + if (user_blob && blob_size != 0) { + hermes::Blob blob = {}; + blob.data = (u8 *)user_blob; + blob.size = blob_size; LOG(INFO) << "Getting Blob " << name << " from bucket " << name_ << '\n'; BlobID blob_id = GetBlobId(&hermes_->context_, &hermes_->rpc_, - name, id_); + name, id_); ret = ReadBlobById(&hermes_->context_, &hermes_->rpc_, - &hermes_->trans_arena_, user_blob, blob_id); + &hermes_->trans_arena_, blob, blob_id); + } else { + ret = GetBlobSize(scratch, name, ctx); } } diff --git a/src/api/bucket.h b/src/api/bucket.h index e8155be8e..37f16f179 100644 --- a/src/api/bucket.h +++ b/src/api/bucket.h @@ -144,7 +144,7 @@ class Bucket { /** * \brief Retrieve a Blob into a user buffer. */ - size_t Get(const std::string &name, void *user_blob, size_t buffer_size, + size_t Get(const std::string &name, void *user_blob, size_t blob_size, Context &ctx); /** get blob(s) on this bucket according to predicate */ diff --git a/src/buffer_pool.cc b/src/buffer_pool.cc index 52ce933bd..3e3f4b1f6 100644 --- a/src/buffer_pool.cc +++ b/src/buffer_pool.cc @@ -1529,11 +1529,8 @@ size_t ReadBlobFromBuffers(SharedMemoryContext *context, RpcContext *rpc, } size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena, - api::Blob &dest, BlobID blob_id) { + Blob blob, BlobID blob_id) { size_t result = 0; - hermes::Blob blob = {}; - blob.data = dest.data(); - blob.size = dest.size(); BufferIdArray buffer_ids = {}; if (hermes::BlobIsInSwap(blob_id)) { @@ -1543,7 +1540,7 @@ size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena, } else { u32 *buffer_sizes = 0; buffer_ids = GetBufferIdsFromBlobId(arena, context, rpc, blob_id, - &buffer_sizes); + &buffer_sizes); result = ReadBlobFromBuffers(context, rpc, &blob, &buffer_ids, buffer_sizes); } @@ -1551,6 +1548,16 @@ size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena, return result; } +size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena, + api::Blob &dest, BlobID blob_id) { + hermes::Blob blob = {}; + blob.data = dest.data(); + blob.size = dest.size(); + size_t result = ReadBlobById(context, rpc, arena, blob, blob_id); + + return result; +} + int OpenSwapFile(SharedMemoryContext *context, u32 node_id) { int result = 0; diff --git a/src/buffer_pool.h b/src/buffer_pool.h index b7aefba25..ed9dd30d9 100644 --- a/src/buffer_pool.h +++ b/src/buffer_pool.h @@ -435,6 +435,9 @@ size_t ReadBlobFromBuffers(SharedMemoryContext *context, RpcContext *rpc, Blob *blob, BufferIdArray *buffer_ids, u32 *buffer_sizes); +size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena, + Blob blob, BlobID blob_id); + size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena, api::Blob &dest, BlobID blob_id); diff --git a/test/bucket_test.cc b/test/bucket_test.cc index 32d505c28..e7fe5d224 100644 --- a/test/bucket_test.cc +++ b/test/bucket_test.cc @@ -232,6 +232,16 @@ void TestMultiGet(std::shared_ptr hermes) { Assert(blobs[i] == retrieved_blobs[i]); } + // Test Get into user buffer + hermes::u8 user_buffer[blob_size] = {}; + size_t b1_size = bucket.Get(blob_names[0], nullptr, 0, ctx); + Assert(b1_size == blob_size); + b1_size = bucket.Get(blob_names[0], user_buffer, b1_size, ctx); + + for (size_t i = 0; i < b1_size; ++i) { + Assert(user_buffer[i] == blobs[0][i]); + } + bucket.Destroy(ctx); }