In [None]:
import jieba
import jieba.analyse
import jieba.posseg
import jieba.posseg as pseg
from neo4j import GraphDatabase
from gensim.models import Word2Vec
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

In [None]:
import os
from neo4j import GraphDatabase

NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")

if not all([NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD]):
    raise ValueError("Missing env vars: NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD")

driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

In [None]:
# Method 1(Finding the Keywords):：use the vocabulary-filtered segmentation approach
from pathlib import Path

# Load medical vocabulary into a Python set
def load_medical_vocab(vocab_dir):
    medical_terms = set()
    for fname in [
        "vocab.txt",
        "symptom_vocab.txt",
        "disease_vocab.txt",
        "complications_vocab.txt",
        "alias_vocab.txt",
    ]:
        with open(Path(vocab_dir) / fname, "r", encoding="utf-8") as f:
            for line in f:
                term = line.strip()
                if term:
                    medical_terms.add(term)
    return medical_terms

#   Extract medical-related keywords from a sentence.
def key_word(sentence):
    medical_terms = load_medical_vocab("data/vocab")
    tokens = jieba.lcut(sentence)
    return [w for w in tokens if w in medical_terms]

In [None]:
# Method 2(Finding the Keywords): perform unsupervised keyword discovery by combining TF–IDF-based and TextRank-based extraction methods and returning the union of their results

jieba.Tokenizer()
jieba.load_userdict('/vocab.txt')
jieba.load_userdict('/symptom_vocab.txt')
jieba.load_userdict('/disease_vocab.txt')
jieba.load_userdict('/complications_vocab.txt')
jieba.load_userdict('/alias_vocab.txt')


def keyword(sentence):
    entities1=jieba.analyse.extract_tags(sentence, topK=20, withWeight=False)
    lists_of_t=jieba.analyse.textrank(sentence, topK=20, withWeight=True)
    entities2 = [t[0] for t in lists_of_t]
    entities = list(set(entities1+entities2))
    return entities

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline


model_dir = "facebook/bart-large-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)

candidate_labels = [
    "部门", "身体部位", "传染性", "人群", "保险", "药物", "别名", "花费",
    "治愈率", "症状", "治疗方案", "检查项目", "并发症", "时间", "病症名称",
    "非医疗/其他"  
]

# Internal intent codes used in your system (must align index-wise)
intention = [
    "department", "part", "infection", "age", "insurance", "drug", "alias", "cost",
    "rate", "symptom", "treatment", "checklist", "complication", "period", "Disease",
    "out_of_scope"
]


def find_intention(sentence, intent_threshold=0.45, margin_threshold=0.08, oos_threshold=0.50):

    result = classifier(sentence, candidate_labels)
    labels = result["labels"]
    scores = result["scores"]

    top_label = labels[0]
    top_score = float(scores[0])
    second_score = float(scores[1]) if len(scores) > 1 else 0.0
    gap = top_score - second_score

    # Case A: explicitly out-of-scope
    if top_label == "非医疗/其他" and top_score >= oos_threshold:
        return "out_of_scope"

    # Map to internal intent code
    intent = intention[candidate_labels.index(top_label)]

    # Case B: conservative fallback to out-of-scope when classification is unreliable
    if top_score < intent_threshold or gap < margin_threshold:
        return "out_of_scope"

    return intent


In [None]:
DISEASE_ANCHOR_LABELS = [
    "Disease",     
    "alias",       
    "part",        
    "department",  
    "symptom",     
    "drug",        
]

#Retrieve all labels for nodes whose name matches the given entity.
def fetch_labels(tx, name):
    query = "MATCH (n {name: $name}) RETURN labels(n) AS labels"
    result = tx.run(query, name=name)
    all_labels = set()
    found = False

    for record in result:
        found = True
        for lab in record["labels"]:
            all_labels.add(lab)
    return list(all_labels) if found else None


# Ground extracted entities in the KG and infer their semantic labels.
# Returns situation codes for ambiguity/empty/attribute-only cases.
def discover_labels(entities):
    label_dic = {}     
    hit_any = False

    with graph.session() as session:
        for entity in entities:
            labels = session.execute_read(fetch_labels, entity)

            if labels is None:
                continue

            hit_any = True

            # situation 1: ambiguous semantic type
            # Here "multiple labels" means the entity cannot be uniquely typed for routing.
            if len(labels) > 1:
                return "situation 1"

            # exactly one label case
            label = labels[0]
            label_dic.setdefault(label, []).append(entity)

    # situation 2: no entity hits the KG
    if not hit_any:
        return "situation 2"

    # situation 3: only attribute-like labels, i.e., cannot anchor disease resolution
    if label_dic:
        found_anchor = any(label in DISEASE_ANCHOR_LABELS for label in label_dic)
        if not found_anchor:
            return "situation 3"
        
    return label_dic

In [None]:
# Retrieve the names of nodes that are connected to a specified end node by traversing a given relationship.
def fetch_related_nodes(tx, end_label, relationship, start_label, end_node_name):
    query = f"""MATCH (start:{start_label})-[r:{relationship}]->(end:{end_label}{{name: $end_node_name}}) RETURN start.name AS related_node_name"""
    result = tx.run(query, end_node_name=end_node_name)
    related_node_names = [record["related_node_name"] for record in result]
    return related_node_names if related_node_names else None

def get_related_nodes(driver, end_label, relationship, start_label, end_node_name):
    with driver.session() as session:
        return session.execute_read(fetch_related_nodes, end_label, relationship, start_label, end_node_name)

In [None]:
# Retrieve the value of a specified property from a given node.
def fetch_property_value(tx, label, name_property, name_value, property_key):
    query = f"""MATCH (n:{label}) WHERE n.{name_property}='{name_value}' RETURN n.{property_key} AS property_values"""
    result = tx.run(query, name_value=name_value)
    property_values =[record["property_values"] for record in result]
    return property_values if property_values else None

def get_property_value(driver, label, name_property, name_value, property_key):
    with driver.session() as session:
        return session.execute_read(fetch_property_value, label, name_property, name_value, property_key)

In [None]:
# Given a known node (should be a Disease) and a relationship type, retrieve the connected nodes from the Neo4j knowledge graph.
def fetch_end_nodes(tx, start_label, relationship, end_label, start_node_name):
    query = f"""MATCH (start:{start_label} {{name: $start_node_name}})-[r:{relationship}]->(end:{end_label}) RETURN end.name AS related_node_name"""
    result = tx.run(query, start_node_name=start_node_name)
    related_node_names = [record["related_node_name"] for record in result]
    return related_node_names if related_node_names else None

def get_end_nodes(driver, start_label, relationship, end_label, start_node_name):
    with driver.session() as session:
        return session.execute_read(fetch_end_nodes, start_label, relationship, end_label, start_node_name)


In [None]:
# Map grounded entities to candidate diseases.
def disease_search(key_labels):
    """
    Design notes:
    - Complications are modeled as Disease nodes (i.e., no separate 'complication' label is required).
    - Non-disease entities (symptom/drug/department/etc.) are mapped back to Disease via KG relationships.
    - Attribute-only intents (age/infection/insurance/cost/...) cannot infer a disease by themselves.
    - Input contains exactly one label with one or more keywords. (assumed by previous steps)
    - Return the intersection across all keyword-supported disease sets (diseases supported by every keyword).
    """

    # These labels represent entity types that can be traced back to Disease via graph traversal.
    ENTITY_TO_DISEASE_REL = {
        "alias": "病症别名",
        "part": "病痛的部位",
        "department": "疾病所属部门",
        "symptom": "疾病症状",
        "drug": "疾病所需药物",
    }
    
    disease_sets = []
    
    # Extract the single label and its keywords
    (label, keywords), = key_labels.items()
    keywords = [k for k in keywords if k]
    
    # If the single label is Disease, then the keywords themselves are diseases.
    # For multiple diseases, the "intersection" is only meaningful if they are identical.
    if label == "Disease":
        return sorted(set(keywords))  # keep as candidates; disambiguate upstream if >1
    
    rel = LABEL_TO_REL[label]

    # Build disease sets per keyword, then hard-intersect them.
    intersection = None

    for kw in keywords:
        # Reverse lookup: find diseases.
        related = get_related_nodes(graph, label, rel, "Disease", kw)
        disease_set = set(related) if related else set()

        if intersection is None:
            intersection = disease_set
        else:
            intersection &= disease_set

        # Early stop: once empty, no disease satisfies all keywords
        if not intersection:
            return []

    return sorted(intersection) if intersection is not None else []

In [None]:
# output templated answer, 'we found that your are asking for disease XXX, and the information you want to know about this is XXX'
def _missing():
    # EN: Sorry, the answer is missing in the knowledge graph.
    return "抱歉，答案缺失"

def _format_list(items):
    # Join list items using Chinese separator
    return "、".join(items) + "。"

# Relation-based intentions
# These intents require traversing edges in the knowledge graph.
RELATION_INTENTS = {
    "alias":       ("病症别名", "alias",        "{d}的别名是{a}"), # EN: Query disease aliases (synonyms)
    "part":        ("病痛的部位", "part",        "{d}病痛的部位是{a}"), # EN: Query affected body parts
    "department":  ("疾病所属部门", "department","{d}应该去{a}挂号"), # EN: Query the clinical department to visit
    "symptom":     ("疾病症状", "symptom",      "{d}的症状有{a}"), # EN: Query disease symptoms
    "drug":        ("疾病所需药物", "drug",      "{d}的治疗药物是{a}"), # EN: Query commonly used medications
    "complication":("疾病并发症", "Disease",    "{d}的并发症是{a}"), # EN: Query disease complications
}

# Property-based intentions
# These intents correspond to attributes of the Disease node.
PROPERTY_INTENTS = {
    "age":       "{a}是得{d}的高风险人群。", # EN: Query high-risk age groups or populations
    "infection": "{d}{a}", # EN: Query whether the disease is infectious 
    "insurance": "{d}{a}", # EN: Query whether the disease is covered by health insurance
    "treatment": "{d}的治疗方案是{a}", # EN: Query treatment approaches
    "period":    "{d}的治疗时长是{a}", # EN: Query expected treatment duration
    "rate":      "{d}的治愈率是{a}", # EN: Query cure / remission rate
    "checklist": "{d}需要检查{a}", # EN: Query recommended medical tests
    "cost":      "治疗{d}的花费是{a}", # EN: Query treatment cost
}


def find_final_answer(given_entity, want_to_know):
    # EN: Detected disease: X.
    prefix = f"检测到您提问的疾病为：{given_entity}。"

    # Intent: only confirm the disease name
    if want_to_know == "Disease":
        return prefix

    # A) Relation-based intents
    if want_to_know in RELATION_INTENTS:
        rel_type, end_label, template = RELATION_INTENTS[want_to_know]

        # Query neighbor nodes
        answer = get_end_nodes(graph, "Disease", rel_type, end_label, given_entity)
        if not answer:
            return _missing()

        # Formatting: department reads better without an extra "。"
        if want_to_know == "department":
            formatted = "、".join(answer)
        else:
            formatted = _format_list(answer)

        return prefix + template.format(d=given_entity, a=formatted)

    # B) Property-based intentions
    elif want_to_know in PROPERTY_INTENTS:
        template = PROPERTY_INTENTS[want_to_know]

        # Retrieve property value from the Disease node
        answer = get_property_value(graph, "Disease", "name", given_entity, want_to_know)
        if not answer:
            return _missing()

        # Unified formatting for all properties
        formatted = _format_list(answer)

        return prefix + template.format(d=given_entity, a=formatted)


    # EN: This intent is not supported yet. Please ask about symptoms/treatment/tests/department, etc.
    return "暂不支持该类型的问题。请尝试询问：症状、治疗、检查、科室等。"

In [None]:
#Handles special cases such as ambiguity, irrelevance, or multiple diseases. Or output the answer
def get_answer(sentence):
    entities=key_word(sentence)
    key_labels=discover_labels(entities)
    if key_labels=='situation 1':
        # English: Some detected keywords correspond to multiple semantic types in the knowledge graph and require disambiguation. Please provide additional context or clarify whether you are referring to a disease, symptom, medication, or medical examination.
        return '检测到部分关键词在知识图谱中存在多种语义类型（需要消歧）。请补充上下文或明确您指的是疾病/症状/药物/检查中的哪一种。' 
    elif key_labels=='situation 2':
        # English: No medical-related keywords were identified in the question. Please provide a disease name, symptom, or medical test (e.g., fever, cough, ALS).
        return '未识别到与医疗相关的关键词。请提供疾病名称、症状或检查项目（例如：发热、咳嗽、渐冻症）。' #
    elif key_labels == 'situation 3':
        # English: Medical-related information was detected, but the identified keywords correspond only to attribute-level information and cannot be used to locate a specific disease. Please provide a disease name, symptom, or medication to continue.
        return '已识别到部分医疗相关信息，但这些关键词仅对应属性类信息，无法定位到具体疾病。请补充疾病名称、症状或用药信息以便继续查询。'
    else:
        t_disease=disease_search(key_labels)
        if len(t_disease)>1:
            # English: Your question appears to involve multiple diseases: … Please specify which disease you would like to focus on, and I will continue querying and answering based on that disease.
            return '检测到您的问题涉及多个疾病，分别是'+t_disease.join("，")+'，请告知您想优先了解哪一个疾病，我将为该疾病继续查询并回答。'
        elif len(t_disease)==0:
            # English: Your question cannot be resolved to a single disease, possibly because it involves multiple diseases.
            return '检测到您的问题无法锁定一个问题，可能因为涉及多个疾病。'
        else:
            intention = find_intention(sentence)
                if intention == 'out_of_scope':
                    # English: The question is considered out of scope because it does not involve medical-related information that the system can answer.
                    return '您的问题与医疗健康无关，当前系统仅支持回答疾病、症状、用药及相关医疗问题。'
            reply=find_final_answer(t_disease[0], intention)
            return reply

In [None]:
# Input questions
def main():
    while True:
        question = input("请问您有什么问题？") # English: "What question would you like to ask?"
        if not question.strip():
            print("请输入有效的问题。") # English: "Please enter a valid question."
            continue
        if question.strip() == "退出": # English: "exit / quit"
            print("再见！！！") # English: "Goodbye!"
            break
        try:
            answer = get_answer(question)
            print(answer)
        except Exception as e:
            print("系统暂时无法处理该问题，请稍后再试。") # English: "The system cannot process this question at the moment. Please try again later."

if __name__ == "__main__":
    main()