Skip to content

Commit 8c7534b

Browse files
Fix offset calculation during batching
1 parent e88f96c commit 8c7534b

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

include/proteus/core/predict_api.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class InferenceRequestInput {
214214
parameters_ = parameters;
215215
}
216216

217-
/// Set the tensor's size
217+
/// Get the tensor's size (number of elements)
218218
size_t getSize() const;
219219

220220
/// Provide an implementation to print the class with std::cout to an ostream

src/proteus/clients/http_internal.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ InferenceRequestPtr RequestBuilder::build(
467467

468468
auto input =
469469
InputBuilder::build(std::make_shared<Json::Value>(i), buffer, offset);
470-
offset += input.getSize();
470+
offset += (input.getSize() * input.getDatatype().size());
471471

472472
request->inputs_.push_back(std::move(input));
473473
}

src/proteus/clients/native_internal.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class InferenceRequestBuilder<InferenceRequest> {
9595

9696
request->inputs_.push_back(
9797
InputBuilder::build(input, buffer, offset));
98-
offset += request->inputs_.back().getSize();
98+
const auto &last_input = request->inputs_.back();
99+
offset += (last_input.getSize() * last_input.getDatatype().size());
99100
}
100101
} catch (const std::invalid_argument &e) {
101102
throw;

src/proteus/servers/grpc_server.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,8 @@ class InferenceRequestBuilder<CallDataModelInfer*> {
385385

386386
request->inputs_.push_back(
387387
InputBuilder::build(input, buffer, offset));
388-
offset += request->inputs_.back().getSize();
388+
const auto& last_input = request->inputs_.back();
389+
offset += (last_input.getSize() * last_input.getDatatype().size());
389390
}
390391
} catch (const std::invalid_argument& e) {
391392
throw;

0 commit comments

Comments
 (0)