In [None]:
# tools/impl.py
from __future__ import annotations
import os, json, requests, pymysql
from typing import List, Dict, Any, Tuple

# ───────────────────────────────────────────
# 1) PubMed
# ───────────────────────────────────────────
PUBMED = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"

def pubmed_search(term: str, retmax: int = 10) -> List[str]:
    params = {"db": "pubmed", "term": term, "retmode": "json",
              "sort": "pub+date", "retmax": retmax}
    r = requests.get(f"{PUBMED}/esearch.fcgi", params=params, timeout=20)
    r.raise_for_status()
    return r.json()["esearchresult"]["idlist"]


In [2]:
from datasets import load_dataset

ds = load_dataset("DiligentDing/MAIA", split="full")  # loads the entire benchmark
ds[0]

{'id': 'ret_cacfe0e74802',
 'question': 'What is the PMID of the article titled “[Practical guideline for short bowel syndrome]” by first author Dabsch S, published in Zeitschrift fur Gastroenterologie in 2025?',
 'tool_calls': {'tool': ['pubmed.search'],
  'params': ['{"term":"\\"[Practical guideline for short bowel syndrome].\\"[ti] AND Dabsch S[au] AND 2025[dp]","retmax":1}']},
 'answer': ['40360142'],
 'type': 'retrieval'}

In [None]:
CTGOV = "https://clinicaltrials.gov/api/v2/studies"

def ctgov_search(filter_expr: str, page_size: int = 100) -> List[str]:
    r = requests.get(CTGOV, params={"filter": filter_expr,
                                    "pageSize": page_size}, timeout=20)
    r.raise_for_status()
    rows = r.json().get("studies", [])
    return [s["protocolSection"]["identificationModule"]["nctId"] for s in rows]

In [None]:
OT_URL = "https://api.platform.opentargets.org/api/v4/graphql"
def _ot_query(query: str) -> Dict:
    r = requests.post(OT_URL, json={"query": query}, timeout=20)
    r.raise_for_status()
    return r.json()["data"]

def ot_associated_diseases(target_id: str, min_score: float = 0.5) -> List[Dict]:
    q = f"""
    {{ target(ensemblId: "{target_id}") {{
        associatedDiseases {{ rows {{ disease {{id name}} score }} }} }} }}
    """
    rows = _ot_query(q)["target"]["associatedDiseases"]["rows"]
    return [r for r in rows if r["score"] >= min_score]
def ot_tractability(target_id: str, value: bool = True) -> List[Dict]:
    q = f"""
    {{ target(ensemblId: "{target_id}") {{
        tractability {{ modality label value }} }} }}
    """
    rows = _ot_query(q)["target"]["tractability"]
    return [r for r in rows if r["value"] is value]


def ot_safety(symbol: str, event: str) -> Dict:
    # 1. 查找target ID
    search_q = f'''
    {{
      search(queryString: "{symbol}") {{
        hits {{
          id
          entity
          description
        }}
      }}
    }}
    '''
    search_data = _ot_query(search_q)["search"]["hits"]
    # 过滤出entity为target的
    target_id = None
    for hit in search_data:
        if hit["entity"] == "target":
            target_id = hit["id"]
            break
    if not target_id:
        return {}
    
    # 2. 用ID查safetyLiabilities
    safety_q = f'''
    {{
      target(ensemblId: "{target_id}") {{
        safetyLiabilities {{
          event
          biosamples {{ tissueLabel tissueId }}
          effects {{ dosing direction }}
        }}
      }}
    }}
    '''
    rows = _ot_query(safety_q)["target"]["safetyLiabilities"]
    for r in rows:
        if r["event"].lower() == event.lower():
            return {"biosamples": r["biosamples"], "effects": r["effects"]}
    return {}

In [None]:
# ───────────────────────────────────────────
DB_CFG = dict(
    host="172.188.121.85",
    port=3306,
    user="root",
    password="1qaz0plm",
    database="umls",
    cursorclass=pymysql.cursors.DictCursor,
    autocommit=True,
)
_conn = pymysql.connect(**DB_CFG)

def umls_concept_lookup(name: str) -> str:
    with _conn.cursor() as cur:
        cur.execute("SELECT cui FROM MRCONSO WHERE STR = %s LIMIT 1", (name,))
        row = cur.fetchone()
        return row["cui"] if row else ""

