Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gradient merge with recompute, test=develop #27834

Merged
merged 5 commits into from
Oct 13, 2020

Conversation

mapingshuo
Copy link
Contributor

PR types

Bug fixes

PR changes

APIs

Describe

support gradient merge with recompute
usage:

import os

os.environ['FLAGS_enable_parallel_graph'] = "0"
os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = "0.98"
os.environ['FLAGS_sync_nccl_allreduce'] = "1"
os.environ['FLAGS_eager_delete_tensor_gb'] = "0"
os.environ['FLAGS_fuse_parameter_memory_size'] = "32"
os.environ['FLAGS_fuse_parameter_groups_size'] = "50"
os.environ['FLAGS_allocator_strategy']="naive_best_fit"

import numpy as np
import fleetx as X
import paddle
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import time
paddle.enable_static()

# FleetX help users to focus more on learning to train a large scale model
# if you want to learn how to write a model, fleetx is not for you
# focus more on engineering staff in fleet-x

use_gradient_merge = True
use_recompute = True
use_amp = False

configs = X.parse_train_configs()
if use_gradient_merge:
    batch_size=3
else:
    batch_size=12

configs.lr = 1e-4

fleet.init(is_collective=True)
# load Bert_large / Bert_base model
model = X.applications.BertLarge()
downloader = X.utils.Downloader()
local_path = downloader.download_from_bos(
    fs_yaml='https://fleet.bj.bcebos.com/small_datasets/yaml_example/wiki_cn.yaml',
    local_path='./data')
data_loader = model.get_train_dataloader(data_dir='{}'.format(local_path))
place = fluid.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 2
exec_strategy.num_iteration_per_drop_scope = 1
dist_strategy = fleet.DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.nccl_comm_num = 3

if use_amp:
   dist_strategy.amp = True

if use_gradient_merge:
    dist_strategy.gradient_merge = True
    dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}

if use_recompute:
    dist_strategy.recompute = True
    dist_strategy.recompute_configs = {"checkpoints": model.checkpoints}


print("base lr: ", configs.lr)
if use_gradient_merge:
    scheduled_lr = X.utils.linear_warmup_decay(configs.lr, warmup_steps=16000,
                                               num_train_steps=1000000)
else:
    scheduled_lr = X.utils.linear_warmup_decay(configs.lr, warmup_steps=4000,
                                               num_train_steps=1000000)

optimizer = fluid.optimizer.Adam(learning_rate=scheduled_lr)

# if use_recompute:
#     optimizer = fluid.optimizer.RecomputeOptimizer(optimizer)
#     optimizer._set_checkpoints(model.checkpoints)

optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)

optimizer.minimize(model.loss)

exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

if not os.path.isdir("program_desc"):
    os.mkdir("program_desc")

with open("program_desc/main_program.txt.%d" % (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f:
    f.write(str(fluid.default_main_program()))

with open("program_desc/startup_program.txt.%d" % (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f:
    f.write(str(fluid.default_startup_program()))

skip = 0
gap = 20
total_time = 0
costs = []
ppls = []
next_sent_accs = []
for i, data in enumerate(data_loader()):
    if i == skip:
        start_time = time.time()
    fetch_list = [model.loss.name] + \
                                  list(model.target.values()) + \
                                  [scheduled_lr.name]
    cost_val, next_sent_acc, lm_loss, np_lr = exe.run(fluid.default_main_program(),
                       feed=data,
                       fetch_list=fetch_list,
                       use_program_cache=True)
    costs.append(cost_val[0])
    ppls.append(np.exp(lm_loss[0]))
    next_sent_accs.append(next_sent_acc[0])

    if i >= skip and i % gap == 0:
        end_time = time.time()
        total_time = (end_time - start_time)
        print("learning rate: ", np_lr[0])
        print(
            "worker_index: %d, step%d cost = %f, "
            "ppl: %f, next_sent_acc: %f, "
            "total time cost = %f, "
            " %f steps/s"
            % (fleet.worker_index(), i, np.mean(costs),
               np.mean(ppls), np.mean(next_sent_accs),
               total_time,
               (i - skip + 1) / total_time))

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@mapingshuo mapingshuo requested review from wangxicoding and guru4elephant and removed request for wangxicoding October 13, 2020 02:11
@@ -17,20 +17,26 @@

class GradientMergeOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
print("init Meta GradientMergeOptimizer with {}".format(optimizer))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this print?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, I will remove this soon.

Copy link
Contributor

@wangxicoding wangxicoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mapingshuo mapingshuo merged commit 8d2cb14 into PaddlePaddle:develop Oct 13, 2020
chen-zhiyu pushed a commit to chen-zhiyu/Paddle that referenced this pull request Oct 15, 2020
* support gradient merge with recompute, test=develop
test=develop
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants