Skip to content

Commit

Permalink
Merge pull request #620 from FeixLiu/fused_linear_and_selective_recom…
Browse files Browse the repository at this point in the history
…pute

fused linear and selective recompute
  • Loading branch information
ForFishes committed Aug 11, 2022
2 parents da262d7 + 49ad9e0 commit 6c12050
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 46 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,4 @@ ENV/
.mypy_cache/

.DS_Store
.idea
4 changes: 4 additions & 0 deletions examples/gpt/hybrid_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
num_train_epochs: 1
seed: 1024
use_recompute: False
recompute_granularity:
batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -113,6 +114,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
save_steps: 1000
output_dir: ./output
ckpt_dir:
fused_linear: False
```

其中参数说明:
Expand All @@ -124,6 +126,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| num_train_epochs | 训练的epoch数量 |
| seed | 随机种子,保证训练过程可复现 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| global_batch_size | 全局的batch size大小,即一次参数更新等效的batch size |
| local_batch_size | 每个进程训练的batch size大小 |
| micro_batch_size | 每次前向计算的batch size大小 |
Expand All @@ -138,6 +141,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| save_steps | 保存模型间隔 |
| output_dir | 指定输出文件 |
| ckpt_dir | checkpoint的加载目录 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |


### 并行维度
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/hybrid_parallel/configs_1.3B_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 64
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/hybrid_parallel/configs_175B_mp8_pp16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 1536
local_batch_size: 1536
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 51200
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/hybrid_parallel/configs_6.7B_sharding16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 128
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
4 changes: 4 additions & 0 deletions examples/gpt/single/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
num_train_epochs: 1
seed: 1024
use_recompute: False
recompute_granularity:
batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -103,6 +104,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
save_steps: 1000
output_dir: ./output
ckpt_dir:
fused_linear: False
```

其中参数说明:
Expand All @@ -114,6 +116,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| num_train_epochs | 训练的epoch数量 |
| seed | 随机种子,保证训练过程可复现 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| global_batch_size | 全局的batch size大小,即一次参数更新等效的batch size |
| local_batch_size | 每个进程训练的batch size大小 |
| micro_batch_size | 每次前向计算的batch size大小 |
Expand All @@ -128,6 +131,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| save_steps | 保存模型间隔 |
| output_dir | 指定输出文件 |
| ckpt_dir | checkpoint的加载目录 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |


## 运行方式
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/single/configs_1.3B_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/single/configs_345m_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: False
recompute_granularity:
batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
25 changes: 25 additions & 0 deletions examples/gpt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from __future__ import division
from __future__ import print_function

import logging
import os
import sys

import yaml
import paddle
import paddle.distributed as dist
from paddle.fluid import core
import argparse
from fleetx.datasets.gpt import create_pretrained_dataset, get_train_data_file

Expand Down Expand Up @@ -49,6 +51,13 @@ def process_batch_size(args):
assert args.local_batch_size % args.micro_batch_size == 0


def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(core.ops, 'fused_gemm_epilogue')
else:
return False


def model_size(args):
"""
get model size for transformer
Expand Down Expand Up @@ -84,6 +93,22 @@ def add_dict(config, k, v):

args.test_iters = args.eval_iters * 10

if args.fused_linear and not is_fused_matmul_bias_supported():
args.fused_linear = False
logging.warning("The flag fused_linear only valid for cuda version higher than 11.6, "
"but the paddle is compiled with cuda " + paddle.version.cuda())

if args.recompute:
assert args.recompute_granularity is None or \
isinstance(args.recompute_granularity, str), \
"recompute_granularity must be a None or a string object"
if args.recompute_granularity is None:
args.recompute_granularity = "full"
else:
assert args.recompute_granularity in ["full", "only_attn"], \
"recompute_granularity can be only chosen from " \
"full or only_attn, but received " + args.recompute_granularity

# process batch size
process_batch_size(args)

Expand Down
65 changes: 46 additions & 19 deletions fleetx/models/gpt_model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.

import collections
import logging

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
Expand All @@ -24,6 +26,7 @@
import paddle.incubate as incubate
from paddle.distributed.fleet.utils import recompute
from .config import configurable
from paddle.incubate.nn import FusedLinear


class MultiHeadAttention(nn.Layer):
Expand All @@ -46,7 +49,8 @@ def __init__(self,
need_weights=False,
weight_attr=None,
bias_attr=None,
fuse=True):
fuse=True,
fused_linear=False):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -59,19 +63,21 @@ def __init__(self,
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

Linear = FusedLinear if fused_linear else nn.Linear

if self.fuse:
assert self.kdim == embed_dim
assert self.vdim == embed_dim
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.kdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.vdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.out_proj = Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)

def _fuse_prepare_qkv(self, query):
Expand Down Expand Up @@ -221,13 +227,15 @@ def __init__(self,
num_layers,
norm=None,
hidden_size=None,
use_recompute=False):
use_recompute=False,
recompute_granularity="full"):
super(TransformerDecoder, self).__init__()

self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
if norm == "LayerNorm":
self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5)
elif norm is not None:
Expand Down Expand Up @@ -258,9 +266,10 @@ def forward(self,
cache=cache)
new_caches.append(new_cache)
else:
output = recompute(mod, output, memory, tgt_mask, use_cache, cache) if self.use_recompute \
else mod(output, memory, tgt_mask, use_cache, cache)

if self.use_recompute and self.recompute_granularity == "full":
output = recompute(mod, output, memory, tgt_mask, use_cache, cache)
else:
output = mod(output, memory, tgt_mask, use_cache, cache)
else:
output, new_cache = mod(output,
memory,
Expand Down Expand Up @@ -304,7 +313,9 @@ def __init__(self,
act_dropout=None,
normalize_before=True,
weight_attr=None,
bias_attr=None):
bias_attr=None,
fused_linear=False,
recompute_attn=False):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -313,19 +324,23 @@ def __init__(self,
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.recompute_attn = recompute_attn

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)

Linear = FusedLinear if fused_linear else nn.Linear

self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0])
self.linear1 = nn.Linear(
bias_attr=bias_attrs[0],
fused_linear=fused_linear)
self.linear1 = Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(
self.linear2 = Linear(
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2])

self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
Expand All @@ -341,7 +356,10 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
tgt = self.norm1(tgt)

if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
if self.recompute_attn:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache)
else:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
use_cache, cache)
Expand Down Expand Up @@ -421,10 +439,14 @@ def __init__(self,
max_position_embeddings=512,
type_vocab_size=16,
use_recompute=False,
initializer_range=0.02):
initializer_range=0.02,
fused_linear=False,
recompute_granularity="full"):

super(GPTModel, self).__init__()

recompute_attn = use_recompute and recompute_granularity == "only_attn"

self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand All @@ -447,14 +469,17 @@ def __init__(self,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None))
bias_attr=None,
fused_linear=fused_linear,
recompute_attn=recompute_attn))

self.decoder = TransformerDecoder(
decoder_layers,
num_layers,
norm="LayerNorm",
hidden_size=hidden_size,
use_recompute=use_recompute)
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)

@classmethod
def from_config(cls, cfg):
Expand All @@ -469,7 +494,9 @@ def from_config(cls, cfg):
"max_position_embeddings": cfg.max_position_embeddings,
"type_vocab_size": cfg.type_vocab_size,
"initializer_range": cfg.initializer_range,
"use_recompute": cfg.use_recompute
"use_recompute": cfg.use_recompute,
"fused_linear": cfg.fused_linear,
"recompute_granularity": cfg.recompute_granularity
}

def forward(self,
Expand Down

0 comments on commit 6c12050

Please sign in to comment.