In [None]:
import os
import dspy
import cohere

from datetime import datetime
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from cohere.responses.rerank import RerankResult

from dspy.predict import Retry
from dspy.primitives.assertions import assert_transform_module, backtrack_handler

from typing import Dict, Optional, List
from tqdm import tqdm

from utils import pp, ppj, ppjf
from utils import exists, defaults
from citations_utils import create_reference_nodes
from rag_logger import logger
from rag_config import RAGConfig
from chroma_db_retriever import ChromadbRM
from rag_utils import * # change
from signatures import * # change
from rag import RAG, load_rag

from dotenv import load_dotenv
load_dotenv()

rag_config = RAGConfig()


In [None]:
embedding_function = OpenAIEmbeddingFunction(
    api_key=os.environ.get('OPENAI_API_KEY'),
    model_name=rag_config.default_embed_model
)

retriever_model = ChromadbRM(
    collection_name=rag_config.collection_name, 
    persist_directory=rag_config.default_client_path, 
    embedding_function=embedding_function,
    k=rag_config.retrieve_top_k,
)

lm_model = dspy.OpenAI(model=rag_config.default_llm_name, **rag_config.llm_kwargs)
rerank_model = cohere.Client(api_key=os.environ.get('COHERE_API_KEY'))

dspy.settings.configure(lm=lm_model, trace=[], temperature=rag_config.llm_kwargs["temperature"]*3, rm=retriever_model) # TODO: Remove random number

In [None]:
def gen_rag_case(name: str, tool_results: str = "tool_results", rag_results: str = "rag_results"):

    """Return all data for running 1 case over the agent with RAG"""

    from collections import namedtuple
    from patient_cases import cases
    from med_agent import MedOpenAIAgent
    #from agent_tools import openai_agent_tools
    from agent_tools_dummy import openai_agent_tools

    PatientItem = namedtuple("patient_data", ["context", "question", "agent_tools"])

    rag = load_rag()
    agent = MedOpenAIAgent.from_tools(tools=openai_agent_tools, rag=rag)

    patient = cases[name]
    context, question = patient

    pat = PatientItem(context=context, question=question, agent_tools=agent.tools)
    
    tool_result = agent.chat_ext(context=pat.context, question=pat.question, use_rag=False) # use_rag = False => save first and use rag then later

    pp(tool_result)
    
    if not os.path.exists(tool_results):
        os.makedirs(tool_results)

    with open(f"{tool_results}/{name}.txt", "w") as f:
        f.write(tool_result)

    print(agent._rag.check_citations)

    patient_result = agent._rag(pat.question, pat.context, tool_results, pat.agent_tools, rerank_model=rerank_model)

    if not os.path.exists(rag_results):
        os.makedirs(rag_results)

    with open(f"{rag_results}/{name}.txt", "w") as f:
        f.write(patient_result.response)

    pp(patient_result.response)
    
    return patient_result

In [None]:
patient_result = gen_rag_case("<<case_name>>") # replace with the name of the case you want to run

In [None]:
pp(patient_result.response)