Skip to content

Commit

Permalink
Separate load and store from computation expressions.
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
Xreki committed Nov 14, 2019
1 parent 90a7b37 commit fe3d886
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 54 deletions.
83 changes: 55 additions & 28 deletions paddle/fluid/framework/ir/fusion_group/code_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <set>
#include <sstream>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
Expand Down Expand Up @@ -88,65 +87,93 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
// store the expression as suffix Expressions using vector.
std::string CodeGenerator::Generate(
std::string func_name, std::vector<OperationExpression> expressions) {
// Check whether all expressions are elementwise operations.
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::string dtype = "float";
std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions);

TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(expressions, "float"));
template_var.Add("compute_body", EmitComputeBody(expressions));
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype));
template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtype));
return predefined_cuda_functions + code_templates_[0].Format(template_var);
}

// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(
std::vector<OperationExpression> expressions, std::string dtype) {
std::set<int> CodeGenerator::DistilInputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> input_ids;
std::set<int> output_ids;
// Remove the reptead id and get a ordered list.
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetInputIds()) {
input_ids.insert(id);
}
}
return input_ids;
}

std::set<int> CodeGenerator::DistilOutputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> output_ids;
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) {
output_ids.insert(id);
}
}
return output_ids;
}

// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype) {
std::stringstream ret;
ret << "int N, ";

// If a id is in the input and output list at the same time, then remove it
// from the input list.
for (auto iter = input_ids.begin(); iter != input_ids.end();) {
if (output_ids.find(*iter) != output_ids.end()) {
input_ids.erase(iter++);
} else {
iter++;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtype << "* " << ArgName(id) << ", ";
}
}

std::stringstream ret;
ret << "int N, ";
for (auto iter = input_ids.begin(); iter != input_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter) << ", ";
}

size_t count_index = 0;
for (auto iter = output_ids.begin(); iter != output_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter);
if (count_index != output_ids.size() - 1) {
size_t index = 0;
for (auto id : output_ids) {
ret << dtype << "* " << ArgName(id);
if (index != output_ids.size() - 1) {
ret << ", ";
}
count_index++;
index++;
}

return ret.str();
}

std::string CodeGenerator::EmitComputeBody(
std::vector<OperationExpression> expressions) {
// get the right experssion code using suffix expression
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype) {
std::stringstream ret;

// Load input to temporal variables.
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
}

for (size_t i = 0; i < expressions.size(); i++) {
VLOG(3) << DebugString(expressions[i]);
ret << expressions[i].GetExpression();
ret << expressions[i].GetExpression(dtype);
}

// Store temporal variables to memory.
for (auto id : output_ids) {
ret << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}

return ret.str();
}

Expand Down
14 changes: 12 additions & 2 deletions paddle/fluid/framework/ir/fusion_group/code_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <set>
#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -37,11 +38,20 @@ class CodeGenerator {
std::vector<OperationExpression> ConvertToExpressions(SubGraph* subgraph);

private:
std::set<int> DistilInputIds(
const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions);

// we get the parameter list code for the expression information
std::string EmitParameters(std::vector<OperationExpression> expressions,
std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype);

std::string EmitComputeBody(std::vector<OperationExpression> expressions);
std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype);

// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static T StringTo(const std::string& str) {
return value;
}

std::string OperationExpression::GetRHS(size_t i) {
std::string OperationExpression::GetRHS(size_t i) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i];
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i;
Expand All @@ -47,29 +47,29 @@ std::string OperationExpression::GetRHS(size_t i) {
PADDLE_ENFORCE_LT(index, input_ids_.size(),
"Only %d inputs are provided, but need %d.",
input_ids_.size(), index + 1);
rhs.replace(pos, length + 3, VarName(input_ids_[index]) + R"([idx])");
rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
}
}
return rhs;
}

