Skip to content

GeneralEmbodiedAI/ACCon-for-deep-regression

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 

Repository files navigation

ACCon: Angle-Compensated Contrastive Regularizer for Deep Regression

Quick Preview

ACCon is complementary to conventional imbalanced learning techniques. The following code snippent shows the implementation of ACCon loss for the task of Age estimation

    def forward(self, features, labels=None):
        batch_size = features.size()[0]
        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        anchor_feature = contrast_feature
        anchor_count = contrast_count

        dist = (labels - labels.T).float().cuda()
        mask = torch.eq(labels, labels.T).float().cuda()
        dist = dist.repeat(anchor_count, contrast_count)
        mask = mask.repeat(anchor_count, contrast_count)

        phi = (1 - dist / self.max_inernal) * np.pi

        cos_phi = torch.cos(phi)
        sin_phi = torch.sin(phi)

        cos_theta = torch.matmul(anchor_feature, contrast_feature.T)
        cos_theta = torch.clamp(cos_theta, -1, 1)
        sin_theta = torch.sqrt(1 - cos_theta ** 2 + self.tau)

        logits = torch.div(cos_theta, self.temperature)
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()

        logits_mask = torch.scatter(torch.ones_like(mask), 1,
                                    torch.arange(batch_size * anchor_count).view(-1, 1).cuda(), 0)

        mask = mask * logits_mask

        neg_logit = torch.div(cos_theta * cos_phi - sin_theta * torch.abs(sin_phi), self.temperature)
        neg_logit[mask == 1] = logits[mask == 1]
        exp_logits = torch.exp(neg_logit) * logits_mask

        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        mean_log_prob_pos = ((mask * log_prob).sum(1) + self.tau) / (mask.sum(1) + self.tau)

        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()
        return loss

Acknowledgment

The code is based on Yang et al., Delving into Deep Imbalanced Regression, ICML 2021, Ren et al.,Balanced MSE for Imbalanced Visual Regression, CVPR 2022, and Keramati et al., ConR: Contrastive Regularizer for Deep Imbalanced Regression.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published