In [1]:
import os
import dotenv
dotenv.load_dotenv()
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_PROJECT"] = "dvd_evaluation"

In [2]:
import os
from typing import List, Dict, Any
from pydantic import BaseModel, Field, field_validator
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import StateGraph, END
import re

# Initialize the LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.7)

# Define data models
class MCQ(BaseModel):
    question: str
    options: List[str]
    correct_answer: str
    relevance_score: float = 0.0  # Add relevance score to each MCQ

class Document(BaseModel):
    content: str
    mcqs: List[MCQ] = Field(default_factory=list)

class DVDState(BaseModel):
    documents: List[Document] = Field(default_factory=list)
    current_document_index: int = 0
    mcqs_generated: bool = False
    mcqs_evaluated: bool = False
    user_responses: List[Dict[str, Any]] = Field(default_factory=list)
    evaluation_results: Dict[str, Any] = Field(default_factory=dict)

# Function to load document content from file
def load_document(filename: str) -> str:
    try:
        with open(filename, 'r', encoding='utf-8') as file:
            return file.read().strip()
    except FileNotFoundError:
        print(f"Error: File '{filename}' not found.")
        return ""
    except IOError:
        print(f"Error: Unable to read file '{filename}'.")
        return ""

# Function to generate MCQs
def generate_mcqs(state: DVDState) -> DVDState:
    current_doc = state.documents[state.current_document_index]
    
    system_message = """
    You are an expert in creating challenging multiple-choice questions (MCQs) based on medical notes. 
    Generate a maximum of 10 MCQs that are diverse and directly related to the content of the given medical note. 
    Each MCQ should have 5 answer choices (A, B, C, D, E), including "I don't know" as the last option.
    Ensure that the questions are not obvious and require a good factual grasp of the note's content.
    Format each MCQ as follows exactly with no additional text or formatting:
    Question: [Question text]
    A. [Option A]
    B. [Option B]
    C. [Option C]
    D. [Option D]
    E. I don't know
    Correct Answer: [Correct option letter]
    """
    
    human_message = f"Generate a maximum of 10 diverse and relevant MCQs based on this medical note:\n\n{current_doc.content}"
    
    response = llm.invoke([
        SystemMessage(content=system_message),
        HumanMessage(content=human_message)
    ])
    print(f"Generating MCQs for document {state.current_document_index + 1}")
    # print(response.content)
    # Parse the response and create MCQ objects
    mcqs = []
    for mcq_text in response.content.split("\n\n"):
        lines = [line.strip() for line in mcq_text.split("\n") if line.strip()]
        if len(lines) < 7:
            continue  # Skip incomplete MCQs
        
        question = lines[0].replace("Question: ", "").strip()
        options = [line.split(". ", 1)[1].strip() for line in lines[1:6] if ". " in line]
        correct_answer_line = next((line for line in lines if line.lower().startswith("correct answer:")), None)
        # print(f"correct_answer_line: {correct_answer_line}")
        
        if correct_answer_line and len(options) == 5:
            correct_answer_letter = correct_answer_line.split(":", 1)[1].strip()
            correct_answer_index = ord(correct_answer_letter.upper()) - ord('A')
            if 0 <= correct_answer_index < len(options):
                correct_answer = options[correct_answer_index]
            else:
                correct_answer = options[-1]  # Default to last option if index is invalid
            mcqs.append(MCQ(question=question, options=options, correct_answer=correct_answer))
            # print(mcqs[-1])
    
    new_state = DVDState(**state.model_dump())
    new_state.documents[state.current_document_index].mcqs = mcqs
    new_state.current_document_index += 1
    
    if new_state.current_document_index >= len(new_state.documents):
        new_state.mcqs_generated = True
        new_state.current_document_index = 0  # Reset for next steps
        
        # Combine MCQs from both documents and limit to 20
        all_mcqs = []
        for doc in new_state.documents:
            all_mcqs.extend(doc.mcqs)
        combined_mcqs = all_mcqs[:20]
        
        # Update both documents with the combined MCQs
        for doc in new_state.documents:
            doc.mcqs = combined_mcqs
    
    return new_state




