Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot run V-Net on medical decathlon data #7852

Open
linnabraham opened this issue Jun 15, 2024 · 6 comments
Open

Cannot run V-Net on medical decathlon data #7852

linnabraham opened this issue Jun 15, 2024 · 6 comments

Comments

@linnabraham
Copy link

Describe the bug
PyTorch complains of size mismatch when using V-Net with medical decathlon data.

To Reproduce

import monai
from monai.apps import DecathlonDataset
from monai.transforms import LoadImaged, EnsureChannelFirstd,ScaleIntensityd, ToTensord, Compose
from monai.networks.nets import VNet
from monai.losses.dice import DiceLoss
import torch

def train_one_epoch(train_loader, loss_fn, optimizer, epoch):
    running_loss = 0.
    example_ct = 0

    for batch_idx, dict_item in enumerate(train_loader):
        images = dict_item['image']
        labels = dict_item['label']
        print("Shape of images", images.shape)
        print("Shape of labels", labels.shape)
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        example_ct += images.size(0)
        metrics = {"train/train_loss": loss.item(),
                   "train/epoch": epoch,
                    "train/example_ct": example_ct
                   }
        print(metrics)
    return running_loss/example_ct

def train_loop(train_loader, val_loader):
    loss_fn = DiceLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    best_vloss = 1_000_000.
    for epoch in range(3):
        print(f"Epoch:{epoch+1}")
        model.train()
        avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
        print("train loss", avg_train_loss)

if __name__=="__main__":

    transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    train_data = DecathlonDataset(
        root_dir="./", task="Task04_Hippocampus", transform=transform, section="validation", seed=12345, download=False
    )
    model = VNet(spatial_dims=3, in_channels=1, out_channels=1, act='elu')
    train_loader = monai.data.DataLoader(
        train_data, batch_size=1, num_workers=2, persistent_workers=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device {device}")

    model.to(device)

    train_loop(train_loader, val_loader=None)

Expected behavior
Training happens

Screenshots

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 5 for tensor number 1 in the list.

Complete Traceback

Traceback (most recent call last):
  File "/home/linn/vnet/train.py", line 65, in <module>
    train_loop(train_loader, val_loader=None)
  File "/home/linn/vnet/train.py", line 39, in train_loop
    avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linn/vnet/train.py", line 19, in train_one_epoch
    outputs = model(images)
              ^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 274, in forward
    x = self.up_tr256(out256, out128)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 165, in forward
    xcat = torch.cat((out, skipxdo), 1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/data/meta_tensor.py", line 282, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/_tensor.py", line 1443, in __torch_function__
    ret = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^

Environment

Ensuring you use the relevant python executable, please paste the output of:

python -c "import monai; monai.config.print_debug_info()"
================================
Printing MONAI config...
================================
MONAI version: 1.3.1
Numpy version: 1.26.4
Pytorch version: 2.3.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 96bfda00c6bd290297f5e3514ea227c6be4d08b4
MONAI __file__: /data/<username>/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.4
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies


================================
Printing system config...
================================
`psutil` required for `print_system_info`

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 12.1
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 8902
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A100-PCIE-40GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 108
GPU 0 Total memory (GB): 39.4
GPU 0 CUDA capability (maj.min): 8.0

**Additional context**
Add any other context about the problem here.
@KumoLiu
Copy link
Contributor

KumoLiu commented Jun 25, 2024

Hi @linnabraham, looks like a shape mismatch issue. Did you try to check your input data shape before sending to the model?

@linnabraham
Copy link
Author

@KumoLiu I did now and it seems like the decathlon data shape is not compatible with the V-Net. I was not expecting that since I had earlier used the same data with tensorflow implementation of V-Net (https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Segmentation/VNet).

But I found an issue with the V-Net implementation. It seems like the out_channels is hard coded as 16. Which implied that the in_channels could only be 16 or 1. I have re-opened an bug report that was closed without the proper fix here #4896

@KumoLiu
Copy link
Contributor

KumoLiu commented Jul 8, 2024

Which implied that the in_channels could only be 16 or 1. I have re-opened an bug report that was closed without the proper fix here #4896

The in_channels can be multiples of 16.

in_channels: number of input channels for the network. Defaults to 1.

In Tensorflow, they also set out_channels as 16:
https://github.com/NVIDIA/DeepLearningExamples/blob/729963dd47e7c8bd462ad10bfac7a7b0b604e6dd/TensorFlow/Segmentation/VNet/model/vnet.py#L34

@linnabraham
Copy link
Author

Thanks for pointing out the tensorflow code. But I am still confused. My input has shape (64, 128, 128). Right now I edited the source code to remove 16 from being hard coded, but no matter what I give as out_channel, 1, 16, 64, 128, I am getting a shape mismatch error. What do I do?

@KumoLiu
Copy link
Contributor

KumoLiu commented Jul 8, 2024

If your shape is (64, 128, 128), then your spatial_dims should be 2 since 64 is the channel dim.

@linnabraham
Copy link
Author

Thanks for pointing that out. I set it to 2. I could not use 16 as out channels, so I tried 64 itself. Now I get this error

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 64, 128, 128, 1]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants