Skip to content

Commit

Permalink
Tensor fusion for data parallel (#634)
Browse files Browse the repository at this point in the history
* tensor fusion for dp

* update the if else

* update readme

* update readme

* update yaml

* update
  • Loading branch information
FeixLiu committed Aug 16, 2022
1 parent 3329919 commit 3899b47
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 6 deletions.
7 changes: 4 additions & 3 deletions examples/gpt/gpt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def training_step_end(self, loss, epoch, step, reader_cost, train_cost):
def configure_optimizers(self):
if self.args.decay_steps is None:
self.args.decay_steps = self.args.max_steps
self.decay_fused_tensors, self.all_fused_tensors = None, None
if self.args.tensor_fusion:
decay_fused_tensors, all_fused_tensors = fused_parameters(
self.decay_fused_tensors, self.all_fused_tensors = fused_parameters(
self.model)
warmup_step = self.args.warmup_rate * self.args.decay_steps
lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
Expand All @@ -93,7 +94,7 @@ def configure_optimizers(self):
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
if self.args.tensor_fusion:
decay_params = [p.name for p in decay_fused_tensors]
decay_params = [p.name for p in self.decay_fused_tensors]
else:
decay_params = [
p.name for n, p in self.model.named_parameters()
Expand All @@ -105,7 +106,7 @@ def configure_optimizers(self):
beta1=self.args.adam_beta1,
beta2=self.args.adam_beta2,
epsilon=self.args.adam_epsilon,
parameters=all_fused_tensors
parameters=self.all_fused_tensors
if self.args.tensor_fusion else self.model.parameters(),
weight_decay=self.args.weight_decay,
grad_clip=clip,
Expand Down
2 changes: 2 additions & 0 deletions examples/gpt/hybrid_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
output_dir: ./output
ckpt_dir:
fused_linear: False
tensor_fusion: False
```

其中参数说明:
Expand Down Expand Up @@ -146,6 +147,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| output_dir | 指定输出文件 |
| ckpt_dir | checkpoint的加载目录 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |
| tensor_fusion | 是否使用tensor_fustion功能加速训练。注:该选项仅支持数据并行的模式 |


### 并行维度
Expand Down
1 change: 1 addition & 0 deletions examples/gpt/single/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
output_dir: ./output
ckpt_dir:
fused_linear: False
tensor_fusion: False
```

其中参数说明:
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/single/configs_1.3B_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ PreTraining:
output_dir: ./output
ckpt_dir:
fused_linear: True
tensor_fusion: True
tensor_fusion: False

Model:
vocab_size: 50304
Expand Down
4 changes: 4 additions & 0 deletions examples/gpt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def add_dict(config, k, v):

args.test_iters = args.eval_iters * 10

if args.tensor_fusion:
assert args.mp_degree == 1 and args.pp_degree == 1 and args.sharding_degree == 1, \
"tensor_fusion only support single card train or data parallel train"

if args.fused_linear and not is_fused_matmul_bias_supported():
args.fused_linear = False
logging.warning(
Expand Down
8 changes: 6 additions & 2 deletions fleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from fleetx.utils import logger
from fleetx.core.engine.basic_engine import BasicEngine
from fleetx.core.module.basic_module import BasicModule
from fleetx.utils.tensor_fusion_helper import all_reduce_parameters


class EagerEngine(BasicEngine):
Expand Down Expand Up @@ -211,8 +212,11 @@ def _fit_impl(self, batch):
paddle.DataParallel):
with self._module.model.no_sync():
loss = self._model_forward_backward(batch)
fused_allreduce_gradients(
list(self._module.model.parameters()), None)
if not hasattr(self._module, "all_fused_tensors") or self._module.all_fused_tensors is None:
fused_allreduce_gradients(
list(self._module.model.parameters()), None)
else:
all_reduce_parameters(self._module.all_fused_tensors, self._dp_group)
else:
loss = self._model_forward_backward(batch)
self._optim_update_params()
Expand Down

0 comments on commit 3899b47

Please sign in to comment.