Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single-node Multi-GPU training throws CUDA failure: an illegal memory access was encountered. #61

Closed
Elnifio opened this issue Jul 20, 2023 · 3 comments

Comments

@Elnifio
Copy link

Elnifio commented Jul 20, 2023

馃悰 Describe the bug

Hello, we're trying to modify the examples/igbh/train_rgnn.py so that it supports single-node multi-GPU training. However, when trying to follow the OGBN single-node multi-GPU training example, we encountered some CUDA failure /workspace/graphlearn/graphlearn_torch/csrc/cuda/unified_tensor.cu:351: 'an illegal memory access was encountered' errors when loading the first batch of data. Here is a min rep code for the issue (we removed the validation & test dataset for simplicity):

import argparse
import torch
import warnings
import os

import graphlearn_torch as glt

from dataset import IGBHeteroDataset
from rgnn import RGNN

torch.manual_seed(42)
warnings.filterwarnings("ignore")
    

def run(proc_id, devices, glt_dataset, train_idx, etypes, node_features, args):
    # examples/multi_gpu/train_sage_ogbn_papers100m.py line 35-39
    os.environ['MASTER_ADDR'] = "127.0.0.1"
    os.environ['MASTER_PORT'] = "12365"
    torch.distributed.init_process_group(
        "nccl", rank=proc_id, world_size=len(devices)
    )
    torch.cuda.set_device(proc_id)
    device = torch.device(proc_id)

    # examples/multi_gpu/train_sage_ogbn_papers100m.py line 41: splitting train_idx according to devices
    train_idx = train_idx.split(train_idx.size(0) // len(devices))[proc_id]

    train_loader = glt.loader.NeighborLoader(
        glt_dataset,
        [int(fanout) for fanout in args.fan_out.split(",")],
        ("paper", train_idx),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        device=device
    )

    model = RGNN(
        etypes,
        node_features,
        args.hidden_channels,
        args.num_classes,
        num_layers=args.num_layers,
        dropout=0.2,
        model=args.model,
        heads=args.num_heads,
        node_type='paper').to(device)
    
    # examples/multi_gpu/train_sage_ogbn_papers100m.py line 55: torch DDP
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device.index], find_unused_parameters=True)

    loss_fcn = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(args.epochs):
        model.train()
        for batch in train_loader:
            batch_size = batch['paper'].batch_size
            out = model(batch.x_dict, batch.edge_index_dict)[:batch_size]
            y = batch['paper'].y[:batch_size]
            loss = loss_fcn(out, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        torch.cuda.synchronize()
        torch.distributed.barrier()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', type=str, default="/data",
            help='path containing the datasets')
    parser.add_argument('--dataset_size', type=str, default='tiny',
            choices=['tiny', 'small', 'medium', 'large', 'full'],
            help='size of the datasets')
    parser.add_argument('--num_classes', type=int, default=19,
            choices=[19, 2983], help='number of classes')
    parser.add_argument('--in_memory', type=int, default=0,
            choices=[0, 1], help='0:read only mmap_mode=r, 1:load into memory')
    # Model
    parser.add_argument('--model', type=str, default='rgat',
                                            choices=['rgat', 'rsage'])
    # Model parameters
    parser.add_argument('--fan_out', type=str, default='10,10')
    parser.add_argument('--batch_size', type=int, default=5120)
    parser.add_argument('--hidden_channels', type=int, default=128)
    parser.add_argument('--learning_rate', type=int, default=0.01)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--log_every', type=int, default=5)
    parser.add_argument("--edge_dir", type=str, default='in')

    parser.add_argument("--gpu_devices", type=str, default="0,1")
    args = parser.parse_args()

    # in specifying GPU groups with NVLinks,
    # we separate each GPU id by comma, and separate the GPU groups by semicolon
    gpu_groups = [[int(idx) for idx in group.split(",")] for group in args.gpu_devices.split(";")]
    num_gpus = len(sum(gpu_groups, []))

    igbh_dataset = IGBHeteroDataset(
        args.path, 
        args.dataset_size, 
        args.in_memory,
        args.num_classes==2983)
    # init graphlearn_torch Dataset.
    glt_dataset = glt.data.Dataset(edge_dir=args.edge_dir)

    glt_dataset.init_graph(
        edge_index=igbh_dataset.edge_dict,
        graph_mode='ZERO_COPY'
        # examples/multi_gpu/train_sage_ogbn_papers100m.py line 99: graph_mode = zero copy
    )

    glt_dataset.init_node_features(
        node_feature_data=igbh_dataset.feat_dict,
        # examples/multi_gpu/train_sage_ogbn_papers100m.py line 102, default with_gpu is True
        with_gpu=True,
        # examples/multi_gpu/train_sage_ogbn_papers100m.py line 105
        split_ratio=0.15 * min(num_gpus, 4),
        # examples/multi_gpu/train_sage_ogbn_papers100m.py line 106, create DeviceGroups
        device_group_list=[
            glt.data.DeviceGroup(idx, group) 
            for idx, group in enumerate(gpu_groups)
        ]
    )

    # no change to initializing node labels
    glt_dataset.init_node_labels(node_label_data={'paper': igbh_dataset.label})

    etypes = igbh_dataset.etypes
    node_features = igbh_dataset.feat_dict['paper'].shape[1]

    # examples/multi_gpu/train_sage_ogbn_papers100m.py line 111: train_idx.share_memory_()
    train_idx = igbh_dataset.train_idx
    train_idx.share_memory_()

    torch.multiprocessing.spawn(run,
             args=(sum(gpu_groups, []), glt_dataset, train_idx, etypes, node_features, args),
             nprocs=len(sum(gpu_groups, [])))

