Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JanuszL committed Apr 20, 2023
2 parents 581f755 + 8f64d02 commit e3b0f7c
Show file tree
Hide file tree
Showing 17 changed files with 622 additions and 309 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ Additional Resources
.. |talkAdvanced2022| replace:: event
.. _talkAdvanced2022: https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41442/
.. |talkGeospatial2023| replace:: event
.. _talkGeospatial2023: https://www.nvidia.com/gtc/session-catalog/?search=S51796
.. _talkGeospatial2023: https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51796/
.. |breakoutDALITRITON| replace:: event
.. _breakoutDALITRITON: https://www.nvidia.com/gtc/session-catalog/?search=SE52140
.. _breakoutDALITRITON: https://www.nvidia.com/en-us/on-demand/session/gtcspring23-se52140/
----

Contributing to DALI
Expand Down
6 changes: 6 additions & 0 deletions dali/operators/input/video_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ void VideoInput<Backend, FramesDecoder>::VideoInputRunImpl(Workspace &ws) {
}
}

// If true, this operator can be run again, after this Run.
bool will_return_next = true;

// There won't be any more output using the current input.
bool input_sample_depleted = !full_sequence || frames_decoders_[0]->NextFrameIdx() == -1;

Expand All @@ -402,12 +405,15 @@ void VideoInput<Backend, FramesDecoder>::VideoInputRunImpl(Workspace &ws) {
* "next_output_data_id" trace.
*/
LoadDataFromInputOperator(GetThreadPool(ws));
} else {
will_return_next = false;
}
}

if (data_id_) {
SetNextDataIdTrace(ws, *data_id_);
}
InputOperator<Backend>::SetDepletedOperatorTrace(ws, !will_return_next);
}

} // namespace dali
Expand Down
68 changes: 48 additions & 20 deletions dali/operators/input/video_input_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,22 @@ class VideoInputNextOutputDataIdTest : public ::testing::Test {
daliRun(h);
daliOutput(h);

DoesOperatorTraceExist(h, i, test_file_idx);
IsOperatorTraceCorrect(h, i, test_file_idx);
AssertDataIdTraceExist(h, i, test_file_idx);
CheckDataIdTraceValue(h, i, test_file_idx);
AssertDepletedTraceExists(h);
CheckDepletedTraceValue(h, false);
}
/*
* The last iteration of the pipeline shall carry a different result.
* Since this function tests a single file, after the last iteration there shouldn't
* be a "next_output_data_id" trace available.
* "depleted" trace shall always be available.
*/
daliRun(h);
daliOutput(h);
EXPECT_EQ(daliHasOperatorTrace(h, video_input_name_.c_str(), trace_name_.c_str()), 0);
EXPECT_EQ(daliHasOperatorTrace(h, video_input_name_.c_str(), data_id_trace_name_.c_str()), 0);
AssertDepletedTraceExists(h);
CheckDepletedTraceValue(h, true);
}


Expand Down Expand Up @@ -171,29 +176,51 @@ class VideoInputNextOutputDataIdTest : public ::testing::Test {
/**
* Check, if the "next_output_data_id" trace exists, provided it should.
*/
void DoesOperatorTraceExist(daliPipelineHandle *h, int iteration_idx, int test_file_idx) {
bool has_data_id = daliHasOperatorTrace(h, video_input_name_.c_str(), trace_name_.c_str());
void AssertDataIdTraceExist(daliPipelineHandle *h, int iteration_idx, int test_file_idx) {
bool has_data_id = daliHasOperatorTrace(h, video_input_name_.c_str(),
data_id_trace_name_.c_str());
ASSERT_EQ(
has_data_id,
!test_files_[test_file_idx].data_id.empty())
<< "Failed at iteration " << iteration_idx << " of file with index " << test_file_idx;
<< "Failed at iteration " << iteration_idx << " of file with index "
<< test_file_idx;
}


/**
* Verify, if the operator trace has a correct value (provided it should exist).
* Verify, if the "next_output_data_id" trace has a correct value (provided it should exist).
*/
void IsOperatorTraceCorrect(daliPipelineHandle *h, int iteration_idx, int test_file_idx) {
bool has_data_id = daliHasOperatorTrace(h, video_input_name_.c_str(), trace_name_.c_str());
void CheckDataIdTraceValue(daliPipelineHandle *h, int iteration_idx, int test_file_idx) {
bool has_data_id = daliHasOperatorTrace(h, video_input_name_.c_str(),
data_id_trace_name_.c_str());
if (has_data_id) {
EXPECT_STREQ(
daliGetOperatorTrace(h, video_input_name_.c_str(), trace_name_.c_str()),
daliGetOperatorTrace(h, video_input_name_.c_str(), data_id_trace_name_.c_str()),
test_files_[test_file_idx].data_id.c_str())
<< "Failed at iteration " << iteration_idx << " of file with index " << test_file_idx;
}
}


/**
* Check, if the "depleted" trace exists.
*/
void AssertDepletedTraceExists(daliPipelineHandle *h) {
// The "depleted" trace should always exist.
ASSERT_TRUE(daliHasOperatorTrace(h, video_input_name_.c_str(), depleted_trace_name_.c_str()));
}


/**
* Verify the value of "depleted" trace.
* @param shall_be_depleted Expected value.
*/
void CheckDepletedTraceValue(daliPipelineHandle *h, bool shall_be_depleted) {
EXPECT_STREQ(daliGetOperatorTrace(h, video_input_name_.c_str(), depleted_trace_name_.c_str()),
shall_be_depleted ? "true" : "false");
}


const int batch_size_ = 3;
const int num_threads_ = 2;
const int device_id_ = 0;
Expand All @@ -216,7 +243,8 @@ class VideoInputNextOutputDataIdTest : public ::testing::Test {
},
};
const int frames_per_sequence_ = 4;
const std::string trace_name_ = "next_output_data_id";
const std::string data_id_trace_name_ = "next_output_data_id";
const std::string depleted_trace_name_ = "depleted";

std::string serialized_pipeline_;
};
Expand Down Expand Up @@ -342,29 +370,29 @@ TYPED_TEST(VideoInputNextOutputDataIdTest, MultipleInputFilesParallelTest) {
iteration_idx < num_iterations_per_input[test_file_idx] - 1; iteration_idx++) {
daliRun(&h);
daliOutput(&h);
this->DoesOperatorTraceExist(&h, iteration_idx, test_file_idx);
this->IsOperatorTraceCorrect(&h, iteration_idx, test_file_idx);
this->AssertDataIdTraceExist(&h, iteration_idx, test_file_idx);
this->CheckDataIdTraceValue(&h, iteration_idx, test_file_idx);
}
daliRun(&h);
daliOutput(&h);
this->DoesOperatorTraceExist(&h, num_iterations_per_input[test_file_idx] - 1,
next_test_file_idx);
this->IsOperatorTraceCorrect(&h, num_iterations_per_input[test_file_idx] - 1,
this->AssertDataIdTraceExist(&h, num_iterations_per_input[test_file_idx] - 1,
next_test_file_idx);
this->CheckDataIdTraceValue(&h, num_iterations_per_input[test_file_idx] - 1,
next_test_file_idx);
}
// The last test file should just clear the "next_output_data_id" trace after it's done.
auto test_file_idx = test_files_order.back();
for (int iteration_idx = 0;
iteration_idx < num_iterations_per_input[test_file_idx] - 1; iteration_idx++) {
daliRun(&h);
daliOutput(&h);
this->DoesOperatorTraceExist(&h, iteration_idx, test_file_idx);
this->IsOperatorTraceCorrect(&h, iteration_idx, test_file_idx);
this->AssertDataIdTraceExist(&h, iteration_idx, test_file_idx);
this->CheckDataIdTraceValue(&h, iteration_idx, test_file_idx);
}
daliRun(&h);
daliOutput(&h);
EXPECT_EQ(daliHasOperatorTrace(&h, this->video_input_name_.c_str(), this->trace_name_.c_str()),
0);
EXPECT_EQ(daliHasOperatorTrace(&h, this->video_input_name_.c_str(),
this->data_id_trace_name_.c_str()), 0);
daliDeletePipeline(&h);
}

