Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

semantic segmentation #2

Open
garryz94 opened this issue Jan 24, 2022 · 2 comments
Open

semantic segmentation #2

garryz94 opened this issue Jan 24, 2022 · 2 comments

Comments

@garryz94
Copy link

garryz94 commented Jan 24, 2022

Thanks for your brilliant work!
I've been working on semantic segmentation with noisy labels recently. Have you tried this idea on any segmentation datasets? Or what changes can be done for apl-losses to adapt to segmentation tasks?

@HanxunH
Copy link
Owner

HanxunH commented Jan 24, 2022

Hi,

Thanks for your interest in our work. I did not try this on the segmentation dataset, but I'm very interested in the results. I don't think the current implementation can work on segmentation mask as the label, and it needs a few modifications. For segmentation mask [b,1,h,w], I think you can try to normalize on the channel dimension then sum the h and w, then mean across batches. I think the torch.nn.functional.one_hot() will not on masks. You probably need to write a custom function to convert the mask into one_hot format.

@garryz94
Copy link
Author

garryz94 commented Jan 25, 2022

Hi,

Thanks for your interest in our work. I did not try this on the segmentation dataset, but I'm very interested in the results. I don't think the current implementation can work on segmentation mask as the label, and it needs a few modifications. For segmentation mask [b,1,h,w], I think you can try to normalize on the channel dimension then sum the h and w, then mean across batches. I think the torch.nn.functional.one_hot() will not on masks. You probably need to write a custom function to convert the mask into one_hot format.

Thanks for your reply!
I modified your code for semantic segmentation as follows. It will be appreciated if you could help review the code.

    class NCELoss(torch.nn.Module):
        def __init__(self, ignore_label, class_weight=None):
            super(NCELoss, self).__init__()
            self.ignore_label = ignore_label
        def forward(self, pred, target, sample_weight=None):
            target = target.long()
            b, c, h, w = pred.shape
            logsoftmax = F.log_softmax(pred, dim=1)
    
            ohot = target
            mask = torch.ones_like(target).float().cuda()
            ohot[torch.where(target == self.ignore_label)] = 0
            mask[torch.where(target == self.ignore_label)] = 0
            ohot = torch.nn.functional.one_hot(ohot, c).reshape(b, c, h, w).float().cuda()  # BCHW
    
            nce = (-1 * torch.sum(ohot * logsoftmax, dim=1)) / (-1 * logsoftmax.sum(dim=1))  # BHW
            nce = nce * mask  # set weights of ignore label to 0
            nce = torch.sum(nce) / target.ne(self.ignore_label).sum()  # compute mean loss
            return nce

    class RCELoss(torch.nn.Module):
        def __init__(self, ignore_label, class_weight=None):
            super(RCELoss, self).__init__()
            self.ignore_label = ignore_label
        def forward(self, pred, target, sample_weight=None):
            target = target.long()
            b, c, h, w = pred.shape
            prob = F.softmax(pred, dim=1)
            prob = torch.clamp(pred, min=1e-7, max=1.0)
    
            ohot = target
            mask = torch.ones_like(target).float().cuda()
            ohot[torch.where(target == self.ignore_label)] = 0
            mask[torch.where(target == self.ignore_label)] = 0
            ohot = torch.nn.functional.one_hot(ohot, c).reshape(b, c, h, w).float().cuda()
            ohot = torch.clamp(ohot, min=1e-4, max=1.0)
    
            rce = -1 * torch.sum(prob * torch.log(ohot), dim=1)  # BHW
            rce = rce * mask
            rce = torch.sum(rce) / target.ne(self.ignore_label).sum()
            return rce

    class APLLoss(torch.nn.Module):
        def __init__(self, alpha, beta, ignore_label, class_weight=None):
            super(APLLoss, self).__init__()
            print("APLLoss: nce + rce")
            self.a = alpha
            self.b = beta
            self.nce = NCELoss(ignore_label=ignore_label)
            self.rce = RCELoss(ignore_label=ignore_label)
    
        def forward(self, pred, target):
            nce = self.nce(pred, target)
            rce = self.rce(pred, target)
            return self.a * nce + self.b * rce

Any progress will be updated here in a few days!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants