In [5]:
import torch 
import numpy as np
from torch.nn import functional as F

def euclidean_distance(qf, gf):
    m = qf.shape[0]
    n = gf.shape[0]
    dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
               torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
    dist_mat.addmm_(1, -2, qf, gf.t())
    return dist_mat.cpu().numpy()

'''
A quantitative metric for Identity density to replce visualization tools 
like t-SNE, which is random and only focus on few samples.
'''
def ID2(feats, pid):
    feats = F.normalize(feats , dim=1, p=2)
    pids = np.asarray(pid)

    id_set = set(pids)
    id_list = list(id_set)
    id_list.sort()
    id_center = []
    for i in id_list:
        mask = pids == i
        x = feats[mask].mean(dim=0)
        id_center.append(x)

    density = torch.zeros(feats.size(0))
    idx = 0
    for i in id_list:
        mask = pids == i
        center = id_center[idx].unsqueeze(0)
        density[mask] = torch.tensor(euclidean_distance(feats[mask], center)).squeeze(1)
        idx += 1
    return density


In [6]:
# Simulate fake feature vectors (10 people, 512-dimensional features)
num_samples = 10
feature_dim = 512
torch.manual_seed(0)  # for reproducibility
feats = torch.randn(num_samples, feature_dim)

# Simulate person IDs: 5 classes, 2 samples per person
pid = [i // 2 for i in range(num_samples)]  # [0,0,1,1,2,2,3,3,4,4]

# Test the ID2 function
density_scores = ID2(feats, pid)

# Print results
print("Feature vector shape:", feats.shape)
print("Person IDs:", pid)
print("Density scores:", density_scores)


Feature vector shape: torch.Size([10, 512])
Person IDs: [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
Density scores: tensor([0.5327, 0.5327, 0.4876, 0.4876, 0.5300, 0.5300, 0.5116, 0.5116, 0.4855,
        0.4855])


	addmm_(Number beta, Number alpha, Tensor mat1, Tensor mat2)
Consider using one of the following signatures instead:
	addmm_(Tensor mat1, Tensor mat2, *, Number beta = 1, Number alpha = 1) (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\utils\python_arg_parser.cpp:1661.)
  dist_mat.addmm_(1, -2, qf, gf.t())
