In [12]:
import glob
import sys
import os
import json
import pickle
import logging
import ast

import openai
from sentence_transformers import SentenceTransformer
from dotenv import find_dotenv, load_dotenv
from sentence_transformers import CrossEncoder
from pinecone import Pinecone, ServerlessSpec
from helpers import chat

root_path = '/home/ec2-user/sarang/wiki_cheat'

sys.path.insert(0, os.path.abspath(root_path))
os.chdir(root_path)

### Load the environment file
env_file_path = find_dotenv()
logging.info(f'env_file_path: {env_file_path}')
load_dotenv(env_file_path)

### API keys and tokens needed for interaction with external APIS
openai.api_key = os.getenv('OPENAI_API_KEY')
pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY'), pool_threads=30)

## Model and index Paths

In [15]:
encoder_model_path = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-27_20-14-10'
reranker_model_path = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
reader_model = "gpt-4-0125-preview"
index_name = "wiki-all-mpnet-base-v2-trained"

## Load relevant Models

In [16]:
encoder_model = SentenceTransformer(encoder_model_path, device='cuda')
reranker_model = CrossEncoder(reranker_model_path, max_length=512, device='cuda')
index = pc.Index(index_name)

## Helper methods for retrieving, reranking and reading of results

In [32]:
def retrieve_results(query):
    query_emb = encoder_model.encode(query, show_progress_bar=True, batch_size=512).tolist()
    result = index.query(
        vector=query_emb,
        top_k=100,
        include_metadata=True
    )
    return result['matches']

def rerank_results(query, results):
    reranked_results = []
    predictions = [(query, res['metadata']['text']) for res in results]
    scores = reranker_model.predict(predictions).tolist()
    results_with_scores = [ {'res':res, 'score':scores[idx]} for idx, res in enumerate(results)]
    for result in sorted(results_with_scores, key=lambda x: x['score'], reverse=True):
        result['res']['rerank_score'] = result['score']
        reranked_results.append(result['res'])
    return reranked_results

def get_reader_prompt():
    """
    This function returns a multiline string that serves as a prompt for reading the reranked data
    """
    prompt = """You are working as a Question Answering assistant. You will be given a question along with 10 passages. 
    You need to find the id of the passage that contains the answer to the question asked, 
    answer the question and provide a step by step reasoning for your answer.
    --- INPUT 1: Input Question. This input is delimited by triple @.
    @@@{input_question}@@@
    --- INPUT 2: Set of passages. This input is delimited by triple @.
    @@@{input_passages}@@@
    --- TASK 1: You need to find the id of the passage that contains the answer to the question asked, 
    answer the question and provide a step by step reasoning for your answer.
    --- CONSTRAINT 1: You should return a dictionary. That dictionary should have 3 keys. 
    First key is the 'id' which is the id of the passage that contains the answer, second key is 'answer' which is the answer to the question, 
    and Thrid key is the 'reason' which represents a step by step 
    reasoning for your answer. And 
    --- CONSTRAINT 2: If the answer is not present in any of the input passages you should return a dictionary with id as -1, answer as None and reason as Not present.
    --- CONSTRAINT 3: You should not add any extra text before/after the json object that you return. Just return a list of dictionaries without any text before or after.
    """
    return prompt
    
def read_results(query, reranked_results, read_top_k=10):
    relevant_passages = [ {'id': idx, 'text': res['metadata']['title'] +": " + res['metadata']['text'] } 
                         for idx, res in enumerate(reranked_results[:read_top_k]) ]
    prompt=  get_reader_prompt()
    prompt = prompt.format(**{"input_question":query , "input_passages":relevant_passages})
    try:
        completion = chat(prompt, model=reader_model, response_format="json_object")
        answer = ast.literal_eval(completion.choices[0].message.content)
        reader_answer = reranked_results[answer['id']], answer['answer']
    except Exception as e:
        reader_answer = reranked_results[0], ""
    return reader_answer

def run_pipeline(query):
    results = retrieve_results(query)
    reranked_results = rerank_results(query, results)
    result, answer = read_results(query,reranked_results )
    print(f"Question: {query}")
    print("----------------")
    if answer:
        print(f"Answer: {answer}")
    print(f"Wikipedia page: {result['metadata']['url']}")
    print(f"Evidence paragraph: {result['metadata']['text']}")
    print("################")
    return { "question": query, "answer":answer, "wikipedia_page":result['metadata']['url'], "paragraph":result['metadata']['text']}

## Perform Dry run

In [35]:
queries = ["How did anne frank die?",
          "How are glaciers formed?",
          "How are glacier caves formed?",
          "What is a desktop computer and where is it typically placed?",
          "When did the Western Allies invade mainland Italy during World War II?",
          "What elements make up the Earth's core (approximately) ?"]

