Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime][PipelineExecutor] Getting the asynchronous output #10723

Merged
merged 1 commit into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def __init__(self, module):
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
self._get_pipe_execute_count = self.module["get_execute_count"]

def run(self, sync=False):
def run(self):
"""Run the pipeline executor."""
self._run(sync)
self._run()

def get_input_pipeline_map(self, name):
"""Using the "name" to get the corresponding subgraph index and also get the "input name"
Expand Down
11 changes: 3 additions & 8 deletions src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(); });
} else if (name == "run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(args[0]); });
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); });
} else if (name == "get_execute_count") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetExecutionCount(); });
Expand Down Expand Up @@ -140,13 +140,8 @@ int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) {
return param_connection_config[name];
}

/*!
* \brief Run the pipeline executor.
* \param serialized_mode Whether run the pipeline executor in serialized mode.
*/
void PipelineExecutor::Run(bool serialized_mode) {
pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_, serialized_mode);
}
/*!\brief Run the pipeline executor.*/
void PipelineExecutor::Run() { pipeline_scheduler_.PipelineRun(runtimes_, pipeline_config_); }
/*!
* \brief return A list of global output data.
*/
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
* \return The number of outputs.
*/
int NumOutputs() const { return num_outputs_; }
/*!
* \brief Run the pipeline executor.
* \param serialized_mode Whether run the pipeline executor in serialized mode.
*/
void Run(bool serialized_mode);
/*!\brief Run the pipeline executor.*/
void Run();
/*!
* \brief Get a list output data.
* \return A list of output data.
Expand Down
61 changes: 8 additions & 53 deletions src/runtime/pipeline/pipeline_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
const std::vector<Module>& modules, const ConfigPipelineExecution& pipeline_config) {
std::vector<std::shared_ptr<BackendRuntime>> runtimes;
graph_modules_ = modules;
global_runtime_ = std::make_shared<GlobalRuntime>(GLOBAL_MODULE_INDEX);
// Creating a list of runtimes.
for (size_t i = 0; i < graph_modules_.size(); i++) {
auto run_item = std::make_shared<BackendRuntime>(graph_modules_[i], i);
Expand All @@ -49,71 +50,25 @@ std::vector<std::shared_ptr<BackendRuntime>> PipelineScheduler::PipelineInit(
}
// Initializing and then running the worker thread.
for (auto runtime : runtimes) {
runtime->InitializePipeline(pipeline_config, &runtimes);
runtime->InitializePipeline(pipeline_config, &runtimes, global_runtime_);
}
return runtimes;
}
/*!
* \brief Running the pipeline logic in the sequential mode.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependent configuration of each runtime module.
*/
void PipelineScheduler::PipelineRunSequential(
const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config) {
for (size_t i = 0; i < runtimes.size(); i++) {
// The "runtimes" is a list of runtime sorted by the runtime index which should be
// contiguous ascend.
if (static_cast<int>(i) != runtimes[i]->GetModuleIndex()) {
LOG(FATAL) << "Runtime index " << runtimes[i]->GetModuleIndex()
<< " is not as same as vector offset value " << i;
}

if (!pipeline_config.FindModuleInConfig(i)) {
LOG(FATAL) << "Not find the configuration for the module " << i;
}

runtimes[i]->Run();
// Getting the output then forwarding into other module once it is configured as input of
// another module or storaging into the "output_array" when the output is a global one.
int outputs_num = runtimes[i]->NumOutputs();
for (int j = 0; j < outputs_num; j++) {
ConfigBindings& out_binding = pipeline_config[i][j];
std::unordered_map<int, std::string>& input_connections = out_binding.Get();
NDArray output = runtimes[i]->GetOutput(j);
for (auto bind : input_connections) {
// "bind.first < 0" means the bind is a global bind, by pass the forwarding for
// a global bind.
if (bind.first < 0) continue;
// Setting the output as an input data into the runtime module.
runtimes[bind.first]->SetInput(bind.second, const_cast<DLTensor*>(output.operator->()));
}
// Store the output.
if (out_binding.IsGlobalOutput()) {
int global_idx = out_binding.GetGlobalOutputIndex();
TVMArrayCopyFromTo(const_cast<DLTensor*>(output.operator->()),
const_cast<DLTensor*>(output_arrays_[global_idx].operator->()), nullptr);
}
}
}
}
/*!
* \brief Running pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
* \param sequential_mode Whether the execution is in a sequential mode.
*/
void PipelineScheduler::PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config, bool sequential_mode) {
if (!sequential_mode) {
runtimes.front()->RunPipeline();
} else {
PipelineRunSequential(runtimes, pipeline_config);
}
ConfigPipelineExecution pipeline_config) {
runtimes.front()->RunPipeline();
}
/*!
* \brief Get a list of output.
*/
Array<NDArray> PipelineScheduler::PipelineGetOutput() { return output_arrays_; }
Array<NDArray> PipelineScheduler::PipelineGetOutput() {
bool ret = global_runtime_->GetOutput(&output_arrays_);
return ret ? output_arrays_ : Array<NDArray>{};
}
} // namespace runtime
} // namespace tvm
12 changes: 3 additions & 9 deletions src/runtime/pipeline/pipeline_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,9 @@ class PipelineScheduler {
* \brief Running the pipeline logic.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependency configuration of each runtime module.
* \param sequential_mode Whether the execution is in a sequential mode.
*/
void PipelineRun(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config, bool sequential_mode = false);
/*!
* \brief Running the pipeline logic in the sequential mode.
* \param runtimes A list of backend runtime modules.
* \param pipeline_config The dependent configuration of each runtime module.
*/
void PipelineRunSequential(const std::vector<std::shared_ptr<BackendRuntime>>& runtimes,
ConfigPipelineExecution pipeline_config);
ConfigPipelineExecution pipeline_config);
/*!
* \brief Get a list of outputs.
*/
Expand All @@ -68,6 +60,8 @@ class PipelineScheduler {
std::vector<Module> graph_modules_;
/*!\brief A list of NDArray used to storage outputs.*/
Array<NDArray> output_arrays_;
/*!\brief The global runtime to represent the pipeline executor.*/
std::shared_ptr<GlobalRuntime> global_runtime_;
};
} // namespace runtime
} // namespace tvm
Expand Down
Loading