Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] convert_points_from_homogeneous - NaN gradients in backward pass #367

Closed
poxyu opened this issue Dec 6, 2019 · 19 comments · Fixed by #369
Closed

[Bug] convert_points_from_homogeneous - NaN gradients in backward pass #367

poxyu opened this issue Dec 6, 2019 · 19 comments · Fixed by #369
Assignees
Labels
bug 🐛 Something isn't working

Comments

@poxyu
Copy link

poxyu commented Dec 6, 2019

I just experienced a NaN-gradient problem while doing a backward pass here:

torch.tensor(1.) / z_vec,

torch.where works absolutely fine, but if you have zero divisions you find yourself with NaN-gradients for sure 💩

Here is a toy example:

eps = 1e-8

z_vec: torch.Tensor = torch.tensor([4., 6., 0., -3., 1e-9], requires_grad=True)

scale: torch.Tensor = torch.where(
    torch.abs(z_vec) > eps,
    torch.tensor(1.) / z_vec,
    torch.ones_like(z_vec)
)
scale.backward(torch.ones_like(scale))

And these are z_vec gradients:
tensor([-0.0625, -0.0278, nan, -0.1111, -0.0000])

For now my little hack is:

...
    # we check for points at infinity
    z_vec: torch.Tensor = points[..., -1:]
    if z_vec.requires_grad:
        def z_vec_backward_hook(grad: torch.Tensor) -> torch.Tensor:
            grad[grad != grad] = 0.
            return grad
        z_vec.register_hook(z_vec_backward_hook)
...

But not sure if it's good enough.

@edgarriba
Copy link
Member

@poxyu thanks for reporting, I will investigate it

@edgarriba
Copy link
Member

with eps equals to 1e-5 seems to work on my computer

@poxyu
Copy link
Author

poxyu commented Dec 7, 2019

Even in this toy example?
In my case it happend during homography regression. I initialized my homography module with some "strong" transformation and after several iterations grads became NaN (inside kornia.HomographyWarper).
As I told you there are no problems during forward pass (I tested all homographies manually) and shit happens only in backward pass, when one of the z_vec value becomes inf after division.

@poxyu
Copy link
Author

poxyu commented Dec 7, 2019

I will try to give you a real example instead of this toy one.

@poxyu
Copy link
Author

poxyu commented Dec 7, 2019

@edgarriba this is just an example. After about 13-14 iterations dst_homo_src() has three NaNs in it:

...

class MyHomography(nn.Module):

    def __init__(self, init_homo: torch.Tensor) -> None:
        super().__init__()
        self.homo = nn.Parameter(init_homo.clone().detach())

    def forward(self) -> torch.Tensor:
        return torch.unsqueeze(self.homo, dim=0)


num_iterations = 400
device = torch.device("cuda")

img_src: np.ndarray = cv2.imread("img_src.png", 0)
img_dst: np.ndarray = cv2.imread("img_dst.png", 0)
img_src_t: torch.Tensor = (kornia.image_to_tensor(img_src).float() / 255.).to(device)
img_dst_t: torch.Tensor = (kornia.image_to_tensor(img_dst).float() / 255.).to(device)

init_homo: torch.Tensor = torch.from_numpy(
    np.array([
        [0.0415, 1.2731, -1.1731],
        [-0.9094, 0.5072, 0.4272],
        [0.0762, 1.3981, 1.0646]
    ])
).float()

height, width = img_dst_t.shape[-2:]
warper = kornia.HomographyWarper(height, width)
dst_homo_src = MyHomography(init_homo=init_homo).to(device)

learning_rate = 1e-3
optimizer = optim.Adam(dst_homo_src.parameters(), lr=learning_rate)

for iter_idx in range(num_iterations):
    # warp the reference image to the destiny with current homography
    img_src_to_dst = warper(img_src_t, dst_homo_src())

    # compute the photometric loss
    loss = F.l1_loss(img_src_to_dst, img_dst_t)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

...

img_src.png:
img_src

img_dst.png:
img_dst

@poxyu
Copy link
Author

poxyu commented Dec 7, 2019

tell me if you need more information

@edgarriba
Copy link
Member

been trying a couple of things so far. First, created this small unit test:

def test_gradcheck_zeros(self, device):                                      
     points_h = torch.tensor([[1., 2., 3., 4., 5.],                           
                              [4., 6., 0., -3., 1e-9]]).t().to(device)        
                                                                              
     # evaluate function gradient                                             
     points_h = tensor_to_gradcheck_var(points_h)  # to var                   
     assert gradcheck(kornia.convert_points_from_homogeneous, (points_h,),    
                      raise_exception=True)                                   

with your fix it doesn't pass

what I've tried is the following,

scale: torch.Tensor = torch.where(            
    torch.abs(z_vec) > eps,                   
    torch.tensor(1.) / (z_vec + eps), -> torch.tensor(1.) / z_vec.clamp(min=eps), 
    torch.tensor(1.) / (z_vec),               
    torch.ones_like(z_vec))

plus reducing the epsilon to eps -> 1e-5 and it passes all the tests except in those where the points has negative values (for the clamp close to positive zero)

the other trick is to directly add an epsilon value to the z_vec which also breaks other tests related to geometry transforms.

Not really sure what should be the trade-off here, but it's an issue that has been here for a while and discussed at the early stage of the library. Check this: https://www.google.com/url?q=https://github.com/tensorflow/graphics/issues/17&sa=D&source=hangouts&ust=1575817956956000&usg=AFQjCNF7s96oq3UR0zzNUQ0S1TuA8x7rKw

@edgarriba
Copy link
Member

another thing to consider is a bug in torch.where since should be able to bypass the gradient correctly based on the binary mask generated

@poxyu
Copy link
Author

poxyu commented Dec 8, 2019

https://discuss.pytorch.org/t/gradients-of-torch-where/26835/2

my current workaround is:

scale: torch.Tensor = torch.where(
    torch.abs(z_vec) > eps,
#        torch.tensor(1.) / z_vec,
    torch.tensor(1.) / z_vec.masked_fill(z_vec == 0., eps),
    torch.ones_like(z_vec))

thank you, @edgarriba

@edgarriba
Copy link
Member

@poxyu great ! check #369 , I had to increase the epsilon value otherwise the numerical test didn't pass

@poxyu
Copy link
Author

poxyu commented Dec 9, 2019

@edgarriba I love everything about your PR except this big epsilon 😂

@edgarriba
Copy link
Member

well, it's the only way I found to pass the provided test for gradient check and to not break the other existing tests of the whole framework. I've been also playing a bit with the gradcheck tolerance, but it seems that eps 1e-5 is the magic number that makes gradcheck happy :D

@ducha-aiki
Copy link
Member

ducha-aiki commented Dec 9, 2019

Guys, what about?

eps=1e-8
mask: torch.Tensor = z_vec.abs() > eps
scale: torch.Tensor = torch.ones_like(z_vec).masked_scatter_(mask, 1.0 / z_vec[mask])

Although it is basically the same :)
image

image

@edgarriba
Copy link
Member

@ducha-aiki the test I provided above is still failing with your solution with an epsilon smaller than 1e-5

@ducha-aiki
Copy link
Member

Well, actually, gradient is not correct here, as we are putting 1 instead of correct huge number/inf.
So I am in favor of dropping your test and just use @poxyu version, if it works in practice and not generates NaNs

@edgarriba
Copy link
Member

sure, then I believe we should somehow provide @poxyu's test (or a small version of it) in order to assure correctness of this function since can be a bit critical for others functions that rely on it.

@edgarriba
Copy link
Member

@poxyu @ducha-aiki I'va updated #369 with a real case test (using random data). Check it out,

@edgarriba edgarriba added the bug 🐛 Something isn't working label Dec 9, 2019
@edgarriba edgarriba self-assigned this Dec 9, 2019
@poxyu
Copy link
Author

poxyu commented Dec 11, 2019

Similar problem (zero division and NaN gradients) occurs in backward pass here:

theta = torch.sqrt(theta2)

Super simple toy example:

with torch.autograd.detect_anomaly():
    r_vec = torch.tensor([[0., 0., 0.]], requires_grad=True)
    r_mat = kornia.angle_axis_to_rotation_matrix(r_vec)
    r_mat.backward(torch.ones_like(r_mat))

I guess it has to be a separate issue, hasn't it? 🙂

P.S. - found it several minutes ago and haven't fixed it myself yet.

@edgarriba
Copy link
Member

Yes, please open a separated issue. Will close this once we merge #369
Regarding this new issue, I would give a try at this code from tf graphics,
https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py since at least it's more clean

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants