Skip to content

Commit

Permalink
iter-to-iter variable batch size (#2481)
Browse files Browse the repository at this point in the history
* Update C API to support i2i variable batch size

* Update nvJpegDecoder and its flavours to support variable BS

* Update Constant op for the same reason

* Partially split Executor into h and cc file, to make development convenient

* Also add variable batch size routines to Executor (batch size queues, PreRun with BS inferring from the graph)

* Introduce BatchSizeProvider interface and add it to ExternalSource

* Add a Prophet to CachingList (inner ExternalSource memory), so that it can traverse over the data list asynchronously w.r.t current head and tail of this list

* Modify constraint that stipulates constant batch size to stipulate uniform batch size within single iteration

* Python API adjustment

* Adding enormous test for i2i variable batch size (every op is tested)

Signed-off-by: Michał Szołucha <mszolucha@nvidia.com>
  • Loading branch information
szalpal committed Jan 22, 2021
1 parent a2f39ba commit 3dd70d6
Show file tree
Hide file tree
Showing 26 changed files with 1,703 additions and 426 deletions.
75 changes: 53 additions & 22 deletions dali/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,37 @@ namespace {

bool dali_initialized = false;

template<typename Backend>
/**
* Maps operator name to the batch size set prior to daliSetExternal... call.
* Typically, this operator will be BatchSizeProvider.
* Negative values denote max batch size (default state).
* Typical usage:
* auto batch_sizes_map = reinterpret_cast<batch_size_map_t*>(handle->batch_sizes_map);
*/
using batch_size_map_t = std::unordered_map<std::string /* op_name */, int /* batch_size */>;


int PopCurrBatchSize(batch_size_map_t *batch_size_map, int max_batch_size,
const std::string &op_name) {
auto it = batch_size_map->find(op_name);
auto exists = it != batch_size_map->end();
auto ret = !exists || it->second < 0 ? max_batch_size : it->second;
if (exists) {
it->second = -1;
}
return ret;
}


template <typename Backend>
void SetExternalInput(daliPipelineHandle *pipe_handle, const char *name, const void *data_ptr,
dali_data_type_t data_type, const int64_t *shapes, int sample_dim,
const char *layout_str, cudaStream_t stream = 0, unsigned int flags = 0) {
dali::Pipeline *pipeline = reinterpret_cast<dali::Pipeline *>(pipe_handle->pipe);
std::vector<int64_t> shapes_tmp(shapes, shapes + sample_dim * pipeline->batch_size());
dali::TensorListShape<> tl_shape(std::move(shapes_tmp), pipeline->batch_size(), sample_dim);
auto bs_map = reinterpret_cast<batch_size_map_t *>(pipe_handle->batch_sizes_map);
auto curr_batch_size = PopCurrBatchSize(bs_map, pipeline->max_batch_size(), name);
std::vector<int64_t> shapes_tmp(shapes, shapes + sample_dim * curr_batch_size);
dali::TensorListShape<> tl_shape(std::move(shapes_tmp), curr_batch_size, sample_dim);
dali::TensorLayout layout{};
if (layout_str != nullptr) {
layout = dali::TensorLayout(layout_str);
Expand All @@ -68,16 +92,18 @@ void SetExternalInputTensors(daliPipelineHandle *pipe_handle, const char *name,
const int64_t *shapes, int64_t sample_dim, const char *layout_str,
cudaStream_t stream = 0, unsigned int flags = 0) {
dali::Pipeline *pipeline = reinterpret_cast<dali::Pipeline *>(pipe_handle->pipe);
std::vector<int64_t> shapes_tmp(shapes, shapes + sample_dim * pipeline->batch_size());
dali::TensorListShape<> tl_shape(std::move(shapes_tmp), pipeline->batch_size(), sample_dim);
auto bs_map = reinterpret_cast<batch_size_map_t *>(pipe_handle->batch_sizes_map);
auto curr_batch_size = PopCurrBatchSize(bs_map, pipeline->max_batch_size(), name);
std::vector<int64_t> shapes_tmp(shapes, shapes + sample_dim * curr_batch_size);
dali::TensorListShape<> tl_shape(std::move(shapes_tmp), curr_batch_size, sample_dim);
dali::TensorLayout layout{};
if (layout_str != nullptr) {
layout = dali::TensorLayout(layout_str);
}
dali::TensorVector<Backend> data(pipeline->batch_size());
dali::TensorVector<Backend> data(curr_batch_size);
const auto &type_info = dali::TypeTable::GetTypeInfo(static_cast<dali::DALIDataType>(data_type));
auto elem_sizeof = type_info.size();
for (int i = 0; i < pipeline->batch_size(); i++) {
for (int i = 0; i < curr_batch_size; i++) {
// We cast away the const from data_ptr, as there is no other way of passing it to the
// Tensor as we must also set the shape and type metadata.
// The vector that we pass to pipeline is const.
Expand Down Expand Up @@ -113,22 +139,15 @@ void daliInitialize() {
std::call_once(init_flag, init);
}


void daliCreatePipeline(daliPipelineHandle *pipe_handle,
const char *serialized_pipeline,
int length,
int batch_size,
int num_threads,
int device_id,
int separated_execution,
int prefetch_queue_depth,
int cpu_prefetch_queue_depth,
int gpu_prefetch_queue_depth,
void daliCreatePipeline(daliPipelineHandle *pipe_handle, const char *serialized_pipeline,
int length, int max_batch_size, int num_threads, int device_id,
int separated_execution, int prefetch_queue_depth,
int cpu_prefetch_queue_depth, int gpu_prefetch_queue_depth,
int enable_memory_stats) {
bool se = separated_execution != 0;
auto pipeline = std::make_unique<dali::Pipeline>(std::string(serialized_pipeline, length),
batch_size, num_threads, device_id, true,
prefetch_queue_depth, true);
auto pipeline =
std::make_unique<dali::Pipeline>(std::string(serialized_pipeline, length), max_batch_size,
num_threads, device_id, true, prefetch_queue_depth, true);
pipeline->SetExecutionTypes(true, se, true);
if (se) {
pipeline->SetQueueSizes(cpu_prefetch_queue_depth, gpu_prefetch_queue_depth);
Expand All @@ -143,6 +162,9 @@ void daliCreatePipeline(daliPipelineHandle *pipe_handle,
pipe_handle->ws = ws.release();
pipe_handle->copy_stream = stream.release();
pipe_handle->pipe = pipeline.release();

auto bs_map = std::make_unique<batch_size_map_t>();
pipe_handle->batch_sizes_map = bs_map.release();
}


Expand Down Expand Up @@ -183,14 +205,20 @@ void daliPrefetchSeparate(daliPipelineHandle *pipe_handle,
}


void daliSetExternalInputBatchSize(daliPipelineHandle *pipe_handle, const char *name,
int batch_size) {
auto *bs_map = reinterpret_cast<batch_size_map_t *>(pipe_handle->batch_sizes_map);
(*bs_map)[name] = batch_size;
}


void daliSetExternalInput(daliPipelineHandle *pipe_handle, const char *name, device_type_t device,
const void *data_ptr, dali_data_type_t data_type, const int64_t *shapes,
int sample_dim, const char *layout_str, unsigned int flags) {
daliSetExternalInputAsync(pipe_handle, name, device, data_ptr, data_type, shapes, sample_dim,
layout_str, pipe_handle->copy_stream, flags | DALI_ext_force_sync);
}


void daliSetExternalInputAsync(daliPipelineHandle *pipe_handle, const char *name,
device_type_t device, const void *data_ptr,
dali_data_type_t data_type, const int64_t *shapes,
Expand Down Expand Up @@ -479,15 +507,18 @@ void daliCopyTensorListNTo(daliPipelineHandle *pipe_handle, void *dst, int outpu
void daliDeletePipeline(daliPipelineHandle* pipe_handle) {
dali::Pipeline *pipeline = reinterpret_cast<dali::Pipeline *>(pipe_handle->pipe);
dali::DeviceWorkspace *ws = reinterpret_cast<dali::DeviceWorkspace *>(pipe_handle->ws);
auto *bs_map = reinterpret_cast<batch_size_map_t *>(pipe_handle->batch_sizes_map);
DALI_ENFORCE(pipeline != nullptr && ws != nullptr, "Pipeline already deleted");
if (pipe_handle->copy_stream) {
CUDA_CALL(cudaStreamDestroy(pipe_handle->copy_stream));
}
pipe_handle->copy_stream = nullptr;
delete ws;
delete pipeline;
delete bs_map;
pipe_handle->ws = nullptr;
pipe_handle->pipe = nullptr;
pipe_handle->batch_sizes_map = nullptr;
}

void daliLoadLibrary(const char* lib_path) {
Expand Down
60 changes: 58 additions & 2 deletions dali/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ std::unique_ptr<Pipeline> GetTestPipeline(bool is_file_reader, const std::string
// Allows only for uint8_t CPU/GPU output data to be compared
template <typename Backend>
void ComparePipelinesOutputs(daliPipelineHandle &handle, Pipeline &baseline,
unsigned int copy_output_flags = DALI_ext_default) {
unsigned int copy_output_flags = DALI_ext_default,
int batch_size = dali::batch_size) {
dali::DeviceWorkspace ws;
baseline.Outputs(&ws);
daliOutput(&handle);
Expand Down Expand Up @@ -230,6 +231,7 @@ TYPED_TEST(CApiTest, ExternalSourceSingleAllocPipe) {
// Unnecessary copy in case of CPUBackend, makes the code generic across Backends
input.Copy(input_cpu, cuda_stream);
pipe_ptr->SetExternalInput(input_name, input);
daliSetExternalInputBatchSize(&handle, input_name.c_str(), input_shape.num_samples());
daliSetExternalInputAsync(&handle, input_name.c_str(), backend_to_device_type<TypeParam>::value,
input.raw_data(), dali_data_type_t::DALI_UINT8, input_shape.data(),
input_shape.sample_dim(), nullptr, cuda_stream, DALI_ext_default);
Expand Down Expand Up @@ -261,8 +263,62 @@ TYPED_TEST(CApiTest, ExternalSourceSingleAllocPipe) {
}


TYPED_TEST(CApiTest, ExternalSourceSingleAllocVariableBatchSizePipe) {
TensorListShape<> reference_input_shape = {{37, 23, 3}, {12, 22, 3}, {42, 42, 3}, {8, 8, 3},
{64, 32, 3}, {32, 64, 3}, {20, 20, 3}, {64, 64, 3},
{10, 10, 3}, {60, 50, 3}, {10, 15, 3}, {48, 48, 3}};
int max_batch_size = reference_input_shape.num_samples();
std::vector<TensorListShape<>> trimmed_input_shapes = {
subshape(reference_input_shape, 0, max_batch_size / 2),
subshape(reference_input_shape, 0, max_batch_size / 4),
subshape(reference_input_shape, 0, max_batch_size),
};

auto pipe_ptr = GetTestPipeline<TypeParam>(false, this->output_device_);
auto serialized = pipe_ptr->SerializeToProtobuf();

daliPipelineHandle handle;
daliCreatePipeline(&handle, serialized.c_str(), serialized.size(), batch_size, num_thread,
device_id, false, prefetch_queue_depth, prefetch_queue_depth,
prefetch_queue_depth, false);

for (auto &input_shape : trimmed_input_shapes) {
pipe_ptr = GetTestPipeline<TypeParam>(false, this->output_device_);
pipe_ptr->Build();

TensorList<CPUBackend> input_cpu;
TensorList<TypeParam> input;
input_cpu.Resize(input_shape, TypeInfo::Create<uint8_t>());

for (int i = 0; i < prefetch_queue_depth; i++) {
SequentialFill(view<uint8_t>(input_cpu), 42 * i);
// Unnecessary copy in case of CPUBackend, makes the code generic across Backends
input.Copy(input_cpu, cuda_stream);
pipe_ptr->SetExternalInput(input_name, input);
daliSetExternalInputBatchSize(&handle, input_name.c_str(), input_shape.num_samples());
daliSetExternalInputAsync(&handle, input_name.c_str(),
backend_to_device_type<TypeParam>::value, input.raw_data(),
dali_data_type_t::DALI_UINT8, input_shape.data(),
input_shape.sample_dim(), nullptr, cuda_stream, DALI_ext_default);
}

for (int i = 0; i < prefetch_queue_depth; i++) {
pipe_ptr->RunCPU();
pipe_ptr->RunGPU();
}
daliPrefetchUniform(&handle, prefetch_queue_depth);

dali::DeviceWorkspace ws;
for (int i = 0; i < prefetch_queue_depth; i++) {
ComparePipelinesOutputs<TypeParam>(handle, *pipe_ptr, DALI_ext_default,
input_shape.num_samples());
}
}
}


TYPED_TEST(CApiTest, ExternalSourceMultipleAllocPipe) {
TensorListShape<> input_shape = {{37, 23, 3}, {12, 22, 3}, {42, 42, 3}, {8, 8, 3},
TensorListShape<> input_shape = {{37, 23, 3}, {12, 22, 3}, {42, 42, 3}, {8, 8, 3},
{64, 32, 3}, {32, 64, 3}, {20, 20, 3}, {64, 64, 3},
{10, 10, 3}, {60, 50, 3}, {10, 15, 3}, {48, 48, 3}};
TensorList<CPUBackend> input_cpu;
Expand Down
19 changes: 19 additions & 0 deletions dali/core/tensor_shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1293,5 +1293,24 @@ TEST(AppendTest, AppendSingleTLS) {
EXPECT_EQ(tls_tested_dyn, ref);
}


TEST(TensorListShapeTest, SubshapeTest) {
TensorListShape<> tls = {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 9}};
TensorListShape<> ref091 = {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5},
{5, 6}, {6, 7}, {7, 8}, {8, 9}};
TensorListShape<> ref081 = {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}, {5, 6}, {6, 7}, {7, 8}};
TensorListShape<> ref052 = {{0, 1}, {2, 3}, {4, 5}};
TensorListShape<> ref183 = {{1, 2}, {4, 5}, {7, 8}};
TensorListShape<> ref092 = {{0, 1}, {2, 3}, {4, 5}, {6, 7}, {8, 9}};
TensorListShape<> ref192 = {{1, 2}, {3, 4}, {5, 6}, {7, 8}};

EXPECT_EQ(subshape(tls, 0, 9, 1), ref091);
EXPECT_EQ(subshape(tls, 0, 8, 1), ref081);
EXPECT_EQ(subshape(tls, 0, 5, 2), ref052);
EXPECT_EQ(subshape(tls, 1, 8, 3), ref183);
EXPECT_EQ(subshape(tls, 0, 9, 2), ref092);
EXPECT_EQ(subshape(tls, 1, 9, 2), ref192);
}

} // namespace kernels
} // namespace dali
Loading

0 comments on commit 3dd70d6

Please sign in to comment.