In [2]:
import torch
import torch.nn.functional as F

In [44]:
def PCA(X, X_center):# X N by Dim; X_center Dim
    # preprocess the data
    X = X - X_center.expand_as(X)
    # svd
    U,S,V = torch.svd(torch.t(X))
    return U.t(), torch.mul(S, S)/(X.size(0)-1)

# Matrix Cosine Similarity
def distMC(Mat_A, Mat_B, norm=1, cpu=False, sq=True):#N by F
    N_A = Mat_A.size(0)
    N_B = Mat_B.size(0)
    DC = Mat_A.mm(torch.t(Mat_B))
    return DC.fill_diagonal_(-norm)

def mahalanobis(x, y, cov):
    diff = x-y
    covI = cov.inverse()
    a = covI.matmul(diff.t())
    b = diff.matmul(a)
    return b.sqrt()

def cov(m):# input Dim by N
    m_exp = torch.mean(m, dim=1)
    x = m - m_exp[:, None]
    cov = x.mm(x.t()) / (x.size(1) - 1)
    return cov


In [2]:
tra_fvec = torch.load('data/39traFvecs.pth')
val_fvec = torch.load('data/39valFvecs.pth')
val_dset = torch.load('data/valdsets.pth')

In [65]:
def resort(X_NN,X,rt=0.8):
    N = X_NN.size(0)
    Delta = X_NN - X.expand_as(X)#F.normalize(X_NN - X.expand_as(X), p = 2, dim = 1)
    baseD = Delta[0,:].view(-1,1)
    NN_list = [0]
    NN_remn = torch.LongTensor([i for i in range(N)])
    cosine = torch.zeros(N)
    
    for i in range(5):
        Delta = Delta[1:,:]
        NN_remn = NN_remn[1:]
        
        cosine = Delta.mm(baseD).view(-1)+cosine[1:]*0.1
        val, idx = cosine.sort(0,descending=True)
        
        if i+1 in idx[:int(len(idx)*rt)].tolist():
            NN_list.append(NN_remn[0].item())
            baseD = Delta[0,:].view(-1,1)
        else:
            Delta = Delta[idx,:]
            NN_remn = NN_remn[idx]
            NN_list.append(NN_remn[0])
            baseD = Delta[0,:].view(-1,1)

    return torch.LongTensor(NN_list)

In [46]:
def accnew(N_topk=10):
    val_fvec = torch.load('data/39valFvecs.pth')
    val_dset = torch.load('data/valdsets.pth')
    N = val_fvec.size(0)
    
    # get the NN list
    Dist = distMC(val_fvec,val_fvec)
    NN_val, NN_idx = Dist.sort(1, descending=True)
    
    tp1 = 0
    tp2 = 0
    tp3 = 0
    tp4 = 0
    for i in range(N):
        lab = val_dset.idx_to_class[i]
        NN_list_i = NN_idx[i,:N_topk]
        NN_list_i_resort_idx = resort(val_fvec[NN_list_i,:], val_fvec[i,:])
        
        NN_list_i_new = NN_list_i[NN_list_i_resort_idx]
        
        if lab == val_dset.idx_to_class[NN_list_i_new[0].item()]:
            tp1+=1
        if lab == val_dset.idx_to_class[NN_list_i_new[1].item()]:
            tp2+=1
        if lab == val_dset.idx_to_class[NN_list_i_new[2].item()]:
            tp3+=1
        if lab == val_dset.idx_to_class[NN_list_i_new[3].item()]:
            tp4+=1

    print(tp1/N,tp2/N,tp3/N,tp4/N)

In [66]:
accnew()

0.7311523797810847 0.6533021768540155 0.628213011929652 0.596113639158775


In [None]:
0.7311523797810847 0.6636330094699299 0.6259992620833845 0.6037387775181404

In [71]:
def acc(N_top = 2, N_pca = 1, th = 0.7):
    N = val_fvec.size(0)
    
    # get the NN list
    Dist = distMC(val_fvec,val_fvec)
    NN_val, NN_idx = Dist.sort(1, descending=True)
    
    # local pca for major direction
    val_fvec_pca = []
    for i in range(N):
        NN_list_i = NN_idx[i,:N_top+1]
        # pca n_component should less than nn size
        comp = PCA(val_fvec[NN_list_i,:], val_fvec[i,:], k=N_pca)
        val_fvec_pca.append(F.normalize(comp, p = 2, dim = 1))
    val_fvec_pca_tensor = torch.stack(val_fvec_pca,0)

    # balabala
    tp1 = 0
    tp2 = 0
    for i in range(N):
        lab = val_dset.idx_to_class[i]
        NN_list_i = NN_idx[i,1:N_top+1]

        NN_cos_i = []
        for j in range(N_top):
            NN_id = NN_list_i[j]
            rel_proj = torch.mm(val_fvec_pca_tensor[NN_id,:,:], (val_fvec[i,:]-val_fvec[NN_id,:]).view(-1,1))
            rel_dist = rel_proj.pow(2).sum().sqrt()
            NN_cos_i.append(rel_dist)

        NN_cos_i = torch.stack(NN_cos_i,0)

        re_idx = NN_cos_i.sort(0,descending=True)[1]
        re_rank_list = NN_list_i[re_idx].view(-1)
        
        if lab == val_dset.idx_to_class[NN_idx[i,0].item()]:
            tp1+=1
        if lab == val_dset.idx_to_class[re_rank_list[0].item()]:
            tp2+=1

    print(tp1/N,tp2/N)

In [76]:
# for t in [0.4,0.5,0.6,0.7,0.8,0.9]:
acc(N_top = 2, N_pca = 1, th=t)

0.7311523797810847 0.6630180789570779


In [None]:
for j in range(5):
    for i in range(1,j):
        print(j,i)
        acc(N_top = j, N_pca = i)

In [None]:
0.7316443241913664 0.6637559955725003