Skip to content

Commit

Permalink
[DNNL] 3D Fully-Connected (#21746)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sand3r- authored and luotao1 committed Jan 3, 2020
1 parent c1fea3e commit 6192108
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 129 deletions.
7 changes: 4 additions & 3 deletions paddle/fluid/framework/ir/fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
// This is to add padding for dimension 128 on concern of MKL performance
auto* scope = param_scope();
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
auto place = weight->place();
bool use_gpu = Get<bool>("use_gpu");
auto* weight_data = weight->data<float>();
auto weight_dims = weight->dims();
int weight_num = product(weight_dims);
int w_h = weight_dims[0];
int w_w = weight_dims[1];
if (!use_gpu) {
bool use_gpu = Has("use_gpu") ? Get<bool>("use_gpu") : false;
bool use_fc_padding =
Has("use_fc_padding") ? Get<bool>("use_fc_padding") : true;
if (!use_gpu && use_fc_padding) {
if (w_h % 128 == 0 && w_w % 128 == 0) {
auto* weight_data_tmp = new float[weight_num];
for (int i = 0; i < w_h; i++) {
Expand Down
38 changes: 37 additions & 1 deletion paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,47 @@ void IRPassManager::CreatePasses(Argument *argument,
}
}

bool IRPassManager::HasPass(const std::string &pass_type) {
if (passes_.empty()) return false;
auto it = std::find_if(
passes_.begin(), passes_.end(),
[&](std::unique_ptr<Pass> &pass) { return pass->Type() == pass_type; });
return it != passes_.end();
}

std::unique_ptr<Pass> &IRPassManager::GetPass(const std::string &pass_type) {
PADDLE_ENFORCE_EQ(passes_.empty(), false,
platform::errors::PreconditionNotMet(
"The list of passes cannot be empty."));
auto it = std::find_if(passes_.begin(), passes_.end(),
[&](const std::unique_ptr<Pass> &pass) {
return pass->Type() == pass_type;
});
PADDLE_ENFORCE_NE(it, passes_.end(),
platform::errors::PermissionDenied(
"You cannot get pass which was not added earlier."));
return *it;
}

// Some passes depend on each other. This method serves for exchanging
// information between them.
void IRPassManager::UpdatePasses() {
// Update padding settings for fc_fuse_pass. Skipp adding padding for
// MKL-DNN-based FC
bool use_fc_padding = !HasPass("fc_mkldnn_pass");
if (HasPass("fc_fuse_pass")) {
auto &fc_fuse_pass = GetPass("fc_fuse_pass");
fc_fuse_pass->Set<bool>("use_fc_padding", new bool(use_fc_padding));
}
}

std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE(graph.get());
PADDLE_ENFORCE_NOT_NULL(graph.get(), platform::errors::PreconditionNotMet(
"Graph cannot be NULL."));
UpdatePasses();
// Apply all the passes
for (const auto &pass : passes_) {
if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/inference/analysis/ir_pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace inference {
namespace analysis {
using framework::ProgramDesc;
using framework::ir::Graph;
using framework::ir::Pass;

class IRPassManager final {
public:
Expand All @@ -53,9 +54,12 @@ class IRPassManager final {

private:
void CreatePasses(Argument *argument, const std::vector<std::string> &passes);
bool HasPass(const std::string &pass_type);
std::unique_ptr<Pass> &GetPass(const std::string &pass_type);
void UpdatePasses();

std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<framework::ir::Pass>> passes_;
std::vector<std::unique_ptr<Pass>> passes_;
bool disable_logs_{false};
};

Expand Down
Loading

0 comments on commit 6192108

Please sign in to comment.