In [28]:
%pip install -qU datasets pinecone-client sentence-transformers torch

Note: you may need to restart the kernel to use updated packages.


In [29]:
import torch

In [30]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [31]:
from sentence_transformers import SentenceTransformer

retriver = SentenceTransformer('flax-sentence-embeddings/all_datasets_v3_mpnet-base', device=device)
retriver

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [32]:
retriver.get_sentence_embedding_dimension()

768

In [33]:
API_KEY = 'a946eb79-3ed2-4ea3-bab5-7c97665b4f5f'
ENVIRONMENT = 'gcp-starter'

In [34]:
import pinecone
pinecone.init(
    api_key=API_KEY,
    environment=ENVIRONMENT
)

In [35]:
index_name = 'generative-text-comprehension-qa'

if index_name not in pinecone.list_indexes():
  pinecone.create_index(
      index_name,
      dimension=retriver.get_sentence_embedding_dimension(),
      metric='cosine'
  )
index = pinecone.Index(index_name)

In [36]:
from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained('vblagoje/bart_lfqa')
generator = BartForConditionalGeneration.from_pretrained('vblagoje/bart_lfqa')

In [37]:
def query_pinecone(query, top_k):
  xq = retriver.encode([query]).tolist()
  xc = index.query(xq, top_k=top_k, include_metadata=True)
  return xc

In [38]:
def format_query(query, context):
  context = [f'<P> {m["metadata"]["passage_text"]}' for m in context]
  context = " ".join(context)
  query = f'question: {query} context: {context}'
  return query

In [39]:
from pprint import pprint

In [40]:
def generate_answer(query):
  inputs = tokenizer([query], max_length=1024, return_tensors='pt')
  ids = generator.generate(inputs['input_ids'], num_beams=2, min_length=20, max_length=40 )
  answer = tokenizer.batch_decode(ids, skip_special_token=True, clean_up_tokenization_spaces=False)[0]
  return answer
  # pprint(answer)

In [44]:
query = 'when was Michael Jackson born'
context = query_pinecone(query, top_k=2)
query = format_query(query, context['matches'])

In [45]:
generated_answer = generate_answer(query) 

In [46]:
ans = generated_answer.replace('</s>','')
print(ans)

Michael Jackson was born on January 21, 1981 in Los Angeles, California. He was the youngest child of Michael Jackson and Katherine Jackson, and the son of Michael Jackson's first wife, Katherine
