Skip to content

Commit

Permalink
Enable generating code for a given subgraph. (#21126)
Browse files Browse the repository at this point in the history
* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop
  • Loading branch information
Xreki committed Nov 20, 2019
1 parent 3ff5cc2 commit 6b1e1f0
Show file tree
Hide file tree
Showing 15 changed files with 945 additions and 223 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
if(NOT APPLE AND NOT WIN32)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif()
endif()

Expand Down
201 changes: 169 additions & 32 deletions paddle/fluid/framework/ir/fusion_group/code_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <set>
#include <sstream>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"

namespace paddle {
namespace framework {
Expand All @@ -30,69 +31,205 @@ CodeGenerator::CodeGenerator() {
code_templates_[0] = elementwise_t;
}

std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions);
}

std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
std::vector<OperationExpression> expressions;
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) {
auto* op = node->Op();

// Input ids should be set in fixed order, like:
// - x, y in forward operations
// - x, y, out, out@GRAD in backward operations
std::vector<int> input_ids;
std::vector<std::string> input_names =
OperationMap::Instance().Get(op->Type()).input_names;
for (auto& name : input_names) {
// TODO(liuyiqun): support duplicated input.
if (op->Input(name).size() >= 1U) {
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
PADDLE_ENFORCE_NE(var_ids.find(op->Input(name)[0]), var_ids.end(),
"Input(%s) of operation %s should be set.", name,
op->Type());
input_ids.push_back(var_ids[op->Input(name)[0]]);
} else {
input_ids.push_back(-1);
}
}
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
std::vector<int> output_ids;
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(op->Output(name).size(), 1U,
"Output(%s) of operation %s should be set.", name,
op->Type());
PADDLE_ENFORCE_NE(var_ids.find(op->Output(name)[0]), var_ids.end(),
"Output(%s) of operation %s should be set.", name,
op->Type());
output_ids.push_back(var_ids[op->Output(name)[0]]);
}
expressions.push_back(
OperationExpression(node->Name(), input_ids, output_ids));
}
}
return expressions;
}

// In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector.
std::string CodeGenerator::GenerateCode(
std::string CodeGenerator::Generate(
std::string func_name, std::vector<OperationExpression> expressions) {
// Check whether all expressions are elementwise operations.
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::string dtype = "float";
std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions);

TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(expressions, "float"));
template_var.Add("compute_body", EmitComputeBody(expressions));
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype));
template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtype));
return predefined_cuda_functions + code_templates_[0].Format(template_var);
}

// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(
std::vector<OperationExpression> expressions, std::string dtype) {
std::set<int> CodeGenerator::DistilInputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> input_ids;
std::set<int> output_ids;
// Remove the reptead id and get a ordered list.
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetInputIds()) {
input_ids.insert(id);
if (id >= 0) {
input_ids.insert(id);
}
}
}
return input_ids;
}

std::set<int> CodeGenerator::DistilOutputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> output_ids;
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) {
output_ids.insert(id);
}
}
return output_ids;
}

// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype) {
std::stringstream ret;
ret << "int N, ";

// If a id is in the input and output list at the same time, then remove it
// from the input list.
for (auto iter = input_ids.begin(); iter != input_ids.end();) {
if (output_ids.find(*iter) != output_ids.end()) {
input_ids.erase(iter++);
} else {
iter++;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtype << "* " << ArgName(id) << ", ";
}
}

std::stringstream ret;
ret << "int N, ";
for (auto iter = input_ids.begin(); iter != input_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter) << ", ";
}

size_t count_index = 0;
for (auto iter = output_ids.begin(); iter != output_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter);
if (count_index != output_ids.size() - 1) {
size_t index = 0;
for (auto id : output_ids) {
ret << dtype << "* " << ArgName(id);
if (index != output_ids.size() - 1) {
ret << ", ";
}
count_index++;
index++;
}

return ret.str();
}

std::string CodeGenerator::EmitComputeBody(
std::vector<OperationExpression> expressions) {
// get the right experssion code using suffix expression
std::stringstream ret;
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype) {
std::ostringstream compute;
std::unordered_set<int> used;
for (size_t i = 0; i < expressions.size(); i++) {
ret << expressions[i].GetExpression();
VLOG(3) << DebugString(expressions[i]);
compute << expressions[i].GetExpression(dtype, &used);
}
return ret.str();

// Load input to temporal variables.
std::ostringstream load;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
}

// Store temporal variables to memory.
std::ostringstream store;
for (auto id : output_ids) {
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}

return load.str() + compute.str() + store.str();
}

std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
SubGraph* subgraph) {
const auto& input_var_nodes = subgraph->GetInputVarNodes();
const auto& output_var_nodes = subgraph->GetOutputVarNodes();

int id = 0;
std::unordered_map<std::string, int> var_ids;
// Numbering input vars.
for (auto* in : input_var_nodes) {
VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id;
if (var_ids.find(in->Name()) == var_ids.end()) {
var_ids[in->Name()] = id++;
}
}
// Numbering internal vars.
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsVar() && node->Var()) {
bool is_found = false;
for (auto* in : input_var_nodes) {
if (node == in) {
is_found = true;
break;
}
}
if (is_found) {
continue;
}
for (auto* out : output_var_nodes) {
if (node == out) {
is_found = true;
break;
}
}
PADDLE_ENFORCE_EQ(
is_found, true,
"Subgraph with internal var nodes (%s) is not supported yet.",
node->Name());
}
}
// Encoding output vars.
for (auto* out : output_var_nodes) {
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
if (var_ids.find(out->Name()) == var_ids.end()) {
var_ids[out->Name()] = id++;
}
}
return var_ids;
}

} // namespace fusion_group
Expand Down
28 changes: 22 additions & 6 deletions paddle/fluid/framework/ir/fusion_group/code_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ limitations under the License. */

#pragma once

#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"

namespace paddle {
namespace framework {
Expand All @@ -27,18 +30,31 @@ class CodeGenerator {
public:
CodeGenerator();

std::string GenerateCode(std::string func_name,
std::vector<OperationExpression> expressions);
std::string Generate(std::string func_name,
std::vector<OperationExpression> expressions);

// TODO(wangchao): add a more general interface
// std::string Generate(const std::string name, const SubGraph& subgraph);
std::string Generate(SubGraph* subgraph);

std::vector<OperationExpression> ConvertToExpressions(SubGraph* subgraph);

private:
std::set<int> DistilInputIds(
const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions);

// we get the parameter list code for the expression information
std::string EmitParameters(std::vector<OperationExpression> expressions,
std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype);

std::string EmitComputeBody(std::vector<OperationExpression> expressions);
std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype);

// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);

private:
std::vector<CodeTemplate> code_templates_;
Expand Down
23 changes: 14 additions & 9 deletions paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ static T StringTo(const std::string& str) {
return value;
}

std::string OperationExpression::GetRHS(size_t i) {
auto rhs = OperationMap::Instance().Get(op_).exprs[i];
std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t i) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i];
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
Expand All @@ -47,29 +48,33 @@ std::string OperationExpression::GetRHS(size_t i) {
PADDLE_ENFORCE_LT(index, input_ids_.size(),
"Only %d inputs are provided, but need %d.",
input_ids_.size(), index + 1);
rhs.replace(pos, length + 3, VarName(input_ids_[index]) + R"([idx])");
PADDLE_ENFORCE_GE(input_ids_[index], 0,
"Input id should be no less than 0.");
rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
used->insert(input_ids_[index]);
}
}
return rhs;
}

std::string OperationExpression::GetLHS(size_t i) {
std::string OperationExpression::GetLHS(size_t i) const {
std::stringstream ret;
ret << VarName(output_ids_[i]) << R"([idx])";
ret << TmpName(output_ids_[i]);
return ret.str();
}

bool OperationExpression::IsSupport() {
return OperationMap::Instance().Has(op_);
bool OperationExpression::IsSupport() const {
return OperationMap::Instance().Has(op_type_);
}

// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std::string OperationExpression::GetExpression() {
std::string OperationExpression::GetExpression(
std::string dtype, std::unordered_set<int>* used) const {
std::stringstream ret;
if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << GetLHS(i) << " = " << GetRHS(i) << ";";
ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";";
}
}

Expand Down
Loading

0 comments on commit 6b1e1f0

Please sign in to comment.