In [59]:
import pandas as pd
import numpy as np
import torch
from datasets import Dataset

from transformers import AutoTokenizer, AutoModel

model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [88]:
class Embeddings:
    #CLS is a special classification token and the last hidden state of BERT Embedding
    def cls_pooling(self, model_output):
        return model_output.last_hidden_state[:, 0]

    #BERT tokenizer of input text
    def get_embeddings(self, text_list):
        encoded_input = tokenizer(
            text_list, padding=True, truncation=True, return_tensors="pt"
        )
        encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
        model_output = model(**encoded_input)
        return self.cls_pooling(model_output).cpu().detach().numpy()
    
    
    #convert dataset into embeddings dataset to run FAISS
    def makeEmbeddings(self,dataset):
        embeddings = []
        for data in dataset:
            embeddings.append(Embedding().get_embeddings(data)[0])
        return np.array(embeddings)
    
    def getQueryEmbedding(self, query):
        return Embedding().get_embeddings([query])
    
class Faiss:
    def __init__(self):
        pass

    def faiss(self,xb):
        d = 768
        M = 32
        index = faiss.IndexHNSWFlat(d, M)            
        index.hnsw.efConstruction = 40         # Setting the value for efConstruction.
        index.hnsw.efSearch = 16               # Setting the value for efSearch.
        index.add(xb)
        return index
    
    def query(self,index,xq):
        D, I = index.search(xq, k)   
        return D, I

In [89]:
values = ["julia","vedha","isabelle"]
embeddings_dataset = Embeddings().makeEmbeddings(values)
xb = embeddings_dataset
xq = Embeddings().getQueryEmbedding("julia")

In [92]:
index = Faiss().faiss(xb)
D,I = Faiss().query(index,xq)
I

array([[0, 2, 1]])