In [None]:
# data should be from the output of first_pass_gen.py
work_on_data = "./data_with_reasonings_0.pkl"
# should be the output of the infer.py
infer_records = "./recrods.pkl"
# should be the output of generate_new_reasoning.py
new_reasoning = "./all_rs.json"

In [1]:
from typing import Annotated, Literal, TypedDict, Any
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_ollama import ChatOllama
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
import pandas as pd
import pickle
from tqdm import tqdm
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import json
import uuid
from langchain_core.documents import Document
from time import time
import chromadb
from langgraph.graph.message import add_messages
from pydantic import Field, BaseModel
import loguru
import pickle
import requests as rq
from bs4 import BeautifulSoup
from urllib.parse import quote, quote_plus, urlparse, parse_qs, urlunparse, urlencode
from concurrent.futures import ThreadPoolExecutor
from langchain_openai import ChatOpenAI
import regex
import re
import time
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from threading import Lock
import nest_asyncio
import random
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
nest_asyncio.apply()
logger = loguru.logger
try:
    logger.remove(0)
except:
    pass
logger.add("a.log")
from concurrent.futures import ThreadPoolExecutor
class MedicalAskContent(BaseModel):
    
    disease: list[str] = Field(
        default_factory=list,
        description="医生提问中的疾病对象"
    )
    symptom: list[str] = Field(
        default_factory=list,
        description="医生提问中的症状对象"
    )
    medcine: list[str] = Field(
        default_factory=list,
        description="医生提问中的药物对象"
    )
    surgery: list[str] = Field(
        default_factory=list,
        description="医生提问中的手术对象"
    )
    body_part: list[str] = Field(
        default_factory=list,
        description="医生提问中的身体部位对象"
    )
    medical_check: list[str] = Field(
        default_factory=list,
        description="医生提问中的检查项目对象"
    )
    concept: list[str] = Field(
        default_factory=list,
        description="医生提问中的问诊医学概念对象"
    )
class RetrievalAction(BaseModel):
    """The action of retrievaling from the knowledge base."""
    reasoning: str = Field(
        description="The reasong why you are doing the following query."
    )
    queries: list[str] = Field(description="The list of queries to execute. Each should be a question.")

class RetrievalItem(TypedDict):
    
    query: str
    refined_result: str

class AskAction(BaseModel):
    """The action of asking the patient for extra information."""
    reasoning: str = Field(
        description="The reason why you are asking the following question."
    )
    disease: list[str] = Field(
        default_factory=list,
        description="你提问的疾病对象"
    )
    symptom: list[str] = Field(
        default_factory=list,
        description="你提问的症状对象"
    )
    medcine: list[str] = Field(
        default_factory=list,
        description="你提问的药物对象"
    )
    surgery: list[str] = Field(
        default_factory=list,
        description="你提问的手术对象"
    )
    body_part: list[str] = Field(
        default_factory=list,
        description="你提问的身体部位对象"
    )
    medical_check: list[str] = Field(
        default_factory=list,
        description="你提问的检查项目对象"
    )
    concept: list[str] = Field(
        default_factory=list,
        description="你提问的问诊医学概念对象"
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class TellAction(BaseModel):
    """The action of telling the patient something."""
    reasoning: str = Field(
        description="The reason why you are telling the user the following text."
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class Conversation(TypedDict):
    
    patient: str
    doctor: str

class RagMeta(TypedDict):

    vector_store: list[Document]
    baidu_search: tuple[
        Annotated[str, "question"],
        Annotated[str, "answer"],
        Annotated[str, "failed"],
        Annotated[str, "src"]
    ]

class RetrievalRequest(BaseModel):

    reasoning: str = Field(
        description="The process of how you arrived on the following queries you are going to make based on the input. Reasoning should be concise and contain only important points."
    )
    queries: list[str] = Field(
        description="The queries you want to make. Each should be a question."
    )

class RefinedResult(BaseModel):

    refined_result: str = Field(
        description="The refined result of the given message. Should be in no more than 2 to 3 sentences. Should contain no more content than what's related to the query."
    )

class ReasoningCheck(BaseModel):

    reasoning: str = Field(
        description="The reason why you assert that you can deduce the answer of the professional doctor from the given information, or why you cannot deduce it. Reasoning should be concise and contain only important points."
    )
    deducible: bool = Field(description="Whether you can deduce the answer of the professional doctor from the given information.")

class Response(BaseModel):

    reasoning: str = Field(
        description="The reason why you ask the user for more information or respond to them. Reasoning should be concise and contain only important points."
    )
    is_asking: bool = Field(
        description="Wehther your sentence asks the user for more information or tell them something. If this is true, this response will be marked as ask, or else it will be marked as tell."
    )
    entities: list[str] = Field(
        description="The medical entities related to this conversation. Only the important ones should be included."
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class RetrievalItem(TypedDict):
    
    query: str
    refined_result: str

class StructuredOutputWithRaw(TypedDict):
    
    raw: AIMessage
    parsed: Any
    parsing_error: Any

In [2]:
lded = pickle.load(open(work_on_data, "rb"))

In [3]:
from typing import Annotated, Literal, TypedDict, Any
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_ollama import ChatOllama
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
import pandas as pd
import pickle
from tqdm import tqdm
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import json
import uuid
from langchain_core.documents import Document
from time import time
import chromadb
from langgraph.graph.message import add_messages
from pydantic import Field, BaseModel
import loguru
import pickle
import requests as rq
from bs4 import BeautifulSoup
from urllib.parse import quote, quote_plus, urlparse, parse_qs, urlunparse, urlencode
from concurrent.futures import ThreadPoolExecutor
from langchain_openai import ChatOpenAI
import regex
import re
import time
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from threading import Lock
logger = loguru.logger
try:
    logger.remove(0)
except:
    pass
logger.add("a.log")
from concurrent.futures import ThreadPoolExecutor
class MedicalAskContent(BaseModel):
    
    disease: list[str] = Field(
        default_factory=list,
        description="医生提问中的疾病对象"
    )
    symptom: list[str] = Field(
        default_factory=list,
        description="医生提问中的症状对象"
    )
    medcine: list[str] = Field(
        default_factory=list,
        description="医生提问中的药物对象"
    )
    surgery: list[str] = Field(
        default_factory=list,
        description="医生提问中的手术对象"
    )
    body_part: list[str] = Field(
        default_factory=list,
        description="医生提问中的身体部位对象"
    )
    medical_check: list[str] = Field(
        default_factory=list,
        description="医生提问中的检查项目对象"
    )
    concept: list[str] = Field(
        default_factory=list,
        description="医生提问中的问诊医学概念对象"
    )
class RetrievalAction(BaseModel):
    """The action of retrievaling from the knowledge base."""
    reasoning: str = Field(
        description="The reasong why you are doing the following query."
    )
    queries: list[str] = Field(description="The list of queries to execute. Each should be a question.")

class RetrievalItem(TypedDict):
    
    query: str
    refined_result: str

class AskAction(BaseModel):
    """The action of asking the patient for extra information."""
    reasoning: str = Field(
        description="The reason why you are asking the following question."
    )
    disease: list[str] = Field(
        default_factory=list,
        description="你提问的疾病对象"
    )
    symptom: list[str] = Field(
        default_factory=list,
        description="你提问的症状对象"
    )
    medcine: list[str] = Field(
        default_factory=list,
        description="你提问的药物对象"
    )
    surgery: list[str] = Field(
        default_factory=list,
        description="你提问的手术对象"
    )
    body_part: list[str] = Field(
        default_factory=list,
        description="你提问的身体部位对象"
    )
    medical_check: list[str] = Field(
        default_factory=list,
        description="你提问的检查项目对象"
    )
    concept: list[str] = Field(
        default_factory=list,
        description="你提问的问诊医学概念对象"
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class TellAction(BaseModel):
    """The action of telling the patient something."""
    reasoning: str = Field(
        description="The reason why you are telling the user the following text."
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class Conversation(TypedDict):
    
    patient: str
    doctor: str

class RagMeta(TypedDict):

    vector_store: list[Document]
    baidu_search: tuple[
        Annotated[str, "question"],
        Annotated[str, "answer"],
        Annotated[str, "failed"],
        Annotated[str, "src"]
    ]

class RetrievalRequest(BaseModel):

    reasoning: str = Field(
        description="The process of how you arrived on the following queries you are going to make based on the input. Reasoning should be concise and contain only important points."
    )
    queries: list[str] = Field(
        description="The queries you want to make. Each should be a question."
    )

class RefinedResult(BaseModel):

    refined_result: str = Field(
        description="The refined result of the given message. Should be in no more than 2 to 3 sentences. Should contain no more content than what's related to the query."
    )

class ReasoningCheck(BaseModel):

    reasoning: str = Field(
        description="The reason why you assert that you can deduce the answer of the professional doctor from the given information, or why you cannot deduce it. Reasoning should be concise and contain only important points."
    )
    deducible: bool = Field(description="Whether you can deduce the answer of the professional doctor from the given information.")

class Response(BaseModel):

    reasoning: str = Field(
        description="The reason why you ask the user for more information or respond to them. Reasoning should be concise and contain only important points."
    )
    is_asking: bool = Field(
        description="Wehther your sentence asks the user for more information or tell them something. If this is true, this response will be marked as ask, or else it will be marked as tell."
    )
    entities: list[str] = Field(
        description="The medical entities related to this conversation. Only the important ones should be included."
    )
    text: str = Field(
        description="The text sent to the patient."
    )

class RetrievalItem(TypedDict):
    
    query: str
    refined_result: str

class StructuredOutputWithRaw(TypedDict):
    
    raw: AIMessage
    parsed: Any
    parsing_error: Any

In [4]:
records = pickle.load(open(infer_records, "rb"))

In [5]:
len(records)

7000

In [7]:
rss = json.load(open(new_reasoning, "r"))

In [8]:
len(rss)

13196

In [9]:
pmp_start = """<|im_start|>system

# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name": "RetrievalAction", "description": "The action of retrievaling from the knowledge base.", "parameters": {"properties": {"reasoning": {"description": "The reasong why you are doing the following query.", "type": "string"}, "queries": {"description": "The list of queries to execute. Each should be a question.", "items": {"type": "string"}, "type": "array"}}, "required": ["reasoning", "queries"], "type": "object"}}}
{"type": "function", "function": {"name": "AskAction", "description": "The action of asking the patient for extra information.", "parameters": {"properties": {"reasoning": {"description": "The reason why you are asking the following question.", "type": "string"}, "disease": {"description": "你提问的疾病对象", "items": {"type": "string"}, "type": "array"}, "symptom": {"description": "你提问的症状对象", "items": {"type": "string"}, "type": "array"}, "medcine": {"description": "你提问的药物对象", "items": {"type": "string"}, "type": "array"}, "surgery": {"description": "你提问的手术对象", "items": {"type": "string"}, "type": "array"}, "body_part": {"description": "你提问的身体部位对象", "items": {"type": "string"}, "type": "array"}, "medical_check": {"description": "你提问的检查项目对象", "items": {"type": "string"}, "type": "array"}, "concept": {"description": "你提问的问诊医学概念对象", "items": {"type": "string"}, "type": "array"}, "text": {"description": "The text sent to the patient.", "type": "string"}}, "required": ["reasoning", "text"], "type": "object"}}}
{"type": "function", "function": {"name": "TellAction", "description": "The action of telling the patient something.", "parameters": {"properties": {"reasoning": {"description": "The reason why you are telling the user the following text.", "type": "string"}, "text": {"description": "The text sent to the patient.", "type": "string"}}, "required": ["reasoning", "text"], "type": "object"}}}
</tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"arguments": <args-json-object>, "name": <function-name>}
</tool_call><|im_end|>
<|im_start|>user
### 指令
是一个医学专家级的AI助手。现在你需要根据已有信息，给出下一个Action，如果信息不足，可以使用RetrievalAction查询知识库，或使用AskAction询问患者。如果信息充足，可以使用TellAction告知患者信息。注意，你一次只能执行一个Action。

### 医患历史对话
<|[conversation]|>

<|im_end|>"""

In [10]:
# <|[conversation]|>
# <|[query_reasoning]|>
# <|[queries_list_json]|>
# <|[refined_result_s]|>
# <|[tell_action_args_json_obj]|>

In [11]:
query_rq_and_resp = """<|im_start|>assistant
<tool_call>
{"arguments": {"reasoning": "<|[query_reasoning]|>", "queries": <|[queries_list_json]|>}, "name": "RetrievalAction"}
</tool_call><|im_end|><|im_start|>user
<tool_response>
<|[refined_result_s]|>
</tool_response><|im_end|>"""

In [12]:
tell_action = """<|im_start|>assistant
<tool_call>
{"arguments": <|[tell_action_args_json_obj]|>, "name": "TellAction"}
</tool_call><|im_end|>"""

In [13]:
ask_action = """<|im_start|>assistant
<tool_call>
{"arguments": <|[ask_action_args_json_obj]|>, "name": "AskAction"}
</tool_call><|im_end|>"""

In [14]:
class DPOData(BaseModel):

    prompt: str = Field(
        description="The prompt to generate the response for. Should be in no more than 2 to 3 sentences."
    )
    reject: list[str] = Field(
        description="The reason why the prompt is rejected. Should be in no more than 2 to 3 sentences."
    )
    chosen: list[str] = Field(
        description="The response to the prompt. Should be in no more than 2 to 3 sentences."
    )

In [15]:
def get_prompt(
    conversations: list[Conversation],
    query_requests: list[RetrievalRequest],
    query_history_by_turn: list[list[RetrievalItem]],
    model_response: TellAction | AskAction,
    actual_reasoning: str,
    actual_objects: MedicalAskContent,
    model_is_asking: bool,
    actual_is_asking: bool
) -> DPOData:
    fmt_conv = []
    for each in conversations:
        fmt_conv.append(
            f"医生：{each['doctor']}"
        )
        fmt_conv.append(
            f"患者：{each['patient']}"
        )
    last_doc = conversations[-1]['doctor']
    start_part = pmp_start.replace("<|[conversation]|>", "\n".join(fmt_conv[:-1]))
    q_sequence = []
    for i in range(len(query_requests)):
        this_query = query_rq_and_resp.replace(
            "<|[query_reasoning]|>", query_requests[i].reasoning,
        ).replace(
            "<|[queries_list_json]|>", json.dumps(query_requests[i].queries, ensure_ascii=False)
        )
        q_results = []
        for each in query_history_by_turn[i]:
            q_results.append(
                json.dumps(each, ensure_ascii=False)
            )
        this_query = this_query.replace(
            "<|[refined_result_s]|>", "\n".join(q_results)
        )
        q_sequence.append(
            this_query
        )
    prompt = start_part + "\n" + "".join(q_sequence)
    reject = ""
    chosen = ""
    if model_is_asking:
        reject = ask_action.replace(
            "<|[ask_action_args_json_obj]|>", json.dumps(model_response.dict(), ensure_ascii=False)
        )
    else:
        reject = tell_action.replace(
            "<|[tell_action_args_json_obj]|>", json.dumps(model_response.model_dump(), ensure_ascii=False)
        )
    if not actual_is_asking:
        accept = tell_action.replace(
            "<|[tell_action_args_json_obj]|>", json.dumps(
                TellAction(   
                    reasoning=actual_reasoning,
                    text=last_doc
                ).model_dump()
                , ensure_ascii=False)
        )
    else:
        accept = ask_action.replace(
            "<|[ask_action_args_json_obj]|>", json.dumps(
                AskAction(   
                    reasoning=actual_reasoning,
                    body_part=actual_objects.body_part,
                    concept=actual_objects.concept,
                    disease=actual_objects.disease,
                    medcine=actual_objects.medcine,
                    medical_check=actual_objects.medical_check,
                    surgery=actual_objects.surgery,
                    symptom=actual_objects.symptom,
                    text=last_doc,
                ).model_dump()
                , ensure_ascii=False)
        )
    return {
        "prompt": prompt,
        "rejected": reject,
        "chosen": accept,
    }

In [16]:
dpo_data = []

In [18]:
records[0]

[{'conversation': [{'patient': '雀斑。小时候长的现在还可以治疗吗，不是遗传的。',
    'doctor': '是单侧还是双侧？'},
   {'patient': '单侧，费用大概要多少，需要几次。',
    'doctor': '你这个不是雀斑。应该是雀斑样痣。需要三四次，总费用要一万元左右。'}],
  'response_structured_with_raw': {'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_EcINfsOOAk0RyrcqtrSADTsl', 'function': {'arguments': '{"entities":["雀斑","激光治疗","费用","次数"],"is_asking":false,"reasoning":"患者询问单侧雀斑的治疗方法及相关费用，因此提供了详细的治疗选项和费用信息以帮助患者决策。","text":"雀斑的治疗方法包括美白产品（如维生素C）、激光治疗或化学剥脱等。具体来说，激光治疗的费用大约在1000元左右，通常需要2-3次治疗，每次治疗间隔约3周。建议您在当地医院咨询详细情况，以获得适合您的治疗方案。"}', 'name': 'Response'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 136, 'prompt_tokens': 794, 'total_tokens': 930, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'gpt-4o-mini', 'system_fingerprint': 'fp_04751d0b65', 'finish_reason': 'stop', 'logprobs': None}, id='run-f4ac6b15-5647-45ae-be13-cbed70154f81-0', tool_calls=[{'name': 'Response', 'args': 

In [19]:
for i in range(len(records)):
    # records[i]['doc_reasoning'] = rss[i]
    if "<tool_call>" in rss[i]:
        continue
    if len(rss[i]) < 3:
        continue
    if bool(
        bool(records[i][0]['all_meta']['is_asking']) ^ bool(isinstance(records[i][2], AskAction)) 
    ):
        continue
    try:
        p = get_prompt(
            records[i][0]['all_meta']['conversation'],
            records[i][0]['all_meta']['query_requests'],
            records[i][0]['all_meta']['query_history_by_turn'],
            records[i][2],
            rss[i],
            records[i][0]['all_meta']['objs'] if 'objs' in records[i][0]['all_meta'] else None,
            isinstance(records[i][2], AskAction),
            records[i][0]['all_meta']['is_asking']
        )
        dpo_data.append(p)
    except:
        continue

/tmp/ipykernel_786557/1171835073.py:44: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  "<|[ask_action_args_json_obj]|>", json.dumps(model_response.dict(), ensure_ascii=False)


In [20]:
len(dpo_data)

8824

In [21]:
rej_dpo = []
pmp_dpo = []
chosen_dpo = []

In [22]:
for each in dpo_data:
    rej_dpo.append(each['rejected'])
    pmp_dpo.append(each['prompt'])
    chosen_dpo.append(each['chosen'])

In [23]:
json.dump(dpo_data, open("./dpo_data.json", "w"), ensure_ascii=False)

In [24]:
json.dump(dpo_data[:10], open("./dpo_data_10.json", "w"), indent=2, ensure_ascii=False)