def umls_get_related(from_cui: str, rela: str) -> List[str]:
    with _conn.cursor() as cur:
        cur.execute(
            "SELECT cui1 FROM MRREL WHERE cui2=%s AND rela=%s",
            (from_cui, rela))
        return [row["cui1"] for row in cur.fetchall()]
def umls_cui_to_name(cui: str) -> str:
    """
    输入CUI，返回英文名，优先PF/PT类型，否则任选一个。
    """
    sql = """
        SELECT STR, TTY
        FROM   MRCONSO
        WHERE  LAT='ENG' AND CUI=%s
    """
    with _conn.cursor() as cur:
        cur.execute(sql, (cui,))
        names = []
        pfpt = None
        for row in cur.fetchall():
            names.append(row["STR"])
            if row["TTY"] in ("PF", "PT"):
                pfpt = row["STR"]
        if pfpt:
            return pfpt
        if names:
            return names[0]
        return ""


In [None]:
umls_cui_to_name("C1521863")

In [None]:
import requests
from typing import Dict, List

CTGOV_API_URL = "https://clinicaltrials.gov/api/v2/studies"

def parse_filter_expr(filter_expr: str) -> Dict[str, str]:
    """
    解析形如 key1=val1;key2=val2 的表达式为 dict
    """
    res = {}
    for seg in filter_expr.split(";"):
        if "=" in seg:
            k, v = seg.split("=", 1)
            res[k.strip()] = v.strip()
    return res

import requests

def ctgov_search(filter_expr: str, page_size: int = 100):
    filters = {}
    interventions = []
    locations = {}
    # 解析 filter_expr
    for seg in filter_expr.split(";"):
        if "=" in seg:
            k, v = seg.split("=", 1)
            k, v = k.strip(), v.strip()
            if k == "interventions.name":
                interventions.append(v)
            elif k == "locations.country":
                locations["country"] = v
            elif k == "startDateFrom":
                filters.setdefault("startDate", {})["from"] = v
            else:
                filters[k] = v

    if interventions:
        filters.setdefault("interventions", {})["name"] = interventions
    if locations:
        filters["locations"] = locations

    payload = {
        "filters": filters,
        "pageSize": page_size
    }

    r = requests.post(
        "https://clinicaltrials.gov/api/v2/studies",
        json=payload,
        timeout=20
    )
    r.raise_for_status()
    return r.json().get("studies", [])



In [None]:
# ctgov_search_fixed.py
import requests, urllib.parse
from typing import List, Dict, Optional

BASE_URL = "https://clinicaltrials.gov/api/v2/studies"

def _build_params(expr: str,
                  page_size: int,
                  page_token: Optional[str]) -> Dict[str, str]:
    params: Dict[str, str] = {
        "pageSize": str(page_size),
        "countTotal": "false",
        "markupFormat": "markdown",
    }
    interventions, advanced = [], []

    for seg in filter(None, (s.strip() for s in expr.split(";"))):
        k, v = map(str.strip, seg.split("=", 1))

        if k == "conditions":
            params["query.cond"] = v

        elif k == "interventions.name":
            interventions.append(v)

        elif k == "locations.country":
            params["query.locn"] = f'"{v}"'          # 加引号避免拆词

        elif k == "overallStatus":
            params["filter.overallStatus"] = v.upper()

        elif k == "studyType":
            advanced.append(f"AREA[StudyType]({v.upper()})")

        elif k == "startDateFrom":
            advanced.append(f"AREA[StartDate]RANGE[{v},MAX]")

        else:
            raise ValueError(f"Unsupported key: {k}")

    if interventions:
        params["query.intr"] = " AND ".join(interventions)
    if advanced:
        params["filter.advanced"] = " AND ".join(advanced)
    if page_token:
        params["pageToken"] = page_token
    return params


def ctgov_search(expr: str, page_size: int = 100) -> List[Dict]:
    studies, token = [], None
    while True:
        params = _build_params(expr, page_size, token)
        r = requests.get(BASE_URL, params=params, timeout=30)
        if r.status_code != 200:
            # 打印可读 URL 与返回文本帮助调试
            print("ERROR  :", r.status_code, r.reason)
            print("URL    :", urllib.parse.unquote(r.url))
            print("BODY   :", r.text[:500], "...")
            r.raise_for_status()

        data = r.json()
        studies.extend(data.get("studies", []))
        token = data.get("nextPageToken")
        if not token:
            break
    return studies


# ------------------- quick demo -------------------
if __name__ == "__main__":
    fx = (
        "overallStatus=COMPLETED;"
        "studyType=INTERVENTIONAL;"
        "conditions=Multiple Myeloma;"
        "interventions.name=CC-5013;"
        "interventions.name=Dexamethasone;"
        "startDateFrom=2003-01-01;"
        "locations.country=United States"
    )
    lst = ctgov_search(fx)
    print("Fetched:", len(lst))
    for s in lst[:5]:
        ident = s["protocolSection"]["identificationModule"]
        print(ident["nctId"], "-", ident["briefTitle"])


In [None]:
import requests, urllib.parse
from typing import Dict, List, Optional

CTGOV = "https://clinicaltrials.gov/api/v2/studies"

def _build_params(expr: str,
                  page_size: int = 100,
                  page_token: Optional[str] = None) -> Dict[str, str]:
    """
    将形如 'key=value;key=value' 的过滤串转为 v2 API 参数。
    支持的 key:
      overallStatus, studyType, conditions, interventions.name,
      locations.country, startDateFrom
    """
    params: Dict[str, str] = {
        "pageSize": str(page_size),
        "countTotal": "false",
        "markupFormat": "markdown",
    }
    interventions, advanced = [], []

    for seg in filter(None, (s.strip() for s in expr.split(";"))):
        k, v = map(str.strip, seg.split("=", 1))

        if k == "conditions":
            params["query.cond"] = v

        elif k == "interventions.name":
            interventions.append(v)

        elif k == "locations.country":
            params["query.locn"] = f'"{v}"'    # 用引号避免拆词

        elif k == "overallStatus":
            params["filter.overallStatus"] = v.upper()

        elif k == "studyType":
            advanced.append(f"AREA[StudyType]({v.upper()})")

        elif k == "startDateFrom":
            advanced.append(f"AREA[StartDate]RANGE[{v},MAX]")

        else:
            raise ValueError(f"Unsupported key: {k}")

    if interventions:
        params["query.intr"] = " AND ".join(interventions)
    if advanced:
        params["filter.advanced"] = " AND ".join(advanced)
    if page_token:
        params["pageToken"] = page_token
    return params


def ctgov_search(filter_expr: str, page_size: int = 100) -> List[str]:
    """返回符合条件的 NCT ID 列表"""
    studies, token = [], None
    while True:
        params = _build_params(filter_expr, page_size, token)
        r = requests.get(CTGOV, params=params, timeout=30)
        r.raise_for_status()

        data = r.json()
        studies.extend(data.get("studies", []))
        token = data.get("nextPageToken")
        if not token:
            break

    # 提取 NCT ID
    return [
        s["protocolSection"]["identificationModule"]["nctId"]
        for s in studies
    ]

In [None]:
# demo_agent.py
"""
示例：把所有 schema 注册给 Azure OpenAI，
收到 tool_call 请求后执行 tools.impl 中的真实函数，再回传。
"""
import os
import json, importlib
from openai import AzureOpenAI
from tools.schema import ALL_SCHEMAS
from tools import impl      as T  # 引入实现
from typing import Optional
from openai import OpenAI
import re
os.environ["OPENAI_API_KEY"] = 'sk-10Dnh4QEV9v2TJQMI8VA4XHSZgB0J0DzfCInl4holuLlNjpH' 
os.environ['OPENAI_BASE_URL'] = 'https://api2.aigcbest.top/v1' 
os.environ["AZURE_OPENAI_API_KEY"] = "5a1437f6ff2648b9b969507fb5a73276"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://ai-mistraleastus2753718354821.openai.azure.com/"
os.environ["YOUR_DEPLOYMENT"] = "gpt-4.1-noah"  # 替换为你的部署名
# client = AzureOpenAI(
#     api_key=os.getenv("AZURE_OPENAI_API_KEY"),
#     azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
#     api_version="2024-12-01-preview"
# )

