Skip to content

Commit

Permalink
[with_data_parallel][install_check] remove with_data_parallel install…
Browse files Browse the repository at this point in the history
…_check (#50866)

* modify install check: static graph parallel training to dynamic graph prallel training

* remove test code

* fix cyclic import

* fix typo
  • Loading branch information
kangguangli committed Feb 28, 2023
1 parent 1e02769 commit 16a1b4a
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions python/paddle/utils/install_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,14 @@ def _simple_network():
return input, out, weight


def _prepare_data(device_count):
def _prepare_data():
"""
Prepare feeding data for simple network. The shape is [device_count, 2, 2].
Prepare feeding data for simple network. The shape is [1, 2, 2].
Args:
device_count (int): The number of devices.
"""
# Prepare the feeding data.
np_input_single = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
if device_count == 1:
return np_input_single.reshape(device_count, 2, 2)
else:
input_list = []
for i in range(device_count):
input_list.append(np_input_single)
np_input_muti = np.array(input_list)
np_input_muti = np_input_muti.reshape(device_count, 2, 2)
return np_input_muti
return np_input_single.reshape(1, 2, 2)


def _is_cuda_available():
Expand Down Expand Up @@ -134,7 +124,7 @@ def _run_dygraph_single(use_cuda, use_xpu, use_npu):
linear = paddle.nn.Linear(
2, 4, weight_attr=weight_attr, bias_attr=bias_attr
)
input_np = _prepare_data(1)
input_np = _prepare_data()
input_tensor = paddle.to_tensor(input_np)
linear_out = linear(input_tensor)
out = paddle.tensor.sum(linear_out)
Expand Down Expand Up @@ -178,13 +168,55 @@ def _run_static_single(use_cuda, use_xpu, use_npu):
exe.run(startup_prog)
exe.run(
train_prog,
feed={input.name: _prepare_data(1)},
feed={input.name: _prepare_data()},
fetch_list=[out.name, param_grads[1].name],
)
paddle.disable_static()


def _run_static_parallel(use_cuda, use_xpu, use_npu, device_list):
def train_for_run_parallel():
"""
train script for parallel traning check
"""

# to avoid cyclic import
class LinearNet(paddle.nn.Layer):
"""
simple fc network for parallel training check
"""

def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = paddle.nn.Linear(10, 10)
self._linear2 = paddle.nn.Linear(10, 1)

def forward(self, x):
"""
forward
"""
return self._linear2(self._linear1(x))

paddle.distributed.init_parallel_env()

layer = LinearNet()
dp_layer = paddle.DataParallel(layer)

loss_fn = paddle.nn.MSELoss()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=dp_layer.parameters()
)

inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)

loss.backward()
adam.step()
adam.clear_grad()


def _run_parallel(device_list):
"""
Testing the simple network in data parallel mode, using multiple CPU/GPU.
Expand All @@ -194,39 +226,7 @@ def _run_static_parallel(use_cuda, use_xpu, use_npu, device_list):
use_npu (bool): Whether running with NPU.
device_list (int): The specified devices.
"""
paddle.enable_static()
with paddle.static.scope_guard(paddle.static.Scope()):
train_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(train_prog, startup_prog):
input, out, _ = _simple_network()
loss = paddle.tensor.mean(out)
loss.persistable = True
paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)

compiled_prog = paddle.static.CompiledProgram(
train_prog
).with_data_parallel(loss_name=loss.name, places=device_list)

if use_cuda:
place = paddle.CUDAPlace(0)
elif use_xpu:
place = paddle.XPUPlace(0)
compiled_prog = train_prog
elif use_npu:
place = paddle.NPUPlace(0)
compiled_prog = train_prog
else:
place = paddle.CPUPlace()

exe = paddle.static.Executor(place)
exe.run(startup_prog)
exe.run(
compiled_prog,
feed={input.name: _prepare_data(len(device_list))},
fetch_list=[loss.name],
)
paddle.disable_static()
paddle.distributed.spawn(train_for_run_parallel, nprocs=len(device_list))


def run_check():
Expand Down Expand Up @@ -280,7 +280,7 @@ def run_check():
print("PaddlePaddle works well on 1 {}.".format(device_str))

try:
_run_static_parallel(use_cuda, use_xpu, use_npu, device_list)
_run_parallel(device_list)
print(
"PaddlePaddle works well on {} {}s.".format(
device_count, device_str
Expand Down

0 comments on commit 16a1b4a

Please sign in to comment.