-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
backward.py
executable file
·2741 lines (2411 loc) · 105 KB
/
backward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .proto import framework_pb2
from paddle.fluid import framework as framework
from paddle.fluid import program_guard
from . import core
import collections
import copy
import logging
from . import unique_name
from . import log_helper
import paddle.fluid
from .data_feeder import check_type
import warnings
from collections.abc import Sequence
import re
__all__ = [
'append_backward',
'gradients',
]
_logger = log_helper.get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class ProgramStats:
def __init__(self, block, ops):
self.block = block
self.ops = ops
self.op_deps = {} # op-> in_ops, out_ops
self.var_op_deps = {} # var as input op, var as output op
def get_input_nodes(self):
input_names = []
for name in self.var_op_deps:
if (
len(self.var_op_deps[name]["var_as_output_ops"]) == 0
and len(self.var_op_deps[name]["var_as_input_ops"]) > 0
):
if self.block.var(name).persistable:
continue
input_names.append(name)
for op in self.ops:
if op.desc.type() == "read":
input_names.extend(op.desc.output_arg_names())
return input_names
def get_reserved_vars(self):
var_name = []
for op in self.ops:
if op.desc.type() == "seed":
var_name.extend(op.desc.output_arg_names())
return var_name
def get_out_of_subgraph_vars(self, begin_op_idx, end_op_idx):
var_name = []
for i in range(begin_op_idx, end_op_idx, 1):
for name in self.ops[i].desc.output_arg_names():
if name in self.var_op_deps:
for idx in self.var_op_deps[name]["var_as_input_ops"]:
if idx >= end_op_idx:
var_name.append(name)
for name in self.ops[i].desc.input_arg_names():
if name in self.var_op_deps:
for idx in self.var_op_deps[name]["var_as_output_ops"]:
if idx < begin_op_idx:
var_name.append(name)
return var_name
def is_subgraph(self, var_group1, var_group2):
# should traverse from var_group1 to var_group2
# max op idx in var_group2
# min op idx in var_group1
min_op_idx = len(self.ops)
max_op_idx = -1
for name in var_group1:
if name not in self.var_op_deps:
return False, min_op_idx, max_op_idx
for name in var_group2:
if name not in self.var_op_deps:
return False, min_op_idx, max_op_idx
for name in var_group1:
op_idx = self.var_op_deps[name]["var_as_input_ops"]
for idx in op_idx:
min_op_idx = min(min_op_idx, idx)
for name in var_group2:
op_idx = self.var_op_deps[name]["var_as_output_ops"]
for idx in op_idx:
max_op_idx = max(max_op_idx, idx)
if min_op_idx >= max_op_idx:
return False, min_op_idx, max_op_idx
return True, min_op_idx, max_op_idx
def _update_segment_start(self, min_idx, pre_segment_end_idx):
"""
persist vars of amp-related cast should be included in recompute segment
"""
def is_amp_cast(op):
return (
op.desc.type() == 'cast'
and self.block.var(op.desc.input_arg_names()[0]).persistable
)
idx_ = min_idx - 1
updated_min_idx = min_idx
while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]):
_logger.info(
"found amp-cast op: {}, : {}".format(
self.ops[idx_].desc.type(),
self.ops[idx_].desc.input_arg_names()[0],
)
)
updated_min_idx = idx_
idx_ -= 1
else:
break
return updated_min_idx
def build_stats(self):
for i, op in enumerate(self.ops):
self.op_deps[i] = {"in_ops": [], "out_ops": []}
for j, name in enumerate(op.desc.input_arg_names()):
if name in self.var_op_deps:
self.op_deps[i]["in_ops"].extend(
self.var_op_deps[name]["var_as_output_ops"]
)
for j, name in enumerate(op.desc.input_arg_names()):
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_input_ops"].extend([i])
else:
self.var_op_deps[name] = {}
self.var_op_deps[name]["var_as_input_ops"] = [i]
self.var_op_deps[name]["var_as_output_ops"] = []
for j, name in enumerate(op.desc.output_arg_names()):
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_output_ops"].extend([i])
else:
self.var_op_deps[name] = {}
self.var_op_deps[name]["var_as_input_ops"] = []
self.var_op_deps[name]["var_as_output_ops"] = [i]
for op_idx in self.op_deps[i]["in_ops"]:
self.op_deps[op_idx]["out_ops"].extend([i])
def sort_checkpoints(self, checkpoints_name):
sorted_checkpoints = []
for name in checkpoints_name:
if name not in self.var_op_deps:
_logger.info(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name
)
elif self.var_op_deps[name]["var_as_output_ops"] == []:
# input nodes
sorted_checkpoints.append((name, -1))
else:
sorted_checkpoints.append(
(name, max(self.var_op_deps[name]["var_as_output_ops"]))
)
sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1])
return [x[0] for x in sorted_checkpoints]
def modify_forward_desc_for_recompute(self):
op_types = [op.desc.type() for op in self.ops]
if "dropout" not in op_types:
return
op_idx = 0
while op_idx < len(self.ops):
op = self.ops[op_idx]
if op.desc.type() != "dropout":
op_idx += 1
continue
# already insert seed op before dropout
if op.input('Seed') is not None and len(op.input('Seed')) == 1:
op_idx += 1
continue
# add a seed op so that the two dropout op can generate same output
op_unique_name = unique_name.generate("seed")
var_unique_name = unique_name.generate_with_ignorable_key(
".".join([op_unique_name, 'tmp'])
)
added_var = self.block.create_var(
name=var_unique_name,
dtype='int32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
seed = 0 if op.attr("fix_seed") is False else int(op.attr("seed"))
op_device_attr_name = (
core.op_proto_and_checker_maker.kOpDeviceAttrName()
)
op_device = ""
if op.desc.has_attr(op_device_attr_name):
op_device = op.desc.attr(op_device_attr_name)
# Setting the force_cpu of seed to true will make the output of seed in cpu memory,
# reduce the synchronous copy from GPU to CPU in dropout, and reduce the communication hang
added_op = self.block._insert_op(
index=op.idx,
type='seed',
inputs={},
outputs={'Out': [added_var]},
attrs={'seed': seed, 'op_device': op_device, 'force_cpu': True},
)
self.ops.insert(op_idx, added_op)
# modify dropout op desc so that it accept a seed var as input
op.desc.set_input("Seed", [var_unique_name])
op.desc.remove_attr("fix_seed")
op.desc.remove_attr("seed")
self.block._sync_with_cpp()
op_idx += 2
def _pretty_op_desc_(op_desc, prefix):
out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % (
prefix + "_op",
str(op_desc.type()),
prefix + "_input",
" ".join(op_desc.input_arg_names()),
prefix + "_output",
" ".join(op_desc.output_arg_names()),
)
return out_s
def _add_needed_descs_to_block(
descs, block, main_block, in_memory_vars, grad_op_id_to_fwd_op=None
):
if len(descs) == 0:
return []
result_descs = []
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
origin_desc = desc
origin_is_operator = False
if isinstance(desc, framework.Operator):
desc = desc.desc
origin_is_operator = True
if isinstance(desc, tuple):
desc = desc[0]
is_needed = False
for name in desc.output_arg_names():
if main_block.has_var(name) and main_block.var(name).persistable:
continue
if name not in in_memory_vars:
is_needed = True
if is_needed:
if origin_is_operator and grad_op_id_to_fwd_op is not None:
grad_op_id_to_fwd_op[desc.original_id()] = origin_desc
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
def _add_descs_to_block(descs, block, grad_op_id_to_fwd_op=None):
if len(descs) == 0:
return []
result_descs = []
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
if isinstance(desc, framework.Operator):
# for recompute, should record recompute ops
if grad_op_id_to_fwd_op is not None:
grad_op_id_to_fwd_op[desc.desc.original_id()] = desc
desc = desc.desc
if isinstance(desc, tuple):
desc = desc[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
def _find_loss_op_(loss):
for op in reversed(loss.block.ops):
assert isinstance(op, framework.Operator)
if (
len(op.output_arg_names) == 1
and op.output_arg_names[0] == loss.name
):
loss.op = op
break
if loss.op is None:
raise ValueError("loss.op is None. Should not happen")
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
"""
Traverse all ops in op_descs[begin_idx : end_idx],
if any op has inputs/outputs named "old_name", rename it as 'new_name'
"""
if begin_idx is None:
begin_idx = 0
if end_idx is None:
end_idx = len(op_descs)
if isinstance(op_descs, (list, tuple)):
for i in range(begin_idx, end_idx):
op_desc = op_descs[i]
if isinstance(op_desc, tuple):
op_desc = op_desc[0]
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
if isinstance(op_descs, collections.OrderedDict):
for key, value in op_descs.items():
if isinstance(value, (list, tuple)):
for op_desc in value:
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
Create a C++ OpDesc object with specified inputs, outputs and attributes.
"""
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for para, args in inputs.items():
op_desc.set_input(
para,
list(
map(
lambda arg: arg.decode() if isinstance(arg, bytes) else arg,
args,
)
),
)
for para, args in outputs.items():
op_desc.set_output(
para,
list(
map(
lambda arg: arg.decode() if isinstance(arg, bytes) else arg,
args,
)
),
)
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
op_device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if op_role_attr_name not in attrs:
attrs[
op_role_attr_name
] = core.op_proto_and_checker_maker.OpRole.Backward
if op_device_attr_name not in attrs:
attrs[op_device_attr_name] = ""
for name, val in attrs.items():
if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc)
else:
op_desc._set_attr(name, val)
return op_desc
def _create_loss_op_desc_(loss):
# 0-D Tensor or 0-Size Tensor
if len(loss.shape) == 0 or 0 in loss.shape:
create_shape = loss.shape
else:
create_shape = [1]
op_desc = _create_op_desc_(
"fill_constant",
{},
{"Out": [_append_grad_suffix_(loss.name)]},
{
"shape": create_shape,
"value": 1.0,
"dtype": loss.dtype,
"force_cpu": False,
core.op_proto_and_checker_maker.kOpRoleAttrName(): int(
core.op_proto_and_checker_maker.OpRole.Backward
)
| int(core.op_proto_and_checker_maker.OpRole.Loss),
core.op_proto_and_checker_maker.kOpDeviceAttrName(): loss.op.attr(
core.op_proto_and_checker_maker.kOpDeviceAttrName()
),
},
)
return op_desc
def _infer_var_data_type_shape_(grad_var_name, block):
"""
Infer the data type and shape of given grad variable
"""
grad_var = block.desc.find_var(grad_var_name.encode())
fwd_name = _strip_grad_suffix_(grad_var_name)
if block.desc.has_var_recursive(fwd_name.encode()):
fwd_var = block.desc.find_var_recursive(fwd_name.encode())
grad_var.set_dtype(fwd_var.dtype())
grad_var.set_shape(fwd_var.shape())
else:
# TODO(jiabin): Maybe we should not to this to cause some unexpected error on dtype
warnings.warn(
"Set grad var: {} dtype to default FP32, since we can't find its related forward var".format(
grad_var_name
)
)
grad_var.set_dtype(core.VarDesc.VarType.FP32)
def _all_in_set_(cands, s):
"""
Test if all elements of 'cands' are in set 's'
"""
if len(cands) == 0:
return False
for c in cands:
if not c in s:
return False
return True
def _some_in_set_(cands, s):
"""
Test if some elements of 'cands' are in set 's'
"""
if len(cands) == 0:
return False
for c in cands:
if c in s:
return True
return False
def _strip_grad_suffix_(name):
"""
Strip the grad suffix from the given variable name
e.g. x@GRAD ==> x
x@GRAD@GRAD ==> x
y@GRAD@RENAME@1 ==> y
z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0
grad/grad/z@GRAD@RENAME@block0@1@GRAD ==> z
"""
pos = re.search(f'{core.grad_var_suffix()}+@', name) or re.search(
f'{core.grad_var_suffix()}$', name
)
new_name = name[: pos.start()] if pos is not None else name
new_pos = name.rfind('grad/')
return new_name[new_pos + 5 :] if new_pos != -1 else new_name
def _append_grad_suffix_(name):
"""
Append grad suffix to the given variable name
e.g. x ==> x@GRAD
"""
return name + core.grad_var_suffix()
def _accumulate_gradients_by_sum_op_(
var_name, renamed_vars, pending_sum_ops, op_idx, op_device=""
):
"""
Use sum op to accumulate_gradients, the gradients are stored in renamed_vars.
"""
if op_idx not in pending_sum_ops.keys():
pending_sum_ops[op_idx] = []
pending_sum_ops[op_idx].append(
_create_op_desc_(
"sum",
{"X": renamed_vars[var_name]},
{"Out": [var_name]},
{"use_mkldnn": False, "op_device": op_device},
)
)
renamed_vars[var_name] = [var_name]
def _accumulate_gradients_by_add_ops_(
var_name, renamed_vars, pending_sum_ops, op_idx, op_device=""
):
"""
Use several inplace add op to accumulate_gradients, the gradients are stored in renamed_vars.
"""
if op_idx not in pending_sum_ops.keys():
pending_sum_ops[op_idx] = []
out_name = renamed_vars[var_name][0]
for i in range(1, len(renamed_vars[var_name])):
x_name = out_name
y_name = renamed_vars[var_name][i]
if i != len(renamed_vars[var_name]) - 1:
out_name = var_name + '@ADD@' + str(i)
else:
out_name = var_name
pending_sum_ops[op_idx].append(
_create_op_desc_(
"grad_add",
{"X": [x_name], "Y": [y_name]},
{"Out": [out_name]},
{"use_mkldnn": False, "op_device": op_device},
)
)
renamed_vars[var_name] = [var_name]
def _addup_repetitive_outputs_(
op_descs, block_idx, grad_var_to_var=None, grad_op_id_to_fwd_op=None
):
"""
In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable.
In these cases, the variable should be the accumulation of all the outputs.
`sum_op`s are added to implement the accumulate.
Args:
grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
Only for auto parallel.
"""
_MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add']
# pending_sum_ops = []
pending_sum_ops = collections.OrderedDict()
var_rename_count = collections.defaultdict(int)
renamed_vars = collections.defaultdict(list)
renamed_var_start_idx = collections.defaultdict(list)
var_device = collections.defaultdict(str)
for idx, op_desc in enumerate(op_descs):
op_device_attr_name = (
core.op_proto_and_checker_maker.kOpDeviceAttrName()
)
op_device = ""
if op_desc.has_attr(op_device_attr_name):
op_device = op_desc.attr(op_device_attr_name)
for var_name in op_desc.input_arg_names():
if "@GRAD" not in var_name:
continue
if len(renamed_vars[var_name]) > 1:
if len(renamed_vars[var_name]) > _MAX_ADD_NUM_:
_accumulate_gradients_by_sum_op_(
var_name,
renamed_vars,
pending_sum_ops,
idx,
var_device[var_name],
)
else:
_accumulate_gradients_by_add_ops_(
var_name,
renamed_vars,
pending_sum_ops,
idx,
var_device[var_name],
)
for param_idx, param_name in enumerate(op_desc.output_names()):
arg_names = op_desc.output(param_name)
for arg_idx, var_name in enumerate(arg_names):
if "@GRAD" not in var_name:
continue
# if "@RENAME@" in var_name:
# continue
if (
var_name == core.empty_var_name()
or var_name in op_desc.input_arg_names()
):
# empty variable or inplace op
continue
if len(renamed_vars[var_name]) == 0:
# it's the first time we get the variable
renamed_vars[var_name] = [var_name]
renamed_var_start_idx[var_name] = idx
else:
if len(renamed_vars[var_name]) == 1:
new_name = (
var_name
+ "@RENAME@block"
+ str(block_idx)
+ "@"
+ str(var_rename_count[var_name])
)
var_rename_count[var_name] += 1
# Build the mapping between the new_name and var_name (Only for auto parallel)
if grad_var_to_var is not None:
if var_name in grad_var_to_var:
grad_var_to_var[new_name] = grad_var_to_var[
var_name
]
else:
grad_var_to_var[new_name] = var_name
# rename original var_name
renamed_vars[var_name][0] = new_name
# before change: _rename_arg_(op_descs, var_name,
# new_name, 0, idx)
# rename arg from idx of the first appearance
# in backward, not always from 0
_rename_arg_(
op_descs,
var_name,
new_name,
renamed_var_start_idx[var_name],
idx,
)
_rename_arg_(pending_sum_ops, var_name, new_name)
for p in op_desc.output_names()[:param_idx]:
p_arg_names = op_desc.output(p)
if var_name in p_arg_names:
op_desc.set_output(
p,
[
new_name if x == var_name else x
for x in p_arg_names
],
)
arg_names = [
new_name if x == var_name else x
for x in arg_names[:arg_idx]
] + arg_names[arg_idx:]
new_name = (
var_name
+ "@RENAME@block"
+ str(block_idx)
+ "@"
+ str(var_rename_count[var_name])
)
var_rename_count[var_name] += 1
# Build the mapping between the new_name and var_name (Only for auto parallel)
if grad_var_to_var is not None:
if var_name in grad_var_to_var:
grad_var_to_var[new_name] = grad_var_to_var[
var_name
]
else:
grad_var_to_var[new_name] = var_name
arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name)
# record the latest device
var_device[var_name] = op_device
for var_name, inputs in renamed_vars.items():
if len(renamed_vars[var_name]) > 1:
if len(renamed_vars[var_name]) > _MAX_ADD_NUM_:
_accumulate_gradients_by_sum_op_(
var_name,
renamed_vars,
pending_sum_ops,
len(op_descs),
var_device[var_name],
)
else:
_accumulate_gradients_by_add_ops_(
var_name,
renamed_vars,
pending_sum_ops,
len(op_descs),
var_device[var_name],
)
op_descs_len = len(op_descs)
# sum_op descs are sorted according to their insert position
for key, value in collections.OrderedDict(
reversed(list(pending_sum_ops.items()))
).items():
# NOTE(zhiqiu): Since reversed, the idx of op_descs to be inserted will remains correct.
# For example, [0, 1, 2], and we want to insert 'a' at idx 1, 'b' at idx 2, and the expected result is [0, 1, 'a', 2, 'b'].
# If reversed, we first insert 'b' at idx 2, it becomes [0, 1, 2, 'b'], and then insert 'a' at idx 1, it becomes [0, 1, 'a', 2, 'b'].
# If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2].
idx = key
for i, op in enumerate(value):
# update the mapping between fwd and bwd
target_idx = idx - 1 if idx == op_descs_len else idx + i
if (
grad_op_id_to_fwd_op is not None
and grad_op_id_to_fwd_op.get(
op_descs[target_idx].original_id(), None
)
is not None
):
grad_op_id_to_fwd_op[op.original_id()] = grad_op_id_to_fwd_op[
op_descs[target_idx].original_id()
]
op_descs.insert(idx + i, op)
return op_descs
def _remove_no_grad_branch_(
op_descs, no_grad_set, grad_op_id_to_fwd_op=None, target_vars=[]
):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set'
2. all grad inputs of the grad op are in 'no_grad_set'
NOTE: we will skip target_vars's grad name.
"""
def _op_can_be_removed_(op_desc, no_grad_set):
out_arg_names = op_desc.output_arg_names()
if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
return True
if _all_in_set_(
[
name
for name in op_desc.input_arg_names()
if name.find(core.grad_var_suffix()) != -1
],
no_grad_set,
):
no_grad_set.update(set(out_arg_names) - target_grad_var_names)
return True
return False
# Remove ops whose outputs are all in no_grad_dict
target_grad_var_names = set(
[var.name + core.grad_var_suffix() for var in target_vars]
)
op_descs = [
op_desc
for op_desc in op_descs
if not _op_can_be_removed_(op_desc, no_grad_set)
]
# Insert fill_any_like_op with value 0
to_insert = []
if not core._is_bwd_prim_enabled():
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
# arg is a gradient var name and arg should not have gradient
if core.grad_var_suffix() in arg and arg in no_grad_set:
x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
new_op_desc = _create_op_desc_(
"fill_any_like",
{"X": [x_in]},
{"Out": [arg]},
{'value': 0, 'dtype': -1},
)
# update the mapping between fwd and bwd
if (
grad_op_id_to_fwd_op is not None
and grad_op_id_to_fwd_op.get(
op_desc.original_id(), None
)
is not None
):
grad_op_id_to_fwd_op[
new_op_desc.original_id()
] = grad_op_id_to_fwd_op[op_desc.original_id()]
to_insert.append((new_op_desc, idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
return op_descs
def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
"""
Pruning Program with Structural Analysis Method of Computational Graph.
The nodes of the computational graph composed of backward OPS should be
interconnected. If there are unconnected sub-graphs in the computational graph,
these sub-graphs should be cut off.
Args:
grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs.
forward_ops(list[Operator]): The forward ops.
input_grad_names_set(set): this set is used to store the gradients' name
which is generated by backward ops, and input_grad_names_set can help
to prune the unnecessary backward ops.
Return:
(set[core.OpDesc]): A set of OpDescs which should be pruned.
"""
class Var:
def __init__(self, var_name):
self.var_name = var_name
self.gen_op = None
self.pendding_ops = []
def set_gen_op(self, gen_op):
assert isinstance(gen_op, Op)
assert self.gen_op is None
self.gen_op = gen_op
def add_pending_op(self, op):
assert isinstance(op, Op)
self.pendding_ops.append(op)
class Op:
def __init__(self, op_desc):
self.op_desc = op_desc
self.inputs = []
self.outputs = []
def insert_input(self, var):
assert isinstance(var, Var)
self.inputs.append(var)
def insert_output(self, var):
assert isinstance(var, Var)
self.outputs.append(var)
var_versions = dict()
def _create_node(name):
if name not in var_versions.keys():
var_versions[name] = [Var(name)]
else:
var_versions[name].append(Var(name))
return var_versions[name][-1]
def _create_or_get_last_version_node(name):
if name not in var_versions.keys():
var_versions[name] = [Var(name)]
return var_versions[name][-1]
def _create_op_node(op_desc):
op_node = Op(op_desc)
for input in op_desc.input_arg_names():
var = _create_or_get_last_version_node(name=input)
var.add_pending_op(op_node)
op_node.insert_input(var)
for output in op_desc.output_arg_names():
var = _create_node(name=output)
var.set_gen_op(op_node)
op_node.insert_output(var)
return op_node
# Record the forward vars
forward_vars_set = (
set() if input_grad_names_set is None else set(input_grad_names_set)
)
for op in forward_ops:
forward_vars_set.update(op.desc.input_arg_names())
forward_vars_set.update(op.desc.output_arg_names())
# Record the vars which are created during backward and is not generated by op.
backward_vars_set = set()
# special_op_nodes is the candidate sub-graph head node.
special_op_nodes = set()
for op_desc in grad_op_descs:
input_set = set(op_desc.input_arg_names())
# The new_vars are created during backward and is not generated by op.
new_vars = input_set - forward_vars_set - backward_vars_set
backward_vars_set.update(op_desc.output_arg_names())
op_node = _create_op_node(op_desc)
if len(new_vars) == len(input_set):
special_op_nodes.add(op_node)
not_need_op_descs = []
# Start traversing all candidate sub-graph headers to check whether
# they are connected to backward computational graphs, and if they are
# not, list them in not_need_op_descs
for special_op_node in special_op_nodes:
op_list = [special_op_node]
ready_vars = set(special_op_node.inputs)
remove_ops = True
candidate_ops = [special_op_node]
while len(candidate_ops) > 0:
op_node = candidate_ops.pop(0)
if _all_in_set_(op_node.inputs, ready_vars):
for out_var in op_node.outputs:
candidate_ops.extend(out_var.pendding_ops)
op_list.extend(out_var.pendding_ops)
ready_vars.update(op_node.outputs)
else:
remove_ops = False
break
if remove_ops:
not_need_op_descs.extend([node.op_desc for node in op_list])
not_need_op_descs_set = set(not_need_op_descs)
grad_op_descs_set = set(grad_op_descs)
# If a backward computational graph is simply one sub-graph header, the
# not_need_op_descs will be whole graph, this IF clause avoids it.
if grad_op_descs_set == not_need_op_descs_set:
return set()
return not_need_op_descs_set
def serialize_op_decs(op_desc):
protostr = op_desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(bytes(protostr))
return proto.__str__()
def _append_backward_ops_with_checkpoints_(
block,
ops,
target_vars,
target_block,
no_grad_dict,
grad_to_var,
checkpoints,
grad_op_id_to_fwd_op=None,
):
"""
Create grad ops with forward ops, and insert them into given block
Args:
block(Block): the block where forward ops are
ops(Op): the forward operators whose forward recomputation backward ops need to be added
target_vars(list[Tensor]): the loss vars we want to calculate gradient.
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
val(str): corresponding forward variable name
checkpoints: variables that a user defined as checkpoint for forward recomputation
Algorithms:
0) deal with forward recomputing program descs
1) find ops between checkpoints, i.e. recompute_segments
2) go through all forward ops and induct all variables that will be hold in memory
a. variables that are used across segments will be held in memory
b. output of dropout op will be held in memory
c. input variables will be held in memory
3) go through each recompute_segments, add backward ops with forward recomputation
a. add ops in current recompute_segment as forward recomputation ops
b. rename all non-checkpoint variables in recomputation ops
c. add backward ops of current recomputation ops
d. add sum op for repetitive_outputs
4) remove no grad branch as it is in _remove_no_grad_branch_
5) Note1: all appended ops' OpRole are Backward
6) Note2: all variables with new name should be returned so that _append_backward_vars_ can be called
7) Note3: current forward recomputation backpropagation does not handle programs with subblock
"""
checkpoints_name = [x.name for x in checkpoints]
checkpoints_name = list(set(checkpoints_name))
local_block = block.program._create_block()
buffer_block = block.program._create_block()
# 0) deal with forward recomputing program descs
program_stat = ProgramStats(block, ops)
program_stat.modify_forward_desc_for_recompute()
program_stat.build_stats()
# 1) find ops between checkpoints, i.e. recompute_segments
checkpoints_name = program_stat.sort_checkpoints(checkpoints_name)
segments = []
if len(checkpoints_name) == 1:
# only one checkpoint
max_op_idx = -1
var_group = [checkpoints_name[0]]
for name in var_group:
if name not in program_stat.var_op_deps:
break
op_idx = program_stat.var_op_deps[name]["var_as_output_ops"]
# only count the last generate op
for idx in op_idx:
max_op_idx = max(max_op_idx, idx)
if max_op_idx > 0:
segments.append([0, max_op_idx + 1])
else:
start_idx = 0
pre_segment_end_idx = -1
while True:
if start_idx >= len(checkpoints_name) - 1:
break
# min_idx: checkpoint_1' s input op
# max_idx: checkpoint_2' s output op
flag, min_idx, max_idx = program_stat.is_subgraph(
[checkpoints_name[start_idx]], [checkpoints_name[start_idx + 1]]
)
if flag:
# max_idx + 1 since the exact and used segment end idx is max_idx
min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1])