# Setup

In [41]:
from operator import itemgetter
from _global import path_to_resources, hf_embed
import templates
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import langsmith
from langsmith import traceable, trace
from langsmith.evaluation import LangChainStringEvaluator, evaluate
from langchain.callbacks.tracers import LangChainTracer

In [26]:
# set up retriever
db = Chroma(collection_name="main_collection", persist_directory=f"{path_to_resources}/db_wiki", embedding_function=hf_embed)
retriever = db.as_retriever(
                search_type = "similarity",
                search_kwargs = {"k":4},
            )

In [31]:
# langsmith setup
project_name = "ED-handout"

# RAG Class

In [38]:
class RagBot:
    def __init__(self, retriever, templates, model: str = "gpt-3.5-turbo-1106"):
        self._retriever = retriever
        self._llm_gpt = ChatOpenAI(model_name=model, temperature=0)
        self._llm_llama = Ollama(model="llama2:13b", temperature=0)
        self.templates = templates
        self.queries = {
            "definition": "definition of {diagnosis}",
            "presentation": "manifestations of {diagnosis}",
            "course": "natural history of {diagnosis}",
            "management": "treatment and management for {diagnosis}",
            "follow_up": "follow-up plan for {diagnosis}",
            "redflags": "signs and symptoms that indicate the need for urgent medical attention for patients with {diagnosis}",
        }

    
    @traceable
    def diagnosis_extraction(self, assessment):
        """Extracts diagnosis from physician's assessment of the patient"""
        prompt_extract_diagnosis = PromptTemplate.from_template(self.templates.extract_diagnosis)
        prompt_main = ChatPromptTemplate.from_messages([("system",self.templates.discharge_instructions_2)])
        diagnosis = prompt_extract_diagnosis | self._llm_gpt
        return diagnosis.invoke({"assessment":assessment}).content

    
    def make_queries(self, diagnosis):
        """Uses the diagnosis to populate dict of queries that will be used to retreive context from db"""
        return {key: value.format(diagnosis=diagnosis) for key, value in self.queries.items()}

    
    @traceable(run_type="retriever")
    def _retrieve_docs(self, query):
        return self._retriever.invoke(query)

    
    def get_contexts(self, queries):
        contexts = {}
        for k, query in queries.items():
            contexts[query] = self._retrieve_docs(query)
        
        return contexts

    
    @traceable()
    def retrieval_steps(self, assessment):
        """All the steps to prep the contexts for final handout generation"""    
        diagnosis = self.diagnosis_extraction(assessment)
        queries = self.make_queries(diagnosis)
        contexts = self.get_contexts(queries)

        return {"contexts": contexts, "diagnosis": diagnosis}
        
    
    @traceable()
    def make_handout(self, assessment, md_plan):
        contexts = self.retrieval_steps(assessment)
        response = self._llm_gpt.invoke(
            
        )

        # Evaluators will expect "answer" and "contexts"
        return {
            "diagnosis": diagnosis,
            "handout": response.choices[0].message.content,
            "contexts": [str(doc) for doc in similar],
        }



In [39]:
bot = RagBot(retriever, templates)

In [30]:
bot.diagnosis_extraction(inputs["assessment"])

'viral-triggered asthma'

In [14]:
# test that extraction works and works with langsmith

app_inputs = {"assessment": "5yo M with viral-triggered asthma"}

with trace("Diagnosis extraction", "chain", project_name=project_name, inputs=inputs) as rt:
    output = bot.diagnosis_extraction(inputs["assessment"])
    rt.end(outputs={"output": output})

# Eval

tasks
- doc grader to use custom metrics

## Doc grader

In [60]:
### OpenAI Grader

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

# Data model
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    score: str = Field(description="Documents grade based on correct diagnosis and relevant information")

# LLM with function call 
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt 
system = """
    You are a grader assessing relevance of a retrieved document to a user question. \n 
    The content of the document can be found in page_content. Give a score for the document using the scoring system below. 
    Scoring: (0: irrelevant diagnosis), (1: correct diagnosis, but does not contain information to anser the user question), (2: correct diagnosis and contains information to answer the user question). \n
    
    
"""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "RETRIEVED DOCUMENT: \n\n {document} \n\n USER QUESTION: {query}"),
    ]
)

retrieval_grader_oai = grade_prompt | structured_llm_grader

def predict_oai(inputs: dict) -> dict:
    # Returns pydantic object
    grade = retrieval_grader_oai.invoke({"query": inputs["query"], "document": inputs["context"]})
    return {"grade":grade.score}


- given a diagnosis
- create dataset of query + context for each query
- run experiement on the dataset

In [61]:
def create_dataset(diagnosis, q_c, dataset_name):
    """Takes query_context dictionary and create a dataset for {diagnosis} to evaluate the relevance of retrieved context"""
    client = langsmith.Client()
    
    dataset = client.create_dataset(
        dataset_name=dataset_name,
        description=f"Test context relevance for docs retreiived for {diagnosis}",
    )

    for query, context in q_c.items(): #each document should be an example in the dataset        
        for doc in context:
            client.create_examples(
                inputs=[{"query": query, "context": doc}],
                dataset_id=dataset.id,
            )

In [63]:
from datetime import datetime

def context_relevance(rag_bot, assessment):
    retrieved = rag_bot.retrieval_steps(assessment) # dict of query:context
    q_c = retrieved["contexts"]
    diagnosis = retrieved["diagnosis"]

    dataset_name = f"Queries_Docs_{diagnosis}"
    create_dataset(diagnosis, q_c, dataset_name)
        
    evaluate(
        predict_oai,
        data=dataset_name,
        #summary_evaluators=[f1_score_summary_evaluator],
        experiment_prefix="Context-relevance-",
        # Any experiment metadata can be specified here
        metadata={
            "model": "oai",
            "diagnosis":diagnosis
        },
    )

In [64]:
context_relevance(bot, "5yo M, viral triggered asthma")

View the evaluation results for experiment: 'Context-relevance--d016fe16' at:
https://smith.langchain.com/o/edfbc8bb-c3a3-5c1e-8b48-11b5a8cfd8ac/datasets/0b528c40-ebd7-4abc-90a1-0e4440a5c5e6/compare?selectedSessions=29dae9fb-9c5e-426e-b538-4055b0695e21




0it [00:00, ?it/s]

## Ground truth checker

## LLM grading based on custom metrics
- jargon
- reference list
- template format

## Human feedback of output