Skip to content

Commit

Permalink
Validate SRUN variables when launching in SLURM (#15011)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 19, 2022
1 parent 26f632c commit 576757f
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 7 deletions.
48 changes: 47 additions & 1 deletion src/lightning_lite/plugins/environments/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
import os
import re
import signal
import subprocess
import sys
from typing import Optional

from lightning_utilities.core.rank_zero import rank_zero_warn

from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.imports import _IS_WINDOWS
from lightning_lite.utilities.warnings import PossibleUserWarning

log = logging.getLogger(__name__)

Expand All @@ -40,6 +45,8 @@ def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Si
if requeue_signal is None and not _IS_WINDOWS:
requeue_signal = signal.SIGUSR1
self.requeue_signal = requeue_signal
self._validate_srun_used()
self._validate_srun_variables()

@property
def creates_processes_externally(self) -> bool:
Expand Down Expand Up @@ -89,7 +96,8 @@ def detect() -> bool:
This will then avoid the detection of ``SLURMEnvironment`` and another environment can be detected
automatically.
"""
return "SLURM_NTASKS" in os.environ and SLURMEnvironment.job_name() != "bash"
SLURMEnvironment._validate_srun_used()
return _is_srun_used()

@staticmethod
def job_name() -> Optional[str]:
Expand Down Expand Up @@ -141,3 +149,41 @@ def resolve_root_node_address(nodes: str) -> str:
nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range
nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number
return nodes.split(" ")[0].split(",")[0]

@staticmethod
def _validate_srun_used() -> None:
"""Checks if the `srun` command is available and used.
Parallel jobs (multi-GPU, multi-node) in SLURM are launched by prepending `srun` in front of the Python command.
Not doing so will result in processes hanging, which is a frequent user error. Lightning will emit a warning if
`srun` is found but not used.
"""
if _IS_WINDOWS:
return
srun_exists = subprocess.call(["which", "srun"]) == 0
if srun_exists and not _is_srun_used():
hint = " ".join(["srun", os.path.basename(sys.executable), *sys.argv])[:64]
rank_zero_warn(
"The `srun` command is available on your system but is not used. HINT: If your intention is to run"
f" Lightning on SLURM, prepend your python command with `srun` like so: {hint} ...",
category=PossibleUserWarning,
)

@staticmethod
def _validate_srun_variables() -> None:
"""Checks for conflicting or incorrectly set variables set through `srun` and raises a useful error
message.
Right now, we only check for the most common user errors. See `the srun docs
<https://slurm.schedmd.com/srun.html>`_ for a complete list of supported srun variables.
"""
ntasks = int(os.environ.get("SLURM_NTASKS", "1"))
if ntasks > 1 and "SLURM_NTASKS_PER_NODE" not in os.environ:
raise RuntimeError(
f"You set `--ntasks={ntasks}` in your SLURM bash script, but this variable is not supported."
f" HINT: Use `--ntasks-per-node={ntasks}` instead."
)


def _is_srun_used() -> bool:
return "SLURM_NTASKS" in os.environ and SLURMEnvironment.job_name() != "bash"
10 changes: 4 additions & 6 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `LightningLite.no_backward_sync` for control over efficient gradient accumulation with distributed strategies ([#14966](https://github.com/Lightning-AI/lightning/pull/14966))



### Changed

- Moved the warning about saving nn.Module in `save_hyperparameters()` to before the deepcopy ([#15132](https://github.com/Lightning-AI/lightning/pull/15132))



## [1.8.0] - 2022-MM-DD
Expand Down Expand Up @@ -48,6 +45,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/pull/14709))
- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998))
- Added `inference_mode` flag to Trainer to let users enable/disable inference mode during evaluation ([#15034](https://github.com/Lightning-AI/lightning/pull/15034))
- Added `LightningLite.no_backward_sync` for control over efficient gradient accumulation with distributed strategies ([#14966](https://github.com/Lightning-AI/lightning/pull/14966))
- Added a sanity check that scripts are executed with the `srun` command in SLURM and that environment variables are not conflicting ([#15011](https://github.com/Lightning-AI/lightning/pull/15011))


### Changed
Expand All @@ -73,8 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- It is no longer needed to call `model.double()` when using `precision=64` in Lightning Lite ([#14827](https://github.com/Lightning-AI/lightning/pull/14827))
- HPC checkpoints are now loaded automatically only in slurm environment when no specific value for `ckpt_path` has been set ([#14911](https://github.com/Lightning-AI/lightning/pull/14911))
- The `Callback.on_load_checkpoint` now gets the full checkpoint dictionary and the `callback_state` argument was renamed `checkpoint` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))


- Moved the warning about saving nn.Module in `save_hyperparameters()` to before the deepcopy ([#15132](https://github.com/Lightning-AI/lightning/pull/15132))
- To avoid issues with forking processes, from PyTorch 1.13 and higher, Lightning will directly use the PyTorch NVML-based check for `torch.cuda.device_count` and from PyTorch 1.14 and higher, Lightning will configure PyTorch to use a NVML-based check for `torch.cuda.is_available`. ([#15110](https://github.com/Lightning-AI/lightning/pull/15110), [#15133](https://github.com/Lightning-AI/lightning/pull/15133))


Expand Down
34 changes: 34 additions & 0 deletions tests/tests_lite/plugins/environments/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
# limitations under the License.
import logging
import os
import sys
from unittest import mock

import pytest
from tests_lite.helpers.runif import RunIf
from tests_lite.helpers.utils import no_warning_call

from lightning_lite.plugins.environments import SLURMEnvironment
from lightning_lite.utilities.warnings import PossibleUserWarning


@mock.patch.dict(os.environ, {}, clear=True)
Expand Down Expand Up @@ -47,6 +51,7 @@ def test_default_attributes():
"SLURM_NODELIST": "1.1.1.1, 1.1.1.2",
"SLURM_JOB_ID": "0001234",
"SLURM_NTASKS": "20",
"SLURM_NTASKS_PER_NODE": "10",
"SLURM_LOCALID": "2",
"SLURM_PROCID": "1",
"SLURM_NODEID": "3",
Expand Down Expand Up @@ -112,3 +117,32 @@ def test_detect():

with mock.patch.dict(os.environ, {"SLURM_JOB_NAME": "bash"}):
assert not SLURMEnvironment.detect()


@RunIf(skip_windows=True)
def test_srun_available_and_not_used(monkeypatch):
"""Test that a warning is emitted if Lightning suspects the user forgot to run their script with `srun`."""
monkeypatch.setattr(sys, "argv", ["train.py", "--lr", "0.01"])
expected = "`srun` .* available .* but is not used. HINT: .* srun python train.py --lr 0.01"

# pretend `srun` is available
with mock.patch("lightning_lite.plugins.environments.slurm.subprocess.call", return_value=0):
with pytest.warns(PossibleUserWarning, match=expected):
SLURMEnvironment()

with pytest.warns(PossibleUserWarning, match=expected):
SLURMEnvironment.detect()

# no warning if `srun` is unavailable
with no_warning_call(PossibleUserWarning, match=expected):
SLURMEnvironment()
assert not SLURMEnvironment.detect()


def test_srun_variable_validation():
"""Test that we raise useful errors when `srun` variables are misconfigured."""
with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}):
SLURMEnvironment()
with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}):
with pytest.raises(RuntimeError, match="You set `--ntasks=2` in your SLURM"):
SLURMEnvironment()
4 changes: 4 additions & 0 deletions tests/tests_lite/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def creates_processes_externally(self) -> bool:
os.environ,
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down Expand Up @@ -203,6 +204,7 @@ class Strat(DDPStrategy):
os.environ,
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down Expand Up @@ -496,6 +498,7 @@ def test_strategy_choice_ddp_slurm(_, strategy, job_name, expected_env):
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": job_name,
"SLURM_NODEID": "0",
"SLURM_PROCID": "1",
Expand Down Expand Up @@ -596,6 +599,7 @@ def test_strategy_choice_ddp_cpu_kubeflow():
os.environ,
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_amp_gpus(tmpdir, strategy, precision, devices):
os.environ,
{
"SLURM_NTASKS": "1",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/plugins/test_cluster_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def environment_combinations():
"SLURM_NODEID": "1",
"SLURM_PROCID": "3",
"SLURM_NTASKS": "4",
"SLURM_NTASKS_PER_NODE": "2",
}
environment = SLURMEnvironment()
yield environment, variables, expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _test_strategy_choice_ddp_and_cpu(tmpdir, ddp_strategy_class):
os.environ,
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down Expand Up @@ -128,6 +129,7 @@ def creates_processes_externally(self) -> bool:
os.environ,
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down Expand Up @@ -195,6 +197,7 @@ class Strat(DDPStrategy):
os.environ,
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down Expand Up @@ -473,6 +476,7 @@ def test_strategy_choice_ddp_slurm(cuda_count_2, strategy, job_name, expected_en
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": job_name,
"SLURM_NODEID": "0",
"SLURM_PROCID": "1",
Expand Down Expand Up @@ -575,6 +579,7 @@ def test_strategy_choice_ddp_cpu_kubeflow(cuda_count_0):
os.environ,
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
Expand Down

0 comments on commit 576757f

Please sign in to comment.