diff --git a/DGraph/CommunicatorBase.py b/DGraph/CommunicatorBase.py index e233420..502c841 100644 --- a/DGraph/CommunicatorBase.py +++ b/DGraph/CommunicatorBase.py @@ -26,3 +26,6 @@ def get_rank(self) -> int: def get_world_size(self) -> int: raise NotImplementedError + + def barrier(self): + raise NotImplementedError diff --git a/DGraph/data/ogbn_datasets.py b/DGraph/data/ogbn_datasets.py index 874ef16..8125d58 100644 --- a/DGraph/data/ogbn_datasets.py +++ b/DGraph/data/ogbn_datasets.py @@ -174,9 +174,25 @@ def __init__( self._rank = self.comm_object.get_rank() self._world_size = self.comm_object.get_world_size() - self.dataset = NodePropPredDataset( - name=dname, - ) + comm_object.barrier() + # Load the dataset on rank 0 + if comm_object.get_rank() == 0: + self.dataset = NodePropPredDataset( + name=dname, + ) + # Block until rank 0 loads and processe the data + # For the first time, the code downloads and processes the data + # doing that on all ranks causes a race condition + comm_object.barrier() + # Load the dataset on all other ranks + # This is to use the processed data that was generated by rank 0 + # This should account for a race condition + + if comm_object.get_rank() != 0: + self.dataset = NodePropPredDataset( + name=dname, + ) + comm_object.barrier() graph_data, labels = self.dataset[0] self.split_idx = self.dataset.get_idx_split() @@ -185,7 +201,7 @@ def __init__( dir_name = dir_name if dir_name is not None else os.getcwd() + "/data" if not os.path.exists(dir_name): - os.makedirs(dir_name) + os.makedirs(dir_name, exist_ok=True) cached_graph_file = f"{dir_name}/{dname}_graph_data_{self._world_size}.pt" diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index 3cafb0a..c4b6de0 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -69,8 +69,8 @@ def OptimizedRankLocalMaskedGather( num_features = src.shape[-1] local_masked_gather( src, - indices, - rank_mapping, + indices.cuda(), + rank_mapping.cuda(), output, bs, num_src_rows, @@ -137,13 +137,11 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping): """ This function removes duplicates from the indices tensor. """ - unique_indices = torch.unique(_indices).to(_indices.device) + unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True) rank_mapping = rank_mapping.to(_indices.device) - renumbered_indices = torch.zeros_like(_indices) - unique_rank_mapping = torch.zeros_like(unique_indices) - for i, idx in enumerate(unique_indices): - renumbered_indices[_indices == idx] = i - unique_rank_mapping[i] = rank_mapping[_indices == idx][0] + renumbered_indices = inverse_indices + unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device) + unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping) return renumbered_indices, unique_indices, unique_rank_mapping diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index eb94b6c..b3ea11a 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -689,10 +689,18 @@ def gather( return output_tensor # type: ignore def destroy(self) -> None: - if self._initialized: + if NCCLBackendEngine._is_initialized: # dist.destroy_process_group() - self._initialized = False + NCCLBackendEngine._is_initialized = False def finalize(self) -> None: - if self._initialized: + if NCCLBackendEngine._is_initialized: dist.barrier() + + def barrier(self) -> None: + if NCCLBackendEngine._is_initialized: + dist.barrier() + else: + raise RuntimeError( + "NCCLBackendEngine is not initialized, cannot call barrier" + ) diff --git a/experiments/OGB/GenerateCache.py b/experiments/OGB/GenerateCache.py new file mode 100644 index 0000000..a16e795 --- /dev/null +++ b/experiments/OGB/GenerateCache.py @@ -0,0 +1,164 @@ +# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +from DGraph.data.ogbn_datasets import process_homogenous_data +from ogb.nodeproppred import NodePropPredDataset +from fire import Fire +import os +import torch +from DGraph.distributed.nccl._nccl_cache import ( + NCCLGatherCacheGenerator, + NCCLScatterCacheGenerator, +) +from time import perf_counter +from tqdm import tqdm +from multiprocessing import get_context + + +cache_prefix = { + "ogbn-arxiv": "arxiv", + "ogbn-products": "products", + "ogbn-papers100M": "papers100M", +} + + +def generate_cache_file( + dist_graph, + src_indices, + dst_indices, + edge_placement, + edge_src_placement, + edge_dest_placement, + cache_prefix_str: str, + rank: int, + world_size: int, +): + print(f"Generating cache for rank {rank}...") + local_node_features = dist_graph.get_local_node_features(rank).unsqueeze(0) + num_input_rows = local_node_features.size(1) + + print( + f"Rank {rank} has {num_input_rows} input rows with shape {local_node_features.shape}" + ) + gather_cache = NCCLGatherCacheGenerator( + dst_indices, + edge_placement, + edge_dest_placement, + num_input_rows, + rank, + world_size, + ) + + nodes_per_rank = dist_graph.get_nodes_per_rank() + nodes_per_rank = int(nodes_per_rank[rank].item()) + + scatter_cache = NCCLScatterCacheGenerator( + src_indices, + edge_placement, + edge_src_placement, + nodes_per_rank, + rank, + world_size, + ) + print(f"Rank {rank} completed cache generation") + with open( + f"{cache_prefix_str}_gather_cache_rank_{world_size}_{rank}.pt", "wb" + ) as f: + torch.save(gather_cache, f) + + with open( + f"{cache_prefix_str}_scatter_cache_rank_{world_size}_{rank}.pt", "wb" + ) as f: + torch.save(scatter_cache, f) + return 0 + + +def main(dset: str, world_size: int, node_rank_placement_file: str): + assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"] + + assert world_size > 0 + assert os.path.exists( + node_rank_placement_file + ), "Node rank placement file does not exist." + + node_rank_placement = torch.load(node_rank_placement_file) + + dataset = NodePropPredDataset( + dset, + ) + + split_index = dataset.get_idx_split() + assert split_index is not None, "Split index is None." + + graph, labels = dataset[0] + + num_edges = graph["edge_index"].shape + print(num_edges) + + dist_graph = process_homogenous_data( + graph_data=graph, + labels=labels, + world_Size=world_size, + split_idx=split_index, + node_rank_placement=node_rank_placement, + rank=0, + ) + + edge_indices = dist_graph.get_global_edge_indices() + rank_mappings = dist_graph.get_global_rank_mappings() + + print("Edge indices shape:", edge_indices.shape) + print("Rank mappings shape:", rank_mappings.shape) + + edge_indices = edge_indices.unsqueeze(0) + src_indices = edge_indices[:, 0, :] + dst_indices = edge_indices[:, 1, :] + + edge_placement = rank_mappings[0] + edge_src_placement = rank_mappings[0] + edge_dest_placement = rank_mappings[1] + + start_time = perf_counter() + cache_prefix_str = f"cache/{cache_prefix[dset]}" + with get_context("spawn").Pool(min(world_size, 8)) as pool: + args = [ + ( + dist_graph, + src_indices, + dst_indices, + edge_placement, + edge_src_placement, + edge_dest_placement, + cache_prefix_str, + rank, + world_size, + ) + for rank in range(world_size) + ] + + out = pool.starmap(generate_cache_file, args) + + end_time = perf_counter() + print(f"Cache generation time: {end_time - start_time:.4f} seconds") + print("Cache files generated successfully.") + print( + f"Gather cache file: {cache_prefix_str}_gather_cache_rank_{world_size}_.pt" + ) + print( + f"Scatter cache file: {cache_prefix_str}_scatter_cache_rank_{world_size}_.pt" + ) + + +if __name__ == "__main__": + Fire(main) diff --git a/experiments/OGB/Readme.md b/experiments/OGB/Readme.md index 37265cd..600f26f 100644 --- a/experiments/OGB/Readme.md +++ b/experiments/OGB/Readme.md @@ -23,8 +23,13 @@ DGraph supports distributed training using the `nccl`, `nvshmem`, and `mpi` back In order to run the experiments with the `nccl` backend, run the following command: ```bash -torchrun --nnodes --nproc-per-node main.py --backend nccl --lr lr --epochs epochs --runs runs --log_dir log-dir +torchrun-hpc -N -n main.py --backend nccl --lr lr --epochs epochs --runs runs --node_rank_placement_file --log_dir log-dir ``` +You may have to turn ``--xargs=--mpibind=off`` and ``--xargs=--gpu-bind=none`` in your Slurm script to avoid binding issues. + +**Note that we use `torchrun-hpc` instead of `torchrun` **, the run command may vary based on your environment. + + ### Additional Notes The experiments use some additional libraries. Use the [ogb] option diff --git a/experiments/OGB/main.py b/experiments/OGB/main.py index 31a52f6..3ccd58b 100644 --- a/experiments/OGB/main.py +++ b/experiments/OGB/main.py @@ -81,6 +81,10 @@ def rank_cuda_device(self): device = torch.cuda.current_device() return device + def barrier(self): + # No-op for single process + pass + def _run_experiment( dataset, @@ -88,16 +92,18 @@ def _run_experiment( lr: float, epochs: int, log_prefix: str, + in_dim: int = 128, hidden_dims: int = 128, num_classes: int = 40, use_cache: bool = False, + dset_name: str = "arxiv", ): local_rank = comm.get_rank() % torch.cuda.device_count() print(f"Rank: {local_rank} Local Rank: {local_rank}") torch.cuda.set_device(local_rank) device = torch.cuda.current_device() model = GCN( - in_channels=128, hidden_dims=hidden_dims, num_classes=num_classes, comm=comm + in_channels=in_dim, hidden_dims=hidden_dims, num_classes=num_classes, comm=comm ) rank = comm.get_rank() model = model.to(device) @@ -114,9 +120,9 @@ def _run_experiment( node_features, edge_indices, rank_mappings, labels = dataset[0] node_features = node_features.to(device).unsqueeze(0) - edge_indices = edge_indices.to(device)[:, :-1].unsqueeze(0) + edge_indices = edge_indices.to(device).unsqueeze(0) labels = labels.to(device).unsqueeze(0) - rank_mappings = rank_mappings[:, :-1] + rank_mappings = rank_mappings if rank == 0: print("*" * 80) @@ -125,7 +131,8 @@ def _run_experiment( print(f"Rank: {rank} Mapping: {rank_mappings.shape}") print(f"Rank: {rank} Node Features: {node_features.shape}") print(f"Rank: {rank} Edge Indices: {edge_indices.shape}") - dist.barrier() + + comm.barrier() criterion = torch.nn.CrossEntropyLoss() train_mask = dataset.graph_obj.get_local_mask("train", rank) @@ -150,6 +157,16 @@ def _run_experiment( # This says where the edges are located edge_placement = rank_mappings[0] + + cache_prefix = f"cache/{dset_name}" + scatter_cache_file = f"{cache_prefix}_scatter_cache_{world_size}_{rank}.pt" + gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt" + + if os.path.exists(gather_cache_file): + gather_cache = torch.load(gather_cache_file, weights_only=False) + + if os.path.exists(scatter_cache_file): + scatter_cache = torch.load(scatter_cache_file, weight_only=False) # These say where the source and destination nodes are located edge_src_placement = rank_mappings[ @@ -169,6 +186,9 @@ def _run_experiment( rank, world_size, ) + with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: + torch.save(gather_cache, f) + if scatter_cache is None: nodes_per_rank = dataset.graph_obj.get_nodes_per_rank() @@ -180,6 +200,8 @@ def _run_experiment( rank, world_size, ) + with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: + torch.save(scatter_cache, f) # Sanity checks for the cache for key, value in gather_cache.gather_send_local_placement.items(): @@ -208,16 +230,16 @@ def _run_experiment( end_time = perf_counter() print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s") - if rank == 0: - with open(f"{log_prefix}_gather_cache_{world_size}.pt", "wb") as f: - torch.save(gather_cache, f) - with open(f"{log_prefix}_scatter_cache_{world_size}.pt", "wb") as f: - torch.save(scatter_cache, f) - print(f"Rank: {rank} Cache Generated") + + #with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f: + # torch.save(gather_cache, f) + #with open(f"{log_prefix}_scatter_cache_{world_size}_{rank}.pt", "wb") as f: + # torch.save(scatter_cache, f) + #print(f"Rank: {rank} Cache Generated") training_times = [] for i in range(epochs): - dist.barrier() + comm.barrier() torch.cuda.synchronize() start_time = torch.cuda.Event(enable_timing=True) end_time = torch.cuda.Event(enable_timing=True) @@ -234,7 +256,7 @@ def _run_experiment( dist_print_ephemeral(f"Epoch {i} \t Loss: {loss.item()}", rank) optimizer.step() - dist.barrier() + comm.barrier() end_time.record(stream) torch.cuda.synchronize() training_times.append(start_time.elapsed_time(end_time)) @@ -313,7 +335,7 @@ def main( use_cache: bool = False, ): _communicator = backend.lower() - + dset_name = dataset assert _communicator.lower() in [ "single", "nccl", @@ -321,6 +343,8 @@ def main( "mpi", ], "Invalid backend" + in_dims = {"arxiv": 128, "products": 100} + assert dataset in ["arxiv", "products"], "Invalid dataset" node_rank_placement = None @@ -366,6 +390,8 @@ def main( log_prefix, use_cache=use_cache, num_classes=num_classes, + dset_name=dset_name, + in_dim=in_dims[dset_name] ) training_trajectores[i] = training_traj validation_trajectores[i] = val_traj diff --git a/experiments/OGB/utils.py b/experiments/OGB/utils.py index 377a5e7..fef5d8f 100644 --- a/experiments/OGB/utils.py +++ b/experiments/OGB/utils.py @@ -56,4 +56,7 @@ def safe_create_dir(directory, rank): def calculate_accuracy(pred, labels): pred = pred.argmax(dim=1) correct = pred.eq(labels).sum().item() - return correct / len(labels) * 100 + if len(labels) > 0: + return correct / len(labels) * 100 + else: + return 0.0