In [1]:
import numpy as np

import faiss

In [2]:
embedding = np.load("./data/embedding_10.npy")

print(embedding.shape)
embedding

(70000, 10)


array([[ 1.0809438 ,  4.5842113 ,  8.188814  , ...,  4.3815236 ,
         4.5348864 ,  3.8039389 ],
       [10.242349  ,  1.8976094 ,  5.18953   , ...,  5.2431035 ,
         3.8557544 ,  3.9528236 ],
       [10.297947  ,  3.2953832 ,  5.1501646 , ...,  5.2841597 ,
         3.912732  ,  4.7057314 ],
       ...,
       [ 9.052287  ,  2.36858   ,  5.025304  , ...,  5.19613   ,
         3.960764  ,  4.592086  ],
       [10.50785   , 10.595664  ,  4.872868  , ...,  5.2744155 ,
         3.4071124 ,  3.2184975 ],
       [ 0.59059334,  4.7471175 ,  3.3372943 , ...,  5.6807604 ,
         4.609006  ,  3.4885309 ]], dtype=float32)

In [3]:
labels = np.load("./data/labels.npy")

print(labels.shape)
labels

(70000,)


array([9, 0, 0, ..., 8, 1, 5])

In [4]:
# pick one
# target_labels = [0, 2, 6]  # tops
target_labels = [5, 7, 9]  # footwear

In [5]:
# re-label the data points
new_labels = np.zeros_like(labels)
for i, target in enumerate(target_labels):
    mask = (labels == target)
    new_labels[mask] = i + 1

In [6]:
# sub-sample the targets
prevalence = 0.01
num_positives = int(labels.size * prevalence * len(target_labels))  # doesn't actually give exact prevalence

negative_mask = (new_labels == 0)
negative_ind = np.where(negative_mask)[0]
positive_ind = np.where(~negative_mask)[0]

np.random.seed(0)
to_keep = np.concatenate(
    [
        negative_ind, 
        np.random.choice(positive_ind, size=num_positives, replace=False)
    ]
)

new_labels = new_labels[to_keep]
embedding = embedding[to_keep]

new_labels.shape, embedding.shape

((51100,), (51100, 10))

In [7]:
k = 500

embedding = embedding.copy(order="C").astype(np.float32)
embedding -= embedding.mean(axis=0)

index = faiss.IndexFlatL2(embedding.shape[1])
index.add(embedding)

sq_distances, nearest_neighbors = index.search(embedding, k + 1)

In [8]:
nearest_neighbors

array([[    0, 22018, 26175, ..., 17837, 10626, 24903],
       [    1, 28108, 37210, ..., 48028, 43639, 32450],
       [    2, 10818, 23250, ..., 32839, 45373, 40296],
       ...,
       [51097, 49858, 49720, ..., 49469, 49846, 50894],
       [51098,  3960, 51080, ..., 51008, 49741, 49191],
       [51099, 50344, 49833, ..., 50180, 50465, 49616]])

In [9]:
# similarities = 1 / (1 + sq_distances)
similarities = np.exp(- sq_distances)

print(similarities.min())
similarities

0.0007124537


array([[1.        , 0.99529237, 0.99514043, ..., 0.83368146, 0.83366877,
        0.83366555],
       [0.99999815, 0.99900866, 0.9982268 , ..., 0.6645135 , 0.6645028 ,
        0.66423035],
       [1.        , 0.99930984, 0.99891144, ..., 0.58166945, 0.58061975,
        0.5803042 ],
       ...,
       [0.99998474, 0.96465933, 0.9535226 , ..., 0.02088873, 0.02085115,
        0.02029059],
       [0.99998474, 0.9881838 , 0.9743936 , ..., 0.17039986, 0.17007515,
        0.16999212],
       [1.        , 0.99652714, 0.99616224, ..., 0.03699579, 0.03689544,
        0.03663849]], dtype=float32)

In [10]:
from scipy.io import savemat

savemat(
    "fashion_footwear_nearest_neighbors_500.mat",
    {
        "labels": new_labels + 1,
        "nearest_neighbors": nearest_neighbors[:, 1:] + 1,
        "similarities": similarities[:, 1:],
    },
)