Skip to content

Commit 765e802

Browse files
committed
fix typing, formatting
1 parent 4b4cf13 commit 765e802

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

src/torchrunx/launcher.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,22 @@ def monitor_log(log_file: Path):
7878

7979
@dataclass
8080
class Launcher:
81-
hostnames: list[str] = (["localhost"],)
82-
workers_per_host: int | list[int] = (1,)
83-
ssh_config_file: str | os.PathLike | None = (None,)
84-
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = (None,)
85-
log_dir: os.PathLike | str = ("./logs",)
86-
env_vars: list[str] = (
87-
[
88-
"PATH",
89-
"LD_LIBRARY",
90-
"LIBRARY_PATH",
91-
"PYTHON*",
92-
"CUDA*",
93-
"TORCH*",
94-
"PYTORCH*",
95-
"NCCL*",
96-
],
97-
)
98-
env_file: str | os.PathLike | None = (None,)
81+
hostnames: list[str] = ["localhost"]
82+
workers_per_host: int | list[int] = 1
83+
ssh_config_file: str | os.PathLike | None = None
84+
backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None
85+
log_dir: os.PathLike | str = "./logs"
86+
env_vars: list[str] = [
87+
"PATH",
88+
"LD_LIBRARY",
89+
"LIBRARY_PATH",
90+
"PYTHON*",
91+
"CUDA*",
92+
"TORCH*",
93+
"PYTORCH*",
94+
"NCCL*",
95+
]
96+
env_file: str | os.PathLike | None = None
9997

10098
def run(
10199
self,
@@ -137,7 +135,7 @@ def run(
137135
workers_per_host = [workers_per_host] * num_hosts
138136

139137
assert workers_per_host is not None
140-
assert len(workers_per_host) == num_hosts
138+
assert len(workers_per_host) == num_hosts # type: ignore
141139

142140
# launch command
143141

@@ -199,7 +197,7 @@ def run(
199197

200198
# build and sync payloads between launcher and agents
201199

202-
_cumulative_workers = [0] + list(itertools.accumulate(workers_per_host))
200+
_cumulative_workers = [0] + list(itertools.accumulate(workers_per_host)) # type: ignore
203201

204202
worker_world_size = _cumulative_workers[-1]
205203

@@ -211,7 +209,7 @@ def run(
211209
worker_log_files = [
212210
[
213211
log_dir / f"{timestamp}_{hostname}_{local_rank}.log"
214-
for local_rank in range(workers_per_host[i])
212+
for local_rank in range(workers_per_host[i]) # type: ignore
215213
]
216214
for i, hostname in enumerate(self.hostnames)
217215
]

src/torchrunx/slurm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import subprocess
33

44

5-
def slurm_hosts() -> list[str]:
5+
def slurm_hosts() -> "list[str]":
66
"""Retrieves hostnames of Slurm-allocated nodes.
77
88
:return: Hostnames of nodes in current Slurm allocation

0 commit comments

Comments
 (0)