-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
fleet.py
executable file
·1655 lines (1297 loc) · 54.9 KB
/
fleet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import paddle
from paddle.fluid import compiler
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.framework import _global_flags, in_dynamic_mode
from paddle.framework.ir import apply_build_strategy
from .base import topology as tp
from .base.distributed_strategy import DistributedStrategy
from .base.meta_optimizer_factory import MetaOptimizerFactory
from .base.role_maker import PaddleCloudRoleMaker, RoleMakerBase
from .base.runtime_factory import RuntimeFactory
from .base.strategy_compiler import StrategyCompiler
from .meta_parallel import model_parallel_random_seed
from .utils.log_util import logger, set_log_level
__all__ = []
def apply_ir_passes(main_program, startup_program, config):
build_strategy = config._user_defined_strategy.build_strategy._copy()
if not _global_flags()['FLAGS_apply_pass_to_program']:
return build_strategy
pipeline_opt = getattr(main_program, "_pipeline_opt", {})
if pipeline_opt:
main_program = pipeline_opt["section_program"]
startup_program = startup_program._pipeline_opt["startup_program"]
pass_attrs = {"use_cuda": config._is_collective}
fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
# FIXME(zjl): currently, fuse_all_optimizer_ops
# have conflict with fuse_all_reduce_ops because
# RawProgramOptimizer also inserts coalesce_tensor
# into program. These two procedures may conflict
# in which vars are to be fused.
logger.warning(
'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
)
build_strategy.fuse_all_optimizer_ops = False
return apply_build_strategy(
main_program, startup_program, build_strategy, pass_attrs
)
def _inited_runtime_handler_(func):
def __impl__(*args, **kwargs):
cls = args[0]
if cls._runtime_handle is None:
raise ValueError("Fleet can not find suitable runtime handler")
return func(*args, **kwargs)
return __impl__
def _is_non_distributed_check_(func):
def __impl__(*args, **kwargs):
cls = args[0]
if (
cls._role_maker is not None
and cls._role_maker._is_non_distributed() is True
):
logger.warning(
"%s() function doesn't work when use non_distributed fleet."
% (func.__name__)
)
return
return func(*args, **kwargs)
return __impl__
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
class Fleet:
"""
Unified API for distributed training of PaddlePaddle
Please reference the https://github.com/PaddlePaddle/PaddleFleetX for details
Returns:
Fleet: A Fleet instance
Example for collective training:
.. code-block:: python
import paddle
paddle.enable_static()
import paddle.distributed.fleet as fleet
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
# do distributed training
Example for parameter server training:
.. code-block:: python
import paddle
paddle.enable_static()
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
fleet.init(strategy=strategy)
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer)
if fleet.is_first_worker():
print("this is first worker")
print("current node index: {}".format(fleet.worker_index()))
print("total number of worker num: {}".format(fleet.worker_num()))
if fleet.is_worker():
print("this is worker")
print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
print("server num: {}".format(fleet.server_num()))
print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
if fleet.is_server():
print("this is server")
fleet.stop_worker()
"""
def __init__(self):
self._role_maker = None
self.strategy_compiler = None
self._is_collective = False
self._runtime_handle = None
self._util = None
self._context = {}
self.user_defined_optimizer = paddle.optimizer.Optimizer(0.0)
def init(
self,
role_maker=None,
is_collective=False,
strategy=None,
log_level="INFO",
):
"""
Initialize role_maker in Fleet.
This function is responsible for the distributed architecture
what you want to run your code behind.
Args:
role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
of environment variables related to distributed training.If you did not initialize
the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
The default value is None.
is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program
runs on Collective mode or ParameterServer mode. True means the program runs on
Collective mode, and False means running on ParameterServer mode. The default value
is False.
strategy (DistributedStrategy): Extra properties for distributed training.
For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.
log_level (Integer, String, optional): A ``Integer`` or ``String`` Variable determining how hight
the logging level is. Default is "INFO".
Returns:
None
Examples1:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
Examples2:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init(is_collective=True)
Examples3:
.. code-block:: python
import paddle.distributed.fleet as fleet
role = fleet.PaddleCloudRoleMaker()
fleet.init(role)
Examples4:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
fleet.init(strategy=strategy)
Examples5:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
fleet.init(log_level = "DEBUG")
"""
from paddle.distributed import parallel_helper
set_log_level(log_level)
if strategy is None:
strategy = DistributedStrategy()
self._user_defined_strategy = copy.deepcopy(strategy)
if role_maker is None:
if isinstance(is_collective, bool):
self._is_collective = is_collective
self._role_maker = PaddleCloudRoleMaker(
is_collective=self._is_collective
)
else:
raise ValueError(
"`is_collective` should be instance of `bool`, but got {}".format(
type(is_collective)
)
)
else:
if isinstance(role_maker, RoleMakerBase):
self._role_maker = role_maker
self._is_collective = role_maker._is_collective
else:
raise ValueError(
"`role_maker` should be subclass of `RoleMakerBase`, but got {}".format(
type(role_maker)
)
)
self._role_maker._generate_role()
from paddle.distributed import fleet
fleet.util._set_role_maker(self._role_maker)
self.strategy_compiler = StrategyCompiler()
if self._role_maker._is_non_distributed() and self._is_collective:
if paddle.framework.core.is_compiled_with_cuda():
gpus_num = paddle.framework.core.get_cuda_device_count()
if gpus_num != 1:
raise ValueError(
"CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
)
if in_dynamic_mode():
if self.worker_num() == 1:
# if worker_num is 1, should construct default topology & hcg
self._topology = tp.CommunicateTopology()
self._hcg = tp.HybridCommunicateGroup(self._topology)
return
if parallel_helper._is_parallel_ctx_initialized():
logger.warning(
"The dygraph parallel environment has been initialized."
)
else:
# FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
if "FLAGS_nccl_nrings" in os.environ:
logger.warning(
"You have set the environment variable FLAGS_nccl_nrings "
"outside the program, so the nccl_comm_num in "
"DistributedStrategy will not take effect here."
)
else:
os.environ["FLAGS_nccl_nrings"] = str(
self._user_defined_strategy.nccl_comm_num
)
paddle.distributed.init_parallel_env()
# hybrid parallel not support for npu/xpu
if not self._user_defined_strategy.heter_ccl_mode:
# init hybrid parallel environment in dygraph
if tp._HYBRID_PARALLEL_GROUP is None:
self._init_hybrid_parallel_env()
else:
logger.warning(
"The dygraph hybrid parallel environment has been initialized."
)
elif self._is_collective:
use_sharding = self._user_defined_strategy.sharding
# global group
global_rank = self.worker_index()
global_world_size = self.worker_num()
# NOTE(wangxi): see sharding_optimizer
global_ring_id = 3 if use_sharding else 0
global_ranks = list(range(global_world_size))
if tp._HYBRID_PARALLEL_GROUP is None:
tp._CommunicateGroup()
cg = tp._HYBRID_PARALLEL_GROUP
self._hcg = cg
cg.set_comm_group(
'global',
global_rank,
global_world_size,
global_ring_id,
global_ranks,
)
use_tensor_parallel = self._user_defined_strategy.tensor_parallel
use_mp = use_sharding or use_tensor_parallel
# hybrid group
if use_mp is False:
return
mp_degree_sharding = 1
mp_degree_tensor_parallel = 1
if use_sharding:
sharding_configs = self._user_defined_strategy.sharding_configs
mp_degree_sharding = int(sharding_configs['mp_degree'])
if use_tensor_parallel:
tensor_parallel_configs = (
self._user_defined_strategy.tensor_parallel_configs
)
mp_degree_tensor_parallel = int(
tensor_parallel_configs['tensor_parallel_degree']
)
if use_sharding and use_tensor_parallel:
assert mp_degree_sharding == mp_degree_tensor_parallel
mp_degree = (
mp_degree_sharding
if use_sharding
else mp_degree_tensor_parallel
)
if mp_degree > 1:
assert global_world_size % mp_degree == 0
# NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
mp_ring_id = 0
mp_rank = global_rank % mp_degree
mp_group_id = global_rank // mp_degree
mp_group_ranks = [
idx
for idx in global_ranks
if idx // mp_degree == mp_group_id
]
cg.set_comm_group(
'model', mp_rank, mp_degree, mp_ring_id, mp_group_ranks
)
return self
def _init_hybrid_parallel_env(self):
"""initialize the hybrid environment"""
self.hybrid_configs = self._user_defined_strategy.hybrid_configs
self.dp_degree = self.hybrid_configs["dp_degree"]
self.mp_degree = self.hybrid_configs["mp_degree"]
self.pp_degree = self.hybrid_configs["pp_degree"]
self.sharding_degree = self.hybrid_configs["sharding_degree"]
assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
assert (
self.sharding_degree >= 0
), "sharding_degree should be greater or equal to 0"
self.mp_degree = max(self.mp_degree, 1)
self.pp_degree = max(self.pp_degree, 1)
if self.dp_degree < 0:
nranks = paddle.distributed.get_world_size()
self.dp_degree = nranks // (self.mp_degree * self.pp_degree)
self.dp_degree = max(self.dp_degree, 1)
d_hybrid_degree = {
"dp": ["data", self.dp_degree],
"pp": ['pipe', self.pp_degree],
"sharding": ['sharding', self.sharding_degree],
"mp": ['model', self.mp_degree],
}
order = self._user_defined_strategy.hybrid_parallel_order
if order[:].sort() != list(d_hybrid_degree.keys())[:].sort():
raise AssertionError(
'The order of hybrid_config setting is incorrect.'
)
hybrid_group_names = []
dims = []
for h_name in order:
name, degree = d_hybrid_degree[h_name]
hybrid_group_names.append(name)
dims.append(degree)
self._topology = tp.CommunicateTopology(
hybrid_group_names=hybrid_group_names, dims=dims
)
self._hcg = tp.HybridCommunicateGroup(self._topology)
if self.mp_degree > 1:
tensor_parallel_configs = (
self._user_defined_strategy.tensor_parallel_configs
)
tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
if tensor_init_seed == -1:
model_parallel_random_seed()
else:
model_parallel_random_seed(tensor_init_seed)
def get_hybrid_communicate_group(self):
assert self._hcg is not None
return self._hcg
def get_hybrid_parallel_topology(self):
assert self._topology is not None
return self._topology
def is_first_worker(self):
"""
Check whether the node is the first instance of worker.
Returns:
bool: True if this is the first node of worker,
False if not.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.is_first_worker()
"""
return self._role_maker._is_first_worker()
def worker_index(self):
"""
Get current worker index.
Returns:
int: node id
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.worker_index()
"""
return self._role_maker._worker_index()
def worker_num(self):
"""
Get current total worker number.
Returns:
int: worker numbers
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.worker_num()
"""
return self._role_maker._worker_num()
def node_num(self):
return self._role_maker._get_node_num()
def local_rank(self):
return self._role_maker._get_local_rank()
def local_device_ids(self):
return self._role_maker._get_local_device_ids()
def world_device_ids(self):
return self._role_maker._get_world_device_ids()
def is_worker(self):
"""
Check whether the node is an instance of worker.
Returns:
bool: True if this is a node of worker,
False if not.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.is_worker()
"""
return self._role_maker._is_worker()
def is_coordinator(self):
return self._role_maker._is_coordinator()
def worker_endpoints(self, to_string=False):
"""
Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
Returns:
list/string: server endpoints
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.worker_endpoints()
"""
if to_string:
return ",".join(self._role_maker._get_trainer_endpoints())
else:
return self._role_maker._get_trainer_endpoints()
def server_num(self):
"""
Get current total worker number.
Returns:
int: server number
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.server_num()
"""
return len(self._role_maker._get_pserver_endpoints())
def server_index(self):
"""
Get current server index.
Returns:
int: node id
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.server_index()
"""
return self._role_maker._server_index()
def server_endpoints(self, to_string=False):
"""
Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
Returns:
list/string: server endpoints
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.server_endpoints()
"""
if to_string:
return ",".join(self._role_maker._get_pserver_endpoints())
else:
return self._role_maker._get_pserver_endpoints()
def is_server(self):
"""
Check whether the node is an instance of server.
Returns:
bool: True if this is a node of server,
False if not.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.is_server()
"""
return self._role_maker._is_server()
def barrier_worker(self):
"""
barrier all workers
Returns:
None
"""
self._role_maker._barrier("worker")
@is_non_distributed_check
@inited_runtime_handler
def init_worker(self, scopes=None):
"""
initialize `Communicator` for parameter server training.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_worker()
"""
self._runtime_handle._init_worker(scopes)
@is_non_distributed_check
@inited_runtime_handler
def init_coordinator(self, scopes=None):
"""
initialize coordinator node
"""
self._runtime_handle._init_coordinator(scopes)
def make_fl_strategy(self):
self._runtime_handle._make_fl_strategy()
@is_non_distributed_check
@inited_runtime_handler
def get_fl_client(self):
"""
get worker(training node) ptr
"""
return self._runtime_handle._worker
@is_non_distributed_check
@inited_runtime_handler
def init_server(self, *args, **kwargs):
"""
init_server executor to initialize startup program,
if the `args` is not empty, it will run load_persistables for increment training.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_server()
"""
self._runtime_handle._init_server(*args, **kwargs)
@is_non_distributed_check
@inited_runtime_handler
def load_model(self, path, mode):
"""
load fleet model from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_model("path", mode=0)
"""
self._runtime_handle._load_persistables(path, mode)
@is_non_distributed_check
@inited_runtime_handler
def load_one_table(self, table_id, path, mode):
"""
load fleet one table from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_one_table(0, "path", mode=0)
"""
self._runtime_handle._load_one_table(table_id, path, mode)
@is_non_distributed_check
@inited_runtime_handler
def load_inference_model(self, path, mode):
"""
load fleet inference model from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_inference_model("path", mode=1)
"""
self._runtime_handle._load_inference_model(path, mode)
@is_non_distributed_check
@inited_runtime_handler
def run_server(self):
"""
run server will run pserver main program with executor.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
if fleet.is_server():
fleet.init_server()
"""
self._runtime_handle._run_server()
@is_non_distributed_check
@inited_runtime_handler
def stop_worker(self):
"""
stop `Communicator` and give training complete notice to parameter server.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_server()
"""
self._runtime_handle._stop_worker()
@is_non_distributed_check
@inited_runtime_handler
def save(self, dirname, feed=[], fetch=[], **configs):
inference = True
if not feed and not fetch:
inference = False
place = paddle.CPUPlace()
executor = paddle.static.Executor(place)
if inference:
feeded_var_names = []
fetch_var_names = []
for var in feed:
if isinstance(var, str):
feeded_var_names.append(var)
elif isinstance(var, paddle.static.Variable):
feeded_var_names.append(var.name)
else:
raise ValueError("feed must be [str|Variable]")
for var in fetch:
if isinstance(var, str):
fetch_var_names.append(var)
elif isinstance(var, paddle.static.Variable):
fetch_var_names.append(var.name)
else:
raise ValueError("feed must be [str|Variable]")
fetch_vars = [
paddle.static.default_main_program().global_block().var(name)
for name in fetch_var_names
]
self._runtime_handle._save_inference_model(
executor, dirname, feeded_var_names, fetch_vars, None, True, 0
)
else:
increment_mode = 0
if "mode" in configs:
increment_mode = int(configs["mode"])
self._runtime_handle._save_persistables(
executor, dirname, main_program=None, mode=increment_mode
)
@is_non_distributed_check
@inited_runtime_handler
def save_inference_model(
self,
executor,
dirname,
feeded_var_names,
target_vars,
main_program=None,
export_for_deployment=True,
mode=0,
):
"""
save inference model for inference.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_server()
"""
self._runtime_handle._save_inference_model(
executor,
dirname,
feeded_var_names,
target_vars,
main_program,
export_for_deployment,
mode,
)
@is_non_distributed_check
@inited_runtime_handler
def save_persistables(self, executor, dirname, main_program=None, mode=0):
"""
saves all persistable tensors from :code:`main_program` to
the folder :code:`dirname`. You can refer to
The :code:`dirname` is used to specify the folder where persistable tensors
are going to be saved. If you would like to save tensors in separate
files, set :code:`filename` None.
Args:
executor(Executor): The executor to run for saving persistable tensors.
You can refer to :ref:`api_guide_executor_en` for
more details.
dirname(str, optional): The saving directory path.
When you need to save the parameter to the memory, set it to None.
main_program(Program, optional): The program whose persistbale tensors will
be saved. Default: None.
Returns:
None
Examples:
.. code-block:: text
import paddle
paddle.enable_static()
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
exe = paddle.static.Executor(paddle.CPUPlace())
fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
"""
self._runtime_handle._save_persistables(
executor, dirname, main_program, mode
)
@is_non_distributed_check
@inited_runtime_handler
def save_cache_model(self, dirname, **configs):
return self._runtime_handle._save_cache_model(dirname, **configs)
@is_non_distributed_check
@inited_runtime_handler
def check_save_pre_patch_done(self):
return self._runtime_handle._check_save_pre_patch_done()
@is_non_distributed_check
@inited_runtime_handler
def save_cache_table(
self, table_id, pass_id, mem_cache_key_threshold=4000000000
):
return self._runtime_handle._save_cache_table(
table_id, pass_id, mem_cache_key_threshold
)
@is_non_distributed_check
@inited_runtime_handler
def save_one_table(self, table_id, path, mode):
"""
save fleet one table from path