Skip to content
Merged
7 changes: 3 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
pixi-version: v0.27.1
frozen: true
cache: true
environments: default
activate-environment: default
environments: extra
activate-environment: extra
- run: pyright
if: success() || failure()

Expand Down Expand Up @@ -86,5 +86,4 @@ jobs:
cache: false
environments: default
activate-environment: default

- run: pytest tests
- run: pytest tests/test_CI.py
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ logs/
test_logs/
_build/
out/
output/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6,839 changes: 2,365 additions & 4,474 deletions pixi.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ pytest = "*"
build = "*"
twine = "*"

[feature.extra.pypi-dependencies]
transformers = "*"
submitit = "*"
setuptools = "*"
accelerate = "*"

[environments]
default = { features = ["package", "dev"], solve-group = "default" }
dev = { features = ["dev"], solve-group = "default" }
extra = { features = ["package", "dev", "extra"], solve-group = "default"}
1 change: 1 addition & 0 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def run(
raise
finally:
print_process.kill()
dist.destroy_process_group()

return_values: dict[int, Any] = dict(ChainMap(*[s.return_values for s in agent_statuses]))
return return_values
Expand Down
48 changes: 23 additions & 25 deletions tests/test_CI.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import shutil
import sys
import tempfile

import pytest
import torch
import torch.distributed as dist

sys.path.append("../src")

import torchrunx # noqa: I001
import torchrunx as trx


def test_simple_localhost():
Expand All @@ -30,38 +28,27 @@ def dist_func():

return o.detach()

r = torchrunx.launch(
func=dist_func,
func_kwargs={},
workers_per_host=2,
backend="gloo",
r = trx.launch(
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir="./test_logs"
)

assert torch.all(r[0] == r[1])

dist.destroy_process_group()


def test_logging():
def dist_func():
rank = int(os.environ["RANK"])
print(f"worker rank: {rank}")

try:
shutil.rmtree("./test_logs", ignore_errors=True)
except FileNotFoundError:
pass

torchrunx.launch(
func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir="./test_logs"
)
tmp = tempfile.mkdtemp()
trx.launch(func=dist_func, func_kwargs={}, workers_per_host=2, backend="gloo", log_dir=tmp)

log_files = next(os.walk("./test_logs"), (None, None, []))[2]
log_files = next(os.walk(tmp), (None, None, []))[2]

assert len(log_files) == 3

for file in log_files:
with open("./test_logs/" + file, "r") as f:
with open(f"{tmp}/{file}", "r") as f:
if file.endswith("0.log"):
assert f.read() == "worker rank: 0\n"
elif file.endswith("1.log"):
Expand All @@ -71,7 +58,18 @@ def dist_func():
assert "worker rank: 0" in contents
assert "worker rank: 1" in contents

# clean up
shutil.rmtree("./test_logs", ignore_errors=True)

dist.destroy_process_group()
def test_error():
def error_func():
raise ValueError("abcdefg")

with pytest.raises(RuntimeError) as excinfo:
trx.launch(
func=error_func,
func_kwargs={},
workers_per_host=1,
backend="gloo",
log_dir=tempfile.mkdtemp(),
)

assert "abcdefg" in str(excinfo.value)
17 changes: 8 additions & 9 deletions examples/slurm_poc.py → tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@
import torch
import torch.distributed as dist

import torchrunx

# this is not a pytest test, but a functional test designed to be run on a slurm allocation
import torchrunx as trx


def test_launch():
result = torchrunx.launch(
result = trx.launch(
func=simple_matmul,
hostnames=torchrunx.slurm_hosts(),
workers_per_host=torchrunx.slurm_workers(),
hostnames=trx.slurm_hosts(),
workers_per_host=trx.slurm_workers(),
)

t = True
for i in range(len(result)):
assert torch.all(result[i] == result[0]), "Not all tensors equal"
print(result[0])
print("PASS")
t = t and torch.all(result[i] == result[0])

assert t, "Not all tensors equal"


def simple_matmul():
Expand Down
41 changes: 30 additions & 11 deletions examples/submitit_train.py → tests/test_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,54 @@ def __getitem__(self, index):
"labels": self.labels[index],
}


def main():
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
train_dataset = DummyDataset()

## Training

training_arguments = TrainingArguments(
output_dir = "output",
do_train = True,
per_device_train_batch_size = 16,
max_steps = 20,
output_dir="output",
do_train=True,
per_device_train_batch_size=16,
max_steps=20,
)

trainer = Trainer(
model=model, # type: ignore
model=model, # type: ignore
args=training_arguments,
train_dataset=train_dataset
train_dataset=train_dataset,
)

trainer.train()


def launch():
trx.launch(
func=main,
func_kwargs={},
hostnames=trx.slurm_hosts(),
workers_per_host=trx.slurm_workers()
func=main, func_kwargs={}, hostnames=trx.slurm_hosts(), workers_per_host=trx.slurm_workers()
)


def test_submitit():
executor = submitit.SlurmExecutor(folder="logs")

executor.update_parameters(
time=60,
nodes=1,
ntasks_per_node=1,
mem="32G",
cpus_per_task=4,
gpus_per_node=2,
constraint="geforce3090",
partition="3090-gcondo",
stderr_to_stdout=True,
use_srun=False,
)

executor.submit(launch).result()


if __name__ == "__main__":
executor = submitit.SlurmExecutor(folder="logs")

Expand All @@ -68,4 +87,4 @@ def launch():
use_srun=False,
)

executor.submit(launch)
executor.submit(launch)
28 changes: 8 additions & 20 deletions examples/distributed_train.py → tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import socket
import subprocess

import torchrunx
import torchrunx as trx


def worker():
Expand Down Expand Up @@ -34,24 +32,14 @@ def forward(self, x):
loss.sum().backward()


def resolve_node_ips(nodelist):
# Expand the nodelist into individual hostnames
hostnames = (
subprocess.check_output(["scontrol", "show", "hostnames", nodelist])
.decode()
.strip()
.split("\n")
def test_distributed_train():
trx.launch(
worker,
hostnames=trx.slurm_hosts(),
workers_per_host=trx.slurm_workers(),
backend="nccl",
)
# Resolve each hostname to an IP address
ips = [socket.gethostbyname(hostname) for hostname in hostnames]
return ips


if __name__ == "__main__":
torchrunx.launch(
worker,
{},
hostnames=torchrunx.slurm_hosts(),
workers_per_host=torchrunx.slurm_workers(),
backend="nccl",
)
test_distributed_train()