# 映射工具名 → Python 函数
# TOOLS = {
#     "pubmed.search":        T.pubmed_search,
#     "ctgov_search":         T.ctgov_search,
#     "opentargets.search":   T.ot_associated_diseases,
#     "opentargets.tractability": T.ot_tractability,
#     "opentargets.safety":   T.ot_safety,
#     "umls.concept_lookup":  T.umls_concept_lookup,
#     "umls.get_related":     T.umls_get_related,
#     "oncology.path_query":  T.oncology_path_query,
# }
TOOLS = {
    "pubmed_search":            T.pubmed_search,
    "ctgov_search":             T.ctgov_search,
    "opentargets_search":       T.ot_associated_diseases,
    "opentargets_tractability": T.ot_tractability,
    "opentargets_safety":       T.ot_safety,
    "umls_concept_lookup":      T.umls_concept_lookup,
    "umls_get_related":         T.umls_get_related,
    "oncology_path_query":      T.oncology_path_query,
}
# demo_agent.py  顶部 import 之后加
from copy import deepcopy
def make_safe_schemas(schemas):
    safe = []
    for sch in schemas:
        fn = deepcopy(sch)
        fn["name"] = fn["name"].replace(".", "_")
        assert re.match(r"^[a-zA-Z0-9_-]+$", fn["name"])
        safe.append({"type": "function", "function": fn})
    return safe

def ensure_function_wrapping(schemas: list[dict]):
    """确保每个 schema 都包含 {"type":"function", "function": ...} 结构"""
    new_list = []
    for sch in schemas:
        if "type" in sch:              # 已是新格式
            new_list.append(sch)
        else:                          # 旧格式 → 包装
            new_list.append({
                "type": "function",
                "function": deepcopy(sch)
            })
    return new_list


RAW_SCHEMAS = ALL_SCHEMAS 
SAFE_SCHEMAS = make_safe_schemas(RAW_SCHEMAS)          # 名字变下划线
ALL_SCHEMAS   = ensure_function_wrapping(SAFE_SCHEMAS) # ←覆盖全局变量


def chat_once(
        user_text: str,
        model_name: Optional[str] = None      # ← 替换掉  str | None
    ) -> str:

    if model_name is None:
        model_name = os.getenv("YOUR_DEPLOYMENT")

    messages = [
        {"role": "system", "content": "You are an assistant."},
        {"role": "user",   "content": user_text}
    ]

    while True:
        # client = OpenAI(
        #     api_key="sk-5157bc95a7ee4f7e9086a80fd41c69fc",
        #     base_url="https://api.deepseek.com/v1"
        # )
        client = OpenAI()
        rsp = client.chat.completions.create(
            model=model_name,          # ← 动态部署名
            messages=messages,
            tools=ALL_SCHEMAS,
            tool_choice="auto"
        )
        msg = rsp.choices[0].message
        if msg.tool_calls:                 # 有工具调用：执行并回传
            for tc in msg.tool_calls:
                fn   = TOOLS[tc.function.name]
                args = json.loads(tc.function.arguments)
                result = fn(**args)
                messages.append(msg)
                messages.append({
                    "role": "tool",
                    "tool_call_id": tc.id,
                    "name": tc.function.name,
                    "content": json.dumps(result, ensure_ascii=False)
                })
            continue                       # 再让模型整合结果
        else:
            return msg.content.strip()


if __name__ == "__main__":
    print(chat_once("Which PMIDs match “A rare case of rectal malignant melanoma with long-term survival ...”?","deepseek-ai/DeepSeek-R1"))


In [None]:
# evaluate_retrieval.py
import json
from pathlib import Path
from tqdm import tqdm
# from demo_agent import chat_once

# ---------- 配置 ----------
IN_FILE   = Path("MAIA/MAIA_retrieval.json")
OUT_FILE  = Path("retrieval_answers.json")
MODEL     = "deepseek-r1"          # ← 替换为真实部署名
BATCH_SIZE = 10                    # 每 10 条写盘一次

# ---------- 读数据 ----------
items = json.loads(IN_FILE.read_text(encoding="utf-8"))
items = items["dataset"] if isinstance(items, dict) else items

# ---------- 读取已完成答案 ----------
answered: dict[str, str] = {}
if OUT_FILE.exists():
    answered = json.loads(OUT_FILE.read_text(encoding="utf-8"))

pending = {}   # 本轮新增，达到 BATCH_SIZE 时写盘

# ---------- 主循环 ----------
for idx, item in enumerate(tqdm(items, desc="QA")):
    key = str(idx)
    if key in answered:            # 已有 → 跳过
        continue

    try:
        reply = chat_once(item["question"], MODEL)
    except Exception as e:
        reply = f"[ERROR] {e!s}"

    answered[key] = reply
    pending[key]  = reply

    # —— 批量落盘 ——
    if len(pending) >= BATCH_SIZE:
        OUT_FILE.write_text(json.dumps(answered, ensure_ascii=False, indent=2))
        pending.clear()

# ---------- 收尾：保存剩余未写出的 ----------
if pending:
    OUT_FILE.write_text(json.dumps(answered, ensure_ascii=False, indent=2))

print(f"✅  共 {len(answered)} answers saved to {OUT_FILE}")


In [None]:
# demo_agent.py
"""
示例：把所有 schema 注册给 Azure OpenAI，
收到 tool_call 请求后执行 tools.impl 中的真实函数，再回传。
"""
import os
import json, importlib
from openai import AzureOpenAI
from tools.schema import ALL_SCHEMAS
from tools import impl      as T  # 引入实现
from typing import Optional

os.environ["AZURE_OPENAI_API_KEY"] = "5a1437f6ff2648b9b969507fb5a73276"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://ai-mistraleastus2753718354821.openai.azure.com/"
os.environ["YOUR_DEPLOYMENT"] = "gpt-4.1-noah"  # 替换为你的部署名
client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_version="2024-12-01-preview"
)

# 映射工具名 → Python 函数
TOOLS = {
    "pubmed.search":        T.pubmed_search,
    "ctgov_search":         T.ctgov_search,
    "opentargets.search":   T.ot_associated_diseases,
    "opentargets.tractability": T.ot_tractability,
    "opentargets.safety":   T.ot_safety,
    "umls.concept_lookup":  T.umls_concept_lookup,
    "umls.get_related":     T.umls_get_related,
    "oncology.path_query":  T.oncology_path_query,
}

# demo_agent.py  顶部 import 之后加
from copy import deepcopy

def ensure_function_wrapping(schemas: list[dict]):
    """确保每个 schema 都包含 {"type":"function", "function": ...} 结构"""
    new_list = []
    for sch in schemas:
        if "type" in sch:              # 已是新格式
            new_list.append(sch)
        else:                          # 旧格式 → 包装
            new_list.append({
                "type": "function",
                "function": deepcopy(sch)
            })
    return new_list

ALL_SCHEMAS = ensure_function_wrapping(ALL_SCHEMAS)   # ← 覆盖原变量

def chat_once(
        user_text: str,
        model_name: Optional[str] = None      # ← 替换掉  str | None
    ) -> str:

    if model_name is None:
        model_name = os.getenv("YOUR_DEPLOYMENT")

    messages = [
        {"role": "system", "content": "You are an assistant."},
        {"role": "user",   "content": user_text}
    ]

    while True:
        rsp = client.chat.completions.create(
            model=model_name,          # ← 动态部署名
            messages=messages,
            tools=ALL_SCHEMAS,
            tool_choice="auto"
        )
        msg = rsp.choices[0].message
        if msg.tool_calls:                 # 有工具调用：执行并回传
            for tc in msg.tool_calls:
                fn   = TOOLS[tc.function.name]
                args = json.loads(tc.function.arguments)
                result = fn(**args)
                messages.append(msg)
                messages.append({
                    "role": "tool",
                    "tool_call_id": tc.id,
                    "name": tc.function.name,
                    "content": json.dumps(result, ensure_ascii=False)
                })
            continue                       # 再让模型整合结果
        else:
            return msg.content.strip()


if __name__ == "__main__":
    print(chat_once("In ClinicalTrials.gov, what interventional studies for relapsing-remitting multiple sclerosis involving Cladribine 5.25 mg/kg were started in April 2005?"))


In [2]:

import random
import numpy as np
import json
import time
import difflib
import copy
import os
import pymysql
import torch
from multiprocessing import Pool
import networkx as nx
from community import community_louvain


In [3]:
def network_analysis(network_json):
    start = time.time()
    graph = nx.Graph()
    for e in network_json['edges']:
        graph.add_edge(e['source'], e['target'])
    partition = community_louvain.best_partition(graph)
    degree_centrality = nx.degree_centrality(graph)
    degree = nx.degree(graph)
    clustering = nx.clustering(graph)
    betweenness = nx.betweenness_centrality(graph)
    # average_clustering = nx.average_clustering(graph)
    return {'partition': partition,
            'degree_centrality': degree_centrality,
            'degree': degree,
            'clustering': clustering,
            'betweenness': betweenness, }


In [6]:
a=list(map(int,input().split()))
print(a)

[1, 2, 3, 4]


In [9]:
mat=[list(map(int,input().split())) for _ in range(2)]
print(mat)

[[1, 2, 3], [4, 5, 6]]


In [11]:
s=input().strip()
print(s)

aaaad


In [1]:
def quicksort(nums):
    if len(nums)<=1:
        return nums
    pivot=nums[len(nums)//2]
    left=[x for x in nums if x<pivot]
    middle=[x for x in nums if x==pivot]
    right=[x for x in nums if x>pivot]
    return quicksort(left)+middle+quicksort(right)
nums=[3, 6, 8, 10, 1, 2, 1]
print(quicksort(nums))

[1, 1, 2, 3, 6, 8, 10]


In [7]:
x=5
y=x/2
a=0.001
while abs(y*y-x)>1e-6:
    grad=4*y*(y*y-x)
    y=y-a*grad
print(y)

2.2360681944690217


In [None]:
from functools import cache
def minfunc(s,t):
    m=len(s)
    n=len(t)
    @cache
    def dfs(i,j):
        if i<0:
            return j+1
        if j<0:
            return i+1
        if s[i]==t[j]:
            return dfs(i-1,j-1)
        return min(dfs(i-1,j-1),dfs(i-1,j),dfs(i,j-1))+1
    return dfs(m-1,n-1)
word1=input().strip()
word2=input().strip()
res=minfunc(word1,word2)
print(res)

3


In [16]:

def maxprofit(nums):
    min_price=nums[0]
    max_profit=0
    for x in nums:
        profit=x-min_price
        max_profit=max(max_profit,profit)
        min_price=min(min_price,x)
    return max_profit

In [19]:
nums=list(map(int,input().split()))
print(maxprofit(nums))

0


In [22]:
def maxnum(nums):
    min_s=0
    ans=nums[0]
    s=0
    for x in nums:
        s+=x
        ans=max(ans,s-min_s)
        min_s=min(s,min_s)
    return ans

In [23]:
maxnum([-2,1,-3,4,-1,2,1,-5,4])

6

In [26]:
def longestPalindrome(s):
    n=len(s)
    if n<2:
        return s
    dp=[[False]*n for _ in range(n)]
    for i in range(n):
        dp[i][i]=True
    start=0
    max_len=1
    for j in range(1,n):
        for i in range(j):
            if s[i]==s[j]:
                if j-i+1<=2:
                    dp[i][j]=True
                else:
                    dp[i][j]=dp[i+1][j-1]
            else:
                dp[i][j]=False
            if dp[i][j] and (j-i+1>max_len):
                max_len=j-i+1
                start=i
    return s[start:start+max_len]

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention (nn.Module):
    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model=d_model
        self.num_heads=num_heads
        self.d_k=d_model//num_heads

        self.W_q=nn.Linear(d_model,d_model)
        self.W_k=nn.Linear(d_model,d_model)
        self.W_v=nn.Linear(d_model,d_model)
        self.W_o=nn.Linear(d_model,d_model)
    def forward(self,x,mask=None):
        batch_size,seq_len,_=x.size()
        Q=self.W_q(x).view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)
        K=self.W_k(x).view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)
        V=self.W_v(x).view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)

        scores=torch.matmul(Q,K.transpose(-2,-1))/torch.sqrt(torch.tensor(self.d_k,dtype=torch.float32))
        if mask is not None:
            scores=scores.masked_fill(mask==0,float('-inf'))

        attention_weights=F.softmax(scores,dim=-1)

        context=torch.matmul(attention_weights,V)
        context=context.transpose(1,2).contiguous()
        context=context.view(batch_size,seq_len,self.d_model)

        output=self.W_o(context)
        return output
