Skip to content

Commit

Permalink
Add missing skip_synchronize in dist optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
bichengying committed Jun 1, 2020
1 parent 27edc87 commit 7e2a134
Showing 1 changed file with 81 additions and 22 deletions.
103 changes: 81 additions & 22 deletions bluefog/torch/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,26 @@ def synchronize(self):

self._synchronized = True

@contextmanager
def skip_synchronize(self):
"""
A context manager used to specify that optimizer.step() should
not perform synchronization.
It's typically used in a following pattern:
.. code-block:: python
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
"""
self._should_synchronize = False
try:
yield
finally:
self._should_synchronize = True

def step(self, closure=None):
# consensus style is the easist way to implement it.
if self._should_synchronize:
Expand Down Expand Up @@ -398,6 +418,37 @@ def _register_window(self):
raise ValueError(
"Cannot allocate MPI window for the parameter {}".format(name))

def turn_on_timeline(self):
handles = _register_timeline(self, self._models, self._parameter_names)
self._timeline_hook_handles.extend(handles)
self._use_timeline = True

def turn_off_timeline(self):
for hook in self._timeline_hook_handles:
hook.remove()
self._timeline_hook_handles.clear()
self._use_timeline = False

@contextmanager
def skip_synchronize(self):
"""
A context manager used to specify that optimizer.step() should
not perform synchronization.
It's typically used in a following pattern:
.. code-block:: python
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
"""
self._should_synchronize = False
try:
yield
finally:
self._should_synchronize = True

def synchronize(self):
# Here synchronize just to make sure win_put ops is finished
# in one iteration.
Expand All @@ -411,17 +462,6 @@ def synchronize(self):
self._handles.clear()
self._synchronized = True

def turn_on_timeline(self):
handles = _register_timeline(self, self._models, self._parameter_names)
self._timeline_hook_handles.extend(handles)
self._use_timeline = True

def turn_off_timeline(self):
for hook in self._timeline_hook_handles:
hook.remove()
self._timeline_hook_handles.clear()
self._use_timeline = False

def step(self, closure=None):
if self.force_barrier:
bf.barrier()
Expand Down Expand Up @@ -510,6 +550,36 @@ def hook(module, *unused):
self._handles[p] = handle
return hook

def turn_on_timeline(self):
handles = _register_timeline(self, self._models, self._parameter_names)
self._timeline_hook_handles.extend(handles)
self._use_timeline = True

def turn_off_timeline(self):
for hook in self._timeline_hook_handles:
hook.remove()
self._timeline_hook_handles.clear()
self._use_timeline = False

@contextmanager
def skip_synchronize(self):
"""
A context manager used to specify that optimizer.step() should
not perform synchronization.
It's typically used in a following pattern:
.. code-block:: python
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
"""
self._should_synchronize = False
try:
yield
finally:
self._should_synchronize = True

def synchronize(self):
# Here synchronize just to make sure win_put ops is finished
Expand All @@ -530,17 +600,6 @@ def synchronize(self):
self._handles.clear()
self._synchronized = True

def turn_on_timeline(self):
handles = _register_timeline(self, self._models, self._parameter_names)
self._timeline_hook_handles.extend(handles)
self._use_timeline = True

def turn_off_timeline(self):
for hook in self._timeline_hook_handles:
hook.remove()
self._timeline_hook_handles.clear()
self._use_timeline = False

def step(self, closure=None):
if self.force_barrier:
bf.barrier()
Expand Down

0 comments on commit 7e2a134

Please sign in to comment.