From 133ece296db3e3e667beb801053d76d5b1a823e9 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Thu, 7 Apr 2022 12:32:52 +0000 Subject: [PATCH 1/2] update --- python/paddle/distributed/collective.py | 27 ++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index a5ea528d13450..91a292030985a 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -138,7 +138,7 @@ def _get_global_env(): # Name of the default group for init_parallel_env _default_group_name = "_default_pg" -_valid_backend_list = ['nccl', 'gloo', 'hccl'] +_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter'] _default_store = None # the default tcp store _default_backend = None @@ -234,6 +234,31 @@ def _new_process_group_impl(backend, pg = core.ProcessGroupNCCL(store, rank, world_size, group_id) elif backend == "hccl": pg = core.ProcessGroupHCCL(store, rank, world_size, group_id) + elif backend == "heter": + cluster_id = int(os.getenv("CLUSTER_ID", "-1")) + assert cluster_id >= 0, "please set the CLUSTER_ID variable." + cluster_size = os.getenv("CLUSTER_SIZE", None) + assert cluster_size, "please set the CLUSTER_SIZE variable." + cluster_size = cluster_size.split(",") + cluster_size = [int(s) for s in cluster_size] + switch_ep = os.getenv("CLUSTER_SWITCH", None) + assert switch_ep, "please set the CLUSTER_SWITCH variable." + cluster_size_cumsum = np.cumsum(cluster_size) + cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[ + cluster_id - 1] + global_rank = cluster_offset + rank + global_world_size = cluster_size_cumsum[-1] + pg = core.ProcessGroupHeter( + store, + rank=global_rank, + world_size=global_world_size, + gid=0, + local_rank=rank, + local_size=world_size, + gloo_rank=cluster_id, + gloo_size=len(cluster_sizes), + with_switch=True, + switch_endpoint=switch_ep) return pg From 9f74c12dbfa87abba152292d8c075a66728a539a Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Fri, 8 Apr 2022 05:10:16 +0000 Subject: [PATCH 2/2] update --- python/paddle/distributed/collective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 91a292030985a..fbad470cb3f13 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -256,7 +256,7 @@ def _new_process_group_impl(backend, local_rank=rank, local_size=world_size, gloo_rank=cluster_id, - gloo_size=len(cluster_sizes), + gloo_size=len(cluster_size), with_switch=True, switch_endpoint=switch_ep)