if __name__=="__main__":
    d_model=512
    num_heads=8
    batch_size=1
    seq_len=5

    mha=MultiHeadAttention(d_model,num_heads)
    x=torch.randn(batch_size,seq_len,d_model)
    output=mha(x)
    print(x.shape)
    print(output.shape)

torch.Size([1, 5, 512])
torch.Size([1, 5, 512])


In [2]:
import networkx as nx
from community import community_louvain
from collections import Counter, defaultdict
import logging, pprint
logging.basicConfig(level=logging.DEBUG)

class GraphAnalyzer:
    def __init__(self, graph):
        """
        :param graph: 直接传入 {"nodes": [...], "edges": [...]} 结构
        """
        self.graph_json = graph  # 原始图数据
        self.G = nx.Graph()

        # 构建 NetworkX 图
        for e in graph["edges"]:
            self.G.add_edge(e["source"], e["target"], weight=e.get("weight", 1))

        # 新增缓存字段
        self.last_analysis = {}

    def analyze_existing_graph(self, add_community_name=True, sort_by="degree", sort_reverse=True):
        colors = ['#065758', '#25767B', '#44949E', '#62B3C1', '#81D2E4', '#BCEDD8']
        partition = community_louvain.best_partition(self.G)
        community_size = Counter(partition.values())
        degree = dict(self.G.degree())
        degree_centrality = nx.degree_centrality(self.G)
        betweenness = nx.betweenness_centrality(self.G)
        clustering = nx.clustering(self.G)

        rep_name = {}
        if add_community_name:
            for cid in community_size:
                members = [n for n, c in partition.items() if c == cid]
                hub = max(members, key=lambda n: degree[n])
                node_obj = next(n for n in self.graph_json["nodes"] if n["id"] == hub)
                rep_name[cid] = node_obj.get("name", node_obj.get("label", f"Community_{cid}"))

        nodes_list = []
        id2node = {n["id"]: n for n in self.graph_json["nodes"]}
        for nid, cid in partition.items():
            n = id2node[nid]
            node_data = {
                "id": nid,
                "community_id": cid,
                "community_size": community_size[cid],
                "degree": degree[nid],
                "degree_centrality": degree_centrality[nid],
                "betweenness": betweenness[nid],
                "clustering": clustering[nid],
                "style": {"fill": colors[cid % len(colors)]}
            }

            if add_community_name:
                node_data["community_name"] = rep_name[cid]

            for k, v in n.items():
                if k not in node_data:
                    node_data[k] = v

            nodes_list.append(node_data)

        metrics = {
            "partition": partition,
            "community_size": community_size,
            "degree": degree,
            "degree_centrality": degree_centrality,
            "betweenness": betweenness,
            "clustering": clustering,
        }

        if sort_by:
            nodes_list.sort(key=lambda x: x.get(sort_by, 0), reverse=sort_reverse)

        # 缓存分析结果
        self.last_analysis = {
            "nodes_list": nodes_list,
            "metrics": metrics,
            "rep_name": rep_name,
            "colors": colors
        }

        return nodes_list, metrics

    def community_analysis(self):
        """
        基于最近一次 analyze_existing_graph 的结果，返回完整社区信息。
        """
        if not self.last_analysis:
            raise ValueError("请先调用 analyze_existing_graph")

        metrics = self.last_analysis["metrics"]
        rep_name = self.last_analysis.get("rep_name")
        colors = self.last_analysis.get("colors")

        community_size = metrics["community_size"]

        communities = []
        for cid in sorted(community_size.keys()):
            community = {
                "community_id": cid,
                "community_size": community_size[cid],
                "color": colors[cid % len(colors)] if colors else "#000000"
            }
            if rep_name and cid in rep_name:
                community["community_name"] = rep_name[cid]
            communities.append(community)

        return communities

    def community_analysis_new(self, classify_by_order=False):
        """
        基于最近一次 analyze_existing_graph 的结果，返回按照一阶和二阶邻居划分的社区信息，并根据节点数量从高到低排序。

        如果 classify_by_order=True，则返回分类为一阶和二阶的社区。
        否则返回原始社区结构。
        """
        if not self.last_analysis:
            raise ValueError("请先调用 analyze_existing_graph")

        nodes_list = self.last_analysis["nodes_list"]
        # metrics = self.last_analysis["metrics"]
        rep_name = self.last_analysis.get("rep_name")
        colors = self.last_analysis.get("colors")

        # Step 1: 构建社区结构并收集成员
        communities = defaultdict(list)
        for node in nodes_list:
            cid = node["community_id"]
            communities[cid].append(node)

        # Step 2: 为每个社区找中心节点（度最大的节点）
        center_map = {}
        for cid, members in communities.items():
            center = max(members, key=lambda n: n["degree"])  # 根据 degree 找出中心节点
            center_map[cid] = center["id"]  # {cid: center_node_id}

        # Step 3: 构建详细社区信息，并判断是否属于一阶或二阶
        detailed_communities = []
        first_order_communities = []
        second_order_communities = []

        for cid, members in communities.items():
            center = center_map[cid]  # 使用真正的中心节点 ID

            # 获取该中心节点的一阶、二阶邻居
            first_neighbors = set(self.G.neighbors(center))
            second_neighbors = set()
            for neighbor in first_neighbors:
                second_neighbors.update(self.G.neighbors(neighbor))
            logging.debug(f"{center=}")
            logging.debug(f"(before cleanup) {second_neighbors=}")

            # 清理：去掉自己和重复项
            second_neighbors -= first_neighbors
            second_neighbors.discard(center)
            
            logging.debug(f"(after  cleanup) {second_neighbors=}")
            assert center not in second_neighbors

            # 社区成员 = 中心 + 一阶 + 二阶（理论上）
            community_members = {center} | first_neighbors | second_neighbors
            community_nodes = [node for node in nodes_list if node["id"] in community_members]

            # 实际社区成员（来自图算法划分的结果）
            actual_community_member_ids = {node["id"] for node in members}

            # 判断是否有超出一阶邻居范围的成员
            non_first_order_members = actual_community_member_ids - ({center} | first_neighbors)
            is_second_order = len(non_first_order_members) > 0

            # 构造社区对象
            detailed_community = {
                "community_id": cid,
                "community_size": len(members),
                "center": center,
                "color": colors[cid % len(colors)] if colors else "#000000",
                "nodes": community_nodes
            }
            if rep_name and cid in rep_name:
                detailed_community["community_name"] = rep_name[cid]

            detailed_communities.append(detailed_community)

            if is_second_order:
                second_order_communities.append(detailed_community)
            else:
                first_order_communities.append(detailed_community)

        # 按照社区大小排序
        detailed_communities.sort(key=lambda x: x["community_size"], reverse=True)
        first_order_communities.sort(key=lambda x: x["community_size"], reverse=True)
        second_order_communities.sort(key=lambda x: x["community_size"], reverse=True)

        # 新增：统计一阶和二阶社区的总节点数和社区数量
        first_order_total_nodes = sum(comm["community_size"] for comm in first_order_communities)
        second_order_total_nodes = sum(comm["community_size"] for comm in second_order_communities)
        first_order_community_count = len(first_order_communities)
        second_order_community_count = len(second_order_communities)

        # 如果需要分类返回
        if classify_by_order:
            return {
                "first_level_community": first_order_communities[:10],  # Top10
                "second_level_community": second_order_communities[:10],
                "first_order_total_nodes": first_order_total_nodes,  # 一阶社区总节点数
                "second_order_total_nodes": second_order_total_nodes,
                "first_order_community_count": first_order_community_count,  # 一阶社区数量
                "second_order_community_count": second_order_community_count
            }

        return detailed_communities

    def edge_analysis(self, top_n=10):
        """
        链接分析，返回图中边的信息。
        """
        edges = sorted(
            self.graph_json["edges"],
            key=lambda e: e.get("weight", 1),
            reverse=True
        )[:top_n]

        return {
            "total_edges": len(self.graph_json["edges"]),
            "top_weighted_edges": edges
        }
