Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

output coords are negative floats #6

Closed
YuliangXiu opened this issue Jan 6, 2019 · 15 comments
Closed

output coords are negative floats #6

YuliangXiu opened this issue Jan 6, 2019 · 15 comments

Comments

@YuliangXiu
Copy link

YuliangXiu commented Jan 6, 2019

The network is defined by:

class Net(nn.Module):
    
    def __init__(self, layers):
        super(Net, self).__init__()
        if layers == 18:
            model = models.resnet18(pretrained=True)
        elif layers == 34:
            model = models.resnet34(pretrained=True)
        # change the first layer to recieve five channel image
        model.conv1 = nn.Conv2d(5, 64, kernel_size=7, stride=2, padding=3,bias=True)
        # change the last layer to output 32 coordinates
        # model.fc=nn.Linear(512,32)
        # remove final two layers(fc, avepool)
        model = nn.Sequential(*(list(model.children())[:-2]))
        for param in model.parameters():
            param.requires_grad = True
        self.resnet = model
        
    def forward(self, x):
       
        pose_out = self.resnet(x)
        return pose_out

class CoordRegressionNetwork(nn.Module):
    def __init__(self, n_locations, layers):
        super(CoordRegressionNetwork, self).__init__()
        self.resnet = Net(layers)
        self.hm_conv = nn.Conv2d(512, n_locations, kernel_size=1, bias=False)

    def forward(self, images):
        # 1. Run the images through our Resnet
        resnet_out = self.resnet(images)
        # 2. Use a 1x1 conv to get one unnormalized heatmap per location
        unnormalized_heatmaps = self.hm_conv(resnet_out)
        # 3. Normalize the heatmaps
        heatmaps = dsntnn.flat_softmax(unnormalized_heatmaps)
        # 4. Calculate the coordinates
        coords = dsntnn.dsnt(heatmaps)

        return coords, heatmaps

And the training codes are as follows:

for i, data in enumerate(tqdm(train_dataloader)):
            # training
            images, poses = data['image'], data['pose']
            images, poses = images.to(device), poses.to(device)
            coords, heatmaps = net(images)

            # Per-location euclidean losses
            euc_losses = dsntnn.euclidean_losses(coords, poses)
            # Per-location regularization losses
            reg_losses = dsntnn.js_reg_losses(heatmaps, poses, sigma_t=1.0)
            # Combine losses into an overall loss
            loss = dsntnn.average_loss(euc_losses + reg_losses)
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_epoch.append(loss.item())

I converted keypoint groundtruth into float(-1,1), but all the predicted coords are negative floats:

