Skip to content

Commit

Permalink
Add fleet for transformer benchmark (#5164)
Browse files Browse the repository at this point in the history
* add fleet, test=develop
  • Loading branch information
lilong12 authored Jan 8, 2021
1 parent 8510560 commit 969939e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 19 deletions.
7 changes: 7 additions & 0 deletions PaddleNLP/benchmark/transformer/configs/transformer.big.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,11 @@ dropout: 0.1
# Vocabularies in source and target should be same for weight sharing.
weight_sharing: True

# Use amp or not
use_amp: False
scale_loss: 1.0

# Whether to use multi-card/multi-node distributed training.
is_distributed: True

max_iter: None
4 changes: 4 additions & 0 deletions PaddleNLP/benchmark/transformer/static/run_pretrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

python -m paddle.distributed.launch \
--gpus="0,1" \
train.py
73 changes: 54 additions & 19 deletions PaddleNLP/benchmark/transformer/static/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pprint import pprint

import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed as dist

from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion
Expand All @@ -36,8 +37,14 @@ def parse_args():

def do_train(args):
paddle.enable_static()
places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places()
trainer_count = len(places)
if args.is_distributed:
fleet.init(is_collective=True)
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
places = paddle.CUDAPlace(gpu_id) if args.use_gpu else paddle.static.cpu_places()
trainer_count = 1 if args.use_gpu else len(places)
else:
places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places()
trainer_count = len(places)

# Set seed for CE
random_seed = eval(str(args.random_seed))
Expand Down Expand Up @@ -88,19 +95,38 @@ def do_train(args):
epsilon=float(args.eps),
parameters=transformer.parameters())

if args.is_distributed:
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.build_strategy = build_strategy
dist_strategy.execution_strategy = exec_strategy
dist_strategy.fuse_grad_size_in_MB = 16

if args.use_amp:
dist_strategy.amp = True
dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'init_loss_scaling': args.scale_loss,
}

optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
optimizer.minimize(avg_cost)

exe = paddle.static.Executor()
if args.is_distributed:
exe = paddle.static.Executor(places)
else:
exe = paddle.static.Executor()
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()

compiled_train_program = paddle.static.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
exe.run(startup_program)

build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()

compiled_train_program = paddle.static.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)

# the best cross-entropy value with label smoothing
loss_normalizer = -(
Expand All @@ -127,13 +153,22 @@ def do_train(args):
data = [data]
train_reader_cost = time.time() - batch_start

outs = exe.run(compiled_train_program,
feed=[{
'src_word': data[i][0],
'trg_word': data[i][1],
'lbl_word': data[i][2],
} for i in range(trainer_count)],
fetch_list=[sum_cost.name, token_num.name])
if args.is_distributed:
outs = exe.run(train_program,
feed=[{
'src_word': data[i][0],
'trg_word': data[i][1],
'lbl_word': data[i][2],
} for i in range(trainer_count)],
fetch_list=[sum_cost.name, token_num.name])
else:
outs = exe.run(compiled_train_program,
feed=[{
'src_word': data[i][0],
'trg_word': data[i][1],
'lbl_word': data[i][2],
} for i in range(trainer_count)],
fetch_list=[sum_cost.name, token_num.name])
scheduler.step()

train_batch_cost = time.time() - batch_start
Expand Down Expand Up @@ -176,7 +211,7 @@ def do_train(args):
batch_ips_avg.reset()

if step_idx % args.save_step == 0 and step_idx != 0:
if args.save_model:
if args.save_model and dist.get_rank() == 0:
model_path = os.path.join(
args.save_model, "step_" + str(step_idx), "transformer")
paddle.static.save(train_program, model_path)
Expand Down

0 comments on commit 969939e

Please sign in to comment.