# Function to evaluate MCQ relevance
def evaluate_mcq_relevance(state: DVDState) -> DVDState:
    print("Evaluating MCQ relevance")
    
    new_state = DVDState(**state.model_dump())
    
    relevance_prompt = """
    Please read the following question and assess how pertinent it is to the content typically included in a patient discharge summary. A patient discharge summary encompasses critical information about a patient's hospital stay and instructions for post-discharge care, excluding personal identifiers. The pertinent content areas are as follows:

    1. Hospital Admission and Discharge Details
    2. Reason for Hospitalization
    3. Hospital Course Summary
    4. Discharge Diagnosis
    5. Procedures Performed
    6. Medications at Discharge
    7. Discharge Instructions
    8. Follow-Up Care
    9. Patient's Condition at Discharge
    10. Patient Education and Counseling
    11. Pending Results
    12. Advance Directives and Legal Considerations
    13. Healthcare Provider Information
    14. Additional Notes

    Evaluation Criteria:

    Pertinence Score (0-10):
    0: Not pertinent at all to a patient discharge summary
    5: Moderately pertinent
    10: Highly pertinent and directly related to the content of a discharge summary

    Consider how closely the question relates to the above content areas. Provide a score from 0 to 10 and a brief explanation if necessary.

    Question: {question}

    Your Response:

    Provide only the Pertinence Score (0-10) in number format
    """

    class RelevanceResponse(BaseModel):
        pertinence_score: int = Field(..., ge=0, le=10)

        @field_validator('pertinence_score', mode='before')
        @classmethod
        def extract_score(cls, v):
            if isinstance(v, str):
                try:
                    # Extract all numbers from the string
                    numbers = re.findall(r'\d+', v)
                    if numbers:
                        # Return the first number found
                        return int(numbers[0])
                    else:
                        raise ValueError("No numeric score found in the response")
                except (IndexError, ValueError):
                    raise ValueError("Invalid score format")
            return v

    for doc in new_state.documents:
        for mcq in doc.mcqs:
            try:
                response = llm.invoke([HumanMessage(content=relevance_prompt.format(question=mcq.question))])
                
                relevance = RelevanceResponse(pertinence_score=response.content)
                mcq.relevance_score = relevance.pertinence_score
                print(f"Question: {mcq.question}")
                print(f"Relevance Score: {mcq.relevance_score}")
            except Exception as e:
                print(f"Error processing question: {mcq.question}")
                print(f"Error details: {str(e)}")
                print(f"LLM response: {response.content}")
                mcq.relevance_score = 0  # Default to 0 if parsing fails

    new_state.mcqs_evaluated = True
    return new_state

# Function to present MCQs and collect automated responses
def present_mcqs(state: DVDState) -> DVDState:
    print("Presenting MCQs and collecting automated responses")
    
    user_responses = []
    for doc in state.documents:
        for mcq in doc.mcqs:
            # Create the prompt for answering each question
            answer_prompt = f"""
            You are given the following multiple-choice question based on the provided document content. Provide the best answer based on the given choices.
            Document Content: {doc.content}
            Question: {mcq.question}
            A. {mcq.options[0]}
            B. {mcq.options[1]}
            C. {mcq.options[2]}
            D. {mcq.options[3]}
            E. I don't know
            Respond with the option letter (A, B, C, D, or E) that you think is the correct answer.
            Do not include any other text in your response.
            """
            # print(answer_prompt)
            response = llm.invoke([HumanMessage(content=answer_prompt)])
            answer = response.content.strip()
            # print(f"answer: {answer}")

            # Validate the response is one of the expected options
            if answer not in ['A', 'B', 'C', 'D', 'E']:
                answer = 'E'  # Default to "I don't know" if there's any issue
            
            user_responses.append({
                "document": doc,
                "question": mcq.question,
                "user_answer": mcq.options[ord(answer) - ord('A')],
                "correct_answer": mcq.correct_answer,
                "relevance_score": mcq.relevance_score
            })
    
    new_state = DVDState(**state.model_dump())
    new_state.user_responses = user_responses
    return new_state

