Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip sync when abnormally exiting #6025

Merged
merged 7 commits into from
Aug 24, 2021
55 changes: 44 additions & 11 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

import sys
import collections

import oneflow._oneflow_internal
Expand Down Expand Up @@ -92,19 +93,51 @@ def is_deprecated(func_or_class):
del register_python_callback


def _SyncOnMasterFn():
if not oneflow._oneflow_internal.IsEnvInited():
return
if oneflow.framework.distribute.is_multi_client():
oneflow._oneflow_internal.eager.multi_client.Sync()
elif oneflow.framework.distribute.get_rank() == 0:
oneflow._oneflow_internal.eager.single_client.Sync()
class ExitHook:
def __init__(self):
self.exit_code = None
self.exception = None

self._orig_exit = sys.exit
self._orig_excepthook = sys.excepthook

atexit.register(oneflow._oneflow_internal.SetShuttingDown)
atexit.register(oneflow._oneflow_internal.DestroyEnv)
atexit.register(oneflow.framework.session_context.TryCloseDefaultSession)
atexit.register(_SyncOnMasterFn)
def exit(code=0):
self.exit_code = code
self._orig_exit(code)
sys.exit = exit

def exc_handler(exc_type, exc, *args):
self.exception = exc
self._orig_excepthook(exc_type, exc, *args)

sys.excepthook = exc_handler

def is_normal_exit(self):
if self.exit_code is not None:
return self.exit_code == 0
return self.exception is None


hook = ExitHook()


def atexit_hook(hook):
if hook.is_normal_exit():
if oneflow._oneflow_internal.IsEnvInited():
if oneflow.framework.distribute.is_multi_client():
oneflow._oneflow_internal.eager.multi_client.Sync()
elif oneflow.framework.distribute.get_rank() == 0:
oneflow._oneflow_internal.eager.single_client.Sync()
oneflow.framework.session_context.TryCloseDefaultSession()
if hook.is_normal_exit():
oneflow._oneflow_internal.DestroyEnv()
oneflow._oneflow_internal.SetShuttingDown()


atexit.register(atexit_hook, hook)
del atexit_hook
del hook
del ExitHook
del atexit
del oneflow
import oneflow.framework.docstr as docstr
Expand Down