Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 183 additions & 3 deletions python/pyspark/ml/torch/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
# limitations under the License.
#

import collections
import ctypes
import math
import os
import random
import re
import signal
import sys
import subprocess
import time
from typing import Union, Callable, Optional, Any
import warnings

Expand All @@ -34,8 +43,8 @@ def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Parameters
----------
sc : SparkContext
The SparkContext for the distributor.
sc : :class:`SparkContext`
The :class:`SparkContext` for the distributor.
key : str
string for conf name
default_value : str
Expand Down Expand Up @@ -64,6 +73,42 @@ def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
)


def get_gpus_owned(sc: SparkContext) -> list[str]:
"""Gets the number of GPUs that Spark scheduled to the calling task.

Parameters
----------
sc : :class:`SparkContext`
The :class:`SparkContext` that has GPUs available.

Returns
-------
list
The correct mapping of addresses to workers.

Raises
------
ValueError
Raised if the input addresses were not found.
"""
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
pattern = re.compile("^[1-9][0-9]*|0$")
addresses = sc.resources["gpu"].addresses
if any(not pattern.match(address) for address in addresses):
raise ValueError(
f"Found GPU addresses {addresses} which "
"are not all in the correct format "
"for CUDA_VISIBLE_DEVICES, which requires "
"integers with no zero padding."
)
if CUDA_VISIBLE_DEVICES in os.environ:
gpu_indices = list(map(int, addresses))
gpu_list = os.environ[CUDA_VISIBLE_DEVICES].split(",")
gpu_owned = [gpu_list[i] for i in gpu_indices]
return gpu_owned
return addresses


class Distributor:
"""
The parent class for TorchDistributor. This class shouldn't be instantiated directly.
Expand All @@ -85,6 +130,12 @@ def __init__(
self.num_tasks = self._get_num_tasks()
self.ssl_conf = None

def _create_input_params(self) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this instead of passing a kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to pass the input params (specifically num_processes and local_mode) for the staticmethods to use. We will have to have kwargs support in a future PR in case the users want to have a max_restarts option.

input_params = self.__dict__.copy()
for unneeded_param in ["spark", "sc", "ssl_conf"]:
del input_params[unneeded_param]
return input_params

def _get_num_tasks(self) -> int:
"""
Returns the number of Spark tasks to use for distributed training
Expand Down Expand Up @@ -261,6 +312,130 @@ def __init__(
super().__init__(num_processes, local_mode, use_gpu)
self.ssl_conf = "pytorch.spark.distributor.ignoreSsl" # type: ignore
self._validate_input_params()
self.input_params = self._create_input_params()

@staticmethod
def _create_torchrun_command(
input_params: dict[str, Any], path_to_train_file: str, *args: Any
) -> list[str]:
local_mode = input_params["local_mode"]
num_processes = input_params["num_processes"]

if local_mode:
torchrun_args = ["--standalone", "--nnodes=1"]
processes_per_node = num_processes
else:
pass
# TODO(SPARK-41592): Handle distributed training
Comment on lines +324 to +329
Copy link
Contributor

@harupy harupy Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not an issue if local_mode is always True, but code like below leaves x undefined when condition is False and causes an error.

if condition:
    x =
else:
    ...

# Do something with x
print(x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_mode will always be true for this PR, but the follow up PR addresses the else section.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!


args_string = list(map(str, args)) # converting all args to strings

return (
[sys.executable, "-m", "pyspark.ml.torch.distributor.torch_run_process_wrapper"]
+ torchrun_args
+ [f"--nproc_per_node={processes_per_node}"]
+ [path_to_train_file, *args_string]
Comment on lines +334 to +337
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            [
              sys.executable,
              "-m", 
              "pyspark.ml.torch.distributor.torch_run_process_wrapper",
              *torch_args,
              f"--nproc_per_node={processes_per_node}"
              path_to_train_file,
              *args_string
            ]

This might be simpler than concatenating four lists.

)

@staticmethod
def _execute_command(
cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True
) -> None:
_TAIL_LINES_TO_KEEP = 100

def sigterm_on_parent_death() -> None:
"""
Uses prctl to automatically send SIGTERM to the command process when its parent is dead.
This handles the case when the parent is a PySpark worker process.
If a user cancels the PySpark job, the worker process gets killed, regardless of
PySpark daemon and worker reuse settings.
"""
if _prctl:
try:
libc = ctypes.CDLL("libc.so.6")
# Set the parent process death signal of the command process to SIGTERM.
libc.prctl(1, signal.SIGTERM)
except OSError:
pass

task = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE,
env=os.environ,
preexec_fn=sigterm_on_parent_death,
)
task.stdin.close() # type: ignore
tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP)
try:
for line in task.stdout: # type: ignore
decoded = line.decode()
tail.append(decoded)
if redirect_to_stdout:
sys.stdout.write(decoded)
task.wait()
finally:
if task.poll() is None:
try:
task.terminate() # SIGTERM
time.sleep(0.5)
if task.poll() is None:
task.kill() # SIGKILL
except OSError:
pass
if task.returncode != os.EX_OK:
if len(tail) == _TAIL_LINES_TO_KEEP:
last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task output are"
else:
last_n_msg = "task output is"
task_output = "".join(tail)
raise RuntimeError(
f"Command {cmd} failed with return code {task.returncode}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is missing a space after ..

f"The {last_n_msg} included below: {task_output}"
)

def _run_local_training(
self,
framework_wrapper_fn: Optional[Callable],
train_object: Union[Callable, str],
*args: Any,
) -> Optional[Any]:
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ
old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
try:
if self.use_gpu:
gpus_owned = get_gpus_owned(self.sc)

if self.num_processes > len(gpus_owned):
raise ValueError(
f"""{self.num_processes} processes were requested
for local training with GPU training but only
{len(gpus_owned)} GPUs were available."""
)
Comment on lines +412 to +416
Copy link
Contributor

@harupy harupy Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error message here would look like:

2 processes ...
             for local ...
             3 GPUs were ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triple-quoted strings in _check_encryption also need to be fixed.

random.seed(hash(train_object))
selected_gpus = [str(e) for e in random.sample(gpus_owned, self.num_processes)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Why randomly pick GPUs ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If multiple runs happen concurrently, less chance that they will impact the same GPUs.

Copy link
Contributor

@lu-wang-dl lu-wang-dl Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this may cause more flaky issue if the user rerun the function. and make it more difficult to debug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should we choose just the first self.num_processes GPUs owned then? I'm not sure what a good alternative approach here would be...

Copy link
Contributor

@WeichenXu123 WeichenXu123 Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lu-wang-dl

We can set deterministic randomizer seed.
e.g., we can generate the randomizer seed by the hash of the torch program code,
then using the randomizer, we generate the selected_gpus list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. It is fine if we add a dterministic random seed.

os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)

output = framework_wrapper_fn(self.input_params, train_object, *args) # type: ignore
finally:
if cuda_state_was_set:
os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices
else:
if CUDA_VISIBLE_DEVICES in os.environ:
del os.environ[CUDA_VISIBLE_DEVICES]

return output

@staticmethod
def _run_training_on_pytorch_file(
input_params: dict[str, Any], train_path: str, *args: Any
) -> None:
training_command = TorchDistributor._create_torchrun_command(
input_params, train_path, *args
)
TorchDistributor._execute_command(training_command)

def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
"""Runs distributed training.
Expand All @@ -278,4 +453,9 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
Returns the output of train_object called with args if train_object is a
Callable with an expected output.
"""
pass
framework_wrapper_fn = None
if isinstance(train_object, str):
framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_file
if self.local_mode:
output = self._run_local_training(framework_wrapper_fn, train_object, *args)
return output
Loading