results = [ run_pipeline(query) for query in queries]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Question: How did anne frank die?
----------------
Answer: Anne and her older sister, Margot, died at Bergen-Belsen concentration camp from typhus.
Wikipedia page: https://simple.wikipedia.org/wiki?curid=6312
Evidence paragraph: However, that was not to be. Anne's father, Otto Frank, lived through the war and came back to Amsterdam. He hoped that his family had survived too - but they had not. Of all the family, only he survived. His wife was killed at Auschwitz. Anne and her older sister, Margot, died at Bergen-Belsen concentration camp from typhus, a disease - only a month before the camp was freed by the Allied forces. When he got out, he found Anne's diary and published it.
################


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Question: How are glaciers formed?
----------------
Answer: Glaciers are formed because the snow in an area does not all melt in summer. Each winter, more snow is added, and the weight of all the snow creates pressure. This pressure turns the lower parts of the snow into ice. After many years, the glacier starts growing large and moves due to gravity.
Wikipedia page: https://simple.wikipedia.org/wiki?curid=34576
Evidence paragraph: A glacier is a large body of ice and snow. It forms because the snow in an area does not all melt in summer. Each winter, more snow is added. The weight of all the snow creates pressure. This pressure turns the lower parts of the snow into ice. After this happens for many years, the glacier will start growing large. It becomes so heavy that gravity causes the ice to move. It flows downwards like water but very slowly. A glacier only moves about per year. New snowfalls replace the parts that flow away.
################


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Question: How are glacier caves formed?
----------------
Answer: Glacier caves are formed by ice and glaciers.
Wikipedia page: https://simple.wikipedia.org/wiki?curid=214942
Evidence paragraph: A cave is a natural underground hollow space. They can have narrow passageways (corridors) and chambers (caverns). They are usually formed when underground acidic (sour) water wears away softer stones, such as limestone. Only the hard rock, such as granite, is left. Caves can also be formed during natural catastrophes, such as earthquakes, or by ice and glaciers.
################


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Question: What is a desktop computer and where is it typically placed?
----------------
Answer: A desktop computer is a small machine that has a screen (which is not part of the computer) and is typically placed on top of a desk.
Wikipedia page: https://simple.wikipedia.org/wiki?curid=112
Evidence paragraph: A "desktop computer" is a small machine that has a screen (which is not part of the computer). Most people keep them on top of a desk, which is why they are called "desktop computers." "Laptop computers" are computers small enough to fit on your lap. This makes them easy to carry around. Both laptops and desktops are called personal computers, because one person at a time uses them for things like playing music, surfing the web, or playing video games.
################


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Question: When did the Western Allies invade mainland Italy during World War II?
----------------
Answer: 3 September 1943
Wikipedia page: https://simple.wikipedia.org/wiki?curid=429305
Evidence paragraph: The Allied invasion of Italy was the invasion of mainland Italy by the Allies during World War II. The Allies landed on the mainland on 3 September 1943. The invasion followed the successful invasion of Sicily during the Italian Campaign.
################


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Question: What elements make up the Earth's core (approximately) ?
----------------
Answer: Iron (88.8%), nickel (5.8%), sulfur (4.5%), and less than 1% other things
Wikipedia page: https://simple.wikipedia.org/wiki?curid=219
Evidence paragraph: The structure of Earth changes from the inside to the outside. The center of earth (Earth's core) is mostly iron (88.8%), nickel (5.8%), sulfur (4.5%), and less than 1% other things. The Earth's crust is largely oxygen (47%). Oxygen is normally a gas but it can join with other chemicals to make compounds like water and rocks. 99.22% of rocks have oxygen in them. The most common oxygen-having rocks are silica (made with silicon), alumina (made with aluminium), rust (made with iron), lime (made with calcium), magnesia (made with magnesium), potash (made with potassium), and sodium oxide, and there are others as well.
################


## Some Limitations and Future improvement scope

### Query Reformulation and Dataset generation
#### 1. We can rewrite the users queries and further decompose them to find answers that could be present 
####    across multiple passages and articles.
#### 2. We can generate a synthetic dataset that requires the model to perform multi hop reasoning
#### 3. We can generate a much larger dataset.
#### 4. There could be other interesting ways of chunking the wikipedia articles, instead of just using passages. 
####    We can keep an overlap between multiple passages to further enhance context. 

### RAG pipeline
#### 1. We can introduce an augment step where we can generate a set of questions that need to be answered first
####    in order to find the answer to the original question
#### 2. Instead of showing the complete paragraph as the evidence, we can show the exact part of the paragraph which 
####    contains the answer 
#### 3. In the reader stage, instead of sending in 10 passages from k different articles, we can send all passages 
####    of these k different articles to provide a better context to the LLM
#### 4. We could find the embeddings of the entire wikipedia article and store them in our vector store. We then 
####    use a 2-step process to first narrow down on the correct article and then narrow down on the right passage 
####    in that article.

### Modelling
#### 1. Experiments with other Embedding models, loss functions, exhaustive hyper-parameter search
#### 2. Benchmarking using other commercial embedders. Ex - OpenAI just released their new v3 embedding model
#### 3. We can parse the wikipedia articles, extract their headings and use the section hierarchy to enhance the 
####    context of the passages.
#### 4. Instead of using the encoder based models for generating sentence embeddings, we can use an LLM(a decoder only)
####    to generate embeddings. There is a e5-mistral-7b-instruct model which was trained in this manner and is sitting
####    top of the embedding benchmarks
#### 5. We can use instruction fine tune a smaller LLM like phi-2 to perform the reader task instead of using GPT-4 there. 

### This field is evolving at such a rapid pace that, we are bombarded with new ideas and content every other day. 
### So lets keep our heads down, learn from the community and keep growing. Thank you. 