-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
group_sharded.py
251 lines (218 loc) · 10.3 KB
/
group_sharded.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import paddle
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
GroupShardedOptimizerStage2,
)
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import (
GroupShardedStage2,
)
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import (
GroupShardedStage3,
)
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler,
)
from paddle.distributed.utils.log_utils import get_logger
from paddle.optimizer import Optimizer
logger_ = get_logger(logging.WARNING)
def group_sharded_parallel(
model,
optimizer,
level,
scaler=None,
group=None,
offload=False,
sync_buffers=False,
buffer_max_size=2**23,
segment_size=2**20,
sync_comm=False,
dp_group=None,
exclude_layer=None,
):
"""
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation.
Args:
model (Layer): The layer to be wrapped with group_sharded_parallel.
optimizer (Optimizer): The optimizer to be wrapped with group_sharded_parallel.
level (str): The different level of the group sharded. Such as `os`, `os_g`, `p_g_os`.
scaler (GradScaler, optional): If AMP is used, you need to pass GradScaler. Defaults to None, indicating that GradScaler is not used.
group (Group, optional): The group instance. Defaults to None, indicating that the default environment group is used.
offload (bool, optional): Whether to use the offload function. Defaults to False, which means that the offload function is not used.
sync_buffers (bool, optional): Whether to broadcast model buffers. It is generally used when there are registered model buffers. Defaults to False, indicating that model buffers are not used.
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used.
dp_group(Group, optional): dp communication group, support to combine stage2 or stage3 with dp hybrid communication.
exclude_layer(list, optional): exclude some layers for slicing for sharding stage3, for example, exclude_layer=["GroupNorm", id(model.gpt.linear)], exclude_layer must contain the layers' name or one layer's id.
Returns:
model: A wrapper for group sharded given model.
optimizer: A wrapper for group sharded given optimizer.
scaler: A wrapper for group sharded given scaler.
Examples:
.. code-block:: python
# required: distributed
import paddle
from paddle.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel
fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
model = Linear(1000, 1000)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)
# wrap sharding model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
"""
device = paddle.get_device().split(":")[0]
assert device in [
"gpu",
"xpu",
], "group_sharded_parallel only support gpu and xpu now"
# check optition type
assert isinstance(
model, paddle.nn.Layer
), "The model must be the instance of paddle.nn.Layer."
assert isinstance(
optimizer, Optimizer
), "The optimizer must be the instance of paddle.optimizer.Optimizer."
assert level in [
'os',
'os_g',
'p_g_os',
], "The level must be os, os_g or p_g_os."
def check_dtype(param):
return param.dtype == paddle.float16
params_fp16 = list(filter(check_dtype, model.parameters()))
if scaler is None and len(params_fp16) > 0:
logger_.warning(
"the input of scaler is None, please ensure the logic of your scaler outside is same as GroupShardedScaler."
)
# convert model/optimizer/scaler
if level in ['os', 'os_g']:
logger_.info("*" * 30)
logger_.info("Sharded level os uses sharded level os_g achieved now.")
logger_.info("*" * 30)
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list,
optim=optimizer,
group=group,
offload=offload,
dp_group=dp_group,
device=device,
)
model = GroupShardedStage2(
model,
optimizer,
group=group,
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size,
dp_group=dp_group,
device=device,
)
elif level == 'p_g_os':
model = GroupShardedStage3(
model,
optimizer=optimizer,
group=group,
sync_buffers=sync_buffers,
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm,
dp_group=dp_group,
device=device,
exclude_layer=exclude_layer,
)
else:
raise ValueError("Please enter the correct level.")
if isinstance(scaler, paddle.amp.GradScaler):
scaler = GroupShardedScaler(scaler)
logger_.info("*" * 30)
logger_.info(
"If there is a communication hang using group sharded, please check whether the communication operations of each process are unified."
)
logger_.info("*" * 30)
return model, optimizer, scaler
def save_group_sharded_model(model, output, optimizer=None):
"""
Group sharded encapsulated model and optimizer state saving module.
Note:
If using save_group_sharded_model saves the model. When loading again, you need to set the model or optimizer state before using group_sharded_parallel.
Args:
model (Layer): A wrapper for group sharded given model.
output (str): Save directory.
optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None, indicating that the optimizer state is not saved.
Examples:
.. code-block:: python
# required: distributed
import paddle
from paddle.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model
fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
model = Linear(1000, 1000)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)
# wrap sharding model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
# save model and optimizer state_dict
save_group_sharded_model(model, optimizer, output=output_dir)
"""
logger_.info(
"==========Begin to save group sharded model and optimizer=========="
)
assert not os.path.isfile(
output
), f"Saving directory ({output}) should be a directory, not a file"
os.makedirs(output, exist_ok=True)
output_model = os.path.join(output, "model.pdmodel")
if isinstance(model, GroupShardedStage2):
paddle.save(model._layer.state_dict(), output_model)
elif isinstance(model, GroupShardedStage3):
convert2cpu = True if model._offload else False
model.get_all_parameters(convert2cpu=convert2cpu)
paddle.save(model._layer.state_dict(), output_model)
else:
raise ValueError(
"Please use the layer which is wrapped with group_sharded_parallel."
)
if optimizer is not None:
assert hasattr(
optimizer, "_optim"
), "Please use the optimizer which is wrapped with group_sharded_parallel."
output_opt = os.path.join(output, "model.pdopt")
paddle.save(optimizer._optim.state_dict(), output_opt)
logger_.info(
"==========End to save group sharded model and optimizer=========="
)