After placing the above code under examples/igbh and running it with command python3 min_rep.py --path /data --dataset_size small --num_classes 2983 --epochs 3 --log_every 1, it outputs the following error messages:

CUDA failure /workspace/graphlearn/graphlearn_torch/csrc/cuda/unified_tensor.cu:351: 'an illegal memory access was encountered'
CUDA failure /workspace/graphlearn/graphlearn_torch/csrc/cuda/unified_tensor.cu:351: 'an illegal memory access was encountered'
Traceback (most recent call last):
  File "min_rep.py", line 158, in <module>
    torch.multiprocessing.spawn(run,
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 149, in join
    raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with exit code 1

Could you look into this issue and share some insights on fixing it? We have been using the default dataset.py and rgnn.py provided under examples/igbh for running this code.

Environment

  • GLT version: v0.2.0, built from source based on the current main branch
  • PyG version: 2.3.1
  • PyTorch version: 1.14.0
  • OS: Linux (Docker image)
  • Python version: 3.8.10
  • CUDA/cuDNN version: 11.8
  • Any other relevant information
@Elnifio
Copy link
Author

Elnifio commented Jul 20, 2023

Additionally, we also tried to use with_gpu=False for glt_dataset.init_node_features, and it throws an IndexError instead of the above CUDA failure when trying to load the first batch of the train dataloader. To reproduce this issue, we follow the same bash command above (using 2 GPUs by default), and update line 250-258 of the above min rep code to:

    glt_dataset.init_node_features(
        node_feature_data=igbh_dataset.feat_dict,
        with_gpu=False,
        # split_ratio=0.15 * min(num_gpus, 4),
        # device_group_list=[
        #     glt.data.DeviceGroup(idx, group) 
        #     for idx, group in enumerate(gpu_groups)
        # ]
    )

And the following error is thrown:

Traceback (most recent call last):
  File "min_rep.py", line 158, in <module>
    torch.multiprocessing.spawn(run,
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/workspace/repository/min_rep.py", line 75, in run
    for batch in train_loader:
  File "/usr/local/lib/python3.8/dist-packages/graphlearn_torch/loader/neighbor_loader.py", line 104, in __next__
    result = self._collate_fn(out)
  File "/usr/local/lib/python3.8/dist-packages/graphlearn_torch/loader/node_loader.py", line 99, in _collate_fn
    x_dict = {ntype : self.data.get_node_feature(ntype)[ids] for ntype, ids in sampler_out.node.items()}
  File "/usr/local/lib/python3.8/dist-packages/graphlearn_torch/loader/node_loader.py", line 99, in <dictcomp>
    x_dict = {ntype : self.data.get_node_feature(ntype)[ids] for ntype, ids in sampler_out.node.items()}
  File "/usr/local/lib/python3.8/dist-packages/graphlearn_torch/data/feature.py", line 145, in __getitem__
    return self.cpu_get(ids)
  File "/usr/local/lib/python3.8/dist-packages/graphlearn_torch/data/feature.py", line 163, in cpu_get
    return self.feature_tensor[ids]
IndexError: index 1004440 is out of bounds for dimension 0 with size 1000000

We have checked the train indices and they are all within 1000000 (min 0, max 599999), so we're not sure where is 1004440 index being yielded.

@husimplicity
Copy link
Collaborator

I think this has been solved by #62 . Would you try it again?

@LiSu
Copy link
Collaborator

LiSu commented Jul 25, 2023

We have fixed the bug in #62 , and added a single-node multi-gpu example.

@LiSu LiSu closed this as completed Jul 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants