You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
received_forward_module=Nonereceived_backward_module=Nonedefset_forward_backward_module(forward_module, backward_module):
globalreceived_forward_module, received_backward_modulereceived_forward_module=forward_modulereceived_backward_module=backward_moduledefsplit_and_compile(forward_hlo, backward_hlo):
globalreceived_forward_module, received_backward_modulemerged_hlo=merge(forward_hlo, backward_hlo)
compile_with_auto_sharidng(merged_hlo) # this will set received_forward_module and received_backward_moduleforward_binary=compile_without_auto_sharding(received_forward_module)
backward_binary=compile_without_auto_sharding(received_backward_module)
returnforward_binary, backward_binary
The text was updated successfully, but these errors were encountered:
Tips
Call python code in c++
Example: https://github.com/parax-project/tensorflow-parax/blob/b13493fd4429e9f6f4373c5d2e8503a3ad6d020f/tensorflow/compiler/xla/service/gpu/auto_sharding.cc#L1382-L1417
Pybind11 doc: https://pybind11.readthedocs.io/en/stable/advanced/embedding.html#
Where to get the hlo module? Right after the auto-sharding pass
https://github.com/parax-project/tensorflow-parax/blob/b13493fd4429e9f6f4373c5d2e8503a3ad6d020f/tensorflow/compiler/xla/service/gpu/auto_sharding.cc#L1646-L1648
Export c++ funtion to python
Add your function here https://github.com/parax-project/tensorflow-parax/blob/b13493fd4429e9f6f4373c5d2e8503a3ad6d020f/tensorflow/compiler/xla/python/xla_compiler.cc#L178
A possible implementation
In
auto_sharding.cc
(the last line)in
parax/auto_sharding.py
The text was updated successfully, but these errors were encountered: