-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add sharding for gpt3 #1064
add sharding for gpt3 #1064
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please shorter the import path of DygraphShardingOptimizer
since it's not so friendly for developer.
@@ -30,26 +30,20 @@ | |||
import lr | |||
from paddle.distributed import fleet | |||
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker | |||
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's too long for this import ,is it ok to update the init file to shorter the import path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
paddle.seed(basic_seed + dp_rank) | ||
|
||
# local_seed/ global_seed is used to control dropout in ModelParallel | ||
local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed 的东西 可否都移入到这里 @zhaoyuchen2018
|
||
MODEL_CLASSES = { | ||
"gpt": (GPTForPretraining, GPTTokenizer), | ||
"gpt-cn": (GPTForPretraining, GPTChineseTokenizer), | ||
} | ||
|
||
|
||
def set_hyrbid_parallel_seed(basic_seed, dp_rank, mp_rank, pp_rank): | ||
def set_hyrbid_parallel_seed(basic_seed, idx): | ||
assert args.device != "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对,这里需要给一个 warning 吧, @ForFishes 后面可以加一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will fix in next pr.
@@ -174,7 +184,7 @@ def do_train(args): | |||
|
|||
clip = None | |||
if args.grad_clip > 0: | |||
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?为什么改掉ClipGradByGlobalNorm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已支持ClipGradByGlobalNorm。
@@ -227,8 +253,8 @@ def do_train(args): | |||
args, | |||
data_file, | |||
local_rank=local_rank, | |||
data_world_size=args.dp_degree, | |||
data_world_rank=dp_rank, | |||
data_world_size=worker_num, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不能这么改,要考虑MP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_world_size 只需要考虑读不同数据等rank? dp 和 sharding group 读取不同的数据, mp 和 pp 读取相同的数据,所以data_world_size = dp_degree * sharding_degree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看岔了,worker_index worker_num
你们重新定义了,sorry!
sharding_rank = hcg.get_sharding_parallel_rank() | ||
|
||
sharding_size = hcg.get_sharding_parallel_world_size() | ||
worker_index = dp_rank * sharding_size + sharding_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about?
worker_index -> data_world_rank
worker_num -> data_world_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
paddle.seed(basic_seed + dp_rank) | ||
|
||
# local_seed/ global_seed is used to control dropout in ModelParallel | ||
local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed 的东西 可否都移入到这里 @zhaoyuchen2018
if (global_step % args.save_steps == 0 or | ||
global_step >= args.max_steps) and dp_rank == 0: | ||
global_step >= args.max_steps) and worker_index == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
静态图是专门写了 save_persistable 来save,这里的sharding,需要其他专门的save支持吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
} | ||
|
||
strategy.pipeline_configs = { | ||
"accumulate_steps": args.local_batch_size // args.micro_batch_size, | ||
"micro_batch_size": args.micro_batch_size | ||
} | ||
|
||
strategy.tensor_parallel_configs = {"tensor_init_seed": 123, } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's tensor_init_seed
for ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set control in tensor parallel
@@ -322,24 +348,25 @@ def do_train(args): | |||
logger.info("Save model to %s" % output_dir) | |||
|
|||
if args.pp_degree > 1: | |||
model_to_save.save_state_dict(output_dir) | |||
if mp_rank * pp_rank == 1: | |||
if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所以目前的策略是,dp_rank == 0
下所有参数都保存。
现阶段,可以不加load 参数,但辛苦确认一下,load checkpoint会不会存在问题?
另外记一个合并参数的 TODO 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dp_rank==0下所有参数都保存这个逻辑没有问题,但是现在模型保存没有测试,所以可能会存在问题,比如是否存储了全量的参数。
|
||
MODEL_CLASSES = { | ||
"gpt": (GPTForPretraining, GPTTokenizer), | ||
"gpt-cn": (GPTForPretraining, GPTChineseTokenizer), | ||
} | ||
|
||
|
||
def set_hyrbid_parallel_seed(basic_seed, dp_rank, mp_rank, pp_rank): | ||
def set_hyrbid_parallel_seed(basic_seed, idx): | ||
assert args.device != "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will fix in next pr.
local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) | ||
|
||
# seed control in hybrid parallel | ||
set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank) | ||
set_hyrbid_parallel_seed(args.seed, data_world_rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
所有seed都在这里设置吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新
tokenizer.save_pretrained(output_dir) | ||
model_to_save.save_state_dict(output_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
模型保存这里,没有详细的测试。需要记个TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Models
Description
add sharding parallel for gpt-3