# Function to evaluate user responses
def evaluate_responses(state: DVDState) -> DVDState:
    print("Evaluating responses")
    results = {
        "document1": {"correct": 0, "wrong": 0, "unknown": 0},
        "document2": {"correct": 0, "wrong": 0, "unknown": 0}
    }
    
    for response in state.user_responses:
        doc_index = "document1" if state.documents.index(response["document"]) == 0 else "document2"
        if response["user_answer"] == "I don't know":
            results[doc_index]["unknown"] += 1
        elif response["user_answer"] == response["correct_answer"]:
            results[doc_index]["correct"] += 1
        else:
            results[doc_index]["wrong"] += 1
    
    # Print final scores for both documents and detailed MCQ information
    print("\nFinal Scores:")
    for doc, scores in results.items():
        total_questions = scores['correct'] + scores['wrong'] + scores['unknown']
        print(f"\n{doc.capitalize()}:")
        print(f"  Total Questions: {total_questions}")
        print(f"  Correct: {scores['correct']}")
        print(f"  Wrong: {scores['wrong']}")
        print(f"  Unknown: {scores['unknown']}")
    
    # # Print MCQs, relevance scores, and answers of both models
    # print("\nDetailed MCQ Evaluation:")
    # for response in state.user_responses:
    #     print(f"\nQuestion: {response['question']}")
    #     print(f"Relevance Score: {response['relevance_score']}")
    #     print(f"Correct Answer: {response['correct_answer']}")
    #     print(f"Model Answer: {response['user_answer']}")
    
    new_state = DVDState(**state.model_dump())
    new_state.evaluation_results = results
    return new_state

# Build the state graph
def build_dvd_graph() -> StateGraph:
    # make sure DVDState is empty before starting
    DVDState.model_validate({})

    workflow = StateGraph(DVDState)
    
    workflow.add_node("generate_mcqs", generate_mcqs)
    workflow.add_node("evaluate_mcq_relevance", evaluate_mcq_relevance)
    workflow.add_node("present_mcqs", present_mcqs)
    workflow.add_node("evaluate_responses", evaluate_responses)
    
    workflow.set_entry_point("generate_mcqs")
    
    workflow.add_conditional_edges(
        "generate_mcqs",
        lambda x: "generate_mcqs" if not x.mcqs_generated else "evaluate_mcq_relevance"
    )
    workflow.add_edge("evaluate_mcq_relevance", "present_mcqs")
    # workflow.add_edge("generate_mcqs", "present_mcqs")
    workflow.add_edge("present_mcqs", "evaluate_responses")
    workflow.add_edge("evaluate_responses", END)
    
    return workflow.compile()

# Run the DVD evaluation
def run_dvd_evaluation(doc1_filename: str, doc2_filename: str):
    doc1_content = load_document(doc1_filename)
    doc2_content = load_document(doc2_filename)

    if not doc1_content or not doc2_content:
        print("Error: Unable to load one or both documents. Exiting.")
        return

    initial_state = DVDState(documents=[
        Document(content=doc1_content),
        Document(content=doc2_content)
    ])
    graph = build_dvd_graph()
    
    
    for event in graph.stream(initial_state, {"configurable": {"thread_id": "dvd_evaluation"}}):
        if isinstance(event, DVDState) and event.evaluation_results:
            results = event.evaluation_results
            # print("\nEvaluation Results:")
            # for doc, scores in results.items():
            #     print(f"\n{doc.capitalize()}:")
            #     print(f"  Correct: {scores['correct']}")
            #     print(f"  Wrong: {scores['wrong']}")
            #     print(f"  Unknown: {scores['unknown']}")

# Run the evaluation
if __name__ == "__main__":
    run_dvd_evaluation("human.txt", "ai.txt")

Generating MCQs for document 1
Question: What was the patient's most recent imaging finding related to her kidney condition?
A. Left-sided hydrouretronephrosis
B. Right-sided kidney stones
C. Normal kidney function
D. Stable bone metastasis
E. I don't know
Correct Answer: B

Question: Which medication was the patient recently started on for weight loss?
A. Tamoxifen
B. Femara
C. Ozempic
D. Faslodex
E. I don't know
Correct Answer: C

Question: What was the result of the urine analysis regarding the presence of bacteria?
A. Positive
B. Negative
C. Not specified
D. Trace
E. I don't know
Correct Answer: B

Question: Which cancer treatment was discontinued due to toxicity in this patient?
A. Xeloda
B. Tamoxifen
C. Paclitaxel
D. Ribociclib
E. I don't know
Correct Answer: A

Question: What does the patient's recent CT scan indicate regarding her bone metastasis?
A. Local disease recurrence
B. Stable bone metastasis
C. New visceral metastasis
D. Complete resolution of metastasis
E. I don't kno