Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add direct operator calls in debug mode #3734

Merged
merged 24 commits into from
Apr 5, 2022

Conversation

ksztenderski
Copy link
Contributor

@ksztenderski ksztenderski commented Mar 15, 2022

Category:

New feature (non-breaking change which adds functionality)

Description:

It adds backend implementation of direct operators and replaces minipipelines in debug mode with direct operators.

Additional information:

Affected modules and functionalities:

Debug mode pipeline

Key points relevant for the review:

WIP, no review needed yet.

Checklist

Tests

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

const std::unordered_map<std::string, std::shared_ptr<TensorList<CPUBackend>>> &kwargs,
cudaStream_t cuda_stream) {
ws.set_stream(cuda_stream);
CUDA_CALL(cudaStreamSynchronize(cuda_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term we can make this synchronization optional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we ever need to synchronize before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the synchronization before the run and left only synchronization after. For now (especially in terms of debug mode) it shouldn't bother us but in the future it seems unnecessary to synchronize immediately after the run and only synchronize when we actually need this data (that synchronization before the run was kind of supposed to show that concept).

* Basic implementation of calls in debug mode
* Basic exposing of direct operators to ops.experimental
* TODO: Create debug pipeline class in C++ to keep thread pool

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Moved workspace clear before setting cuda stream

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
@lgtm-com
Copy link
Contributor

lgtm-com bot commented Mar 16, 2022

This pull request fixes 1 alert when merging 596bafe into 02d04aa - view on LGTM.com

fixed alerts:

  • 1 for Unused import

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
@ksztenderski ksztenderski marked this pull request as ready for review March 16, 2022 18:24
@ksztenderski
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4164026]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Mar 16, 2022

This pull request fixes 1 alert when merging cc7a4dc into fd6a8b9 - view on LGTM.com

fixed alerts:

  • 1 for Unused import

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4164026]: BUILD FAILED

* @brief Direct operator providing eager execution of an operator in Run.
*/
template <typename Backend>
class DLL_PUBLIC DirectOperator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class DLL_PUBLIC DirectOperator {
class DLL_PUBLIC ImmediateOperator {

or

Suggested change
class DLL_PUBLIC DirectOperator {
class DLL_PUBLIC EagerOperator {

?
"Direct" is doesn't really convey this meaning. It's an merely an abbreviation for "DirectlyCalledOperator" which is lengthy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, went with EagerOperator

CUDA_CALL(cudaStreamSynchronize(cuda_stream));
auto output = RunImpl<GPUBackend, GPUBackend, TensorList<GPUBackend>, TensorList<GPUBackend>>(
inputs, kwargs);
CUDA_CALL(cudaStreamSynchronize(cuda_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could get rid of this one, too - we could (and should) expose associated stream in TensorXxxGPU in Python and just tell the user that the data is available for that stream. We can (and already do) synchronize D2H copies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating API for stream exposure seems like a good idea for a follow-up as it'll probably generate a lot of code and is not really that necessary for the debug mode. But fully agree with support for it in terms of eager operators.

Comment on lines 32 to 43
template <typename Backend>
std::shared_ptr<TensorList<Backend>> AsTensorList(std::shared_ptr<TensorList<Backend>> input) {
return input;
}

template <typename Backend>
std::shared_ptr<TensorList<Backend>> AsTensorList(std::shared_ptr<TensorVector<Backend>> input) {
// TODO(ksztenderski): Remove copy.
auto tl = std::make_shared<TensorList<Backend>>();
tl->Copy(*input);
return tl;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a very similar thing in workspace_policy.h - PresentAsTensorList. Maybe it should be unified and moved to a common untility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doing a copy here (I guess for a prototype) which will always work but the PresentAsTensorList requires contiguous TV to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be another flavor or option of that function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

DALI_FAIL("Unsupported backends in DirectOperator.Run().");
}

// Runs operator using specified thread pool.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Runs operator using specified thread pool.
// Runs operator using specified thread pool and shared CUDA stream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's kind of the point, that it supports only CPU operators so the CUDA stream is not set.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand this is template. So it can be used to any kind of op, including mixed and GPU one.

DALI_FAIL("Unsupported backends in DirectOperator.Run() with thread pool.");
}

// Runs operator using specified CUDA stream.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Runs operator using specified CUDA stream.
// Runs operator using shared thread and specified CUDA stream.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the opposite, supports only GPU operators and thread pool is not set.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that the template prevent creating a CPU only run function with such signature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well yes, but the thread pool won't be set anyway.

DALI_FAIL("Unsupported backends in DirectOperator.Run() with CUDA stream");
}

// Set shared thread pool used for all direct operators.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Set shared thread pool used for all direct operators.
// Creates thread pool used for all direct operators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that "creates" might suggest that by default there is no thread pool, but I agree that "set" is not perfect either. Maybe "update"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

shared_thread_pool = std::make_unique<ThreadPool>(num_threads, device_id, set_affinity);
}

// Set shared CUDA stream used for all direct operators.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Set shared CUDA stream used for all direct operators.
// Creates shared CUDA stream used for all direct operators.

I would rather expect that Set function accepts the value it should set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (changed to "update")

}

// Set shared thread pool used for all direct operators.
DLL_PUBLIC inline static void SetThreadPool(int num_threads, int device_id, bool set_affinity) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DLL_PUBLIC inline static void SetThreadPool(int num_threads, int device_id, bool set_affinity) {
DLL_PUBLIC inline static void CreateThreadPool(int num_threads, int device_id, bool set_affinity) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, (UpdateThreadPool)

}

// Set shared CUDA stream used for all direct operators.
DLL_PUBLIC inline static void SetCudaStream(int device_id) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DLL_PUBLIC inline static void SetCudaStream(int device_id) {
DLL_PUBLIC inline static void CreateCudaStream(int device_id) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, (UpdateCudaStream)

OpSpec op_spec;
std::unique_ptr<OperatorBase> op;

static cudaStream_t shared_cuda_stream;
Copy link
Contributor

@JanuszL JanuszL Mar 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mzient - I think it should be

Suggested change
static cudaStream_t shared_cuda_stream;
static CUDAStreamLease shared_cuda_stream_;

I'm not sure if having it static works well with SetCudaStream that can change it for all instances of this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed type to CUDAStreamLease. Is having static cuda_stream a problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just afraid weird issues when the library is wrapping up and something is still using given stream.

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Mar 29, 2022

This pull request fixes 1 alert when merging 2da6e95 into 568826f - view on LGTM.com

fixed alerts:

  • 1 for Unused import

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4261738]: BUILD FAILED

Copy link
Contributor

@klecki klecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok, small nitpick.

I am also wandering if we can add a simple check for the current batch size and raise error that the variable batch size is not supported if we encounter smaller than max_batch_size batch until we start supporting it.


return op_helper, init_args, inputs_classification, kwargs_classification, len(inputs)
self._operators[key] = _OperatorManager(
op_class, self._seed_generator.integers(0, 2**32), inputs, kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a change that calls the seed generator every time instead of running it when the op didn't have the argument. I guess it doesn't really matter, but just wanted to check if it's intended.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is intended. I wanted to set seed in the operator after classification and I didn't want to pass seed generator to the OperatorManager. The reason why I wanted the seed to be set after classification is that when later we run the operator and check if its current arguments are correct we miss the seed argument (as intended), but in the expected classification we would have it. And, as you pointed out, it doesn't really matter because it's still deterministic.

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4261738]: BUILD PASSED

@JanuszL JanuszL mentioned this pull request Mar 30, 2022
* Add error for variable batch_size

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
@ksztenderski
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4272690]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Mar 30, 2022

This pull request fixes 2 alerts when merging 92bfc65 into 9b24277 - view on LGTM.com

fixed alerts:

  • 2 for Unused import



@pipeline_def(batch_size=8, num_threads=3, device_id=0, debug=True)
def incorrect_input_Sets_pipeline():
Copy link
Contributor

@JanuszL JanuszL Mar 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def incorrect_input_Sets_pipeline():
def incorrect_input_sets_pipeline():

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return tuple(output)


@raises(ValueError, glob="All argument lists for Multpile Input Sets used with operator 'Cat' must have the same length")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the error message have the name of the operator consistent with the API used - so Cat for ops and cat for fn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As most of these error messages are for functionalities specifically for debug mode and the only way to use operators in debug mode is with fn API than I guess we can just change it to the snake_case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I would go for fn names.

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4272690]: BUILD FAILED

@@ -79,7 +79,7 @@ def fn_wrapper(*inputs, **kwargs):
from nvidia.dali._debug_mode import _PipelineDebug
current_pipeline = _PipelineDebug.current()
if getattr(current_pipeline, '_debug_on', False):
return current_pipeline._wrap_op_call(op_wrapper, inputs, kwargs)
return current_pipeline._wrap_op_call(op_class, *inputs, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return current_pipeline._wrap_op_call(op_class, *inputs, **kwargs)
return current_pipeline._wrap_op_call(op_class, *inputs, name="_to_snake_case(op_class.__name__)", **kwargs)

or it doesn't make any sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense, that's the way the names are created here, but I think it's even better to just pass the name to the _wrap_op_class and have the name generation in one place.

else:
raise RuntimeError(f"Unexpected operator '{op_wrapper.__name__}'. Debug mode does not support"
raise RuntimeError(f"Unexpected operator '{op_class}'. Debug mode does not support"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why changing this? Won't we get Ops.Xyz style names even if we use fn API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted to fn style names.

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
@lgtm-com
Copy link
Contributor

lgtm-com bot commented Mar 30, 2022

This pull request fixes 2 alerts when merging ea94e0e into 65616c5 - view on LGTM.com

fixed alerts:

  • 2 for Unused import

Copy link
Contributor

@klecki klecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nitpick and I think there is small issue in MIS validation. Otherwise looks ok.


aritm_fn_name = _to_snake_case(_ops.ArithmeticGenericOp.__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, shouldn't this be a private member or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (static member)

if input_set_len == 1:
input_set_len = len(classification.is_batch)
else:
raise ValueError("All argument lists for Multipile Input Sets used "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't we raise the error for the second input? If we save the input_set_len from first iteration, we will hit the else in the second one I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

classification = _Classification(input, f'Input {i}')

if isinstance(classification.is_batch, list):
if input_set_len == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a -1 or something make more sense for the initial value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
@ksztenderski
Copy link
Contributor Author

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4330022]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Apr 1, 2022

This pull request fixes 2 alerts when merging b3846c4 into 999379b - view on LGTM.com

fixed alerts:

  • 2 for Unused import

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4330022]: BUILD PASSED

@ksztenderski ksztenderski merged commit 98c2a36 into NVIDIA:main Apr 5, 2022
cyyever pushed a commit to cyyever/DALI that referenced this pull request May 13, 2022
* Add base backend implementation of eager operators
* Add backend implementation of PipelineDebug managing backend operators
* Add OperatorManager util class for debug mode
* Replace minipipelines in debug mode by eager operators

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
cyyever pushed a commit to cyyever/DALI that referenced this pull request Jun 7, 2022
* Add base backend implementation of eager operators
* Add backend implementation of PipelineDebug managing backend operators
* Add OperatorManager util class for debug mode
* Replace minipipelines in debug mode by eager operators

Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants