ACCon is complementary to conventional imbalanced learning techniques. The following code snippent shows the implementation of ACCon loss for the task of Age estimation
def forward(self, features, labels=None):
batch_size = features.size()[0]
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
anchor_feature = contrast_feature
anchor_count = contrast_count
dist = (labels - labels.T).float().cuda()
mask = torch.eq(labels, labels.T).float().cuda()
dist = dist.repeat(anchor_count, contrast_count)
mask = mask.repeat(anchor_count, contrast_count)
phi = (1 - dist / self.max_inernal) * np.pi
cos_phi = torch.cos(phi)
sin_phi = torch.sin(phi)
cos_theta = torch.matmul(anchor_feature, contrast_feature.T)
cos_theta = torch.clamp(cos_theta, -1, 1)
sin_theta = torch.sqrt(1 - cos_theta ** 2 + self.tau)
logits = torch.div(cos_theta, self.temperature)
logits_max, _ = torch.max(logits, dim=1, keepdim=True)
logits = logits - logits_max.detach()
logits_mask = torch.scatter(torch.ones_like(mask), 1,
torch.arange(batch_size * anchor_count).view(-1, 1).cuda(), 0)
mask = mask * logits_mask
neg_logit = torch.div(cos_theta * cos_phi - sin_theta * torch.abs(sin_phi), self.temperature)
neg_logit[mask == 1] = logits[mask == 1]
exp_logits = torch.exp(neg_logit) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
mean_log_prob_pos = ((mask * log_prob).sum(1) + self.tau) / (mask.sum(1) + self.tau)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return lossThe code is based on Yang et al., Delving into Deep Imbalanced Regression, ICML 2021, Ren et al.,Balanced MSE for Imbalanced Visual Regression, CVPR 2022, and Keramati et al., ConR: Contrastive Regularizer for Deep Imbalanced Regression.