Skip to content

Commit

Permalink
Add size and pointer version of Bucket::Get
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherHogan committed Apr 6, 2021
1 parent c0ca435 commit 6bac493
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 11 deletions.
20 changes: 15 additions & 5 deletions src/api/bucket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ size_t Bucket::GetBlobSize(Arena *arena, const std::string &name,
}

size_t Bucket::Get(const std::string &name, Blob &user_blob, Context &ctx) {
size_t ret = Get(name, user_blob.data(), user_blob.size(), ctx);

return ret;
}

size_t Bucket::Get(const std::string &name, void *user_blob, size_t blob_size,
Context &ctx) {
(void)ctx;

size_t ret = 0;
Expand All @@ -103,14 +110,17 @@ size_t Bucket::Get(const std::string &name, Blob &user_blob, Context &ctx) {
// TODO(chogan): Assumes scratch is big enough to hold buffer_ids
ScopedTemporaryMemory scratch(&hermes_->trans_arena_);

if (user_blob.size() == 0) {
ret = GetBlobSize(scratch, name, ctx);
} else {
if (user_blob && blob_size != 0) {
hermes::Blob blob = {};
blob.data = (u8 *)user_blob;
blob.size = blob_size;
LOG(INFO) << "Getting Blob " << name << " from bucket " << name_ << '\n';
BlobID blob_id = GetBlobId(&hermes_->context_, &hermes_->rpc_,
name, id_);
name, id_);
ret = ReadBlobById(&hermes_->context_, &hermes_->rpc_,
&hermes_->trans_arena_, user_blob, blob_id);
&hermes_->trans_arena_, blob, blob_id);
} else {
ret = GetBlobSize(scratch, name, ctx);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/api/bucket.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class Bucket {
/**
* \brief Retrieve a Blob into a user buffer.
*/
size_t Get(const std::string &name, void *user_blob, size_t buffer_size,
size_t Get(const std::string &name, void *user_blob, size_t blob_size,
Context &ctx);

/** get blob(s) on this bucket according to predicate */
Expand Down
17 changes: 12 additions & 5 deletions src/buffer_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,11 +1529,8 @@ size_t ReadBlobFromBuffers(SharedMemoryContext *context, RpcContext *rpc,
}

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
api::Blob &dest, BlobID blob_id) {
Blob blob, BlobID blob_id) {
size_t result = 0;
hermes::Blob blob = {};
blob.data = dest.data();
blob.size = dest.size();

BufferIdArray buffer_ids = {};
if (hermes::BlobIsInSwap(blob_id)) {
Expand All @@ -1543,14 +1540,24 @@ size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
} else {
u32 *buffer_sizes = 0;
buffer_ids = GetBufferIdsFromBlobId(arena, context, rpc, blob_id,
&buffer_sizes);
&buffer_sizes);
result = ReadBlobFromBuffers(context, rpc, &blob, &buffer_ids,
buffer_sizes);
}

return result;
}

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
api::Blob &dest, BlobID blob_id) {
hermes::Blob blob = {};
blob.data = dest.data();
blob.size = dest.size();
size_t result = ReadBlobById(context, rpc, arena, blob, blob_id);

return result;
}

int OpenSwapFile(SharedMemoryContext *context, u32 node_id) {
int result = 0;

Expand Down
3 changes: 3 additions & 0 deletions src/buffer_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ size_t ReadBlobFromBuffers(SharedMemoryContext *context, RpcContext *rpc,
Blob *blob, BufferIdArray *buffer_ids,
u32 *buffer_sizes);

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
Blob blob, BlobID blob_id);

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
api::Blob &dest, BlobID blob_id);

Expand Down
10 changes: 10 additions & 0 deletions test/bucket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@ void TestMultiGet(std::shared_ptr<hapi::Hermes> hermes) {
Assert(blobs[i] == retrieved_blobs[i]);
}

// Test Get into user buffer
hermes::u8 user_buffer[blob_size] = {};
size_t b1_size = bucket.Get(blob_names[0], nullptr, 0, ctx);
Assert(b1_size == blob_size);
b1_size = bucket.Get(blob_names[0], user_buffer, b1_size, ctx);

for (size_t i = 0; i < b1_size; ++i) {
Assert(user_buffer[i] == blobs[0][i]);
}

bucket.Destroy(ctx);
}

Expand Down

0 comments on commit 6bac493

Please sign in to comment.