In [None]:
import os
from openai import OpenAI
from pydantic import BaseModel, Field

client = OpenAI(api_key = os.environ["OPENAI_API_KEY"])

#Vector store for files used for RAG, change this to whatever vector store you have for OpenAI. 
#See Create_Vector_Store_Example.ipynb in the repository to see how to do this.
vector_store=client.vector_stores.retrieve(vector_store_id="vs_67da9f09a6b48191a32189befe73c49e")

class rag_format(BaseModel):
    Ranked_Relevant_Information: str = Field(description="The ranked pieces of information that will be directly relevant for answering the query.")
    File_Sources: str = Field(description="The filenames of the files from which the information was retrieved, with the format '{...}.pdf'.")

class answer_format(BaseModel):
    Response: str = Field(description="The answer to the question/prompt using the given information only.")


def rag_agent(question, vector_store, rag_model) -> str:
    """
    Runs the OpenAI RAG, returning the answer to the inputted question, given the documents in the vector_store.
    
    Args:
        question: Question to be answered
        vector_store: OpenAI vector store containing the documents used to search for answers
        rag_model: LLM that can be used to power OpenAI RAG
    Returns:
        Answer to the inputted question
    """
    
    rag_message="""You are a retrieval agent tasked with performing file searches to find information for the purpose of providing answers.
        Find pieces of information that will be directly relevant for answering the query and rank these pieces of information from most relevant to least relevant
        You must quote the passages from the files directly. Do not paraphrase or change the text in any way.
        Do not include information unless you have a source for that piece of information. 
        If no information is relevant, you must return a single piece of information, where you state "No information found".
        Ideally, these pieces of information will be sentences, phrases, data points or sets of data points, but you have limited flexibility to include other pieces of information if you think they are appropriate.
        
        You must use tool call (i.e., file search).
        
        You know about the content of the code-base.
        """
    rag_assistant = client.beta.assistants.create(
        name="rag_test",
        instructions=rag_message,
        tools=[
            {"type": "file_search",
                "file_search":{
                    'max_num_results': 10,
                    "ranking_options": {
                        "ranker": "auto",
                        "score_threshold": 0.6
                    }
                }
            }
        ],
        tool_resources={"file_search": {"vector_store_ids":[vector_store.id]}},
        model=rag_model, 
        temperature = 0,
        top_p = 0.2,
        response_format= {
            "type": "json_schema",
            "json_schema": {
                "name": "answer",
                "schema": rag_format.model_json_schema()
            },
        }
    )
    
    thread = client.beta.threads.create(
                    messages=[],
                )
    
    parsed = client.beta.threads.messages.create(
                    thread_id=thread.id,
                    content=question,
                    role='user',
                )
    run = client.beta.threads.runs.create(
        thread_id=thread.id,
        assistant_id=rag_assistant.id,
        # pass the latest system message as instructions
        instructions=rag_message,
        tool_choice={"type": "file_search", "function": {"name": "file_search"}}
    )
    run = client.beta.threads.runs.retrieve(run.id, thread_id=thread.id)
    while run.status=="queued" or run.status=="in_progress":
        time.sleep(0.1)
        run = client.beta.threads.runs.retrieve(run.id, thread_id=thread.id)
    if run.status=="completed":
        response_messages = client.beta.threads.messages.list(thread.id, order="asc")
        for message in response_messages.data:
            for content in message.content:
                output=content.text.value
                if output.startswith("{"):
                    data=json.loads(output)
                    try:
                        answer = data.get("Ranked_Relevant_Information") or data.get("Ranked Relevant Information")
                        sources = data.get("File_Sources") or data.get("File Sources")
                    except:
                        print("Ranked_Relevant_Information/File_Sources not found", end="\r", flush=True)
    if not ("answer" in locals()):
        answer="No relevant information."
    if not ("sources" in locals()):
        sources="No relevant sources."
    client.beta.assistants.delete(assistant_id=rag_assistant.id)
    answer_message="""
    You are an answering agent tasked with answering a question or providing a summary only using the relevant information or prompts that are given to you, via the "Ranked Relevant Information".
    Generate a logical and reasoned response to the question or prompts only using the ranked relevant information.
    Use the question to provide context to the information before deciding if the information is relevant or not.
    If no file sources are given, you must answer "No information.".
    """
    answer_assistant = client.beta.assistants.create(
        name="answer_test",
        instructions=answer_message,
        model=rag_model, 
        temperature = 0.0,
        top_p = 0.2,
        response_format= {
            "type": "json_schema",
            "json_schema": {
                "name": "answer",
                "schema": answer_format.model_json_schema()
            },
        }
    )
    thread = client.beta.threads.create(
                    messages=[],
                )
    
    parsed = client.beta.threads.messages.create(
                    thread_id=thread.id,
                    content="Question: "+question+"\nRanked Relevant Information: "+answer,
                    role='user',
                )
    run = client.beta.threads.runs.create(
        thread_id=thread.id,
        assistant_id=answer_assistant.id,
        # pass the latest system message as instructions
        instructions=answer_message,
    )
    del answer
    run = client.beta.threads.runs.retrieve(run.id, thread_id=thread.id)
    while run.status=="queued" or run.status=="in_progress":
        time.sleep(0.1)
        run = client.beta.threads.runs.retrieve(run.id, thread_id=thread.id)
    if run.status=="completed":
        response_messages = client.beta.threads.messages.list(thread.id, order="asc")
        for message in response_messages.data:
            for content in message.content:
                output=content.text.value
                if output.startswith("{"):
                    data=json.loads(output)
                    try:
                        answer=data.get("Response")
                    except:
                        print("Response not found", end="\r", flush=True)
    if not ("answer" in locals()):
        answer="No information."
    client.beta.assistants.delete(assistant_id=answer_assistant.id)
    return answer