@@ -78,24 +78,22 @@ def monitor_log(log_file: Path):
78
78
79
79
@dataclass
80
80
class Launcher :
81
- hostnames : list [str ] = (["localhost" ],)
82
- workers_per_host : int | list [int ] = (1 ,)
83
- ssh_config_file : str | os .PathLike | None = (None ,)
84
- backend : Literal ["mpi" , "gloo" , "nccl" , "ucc" , None ] = (None ,)
85
- log_dir : os .PathLike | str = ("./logs" ,)
86
- env_vars : list [str ] = (
87
- [
88
- "PATH" ,
89
- "LD_LIBRARY" ,
90
- "LIBRARY_PATH" ,
91
- "PYTHON*" ,
92
- "CUDA*" ,
93
- "TORCH*" ,
94
- "PYTORCH*" ,
95
- "NCCL*" ,
96
- ],
97
- )
98
- env_file : str | os .PathLike | None = (None ,)
81
+ hostnames : list [str ] = ["localhost" ]
82
+ workers_per_host : int | list [int ] = 1
83
+ ssh_config_file : str | os .PathLike | None = None
84
+ backend : Literal ["mpi" , "gloo" , "nccl" , "ucc" , None ] = None
85
+ log_dir : os .PathLike | str = "./logs"
86
+ env_vars : list [str ] = [
87
+ "PATH" ,
88
+ "LD_LIBRARY" ,
89
+ "LIBRARY_PATH" ,
90
+ "PYTHON*" ,
91
+ "CUDA*" ,
92
+ "TORCH*" ,
93
+ "PYTORCH*" ,
94
+ "NCCL*" ,
95
+ ]
96
+ env_file : str | os .PathLike | None = None
99
97
100
98
def run (
101
99
self ,
@@ -137,7 +135,7 @@ def run(
137
135
workers_per_host = [workers_per_host ] * num_hosts
138
136
139
137
assert workers_per_host is not None
140
- assert len (workers_per_host ) == num_hosts
138
+ assert len (workers_per_host ) == num_hosts # type: ignore
141
139
142
140
# launch command
143
141
@@ -199,7 +197,7 @@ def run(
199
197
200
198
# build and sync payloads between launcher and agents
201
199
202
- _cumulative_workers = [0 ] + list (itertools .accumulate (workers_per_host ))
200
+ _cumulative_workers = [0 ] + list (itertools .accumulate (workers_per_host )) # type: ignore
203
201
204
202
worker_world_size = _cumulative_workers [- 1 ]
205
203
@@ -211,7 +209,7 @@ def run(
211
209
worker_log_files = [
212
210
[
213
211
log_dir / f"{ timestamp } _{ hostname } _{ local_rank } .log"
214
- for local_rank in range (workers_per_host [i ])
212
+ for local_rank in range (workers_per_host [i ]) # type: ignore
215
213
]
216
214
for i , hostname in enumerate (self .hostnames )
217
215
]
0 commit comments