Skip to content

Commit

Permalink
[SPARK-44264][ML][PYTHON] Support Distributed Training of Functions U…
Browse files Browse the repository at this point in the history
…sing Deepspeed

### What changes were proposed in this pull request?
Made the DeepspeedTorchDistributor run() method use the _run() function as the backbone.
### Why are the changes needed?
It allows the user to run distributed training of a function with deepspeed easily.

### Does this PR introduce _any_ user-facing change?
This adds the ability for the user to pass in a function as the train_object when calling DeepspeedTorchDistributor.run(). The user must have all necessary imports within the function itself, and the function must be picklable. An example use case can be found in the python file linked in the JIRA ticket.

### How was this patch tested?
Notebook/file linked in the JIRA ticket. Formal e2e tests will come in future PR.

### Next Steps/Timeline

- [ ] Add more e2e tests for both running a regular pytorch file and running a function for training
- [ ] Write more documentation

Closes #42067 from mathewjacob1002/add_func_deepspeed.

Authored-by: Mathew Jacob <mathew.jacob@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
mathewjacob1002 authored and HyukjinKwon committed Jul 19, 2023
1 parent 0879a25 commit 392f8d8
Showing 1 changed file with 2 additions and 16 deletions.
18 changes: 2 additions & 16 deletions python/pyspark/ml/deepspeed/deepspeed_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#
import json
import os
import sys
import tempfile
from typing import (
Expand Down Expand Up @@ -135,19 +134,6 @@ def _run_training_on_pytorch_file(
def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: Any) -> Optional[Any]:
# If the "train_object" is a string, then we assume it's a filepath.
# Otherwise, we assume it's a function.
if isinstance(train_object, str):
if os.path.exists(train_object) is False:
raise FileNotFoundError(f"The path to training file {train_object} does not exist.")
framework_wrapper_fn = DeepspeedTorchDistributor._run_training_on_pytorch_file
else:
raise RuntimeError("Python training functions aren't supported as inputs at this time")

if self.local_mode:
return self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs)
return self._run_distributed_training(
framework_wrapper_fn,
train_object,
spark_dataframe=None,
*args,
**kwargs, # type:ignore[misc]
return self._run(
train_object, DeepspeedTorchDistributor._run_training_on_pytorch_file, *args, **kwargs
)

0 comments on commit 392f8d8

Please sign in to comment.