Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into doc/api1
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhwinter committed Jun 15, 2018
2 parents 4970414 + 0ddc5d8 commit f4a49cb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 37 deletions.
79 changes: 48 additions & 31 deletions paddle/fluid/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,55 +43,37 @@ Scope& Scope::NewScope() const {
}

Variable* Scope::Var(const std::string& name) {
// acquire the lock when new var under this scope
std::unique_lock<std::mutex> lock(mutex_);
auto* v = FindVarLocally(name);
if (v != nullptr) return v;

v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
return VarInternal(name);
}

Variable* Scope::Var(std::string* name) {
auto var_name = string::Sprintf("%p.%d", this, vars_.size());
std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) {
*name = var_name;
*name = new_name;
}
return Var(var_name);
return VarInternal(new_name);
}

Variable* Scope::FindVar(const std::string& name) const {
// acquire the lock when find var
std::unique_lock<std::mutex> lock(mutex_);
return FindVarInternal(name);
}

Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name);
if (var != nullptr) {
return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name);
}

const Scope* Scope::FindScope(const Variable* var) const {
for (auto& kv : vars_) {
if (kv.second.get() == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
std::unique_lock<std::mutex> lock(mutex_);
return FindScopeInternal(var);
}

void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s;
kids_.clear();
}

std::vector<std::string> Scope::LocalVarNames() const {
std::unique_lock<std::mutex> lock(mutex_);
std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
Expand Down Expand Up @@ -127,6 +109,39 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {

void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const {
std::unique_lock<std::mutex> lock(mutex_);
RenameInternal(origin_name, new_name);
}

std::string Scope::Rename(const std::string& origin_name) const {
std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name);
return new_name;
}

Variable* Scope::VarInternal(const std::string& name) {
auto* v = FindVarLocally(name);
if (v != nullptr) return v;

v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
}

const Scope* Scope::FindScopeInternal(const Variable* var) const {
for (auto& kv : vars_) {
if (kv.second.get() == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}

void Scope::RenameInternal(const std::string& origin_name,
const std::string& new_name) const {
auto origin_it = vars_.find(origin_name);
PADDLE_ENFORCE(origin_it != vars_.end(),
"Cannot find original variable with name %s", origin_name);
Expand All @@ -137,10 +152,12 @@ void Scope::Rename(const std::string& origin_name,
vars_.erase(origin_it);
}

std::string Scope::Rename(const std::string& origin_name) const {
auto var_name = string::Sprintf("%p.%d", this, vars_.size());
Rename(origin_name, var_name);
return var_name;
Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name);
if (var != nullptr) {
return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
}

Variable* Scope::FindVarLocally(const std::string& name) const {
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,20 @@ class Scope {
// Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {}

// Called by Var.
Variable* VarInternal(const std::string& name);

// Called by FindScope.
const Scope* FindScopeInternal(const Variable* var) const;

// Called by Rename.
void RenameInternal(const std::string& origin_name,
const std::string& new_name) const;

// Called by FindVar recursively.
// Caller doesn't own the returned Variable.
Variable* FindVarInternal(const std::string& name) const;

// Called by FindVarInternal and Var.
// Caller doesn't own the returned Variable.
Variable* FindVarLocally(const std::string& name) const;

// Scope in `kids_` are owned by this class.
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/inference/tests/book/test_inference_nlp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ void ThreadRunInfer(
const int tid, paddle::framework::Scope* scope,
const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) {
// maybe framework:ProgramDesc is not thread-safe
paddle::platform::CPUPlace place;
paddle::framework::Executor executor(place);
auto& sub_scope = scope->NewScope();
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
auto inference_program =
paddle::inference::Load(&executor, scope, FLAGS_model_path);

Expand Down Expand Up @@ -182,8 +182,8 @@ TEST(inference, nlp) {
stop_ms = GetCurrentMs();
} else {
// 1. Define place, executor, scope
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
paddle::platform::CPUPlace place;
paddle::framework::Executor executor(place);

// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
Expand Down

0 comments on commit f4a49cb

Please sign in to comment.