Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DGraph/CommunicatorBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ def get_rank(self) -> int:

def get_world_size(self) -> int:
raise NotImplementedError

def barrier(self):
raise NotImplementedError
24 changes: 20 additions & 4 deletions DGraph/data/ogbn_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,25 @@ def __init__(
self._rank = self.comm_object.get_rank()
self._world_size = self.comm_object.get_world_size()

self.dataset = NodePropPredDataset(
name=dname,
)
comm_object.barrier()
# Load the dataset on rank 0
if comm_object.get_rank() == 0:
self.dataset = NodePropPredDataset(
name=dname,
)
# Block until rank 0 loads and processe the data
# For the first time, the code downloads and processes the data
# doing that on all ranks causes a race condition
comm_object.barrier()
# Load the dataset on all other ranks
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szaman19 What is the issue with the race here, is it that the first one will download the data set to local disk and then the rest will load from local disk? (race is concurrent downloads?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, for the first run, OGB will download the raw data, unzip it, and then delete the raw data. Subsequent runs search for the processed files and use them. If we don't lock around it, the concurrent downloads and processing calls result in OS errors.

# This is to use the processed data that was generated by rank 0
# This should account for a race condition

if comm_object.get_rank() != 0:
self.dataset = NodePropPredDataset(
name=dname,
)
comm_object.barrier()
graph_data, labels = self.dataset[0]

self.split_idx = self.dataset.get_idx_split()
Expand All @@ -185,7 +201,7 @@ def __init__(
dir_name = dir_name if dir_name is not None else os.getcwd() + "/data"

if not os.path.exists(dir_name):
os.makedirs(dir_name)
os.makedirs(dir_name, exist_ok=True)

cached_graph_file = f"{dir_name}/{dname}_graph_data_{self._world_size}.pt"

Expand Down
14 changes: 6 additions & 8 deletions DGraph/distributed/RankLocalOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def OptimizedRankLocalMaskedGather(
num_features = src.shape[-1]
local_masked_gather(
src,
indices,
rank_mapping,
indices.cuda(),
rank_mapping.cuda(),
output,
bs,
num_src_rows,
Expand Down Expand Up @@ -137,13 +137,11 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping):
"""
This function removes duplicates from the indices tensor.
"""
unique_indices = torch.unique(_indices).to(_indices.device)
unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True)
rank_mapping = rank_mapping.to(_indices.device)
renumbered_indices = torch.zeros_like(_indices)
unique_rank_mapping = torch.zeros_like(unique_indices)
for i, idx in enumerate(unique_indices):
renumbered_indices[_indices == idx] = i
unique_rank_mapping[i] = rank_mapping[_indices == idx][0]
renumbered_indices = inverse_indices
unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device)
unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping)

return renumbered_indices, unique_indices, unique_rank_mapping

Expand Down
14 changes: 11 additions & 3 deletions DGraph/distributed/nccl/NCCLBackendEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,18 @@ def gather(
return output_tensor # type: ignore

def destroy(self) -> None:
if self._initialized:
if NCCLBackendEngine._is_initialized:
Copy link

Copilot AI Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code references 'NCCLBackendEngine._is_initialized' but based on the context, this appears to be changing from 'self._initialized'. This could cause issues if '_is_initialized' is not properly defined as a class variable or if other methods still reference 'self._initialized'.

Copilot uses AI. Check for mistakes.
# dist.destroy_process_group()
self._initialized = False
NCCLBackendEngine._is_initialized = False

def finalize(self) -> None:
if self._initialized:
if NCCLBackendEngine._is_initialized:
dist.barrier()

def barrier(self) -> None:
if NCCLBackendEngine._is_initialized:
dist.barrier()
else:
raise RuntimeError(
"NCCLBackendEngine is not initialized, cannot call barrier"
)
164 changes: 164 additions & 0 deletions experiments/OGB/GenerateCache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
# Produced at the Lawrence Livermore National Laboratory.
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
# the CONTRIBUTORS file. See the top-level LICENSE file for details.
#
# LLNL-CODE-697807.
# All rights reserved.
#
# This file is part of LBANN: Livermore Big Artificial Neural Network
# Toolkit. For details, see http://software.llnl.gov/LBANN or
# https://github.com/LBANN and https://github.com/LLNL/LBANN.
#
# SPDX-License-Identifier: (Apache-2.0)

from DGraph.data.ogbn_datasets import process_homogenous_data
from ogb.nodeproppred import NodePropPredDataset
from fire import Fire
import os
import torch
from DGraph.distributed.nccl._nccl_cache import (
NCCLGatherCacheGenerator,
NCCLScatterCacheGenerator,
)
from time import perf_counter
from tqdm import tqdm
from multiprocessing import get_context


cache_prefix = {
"ogbn-arxiv": "arxiv",
"ogbn-products": "products",
"ogbn-papers100M": "papers100M",
}


def generate_cache_file(
dist_graph,
src_indices,
dst_indices,
edge_placement,
edge_src_placement,
edge_dest_placement,
cache_prefix_str: str,
rank: int,
world_size: int,
):
print(f"Generating cache for rank {rank}...")
local_node_features = dist_graph.get_local_node_features(rank).unsqueeze(0)
num_input_rows = local_node_features.size(1)

print(
f"Rank {rank} has {num_input_rows} input rows with shape {local_node_features.shape}"
)
gather_cache = NCCLGatherCacheGenerator(
dst_indices,
edge_placement,
edge_dest_placement,
num_input_rows,
rank,
world_size,
)

nodes_per_rank = dist_graph.get_nodes_per_rank()
nodes_per_rank = int(nodes_per_rank[rank].item())

scatter_cache = NCCLScatterCacheGenerator(
src_indices,
edge_placement,
edge_src_placement,
nodes_per_rank,
rank,
world_size,
)
print(f"Rank {rank} completed cache generation")
with open(
f"{cache_prefix_str}_gather_cache_rank_{world_size}_{rank}.pt", "wb"
) as f:
torch.save(gather_cache, f)

with open(
f"{cache_prefix_str}_scatter_cache_rank_{world_size}_{rank}.pt", "wb"
) as f:
torch.save(scatter_cache, f)
return 0


def main(dset: str, world_size: int, node_rank_placement_file: str):
assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"]

assert world_size > 0
assert os.path.exists(
node_rank_placement_file
), "Node rank placement file does not exist."

node_rank_placement = torch.load(node_rank_placement_file)

dataset = NodePropPredDataset(
dset,
)

split_index = dataset.get_idx_split()
assert split_index is not None, "Split index is None."

graph, labels = dataset[0]

num_edges = graph["edge_index"].shape
print(num_edges)

dist_graph = process_homogenous_data(
graph_data=graph,
labels=labels,
world_Size=world_size,
split_idx=split_index,
node_rank_placement=node_rank_placement,
rank=0,
)

edge_indices = dist_graph.get_global_edge_indices()
rank_mappings = dist_graph.get_global_rank_mappings()

print("Edge indices shape:", edge_indices.shape)
print("Rank mappings shape:", rank_mappings.shape)

edge_indices = edge_indices.unsqueeze(0)
src_indices = edge_indices[:, 0, :]
dst_indices = edge_indices[:, 1, :]

edge_placement = rank_mappings[0]
edge_src_placement = rank_mappings[0]
edge_dest_placement = rank_mappings[1]

start_time = perf_counter()
cache_prefix_str = f"cache/{cache_prefix[dset]}"
with get_context("spawn").Pool(min(world_size, 8)) as pool:
args = [
(
dist_graph,
src_indices,
dst_indices,
edge_placement,
edge_src_placement,
edge_dest_placement,
cache_prefix_str,
rank,
world_size,
)
for rank in range(world_size)
]

out = pool.starmap(generate_cache_file, args)

end_time = perf_counter()
print(f"Cache generation time: {end_time - start_time:.4f} seconds")
print("Cache files generated successfully.")
print(
f"Gather cache file: {cache_prefix_str}_gather_cache_rank_{world_size}_<rank>.pt"
)
print(
f"Scatter cache file: {cache_prefix_str}_scatter_cache_rank_{world_size}_<rank>.pt"
)


if __name__ == "__main__":
Fire(main)
7 changes: 6 additions & 1 deletion experiments/OGB/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ DGraph supports distributed training using the `nccl`, `nvshmem`, and `mpi` back
In order to run the experiments with the `nccl` backend, run the following command:

```bash
torchrun --nnodes <nodes> --nproc-per-node <gpus> main.py --backend nccl --lr lr --epochs epochs --runs runs --log_dir log-dir
torchrun-hpc -N <nodes> -n <gpus> main.py --backend nccl --lr lr --epochs epochs --runs runs --node_rank_placement_file <file_dir> --log_dir log-dir
```
You may have to turn ``--xargs=--mpibind=off`` and ``--xargs=--gpu-bind=none`` in your Slurm script to avoid binding issues.

**Note that we use `torchrun-hpc` instead of `torchrun` **, the run command may vary based on your environment.



### Additional Notes
The experiments use some additional libraries. Use the [ogb] option
Expand Down
Loading