In [1]:
#Install Packages for vector DataBases
!pip install faiss-cpu
!pip install sentence-transformers



In [2]:
pip install tf-keras

Note: you may need to restart the kernel to use updated packages.


In [1]:
import pandas as pd

pd.set_option('display.max_colwidth', 100)

In [2]:
df = pd.read_csv("sample_text.csv")
df.shape

(8, 2)

In [3]:
df.head()

Unnamed: 0,text,category
0,Meditation and yoga can improve mental health,Health
1,"Fruits, whole grains and vegetables helps control blood pressure",Health
2,These are the latest fashion trends for this week,Fashion
3,Vibrant color jeans for male are becoming a trend,Fashion
4,The concert starts at 7 PM tonight,Event


## Step 1 : Create source embeddings for the text column

In [5]:
from sentence_transformers import SentenceTransformer

encoder = SentenceTransformer("all-mpnet-base-v2")
vectors = encoder.encode(df.text)


In [6]:
print(vectors.shape)

(8, 768)


In [7]:
dim = vectors.shape[1]
dim

768

## Step 2 : Build a FAISS Index for vectors

In [8]:
import faiss

index = faiss.IndexFlatL2(dim)
index

<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x000001F3CD17F780> >

## Step 3 : Normalize the source vectors (as we are using L2 distance to measure similarity) and add to the index

In [9]:
index.add(vectors)

In [10]:
index

<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x000001F3CD17F780> >

## Step 4 : Encode search text using same encorder and normalize the output vector

In [11]:
search_query = "I want to buy a polo t-shirt"
# search_query = "looking for places to visit during the holidays"
# search_query = "An apple a day keeps the doctor away"
vec = encoder.encode(search_query)
vec.shape

(768,)

In [12]:
import numpy as np
svec = np.array(vec).reshape(1,-1)
svec.shape

(1, 768)

## Step 5: Search for similar vector in the FAISS index created

In [14]:
distances, I = index.search(svec, k=2)
distances

array([[1.384484 , 1.4039094]], dtype=float32)

In [20]:
I

# means in the df 3,2 index are close to this query

array([[3, 2]], dtype=int64)

In [21]:
I.tolist()

[[3, 2]]