Skip to content

Commit

Permalink
[Dy2static] Refactor ProgramTranslator save_inference_model API (#24989)
Browse files Browse the repository at this point in the history
* experimental refactoring, test=develop

* add TranslatedLayer & remove StaticModelRunner, test=develop

* revert tracedlayer change, test=develop

* fix test_mnist unittest error, test=develop

* add doc & examples, test=develop

* polish doc details, test=develop

* add imperative.jit module, test=develop

* change TranslatedLayer pos, test=develop

* adjust jit module import path, test=develop

* polish doc based review result

* add SaveLoadConfig.separate_params to save paraams separately

* add Layer.buffer support, test=develop

* polish doc details based review result, test=develop

* polish details baesd review comments, test=develop

* add empty str check for param, test=develop

* add unittests, test=develop

* polish details based review comment, test=develop

* remove blanks in comment, test=develop

* polish doc details, test=develop

* update imperative doc link, test=develop

* add api attr for load, test=develop
  • Loading branch information
chenwhql committed Jul 14, 2020
1 parent 43f9f18 commit 41d2247
Show file tree
Hide file tree
Showing 11 changed files with 1,917 additions and 549 deletions.
16 changes: 8 additions & 8 deletions paddle/fluid/operators/run_program_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ static void CheckInputVarStatus(const Variable &var,
var.IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"The input variable %s of "
"RunProgram(Grad)Op(StaticModelRunner) holds "
"RunProgram(Grad)Op holds "
"wrong type. Expect type is LoDTensor, but receive type is %s.",
var_name, platform::demangle(framework::ToTypeName(var.Type()))));
PADDLE_ENFORCE_EQ(
var.Get<LoDTensor>().IsInitialized(), true,
platform::errors::InvalidArgument("The tensor in input variable %s of "
"RunProgram(Grad)Op(StaticModelRunner) "
"RunProgram(Grad)Op "
"is not initialized.",
var_name));
}
Expand All @@ -68,35 +68,35 @@ static void CheckOutputVarStatus(const Variable &src_var,
src_var.IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"The output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s internal scope holds "
"RunProgram(Grad)Op's internal scope holds "
"wrong type. Expect type is LoDTensor, but receive type is %s.",
var_name,
platform::demangle(framework::ToTypeName(src_var.Type()))));
PADDLE_ENFORCE_EQ(src_var.Get<LoDTensor>().IsInitialized(), true,
platform::errors::InvalidArgument(
"The tensor in output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s internal "
"RunProgram(Grad)Op's internal "
"scope is not initialized.",
var_name));
} else if (dst_var.IsType<SelectedRows>()) {
PADDLE_ENFORCE_EQ(
src_var.IsType<SelectedRows>(), true,
platform::errors::InvalidArgument(
"The output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s internal scope holds "
"RunProgram(Grad)Op's internal scope holds "
"wrong type. Expect type is SelectedRows, but receive type is %s.",
var_name,
platform::demangle(framework::ToTypeName(src_var.Type()))));
PADDLE_ENFORCE_EQ(src_var.Get<SelectedRows>().value().IsInitialized(), true,
platform::errors::InvalidArgument(
"The tensor in output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s "
"RunProgram(Grad)Op's "
"internal scope is not initialized.",
var_name));

} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The RunProgram(Grad)Op(StaticModelRunner) only support output "
"The RunProgram(Grad)Op only support output "
"variable of type LoDTensor or SelectedRows, "
"but received variable %s's type is %s",
var_name, platform::demangle(framework::ToTypeName(dst_var.Type()))));
Expand Down Expand Up @@ -143,7 +143,7 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
auto *var = scope->FindVar(var_names[i]);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("The output variable %s is not in "
"RunProgram(Grad)Op(StaticModelRunner)'"
"RunProgram(Grad)Op'"
"s internal scope.",
var_names[i]));
CheckOutputVarStatus(*var, *vars[i], var_names[i]);
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/dygraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from . import jit
from .jit import *

from . import io
from .io import *

from . import static_runner
from .static_runner import StaticModelRunner

Expand All @@ -63,5 +66,6 @@
__all__ += learning_rate_scheduler.__all__
__all__ += backward_strategy.__all__
__all__ += jit.__all__
__all__ += io.__all__
__all__ += rnn.__all__
__all__ += ['ProgramTranslator']
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
from paddle.fluid.annotations import deprecated

__all__ = ['ProgramTranslator', 'convert_to_static']

Expand Down Expand Up @@ -125,6 +126,9 @@ def __init__(self, func, args, kwargs):
self._args = args
self._kwargs = kwargs

dyfunc = getattr(func, '__wrapped__', func)
self._dyfunc_code = inspect.getsource(dyfunc)

def is_method(self):
return self._args and isinstance(self._args[0], layers.Layer)

Expand Down Expand Up @@ -198,7 +202,9 @@ def __key(self):
# Note: if dygraph function is a method of class,
# consider instance info as hash key.
if self.is_method():
return self._dyfunc, self._args[0]
# NOTE: we can use Layer's (instance + function code) as hash key.
# An instance will not hold two identical methods
return self._dyfunc_code, self._args[0]
else:
return self._dyfunc

Expand Down Expand Up @@ -312,6 +318,17 @@ def __getitem__(self, item):
self._caches[item] = self._build_once(item)
return self._caches[item]

def get_program(self, item):
if not isinstance(item, FunctionSpec):
raise ValueError(
"Input item's type should be FunctionSpec, but received %s" %
type(item))
if item not in self._caches:
raise RuntimeError(
"Failed to find program for input item, please decorate input function by `@declarative`."
)
return self._caches[item]

def last(self):
assert len(
self._caches) >= 1, "No valid cached program in ProgramCache."
Expand Down Expand Up @@ -633,6 +650,7 @@ def func(x):
source_code = ast_to_source_code(root_wrapper.node)
return source_code

@deprecated(since='2.0', instead="paddle.imperative.jit.save")
@switch_to_static_graph
def save_inference_model(self, dirname, feed=None, fetch=None):
"""
Expand Down
Loading

0 comments on commit 41d2247

Please sign in to comment.