std::string OperationExpression::GetLHS(size_t i) {
std::string OperationExpression::GetLHS(size_t i) const {
std::stringstream ret;
ret << VarName(output_ids_[i]) << R"([idx])";
ret << TmpName(output_ids_[i]);
return ret.str();
}

bool OperationExpression::IsSupport() {
bool OperationExpression::IsSupport() const {
return OperationMap::Instance().Has(op_type_);
}

// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std::string OperationExpression::GetExpression() {
std::string OperationExpression::GetExpression(std::string dtype) const {
std::stringstream ret;
if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << GetLHS(i) << " = " << GetRHS(i) << ";";
ret << dtype << " " << GetLHS(i) << " = " << GetRHS(i) << ";";
}
}

Expand Down
16 changes: 11 additions & 5 deletions paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ namespace framework {
namespace ir {
namespace fusion_group {

static std::string VarName(int index) { return "var" + std::to_string(index); }
static inline std::string ArgName(int index) {
return "arg" + std::to_string(index);
}
static inline std::string TmpName(int index) {
return "tmp" + std::to_string(index);
}

class OperationExpression {
public:
Expand All @@ -40,13 +45,14 @@ class OperationExpression {
std::vector<int> GetOutputIds() const { return output_ids_; }

// Check whether this operation type is supported in OperationMap.
bool IsSupport();
bool IsSupport() const;

std::string GetExpression();
std::string GetExpression(std::string dtype) const;

private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(size_t i = 0);
std::string GetLHS(size_t i = 0);
std::string GetRHS(size_t i = 0) const;
std::string GetLHS(size_t i = 0) const;

private:
std::string op_type_;
Expand Down
38 changes: 26 additions & 12 deletions paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ limitations under the License. */
#include "paddle/fluid/platform/init.h"

#ifdef PADDLE_WITH_CUDA
namespace fusion_group = paddle::framework::ir::fusion_group;

namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {

// relu
inline float relu(float x) { return x > 0 ? x : 0.; }
Expand Down Expand Up @@ -81,11 +85,10 @@ inline float elementwise_mul_grad_dy(float x, float y, float out, float dout) {
return dout * x;
}

void CheckOutput(
const std::vector<fusion_group::OperationExpression>& expressions,
const std::vector<paddle::framework::LoDTensor> cpu_tensors,
const std::vector<int> input_ids_of_subgraph,
const std::vector<int> output_ids_of_subgraph, int i) {
void CheckOutput(const std::vector<OperationExpression>& expressions,
const std::vector<LoDTensor> cpu_tensors,
const std::vector<int> input_ids_of_subgraph,
const std::vector<int> output_ids_of_subgraph, int i) {
std::vector<float> var(cpu_tensors.size());
for (auto id : input_ids_of_subgraph) {
var[id] = cpu_tensors[id].data<float>()[i];
Expand Down Expand Up @@ -139,7 +142,7 @@ void CheckOutput(
}

template <typename T>
void SetupRandomCPUTensor(paddle::framework::LoDTensor* tensor) {
void SetupRandomCPUTensor(LoDTensor* tensor) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
Expand All @@ -152,6 +155,13 @@ void SetupRandomCPUTensor(paddle::framework::LoDTensor* tensor) {
}
}

} // namespace fusion_group
} // namespace ir
} // namespace framework
} // namespace paddle

namespace fusion_group = paddle::framework::ir::fusion_group;

void TestMainImpl(std::string func_name, std::string code_str,
std::vector<paddle::framework::LoDTensor> cpu_tensors, int n,
std::vector<int> input_ids, std::vector<int> output_ids) {
Expand All @@ -171,7 +181,7 @@ void TestMainImpl(std::string func_name, std::string code_str,
cpu_tensors[input_ids[i]].dims(), place);
args.push_back(&gpu_ptrs[input_ids[i]]);

SetupRandomCPUTensor<float>(&cpu_tensors[input_ids[i]]);
fusion_group::SetupRandomCPUTensor<float>(&cpu_tensors[input_ids[i]]);
TensorCopySync(cpu_tensors[input_ids[i]], place,
&gpu_tensors[input_ids[i]]);
}
Expand Down Expand Up @@ -260,7 +270,8 @@ TEST(code_generator, elementwise) {

// Check the results
for (int i = 0; i < n; i++) {
CheckOutput(expressions, cpu_tensors, input_ids, output_ids, i);
fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
i);
}
}

Expand Down Expand Up @@ -294,7 +305,8 @@ TEST(code_generator, elementwise_grad) {

// Check the results
for (int i = 0; i < n; i++) {
CheckOutput(expressions, cpu_tensors, input_ids, output_ids, i);
fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
i);
}
}

Expand Down Expand Up @@ -412,7 +424,8 @@ TEST(code_generator, subgraph) {

// Check the results
for (int i = 0; i < n; i++) {
CheckOutput(expressions, cpu_tensors, input_ids, output_ids, i);
fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
i);
}
}

Expand Down Expand Up @@ -443,7 +456,8 @@ TEST(code_generator, subgraph_grad) {

// Check the results
for (int i = 0; i < n; i++) {
CheckOutput(expressions, cpu_tensors, input_ids, output_ids, i);
fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids,
i);
}
}
#endif

0 comments on commit fe3d886

Please sign in to comment.