Skip to content

Commit

Permalink
Topo service (#86)
Browse files Browse the repository at this point in the history
* Add construct_topology function

* Add doc and test for infer_destination_source_ranks

* Address comments

* Split infer_topo into two functions

* Delete construct_topo.py

* Fix the none case in _infer_topo func
  • Loading branch information
bichengying committed Mar 29, 2021
1 parent dcf56ba commit 8bde896
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 27 deletions.
44 changes: 37 additions & 7 deletions bluefog/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os
import torch
Expand All @@ -30,10 +26,10 @@
DistributedWinPutOptimizer,
DistributedAllreduceOptimizer,
DistributedNeighborAllreduceOptimizer,
DistributedHierarchicalNeighborAllreduceOptimizer
DistributedHierarchicalNeighborAllreduceOptimizer,
)

check_extension('bluefog.torch', __file__, 'mpi_lib')
check_extension("bluefog.torch", __file__, "mpi_lib")

from bluefog.torch.mpi_ops import init, shutdown
from bluefog.torch.mpi_ops import size, local_size, rank, local_rank
Expand Down Expand Up @@ -74,4 +70,38 @@

from bluefog.torch.mpi_ops import timeline_start_activity, timeline_end_activity
from bluefog.torch.mpi_ops import timeline_context
from bluefog.torch.utility import broadcast_optimizer_state, broadcast_parameters, allreduce_parameters

from bluefog.torch.utility import (
broadcast_optimizer_state,
broadcast_parameters,
allreduce_parameters,
)

from bluefog.common.topology_util import (
GetRecvWeights,
GetSendWeights,
IsRegularGraph,
IsTopologyEquivalent,
)

from bluefog.common.topology_util import (
ExponentialTwoGraph,
ExponentialGraph,
FullyConnectedGraph,
MeshGrid2DGraph,
RingGraph,
StarGraph,
SymmetricExponentialGraph,
)

from bluefog.common.topology_util import (
GetDynamicOnePeerSendRecvRanks,
GetExp2DynamicSendRecvMachineRanks,
GetInnerOuterRingDynamicSendRecvRanks,
GetInnerOuterExpo2DynamicSendRecvRanks,
)

from bluefog.torch.topology_util import (
InferSourceFromDestinationRanks,
InferDestinationFromSourceRanks,
)
108 changes: 108 additions & 0 deletions bluefog/torch/topology_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any, List, Optional, Tuple, Union
import collections

import numpy as np
import torch
import bluefog.torch as bf


def _check_ranks(rank_list: List[Any], self_rank: int, size: int) -> [bool, str]:
for rank in rank_list:
if not isinstance(rank, int):
return False, "contain element that is not integer."
if (rank < 0) or (rank >= size):
return False, "contain element that is not between 0 and size-1."
if len(set(rank_list)) != len(rank_list):
return False, "contain duplicated elements."
if self_rank in rank_list:
return False, "contain self rank."
return True, ""


def InferSourceFromDestinationRanks(
dst_ranks: List[int], construct_adjacency_matrix: bool = False,
) -> Union[List[int], Tuple[List[int], np.array]]:
"""Infer the source ranks from destination ranks. This is collective communication call.
Args:
dst_ranks: A list of destination ranks.
construct_adjacency_matrix: If true, adjacency matrix will be return instead.
Element w_{ij} represents the weights sending from node i to node j.
We use column normalized style, i.e. the sum of receiving weight is 1.
Raises:
ValueError: If dst_ranks or src_ranks does not contain integer from 0 to size-1.
Returns:
If construct_adjacency_matrix is false, returns the source ranks list.
If construct_adjacency_matrix is true, returns the the source ranks list
and a 2-D numpy array.
"""
is_valid, error_msg = _check_ranks(dst_ranks, bf.rank(), bf.size())
assert is_valid, f"The format of dst_ranks is wrong: {error_msg}"
return _infer_topo(
dst_ranks,
transpose=False,
construct_adjacency_matrix=construct_adjacency_matrix,
)


def InferDestinationFromSourceRanks(
src_ranks: List[int], construct_adjacency_matrix: bool = False,
) -> Union[List[int], np.array]:
"""Infer the destination ranks from source ranks. This is collective communication call.
Args:
src_ranks: A list of destination ranks.
construct_adjacency_matrix: If true, adjacency matrix will be return instead.
Element w_{ij} represents the weights sending from node i to node j.
We use column normalized style, i.e. the sum of receiving weight is 1.
Raises:
ValueError: If dst_ranks or src_ranks does not contain integer from 0 to size-1.
Returns:
If construct_adjacency_matrix is false, returns the destination ranks list.
If construct_adjacency_matrix is true, returns the the sodestinationrce ranks
list and a 2-D numpy array.
"""
is_valid, error_msg = _check_ranks(src_ranks, bf.rank(), bf.size())
assert is_valid, f"The format of src_ranks is wrong: {error_msg}"
return _infer_topo(
src_ranks,
transpose=True,
construct_adjacency_matrix=construct_adjacency_matrix,
)


def _infer_topo(
rank_list: List[int], transpose: bool, construct_adjacency_matrix: bool
):
degree = len(rank_list)
all_degree_list = bf.allgather(torch.tensor([degree], dtype=torch.int32)).numpy()
all_rank_list = bf.allgather(torch.tensor(rank_list, dtype=torch.int32)).numpy()
adjacency_dict = dict()
displacement = 0
for i, degree in enumerate(all_degree_list):
adjacency_dict[i] = sorted(all_rank_list[displacement : displacement + degree])
displacement += degree

inv_adjacency_dict = collections.defaultdict(list)
for k, adj in adjacency_dict.items():
for v in adj:
inv_adjacency_dict[v].append(k)
return_list = inv_adjacency_dict.get(bf.rank())
if return_list is None:
return_list = []

if not construct_adjacency_matrix:
return return_list

# construct_adjacency_matrix
W = np.eye(bf.size())
for k, adj in adjacency_dict.items():
W[k, adj] = 1
if transpose:
W = W.T

return return_list, W / W.sum(axis=1)
2 changes: 2 additions & 0 deletions bluefog/torch/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.
# ==============================================================================

from typing import Any, List, Optional
import collections

import numpy as np
import torch
import bluefog.torch as bf

Expand Down
102 changes: 82 additions & 20 deletions test/torch_basics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,18 @@

from common import mpi_env_rank_and_size
import bluefog.torch as bf
from bluefog.common.topology_util import ExponentialGraph, RingGraph, RingGraph
from bluefog.common.topology_util import IsTopologyEquivalent
from bluefog.torch import (
ExponentialGraph,
RingGraph,
StarGraph,
MeshGrid2DGraph,
FullyConnectedGraph,
)
from bluefog.torch import (
IsTopologyEquivalent,
InferDestinationFromSourceRanks,
InferSourceFromDestinationRanks,
)

warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
Expand Down Expand Up @@ -75,10 +85,12 @@ def test_set_topology_fail_with_win_create(self):

if size == 1:
expected_topology = nx.from_numpy_array(
np.array([[0.5]]), create_using=nx.DiGraph)
np.array([[0.5]]), create_using=nx.DiGraph
)
elif size == 2:
expected_topology = nx.from_numpy_array(
np.array([[0, 0.2], [0.2, 0]]), create_using=nx.DiGraph)
np.array([[0, 0.2], [0.2, 0]]), create_using=nx.DiGraph
)
else:
expected_topology = RingGraph(size)

