In [None]:
from datasets import load_dataset
import pandas

dataset = load_dataset("text", data_dir="")

In [None]:
import pinecone

pinecone.init(
    api_key="",
    environment="us-west1-gcp"
)

In [None]:
index = pinecone.Index("chat-rpi")

In [None]:
from sentence_transformers import SentenceTransformer

# load the retriever model from huggingface model hub
retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
retriever

In [None]:
from tqdm.auto import tqdm

# we will use batches of 64
batch_size = 64

for i in tqdm(range(0, len(dataset), batch_size)):
    # find end of batch
    i_end = min(i+batch_size, len(dataset))
    # extract batch
    batch = dataset.iloc[i:i_end]
    # generate embeddings for batch
    emb = retriever.encode(batch["text"].tolist()).tolist()
    # get metadata
    meta = batch.to_dict(orient="records")
    # create unique IDs
    ids = [f"{idx}" for idx in range(i, i_end)]
    # add all to upsert list
    to_upsert = list(zip(ids, emb, meta))
    # upsert/insert these records to pinecone
    _ = index.upsert(vectors=to_upsert)

# check that we have all vectors in index
index.describe_index_stats()

In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration

# load bart tokenizer and model from huggingface
tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa')

In [None]:
def query_pinecone(query, top_k):
    # generate embeddings for the query
    xq = retriever.encode([query]).tolist()
    # search pinecone index for context passage with the answer
    xc = index.query(xq, top_k=top_k, include_metadata=True)
    return xc

In [None]:
def format_query(query, context):
    # extract passage_text from Pinecone search result and add the <P> tag
    context = [f"<P> {m['metadata']['text']}" for m in context]
    # concatinate all context passages
    context = " ".join(context)
    # contcatinate the query and context passages
    query = f"question: {query} context: {context}"
    return query

In [None]:
query = "What does the acronym RPI stand for?"
result = query_pinecone(query, top_k=1)
result