# RAG receiver notebook

This notebook servers as the interface for experimental interaction with the RAG application.

In [1]:
import requests
import json
import pandas as pd
import os

Provide an address to the endpoint - in our case we've used an ngrock funnel to access the hosted models.

In [3]:
ng_link = os.getenv('NG_LINK', '')

In [4]:
request_post_completion = ng_link + "/v1/completions"
request_post_search = ng_link + "/v1/search"

In [None]:
df = pd.read_csv("data_cleaned.csv")

In [None]:
def get_search_results(question):
    """
    Retrieves search results for a given question using a POST request.

    Args:
        question (str): The question to search for.

    Returns:
        tuple: A tuple containing the context and context UID retrieved from the search results.
    """
    payload = {"query": question}
    result = requests.post(request_post_search, json=payload, stream=False)
    result_str = json.loads(result.content)
    results = result_str["result"]
    context = results.get("context", "")
    context_uid = results.get("context_uid", "")
    return context, context_uid

def generate_model_response(user_prompt, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95):
    """
    Generates a model's response based on the user prompt using completion API.

    Args:
        user_prompt (dict): Dictionary with "role" and "content" for the prompt.
        max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 256.
        temperature (float, optional): Sampling temperature, higher values make output more random. Defaults to 0.7.
        top_k (int, optional): Number of highest probability tokens considered for sampling. Defaults to 50.
        top_p (float, optional): Cumulative probability threshold for token selection. Defaults to 0.95.

    Returns:
        str: The generated response from the model.
    """
    role = user_prompt["role"]
    content = user_prompt["content"]
    prompt_full = f"<|{role}|>{content}<|end|>"
    payload = {"prompt": prompt_full, "max_tokens": max_new_tokens, "temperature": temperature}
    result = requests.post(request_post_completion, json=payload, stream=True)
    result_str = json.loads(result.content)
    model_output = result_str["text"]
    initial_response = model_output.split("<|end|>")[0]
    initial_response = initial_response.replace("<|assistant|> ", "")
    return initial_response

def get_summary(context):
    """
    Generates a summary for the provided context using the model.

    Args:
        context (str): The context to summarize.

    Returns:
        str: The generated summary.
    """
    summary_prompt = "Given the following text, generate a summary getting the most important points within 6 to 11 sentences.\nText: " + context
    user_prompt = {"role": "user", "content": summary_prompt+"\nSummary:"}
    summary = generate_model_response(user_prompt, max_new_tokens=512, temperature=0.15)
    summary = summary.replace("\n\n", "\n")
    return summary

def get_answered_question(question, summaries=None):
    """
    Provides an answer to the given question, optionally using additional summaries.

    Args:
        question (str): The question to be answered.
        summaries (dict, optional): Additional summaries to include in the context. Defaults to None.

    Returns:
        tuple: A tuple containing the answer, context, context page, and context UID.
    """
    context, context_uid = get_search_results(question)
    context_page = context_uid.split("_")[0]
    if summaries:
        context = context + "\n" + "\n".join(summaries.values())
    user_prompt = {"role": "user", "content": "You are going to receive a description of a medical task from the user. It is going to be combined with a text of reference providing you with source of truth. Give suggestions basing you answer only on this text. Do not hallucinate! Write only the answer of the question without more information.\nQuestion: "+ question_txt +"\nText: "+context+"\nSuggestions:"}
    answer = generate_model_response(user_prompt, max_new_tokens=512, temperature=0.15)
    return answer, context, context_page, context_uid

def get_followup_question(question, context_uid, df=df):
    """
    Generates a follow-up question based on the given question and context UID.

    Args:
        question (str): The initial question.
        context_uid (str): Context UID to find detailed context in the DataFrame.
        df (pandas.DataFrame): DataFrame containing the context data.

    Returns:
        tuple: A tuple containing the follow-up question and broad context.
    """
    row_index = df[df["context_uid"] == context_uid].index[0]
    broad_context = "\n".join(df.iloc[(row_index-1):(row_index+2)].context.values)
    followup_prompt = "Given the initial question and the text that you have received, generated a follow-up question to better understand the case. Ask questions only related to the task for clarification. Do not ask exam style question, but ask questions requiring contextual information from the user."
    user_prompt = {"role": "user", "content":  followup_prompt+ " Do not hallucinate! Write only the follow-up question without more information.\nQuestion: "+ question_txt +"\nText: "+broad_context+"\nFollow-up question:"}
    followup_question = generate_model_response(user_prompt, max_new_tokens=128, temperature=0.15)
    followup_question = followup_question.split("\n\n")[0]
    return followup_question, broad_context

In [None]:
initial_situation = input("Describe what you need my help with.\n")
visited_pages = {}
passes = 0
summaries = {}
#a later version would have a protection, query splitter, domain checker etc here
question_txt = initial_situation
answer, context, context_page, context_uid = get_answered_question(question_txt)
#TODO add the context chunk index too
if visited_pages.get(context_page):
    visited_pages[context_page]+=1
else:
    visited_pages[context_page]=1
final_ouput = answer + "\nTaken from page: " + str(context_page)
print("Context: " + context)
print("============================")
print("Answer: " + final_ouput)
while True:
    passes+=1
    satisfactory_answer = True if input("Are you happy with the answer?\n")=="Yes" else False
    if satisfactory_answer:
        print("Pleasure to help!")
        print("The pages visited were the following:")
        for key, value in visited_pages.items():
            print("Page number "+ str(key) + " was visited " + str(value) + " times")
        break
    else:

        row_index = df[df["context_uid"]==context_uid].index[0]
        followup, broad_context = get_followup_question(question_txt, context_uid)
        print("Follow-up question: "+followup)
        user_answer = input("Your answer: ")
        summaries[context_uid] = get_summary(broad_context)
        question_txt = user_answer
        answer, context, context_page, context_uid = get_answered_question(question_txt, summaries=summaries)

        if visited_pages.get(context_page):
            visited_pages[context_page]+=1
        else:
            visited_pages[context_page]=1
        final_ouput = answer + "\nTaken from page: " + str(context_page)
        print("Context: " + context)
        print("============================")
        print("Answer: " + final_ouput)