In [2]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from IPython import embed
from local_distance import batch_local_dist
from hard_example_mining import hard_example_mining

In [3]:
class AlignedTripletLoss(nn.Module):
    def __init__(self, margin=0.3):
        super().__init__()
        self.margin=margin
        self.ranking_loss=nn.MarginRankingLoss(margin=margin)
        
    def forward(self, inputs, local_features, targets):
        n=inputs.size(0)
        local_features=local_features.squeeze()
        distance=torch.pow(inputs,2).sum(dim=1, keepdim=True).expand(n,n)
        distance=distance+distance.t()
        #下式等于distance-2inputs*inputs.T
        distance.addmm(1,-2,inputs,inputs.t())
        distance=distance.clamp(min=1e-12).sqrt()
        #获取正负样本对距离，使用最短路径算法
        dist_ap,dist_an,p_inds,n_inds=hard_example_mining(distance,targets,return_inds=True)
        p_inds,n_inds=p_inds.long(),n_inds.long()
        #print(local_features.size())
        #local_features=local_features.permute(0,2,1)
        
        p_local_features=local_features[p_inds]
        n_local_features=local_features[n_inds]
        local_dist_ap=batch_local_dist(local_features,p_local_features)
        local_dist_an=batch_local_dist(local_features,n_local_features)
        print(local_dist_an.size())
        
        y=torch.ones_like(dist_an)
        global_loss=self.ranking_loss(dist_an, dist_ap, y)
        local_loss=self.ranking_loss(local_dist_an,local_dist_ap,y)
        
        return global_loss, local_loss

In [5]:
if __name__=='__main__':
    target=[1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6,7,7,7,7,8,8,8,8]
    target=torch.Tensor(target)
    features=torch.rand(32,2048)
    local_features=torch.randn(32,128,3,1)
    a=AlignedTripletLoss()
    g_loss, l_loss=a.forward(features, local_features, target)
    print(g_loss)
    print(l_loss)

torch.Size([32])
tensor(0.9207)
tensor(4.1440)
