Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetaTensor and DistributedDataParallel. bug (SyncBatchNormBackward is a view and is being modified inplace) #5283

Open
myron opened this issue Oct 7, 2022 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@myron
Copy link
Collaborator

myron commented Oct 7, 2022

When upgrading from MONAI 0.9.0 to 1.0.0, my 3D segmentation code fails due to (most likely) new MetaTensor in transforms, when using DistributedDataParallel (multi-gpu)

the error is
RuntimeError: Output 0 of SyncBatchNormBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

same issue was reported here (but for 2D MIL classification)
#5081
and #5198

I've traced it down to this commit
63e36b6
(prior to it, the code is working fine)

It seems the issue is that dataloader returns data as MetaTensor (and not torch.Tensor as before)
e.g. here https://github.com/Project-MONAI/tutorials/blob/main/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py#L51
both data and target are MetaTensor types

if converting explicitly (on gpu or cpu):

data = torch.Tensor(data)
target = torch.Tensor(target)

then the code runs fine, but a bit slower. It seems there is something wrong with MetaTensor

@myron myron added the bug Something isn't working label Oct 7, 2022
@myron myron added this to the Auto3D Seg framework [P0 v1.0] milestone Oct 7, 2022
@myron myron assigned wyli and Nic-Ma Oct 7, 2022
@wyli wyli removed this from the Auto3D Seg framework [P0 v1.0] milestone Oct 7, 2022
@wyli
Copy link
Contributor

wyli commented Oct 7, 2022

thanks for reporting, I'm able to reproduce with torchrun --nnodes=1 --nproc_per_node=2 test.py using this test.py:

import torch.distributed as dist

import torch
from torchvision import models
from monai.data import MetaTensor

torch.autograd.set_detect_anomaly(True)

def run():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"rank {rank}")
    device = rank

    mod = models.resnet50(pretrained=True).to(device)
    optim = torch.optim.Adam(mod.parameters(), lr=1e-3)
    z1 = MetaTensor(torch.zeros(1, 3, 128, 128)).to(device)

    mod = torch.nn.SyncBatchNorm.convert_sync_batchnorm(mod)
    mod = torch.nn.parallel.DistributedDataParallel(mod, device_ids=[rank], output_device=rank)

    out = mod(z1)
    print(out.shape)
    loss = (out**2).mean()

    optim.zero_grad()
    loss.backward()
    optim.step()

    print("Stepped.")

if __name__ == "__main__":
    run()

I'll submit a PR to fix this.

@wyli
Copy link
Contributor

wyli commented Oct 7, 2022

looks like a pytorch issue, I created a bug report (pytorch/pytorch#86456).

@wyli wyli changed the title MetaTensor and DistributedDataParallel. bug MetaTensor and DistributedDataParallel. bug (SyncBatchNormBackward is a view and is being modified inplace) Oct 7, 2022
@vikashg vikashg closed this as completed Dec 19, 2023
@KumoLiu
Copy link
Contributor

KumoLiu commented Dec 20, 2023

Because the bug in the upstream has not yet been fixed, this ticket should be kept.

@KumoLiu KumoLiu reopened this Dec 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants