Skip to content

Commit

Permalink
Merge c785944 into 11e269e
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherHogan committed Apr 7, 2021
2 parents 11e269e + c785944 commit d370185
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 10 deletions.
36 changes: 31 additions & 5 deletions src/api/bucket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ size_t Bucket::GetBlobSize(Arena *arena, const std::string &name,
}

size_t Bucket::Get(const std::string &name, Blob &user_blob, Context &ctx) {
size_t ret = Get(name, user_blob.data(), user_blob.size(), ctx);

return ret;
}

size_t Bucket::Get(const std::string &name, void *user_blob, size_t blob_size,
Context &ctx) {
(void)ctx;

size_t ret = 0;
Expand All @@ -103,20 +110,39 @@ size_t Bucket::Get(const std::string &name, Blob &user_blob, Context &ctx) {
// TODO(chogan): Assumes scratch is big enough to hold buffer_ids
ScopedTemporaryMemory scratch(&hermes_->trans_arena_);

if (user_blob.size() == 0) {
ret = GetBlobSize(scratch, name, ctx);
} else {
if (user_blob && blob_size != 0) {
hermes::Blob blob = {};
blob.data = (u8 *)user_blob;
blob.size = blob_size;
LOG(INFO) << "Getting Blob " << name << " from bucket " << name_ << '\n';
BlobID blob_id = GetBlobId(&hermes_->context_, &hermes_->rpc_,
name, id_);
name, id_);
ret = ReadBlobById(&hermes_->context_, &hermes_->rpc_,
&hermes_->trans_arena_, user_blob, blob_id);
&hermes_->trans_arena_, blob, blob_id);
} else {
ret = GetBlobSize(scratch, name, ctx);
}
}

return ret;
}

std::vector<size_t> Bucket::Get(const std::vector<std::string> &names,
std::vector<Blob> &blobs, Context &ctx) {
std::vector<size_t> result(names.size(), 0);
if (names.size() == blobs.size()) {
for (size_t i = 0; i < result.size(); ++i) {
result[i] = Get(names[i], blobs[i], ctx);
}
} else {
LOG(ERROR) << "names.size() != blobs.size() in Bucket::Get ("
<< names.size() << " != " << blobs.size() << ")"
<< std::endl;
}

return result;
}

template<class Predicate>
Status Bucket::GetV(void *user_blob, Predicate pred, Context &ctx) {
(void)user_blob;
Expand Down
12 changes: 12 additions & 0 deletions src/api/bucket.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ class Bucket {
/** use provides buffer */
size_t Get(const std::string &name, Blob& user_blob, Context &ctx);

/**
* \brief Retrieve multiple Blobs in one call.
*/
std::vector<size_t> Get(const std::vector<std::string> &names,
std::vector<Blob> &blobs, Context &ctx);

/**
* \brief Retrieve a Blob into a user buffer.
*/
size_t Get(const std::string &name, void *user_blob, size_t blob_size,
Context &ctx);

/** get blob(s) on this bucket according to predicate */
/** use provides buffer */
template<class Predicate>
Expand Down
17 changes: 12 additions & 5 deletions src/buffer_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,11 +1529,8 @@ size_t ReadBlobFromBuffers(SharedMemoryContext *context, RpcContext *rpc,
}

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
api::Blob &dest, BlobID blob_id) {
Blob blob, BlobID blob_id) {
size_t result = 0;
hermes::Blob blob = {};
blob.data = dest.data();
blob.size = dest.size();

BufferIdArray buffer_ids = {};
if (hermes::BlobIsInSwap(blob_id)) {
Expand All @@ -1543,14 +1540,24 @@ size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
} else {
u32 *buffer_sizes = 0;
buffer_ids = GetBufferIdsFromBlobId(arena, context, rpc, blob_id,
&buffer_sizes);
&buffer_sizes);
result = ReadBlobFromBuffers(context, rpc, &blob, &buffer_ids,
buffer_sizes);
}

return result;
}

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
api::Blob &dest, BlobID blob_id) {
hermes::Blob blob = {};
blob.data = dest.data();
blob.size = dest.size();
size_t result = ReadBlobById(context, rpc, arena, blob, blob_id);

return result;
}

int OpenSwapFile(SharedMemoryContext *context, u32 node_id) {
int result = 0;

Expand Down
3 changes: 3 additions & 0 deletions src/buffer_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ size_t ReadBlobFromBuffers(SharedMemoryContext *context, RpcContext *rpc,
Blob *blob, BufferIdArray *buffer_ids,
u32 *buffer_sizes);

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
Blob blob, BlobID blob_id);

size_t ReadBlobById(SharedMemoryContext *context, RpcContext *rpc, Arena *arena,
api::Blob &dest, BlobID blob_id);

Expand Down
49 changes: 49 additions & 0 deletions test/bucket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,54 @@ void TestCompressionTrait(std::shared_ptr<hapi::Hermes> hermes) {
Assert(my_bucket.Destroy(ctx).Succeeded());
}

void TestMultiGet(std::shared_ptr<hapi::Hermes> hermes) {
const size_t num_blobs = 4;
const int blob_size = KILOBYTES(4);

std::vector<std::string> blob_names(num_blobs);
for (size_t i = 0; i < num_blobs; ++i) {
blob_names[i]= "Blob" + std::to_string(i);
}

std::vector<hapi::Blob> blobs(num_blobs);
for (size_t i = 0; i < num_blobs; ++i) {
blobs[i] = hapi::Blob(blob_size, (char)i);
}

hapi::Context ctx;
const std::string bucket_name = "b1";
hapi::Bucket bucket(bucket_name, hermes, ctx);

for (size_t i = 0; i < num_blobs; ++i) {
Assert(bucket.Put(blob_names[i], blobs[i], ctx).Succeeded());
}

std::vector<hapi::Blob> retrieved_blobs(num_blobs);
std::vector<size_t> sizes = bucket.Get(blob_names, retrieved_blobs, ctx);

for (size_t i = 0; i < num_blobs; ++i) {
retrieved_blobs[i].resize(sizes[i]);
}

sizes = bucket.Get(blob_names, retrieved_blobs, ctx);
for (size_t i = 0; i < num_blobs; ++i) {
Assert(blobs[i] == retrieved_blobs[i]);
Assert(sizes[i] == retrieved_blobs[i].size());
}

// Test Get into user buffer
hermes::u8 user_buffer[blob_size] = {};
size_t b1_size = bucket.Get(blob_names[0], nullptr, 0, ctx);
Assert(b1_size == blob_size);
b1_size = bucket.Get(blob_names[0], user_buffer, b1_size, ctx);

for (size_t i = 0; i < b1_size; ++i) {
Assert(user_buffer[i] == blobs[0][i]);
}

bucket.Destroy(ctx);
}

int main(int argc, char **argv) {
int mpi_threads_provided;
MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &mpi_threads_provided);
Expand All @@ -217,6 +265,7 @@ int main(int argc, char **argv) {
TestCompressionTrait(hermes_app);
TestBucketPersist(hermes_app);
TestPutOverwrite(hermes_app);
TestMultiGet(hermes_app);
} else {
// Hermes core. No user code here.
}
Expand Down

0 comments on commit d370185

Please sign in to comment.