tensor([[[-0.1286, -0.0830],
         [-0.1169, -0.0810],
         [-0.1205, -0.1476],
         ...,
         [-0.1767, -0.3881],
         [-0.1970, -0.2403],
         [-0.3226, -0.3909]],

        [[-0.0694, -0.0165],
         [-0.0744, -0.0288],
         [-0.1027, -0.0873],
         ...,
         [-0.0766, -0.3926],
         [-0.1146, -0.2482],
         [-0.0907, -0.1812]],

        [[-0.4647, -0.3639],
         [-0.4430, -0.3409],
         [-0.2485, -0.2339],
         ...,
         [-0.2906, -0.4541],
         [-0.3648, -0.3034],
         [-0.4190, -0.3880]],

and the heatmap seems strange:

image

the visualization results:
image

@YuliangXiu YuliangXiu changed the title coords output are negative floats output coords are negative floats Jan 6, 2019
@YuliangXiu
Copy link
Author

YuliangXiu commented Jan 10, 2019

I think it is caused by the pytorch version compatibility

I use pytorch 0.4, and the test case appears

======================================================================
FAIL: test_cuda (tests.test_dsnt.TestDSNT)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/yuliang/code/dsntnn/tests/test_dsnt.py", line 71, in test_cuda
    self.assertEqual(in_var.grad, expected_grad)
  File "/home/yuliang/code/dsntnn/tests/common.py", line 106, in assertEqual
    assertTensorsEqual(x, y)
  File "/home/yuliang/code/dsntnn/tests/common.py", line 98, in assertTensorsEqual
    self.assertLessEqual(max_err, prec, message)
AssertionError: tensor(0.4800) not less than or equal to 1e-05 :

and when I set self.assertEqual(in_var.grad/2.0, expected_grad)
the test of test_dsnt is OK

@anibali
Copy link
Owner

anibali commented Jan 13, 2019

Yes, this code was written to work with PyTorch 0.2, and 0.4 introduced some breaking changes. I will attempt to update the project for newer versions of PyTorch.

@anibali
Copy link
Owner

anibali commented Jan 14, 2019

I have updated the code so that it should work with 0.4 and 1.0. The branch is https://github.com/anibali/dsntnn/tree/pytorch-1.0 and I have uploaded an alpha release to PyPI (https://pypi.org/project/dsntnn/0.4.0a0/). Please try this version and see if it helps.

@YuliangXiu
Copy link
Author

The new version works! The final results look pretty good! Very thanks for your update.

I am working on multi-person pose estimation, do you have some future plans to extend dsntnn module to non-fixed number keypoints regression problem?

@anibali
Copy link
Owner

anibali commented Jan 15, 2019

Glad to hear that the new version works.

I have put some thought into how DSNT might be applied to multi-person pose estimation, but unfortunately I don't think it's a trivial extension. If you come up with any ideas yourself, please let me know because I am definitely interested.

@simonhessner
Copy link

I have updated the code so that it should work with 0.4 and 1.0. The branch is https://github.com/anibali/dsntnn/tree/pytorch-1.0 and I have uploaded an alpha release to PyPI (https://pypi.org/project/dsntnn/0.4.0a0/). Please try this version and see if it helps.

This branch does not exist anymore. I saw that you have merged it into the master, is that correct?
However, the version 0.4.0 in pip seems to be a different one than 0.4.0a0. Was PyPi not yet updated?

@anibali
Copy link
Owner

anibali commented Jan 30, 2019

Those changes were merged into master and released again as dsntnn version 0.4.0. So now you can just use 0.4.0 from PyPi instead of the alpha release.

@simonhessner
Copy link

simonhessner commented Jan 30, 2019

Thanks for your response. Maybe I am doing something wrong, but when I install dsntnn via pip install --upgrade dsntnn, I get version 0.4.0

Installing collected packages: dsntnn
Found existing installation: dsntnn 0.4.0a0
Uninstalling dsntnn-0.4.0a0:
Successfully uninstalled dsntnn-0.4.0a0
Successfully installed dsntnn-0.4.0

Running my trainer script with this version gives me values outside (-1,1).

When I use pip install dsntnn==0.4.0a0 instead, I get version 0.4.0a0 as expected and the script runs perfectly with values inside (-1,1)

Installing collected packages: dsntnn
Found existing installation: dsntnn 0.4.0
Uninstalling dsntnn-0.4.0:
Successfully uninstalled dsntnn-0.4.0
Successfully installed dsntnn-0.4.0a0

So I assume that the package source used by pip was not updated? I am not very familiar with the mechanism that is used by pip to get the packages. Is it pulling it directly from github or do you as an author have to manually update the packet sources?

If I understand your last answer correctly, there should be no difference between pip install --upgrade dsntnn and pip install dsntnn==0.4.0a0 because the changes from the alpha were merged into the master.

@anibali
Copy link
Owner

anibali commented Jan 30, 2019

Interesting, thanks for letting me know. I'll check it out more closely tomorrow, sounds like something might be wrong with the 0.4.0 release.

@anibali
Copy link
Owner

anibali commented Jan 30, 2019

I just checked and the module source file dsntnn/__init__.py is identical in both 0.4.0a0 and 0.4.0, as it should be. Which is strange since you say that those versions behave differently.

Could you try running the following commands after using pip to install each version, and post the output here? This will show me whether the versions you are installing are somehow different.

$ pip list | grep dsntnn
$ md5sum `python -c "import dsntnn; print(dsntnn.__file__)"`

Here's what I get:

aiden@boba:~$ pip uninstall dsntnn                                      
aiden@boba:~$ pip install dsntnn==0.4.0                                 
aiden@boba:~$ pip list | grep dsntnn                                    
dsntnn                            0.4.0      
aiden@boba:~$ md5sum `python -c "import dsntnn; print(dsntnn.__file__)"`
f941c99e633cdb7258698131557d06bc  /home/aiden/Software/miniconda3/lib/python3.6/site-packages/dsntnn/__init__.py

aiden@boba:~$ pip uninstall dsntnn                                      
aiden@boba:~$ pip install dsntnn==0.4.0a0                               
aiden@boba:~$ pip list | grep dsntnn                                    
dsntnn                            0.4.0a0    
aiden@boba:~$ md5sum `python -c "import dsntnn; print(dsntnn.__file__)"`
f941c99e633cdb7258698131557d06bc  /home/aiden/Software/miniconda3/lib/python3.6/site-packages/dsntnn/__init__.py

So as you can see, the releases appear identical from my end, since the only source file has the same checksum of f941c99e633cdb7258698131557d06bc.

@simonhessner
Copy link

The output is the same you have posted. I get the same hash as you for both versions. So there seems to be no problem with the version. I have now switched to version 0.4.0 and it seems to work in most cases, but not always.

I have started my script a few times to see if I always get coordinates in (-1,1) and after 5 tries I got 1.0102 in one of the first mini batches in the first epoch. Do you have an idea how that could happen? In my understanding this could only occur when the heatmap is not normalized so that its values form a probability distribution, but I am using dsntnn.flat_softmax(unnormalized_heatmaps) as input for the DSNT layer.

I get too small or too large values with both versions 0.4.0 and 0.4.0a0 (which is expected as we have seen that they are the same)

@anibali
Copy link
Owner

anibali commented Jan 31, 2019

Could you please write code to save the unnormalised heatmaps with torch.save when you detect out-of-bounds coordinates? It's difficult for me to solve the problem without being able to reproduce it, but if I had a concrete heatmap tensor to work with it would be much easier.

@simonhessner
Copy link

I have saved the unnormalized heatmaps and wanted to see what the output will be if I load the heatmaps again and then apply dsntnn.flat_softmax and dstnn.dsnt. I was surprised that the coordinates then looked fine. After a bit of debugging I found that the error is not caused by the DSNT layer but by a affine transformation (STN) that I apply to the coordinates after I get them from the DSNT layer.

So my assumption that there is a bug in dsntnn was wrong. Sorry for the false alarm!

@anibali
Copy link
Owner

anibali commented Jan 31, 2019

No worries, I'm glad that you found your true problem!

@YuliangXiu
Copy link
Author

I merged dsntnn module into my previous single person pose estimator, MobilePose-pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants