Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor fusion for data parallel #634

Merged
merged 6 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/gpt/gpt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def training_step_end(self, loss, epoch, step, reader_cost, train_cost):
def configure_optimizers(self):
if self.args.decay_steps is None:
self.args.decay_steps = self.args.max_steps
self.decay_fused_tensors, self.all_fused_tensors = None, None
if self.args.tensor_fusion:
decay_fused_tensors, all_fused_tensors = fused_parameters(
self.decay_fused_tensors, self.all_fused_tensors = fused_parameters(
self.model)
warmup_step = self.args.warmup_rate * self.args.decay_steps
lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
Expand All @@ -93,7 +94,7 @@ def configure_optimizers(self):
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
if self.args.tensor_fusion:
decay_params = [p.name for p in decay_fused_tensors]
decay_params = [p.name for p in self.decay_fused_tensors]
else:
decay_params = [
p.name for n, p in self.model.named_parameters()
Expand All @@ -105,7 +106,7 @@ def configure_optimizers(self):
beta1=self.args.adam_beta1,
beta2=self.args.adam_beta2,
epsilon=self.args.adam_epsilon,
parameters=all_fused_tensors
parameters=self.all_fused_tensors
if self.args.tensor_fusion else self.model.parameters(),
weight_decay=self.args.weight_decay,
grad_clip=clip,
Expand Down
2 changes: 2 additions & 0 deletions examples/gpt/hybrid_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
output_dir: ./output
ckpt_dir:
fused_linear: False
tensor_fusion: False
```

其中参数说明:
Expand Down Expand Up @@ -146,6 +147,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| output_dir | 指定输出文件 |
| ckpt_dir | checkpoint的加载目录 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |
| tensor_fusion | 是否使用tensor_fustion功能加速训练。注:该选项仅支持数据并行的模式 |


### 并行维度
Expand Down
1 change: 1 addition & 0 deletions examples/gpt/single/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
output_dir: ./output
ckpt_dir:
fused_linear: False
tensor_fusion: False
```

其中参数说明:
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/single/configs_1.3B_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ PreTraining:
output_dir: ./output
ckpt_dir:
fused_linear: True
tensor_fusion: True
tensor_fusion: False

Model:
vocab_size: 50304
Expand Down
4 changes: 4 additions & 0 deletions examples/gpt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def add_dict(config, k, v):

args.test_iters = args.eval_iters * 10

if args.tensor_fusion:
assert args.mp_degree == 1 and args.pp_degree == 1 and args.sharding_degree == 1, \
"tensor_fusion only support single card train or data parallel train"

if args.fused_linear and not is_fused_matmul_bias_supported():
args.fused_linear = False
logging.warning(
Expand Down
8 changes: 6 additions & 2 deletions fleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from fleetx.utils import logger
from fleetx.core.engine.basic_engine import BasicEngine
from fleetx.core.module.basic_module import BasicModule
from fleetx.utils.tensor_fusion_helper import all_reduce_parameters


class EagerEngine(BasicEngine):
Expand Down Expand Up @@ -211,8 +212,11 @@ def _fit_impl(self, batch):
paddle.DataParallel):
with self._module.model.no_sync():
loss = self._model_forward_backward(batch)
fused_allreduce_gradients(
list(self._module.model.parameters()), None)
if not hasattr(self._module, "all_fused_tensors") or self._module.all_fused_tensors is None:
fused_allreduce_gradients(
list(self._module.model.parameters()), None)
else:
all_reduce_parameters(self._module.all_fused_tensors, self._dp_group)
else:
loss = self._model_forward_backward(batch)
self._optim_update_params()
Expand Down