Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused linear and selective recompute #620

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,4 @@ ENV/
.mypy_cache/

.DS_Store
.idea
4 changes: 4 additions & 0 deletions examples/gpt/hybrid_parallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
num_train_epochs: 1
seed: 1024
use_recompute: False
recompute_granularity:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个要填full吗?

Copy link
Contributor Author

@FeixLiu FeixLiu Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要,这个recompute是false,空着就行,在backend收到的是一个None

batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -113,6 +114,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
save_steps: 1000
output_dir: ./output
ckpt_dir:
fused_linear: False
```

其中参数说明:
Expand All @@ -124,6 +126,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| num_train_epochs | 训练的epoch数量 |
| seed | 随机种子,保证训练过程可复现 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| global_batch_size | 全局的batch size大小,即一次参数更新等效的batch size |
| local_batch_size | 每个进程训练的batch size大小 |
| micro_batch_size | 每次前向计算的batch size大小 |
Expand All @@ -138,6 +141,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| save_steps | 保存模型间隔 |
| output_dir | 指定输出文件 |
| ckpt_dir | checkpoint的加载目录 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练 |


### 并行维度
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/hybrid_parallel/configs_1.3B_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 64
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/hybrid_parallel/configs_175B_mp8_pp16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 1536
local_batch_size: 1536
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 51200
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/hybrid_parallel/configs_6.7B_sharding16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 128
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
4 changes: 4 additions & 0 deletions examples/gpt/single/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
num_train_epochs: 1
seed: 1024
use_recompute: False
recompute_granularity:
batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -103,6 +104,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
save_steps: 1000
output_dir: ./output
ckpt_dir:
fused_linear: False
```

其中参数说明:
Expand All @@ -114,6 +116,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| num_train_epochs | 训练的epoch数量 |
| seed | 随机种子,保证训练过程可复现 |
| use_recompute | 是否使用recompute训练 |
| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 |
| global_batch_size | 全局的batch size大小,即一次参数更新等效的batch size |
| local_batch_size | 每个进程训练的batch size大小 |
| micro_batch_size | 每次前向计算的batch size大小 |
Expand All @@ -128,6 +131,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| save_steps | 保存模型间隔 |
| output_dir | 指定输出文件 |
| ckpt_dir | checkpoint的加载目录 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练 |


## 运行方式
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/single/configs_1.3B_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: True
recompute_granularity: 'only_attn'
batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
4 changes: 3 additions & 1 deletion examples/gpt/single/configs_345m_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ PreTraining:
num_train_epochs: 1
seed: 1024
use_recompute: False
recompute_granularity:
batch_size:
global_batch_size: 8
local_batch_size: 8
Expand All @@ -22,7 +23,8 @@ PreTraining:
save_load:
save_steps: 1000
output_dir: ./output
ckpt_dir:
ckpt_dir:
fused_linear: True

Model:
vocab_size: 50304
Expand Down
24 changes: 24 additions & 0 deletions examples/gpt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import yaml
import paddle
import paddle.distributed as dist
from paddle.fluid import core
import argparse
from fleetx.datasets.gpt import create_pretrained_dataset, get_train_data_file

Expand Down Expand Up @@ -49,6 +50,13 @@ def process_batch_size(args):
assert args.local_batch_size % args.micro_batch_size == 0


def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(core.ops, 'fused_gemm_epilogue')
else:
return False