Expand All @@ -96,10 +108,16 @@ def test_set_and_load_topology(self):
bf.init()
size = bf.size()
if size == 4:
expected_topology = nx.DiGraph(np.array(
[[1/3., 1/3., 1/3., 0.], [0., 1/3., 1/3., 1/3.],
[1/3., 0., 1/3., 1/3.], [1/3., 1/3., 0., 1/3.]]
))
expected_topology = nx.DiGraph(
np.array(
[
[1 / 3.0, 1 / 3.0, 1 / 3.0, 0.0],
[0.0, 1 / 3.0, 1 / 3.0, 1 / 3.0],
[1 / 3.0, 0.0, 1 / 3.0, 1 / 3.0],
[1 / 3.0, 1 / 3.0, 0.0, 1 / 3.0],
]
)
)
elif size == 1:
expected_topology = nx.DiGraph(np.array([[1.0]]))
else:
Expand All @@ -113,37 +131,81 @@ def test_in_out_neighbors_expo2(self):
rank = bf.rank()
size = bf.size()
assert bf.set_topology(ExponentialGraph(size))
in_neighobrs = bf.in_neighbor_ranks()
in_neighbors = bf.in_neighbor_ranks()
out_neighbors = bf.out_neighbor_ranks()

degree = int(np.ceil(np.log2(size)))
expected_in_neighbors = sorted([(rank - 2**i) %
size for i in range(degree)])
expected_out_neighbors = sorted([(rank + 2**i) %
size for i in range(degree)])
assert sorted(in_neighobrs) == expected_in_neighbors
expected_in_neighbors = sorted([(rank - 2 ** i) % size for i in range(degree)])
expected_out_neighbors = sorted([(rank + 2 ** i) % size for i in range(degree)])
assert sorted(in_neighbors) == expected_in_neighbors
assert sorted(out_neighbors) == expected_out_neighbors

def test_in_out_neighbors_biring(self):
bf.init()
rank = bf.rank()
size = bf.size()
assert bf.set_topology(RingGraph(size))
in_neighobrs = bf.in_neighbor_ranks()
in_neighbors = bf.in_neighbor_ranks()
out_neighbors = bf.out_neighbor_ranks()

expected_in_neighbors = list(set(
map(lambda x: x % size, [rank - 1, rank + 1])))
expected_out_neighbors = list(set(
map(lambda x: x % size, [rank - 1, rank + 1])))
expected_in_neighbors = list(set(map(lambda x: x % size, [rank - 1, rank + 1])))
expected_out_neighbors = list(
set(map(lambda x: x % size, [rank - 1, rank + 1]))
)

if size <= 1:
expected_in_neighbors = []
expected_out_neighbors = []

assert sorted(in_neighobrs) == expected_in_neighbors
assert sorted(in_neighbors) == expected_in_neighbors
assert sorted(out_neighbors) == expected_out_neighbors


@pytest.mark.parametrize(
"topo_func",
[ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph],
)
def test_infer_destination_from_source_ranks(topo_func):
bf.init()
size = bf.size()
bf.set_topology(topo_func(size))
topo = bf.load_topology()
in_neighbors = bf.in_neighbor_ranks()
out_neighbors = bf.out_neighbor_ranks()

# Make the W into average rule.
expected_W = (nx.to_numpy_array(topo) > 0).astype(float)
expected_W /= expected_W.sum(axis=0)

src_ranks, W = InferDestinationFromSourceRanks(
src_ranks=in_neighbors, construct_adjacency_matrix=True
)
assert sorted(src_ranks) == out_neighbors
np.testing.assert_allclose(W, expected_W)


@pytest.mark.parametrize(
"topo_func",
[ExponentialGraph, RingGraph, StarGraph, MeshGrid2DGraph, FullyConnectedGraph],
)
def test_infer_source_from_destination_ranks(topo_func):
bf.init()
size = bf.size()
bf.set_topology(topo_func(size))
topo = bf.load_topology()
in_neighbors = bf.in_neighbor_ranks()
out_neighbors = bf.out_neighbor_ranks()

# Make the W into average rule.
expected_W = (nx.to_numpy_array(topo) > 0).astype(float)
expected_W /= expected_W.sum(axis=0)

dst_ranks, W = InferSourceFromDestinationRanks(
dst_ranks=out_neighbors, construct_adjacency_matrix=True
)
assert sorted(dst_ranks) == in_neighbors
np.testing.assert_allclose(W, expected_W)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8bde896

Please sign in to comment.