Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add warp layer #1452

Closed
kate-sann5100 opened this issue Jan 15, 2021 · 20 comments · Fixed by #1463 or #1470
Closed

add warp layer #1452

kate-sann5100 opened this issue Jan 15, 2021 · 20 comments · Fixed by #1463 or #1470

Comments

@kate-sann5100
Copy link
Collaborator

Is your feature request related to a problem? Please describe.
Current repository does not support warping base on ddf, which is indispensable for registration.

Describe the solution you'd like
Add a Warp layer.

Describe alternatives you've considered
N/A

Additional context
DeepReg implementation of Warp layer

@tvercaut
Copy link
Member

Note that this is related to #789, #189 and #95

Pytorch's grid-sample function provides a GPU implementation of linear interpolation for image resampling.

There is already an initial implementation of higher order resampling here but it needs more testing and use case examples:
https://github.com/Project-MONAI/MONAI/tree/master/monai/csrc/resample

@kate-sann5100
Copy link
Collaborator Author

Note that this is related to #789, #189 and #95

Pytorch's grid-sample function provides a GPU implementation of linear interpolation for image resampling.

There is already an initial implementation of higher order resampling here but it needs more testing and use case examples:
https://github.com/Project-MONAI/MONAI/tree/master/monai/csrc/resample

Thank you for the useful information. I will try to figure out the best way to integrate both pytorch and MONAI implementations into the layer.

@kate-sann5100
Copy link
Collaborator Author

kate-sann5100 commented Jan 17, 2021

It would be super helpful if anyone could answer the following questions about the MONAI implementation of resampling:

  • Is it true that in grid_pull(input, grid), grid[n, d, h, w] specifies the delta_z, delta_y, delta_x displacements from the sample location to the target location ?
  • What is the expected range of grid? Is it defined in the pixel_space, like defined in create_grid?

Thanks in advance. @tvercaut

@tvercaut
Copy link
Member

@brudfors is probably th ebest person to advice on this implementation as it stemmed from https://github.com/balbasty/nitorch

@brudfors
Copy link
Contributor

Hi @kate-sann5100

I am happy to try to answer your questions.

The grid is defined in the target/output space and should specify the coordinates to sample in the input image. The grid coordinates should be defined in the voxel space of the input image. Here is some sample code that may, or may not, shed some more light on how its used:

https://github.com/balbasty/nitorch/blob/master/demo/demo_spatial.ipynb

@kate-sann5100
Copy link
Collaborator Author

@brudfors Thank you for the information, that completely answers my questions.

@kate-sann5100 kate-sann5100 mentioned this issue Jan 18, 2021
7 tasks
@tvercaut
Copy link
Member

By the way, @brudfors is there a computational time comparison somewhere for pytorch grid_sample vs MONAI/nitorch grid_pull in say 3D linear interpolation mode?

@brudfors
Copy link
Contributor

Hi @tvercaut,

The below Colab notebook shows the speed-up doing GPU resampling compared to CPU:

https://colab.research.google.com/drive/1qICvEDn-p8RnmaG-9v0sZu__6OQafM4A?usp=sharing

it is quite substantial.

@tvercaut
Copy link
Member

Thanks. I was more interested in GPU grid_sample (vanilla pytorch) vs GPU grid_pull (nitorch based).

@brudfors
Copy link
Contributor

brudfors commented Jan 19, 2021

Oh yes, sorry, I added that comparison to the notebook as well: vanilla pytorch and nitorch seem to perform similarly, but there is probably a more thorough validation that could be done.

@tvercaut
Copy link
Member

I am reopening this issue following the comment above and the note from @wyli here:
#1463 (review)

I would suggest only using grid_pull for interpolation orders strictly higher than linear as these are not supported by the vanilla grid_sample. For linear and nearest, unless I missed it, there is no benefit in using grid_pull over grid_sample, especially in view of the requirement for compilation.

@kate-sann5100: Can you adapt you warp layer accordingly?

@tvercaut tvercaut reopened this Jan 19, 2021
@brudfors
Copy link
Contributor

nitorch does differentiable splatting, not sure if pytorch does?

@kate-sann5100
Copy link
Collaborator Author

@tvercaut That sounds promising. I will implement that. Thank you for the advice.

@kate-sann5100
Copy link
Collaborator Author

@brudfors I think there is no splatting function in pytorch. But in the Warp layer, the warp output should have the same shape as the input image, so splatting is not required.

@kate-sann5100
Copy link
Collaborator Author

May I know when the USE_COMPILED flag is True please? Because both when I run the code in my local machine (with CPU) and quick-py3 & min-dep-py3 checks, the USE_COMPILED flag is False. @tvercaut @brudfors

@brudfors
Copy link
Contributor

You need to ensure that whatever terminal process compiling your MONAI reads that the environment variable USE_COMPILED is set to true, so it depends how you are compiling MONAI, I guess. I use PyCharm and its in-built terminal environment, so I just set the environment variable within my PyCharm MONAI project.

@wyli
Copy link
Contributor

wyli commented Jan 19, 2021

May I know when the USE_COMPILED flag is True please? Because both when I run the code in my local machine (with CPU) and quick-py3 & min-dep-py3 checks, the USE_COMPILED flag is False.

the flag is defined here:

USE_COMPILED = HAS_EXT and os.getenv("BUILD_MONAI", "0") == "1"

it'll be True when env variable BUILD_MONAI=1 for both installing and running MONAI, for example: with MONAI installed with command: BUILD_MONAI=1 python setup.py develop --user and then

  • run test: BUILD_MONAI=1 python -m tests.test_resampler (will have USE_COMPILED True)
  • run test: BUILD_MONAI=0 python -m tests.test_resampler (will have USE_COMPILED False)

@kate-sann5100
Copy link
Collaborator Author

@brudfors
Sorry for bothering you again.
I tried to run the following code, where identity_grid is copied from nitorch

if __name__ == '__main__':
    def identity_grid(shape, dtype=None, device=None):
        """Returns an identity deformation field.
        Parameters
        ----------
        shape : (dim,) sequence of int
            Spatial dimension of the field.
        dtype : torch.dtype, default=`get_default_dtype()`
            Data type.
        device torch.device, optional
            Device.
        Returns
        -------
        grid : (*shape, dim) tensor
            Transformation field
        """
        mesh1d = [torch.arange(float(s), dtype=dtype, device=device)
                  for s in shape]
        grid = torch.meshgrid(*mesh1d)
        grid = torch.stack(grid, dim=-1)
        return grid

    image = torch.tensor([1, 2, 3, 4], dtype=torch.float).reshape(1, 1, 1, 4)
    grid = identity_grid((1, 4))[None, ...]
    out = grid_pull(
        image,
        grid,
        interpolation=2,
        bound="zero",
        extrapolate=True)
    print(out)

I was expecting to get out == image but got out = tensor([[[[0.7500, 1.5000, 2.2500, 2.5312]]]]). Am I doing some thing wrong or tensor([[[[0.7500, 1.5000, 2.2500, 2.5312]]]]) is the correct output?

@brudfors
Copy link
Contributor

Hi @kate-sann5100

Sorry for the slow reply. For interpolation order greater than 1, a pre-filtering step is needed to determine the b-spline coefficients. This pre-filtering is not yet implemented in nitorch. For certain applications, these coefficients are not needed (e.g., when the resampling is part of some forward model, and the coefficients are implicitly found when inverting the model). So depending on your particular use-case, interpolation > 1 might not work.

@tvercaut
Copy link
Member

For the record, in case we revisit higher-order interpolation. As the nitorch-based implementation only partially adresses the need, we have two options:

I think it make sense to keep this issue closed an followed up higher-order interpolation modes in #789

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants