In [None]:
# data should be from the output of first_pass_gen.py
work_on_data = "./data_with_reasonings_0.pkl"

In [1]:
from typing import Annotated, Literal, TypedDict, Any
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_ollama import ChatOllama
from langchain_core.output_parsers import StrOutputParser
import pandas as pd
import pickle
from tqdm import tqdm
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import json
import uuid
from langchain_core.documents import Document
from time import time
import chromadb
from langgraph.graph.message import add_messages
from pydantic import Field, BaseModel
from langchain_community.retrievers import BM25Retriever
import loguru
import pickle
import requests as rq
from bs4 import BeautifulSoup
from urllib.parse import quote, quote_plus, urlparse, parse_qs, urlunparse, urlencode
from concurrent.futures import ThreadPoolExecutor
from langchain_core.utils.function_calling import convert_to_openai_function
import json

In [2]:
class RetrievalRequest(BaseModel):

    reasoning: str = Field(
        description="The process of how you arrived on the following queries you are going to make based on the input. Reasoning should be concise and contain only important points."
    )
    queries: list[str] = Field(
        description="The queries you want to make. Each should be a question."
    )

class RefinedResult(BaseModel):

    refined_result: str = Field(
        description="The refined result of the given message. Should be in no more than 2 to 3 sentences. Should contain no more content than what's related to the query."
    )

class ReasoningCheck(BaseModel):

    reasoning: str = Field(
        description="The reason why you assert that you can deduce the answer of the professional doctor from the given information, or why you cannot deduce it. Reasoning should be concise and contain only important points."
    )
    deducible: bool = Field(description="Whether you can deduce the answer of the professional doctor from the given information.")

class Response(BaseModel):

    reasoning: str = Field(
        description="The reason why you ask the user for more information or respond to them. Reasoning should be concise and contain only important points."
    )
    is_asking: bool = Field(
        description="Wehther your sentence asks the user for more information or tell them something. If this is true, this response will be marked as ask, or else it will be marked as tell."
    )
    entities: list[str] = Field(
        description="The medical entities related to this conversation. Only the important ones should be included."
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class Conversation(TypedDict):
    
    patient: str
    doctor: str

class RetrievalItem(TypedDict):
    
    query: str
    refined_result: str

class StructuredOutputWithRaw(TypedDict):
    
    raw: AIMessage
    parsed: Any
    parsing_error: Any

In [3]:
ld = pickle.load(open(work_on_data, "rb"))

In [4]:
len(ld)

65019

In [5]:
ld[-28]

{'conversation': [{'patient': '扁桃体发炎，咽痛，拖久了会不会要做手术啊？',
   'doctor': '扁桃体炎需要及时药物治疗控制，拖久了会引起扁桃体周围炎或脓肿，或可引起风湿性疾病。'},
  {'patient': '咽口水都觉得有东西在喉咙管，疼', 'doctor': '需要明确扁桃体炎还是急性咽炎，病情查血常规，看血象有多高'}],
 'query_messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_fpgYnrlIypSQh09nPIL3j9XT', 'function': {'arguments': '{"reasoning":"患者提到扁桃体发炎和咽痛，询问是否需要手术，说明需要了解扁桃体炎的治疗方案和可能的并发症。患者也提到吞咽时喉咙有异物感和疼痛，因此需要了解咽痛的原因和相关症状。","queries":["扁桃体炎的治疗方法有哪些？","咽痛伴随异物感的常见原因是什么？"]}', 'name': 'RetrievalRequest'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 105, 'prompt_tokens': 344, 'total_tokens': 449, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_0705bf87c0', 'finish_reason': 'stop', 'logprobs': None}, id='run-c22bdbb4-c

In [6]:
class RetrievalAction(BaseModel):
    """The action of retrievaling from the knowledge base."""
    reasoning: str = Field(
        description="The reasong why you are doing the following query."
    )
    queries: list[str] = Field(description="The list of queries to execute. Each should be a question.")

class AskAction(BaseModel):
    """The action of asking the patient for extra information."""
    reasoning: str = Field(
        description="The reason why you are asking the following question."
    )
    entities: list[str] = Field(
        description="The medical entities related to this conversation. Only the important ones should be included."
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class TellAction(BaseModel):
    """The action of telling the patient something."""
    reasoning: str = Field(
        description="The reason why you are telling the user the following text."
    )
    text: str = Field(
        description="The text sent to the patient."
    )

In [7]:
tools = []

In [8]:
for each in [RetrievalAction, AskAction, TellAction]:
    converted = convert_to_openai_function(each)
    tools.append({
        "Type": "function",
        "Function": json.dumps(converted, ensure_ascii=False)
    })

In [9]:
format_objects = []

In [10]:
instruction = """### 指令
是一个医学专家级的AI助手。现在你需要根据已有信息，给出下一个Action，如果信息不足，可以使用RetrievalAction查询知识库，或使用AskAction询问患者。如果信息充足，可以使用TellAction告知患者信息。注意，你一次只能执行一个Action。

### 医患历史对话
{}

"""

In [None]:
idx_to_raw = {}

In [11]:
for each in ld:
    messages = []
    
    conversation = each['conversation']
    conversation_formatted = []
    for e in conversation:
        conversation_formatted.append(f"患者：{e['patient']}")
        conversation_formatted.append(f"医生：{e['doctor']}")

    messages.append({
        "Role": "user",
        "Content": instruction.format("\n".join(conversation_formatted[:-1]))
    })
    
    query_requests = each['query_requests']
    query_history_by_turn = each["query_history_by_turn"]

    for i in range(len(query_requests)):
        query_request: RetrievalRequest = query_requests[i]
        query_results_this_turn: list[RetrievalItem] = query_history_by_turn[i]
        messages.append({
            "Role": "assistant",
            "ToolCalls": [
                {
                    "Function": {
                        "Name": "RetrievalAction",
                        "Arguments": json.dumps({
                            "reasoning": query_request.reasoning,
                            "queries": query_request.queries,
                        }, ensure_ascii=False)
                    }
                }
            ]
        })
        messages.append({
            "Role": "tool",
            "Content": json.dumps(query_results_this_turn, ensure_ascii=False)
        })
    if each["response_structured_with_raw"]['parsed'] is None:
        continue
    parsed_response: Response = each["response_structured_with_raw"]['parsed']
    if parsed_response.is_asking:
        messages.append({
            "Role": "assistant",
            "ToolCalls": [
                {
                    "Function": {
                        "Name": "AskAction",
                        "Arguments": json.dumps({
                            "reasoning": parsed_response.reasoning,
                            "entities": parsed_response.entities,
                            "text": parsed_response.text
                            # "text": conversation[-1]["doctor"]
                        }, ensure_ascii=False)
                    }
                }
            ]
        })
    else:
        messages.append({
            "Role": "assistant",
            "ToolCalls": [
                {
                    "Function": {
                        "Name": "TellAction",
                        "Arguments": json.dumps({
                            "reasoning": parsed_response.reasoning,
                            "text": parsed_response.text
                            # "text": conversation[-1]["doctor"]
                        }, ensure_ascii=False)
                    }
                }
            ]
        })
    format_objects.append({
        "Messages": messages,
        "Tools": tools
    })

In [12]:
json.dump(format_objects, open("./objs.json", "w"), ensure_ascii=False)

In [None]:
try:
    os.rmdir("./cache/values")
except:
    pass
os.makedirs("./cache/values", exist_ok=True)
os.makedirs("./cache/results", exist_ok=True)
for i, o in enumerate(format_objects):
    idx_to_raw[i] = ld[i]
    json.dump(o, open(f"./cache/values/{i}.json", "w"), indent=2, ensure_ascii=False)

In [14]:
# call go lang here
os.system("go run ./work.go")

Rendered prompts saved to rendered_prompts.json


0

In [15]:
import json

In [16]:
dd = json.load(open("./rendered_prompts.json", "r"))

In [17]:
trimmed = []
for each in dd:
    trimmed.append(each.replace('<no value>', '<|im_end|>'))

In [18]:
json.dump(trimmed, open('./rendered_prompts.json', "w"), ensure_ascii=False)

In [24]:
with open("res_40.txt", "w") as f:
    for each in trimmed[40:50]:
        f.write(each)
        f.write("=" * 10 + "\n")

In [20]:
len(trimmed)

65018