diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index a5ea528d13450..fbad470cb3f13 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -138,7 +138,7 @@ def _get_global_env(): # Name of the default group for init_parallel_env _default_group_name = "_default_pg" -_valid_backend_list = ['nccl', 'gloo', 'hccl'] +_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter'] _default_store = None # the default tcp store _default_backend = None @@ -234,6 +234,31 @@ def _new_process_group_impl(backend, pg = core.ProcessGroupNCCL(store, rank, world_size, group_id) elif backend == "hccl": pg = core.ProcessGroupHCCL(store, rank, world_size, group_id) + elif backend == "heter": + cluster_id = int(os.getenv("CLUSTER_ID", "-1")) + assert cluster_id >= 0, "please set the CLUSTER_ID variable." + cluster_size = os.getenv("CLUSTER_SIZE", None) + assert cluster_size, "please set the CLUSTER_SIZE variable." + cluster_size = cluster_size.split(",") + cluster_size = [int(s) for s in cluster_size] + switch_ep = os.getenv("CLUSTER_SWITCH", None) + assert switch_ep, "please set the CLUSTER_SWITCH variable." + cluster_size_cumsum = np.cumsum(cluster_size) + cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[ + cluster_id - 1] + global_rank = cluster_offset + rank + global_world_size = cluster_size_cumsum[-1] + pg = core.ProcessGroupHeter( + store, + rank=global_rank, + world_size=global_world_size, + gid=0, + local_rank=rank, + local_size=world_size, + gloo_rank=cluster_id, + gloo_size=len(cluster_size), + with_switch=True, + switch_endpoint=switch_ep) return pg