In [1]:
import pymysql, random, logging
import networkx as nx
from typing import Set, Union,Dict, List, Iterable
import pandas as pd
# ---------- 0. 数据源 ----------
DB_CFG = dict(
    host="172.188.121.85", port=3306,
    user="root", password="1qaz0plm",
    database="umls",
    cursorclass=pymysql.cursors.DictCursor
)

In [2]:
TEMPLATE_LIST = [
    {
        "id": "Symptom_Disease_Drug_Target",
        "steps": [
            {   # ① 起点：临床表现
                "stype": ["T184"],                 # Sign or Symptom
                "as": "Symptom",
                "rela":[]
            },
            {   # ② 由该症状提示的疾病
                "rela": ["manifestation_of"],      # CUI1(Symptom) → CUI2(Disease)
                "stype": ["T047"],                 # Disease or Syndrome
                "as": "Disease or Syndrome"
            },
            {   # ③ 疾病可被哪些药理物质治疗
                "rela": ["may_be_treated_by"],     # CUI1(Disease) → CUI2(Pharm. Substance)
                "stype": ["T121", "T109"],          # Pharmacologic Substance
                "as": "Pharmacologic Substance"
            },
            {   # ④ 该药物的主要分子靶点
                "rela": ["has_target"],            # CUI1(Pharm.) → CUI2(Protein)
                "stype": ["T116", "T126", "T192"], # Amino Acid, Peptide, or Protein
                "as": "Protein Target"
            }
        ]
    }
]

In [108]:
TEMPLATE_LIST=[
{
    "id": "Disease_Drug_Target",
    "steps": [
        {   "stype": ["T047"],                 # Disease or Syndrome   # 5854
            "as": "Disease or Syndrome",
            "rela":[]
        },
        {   # ③ 疾病可被哪些药理物质治疗
            "rela": ["may_be_treated_by"],     # CUI1(Disease) → CUI2(Pharm. Substance)
            "stype": ["T121","T109"],                 # Pharmacologic Substance  #7101
            "as": "Pharmacologic Substance"
        },
        {   # ④ 该药物的主要分子靶点
            "rela": ["has_target"],            # CUI1(Pharm.) → CUI2(Protein)
            "stype": ["T116","T126","T192"],                 # Amino Acid, Peptide, or Protein # 4894+1859+1699
            "as": "Protein Target"
        }
    ]
}

]




In [109]:
TEMPLATE_LIST=[{
    "id": "Disease_Drug_moA",
    "steps": [
        {   "stype": ["T047"],                 # Disease or Syndrome   # 5854
            "as": "Disease or Syndrome",
            "rela":[]
        },
        {   # ③ 疾病可被哪些药理物质治疗
            "rela": ["may_be_treated_by"],     # CUI1(Disease) → CUI2(Pharm. Substance)
            "stype": ["T121","T109"],                 # Pharmacologic Substance  #7101
            "as": "Pharmacologic Substance"
        },
        {   # ② 作用机制（分子功能）
            "rela": ["has_mechanism_of_action"],
            "stype": ["T044"],                  # Molecular Function
            "as": "Molecular Function"
        }
    ]
}]


In [3]:
# 一次查询不超过这么多占位符
CHUNK_SIZE = 1000         # 1000 × 20 字节 ≈ 20 KB，远低于 max_allowed_packet
import math
def load_cui2tui(cui_set: Set[str], chunk: int = CHUNK_SIZE) -> Dict[str, List[str]]:
    """
    仅查询参与边的 CUI → TUI；大集合自动分块，避免超过 max_allowed_packet
    """
    if not cui_set:
        return {}

    tui_map: Dict[str, List[str]] = {}

    cui_list = list(cui_set)
    n_chunks = math.ceil(len(cui_list) / chunk)

    with pymysql.connect(**DB_CFG) as conn, conn.cursor(pymysql.cursors.DictCursor) as cur:
        for i in range(n_chunks):
            batch = cui_list[i * chunk : (i + 1) * chunk]
            fmt   = ",".join(["%s"] * len(batch))
            sql   = f"SELECT CUI, TUI FROM MRSTY WHERE CUI IN ({fmt})"
            cur.execute(sql, batch)

            for r in cur.fetchall():
                tui_map.setdefault(r["CUI"], []).append(r["TUI"])

    return tui_map

In [4]:
def load_cui2str(cui_set: Set[str]) -> Dict[str, str]:
    if not cui_set:
        return {}

    placeholders = ",".join(["%s"] * len(cui_set))
    sql = f"""
        SELECT CUI, STR, TTY
        FROM   MRCONSO
        WHERE  LAT='ENG' AND CUI IN ({placeholders})
    """

    str_map = {}
    with pymysql.connect(**DB_CFG) as conn, conn.cursor(pymysql.cursors.DictCursor) as cur:
        cur.execute(sql, list(cui_set))
        for row in cur.fetchall():
            cui, str_, tty = row["CUI"], row["STR"], row["TTY"]
            # 先记录 PF/PT，否则保留第一条
            if tty in ("PF", "PT") or cui not in str_map:
                str_map[cui] = str_

    return str_map


In [5]:
def load_edges(rela_set: Iterable[str]):
    """
    只拉取模板需要的关系，返回包含 (CUI1, CUI2, REL, RELA) 的记录列表
    """
    fmt = ",".join(["%s"] * len(rela_set))
    sql = f"""
        SELECT CUI1, CUI2, REL, RELA
        FROM MRREL
        WHERE RELA IN ({fmt})
    """
    with pymysql.connect(**DB_CFG) as conn, conn.cursor(pymysql.cursors.DictCursor) as cur:
        cur.execute(sql, tuple(rela_set))
        return cur.fetchall()

In [6]:
def build_graph(rela_set: Set[str]) -> nx.MultiDiGraph:
    """
    构建 MultiDiGraph，并给每个节点写入两项属性：
        - tui: List[str]
        - str: 英文首选术语 (若缺则用 CUI 兜底)
    """
    # 1) 批量加载边
    edges = load_edges(rela_set)

    # 2) 收集所有出现的 CUI
    cui_set: Set[str] = {e["CUI1"] for e in edges} | {e["CUI2"] for e in edges}

    # 3) 一次性查询 TUI 和 STR
    tui_map = load_cui2tui(cui_set)
    str_map = load_cui2str(cui_set)

    # 4) 构图
    G = nx.MultiDiGraph()
    for e in edges:
        G.add_edge(
            e["CUI2"],
            e["CUI1"],
            rela=e["RELA"] or e["REL"]              # UMLS 中有的用 RELA，有的用 REL
        )

    # 5) 批量补节点属性
    for cui in G.nodes:
        G.nodes[cui]["tui"] = tui_map.get(cui, [])
        G.nodes[cui]["str"] = str_map.get(cui, cui)   # 查不到时直接用 CUI

    return G


In [7]:
def has_semtype(node_tuis: List[str], allowed_tui: List[str]) -> bool:
    return any(tui in allowed_tui for tui in node_tuis)

def sample_one(
        G: nx.MultiDiGraph,
        tpl_steps: List[Dict],
        max_attempts: int = 50_000
    ) -> Dict[str, List[str]]:
    """
    随模板随机采 1 条路径。若采样失败，回退到上一步并继续尝试，而不是从头开始。
    节点属性要求已包含：
        - tui : List[str]
        - str : 可读名称（若缺自动兜底成 CUI）
    返回:
        {
          "cuis":      [CUI0, CUI1, ...],
          "relas":     [rela0, rela1, ...],            # len = len(cuis) - 1
          "path_strs": [STR0, rela0, STR1, ...]        # 交错拼接，可直接打印
        }
    """
    # ---------- 1) 找所有可能的起点 ----------
    step0 = tpl_steps[0]
    cand_start = [
        n for n, d in G.nodes(data=True)
        if has_semtype(d["tui"], step0["stype"])
    ]
    if not cand_start:
        raise RuntimeError("No start node matches TUI constraint.")

    # ---------- 2) 多次尝试直到成功 ----------
    for _ in range(max_attempts):
        src = random.choice(cand_start)
        path = [src]  # CUI 序列
        relas = []  # 关系序列
        cur = src
        history = [(src, None)]  # 保存路径和对应的边

        # 从第二步开始尝试
        for step_idx, step in enumerate(tpl_steps[1:], start=1):
            # 过滤满足"关系 + 目标节点 TUI" 的出边
            nxt_edges = [
                (v, edata["rela"])
                for _, v, edata in G.out_edges(cur, data=True)
                if edata["rela"] in step["rela"]
                and has_semtype(G.nodes[v]["tui"], step["stype"])
            ]
            if not nxt_edges:
                if history:
                    # 回退到上一步
                    cur, _ = history.pop()
                    path = path[:len(path)-1]
                    if relas:
                        relas = relas[:len(relas)-1]
                    continue
                else:
                    break

            v, rela = random.choice(nxt_edges)
            cur = v
            path.append(cur)
            relas.append(rela)
            history.append((cur, rela))  # 记录当前节点和关系

        # 检查路径是否完整
        if len(path) == len(tpl_steps):
            # 交错拼装可读路径
            path_strs = []
            for i, cui in enumerate(path):
                node_str = G.nodes[cui].get("str", cui)
                path_strs.append(node_str)
                if i < len(relas):
                    path_strs.append(relas[i])  # 插入当前边

            return {
                "cuis": path,
                "relas": relas,
                "path_strs": path_strs
            }

    raise RuntimeError(
        f"No path found after {max_attempts} random starts. "
        "Consider relaxing template or increasing attempts."
    )


In [8]:
import json
# 收集所有 rela 以便一次性建图
# 只收集非空的关系
# 跳过第一步，因为第一步是起点，不需要入边
needed_rela = set()
for tpl in TEMPLATE_LIST:
    for step in tpl["steps"][1:]:  # 从第二步开始（索引1）
        if "rela" in step and step["rela"]:  # 确保有rela字段且非空
            for r in step["rela"]:
                if r:  # 只添加非空关系
                    needed_rela.add(r)


print(f"Building graph with relationships: {needed_rela}")
G = build_graph(needed_rela)

Building graph with relationships: {'manifestation_of', 'may_be_treated_by', 'has_target'}


In [116]:
G.number_of_nodes(), G.number_of_edges()

(7271, 22421)

In [11]:
import json, logging, random

UNIQUE_TRIES = 100    # 每个模板最多尝试多少次采样
N_PATHS       = 20  # 每个模板想要保留的唯一条数

def sample_unique_paths(G, tpl_steps, want=N_PATHS, max_tries=UNIQUE_TRIES):
    """返回去重后的路径列表（每条为 node-list）"""
    paths, seen = [], set()

    tries = 0
    while len(paths) < want and tries < max_tries:
        tries += 1
        try:
            p = sample_one(G, tpl_steps)        # ← 你的随机采样函数
        except RuntimeError as e:
            logging.warning("sample_one failed: %s", e)
            continue

        key = tuple(p["path_strs"])  # or tuple(p) / make_edge_key(p)
        if key not in seen:
            seen.add(key)
            paths.append(p)

    if len(paths) < want:
        logging.warning(
            "Only got %d unique paths (wanted %d) after %d attempts",
            len(paths), want, tries
        )
    return paths

out = {}
for tpl in TEMPLATE_LIST:
    out[tpl["id"]] = sample_unique_paths(G, tpl["steps"])

with open("sampled_paths.json", "w", encoding="utf8") as f:
    json.dump(out, f, indent=2, ensure_ascii=False)



In [118]:
# 路径完整性检查与分析
def analyze_paths(out):
    """ 分析生成路径的完整性和统计信息 """
    for tpl_id, paths in out.items():
        print(f"\n\n===== 模板: {tpl_id} =====\n")
        
        # 路径长度分析
        path_lens = [len(p['cuis']) for p in paths]
        print(f"Total paths: {len(paths)}")
        print(f"Path lengths distribution: {sorted(set(path_lens))}")
        for length in sorted(set(path_lens)):
            count = path_lens.count(length)
            print(f"Length {length}: {count} paths ({count/len(paths):.2%})")
        
        # 检查路径完整性 - 每个模板应有的长度
        expected_len = len(next(tpl['steps'] for tpl in TEMPLATE_LIST if tpl['id'] == tpl_id))
        incomplete = [p for p in paths if len(p['cuis']) != expected_len]
        if incomplete:
            print(f"\n\u8b66告: 发现 {len(incomplete)} 条不完整路径 (期望长度 {expected_len})!")
            # 显示前3条不完整路径示例
            for i, p in enumerate(incomplete[:3]):
                print(f"\n不完整路径示例 {i+1}:")
                print(f"CUIs: {p['cuis']}")
                print(f"Relationships: {p['relas']}")
                print(f"Path as text: {' -> '.join(p['path_strs'])}")
        else:
            print(f"\n✅ 所有路径均完整 (长度 = {expected_len})")
        
        # 成功路径示例
        print("\n成功路径示例:")
        complete_paths = [p for p in paths if len(p['cuis']) == expected_len]
        for i, p in enumerate(complete_paths[:3]):
            print(f"\n路径 {i+1}:")
            print(' -> '.join(p['path_strs']))

# 采样后进行路径分析
try:
    analyze_paths(out)
except NameError:
    print("\n请先运行采样代码 (sample_unique_paths) 生成路径数据")



===== 模板: Disease_Drug_moA =====

Total paths: 1500
Path lengths distribution: [3]
Length 3: 1500 paths (100.00%)

✅ 所有路径均完整 (长度 = 3)

成功路径示例:

路径 1:
Granuloma inguinale -> may_be_treated_by -> Doxycycline anhydrous -> has_mechanism_of_action -> Protein Synthesis Inhibitors

路径 2:
Ascariasis -> may_be_treated_by -> Pyrantel-containing product -> has_mechanism_of_action -> Cholinesterase Inhibitors

路径 3:
Infection by Fascioloides -> may_be_treated_by -> Emetine Hydrochloride -> has_mechanism_of_action -> Unknown Cellular or Molecular Interaction


In [121]:
#!/usr/bin/env python3
"""
merge_json_files.py

将 sampled_1.json 与 sampled_2.json 合并为 merged_paths.json。
保持各自的结构，合并相同键下的数据，并去除重复项。
"""

import json
from pathlib import Path
from collections import defaultdict

def load_json(path: Path):
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)

def main():
    # 文件路径
    in1 = Path("sampled_1.json")
    in2 = Path("sampled_2.json")
    out = Path("merged_paths.json")
    
    # 加载JSON文件
    data1 = load_json(in1)
    data2 = load_json(in2)
    
    # 初始化合并结果
    merged_data = defaultdict(list)
    
    # 合并逻辑
    # 首先处理第一个文件
    for template_id, paths in data1.items():
        for path in paths:
            merged_data[template_id].append(path)
    
    # 然后合并第二个文件，同时去重
    for template_id, paths in data2.items():
        if template_id not in merged_data:
            merged_data[template_id] = paths
            continue
            
        # 使用集合来跟踪已存在的路径
        seen_paths = set()
        for path in merged_data[template_id]:
            # 创建可哈希的键
            key = (
                tuple(path.get("cuis", [])), 
                tuple(path.get("relas", [])), 
                tuple(path.get("path_strs", []))
            )
            seen_paths.add(key)
        
        # 添加新的不重复路径
        for path in paths:
            key = (
                tuple(path.get("cuis", [])), 
                tuple(path.get("relas", [])), 
                tuple(path.get("path_strs", []))
            )
            if key not in seen_paths:
                merged_data[template_id].append(path)
                seen_paths.add(key)
    
    # 将结果写入新文件
    with out.open("w", encoding="utf-8") as f:
        json.dump(merged_data, f, ensure_ascii=False, indent=2)
    
    # 打印统计信息
    total_paths = sum(len(paths) for paths in merged_data.values())
    print(f"合并完成，共 {len(merged_data)} 个模板，{total_paths} 条路径 -> {out}")
    
    # 打印每个模板的路径数
    for template_id, paths in merged_data.items():
        print(f"  - {template_id}: {len(paths)} 条路径")

if __name__ == "__main__":
    main()

合并完成，共 2 个模板，2068 条路径 -> merged_paths.json
  - Symptom_Disease_Drug_Target: 568 条路径
  - Disease_Drug_moA: 1500 条路径
