[SPARK-44264][ML][PYTHON] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer#41973
Conversation
…raining_on_pytorch_file function pointer. Motivation is to make it easier for future developers to add their own distributed trainer for other accelerators that come out in the future
…ix the pickling error
…for running pytorch files
| return framework_wrapper(input_params, train_object, *args, **kwargs) | ||
| else: | ||
| # We are doing training with a function, will call run_training_on_pytorch_function | ||
| if not run_pytorch_file_fn: |
There was a problem hiding this comment.
Remove this and set the parameter to be run_pytorch_file_fn: Optional[Callable] = TorchDistributor._run...
There was a problem hiding this comment.
This won't work because of the *args and **kwargs after. Python kind of freaks out and can't do default value in my experience.
There was a problem hiding this comment.
cc: @WeichenXu123 what are your thoughts? https://stackoverflow.com/questions/9872824/calling-a-python-function-with-args-kwargs-and-optional-default-arguments <- It seems possible
| @staticmethod | ||
| def _run_training_on_pytorch_function( | ||
| input_params: Dict[str, Any], train_fn: Callable, *args: Any, **kwargs: Any | ||
| input_params: Dict[str, Any], train_fn: Callable, run_pytorch_file_fn: Optional[Callable], *args: Any, **kwargs: Any |
| self, | ||
| framework_wrapper_fn: Callable, | ||
| train_object: Union[Callable, str], | ||
| run_pytorch_file_fn: Optional[Callable], |
There was a problem hiding this comment.
Id probably move this variable before train_object, same with the other functions
There was a problem hiding this comment.
do you mind if I ask why?
There was a problem hiding this comment.
This is just a nit but we want to keep the code that relates to the training (train_object, *args, **kwargs) away from the utils stuff like framework_wrapper_fn and run_pytorch_file_fn for the sake of readability.
There was a problem hiding this comment.
when you say other functions, do you mean all of them? Wouldn't that interfere with the default args comment because they have to be after positional args iirc?
There was a problem hiding this comment.
Yes, run_pytorch_file_fn isn't a keyword argument (yet); you can either do this comment or the other, but I'd prefer Weichen to weigh in first.
There was a problem hiding this comment.
ok sounds good!
| @staticmethod | ||
| def _get_output_from_framework_wrapper(framework_wrapper: Optional[Callable], input_params: Dict, train_object: Union[Callable, str], run_pytorch_file_fn: Optional[Callable], *args, **kwargs) -> Optional[Any]: | ||
| if not framework_wrapper: | ||
| raise RuntimeError("In the _get_output_from_framework_wrapper function, found a framework wrapper that is none. I wonder why this is...") |
There was a problem hiding this comment.
What does this error message mean?
There was a problem hiding this comment.
If there is ever a point where the framework_wrapper function isn't a Callable, we want this error to be thrown because it isn't supposed to happen. The reason we set this to an Optional[Callable] is because my linter complained a lot about it and how we can't assign something to framework_wrapper.
There was a problem hiding this comment.
We should make the error message clear.
| Parameters | ||
| ---------- | ||
| framework_wrapper: Optional[Callable] | ||
| Function pointer that will be invoked. Can either be the function that runs distributed training on |
There was a problem hiding this comment.
User provided function?
There was a problem hiding this comment.
Could we add a coment to indicate which one is from the user input?
There was a problem hiding this comment.
train_object is from the user - it's either a string representing a filepath or a function pointer that the user wants to run in a distributed fashion. I will try to make this more explicit in the docstring.
| Returns | ||
| ------- | ||
| Optional[Any] | ||
| Returns the result of the framework_wrapper |
There was a problem hiding this comment.
Do we expect framework_wrapper return anything?
There was a problem hiding this comment.
It will return out depending on the train_object. This is the same train_object in the rest of the code before, where it's either a path to a file to execute or a function to run in a distributed fashion. What framework_wrapper returns depends on that.
There was a problem hiding this comment.
framework_wrapper is the same meaning as before in the run method
| for functions if the train_object is a Callable | ||
| input_params: Dict | ||
| A dictionary that maps parameter to arguments for the command to be created. | ||
| train_object: Union[Callable, str] |
There was a problem hiding this comment.
I cannot tell the difference between train_object and framework_wrapper from the comments.
There was a problem hiding this comment.
Tried again to make it more obvious which is which. But in a nutshell, train_object is passed in from the user, and the framework_wrapper is something that DeepspeedTorchDistributor decides based on the type of train_object.
Co-authored-by: Lu Wang <38018689+lu-wang-dl@users.noreply.github.com>
rithwik-db
left a comment
There was a problem hiding this comment.
since weichen is okay with this too, lgtm
|
Master to master and branch-3.5. |
…ustom "run_training_on_file" Function Pointer ### What Was Changed We enable for a custom function pointer to be passed around the private functions that allow for distributed training of a function. ### Why Do We Need This Change By abstracting the "run_training_on_pytorch_file" function to something that can be passed in, it allows for much easier creation of distributors that run on top of torch.distributed. Specifically, it makes it easy to implement distributed training of picklable functions in DeepspeedTorchDistributor. As mentioned, if there are accelerators that come out in the future built on top of torch.distributed, it will be very easy to support them in Spark. One can simply do the following: 1. Inherit from TorchDistributor and define a _run_training_on_pytorch_file function or equivalent for your class 2. When defining run(...), simply return _run() and pass in your custom _run_training_on_pytorch_file function in as the respective argument ### Any User-Facing Changes? No. ### How Is This Tested? The existing tests for TorchDistributor. Closes #41973 from mathewjacob1002/distributed_func_support_prototype. Lead-authored-by: Mathew Jacob <mathew.jacob@databricks.com> Co-authored-by: Mathew Jacob <134338709+mathewjacob1002@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit ee0e687) Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
What Was Changed
We enable for a custom function pointer to be passed around the private functions that allow for distributed training of a function.
Why Do We Need This Change
By abstracting the "run_training_on_pytorch_file" function to something that can be passed in, it allows for much easier creation of distributors that run on top of torch.distributed. Specifically, it makes it easy to implement distributed training of picklable functions in DeepspeedTorchDistributor. As mentioned, if there are accelerators that come out in the future built on top of torch.distributed, it will be very easy to support them in Spark. One can simply do the following:
Any User-Facing Changes?
No.
How Is This Tested?
The existing tests for TorchDistributor.