Skip to content

Commit

Permalink
Improve the loading statements in generated codes.
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
Xreki committed Nov 14, 2019
1 parent fe3d886 commit 82e5670
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 31 deletions.
23 changes: 13 additions & 10 deletions paddle/fluid/framework/ir/fusion_group/code_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,26 +155,29 @@ std::string CodeGenerator::EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype) {
std::stringstream ret;
std::ostringstream compute;
std::unordered_set<int> used;
for (size_t i = 0; i < expressions.size(); i++) {
VLOG(3) << DebugString(expressions[i]);
compute << expressions[i].GetExpression(dtype, &used);
}

// Load input to temporal variables.
std::ostringstream load;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
}

for (size_t i = 0; i < expressions.size(); i++) {
VLOG(3) << DebugString(expressions[i]);
ret << expressions[i].GetExpression(dtype);
}

// Store temporal variables to memory.
std::ostringstream store;
for (auto id : output_ids) {
ret << ArgName(id) << "[idx] = " << TmpName(id) << ";";
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}

return ret.str();
return load.str() + compute.str() + store.str();
}

std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
Expand Down
11 changes: 8 additions & 3 deletions paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ static T StringTo(const std::string& str) {
return value;
}

std::string OperationExpression::GetRHS(size_t i) const {
std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t i) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i];
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i;
Expand All @@ -47,7 +48,10 @@ std::string OperationExpression::GetRHS(size_t i) const {
PADDLE_ENFORCE_LT(index, input_ids_.size(),
"Only %d inputs are provided, but need %d.",
input_ids_.size(), index + 1);
PADDLE_ENFORCE_GE(input_ids_[index], 0,
"Input id should be no less than 0.");
rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
used->insert(input_ids_[index]);
}
}
return rhs;
Expand All @@ -65,11 +69,12 @@ bool OperationExpression::IsSupport() const {

// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std::string OperationExpression::GetExpression(std::string dtype) const {
std::string OperationExpression::GetExpression(
std::string dtype, std::unordered_set<int>* used) const {
std::stringstream ret;
if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << dtype << " " << GetLHS(i) << " = " << GetRHS(i) << ";";
ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";";
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ limitations under the License. */

#pragma once

#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/platform/enforce.h"
Expand Down Expand Up @@ -47,11 +47,12 @@ class OperationExpression {
// Check whether this operation type is supported in OperationMap.
bool IsSupport() const;

std::string GetExpression(std::string dtype) const;
std::string GetExpression(std::string dtype,
std::unordered_set<int>* used) const;

private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(size_t i = 0) const;
std::string GetRHS(std::unordered_set<int>* used, size_t i = 0) const;
std::string GetLHS(size_t i = 0) const;

private:
Expand Down
32 changes: 17 additions & 15 deletions paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,22 @@ void TestMainImpl(std::string func_name, std::string code_str,
std::vector<void*> args;
args.push_back(&n);

for (size_t i = 0; i < input_ids.size(); ++i) {
gpu_ptrs[input_ids[i]] = gpu_tensors[input_ids[i]].mutable_data<float>(
cpu_tensors[input_ids[i]].dims(), place);
args.push_back(&gpu_ptrs[input_ids[i]]);

fusion_group::SetupRandomCPUTensor<float>(&cpu_tensors[input_ids[i]]);
TensorCopySync(cpu_tensors[input_ids[i]], place,
&gpu_tensors[input_ids[i]]);
for (auto id : input_ids) {
if (id >= 0) {
gpu_ptrs[id] =
gpu_tensors[id].mutable_data<float>(cpu_tensors[id].dims(), place);
fusion_group::SetupRandomCPUTensor<float>(&cpu_tensors[id]);
TensorCopySync(cpu_tensors[id], place, &gpu_tensors[id]);
} else {
gpu_ptrs[id] = nullptr;
}
args.push_back(&gpu_ptrs[id]);
}

for (size_t i = 0; i < output_ids.size(); ++i) {
gpu_ptrs[output_ids[i]] = gpu_tensors[output_ids[i]].mutable_data<float>(
cpu_tensors[output_ids[i]].dims(), place);
args.push_back(&gpu_ptrs[output_ids[i]]);
for (auto id : output_ids) {
gpu_ptrs[id] =
gpu_tensors[id].mutable_data<float>(cpu_tensors[id].dims(), place);
args.push_back(&gpu_ptrs[id]);
}

device_code.SetNumThreads(1024);
Expand All @@ -200,9 +202,9 @@ void TestMainImpl(std::string func_name, std::string code_str,
paddle::platform::DeviceContextPool::Instance().Get(place));
dev_ctx->Wait();

for (size_t i = 0; i < output_ids.size(); ++i) {
TensorCopySync(gpu_tensors[output_ids[i]], paddle::platform::CPUPlace(),
&cpu_tensors[output_ids[i]]);
for (auto id : output_ids) {
TensorCopySync(gpu_tensors[id], paddle::platform::CPUPlace(),
&cpu_tensors[id]);
}
}

Expand Down

0 comments on commit 82e5670

Please sign in to comment.