Skip to content

Commit

Permalink
Merge branch 'window_name'
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanbin Hu committed Jan 31, 2021
2 parents 43d03da + 83d63a3 commit 37519bd
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 37 deletions.
2 changes: 1 addition & 1 deletion bluefog/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from bluefog.torch.mpi_ops import win_accumulate_nonblocking, win_accumulate
from bluefog.torch.mpi_ops import win_wait, win_poll
from bluefog.torch.mpi_ops import win_mutex
from bluefog.torch.mpi_ops import get_win_version
from bluefog.torch.mpi_ops import get_win_version, get_current_created_window_names

from bluefog.torch.mpi_ops import win_associated_p
from bluefog.torch.mpi_ops import turn_on_win_ops_with_associated_p
Expand Down
4 changes: 4 additions & 0 deletions bluefog/torch/mpi_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,10 @@ def win_wait(handle: int) -> bool:
return True


def get_current_created_window_names() -> List[str]:
"""Return the names of current created windows."""
return sorted(list(_win_map.keys()))

def get_win_version(name: str) -> Dict[int, int]:
""" Get the version of tensor stored in the win buffer.
Expand Down
44 changes: 34 additions & 10 deletions bluefog/torch/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,19 +843,20 @@ def zero_grad(self):

class _DistributedWinOptimizer(torch.optim.Optimizer):

def __init__(self, params, model, num_steps_per_communication, pull_style):
def __init__(self, params, model, num_steps_per_communication, window_prefix, pull_style):
super(self.__class__, self).__init__(params)

if pull_style:
self.src_weights = None # use to control the behavior of win_get dynamically.
else:
self.dst_weights = None # use to control the behavior of win_put dynamically.
self.force_barrier = False
self.window_prefix = window_prefix+'.' if window_prefix is not None else ''

named_parameters, models = _check_named_parameters(self, model)
self._models = models
self._pull_style = pull_style
self._parameter_names = {v: k for k, v in sorted(named_parameters)}
self._parameter_names = {v: self.window_prefix+k for k, v in sorted(named_parameters)}
self._handles = {} # store parameter -> handle
self._synchronized = False
self._should_synchronize = True
Expand All @@ -871,6 +872,9 @@ def __init__(self, params, model, num_steps_per_communication, pull_style):
self._register_window()
self._register_hooks()

def __del__(self):
self.unregister_window()

def _register_hooks(self):
for model in self._models:
# The hook is added at model level instead of layer level, as it avoids triggering
Expand All @@ -888,7 +892,7 @@ def hook(model, *unused):
for name, p in layer.named_parameters():
if self._use_timeline:
# End forward computation timeline
bf.timeline_end_activity(parent_name+'.'+name)
bf.timeline_end_activity(self.window_prefix+parent_name+'.'+name)
if not layer.training:
continue
if p.requires_grad:
Expand All @@ -899,7 +903,7 @@ def hook(model, *unused):
self._bluefog_delay[p] -= 1
if self._bluefog_delay[p] == 0:
handle = bf.win_put_nonblocking(
tensor=p.data, name=parent_name+'.'+name,
tensor=p.data, name=self.window_prefix+parent_name+'.'+name,
dst_weights=self.dst_weights, require_mutex=False)
self._handles[p] = handle
return hook
Expand All @@ -910,7 +914,7 @@ def hook(model, *unused):
for name, p in layer.named_parameters():
if self._use_timeline:
# End forward computation timeline
bf.timeline_end_activity(parent_name+'.'+name)
bf.timeline_end_activity(self.window_prefix+parent_name+'.'+name)
if not layer.training:
continue
if p.requires_grad:
Expand All @@ -921,12 +925,14 @@ def hook(model, *unused):
self._bluefog_delay[p] -= 1
if self._bluefog_delay[p] == 0:
handle = bf.win_get_nonblocking(
name=parent_name+'.'+name, src_weights=self.src_weights,
require_mutex=True)
name=self.window_prefix+parent_name+'.'+name,
src_weights=self.src_weights, require_mutex=True)
self._handles[p] = handle
return hook

def _register_window(self):
if bf.size() <= 1:
return
for param_group in self.param_groups:
for p in param_group["params"]:
name = self._parameter_names.get(p)
Expand All @@ -937,6 +943,22 @@ def _register_window(self):
raise ValueError(
"Cannot allocate MPI window for the parameter {}".format(name))

def unregister_window(self):
''' Unregister MPI Window objects for the optimizer manually.
'''
if bf.size() <= 1:
return
for param_group in self.param_groups:
for p in param_group["params"]:
name = self._parameter_names.get(p)
if name is None:
raise KeyError(
"Cannot find parameter {} in the _parameter_names dictionary".format(name))
if name in bf.get_current_created_window_names():
if not bf.win_free(name):
raise ValueError(
"Cannot free MPI window for the parameter {}".format(name))

def turn_on_timeline(self):
handles = _register_timeline(self, self._models, self._parameter_names)
self._timeline_hook_handles.extend(handles)
Expand Down Expand Up @@ -1246,8 +1268,7 @@ def DistributedPullGetOptimizer(optimizer, model,
return cls(optimizer.param_groups, model, num_steps_per_communication, pull_style=True)


def DistributedWinPutOptimizer(optimizer, model,
num_steps_per_communication=1):
def DistributedWinPutOptimizer(optimizer, model, num_steps_per_communication=1, window_prefix=None):
"""An distributed optimizer that wraps another torch.optim.Optimizer with
pull model average through bf.win_put ops.
Expand All @@ -1258,6 +1279,8 @@ def DistributedWinPutOptimizer(optimizer, model,
communication. This allows local model parameter updates
per num_steps_per_communication before reducing them over
distributed computation resources.
window_prefix: A string to identify the unique DistributedWinPutOptimizer, which will be
applied as the prefix for window name.
Returned optimizer has two extra parameters `dst_weights` and `force_barrier`.
Set dst_weights dictionary as {rank: scaling} differently per iteration to achieve
Expand All @@ -1271,7 +1294,8 @@ def DistributedWinPutOptimizer(optimizer, model,
(optimizer.__class__,),
dict(_DistributedWinOptimizer.__dict__),
)
return cls(optimizer.param_groups, model, num_steps_per_communication, pull_style=False)
return cls(optimizer.param_groups, model, num_steps_per_communication,
window_prefix, pull_style=False)


def DistributedAllreduceOptimizer(optimizer, model,
Expand Down
62 changes: 36 additions & 26 deletions test/torch_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,8 @@ def evaluation(model, dataloader, isCUDA):
id="ATC Neighbor Allreduce on CPU"))
static_topo_scenarios.append(
pytest.param("CPU", "gradient.allreduce", {}, id="Gradient Allreduce on CPU"))
# TODO(hanbinhu): support multiple window put optimizer tests in the file. This issue may be due to
# duplicate name registration for the MPI window.
static_topo_scenarios.append(
pytest.param("CPU", "win.put", {}, id="Window put on CPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("CPU", "win.put", {'window_prefix': 'CPU'}, id="Window put on CPU"))
if TEST_ON_GPU:
static_topo_scenarios.append(
pytest.param("GPU", bf.CommunicationType.empty, {"ATC": False, "error_threshold": 2},
Expand All @@ -322,8 +319,7 @@ def evaluation(model, dataloader, isCUDA):
static_topo_scenarios.append(
pytest.param("GPU", "gradient.allreduce", {}, id="Gradient Allreduce on GPU"))
static_topo_scenarios.append(
pytest.param("GPU", "win.put", {}, id="Window put on GPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("GPU", "win.put", {'window_prefix': 'GPU'}, id="Window put on GPU"))

# device can be set to "GPU" or "CPU".
# communication_type can be selected from bf.CommunicationType, "gradient.allreduce" or "win.put".
Expand All @@ -332,6 +328,7 @@ def evaluation(model, dataloader, isCUDA):
def test_standard_optimizer(device, communication_type, kwargs):
atc_style = kwargs.get("ATC", False)
error_threshold = kwargs.get("error_threshold", 1.5)
window_prefix = kwargs.get("window_prefix", None)

problem_builder, train_dataloader, test_dataloader, model, optimizer, num_epochs = \
problem_setup()
Expand All @@ -344,7 +341,8 @@ def test_standard_optimizer(device, communication_type, kwargs):
optimizer = base_dist_optimizer(optimizer, model=model,
communication_type=communication_type)
elif communication_type == "win.put":
optimizer = bf.DistributedWinPutOptimizer(optimizer, model=model)
optimizer = bf.DistributedWinPutOptimizer(optimizer, model=model,
window_prefix=window_prefix)
elif communication_type == "gradient.allreduce":
optimizer = bf.DistributedGradientAllreduceOptimizer(
optimizer, model=model)
Expand All @@ -369,6 +367,8 @@ def test_standard_optimizer(device, communication_type, kwargs):
test_mse[-3:].max() < error_threshold*problem_builder.noise_level**2
), "Train MSE in the last three epochs doesn't coverge."

if communication_type == "win.put":
optimizer.unregister_window()

hierarchical_model_scenarios = []
hierarchical_model_scenarios.append(
Expand All @@ -383,8 +383,9 @@ def test_standard_optimizer(device, communication_type, kwargs):
pytest.param("CPU", "gradient.allreduce", {}, id="Gradient Allreduce on CPU",
marks=pytest.mark.skip(reason="GA may not converge for hierarchical model.")))
hierarchical_model_scenarios.append(
pytest.param("CPU", "win.put", {}, id="Window put on CPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("CPU", "win.put", {'window_prefix': 'CPU'}, id="Window put on CPU",
marks=pytest.mark.skip(reason="Win put may not converge for hierarchical model.")))

if TEST_ON_GPU:
hierarchical_model_scenarios.append(
pytest.param("GPU", bf.CommunicationType.neighbor_allreduce, {"ATC": False},
Expand All @@ -398,14 +399,16 @@ def test_standard_optimizer(device, communication_type, kwargs):
pytest.param("GPU", "gradient.allreduce", {}, id="Gradient Allreduce on GPU",
marks=pytest.mark.skip(reason="GA may not converge for hierarchical model.")))
hierarchical_model_scenarios.append(
pytest.param("GPU", "win.put", {}, id="Window put on GPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("GPU", "win.put", {'window_prefix', 'GPU'}, id="Window put on GPU",
marks=pytest.mark.skip(
reason="Win put may not converge for hierarchical model.")))


@pytest.mark.parametrize("device,communication_type,kwargs", hierarchical_model_scenarios)
def test_optimizer_for_hierarchical_model(device, communication_type, kwargs):
atc_style = kwargs.get("ATC", False)
error_threshold = kwargs.get("error_threshold", 1.5)
window_prefix = kwargs.get("window_prefix", None)

problem_builder, train_dataloader, test_dataloader, model, optimizer, num_epochs = \
problem_setup(HierarchicalLinearNet)
Expand All @@ -418,7 +421,8 @@ def test_optimizer_for_hierarchical_model(device, communication_type, kwargs):
optimizer = base_dist_optimizer(optimizer, model=model,
communication_type=communication_type)
elif communication_type == "win.put":
optimizer = bf.DistributedWinPutOptimizer(optimizer, model=model)
optimizer = bf.DistributedWinPutOptimizer(optimizer, model=model,
window_prefix=window_prefix)
elif communication_type == "gradient.allreduce":
optimizer = bf.DistributedGradientAllreduceOptimizer(
optimizer, model=model)
Expand All @@ -443,6 +447,8 @@ def test_optimizer_for_hierarchical_model(device, communication_type, kwargs):
test_mse[-3:].max() < error_threshold*problem_builder.noise_level**2
), "Train MSE in the last three epochs doesn't coverge."

if communication_type == "win.put":
optimizer.unregister_window()

# Neighbor allreduce dynamic tests
dynamic_neighbor_allreduce_scenarios = []
Expand Down Expand Up @@ -497,24 +503,24 @@ def test_dynamic_neighbor_allreduce_optimizer(device, atc_style, kwargs):
# Window put dynamic tests
dynamic_win_put_scenarios = []
dynamic_win_put_scenarios.append(
pytest.param("CPU", {}, id="Dynamic window put on CPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("CPU", {'window_prefix':'CPU'}, id="Dynamic window put on CPU"))
if TEST_ON_GPU:
dynamic_win_put_scenarios.append(
pytest.param("GPU", {}, id="Dynamic window put on GPU"))
pytest.param("GPU", {'window_prefix':'GPU'}, id="Dynamic window put on GPU"))


@pytest.mark.parametrize("device,kwargs", dynamic_win_put_scenarios)
def test_dynamic_win_put_optimizer(device, kwargs):
error_threshold = kwargs.get("error_threshold", 1.5)
window_prefix = kwargs.get("window_prefix", None)

problem_builder, train_dataloader, test_dataloader, model, optimizer, num_epochs = \
problem_setup()

isCUDA = pin_model_to_device(device, model)

optimizer = bf.DistributedWinPutOptimizer(optimizer, model=model)

optimizer = bf.DistributedWinPutOptimizer(optimizer, model=model, window_prefix=window_prefix)
# Train and test
train_mse = []
test_mse = []
Expand All @@ -533,6 +539,7 @@ def test_dynamic_win_put_optimizer(device, kwargs):
assert (
test_mse[-3:].max() < error_threshold*problem_builder.noise_level**2
), "Train MSE in the last three epochs doesn't coverge."
optimizer.unregister_window()


local_aggregation_scenarios = []
Expand All @@ -557,8 +564,7 @@ def test_dynamic_win_put_optimizer(device, kwargs):
local_aggregation_scenarios.append(
pytest.param("CPU", "gradient.allreduce", {}, id="Gradient Allreduce on CPU"))
local_aggregation_scenarios.append(
pytest.param("CPU", "win.put", {}, id="Window put on CPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("CPU", "win.put", {'window_prefix': 'CPU'}, id="Window put on CPU"))
local_aggregation_scenarios.append(
pytest.param("CPU", bf.CommunicationType.neighbor_allreduce, {"mini_batch_size": 4},
id="Neighbor allreduce AWC on CPU with a mini_batch_size of 4"))
Expand Down Expand Up @@ -590,15 +596,14 @@ def test_dynamic_win_put_optimizer(device, kwargs):
local_aggregation_scenarios.append(
pytest.param("GPU", "gradient.allreduce", {}, id="Gradient Allreduce on GPU"))
local_aggregation_scenarios.append(
pytest.param("GPU", "win.put", {}, id="Window put on GPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))

pytest.param("GPU", "win.put", {'window_prefix': 'GPU'}, id="Window put on GPU"))

@pytest.mark.parametrize("device,communication_type,kwargs", local_aggregation_scenarios)
def test_optimizer_local_aggregation(device, communication_type, kwargs):
atc_style = kwargs.get("ATC", False)
error_threshold = kwargs.get("error_threshold", 1.5)
mini_batch_size = kwargs.get("mini_batch_size", 16)
window_prefix = kwargs.get("window_prefix", None)

problem_builder, train_dataloader, test_dataloader, model, optimizer, num_epochs = \
problem_setup()
Expand Down Expand Up @@ -641,6 +646,8 @@ def test_optimizer_local_aggregation(device, communication_type, kwargs):
test_mse[-3:].max() < error_threshold*problem_builder.noise_level**2
), "Train MSE in the last three epochs doesn't coverge."

if communication_type == "win.put":
optimizer.unregister_window()

local_aggregation_duplicated_scenarios = []
local_aggregation_duplicated_scenarios.append(
Expand All @@ -650,8 +657,7 @@ def test_optimizer_local_aggregation(device, communication_type, kwargs):
pytest.param("CPU", bf.CommunicationType.neighbor_allreduce, {"ATC": True},
id="ATC Neighbor Allreduce on CPU"))
local_aggregation_duplicated_scenarios.append(
pytest.param("CPU", "win.put", {}, id="Win Put on CPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("CPU", "win.put", {'window_prefix': 'CPU'}, id="Win Put on CPU"))
local_aggregation_duplicated_scenarios.append(
pytest.param("CPU", "gradient.allreduce", {}, id="Gradient Allreduce on CPU"))
if TEST_ON_GPU:
Expand All @@ -662,8 +668,7 @@ def test_optimizer_local_aggregation(device, communication_type, kwargs):
pytest.param("GPU", bf.CommunicationType.neighbor_allreduce, {"ATC": True},
id="ATC Neighbor Allreduce on GPU"))
local_aggregation_duplicated_scenarios.append(
pytest.param("GPU", "win.put", {}, id="Win Put on GPU",
marks=pytest.mark.skip(reason="Multiple win_put optimizer tests will fail")))
pytest.param("GPU", "win.put", {'window_prefix': 'GPU'}, id="Win Put on GPU"))
local_aggregation_duplicated_scenarios.append(
pytest.param("GPU", "gradient.allreduce", {}, id="Gradient Allreduce on GPU"))

Expand All @@ -675,6 +680,7 @@ def test_optimizer_local_aggregation_duplicated(device, communication_type, kwar
# for local aggregation.
atc_style = kwargs.get("ATC", False)
mini_batch_size = kwargs.get("mini_batch_size", 16)
window_prefix = kwargs.get("window_prefix", None)

_, train_dataloader, test_dataloader, model, optimizer, num_epochs = \
problem_setup(DuplicatedLinearNet)
Expand All @@ -692,6 +698,7 @@ def test_optimizer_local_aggregation_duplicated(device, communication_type, kwar
num_steps_per_communication=J)
elif communication_type == "win.put":
optimizer = bf.DistributedWinPutOptimizer(optimizer, model=model,
window_prefix=window_prefix,
num_steps_per_communication=J)
elif communication_type == "gradient.allreduce":
optimizer = bf.DistributedGradientAllreduceOptimizer(optimizer, model=model,
Expand All @@ -705,3 +712,6 @@ def test_optimizer_local_aggregation_duplicated(device, communication_type, kwar
model, optimizer, train_dataloader, isCUDA, mini_batch_size)
evaluation(model, train_dataloader, isCUDA)
evaluation(model, test_dataloader, isCUDA)

if communication_type == "win.put":
optimizer.unregister_window()

0 comments on commit 37519bd

Please sign in to comment.