# Generating Q&A Pairs from UMLS Knowledge Graph Paths

This notebook generates English question-answer pairs using multi-hop paths from the UMLS knowledge graph, where:
- Questions require multi-hop professional reasoning to solve
- Answers are typically the final entity in the path (with some exceptions where appropriate)
- Answers include brief reasoning logic for evaluation purposes
- Generated Q&A pairs are saved in JSON format

In [1]:
import os
import json
import time
import random
import tqdm
from pathlib import Path
from openai import AzureOpenAI
from collections import defaultdict
import re

# Azure OpenAI设置
os.environ["AZURE_OPENAI_API_KEY"] = "5a1437f6ff2648b9b969507fb5a73276"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://ai-mistraleastus2753718354821.openai.azure.com/"

# 初始化Azure OpenAI客户端
client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version="2024-12-01-preview",
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)

# Configuration parameters
MODEL = "gpt-4.1-noah"  # Using advanced model to ensure high-quality questions
TEMPERATURE = 0.7       # Increased temperature for diversity
RATE_LIMIT_S = 1.2      # Request interval to avoid API rate limits

In [2]:
# Load merged_paths.json file
def load_paths_data(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        paths_data = json.load(f)
    return paths_data

# Path and output files
paths_file = Path("/home/xinding/dingxin/Agent/MAIA/code/merged_paths.json")
output_file = Path("/home/xinding/dingxin/Agent/MAIA/code/umls_qa_pairs_english.json")  # English version

# Load path data
paths_data = load_paths_data(paths_file)
print(f"Loaded {sum(len(paths) for paths in paths_data.values())} paths")

# View data structure
template_ids = list(paths_data.keys())
print(f"Template types: {template_ids}")

# Randomly select a path to view its structure
template = random.choice(template_ids)
if paths_data[template]:
    sample_path = paths_data[template][0]
    print(f"\nSample path (Template: {template}):")
    if 'path_strs' in sample_path:
        print(f"path_strs: {sample_path['path_strs']}")
    else:
        print("Path structure:", sample_path.keys())
else:
    print(f"Template {template} has no available paths")

Loaded 2068 paths
Template types: ['Disease_Drug_Target', 'Disease_Drug_moA']

Sample path (Template: Disease_Drug_moA):
path_strs: ['Granuloma inguinale', 'may_be_treated_by', 'Doxycycline anhydrous', 'has_mechanism_of_action', 'Protein Synthesis Inhibitors']


In [None]:
def filter_valid_paths(paths_data, min_path_length=3):
    """Filter valid paths - keep only those with complete structure and sufficient length"""
    valid_paths = {}
    
    for template_id, paths in paths_data.items():
        valid_template_paths = []
        
        for path in paths:
            # 检查路径是否包含必要的字段并且长度足够
            if ('path_strs' in path and 
                isinstance(path['path_strs'], list) and 
                len(path['path_strs']) >= min_path_length):
                valid_template_paths.append(path)
        
        if valid_template_paths:
            valid_paths[template_id] = valid_template_paths
    
    return valid_paths    
    # Filter valid paths
valid_paths = filter_valid_paths(paths_data)
print(f"After filtering, there are {sum(len(paths) for paths in valid_paths.values())} valid paths")
for template_id, paths in valid_paths.items():
    print(f"Template {template_id}: {len(paths)} valid paths")

过滤后共有 2068 条有效路径
模板 Disease_Drug_Target: 568 条有效路径
模板 Disease_Drug_moA: 1500 条有效路径


In [None]:
def create_qa_prompt(path_info, template_id):
    """Create a high-quality English prompt for generating complex medical Q-A pairs"""

    path_strs = path_info["path_strs"]
    template_map = {
        "Disease_Drug_Target": "Disease → Drug → Target",
        "Disease_Drug_moA":    "Disease → Drug → Mechanism-of-Action"
    }
    template_desc = template_map.get(template_id, template_id)

    prompt = f"""
You are a senior medical educator who must craft challenging Q&A pairs from UMLS multi-hop reasoning paths.

Path type: {template_desc}
UMLS path: {' -> '.join(path_strs)}

**Instructions**

1. Write a question that requires **multi-step professional reasoning** along this path.
2. You may reveal **at most one or two** intermediate concepts as clues, but do NOT expose the full path.
3. Frame the question in a realistic clinical, pharmacological, or research scenario—professional and concise, not overly detailed.
4. The answer should normally be the terminal entity **\"{path_strs[-1]}\"**; if another node is more clinically sensible, use it and explain why.
5. After the entity, add a **succinct (≤ 40 words) clinical/pharmacological rationale** introduced by a semicolon, so evaluators can easily judge correctness.
6. Provide a short (20–40 words) “reasoning_path” summary showing the key medical logic without disclosing every node.
7. Output **only** the following JSON structure—no extra text:

{{
  "question": "...",
  "answer": "<Entity>; <≤40-word rationale>",
  "reasoning_path": "<20–40-word reasoning summary>",
  "umls_path": {path_strs},
  "template_id": "{template_id}"
}}
"""
    return prompt


In [5]:
def generate_qa_pair(path_info, template_id):
    """Use OpenAI to generate Q&A pairs"""
    prompt = create_qa_prompt(path_info, template_id)
    
    try:
        response = client.chat.completions.create(
            model=MODEL,
            temperature=TEMPERATURE,
            messages=[
                {"role": "system", "content": "You are a medical education expert specializing in creating high-quality medical reasoning questions in English."},
                {"role": "user", "content": prompt}
            ]
        )
        
        result = response.choices[0].message.content.strip()
        
        # 尝试解析返回的JSON
        try:
            # 查找JSON部分 (可能会有额外的文本)
            json_match = re.search(r'({[\s\S]*})', result)
            if json_match:
                result = json_match.group(1)
            
            qa_pair = json.loads(result)
            
            # Ensure required fields are present
            if "question" not in qa_pair or "answer" not in qa_pair:
                return None
                
            # Add original path data and metadata
            qa_pair["umls_path"] = path_info["path_strs"]
            qa_pair["template_id"] = template_id
            
            return qa_pair
            
        except json.JSONDecodeError:
            print(f"Failed to parse JSON: {result[:100]}...")
            return None
            
    except Exception as e:
        print(f"Error generating Q&A pair: {e}")
        return None

In [None]:


# Define the number of samples per template (now this is no longer needed, as we generate for all paths)
# samples_per_template = 50  # This is no longer needed
output_qa_pairs = []

# Iterate through all the templates and their associated paths
for template_id, paths in valid_paths.items():
    print(f"Generating Q&A pairs for template '{template_id}'...")
    
    # Iterate through all paths for this template (no random sampling)
    for path in tqdm.tqdm(paths, desc=f"Template: {template_id}"):
        qa_pair = generate_qa_pair(path, template_id)
        
        if qa_pair:
            output_qa_pairs.append(qa_pair)
            # Add a short delay to avoid API rate limits
            time.sleep(RATE_LIMIT_S)
            
            # Save intermediate results every 10 samples
            if len(output_qa_pairs) % 10 == 0:
                with open(output_file, "w", encoding="utf-8") as f:
                    json.dump(output_qa_pairs, f, ensure_ascii=False, indent=2)
                print(f"Saved {len(output_qa_pairs)} Q&A pairs")

# Save final results
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(output_qa_pairs, f, ensure_ascii=False, indent=2)
print(f"Finished generating and saving {len(output_qa_pairs)} Q&A pairs.")


In [14]:
# Save final results
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(output_qa_pairs, f, ensure_ascii=False, indent=2)

print(f"Successfully generated and saved {len(output_qa_pairs)} Q&A pairs to {output_file}")

# Display some examples
print("\nSample Q&A pairs:")
for i, qa in enumerate(random.sample(output_qa_pairs, min(3, len(output_qa_pairs)))):
    print(f"\nExample {i+1}:")
    print(f"Question: {qa['question']}")
    print(f"Answer: {qa['answer']}")
    print(f"Reasoning Path: {qa.get('reasoning_path', 'N/A')}")
    print(f"UMLS Path: {' -> '.join(qa['umls_path'])}")

Successfully generated and saved 100 Q&A pairs to /home/xinding/dingxin/Agent/MAIA/code/umls_qa_pairs_english.json

Sample Q&A pairs:

Example 1:
Question: In patients with neutropenia who require oral prophylactic antibiotics, which class of antimicrobial agents is utilized due to its ability to inhibit bacterial DNA replication by targeting a specific bacterial enzyme?
Answer: DNA Gyrase Inhibitors; because ciprofloxacin lactate, commonly used to prevent infections in neutropenic patients, acts by inhibiting bacterial DNA gyrase, thereby blocking DNA replication.
Reasoning Path: Neutropenia increases infection risk, often necessitating prophylactic antibiotics such as ciprofloxacin lactate. This drug's efficacy is due to its mechanism of action as a DNA gyrase inhibitor, which prevents bacterial DNA replication.
UMLS Path: neutropenia -> may_be_treated_by -> ciprofloxacin lactate -> has_mechanism_of_action -> DNA Gyrase Inhibitors

Example 2:
Question: Which molecular target is inhib

In [15]:
# Analyze generated Q&A pairs
def analyze_qa_pairs(qa_pairs):
    """Analyze the generated Q&A data"""
    
    template_counts = defaultdict(int)
    question_lengths = []
    answer_lengths = []
    has_reasoning = 0
    
    for qa in qa_pairs:
        template_counts[qa.get('template_id', 'unknown')] += 1
        question_lengths.append(len(qa.get('question', '')))
        answer_lengths.append(len(qa.get('answer', '')))
        
        # Check if answer includes reasoning (contains semicolon)
        if ';' in qa.get('answer', ''):
            has_reasoning += 1
    
    print("=== Q&A Data Analysis ===")
    print(f"Total Q&A pairs: {len(qa_pairs)}")
    print(f"Answers with reasoning: {has_reasoning} ({has_reasoning/len(qa_pairs):.1%})")
    
    print("\nTemplate distribution:")
    for template, count in template_counts.items():
        print(f"  {template}: {count} ({count/len(qa_pairs):.1%})")
    
    if question_lengths:
        print(f"\nQuestion length: Average {sum(question_lengths)/len(question_lengths):.1f} characters, " 
              f"Range {min(question_lengths)}-{max(question_lengths)} characters")
    
    if answer_lengths:
        print(f"Answer length: Average {sum(answer_lengths)/len(answer_lengths):.1f} characters, "
              f"Range {min(answer_lengths)}-{max(answer_lengths)} characters")

# 分析生成的数据
analyze_qa_pairs(output_qa_pairs)

=== Q&A Data Analysis ===
Total Q&A pairs: 100
Answers with reasoning: 100 (100.0%)

Template distribution:
  Disease_Drug_Target: 50 (50.0%)
  Disease_Drug_moA: 50 (50.0%)

Question length: Average 202.9 characters, Range 125-313 characters
Answer length: Average 194.7 characters, Range 140-279 characters


In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
merge_umls_paths_get_related_w_lookup.py
---------------------------------------
• 读   umls_qa_pairs.json
• 读   merged_paths.json
• 输出 enriched_umls_qa.json
  - id           sha256(question)[:16]
  - tool_calls   [umls.concept_lookup, umls.get_related, ...]
"""

import json, hashlib
from pathlib import Path
from typing import List, Dict, Tuple

QA_FILE   = "umls_qa_pairs.json"
PATH_FILE = "merged_paths.json"
OUT_FILE  = "enriched_umls_qa.json"

# ========== util ==========
def sha16(txt: str) -> str:
    return hashlib.sha256(txt.encode()).hexdigest()[:16]

def build_tool_calls(entity_name: str,
                     cuis: List[str],
                     relas: List[str]) -> List[Dict]:
    """
    构造调用序列：
      0) umls.concept_lookup  (name → CUI0)
      1) umls.get_related     (CUI0 → CUI1)
      2) umls.get_related     (CUI1 → CUI2) ...
      N) umls.concept_lookup  (CUI_last → name)
    """
    calls = [
        {
            "tool": "umls.concept_lookup",      # 名称查 CUI
            "params": {"name": entity_name}
        }
    ]
    for i, rela in enumerate(relas):
        calls.append({
            "tool": "umls.get_related",
            "params": {
                "from_cui": cuis[i],   # CUI_i
                "rela": rela
            }
        })
    # 最后补一条 CUI->name
    calls.append({
        "tool": "umls.cui_to_name",  # CUI_last 查名称
        "params": {
            "cui": cuis[-1]
        }
    })
    return calls


# ========== 1. 读取 ==========
qa_pairs  = json.loads(Path(QA_FILE).read_text(encoding="utf-8"))
path_meta = json.loads(Path(PATH_FILE).read_text(encoding="utf-8"))

#  (template_id, tuple(path_strs)) → (cuis, relas)
path_dict: Dict[Tuple[str, tuple], Tuple[List[str], List[str]]] = {}
for tmpl, lst in path_meta.items():
    for entry in lst:
        path_dict[(tmpl, tuple(entry["path_strs"]))] = (
            entry["cuis"], entry["relas"]
        )

# ========== 2. 合并 ==========
enriched = []
for qa in qa_pairs:
    key = (qa["template_id"], tuple(qa["umls_path"]))
    if key not in path_dict:
        continue

    cuis, relas = path_dict[key]
    first_entity_name = qa["umls_path"][0]        # e.g. "Gynecomastia"

    enriched.append({
        "id":         sha16(qa["question"]),
        "question":   qa["question"],
        "tool_calls": build_tool_calls(first_entity_name, cuis, relas),
        "answer":     qa["answer"],
        "reasoning_path": qa["reasoning_path"],
        "umls_path":      qa["umls_path"],
        "template_id":    qa["template_id"]
    })

# ========== 3. 写出 ==========
Path(OUT_FILE).write_text(
    json.dumps({"dataset": enriched}, ensure_ascii=False, indent=2)
)
print(f"✅ 生成 {len(enriched)} 条 → {OUT_FILE}")


✅ 生成 2068 条 → enriched_umls_qa.json


In [1]:
from tqdm import tqdm
import json
import pymysql as mysql

class DataBase:
    def __init__(self):
        try:
            self.DB = mysql.connect(
                host="server.acemap.cn",
                user="groupleader",
                passwd="onlyleaders",
                port=13306,
                database="mag-new-160205",
                charset="utf8"
            )
            print('数据库连接成功!')
            self.cursor = self.DB.cursor()
        except mysql.Error as e:
            print('数据库连接失败原因:' + str(e))

    def commit(self):
        self.DB.commit()
        
    def rollback(self):
        self.DB.rollback()

    def close(self):
        self.DB.close()

    def insert_paper(self, paper_id, title, keywords, authors, abstract, year, publication, new_acemap_id):
        try:
            sql = """
            INSERT INTO new_ccfa (paper_id, title, keywords, authors, abstract, year, publication, new_acemap_id)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
            """
            self.cursor.execute(sql, (paper_id, title, keywords, authors, abstract, year, publication, new_acemap_id))
            self.commit()
        except mysql.Error as e:
            print('插入数据失败原因:' + str(e))
            self.rollback()

db = DataBase()

数据库连接成功!


In [12]:
import pandas as pd
from tqdm import tqdm
import pymysql as mysql

class DataBase:
    def __init__(self):
        try:
            self.DB = mysql.connect(
                host="server.acemap.cn",
                user="groupleader",
                passwd="onlyleaders",
                port=13306,
                database="mag-new-160205",
                charset="utf8"
            )
            print('数据库连接成功!')
            self.cursor = self.DB.cursor()
        except mysql.MySQLError as e:
            print('数据库连接失败原因:' + str(e))

    def commit(self):
        self.DB.commit()
        
    def rollback(self):
        self.DB.rollback()

    def close(self):
        self.DB.close()

    def insert_paper(self, paper_id, title, keywords, authors, abstract, year, publication, new_acemap_id):
        try:
            sql = """
            INSERT INTO new_ccfa (paper_id, title, keywords, authors, abstract, year, publication, new_acemap_id)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
            """
            self.cursor.execute(sql, (paper_id, title, keywords, authors, abstract, year, publication, new_acemap_id))
            self.commit()
        except mysql.MySQLError as e:
            print('插入数据失败原因:' + str(e))
            self.rollback()

# 正确读取csv（处理BOM，字段自动对齐，丢弃多余字段）
df = pd.read_csv('tmp.csv', encoding='utf-8-sig', dtype=str)
df = df.fillna('')  # 避免None插入出错

db = DataBase()

for idx, row in tqdm(df.iterrows(), total=len(df)):
    # 只取需要的8列，顺序严格与insert_paper参数对齐
    db.insert_paper(
        row['paper_id'],
        row['title'],
        row['keywords'],
        row['authors'],
        row['abstract'],
        row['year'],
        row['publication'],
        row['new_acemap_id']
    )

db.close()


数据库连接成功!


100%|██████████| 3/3 [00:00<00:00, 56.48it/s]
