In [14]:
%%capture
!pip install sentence_transformers
!pip install faiss-cpu

In [15]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle

In [30]:
model = SentenceTransformer("BAAI/bge-large-en-v1.5")
dimensions = 1024
index = faiss.IndexFlatL2(dimensions)
vendor_categories  = {
    "Grocery and Supermarkets": [
        "food products.",
        "various household items.",
        "personal care products."
    ],
    "Restaurants and Food Services": [
        "places to eat.",
        "fast food.",
        "dining establishment.",
        "take-out, for-here, delivery."
    ],
    "Clothing and Apparel": [
        "Stores selling a variety of clothing.",
        "footwear.",
        "fashion.",
        "clothing styles."
    ],
    "Health and Beauty": [
        "pharmacies.",
        "beauty supply stores.",
        "health products.",
        "personal care products."
    ],
    "Electronics and Appliances": [
        "consumer electronics.",
        "household appliances.",
        "accessories for electronics and appliances."
    ],
    "Home and Garden": [
        "home improvement.",
        "gardening supplies.",
        "furniture.",
        "home decor."
    ],
    "Entertainment and Leisure": [
        "movie theaters.",
        "bookstores.",
        "hobby shops.",
        "various entertainment goods and services."
    ]
}

vendor_embeddings_mapping = {}
current_index = 0

for vendor, descriptions in vendor_categories.items():
    for description in descriptions:
        embedding = model.encode(description)
        index.add(np.array([embedding]))
        vendor_embeddings_mapping[current_index] = vendor
        current_index += 1

#faiss.write_index(index, "embeddings.index")
#with open('vendor_mapping.pk1', 'wb') as f:
    #pickle.dump(vendor_embeddings_mapping, f)
    

In [32]:
# can use these for asserts in test.p
print(index.ntotal,
      index.d,
      sep='\n')

26
1024


In [58]:
#search with KNN

from collections import Counter
test_embedding = model.encode("Longs CFRIO SF PEG BAG CR GYSR SPR WTR BOTTLE DEPOSIT")
distances, indices = index.search(np.array([test_embedding]), 5)
vendors = [vendor_embeddings_mapping[idx] for idx in indices[0]]
most_common_vendor = Counter(vendors).most_common(1)[0][0]
print(f'Distance to nearest neighboor \n{distances}', 
     f'Index of neighboor in database \n{indices}',
     vendors,
     most_common_vendor,
     sep="\n\n")

Distance to nearest neighboor 
[[0.7543362  0.81793356 0.82922924 0.8447415  0.88138545]]

Index of neighboor in database 
[[14  1 17 19 13]]

['Health and Beauty', 'Grocery and Supermarkets', 'Electronics and Appliances', 'Home and Garden', 'Health and Beauty']

Health and Beauty
