In [None]:
# !pip install -q torch torchvision torchaudio transformers datasets accelerate bitsandbytes langchain sentence-transformers faiss-gpu openpyxl pacmap ragatouille

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# First Part
import os
import json 
import re
import pickle
import jsonlines
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Second Part
import torch
from tqdm import tqdm
from typing import Optional, List, Tuple
from langchain.vectorstores import FAISS
from ragatouille import RAGPretrainedModel
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain.docstore.document import Document as LangchainDocument

pd.set_option("display.max_colwidth", None)  # This will be helpful when visualizing retriever outputs

In [None]:
path = '/content/drive/MyDrive/mnlpredators-project/' # CHANGE THIS TO YOUR PATH - [NICOLAS]
full_preference_pairs_path = path + 'data/full_preference_pairs.json'

### Preference Pairs Dataset - Questions Extraction

In [None]:
full_preference_pairs = pd.read_json(full_preference_pairs_path, orient='records', lines=False)
print(full_preference_pairs[:1])

In [None]:
# print('Number of questions:', len(full_preference_pairs))

In [None]:
# print 3 full questions
# print(full_preference_pairs[['question_id','course_id','question_complete']].iloc[10])
# print(full_preference_pairs[['question_id','course_id','question_complete']].iloc[20])
# print(full_preference_pairs[['question_id','course_id','question_complete']].iloc[30])

### Initialization of Reranker and Embedding Models

In [None]:
# Reranker model
RERANKER = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")

# Embedding model
EMBEDDING_MODEL_NAME = "thenlper/gte-small"

In [None]:
embedding_model = HuggingFaceEmbeddings(
    model_name=EMBEDDING_MODEL_NAME,
    multi_process=True,
    # model_kwargs={"device": "cpu"},
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True},  # Set `True` for cosine similarity
)

### Loading of the Embedding Vector Database

In [None]:
KNOWLEDGE_VECTOR_DATABASE = FAISS.load_local("faiss_index_8_keywords", embedding_model, allow_dangerous_deserialization=True)

### Getting the Most Relevant Document 

In [None]:
def get_most_relevant_document(
    question: str,
    knowledge_index: FAISS,
    reranker: Optional[RAGPretrainedModel] = None,
    num_retrieved_docs: int = 15,
    num_docs_final: int = 5,
) -> Tuple[str, List[LangchainDocument]]:
    # Gather documents with retriever
    # print("=> Retrieving documents...")
    relevant_docs = knowledge_index.similarity_search(query=question, k=num_retrieved_docs)
    relevant_docs = [doc.page_content for doc in relevant_docs]  # Keep only the text

    # Optionally rerank results
    if reranker:
        # print("=> Reranking documents...")
        relevant_docs = reranker.rerank(question, relevant_docs, k=num_docs_final)
        relevant_docs = [doc["content"] for doc in relevant_docs]

    # relevant_docs = relevant_docs[:num_docs_final]
    
    # Randomly sample num_docs_final documents
    if len(relevant_docs) > num_docs_final:
        relevant_docs = random.sample(relevant_docs, num_docs_final)

    # Build the final prompt
    context = "\nExtracted documents:\n"
    context += "".join([f"\nDocument {str(i)}:::\n" + doc for i, doc in enumerate(relevant_docs)])

    return relevant_docs, context

#### Testing the Relevance of the Document on a Simple Question

In [None]:
question = "What is a good distance metric to be used when you want to compute the similarity between documents independent of their length?"
relevant_docs, context = get_most_relevant_document(question, KNOWLEDGE_VECTOR_DATABASE, reranker=RERANKER)
print("The context is:", context)

### Generation of an Answer with GPT3.5

In [None]:
import gpt_wrapper
from gpt_wrapper.chat import Chat
from dotenv import load_dotenv
load_dotenv()

In [None]:
model_args={"temperature": 0.7, "top_p": 0.7, "presence_penalty": 0.0, "frequency_penalty": 0.0, "max_new_tokens": 1024}

In [None]:
def initial_prompt(question, context):
    prompt = f'''Answer the following question: "{question}".
        Use the following context if you deem necessary: "{context}". 
        If the question has options, specify the ID of the correct answer (A, B, C or D).
        Think step by step and explain your reasoning'''   
    return prompt

In [None]:
def generate_predictions_zero_shot(questions, model_args
):
    predictions = []
    instruction="You are a helpful educational AI bot that answers questions for a student. Keep your response truthful and concise"
    with jsonlines.open(f"data_wikipedia/rag_dataset_gpt3.5.jsonl", mode="w") as writer:

        for question_dict in tqdm(questions):
            question = question_dict['question_complete']  # Extract question text
            
            chat_id = random.randrange(0, 2**16,)
            chat = Chat.create(name=f"{chat_id}")
            
            # _, context = get_most_relevant_document(question, KNOWLEDGE_VECTOR_DATABASE, reranker=RERANKER)
            _, context = get_most_relevant_document(question, KNOWLEDGE_VECTOR_DATABASE, reranker=None) # No reranker to have different documents
            # print("The context is:", context)
            prompt = initial_prompt(question, context)
            print("The final prompt is:\n", prompt)
            
            message = chat.ask(prompt, instruction=instruction, model_args=model_args)

            preds = message.content.strip()
            if preds:
                pred = preds
            else:
                pred = "none"

            print("Predicted answer:", preds)
            predictions.append(pred)

            writer.write({"course_id": question_dict['course_id'], "question_id": question_dict['question_id'], 
                    "question_body": question, "answer": preds, "chat_id":chat_id})


### Final Generation - Full Dataset of 1522 Questions (~2h30-3h00)

In [None]:
questions = full_preference_pairs.to_dict('records')
generate_predictions_zero_shot(questions, model_args)