Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] add 'to_static' in engine api #44202

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/dist_context.py
Expand Up @@ -125,6 +125,9 @@ def __init__(self,
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False

# flag whether using `to_static`
self._dygraph_mode = True

@property
def serial_main_program(self):
return self._serial_main_program
Expand Down
159 changes: 137 additions & 22 deletions python/paddle/distributed/auto_parallel/engine.py
Expand Up @@ -21,14 +21,15 @@

from paddle import fluid, static
from paddle.io import Dataset
from paddle.jit import to_static
from paddle.metric import Metric
from paddle.static import InputSpec
from paddle.fluid import core
from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope, _to_name_str
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(self,
self._feed_vars = {}
self._fetch_vars = {}
self._planners = {}
self._dygraph_mode = False

def prepare(self,
optimizer=None,
Expand Down Expand Up @@ -131,27 +133,110 @@ def prepare(self,

def _build(self, mode):

serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return

losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))

if mode != "predict":
for metric in self._metrics:
metrics.extend(to_list(metric.compute(*(outputs + labels))))
if _non_static_mode() or self._dygraph_mode:
self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.")

# build forward main program
self.static_model = to_static(self.model,
input_spec=self.inputs_spec)
inputs = self.static_model.forward.inputs
outputs = self.static_model.forward.outputs
forward_main_prog = self.static_model.forward.main_program
forward_startup_prog = self.static_model.forward.concrete_program.startup_program
self.concrete_program = self.static_model.forward.concrete_program

# build loss main program
outputs_spec = []
outputs_name = []
for out in outputs:
outputs_spec.append(InputSpec(out.shape, out.dtype, out.name))
outputs_name.append(out.name)
if isinstance(self._loss, paddle.nn.Layer):
self.static_loss = to_static(self._loss.forward,
input_spec=outputs_spec +
self.labels_spec)
loss_main_prog = self.static_loss.main_program
elif callable(self._loss):
self.static_loss = to_static(self._loss,
input_spec=outputs_spec +
self.labels_spec)
loss_main_prog = self.static_loss.main_program

# build startup program
for param in self.concrete_program.parameters:
Parameter(name=param.name,
desc=param,
type=param.type,
shape=param.shape,
dtype=param.dtype,
stop_gradient=param.stop_gradient,
block=forward_startup_prog.global_block())

paddle.enable_static()

# NOTE: pure program will loss dist_attr
# feeded_var_names = [var.name for var in inputs]
# main_prog_0 = main_prog_0._prune_with_input(
# feeded_var_names=feeded_var_names, targets=outputs)

labels = []
losses = []
metrics = []
# concat forward and loss prog
if mode != 'predict' and self._loss:
forward_block = forward_main_prog.global_block()
loss_block = loss_main_prog.global_block()
for idx, op in enumerate(loss_block.ops):
op_desc = forward_block.desc.append_op()
op_desc.copy_from(op.desc)
for in_name in op.input_arg_names:
if in_name in outputs_name:
continue
in_var = forward_block._clone_variable(
loss_block.vars[in_name], force_persistable=False)
if loss_block.vars[in_name].is_data:
labels.append(in_var)
for out_name in op.output_arg_names:
out_var = forward_block._clone_variable(
loss_block.vars[out_name], force_persistable=False)
if idx == len(loss_block.ops) - 1:
losses.append(out_var)
forward_block._sync_with_cpp()
serial_main_prog = forward_main_prog
serial_startup_prog = forward_startup_prog
# update metrics op in program
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))

else:
# build program in static mode
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return

losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))

if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))

default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation:
Expand All @@ -172,6 +257,7 @@ def _build(self, mode):
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode

def _plan(self, mode):
if self._planned_mode is None:
Expand Down Expand Up @@ -236,6 +322,35 @@ def _initialize(self, mode):
self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace):
self._place = fluid.CUDAPlace(ParallelEnv().dev_id)

if self._dygraph_mode:
paddle.disable_static()
main_program = self._dist_main_progs[mode][self._cur_rank]
for param in self.concrete_program.parameters:
# create var in scope and share parameters to scope
if param.name not in main_program.global_block().vars:
continue
# get param_var's dist_attr
var = main_program.global_block().vars[param.name]
var_dist_attr = self._dist_contexts[
mode].get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
from .converter import Converter
param_tensor = global_scope().var(param.name).get_tensor()
sliced_param = Converter.slice_with_dist_attr(
param.numpy(), dist_attr)
shared_tensor = paddle.to_tensor(sliced_param,
place=self._place)
param_tensor._share_data_with(
shared_tensor.value().get_tensor())
paddle.enable_static()

if self._executor is None:
self._executor = paddle.static.Executor(self._place)
uninitialized = []
Expand Down
11 changes: 9 additions & 2 deletions python/paddle/distributed/auto_parallel/parallelizer_v2.py
Expand Up @@ -15,8 +15,10 @@
import copy
from collections import defaultdict

import paddle
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.passes import new_pass

from .reshard import Resharder
Expand Down Expand Up @@ -110,9 +112,14 @@ def _generate_backward(self, main_program, startup_program, loss):

def _generate_optimizer(self, main_program, startup_program, optimizer,
params_grads):
if self._dist_context._dygraph_mode:
paddle.disable_static()
optimizer = copy.deepcopy(optimizer)
paddle.enable_static()
else:
optimizer = copy.deepcopy(optimizer)
with program_guard(main_program, startup_program):
optimizer_ops = copy.deepcopy(optimizer).apply_gradients(
params_grads)
optimizer_ops = optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program)
return optimizer_ops

Expand Down
Expand Up @@ -53,4 +53,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
endif()
122 changes: 122 additions & 0 deletions python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py
@@ -0,0 +1,122 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import os
import numpy as np

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
import paddle.distributed.fleet as fleet

from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine

batch_size = 4
batch_num = 30
hidden_size = 1024
class_num = 10


class MyDataset(Dataset):

def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples

def __getitem__(self, index):
input = np.random.uniform(size=hidden_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label

def __len__(self):
return self.num_samples


class MLPLayer(nn.Layer):

def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))

self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=None)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=None)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=None)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")

def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)

return out


class TestToStatic(unittest.TestCase):

def test_to_static(self):

mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.SGD(learning_rate=0.00001,
parameters=mlp.parameters())

dataset = MyDataset(batch_num * batch_size)

inputs = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels = InputSpec([batch_size], 'int64', 'label')

engine = Engine(model=mlp,
inputs_spec=inputs,
labels_spec=labels,
strategy=None)
assert _non_static_mode() == True

engine.prepare(optimizer=optimizer,
loss=loss,
metrics=paddle.metric.Accuracy())

assert _non_static_mode() == False
engine.fit(dataset, batch_size=batch_size)
engine.evaluate(dataset, batch_size=batch_size)
engine.predict(dataset, batch_size=batch_size)


if __name__ == "__main__":
unittest.main()