Expand Down
1 change: 0 additions & 1 deletion dali/pipeline/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,6 @@ int Executor<WorkspacePolicy, QueuePolicy>::InferBatchSize(
int batch_size;
try {
batch_size = bsps[0]->NextBatchSize();
assert(batch_size > 0);
for (auto &bsp : bsps) {
bsp->Advance();
}
Expand Down
4 changes: 3 additions & 1 deletion dali/pipeline/operator/builtin/caching_list.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,6 +62,8 @@ class CachingList {
assert(!full_data_.empty()); // Can't pop from an empty list
std::list<T> tmp;
tmp.splice(tmp.begin(), full_data_, full_data_.begin());
if (tmp.begin() == prophet_)
prophet_ = full_data_.begin();
return tmp;
}

Expand Down
8 changes: 6 additions & 2 deletions dali/pipeline/operator/builtin/external_source.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -18,11 +18,12 @@

namespace dali {

template <>
template<>
void ExternalSource<CPUBackend>::RunImpl(Workspace &ws) {
auto &output = ws.Output<CPUBackend>(0);
auto &thread_pool = ws.GetThreadPool();
ForwardCurrentData(output, data_id_, thread_pool);
SetDepletedOperatorTrace(ws, !(repeats_last_ || HasDataInQueue()));
}


Expand All @@ -31,6 +32,7 @@ void ExternalSource<GPUBackend>::RunImpl(Workspace &ws) {
auto &output = ws.Output<GPUBackend>(0);
cudaStream_t stream_used = ws.has_stream() ? ws.stream() : 0;
ForwardCurrentData(output, data_id_, stream_used);
SetDepletedOperatorTrace(ws, !(repeats_last_ || HasDataInQueue()));
}


Expand Down Expand Up @@ -66,6 +68,8 @@ of dimensions in the layout.
Specifying the input dimensionality will be required starting from DALI 2.0)code", nullptr)
.AddOptionalArg<TensorLayout>("layout",
"If provided, sets the layout of the data.", nullptr)
.AddOptionalArg("repeat_last", R"(If set, the last batch is re-fed when running
the operator and no new data was provided since the previous run.)", false)
.AddParent("InputOperatorBase");


Expand Down
2 changes: 2 additions & 0 deletions dali/pipeline/operator/builtin/external_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ExternalSource : public InputOperator<Backend> {
public:
explicit ExternalSource(const OpSpec &spec)
: InputOperator<Backend>(spec),
repeats_last_(spec.GetArgument<bool>("repeat_last")),
previous_dtype_(DALIDataType::DALI_NO_TYPE),
ndim_(-1),
layout_() {
Expand Down Expand Up @@ -168,6 +169,7 @@ class ExternalSource : public InputOperator<Backend> {
layout_ = batch.GetLayout();
}

const bool repeats_last_;

string output_name_;

Expand Down
Loading

0 comments on commit e3b0f7c

Please sign in to comment.