def model_size(args):
"""
get model size for transformer
Expand Down Expand Up @@ -84,6 +92,22 @@ def add_dict(config, k, v):

args.test_iters = args.eval_iters * 10

if args.fused_linear:
assert is_fused_matmul_bias_supported(), \
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
"The flag fused_linear only valid for cuda version higher than 11.6, "\
"but the paddle is compiled with cuda " + paddle.version.cuda()

if args.recompute:
assert args.recompute_granularity is None or \
isinstance(args.recompute_granularity, str), \
"recompute_granularity must be a None or a string object"
if args.recompute_granularity is None:
args.recompute_granularity = "full"
else:
assert args.recompute_granularity in ["full", "only_attn"], \
"recompute_granularity can be only chosen from " \
"full or only_attn, but received " + args.recompute_granularity

# process batch size
process_batch_size(args)

Expand Down
65 changes: 46 additions & 19 deletions fleetx/models/gpt_model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.

import collections
import logging

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
Expand All @@ -24,6 +26,7 @@
import paddle.incubate as incubate
from paddle.distributed.fleet.utils import recompute
from .config import configurable
from paddle.incubate.nn import FusedLinear


class MultiHeadAttention(nn.Layer):
Expand All @@ -46,7 +49,8 @@ def __init__(self,
need_weights=False,
weight_attr=None,
bias_attr=None,
fuse=True):
fuse=True,
fused_linear=False):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
Expand All @@ -59,19 +63,21 @@ def __init__(self,
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

Linear = FusedLinear if fused_linear else nn.Linear

if self.fuse:
assert self.kdim == embed_dim
assert self.vdim == embed_dim
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.kdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.vdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.out_proj = Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)

def _fuse_prepare_qkv(self, query):
Expand Down Expand Up @@ -221,13 +227,15 @@ def __init__(self,
num_layers,
norm=None,
hidden_size=None,
use_recompute=False):
use_recompute=False,
recompute_granularity="full"):
super(TransformerDecoder, self).__init__()

self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
if norm == "LayerNorm":
self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5)
elif norm is not None:
Expand Down Expand Up @@ -258,9 +266,10 @@ def forward(self,
cache=cache)
new_caches.append(new_cache)
else:
output = recompute(mod, output, memory, tgt_mask, use_cache, cache) if self.use_recompute \
else mod(output, memory, tgt_mask, use_cache, cache)

if self.use_recompute and self.recompute_granularity == "full":
output = recompute(mod, output, memory, tgt_mask, use_cache, cache)
else:
output = mod(output, memory, tgt_mask, use_cache, cache)
else:
output, new_cache = mod(output,
memory,
Expand Down Expand Up @@ -304,7 +313,9 @@ def __init__(self,
act_dropout=None,
normalize_before=True,
weight_attr=None,
bias_attr=None):
bias_attr=None,
fused_linear=False,
recompute_attn=False):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
Expand All @@ -313,19 +324,23 @@ def __init__(self,
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
self.recompute_attn = recompute_attn

weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)

Linear = FusedLinear if fused_linear else nn.Linear

self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0])
self.linear1 = nn.Linear(
bias_attr=bias_attrs[0],
fused_linear=fused_linear)
self.linear1 = Linear(
d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(
self.linear2 = Linear(
dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2])

self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
Expand All @@ -341,7 +356,10 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
tgt = self.norm1(tgt)

if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
if self.recompute_attn:
tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache)
else:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
use_cache, cache)
Expand Down Expand Up @@ -421,10 +439,14 @@ def __init__(self,
max_position_embeddings=512,
type_vocab_size=16,
use_recompute=False,
initializer_range=0.02):
initializer_range=0.02,
fused_linear=False,
recompute_granularity="full"):

super(GPTModel, self).__init__()

recompute_attn = use_recompute and recompute_granularity == "only_attn"

self.initializer_range = initializer_range
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand All @@ -447,14 +469,17 @@ def __init__(self,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None))
bias_attr=None,
fused_linear=fused_linear,
recompute_attn=recompute_attn))

self.decoder = TransformerDecoder(
decoder_layers,
num_layers,
norm="LayerNorm",
hidden_size=hidden_size,
use_recompute=use_recompute)
use_recompute=use_recompute,
recompute_granularity=recompute_granularity)

@classmethod
def from_config(cls, cfg):
Expand All @@ -469,7 +494,9 @@ def from_config(cls, cfg):
"max_position_embeddings": cfg.max_position_embeddings,
"type_vocab_size": cfg.type_vocab_size,
"initializer_range": cfg.initializer_range,
"use_recompute": cfg.use_recompute
"use_recompute": cfg.use_recompute,
"fused_linear": cfg.fused_linear,
"recompute_granularity": cfg.recompute_granularity
}

def forward(self,
Expand Down