Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support offload in sharding stage2 #37904

Merged
merged 11 commits into from
Dec 9, 2021
Merged

support offload in sharding stage2 #37904

merged 11 commits into from
Dec 9, 2021

Conversation

haohongxiang
Copy link
Contributor

@haohongxiang haohongxiang commented Dec 7, 2021

PR types

Function optimization

PR changes

Others

Describe

Support offload, grad_clip and loss_scaler in dygraph sharding stage2
Optimize the performance of offload in PR-38064

  1. 用户手册
import paddle
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler

fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])

# state model and optimizer
model = model_class(...)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
        parameters=model.parameters(),
        learning_rate=0.001,
        weight_decay=0.00001,
        grad_clip=clip,
        multi_precision=True)

# convert to pfp16 model
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler, group)

# convert to sharding_stage2 model and optimizer
optimizer = ShardingOptimizerStage2(params=model.parameters(), optim=optimizer, group=group, offload=True)
model = ShardingStage2(model, optimizer, group=group)

# forward, backward and optimization
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
model.clear_gradients()
  1. 精度验证
    PaddleNLP GPT-3模型,sharding stage2+pfp16 with/without offload:

a33f027fcf7196a3c584b70274a9f321

  1. 显存优化
    1> PaddleNLP GPT-3模型 0.31B参数量
    单机两卡,sharding stage2+pfp16 without offload,峰值显存为 5319 MiB,显存变化曲线为:

a0d121ff1c6af51a546bdfc9357dd23b

单机两卡,sharding stage2+pfp16 with offload,峰值显存为 3137 MiB(节省 2182 MiB,约 41%),显存变化曲线为:
e34fb2819aead45239a59d3775c9a07e

2> PaddleNLP GPT-3模型 1.02B参数量
单机两卡,sharding stage2+pfp16 without offload,峰值显存为 11941 MiB
单机两卡,sharding stage2+pfp16 with offload,峰值显存为 5369 MiB(节省 6572 MiB,约 55%)

@paddle-bot-old
Copy link

paddle-bot-old bot commented Dec 7, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

if param.name not in self._master_params.keys():
self._master_params[param.name] = core.VarBase(
name=param.name,
value=param.cast(dtype=Type.fp32.value).numpy(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也改成.value().get_tensor()吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的


for param in self._local_params:
if param.name in self._master_params.keys():
param.set_value(self._master_params[param.name].cuda(dev_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方会增加显存,需要先释放param,在shareddata master参数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit dfed4a6 into PaddlePaddle:develop Dec 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants