Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PatchInferer with AvgMerger and filter_fn leads to NaNs #7743

Closed
nicholas-greig opened this issue May 6, 2024 · 4 comments
Closed

PatchInferer with AvgMerger and filter_fn leads to NaNs #7743

nicholas-greig opened this issue May 6, 2024 · 4 comments

Comments

@nicholas-greig
Copy link

Describe the bug
On master currently, when using the PatchInferer class with an AvgMerger (the default Merger class), and a filter_fn, the counts will be zero everywhere the filter_fn filters a region. Then, when the AvgMerger.finalize() is called, the self.values attr of AvgMerger is in-place divided by the self.counts tensor. This is an issue, since the self.counts tensor is initialised to zero, and div by zero causes NaNs. So, everywhere that a filter_fn successfully filters a region, we get NaN outputs.

A quick inplace assignment to counts (to set counts to 1, for example), will set all of these values to zero after this inplace division, but if the output is supposed to be real valued/continuous, it might be better to inplace overwrite these values to be the smallest value possible (using torch.finfo(self.values.dtype).min or something similar). Monkey patching the outputs from an Inferer isn't the best situation, since a network can produce NaNs due to weights exploding or overflow during training, and masking this with by overwriting NaNs to zero would merely obfuscate that problem.

@KumoLiu
Copy link
Contributor

KumoLiu commented May 6, 2024

Hi @nicholas-greig, could you please share a small piece of code that I can reproduce the issue?

Thanks.

@nicholas-greig
Copy link
Author

from monai.inferers.splitter import SlidingWindowSplitter
from monai.inferers.inferer import PatchInferer
import torch 
H,W = 512,512
def filter_fn(x,location):
    if location[1]>H//2:
        return False
    return True
    
splitter = SlidingWindowSplitter(
    (128,128),
    overlap=0,
    offset=0,filter_fn=filter_fn
)

inferer = PatchInferer(
    splitter,
)
inputs = torch.randn((1,1,H,W))
outputs = inferer(inputs=inputs,
                  network = lambda x: x)

print(torch.sum(torch.isnan(outputs[0])))
import matplotlib.pyplot as plt
plt.imshow(torch.isnan(outputs[0]).squeeze())
plt.show()

@nicholas-greig
Copy link
Author

@KumoLiu bump

@KumoLiu
Copy link
Contributor

KumoLiu commented Jul 2, 2024

Hi @nicholas-greig, sorry for the later response. After taking a look at your code, I guess the problem is that you set the filter_fn in the Splitter which is used to filter patches. If you set it to None, then it will works as your expected.

filter_fn: a callable to filter patches. It should accepts exactly two parameters (patch, location), and

Hope it helps, thanks.

@Project-MONAI Project-MONAI locked and limited conversation to collaborators Jul 2, 2024
@KumoLiu KumoLiu converted this issue into discussion #7898 Jul 2, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants