-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-41591][PYTHON][ML] Training PyTorch Files on Single Node Multi GPU #39188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,16 @@ | |
| # limitations under the License. | ||
| # | ||
|
|
||
| import collections | ||
| import ctypes | ||
| import math | ||
| import os | ||
| import random | ||
| import re | ||
| import signal | ||
| import sys | ||
| import subprocess | ||
| import time | ||
| from typing import Union, Callable, Optional, Any | ||
| import warnings | ||
|
|
||
|
|
@@ -34,8 +43,8 @@ def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool: | |
|
|
||
| Parameters | ||
| ---------- | ||
| sc : SparkContext | ||
| The SparkContext for the distributor. | ||
| sc : :class:`SparkContext` | ||
| The :class:`SparkContext` for the distributor. | ||
| key : str | ||
| string for conf name | ||
| default_value : str | ||
|
|
@@ -64,6 +73,42 @@ def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool: | |
| ) | ||
|
|
||
|
|
||
| def get_gpus_owned(sc: SparkContext) -> list[str]: | ||
| """Gets the number of GPUs that Spark scheduled to the calling task. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| sc : :class:`SparkContext` | ||
| The :class:`SparkContext` that has GPUs available. | ||
|
|
||
| Returns | ||
| ------- | ||
| list | ||
| The correct mapping of addresses to workers. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| Raised if the input addresses were not found. | ||
| """ | ||
| CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" | ||
| pattern = re.compile("^[1-9][0-9]*|0$") | ||
| addresses = sc.resources["gpu"].addresses | ||
| if any(not pattern.match(address) for address in addresses): | ||
| raise ValueError( | ||
| f"Found GPU addresses {addresses} which " | ||
| "are not all in the correct format " | ||
| "for CUDA_VISIBLE_DEVICES, which requires " | ||
| "integers with no zero padding." | ||
| ) | ||
| if CUDA_VISIBLE_DEVICES in os.environ: | ||
| gpu_indices = list(map(int, addresses)) | ||
| gpu_list = os.environ[CUDA_VISIBLE_DEVICES].split(",") | ||
| gpu_owned = [gpu_list[i] for i in gpu_indices] | ||
| return gpu_owned | ||
| return addresses | ||
|
|
||
|
|
||
| class Distributor: | ||
| """ | ||
| The parent class for TorchDistributor. This class shouldn't be instantiated directly. | ||
|
|
@@ -85,6 +130,12 @@ def __init__( | |
| self.num_tasks = self._get_num_tasks() | ||
| self.ssl_conf = None | ||
|
|
||
| def _create_input_params(self) -> dict[str, Any]: | ||
| input_params = self.__dict__.copy() | ||
| for unneeded_param in ["spark", "sc", "ssl_conf"]: | ||
| del input_params[unneeded_param] | ||
| return input_params | ||
|
|
||
| def _get_num_tasks(self) -> int: | ||
| """ | ||
| Returns the number of Spark tasks to use for distributed training | ||
|
|
@@ -261,6 +312,130 @@ def __init__( | |
| super().__init__(num_processes, local_mode, use_gpu) | ||
| self.ssl_conf = "pytorch.spark.distributor.ignoreSsl" # type: ignore | ||
| self._validate_input_params() | ||
| self.input_params = self._create_input_params() | ||
|
|
||
| @staticmethod | ||
| def _create_torchrun_command( | ||
| input_params: dict[str, Any], path_to_train_file: str, *args: Any | ||
| ) -> list[str]: | ||
| local_mode = input_params["local_mode"] | ||
| num_processes = input_params["num_processes"] | ||
|
|
||
| if local_mode: | ||
| torchrun_args = ["--standalone", "--nnodes=1"] | ||
| processes_per_node = num_processes | ||
| else: | ||
| pass | ||
| # TODO(SPARK-41592): Handle distributed training | ||
|
Comment on lines
+324
to
+329
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not an issue if if condition:
x =
else:
...
# Do something with x
print(x)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense! |
||
|
|
||
| args_string = list(map(str, args)) # converting all args to strings | ||
|
|
||
| return ( | ||
| [sys.executable, "-m", "pyspark.ml.torch.distributor.torch_run_process_wrapper"] | ||
| + torchrun_args | ||
| + [f"--nproc_per_node={processes_per_node}"] | ||
| + [path_to_train_file, *args_string] | ||
|
Comment on lines
+334
to
+337
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [
sys.executable,
"-m",
"pyspark.ml.torch.distributor.torch_run_process_wrapper",
*torch_args,
f"--nproc_per_node={processes_per_node}"
path_to_train_file,
*args_string
]This might be simpler than concatenating four lists. |
||
| ) | ||
|
|
||
| @staticmethod | ||
| def _execute_command( | ||
| cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True | ||
| ) -> None: | ||
| _TAIL_LINES_TO_KEEP = 100 | ||
|
|
||
| def sigterm_on_parent_death() -> None: | ||
| """ | ||
| Uses prctl to automatically send SIGTERM to the command process when its parent is dead. | ||
| This handles the case when the parent is a PySpark worker process. | ||
| If a user cancels the PySpark job, the worker process gets killed, regardless of | ||
| PySpark daemon and worker reuse settings. | ||
| """ | ||
| if _prctl: | ||
| try: | ||
| libc = ctypes.CDLL("libc.so.6") | ||
| # Set the parent process death signal of the command process to SIGTERM. | ||
| libc.prctl(1, signal.SIGTERM) | ||
| except OSError: | ||
| pass | ||
|
|
||
| task = subprocess.Popen( | ||
| cmd, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.STDOUT, | ||
| stdin=subprocess.PIPE, | ||
| env=os.environ, | ||
| preexec_fn=sigterm_on_parent_death, | ||
| ) | ||
| task.stdin.close() # type: ignore | ||
| tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP) | ||
| try: | ||
| for line in task.stdout: # type: ignore | ||
| decoded = line.decode() | ||
| tail.append(decoded) | ||
| if redirect_to_stdout: | ||
| sys.stdout.write(decoded) | ||
| task.wait() | ||
| finally: | ||
| if task.poll() is None: | ||
| try: | ||
| task.terminate() # SIGTERM | ||
| time.sleep(0.5) | ||
| if task.poll() is None: | ||
| task.kill() # SIGKILL | ||
| except OSError: | ||
| pass | ||
| if task.returncode != os.EX_OK: | ||
| if len(tail) == _TAIL_LINES_TO_KEEP: | ||
| last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task output are" | ||
| else: | ||
| last_n_msg = "task output is" | ||
| task_output = "".join(tail) | ||
| raise RuntimeError( | ||
| f"Command {cmd} failed with return code {task.returncode}." | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this line is missing a space after |
||
| f"The {last_n_msg} included below: {task_output}" | ||
| ) | ||
|
|
||
| def _run_local_training( | ||
| self, | ||
| framework_wrapper_fn: Optional[Callable], | ||
| train_object: Union[Callable, str], | ||
| *args: Any, | ||
| ) -> Optional[Any]: | ||
| CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" | ||
| cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ | ||
| old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "") | ||
| try: | ||
| if self.use_gpu: | ||
| gpus_owned = get_gpus_owned(self.sc) | ||
|
|
||
| if self.num_processes > len(gpus_owned): | ||
| raise ValueError( | ||
| f"""{self.num_processes} processes were requested | ||
| for local training with GPU training but only | ||
| {len(gpus_owned)} GPUs were available.""" | ||
| ) | ||
|
Comment on lines
+412
to
+416
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the error message here would look like:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Triple-quoted strings in |
||
| random.seed(hash(train_object)) | ||
| selected_gpus = [str(e) for e in random.sample(gpus_owned, self.num_processes)] | ||
|
||
| os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus) | ||
|
|
||
| output = framework_wrapper_fn(self.input_params, train_object, *args) # type: ignore | ||
| finally: | ||
| if cuda_state_was_set: | ||
| os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices | ||
| else: | ||
| if CUDA_VISIBLE_DEVICES in os.environ: | ||
| del os.environ[CUDA_VISIBLE_DEVICES] | ||
|
|
||
| return output | ||
|
|
||
| @staticmethod | ||
| def _run_training_on_pytorch_file( | ||
| input_params: dict[str, Any], train_path: str, *args: Any | ||
| ) -> None: | ||
| training_command = TorchDistributor._create_torchrun_command( | ||
| input_params, train_path, *args | ||
| ) | ||
| TorchDistributor._execute_command(training_command) | ||
|
|
||
| def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]: | ||
| """Runs distributed training. | ||
|
|
@@ -278,4 +453,9 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]: | |
| Returns the output of train_object called with args if train_object is a | ||
| Callable with an expected output. | ||
| """ | ||
| pass | ||
| framework_wrapper_fn = None | ||
| if isinstance(train_object, str): | ||
| framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_file | ||
| if self.local_mode: | ||
| output = self._run_local_training(framework_wrapper_fn, train_object, *args) | ||
| return output | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this instead of passing a
kwargs?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just to pass the input params (specifically
num_processesandlocal_mode) for thestaticmethods to use. We will have to havekwargssupport in a future PR in case the users want to have amax_restartsoption.