Skip to content

Latest commit

 

History

History
27 lines (16 loc) · 919 Bytes

shard_optimizer_cn.rst

File metadata and controls

27 lines (16 loc) · 919 Bytes

shard_optimizer

将单卡视角的优化器转变为分布式视角。可以通过指定 shard_fn 来定制化优化器状态的切分方式,否则会将参数的分布式信息传递给对应的优化器状态。

shard_fn 的函数签名为:def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator。

参数

  • optimizer (paddle.optimizer.Optimizer) - 单卡视角的优化器。
  • shard_fn (Callable,可选) - 用于切分优化器状态函数。如果没有指定,默认地我们将参数的分布式信息传递给对应的优化器状态。

返回

Optimizer:一个具有分布式视角的 Optimizer 对象。

代码示例

COPY-FROM: paddle.distributed.shard_optimizer