Skip to content

Comments

[SPARK-44264][ML][PYTHON] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer#41973

Closed
mathewjacob1002 wants to merge 14 commits intoapache:masterfrom
mathewjacob1002:distributed_func_support_prototype
Closed

[SPARK-44264][ML][PYTHON] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer#41973
mathewjacob1002 wants to merge 14 commits intoapache:masterfrom
mathewjacob1002:distributed_func_support_prototype

Conversation

@mathewjacob1002
Copy link
Contributor

@mathewjacob1002 mathewjacob1002 commented Jul 12, 2023

What Was Changed

We enable for a custom function pointer to be passed around the private functions that allow for distributed training of a function.

Why Do We Need This Change

By abstracting the "run_training_on_pytorch_file" function to something that can be passed in, it allows for much easier creation of distributors that run on top of torch.distributed. Specifically, it makes it easy to implement distributed training of picklable functions in DeepspeedTorchDistributor. As mentioned, if there are accelerators that come out in the future built on top of torch.distributed, it will be very easy to support them in Spark. One can simply do the following:

  1. Inherit from TorchDistributor and define a _run_training_on_pytorch_file function or equivalent for your class
  2. When defining run(...), simply return _run() and pass in your custom _run_training_on_pytorch_file function in as the respective argument

Any User-Facing Changes?

No.

How Is This Tested?

The existing tests for TorchDistributor.

…raining_on_pytorch_file function pointer. Motivation is to make it easier for future developers to add their own distributed trainer for other accelerators that come out in the future
@mathewjacob1002 mathewjacob1002 changed the title [DO NOT MERGE/REVIEW] PROTOTYPING: refactoring the TorchDistributor code to take in a run_t… [Spark Ticket] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer Jul 14, 2023
@mathewjacob1002 mathewjacob1002 changed the title [Spark Ticket] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer [SPARK-44264] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer Jul 14, 2023
return framework_wrapper(input_params, train_object, *args, **kwargs)
else:
# We are doing training with a function, will call run_training_on_pytorch_function
if not run_pytorch_file_fn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this and set the parameter to be run_pytorch_file_fn: Optional[Callable] = TorchDistributor._run...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work because of the *args and **kwargs after. Python kind of freaks out and can't do default value in my experience.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@staticmethod
def _run_training_on_pytorch_function(
input_params: Dict[str, Any], train_fn: Callable, *args: Any, **kwargs: Any
input_params: Dict[str, Any], train_fn: Callable, run_pytorch_file_fn: Optional[Callable], *args: Any, **kwargs: Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as here

self,
framework_wrapper_fn: Callable,
train_object: Union[Callable, str],
run_pytorch_file_fn: Optional[Callable],
Copy link
Contributor

@rithwik-db rithwik-db Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Id probably move this variable before train_object, same with the other functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind if I ask why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a nit but we want to keep the code that relates to the training (train_object, *args, **kwargs) away from the utils stuff like framework_wrapper_fn and run_pytorch_file_fn for the sake of readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you say other functions, do you mean all of them? Wouldn't that interfere with the default args comment because they have to be after positional args iirc?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, run_pytorch_file_fn isn't a keyword argument (yet); you can either do this comment or the other, but I'd prefer Weichen to weigh in first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sounds good!

@staticmethod
def _get_output_from_framework_wrapper(framework_wrapper: Optional[Callable], input_params: Dict, train_object: Union[Callable, str], run_pytorch_file_fn: Optional[Callable], *args, **kwargs) -> Optional[Any]:
if not framework_wrapper:
raise RuntimeError("In the _get_output_from_framework_wrapper function, found a framework wrapper that is none. I wonder why this is...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this error message mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is ever a point where the framework_wrapper function isn't a Callable, we want this error to be thrown because it isn't supposed to happen. The reason we set this to an Optional[Callable] is because my linter complained a lot about it and how we can't assign something to framework_wrapper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make the error message clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the new one sound @lu-wang-dl?

Parameters
----------
framework_wrapper: Optional[Callable]
Function pointer that will be invoked. Can either be the function that runs distributed training on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User provided function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a coment to indicate which one is from the user input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_object is from the user - it's either a string representing a filepath or a function pointer that the user wants to run in a distributed fashion. I will try to make this more explicit in the docstring.

Returns
-------
Optional[Any]
Returns the result of the framework_wrapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect framework_wrapper return anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will return out depending on the train_object. This is the same train_object in the rest of the code before, where it's either a path to a file to execute or a function to run in a distributed fashion. What framework_wrapper returns depends on that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

framework_wrapper is the same meaning as before in the run method

for functions if the train_object is a Callable
input_params: Dict
A dictionary that maps parameter to arguments for the command to be created.
train_object: Union[Callable, str]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot tell the difference between train_object and framework_wrapper from the comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried again to make it more obvious which is which. But in a nutshell, train_object is passed in from the user, and the framework_wrapper is something that DeepspeedTorchDistributor decides based on the type of train_object.

Copy link
Contributor

@lu-wang-dl lu-wang-dl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on my side. I will let Ricky do the final approval.

@mathewjacob1002 mathewjacob1002 marked this pull request as ready for review July 17, 2023 20:04
Copy link
Contributor

@WeichenXu123 WeichenXu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@rithwik-db rithwik-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since weichen is okay with this too, lgtm

@HyukjinKwon HyukjinKwon changed the title [SPARK-44264] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer [SPARK-44264][ML][PYTHON] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer Jul 19, 2023
@HyukjinKwon
Copy link
Member

Master to master and branch-3.5.

HyukjinKwon pushed a commit that referenced this pull request Jul 19, 2023
…ustom "run_training_on_file" Function Pointer

### What Was Changed
We enable for a custom function pointer to be passed around the private functions that allow for distributed training of a function.

### Why Do We Need This Change

By abstracting the "run_training_on_pytorch_file" function to something that can be passed in, it allows for much easier creation of distributors that run on top of torch.distributed. Specifically, it makes it easy to implement distributed training of picklable functions in DeepspeedTorchDistributor. As mentioned, if there are accelerators that come out in the future built on top of torch.distributed, it will be very easy to support them in Spark. One can simply do the following:
1. Inherit from TorchDistributor and define a _run_training_on_pytorch_file function or equivalent for your class
2.  When defining run(...), simply return _run() and pass in your custom _run_training_on_pytorch_file function in as the respective argument

### Any User-Facing Changes?
No.

### How Is This Tested?

The existing tests for TorchDistributor.

Closes #41973 from mathewjacob1002/distributed_func_support_prototype.

Lead-authored-by: Mathew Jacob <mathew.jacob@databricks.com>
Co-authored-by: Mathew Jacob <134338709+mathewjacob1002@users.noreply.github.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit ee0e687)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants