# Setup

In [1]:
from operator import itemgetter
from _global import path_to_resources, hf_embed
import templates
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import langsmith
from langsmith import traceable, trace
from langsmith.evaluation import LangChainStringEvaluator, evaluate
from langchain.callbacks.tracers import LangChainTracer

In [2]:
# set up retriever
db = Chroma(collection_name="main_collection", persist_directory=f"{path_to_resources}/db_wiki", embedding_function=hf_embed)
retriever = db.as_retriever(
                search_type = "similarity",
                search_kwargs = {"k":4},
            )

In [3]:
# langsmith setup
project_name = "ED-handout"

# RAG Class

In [4]:
class RagBot:
    def __init__(self, retriever, templates, model: str = "gpt-3.5-turbo-1106"):
        self._retriever = retriever
        self._llm_gpt = ChatOpenAI(model_name=model, temperature=0)
        self._llm_llama = Ollama(model="llama2:13b", temperature=0)
        self.templates = templates
        self.queries = {
            "definition": "definition of {diagnosis}",
            "presentation": "manifestations of {diagnosis}",
            "course": "natural history of {diagnosis}",
            "management": "treatment and management for {diagnosis}",
            "follow_up": "follow-up plan for {diagnosis}",
            "redflags": "signs and symptoms that indicate the need for urgent medical attention for patients with {diagnosis}",
        }

    
    @traceable
    def diagnosis_extraction(self, assessment):
        """Extracts diagnosis from physician's assessment of the patient"""
        prompt_extract_diagnosis = ChatPromptTemplate.from_messages([
            ("system",self.templates.extract_diagnosis_system),
            ("human", "{assessment}")
        ])
        chain_diagnosis = prompt_extract_diagnosis | self._llm_gpt
        
        return chain_diagnosis.invoke({"assessment":assessment}).content

    
    def make_queries(self, diagnosis):
        """Uses the diagnosis to populate dict of queries that will be used to retreive context from db"""
        return {key: value.format(diagnosis=diagnosis) for key, value in self.queries.items()}

    
    @traceable(run_type="retriever")
    def _retrieve_docs(self, query):
        return self._retriever.invoke(query)

    
    def get_contexts(self, queries):
        """returns a tuple with (query, contexts)"""
        contexts = {}
        for k, query in queries.items():
            contexts[k] = (query, self._retrieve_docs(query))
        
        return contexts


    def compress_contexts(self, q_c):
        prompt_compress = ChatPromptTemplate.from_messages([
            ("system", self.templates.compress_context_system),
            ("human", self.templates.compress_context_human)
        ])
        chain_compress = prompt_compress | self._llm_gpt

        return chain_compress.invoke({"query": q_c[0], "context": q_c[1]}).content

    
    @traceable()
    def retrieval_steps(self, assessment):
        """All the steps to prep the contexts for final handout generation"""    
        diagnosis = self.diagnosis_extraction(assessment)
        queries = self.make_queries(diagnosis)
        contexts = self.get_contexts(queries)

        return {"contexts": contexts, "diagnosis": diagnosis}
        
    
    @traceable()
    def make_handout(self, assessment, md_plan):
        _run_input = self.retrieval_steps(assessment)
        _contexts = _run_input["contexts"]
        diagnosis = _run_input["diagnosis"]

        # compression
        contexts = {}
        for k, q_c in _contexts.items():
            contexts[k] = self.compress_contexts(q_c)

        # make handout
        prompt_make_handout = ChatPromptTemplate.from_messages([
            ("system",self.templates.handout_generation_system),
            ("human", self.templates.handout_generation_human),
        ])
        chain_make_handout = prompt_make_handout | self._llm_gpt
        response = chain_make_handout.invoke({
            "context_definition": contexts["definition"],
            "context_presentation": contexts["presentation"],
            "context_course": contexts["course"],
            "context_management": contexts["management"],
            "context_follow_up": contexts["follow_up"],
            "context_redflags": contexts["redflags"],
            "context_md_plan": md_plan,
        })
        
        # Evaluators will expect "answer" and "contexts"
        return {
            "diagnosis": diagnosis,
            "contexts": "\n".join(contexts.values()),
            "handout": response.content,
        }



In [5]:
def make_handout_with_context(example: dict):
    """Use this for evaluation of retrieved documents and hallucinations"""
    response = rag_bot.get_answer(example["question"])
    print(response)
    return {"handout": response["handout"], "contexts": response["contexts"]}

In [6]:
bot = RagBot(retriever, templates)

In [9]:
# test that extraction works and works with langsmith

with trace("Diagnosis extraction", "chain", project_name=project_name, inputs={"assessment": "5yo M with viral-triggered asthma"}) as rt:
    output = bot.diagnosis_extraction(inputs["assessment"])
    rt.end(outputs={"output": output})

# Eval

## Doc grader

In [10]:
### OpenAI Grader

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

# Data model
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    score: str = Field(description="Documents grade based on correct diagnosis and relevant information")

# LLM with function call 
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt 
system = """
    You are a grader assessing relevance of a retrieved document to a user question. \n 
    The content of the document can be found in page_content. Give a score for the document using the scoring system below. 
    Scoring: (0: irrelevant diagnosis), (1: correct diagnosis, but does not contain information to anser the user question), (2: correct diagnosis and contains information to answer the user question). \n
    
    
"""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "RETRIEVED DOCUMENT: \n\n {document} \n\n USER QUESTION: {query}"),
    ]
)

retrieval_grader_oai = grade_prompt | structured_llm_grader

def predict_oai(inputs: dict) -> dict:
    # Returns pydantic object
    grade = retrieval_grader_oai.invoke({"query": inputs["query"], "document": inputs["context"]})
    return {"grade":grade.score}


- given a diagnosis
- create dataset of query + doc for each doc retrieved from each query
- run experiement on the dataset

In [21]:
def create_dataset(diagnosis, context_dict, dataset_name):
    """Takes query_context dictionary and create a dataset for {diagnosis} to evaluate the relevance of retrieved context"""
    client = langsmith.Client()
    
    dataset = client.create_dataset(
        dataset_name=dataset_name,
        description=f"Test context relevance for docs retreiived for {diagnosis}",
    )

    for query, q_c in context_dict.values(): #each document should be an example in the dataset        
        for doc in q_c:
            client.create_examples(
                inputs=[{"query": query, "context": doc}],
                dataset_id=dataset.id,
            )

In [22]:
from datetime import datetime

def context_relevance(rag_bot, assessment):
    retrieved = rag_bot.retrieval_steps(assessment) # dict of query:context
    context_dict = retrieved["contexts"]
    diagnosis = retrieved["diagnosis"]

    dataset_name = f"Queries_Docs_{diagnosis}"
    create_dataset(diagnosis, context_dict, dataset_name)
        
    evaluate(
        predict_oai,
        data=dataset_name,
        #summary_evaluators=[f1_score_summary_evaluator],
        experiment_prefix="Context-relevance-",
        # Any experiment metadata can be specified here
        metadata={
            "model": "oai",
            "diagnosis":diagnosis
        },
    )

In [23]:
context_relevance(bot, "5yo M, viral triggered asthma")

View the evaluation results for experiment: 'Context-relevance--d321b186' at:
https://smith.langchain.com/o/edfbc8bb-c3a3-5c1e-8b48-11b5a8cfd8ac/datasets/2c62c230-5ffb-4be1-a160-67b532c3d537/compare?selectedSessions=4a781120-8ba7-45ee-ac5f-031e0c2ddd8d




0it [00:00, ?it/s]

## Ground truth checker

In [None]:
from langsmith.evaluation import LangChainStringEvaluator, evaluate

answer_hallucination_evaluator = LangChainStringEvaluator(
    "labeled_score_string",
    config={
        "criteria": {
            "accuracy": """Is the Assistant's Answer grounded in the Ground Truth documentation? A score of [[1]] means that the
            Assistant answer contains is not at all based upon / grounded in the Groun Truth documentation. A score of [[5]] means 
            that the Assistant answer contains some information (e.g., a hallucination) that is not captured in the Ground Truth 
            documentation. A score of [[10]] means that the Assistant answer is fully based upon the in the Ground Truth documentation."""
        },
        # If you want the score to be saved on a scale from 0 to 1
        "normalize_by": 10,
    },
    prepare_data=lambda run, example: {
        "prediction": run.outputs["handout"],
        "reference": run.outputs["contexts"],
        "input": example.inputs,
    },
)

dataset_name = "RAG_test_LCEL"
experiment_results = evaluate(
    predict_rag_answer_with_context,
    data=dataset_name,
    evaluators=[answer_hallucination_evaluator],
    experiment_prefix="rag-qa-oai-hallucination",
    # Any experiment metadata can be specified here
    metadata={
        "variant": "LCEL context, gpt-3.5-turbo",
    },
)

## LLM grading based on custom metrics
- jargon
- reference list
- template format

## Human feedback of output