/
main.py
1079 lines (907 loc) · 45.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .context import Context
ctx = None
def launch():
"""
Paddle distribution training entry ``python -m paddle.distributed.launch``.
Usage:
.. code-block:: bash
:name: code-block-bash1
python -m paddle.distributed.launch [-h] [--master MASTER] [--rank RANK]
[--log_level LOG_LEVEL] [--nnodes NNODES]
[--nproc_per_node NPROC_PER_NODE] [--log_dir LOG_DIR]
[--run_mode RUN_MODE] [--job_id JOB_ID] [--devices DEVICES]
[--host HOST] [--servers SERVERS] [--trainers TRAINERS]
[--trainer_num TRAINER_NUM] [--server_num SERVER_NUM]
[--gloo_port GLOO_PORT] [--with_gloo WITH_GLOO]
[--max_restart MAX_RESTART] [--elastic_level ELASTIC_LEVEL]
[--elastic_timeout ELASTIC_TIMEOUT]
training_script ...
Base Parameters:
- ``--master``: The master/rendezvous server, support ``http://`` and ``etcd://``, default with ``http://``. e.g., ``--master=127.0.0.1:8080``. Default ``--master=None``.
- ``--rank``: The rank of the node, can be auto assigned by master. Default ``--rank=-1``.
- ``--log_level``: The log level to set for logging.setLevel which can be CRITICAL/ERROR/WARNING/INFO/DEBUG/NOTSET, case insensitive. Default ``--log_level=INFO``.
- ``--nnodes``: The number of nodes for a distributed job, it can be a range in elastic mode, e.g., ``--nnodes=2:3``. Default ``--nnodes=1``.
- ``--nproc_per_node``: The number of processes to launch on a node. In gpu training, it should be less or equal to the gpus number of you system. e.g., ``--nproc_per_node=8``
- ``--log_dir``: The path for each process's log. e.g., ``--log_dir=output_dir``. Default ``--log_dir=log``.
- ``--run_mode``: The run mode of job, can be:collective/ps/ps-heter/rpc. e.g., ``--run_mode=ps``. Default ``--run_mode=collective``.
- ``--job_id``: The job unique id, it affects the log files' name. e.g., ``--job_id=job1``. Default ``--job_id=default``.
- ``--devices``: The selected accelerate devices on nodes, can be gpu/xpu etc.. e.g., ``--devices=0,1,2,3`` will launch four training processes each bound to one device.
- ``training_script``: The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script. e.g., ``training.py``
- ``training_script_args``: The args of training_script. e.g., ``--lr=0.1``
Collective Parameters:
- ``--ips``: [DEPRECATED] Paddle cluster nodes ips, e.g., ``--ips=192.168.0.16,192.168.0.17``. Default ``--ips=127.0.0.1``.
Parameter-Server Parameters:
- ``--servers``: User defined servers ip:port, e.g., ``--servers="192.168.0.16:6170,192.168.0.17:6170"``
- ``--trainers``: User defined trainers ip:port, e.g., ``--trainers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"``
- ``--workers``: [DEPRECATED] The same as trainers.
- ``--trainer_num``: Number of trainers on each node, can be 0.
- ``--worker_num``: [DEPRECATED] The same as trainer_num.
- ``--server_num``: Number of servers on each node, can be 0.
- ``--heter_workers``: User defined heter workers ip1:port1;ip2:port2, e.g., ``--heter_workers="192.168.0.16:6172;192.168.0.17:6172"``
- ``--heter_worker_num``: Number of heter_workers in each stage (It recommend to set when in the emulated distributed environment using single node)
- ``--heter_devices``: Type of heter_device in each stage
- ``--gloo_port``: Gloo http Port. Default ``--gloo_port=6767``.
- ``--with_gloo``: Using gloo or not. Default ``--with_gloo=0``.
Elastic Parameters:
- ``--max_restart``: The maximum restart times for an elastic job. Default ``--max_restart=3``.
- ``--elastic_level``: The elastic level: -1: disable, 0: failed exit, peers hold, 1: internal restart. Default ``--elastic_level=-1``.
- ``--elastic_timeout``: Seconds to wait before elastic job begin to train. Default ``--elastic_timeout=30``.
IPU Parameters:
IPU distributed launch only requires and allowes three arguments ``--devices``, ``training_script`` and ``training_script_args``.
The ``--devices`` is the number of IPU devices. e.g., ``--devices=4`` will launch the training program with four IPU devices.
The ``training_script`` is only allowed to set as ``ipu``.
The ``training_script_args`` includes arguments required by IPU distributed launch and illustrated as below.
``Examples 10`` has provided a example of paddle.distributed.launch with IPUs.
- ``--hosts``: The hosts for IPU distributd training. Each host is able to include multiple processes.
- ``--nproc_per_host``: The number of processes launched per host. Each process is able to include multiple replicas.
- ``--ipus_per_replica``: The number of IPUs requested per replica. Each replica is able to include multiple IPUs.
- ``--ipu_partition``: The partition name of IPU devices.
- ``--vipu_server``: The ip of the IPU device manager.
- ``training_script``: The full path to the IPU distributed training program/script to be launched in parallel. e.g., ``training.py``.
- ``training_script_args``: The args of the IPU distributed training program/script. e.g., ``--lr=0.1``.
Returns:
- ``None``
Examples 0 (master, ip/port auto detection):
.. code-block:: bash
:name: code-block-example-bash0
# For training on multi node, run the following command in one of the nodes
python -m paddle.distributed.launch --nnodes 2 train.py
# Then the following info will be print
# Copy the following command to other nodes to run.
# --------------------------------------------------------------------------------
# python -m paddle.distributed.launch --master 10.0.0.1:38714 --nnodes 2 train.py
# --------------------------------------------------------------------------------
# Follow the instruction above and paste the command in other nodes can launch a multi nodes training job.
# There are two ways to launch a job with the same command for multi nodes training
# 1) using the following command in every nodes, make sure the ip is one of the training node and the port is available on that node
# python -m paddle.distributed.launch --master 10.0.0.1:38714 --nnodes 2 train.py
# 2) using the following command in every nodes with a independent etcd service
# python -m paddle.distributed.launch --master etcd://10.0.0.1:2379 --nnodes 2 train.py
# This functionality works will for both collective and ps mode and even with other arguments.
Examples 1 (collective, single node):
.. code-block:: bash
:name: code-block-example-bash1
# For training on single node using 4 gpus.
python -m paddle.distributed.launch --devices=0,1,2,3 train.py --lr=0.01
Examples 2 (collective, multi node):
.. code-block:: bash
:name: code-block-example-bash2
# For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17
# On 192.168.0.16:
python -m paddle.distributed.launch --devices=0,1,2,3 --master=192.168.0.16:8090 train.py --lr=0.01
# On 192.168.0.17:
python -m paddle.distributed.launch --devices=0,1,2,3 --master=192.168.0.16:8090 train.py --lr=0.01
Examples 3 (ps, cpu, single node):
.. code-block:: bash
:name: code-block-example-bash3
# To simulate distributed environment using single node, e.g., 2 servers and 4 workers.
python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01
Examples 4 (ps, cpu, multi node):
.. code-block:: bash
:name: code-block-example-bash4
# For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 1 server and 2 workers.
# On 192.168.0.16:
python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
# On 192.168.0.17:
python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
# Or with master, the following command run 2 server and 2 trainer on each node.
python -m paddle.distributed.launch --master 192.168.0.16:9090 --server_num=2 --trainer_num=2 --nnodes 2 train.py
Examples 5 (ps, gpu, single node):
.. code-block:: bash
:name: code-block-example-bash5
# To simulate distributed environment using single node, e.g., 2 servers and 4 workers, each worker use single gpu.
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01
Examples 6 (ps, gpu, multi node):
.. code-block:: bash
:name: code-block-example-bash6
# For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 1 server and 2 workers.
# On 192.168.0.16:
export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
# On 192.168.0.17:
export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01
Examples 7 (ps-heter, cpu + gpu, single node):
.. code-block:: bash
:name: code-block-example-bash7
# To simulate distributed environment using single node, e.g., 2 servers and 4 workers, two workers use gpu, two workers use cpu.
export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch --server_num=2 --worker_num=2 --heter_worker_num=2 train.py --lr=0.01
Examples 8 (ps-heter, cpu + gpu, multi node):
.. code-block:: bash
:name: code-block-example-bash8
# For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 1 server, 1 gpu worker, 1 cpu worker.
# On 192.168.0.16:
export CUDA_VISIBLE_DEVICES=0
python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01
# On 192.168.0.17:
export CUDA_VISIBLE_DEVICES=0
python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01
Examples 9 (elastic):
.. code-block:: bash
:name: code-block-example-bash9
# With the following command, the job will begin to run immediately if 4 nodes are ready,
# or it will run after elastic_timeout if only 2 or 3 nodes ready
python -m paddle.distributed.launch --master etcd://10.0.0.1:2379 --nnodes 2:4 train.py
# once the number of nodes changes between 2:4 during training, the strategy holds
Examples 10 (ipu):
.. code-block:: bash
:name: code-block-example-bash10
# With the following command, the job will begin to run the distributhed program with IPUs
# Require `devices` as the number of IPUs
# Require `training_script` to be set as `ipu`
# Require `training_script_args` as the arguments of IPU distributed training instead of the arguments of the training program/script
# Please Check the `IPU Parameters` for details
python -m paddle.distributed.launch --devices 4 ipu --hosts=localhost --nproc_per_host=2 --ipus_per_replica=1 --ipu_partition=pod16 --vipu_server=127.0.0.1 train.py
Examples 11 (rpc, cpu, single node):
.. code-block:: bash
:name: code-block-example-bash11
# Training on single node with two local servers
python -m paddle.distributed.launch --master 127.0.0.1:8765 --nnodes 1 --nproc_per_node 2 --rank 0 --run_mode rpc train.py
Examples 12 (rpc, cpu, multi node):
.. code-block:: bash
:name: code-block-example-bash12
# For training on multiple nodes, e.g., 192.168.0.16, 192.168.0.17 where each node with 2 servers.
# On 192.168.0.16
python -m paddle.distributed.launch --master 192.168.0.16:8765 --nnodes 2 --nproc_per_node 2 --rank 0 --run_mode rpc train.py
# On 192.168.0.17
python -m paddle.distributed.launch --master 192.168.0.16:8765 --nnodes 2 --nproc_per_node 2 --rank 1 --run_mode rpc train.py
"""
# initialize the context to run
global ctx
ctx = Context()
if ctx.is_legacy_mode():
# legacy mode
from paddle.distributed.fleet import launch
launch.launch()
elif ctx.is_auto_tuner_mode():
import copy
import json
import logging
import os
import sys
import time
from ..auto_tuner.recorder import HistoryRecorder
from ..auto_tuner.tuner import AutoTuner
from ..auto_tuner.utils import (
add_overlap_performance,
gen_new_args,
gen_new_ctx,
read_log,
read_step_time_log,
)
from . import controllers
start_time = time.time()
# read user defined tuner config json
if not ctx.args.auto_tuner_json.endswith(".json"):
raise ValueError("Please use '.json' as the file name suffix.")
try:
with open(ctx.args.auto_tuner_json, "r") as f:
tuner_cfg = json.load(f)
except:
raise ValueError("Please check your auto tuner json whether valid.")
logger = logging.getLogger('auto_tuner')
logger.setLevel(logging.INFO)
auto_tuner_log_path = os.path.join(
os.path.dirname(ctx.args.auto_tuner_json),
f'{os.path.basename(ctx.args.auto_tuner_json).split(".")[0]}_auto_tuner.log',
)
handler = logging.FileHandler(auto_tuner_log_path, mode="w")
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
# copy training script args
if ctx.args.training_script.endswith('.py'):
if os.environ.get("WITH_COVERAGE") == "ON":
entrypoint = [
sys.executable,
"-u",
"-m",
"coverage",
"run",
"--branch",
"-p",
ctx.args.training_script,
]
else:
entrypoint = [sys.executable, "-u", ctx.args.training_script]
else:
entrypoint = [ctx.args.training_script]
entrypoint.extend(ctx.args.training_script_args)
raw_args = copy.deepcopy(ctx.args.training_script_args)
# get nodes and gpus from args
if not ctx.args.devices:
gpus_per_node = 8
else:
gpus_per_node = len(ctx.args.devices.split(","))
nnodes = ctx.args.nnodes
if isinstance(nnodes, str):
nnodes = int(nnodes.split(":")[0])
else:
nnodes = int(nnodes)
tuner_cfg["nodes"] = nnodes
tuner_cfg["gpus_per_node"] = gpus_per_node
tuner_cfg["num_gpus"] = gpus_per_node * tuner_cfg["nodes"]
if not tuner_cfg.get("search_algo", None):
tuner_cfg["search_algo"] = {"name": "grid"}
mode = tuner_cfg.get("mode", None)
history_file_path = os.path.join(
os.path.dirname(ctx.args.auto_tuner_json),
f'{os.path.basename(ctx.args.auto_tuner_json).split(".")[0]}_history.csv',
)
sorted_ips = []
ip = None
if nnodes > 1:
from .utils.etcd_client import ETCDClient
assert "etcd://" in ctx.args.master
master_ip, port = ctx.args.master.strip("etcd://").split(':')
client = ETCDClient(host=master_ip, port=port)
client.delete("best_cfg")
client.delete_prefix("auto_tuner")
import socket
try:
hostname = socket.gethostname()
ip = socket.gethostbyname(socket.getfqdn(hostname))
except:
ip = '127.0.0.1'
assert ip != '127.0.0.1'
if tuner_cfg["search_algo"].get("estimated_num_gpus", None):
# get all machine ips and sort them
# to avoid etcd deleting key and adding key at the same time
time.sleep(5)
path = f"auto_tuner/ip/{ip}"
while not client.put(path, f"{ip}".encode('latin-1')):
time.sleep(1)
ips = list(client.get_prefix("auto_tuner/ip/"))
size = len(ips)
while size != nnodes:
time.sleep(1)
client.put(path, f"{ip}".encode('latin-1'))
ips = list(client.get_prefix("auto_tuner/ip/"))
size = len(ips)
sorted_ips = sorted([i[0].decode() for i in ips])
logger.info(
f"The total count of nodes is {len(sorted_ips)} and sorted ips are {sorted_ips}."
)
# get max time per task run
max_time_per_task = tuner_cfg.get("max_time_per_task", 1800)
tuner_cfg["max_time_per_task"] = max_time_per_task
ctx.max_time_per_task = max_time_per_task
# warmup
warmup_time = (
max_time_per_task
if "warmup_time" not in tuner_cfg
else tuner_cfg.get("warmup_time")
)
# max_search_time
max_search_time = tuner_cfg.get("max_search_time", None)
is_first_task = True
# build history recorder
recorder = HistoryRecorder()
job_id = 0
error_task_nums = 0
ctx.args.max_restart = -1
raw_ctx = copy.deepcopy(ctx)
# gbs search
if (
tuner_cfg.get('model_cfg', {}).get('global_batch_size', 'auto')
== "auto"
):
# adjust micron batch size until out of memory to get best global batch size
gbs_tuner_cfg = copy.deepcopy(tuner_cfg)
gbs_tuner_cfg["search_algo"] = "gbs"
gbs_tuner = AutoTuner(gbs_tuner_cfg)
gbs_cur_cfg = gbs_tuner.search_once()
best_gbs = None
while gbs_cur_cfg:
ctx = copy.deepcopy(raw_ctx)
log_dir = "GBSSearch/GBS{}_DP{}_MP{}_PP{}_Sharding_degree_{}_stage_{}_MBS{}_Recompute_{}_granularity_{}".format(
gbs_cur_cfg["global_batch_size"],
gbs_cur_cfg["dp_degree"],
gbs_cur_cfg["mp_degree"],
gbs_cur_cfg["pp_degree"],
gbs_cur_cfg["sharding_degree"],
gbs_cur_cfg["sharding_stage"],
gbs_cur_cfg["micro_batch_size"],
gbs_cur_cfg["use_recompute"],
gbs_cur_cfg["recompute_granularity"],
)
ctx.args.log_dir = log_dir
# every task has own job id
job_id += 1
task_job_id = "gbs_tuner_" + str(job_id)
ctx.args.job_id = task_job_id
# generate script args of task
gbs_new_args = gen_new_args(
raw_args, gbs_cur_cfg, gbs_tuner_cfg
)
ctx.args.training_script_args = gbs_new_args
# launch task
ctx.logger.info(
"Launch task from auto tuner: job_id {}, log_dir {}, config {}".format(
task_job_id, log_dir, gbs_cur_cfg
)
)
logger.info(
"Launch task from auto tuner: job_id {}, log_dir {}, config {}".format(
task_job_id, log_dir, gbs_cur_cfg
)
)
c = controllers.init(ctx)
c.run()
# process generated result
# TODO diffentiate out of memory and no loss(maybe over time)
# TODO integragte memory and metric read
metric, mem, err = read_log(
path=ctx.args.log_dir,
metric_file="workerlog.0",
target_metric=tuner_cfg["metric_cfg"]["name"],
memory_file=f"{ctx.args.job_id}.gpu.log",
)
if err & (1 << 0):
ctx.logger.warning(
f"Read metric failed for parameters: {log_dir}"
)
logger.warning(
f"Read metric failed for parameters: {log_dir}"
)
# for pruner use
gbs_cur_cfg['time'] = -1
gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = None
gbs_cur_cfg["max_mem_usage"] = mem
if err & (1 << 1):
ctx.logger.warning(
f"Out of memory for parameters: {log_dir}"
)
logger.warning(f"Out of memory for parameters: {log_dir}")
# for pruner use
gbs_cur_cfg['time'] = -1
gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = None
gbs_cur_cfg["max_mem_usage"] = "OOM"
# not err & (1 << 1): do not record memory usage when out of memory
if err & (1 << 2) and not err & (1 << 1):
ctx.logger.warning(
f"Read memory usage failed for parameters: {log_dir}"
)
logger.warning(
f"Read memory usage failed for parameters: {log_dir}"
)
gbs_cur_cfg["max_mem_usage"] = None
if not err:
# for pruner use
gbs_cur_cfg['time'] = metric
gbs_cur_cfg[tuner_cfg['metric_cfg']['name']] = metric
gbs_cur_cfg["max_mem_usage"] = mem
if err & (1 << 0) or err & (1 << 1):
# no metric or out of memory, end gbs search
break
# store and update args for next round
gbs_cur_cfg["job_id"] = job_id
best_gbs = gbs_cur_cfg["global_batch_size"]
recorder.add_cfg(**gbs_cur_cfg)
c.finalize(exit=False)
recorder.store_history("./tuner_gbs_history.csv")
# new cfgs for next round
gbs_new_cfg = gbs_tuner.search_once()
gbs_cur_cfg = copy.deepcopy(gbs_new_cfg)
gbs_tuner.add_cfg(gbs_cur_cfg)
# per task launch interval
time.sleep(3)
# prevent no valid global batch size found
if best_gbs is None:
raise ValueError(
"No valid global batch size found, check memory or valid search time. cur_tuner_cfg{}".format(
gbs_tuner_cfg
)
)
# set best global batch size to tuner cfg
tuner_cfg["model_cfg"]["global_batch_size"] = best_gbs
recorder.store_history("./tuner_gbs_history.csv")
recorder.clean_history()
end_time = time.time()
ctx.logger.info(
f"AtuoTuner for GBS search ends in {end_time-start_time}s."
)
logger.info(
f"AtuoTuner for GBS search ends in {end_time-start_time}s."
)
# build AutoTuner to get new config
auto_tuner = AutoTuner(tuner_cfg)
logger.info(
f"Launch {len(auto_tuner.algo.all_tasks)} tasks by auto tuner: "
)
cur_cfg = auto_tuner.search_once()
auto_tuner.add_cfg(cur_cfg)
assert cur_cfg is not None, "No config can run."
while cur_cfg:
task_start_time = time.time()
ctx = copy.deepcopy(raw_ctx)
if is_first_task:
ctx.max_time_per_task = warmup_time
is_first_task = False
# auto tuner supports dp, mp, pp, micro batch size, sharding, recompute by default and every task has own log dir
global_batch_size = (
cur_cfg["global_batch_size"]
if "global_batch_size" in cur_cfg
else tuner_cfg["model_cfg"]["global_batch_size"]
)
acc_steps = (
global_batch_size
// cur_cfg["dp_degree"]
// cur_cfg["sharding_degree"]
// cur_cfg["micro_batch_size"]
)
cur_cfg["acc_steps"] = acc_steps
cur_cfg["global_batch_size"] = global_batch_size
if "sharding_overlap" in cur_cfg:
log_dir = "GBS{}_DP{}_MP{}_PP{}_VPP{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}_AccStep{}_Overlap_{}".format(
global_batch_size,
cur_cfg["dp_degree"],
cur_cfg["mp_degree"],
cur_cfg["pp_degree"],
cur_cfg["vpp_degree"],
cur_cfg["sharding_degree"],
cur_cfg["sharding_stage"],
cur_cfg["micro_batch_size"],
cur_cfg["use_recompute"],
cur_cfg["recompute_granularity"],
cur_cfg["acc_steps"],
cur_cfg["sharding_overlap"],
)
else:
log_dir = "GBS{}_DP{}_MP{}_PP{}_VPP{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}_AccStep{}".format(
global_batch_size,
cur_cfg["dp_degree"],
cur_cfg["mp_degree"],
cur_cfg["pp_degree"],
cur_cfg["vpp_degree"],
cur_cfg["sharding_degree"],
cur_cfg["sharding_stage"],
cur_cfg["micro_batch_size"],
cur_cfg["use_recompute"],
cur_cfg["recompute_granularity"],
cur_cfg["acc_steps"],
)
ctx.args.log_dir = os.path.join(
os.path.dirname(ctx.args.auto_tuner_json), log_dir
)
# every task has own job id
job_id += 1
task_job_id = "auto_tuner_" + str(job_id)
ctx.args.job_id = task_job_id
# generate script args of task
new_args = gen_new_args(raw_args, cur_cfg, tuner_cfg)
ctx.args.training_script_args = new_args
# launch task
ctx.logger.info(
f"Launch task: job_id {task_job_id}, log_dir {log_dir}"
)
logger.info(f"Launch task: job_id {task_job_id}, log_dir {log_dir}")
# in single dp estimation scene, just some nodes not all nodes run
ctx = gen_new_ctx(ctx, cur_cfg, tuner_cfg)
actual_nnodes = int(ctx.args.nnodes.split(":")[0])
if sorted_ips:
actual_exec_ips = sorted_ips[:actual_nnodes]
if ip not in actual_exec_ips:
cur_cfg = client.get(f"auto_tuner/{log_dir}")[0]
wait_start_time = time.time()
while not cur_cfg:
wait_end_time = time.time()
if (
wait_end_time - wait_start_time
> tuner_cfg["max_time_per_task"] + 30
):
raise ValueError(f"Wait {log_dir} failed")
time.sleep(3)
cur_cfg = client.get(f"auto_tuner/{log_dir}")[0]
logger.info(
f"Receive that task {log_dir} has ended by etcd."
)
ctx.logger.info(
f"Receive that task {log_dir} has ended by etcd."
)
cur_cfg = json.loads(cur_cfg.decode())
auto_tuner.history_cfgs.pop(-1)
auto_tuner.add_cfg(cur_cfg)
recorder.add_cfg(**cur_cfg)
cur_best_cfgs, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg'][
'OptimizationDirection'
],
)
if not err:
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
logger.info(f"Current best config: {cur_best_cfgs}")
else:
ctx.logger.info(
"Get best config failed. Currently no config can be run."
)
logger.info(
"Get best config failed. Currently no config can be run."
)
if (
"sharding_overlap" in cur_cfg
and cur_cfg["sharding_overlap"]
):
add_overlap_performance(
cur_cfg, tuner_cfg, recorder.history
)
has_error = cur_cfg["has_error"]
if has_error:
error_task_nums += 1
error_info = cur_cfg["error_info"]
task_nums = len(auto_tuner.algo.all_tasks)
cur_task_id = auto_tuner.algo.idx
ctx.logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id)
* max_time_per_task
/ 60,
2,
),
)
)
logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id)
* max_time_per_task
/ 60,
2,
),
)
)
recorder.store_history(history_file_path)
# generate a new config
new_cfg = auto_tuner.search_once()
cur_cfg = copy.deepcopy(new_cfg)
auto_tuner.add_cfg(cur_cfg)
continue
c = controllers.init(ctx)
# for single dp estimation and not run sharding overlap
if tuner_cfg["search_algo"]["name"] != "grid":
# estimated_num_gpus means need single dp estimation
if "estimated_num_gpus" in tuner_cfg["search_algo"]:
if cur_cfg["sharding_degree"] == 1:
os.environ["FLAGS_shard_bypass_dygraph_optimizer"] = "1"
else:
os.environ["FLAGS_shard_bypass_dygraph_optimizer"] = "0"
c.run()
task_end_time = time.time()
cur_cfg["exec_time"] = round(task_end_time - task_start_time, 2)
ctx.logger.info(
"Task: job_id {}, log_dir {} ended in {}s".format(
task_job_id, log_dir, cur_cfg["exec_time"]
)
)
logger.info(
"Task: job_id {}, log_dir {} ended in {}s".format(
task_job_id, log_dir, cur_cfg["exec_time"]
)
)
# process generated result
metric, mem, err = read_log(
path=ctx.args.log_dir,
metric_file="workerlog.0",
target_metric=tuner_cfg["metric_cfg"]["name"],
memory_file=f"{ctx.args.job_id}.gpu.log",
)
# sync sigint
timeout_flag = True
OOM_flag = err & (1 << 1)
if actual_nnodes > 1:
path = f"auto_tuner/{job_id}/{ip}"
if OOM_flag:
while not client.put(path, "OOM".encode('latin-1')):
time.sleep(1)
ctx.logger.info(f"Put OOM to {path}")
logger.info(f"Put OOM to {path}")
elif hasattr(c, 'sigint') and c.sigint == 14:
while not client.put(path, "OK".encode('latin-1')):
time.sleep(1)
ctx.logger.info(f"Put OK to {path}")
logger.info(f"Put OK to {path}")
elif not hasattr(c, 'sigint') and c.pod.exit_code == 0:
while not client.put(path, "OK".encode('latin-1')):
time.sleep(1)
ctx.logger.info(f"Put OK to {path}")
logger.info(f"Put OK to {path}")
else:
while not client.put(path, "Error".encode('latin-1')):
time.sleep(1)
ctx.logger.info(f"Put Error to {path}")
logger.info(f"Put Error to {path}")
result = list(client.get_prefix(f"auto_tuner/{job_id}/"))
size = len(result)
while size != actual_nnodes:
time.sleep(1)
result = list(client.get_prefix(f"auto_tuner/{job_id}/"))
size = len(result)
status = [i[0].decode() for i in result]
ctx.logger.info(f"Status of auto_tuner/{job_id}/: {status}")
logger.info(f"Status of auto_tuner/{job_id}/: {status}")
if "OOM" in status:
timeout_flag = False
OOM_flag = True
elif "OK" not in status:
timeout_flag = False
has_error = False
if err & (1 << 0):
ctx.logger.warning(f"Read metric of {log_dir} failed.")
logger.warning(f"Read metric of {log_dir} failed.")
# for pruner use
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM"
has_error = True
if err & (1 << 1):
ctx.logger.warning(f"{log_dir} OOM.")
logger.warning(f"{log_dir} OOM.")
# for pruner use
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = "OOM"
has_error = True
# not err & (1 << 1): do not record memory usage when out of memory
if err & (1 << 2) and not err & (1 << 1):
ctx.logger.warning(f"Read memory usage of {log_dir} failed.")
logger.warning(f"Read memory usage of {log_dir} failed.")
cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM"
if not has_error and timeout_flag:
# for pruner use
cur_cfg['time'] = metric
cur_cfg[tuner_cfg['metric_cfg']['name']] = metric
cur_cfg["max_mem_usage"] = mem if not OOM_flag else "OOM"
if not has_error and not timeout_flag:
cur_cfg['time'] = -1
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg["max_mem_usage"] = None if not OOM_flag else "OOM"
if tuner_cfg['metric_cfg']['name'] not in cur_cfg:
cur_cfg[tuner_cfg['metric_cfg']['name']] = None
cur_cfg['job_id'] = job_id
# multi dp conversion
if (
"conversion" in tuner_cfg["search_algo"]
and "step_time" in tuner_cfg["search_algo"]["conversion"]
and "sharding_overlap" not in cur_cfg
):
single_dp_performance = cur_cfg[tuner_cfg['metric_cfg']['name']]
step_time_metric = tuner_cfg["search_algo"]["conversion"][
"step_time"
]
step_time = read_step_time_log(
path=ctx.args.log_dir,
file="workerlog.0",
target_metric=step_time_metric,
)
# set default
comm_bw = tuner_cfg["search_algo"]["conversion"].get(
"comm_bw", [100]
)
model_size_b = int(
tuner_cfg["search_algo"]["conversion"].get(
"model_size_b", 7
)
)
amp = tuner_cfg["search_algo"]["conversion"].get("amp", False)
for bw in comm_bw:
if amp:
comm_time = model_size_b * (4 + 2) / bw
else:
comm_time = model_size_b * 4 / bw
multi_dp_performace = (
round(
step_time
/ (step_time + comm_time)
* single_dp_performance,
5,
)
if single_dp_performance and step_time
else None
)
cur_cfg[
f"bw_{bw}_{tuner_cfg['metric_cfg']['name']}"
] = multi_dp_performace
cur_cfg["has_error"] = has_error
if has_error:
error_task_nums += 1
error_info = None
cur_cfg["error_info"] = error_info
task_nums = len(auto_tuner.algo.all_tasks)
cur_task_id = auto_tuner.algo.idx
ctx.logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id) * max_time_per_task / 60,
2,
),
)
)
logger.info(
"Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format(
cur_task_id,
task_nums,
cur_task_id - job_id,
error_task_nums,
error_info,
round(
(task_nums - cur_task_id) * max_time_per_task / 60,
2,
),
)
)
# sync for single dp
if sorted_ips:
master_ip = sorted_ips[0]
if ip == master_ip:
while not client.put(
f"auto_tuner/{log_dir}",
json.dumps(cur_cfg).encode('latin-1'),
):
time.sleep(1)
logger.info(f"{ip} put auto_tuner/{log_dir} successfully.")
recorder.add_cfg(**cur_cfg)
cur_best_cfgs, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
)
if not err:
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
logger.info(f"Current best config: {cur_best_cfgs}")
else:
ctx.logger.info("Get best config failed, no config can be run.")
logger.info("Get best config failed, no config can be run.")
# record history
if "sharding_overlap" in cur_cfg and cur_cfg["sharding_overlap"]:
add_overlap_performance(cur_cfg, tuner_cfg, recorder.history)
recorder.store_history(history_file_path)
c.finalize(exit=False)
# generate a new config
new_cfg = auto_tuner.search_once()
cur_cfg = copy.deepcopy(new_cfg)
auto_tuner.add_cfg(cur_cfg)
# per task launch interval
self_pid = str(os.getpid())
processes = os.popen(
"fuser -v /dev/nvidia* |awk '{for(i=1;i<=NF;i++) print $i;}'"
).readlines()
for process in processes:
pid = str(process.strip())
if pid != self_pid:
os.system("kill -9 " + pid)
time.sleep(3)
end_time = time.time()
if max_search_time and (end_time - start_time) > int(
max_search_time
):
break
recorder.store_history(history_file_path)
# get best config to run
best_cfg = None
ctx = copy.deepcopy(raw_ctx)