In [None]:
from tdc.resource import PrimeKG
import networkx as nx
import pandas as pd

data = PrimeKG(path='./data')
G = data.to_nx()

disease_feature = data.get_features('disease') 
drug_feature    = data.get_features('drug')

print(disease_feature.columns)
print(drug_feature.columns)

Found local copy...
Loading...
Found local copy...
Loading...
Found local copy...
Loading...


Index(['node_index', 'mondo_id', 'mondo_name', 'group_id_bert',
       'group_name_bert', 'mondo_definition', 'umls_description',
       'orphanet_definition', 'orphanet_prevalence', 'orphanet_epidemiology',
       'orphanet_clinical_description', 'orphanet_management_and_treatment',
       'mayo_symptoms', 'mayo_causes', 'mayo_risk_factors',
       'mayo_complications', 'mayo_prevention', 'mayo_see_doc'],
      dtype='object')
Index(['node_index', 'description', 'half_life', 'indication',
       'mechanism_of_action', 'protein_binding', 'pharmacodynamics', 'state',
       'atc_1', 'atc_2', 'atc_3', 'atc_4', 'category', 'group', 'pathway',
       'molecular_weight', 'tpsa', 'clogp'],
      dtype='object')


In [None]:
# Flatten all nodes and their attributes from G into a DataFrame
node_rows = []
for n, attr in G.nodes(data=True):
    row = {"node_id": n}   # Node ID in G
    row.update(attr)       # Expand attributes like type, name, etc.
    node_rows.append(row)

nodes_df = pd.DataFrame(node_rows)
print(nodes_df.columns)
nodes_df.head()

Index(['node_id'], dtype='object')


Unnamed: 0,node_id
0,PHYHIP
1,KIF15
2,GPANK1
3,PNMA1
4,ZRSR2


In [10]:
%pip install pandas networkx openai fuzzywuzzy python-levenshtein

Defaulting to user installation because normal site-packages is not writeable
Collecting openai
  Downloading openai-2.9.0-py3-none-any.whl.metadata (29 kB)
Collecting python-levenshtein
  Downloading python_levenshtein-0.27.3-py3-none-any.whl.metadata (3.9 kB)
Collecting jiter<1,>=0.10.0 (from openai)
  Downloading jiter-0.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.2 kB)
Collecting Levenshtein==0.27.3 (from python-levenshtein)
  Downloading levenshtein-0.27.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (3.7 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein==0.27.3->python-levenshtein)
  Downloading rapidfuzz-3.14.3-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading openai-2.9.0-py3-none-any.whl (1.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m12.0 MB/s[0m  [33m0:00:00[0m
[?25hDownloading jiter-0.12.0-cp310-cp310-manylinux_2_17_x86_64.many

In [None]:
#  Layer 1: 基础导入 & 配置
import pandas as pd
import networkx as nx
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer


GPT2_PATH     = "/LLM/gpt2"  # GPT-2
PRIMEKG_PATH  = "kg.csv"   # primKG.CSV
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

print("Config OK. DEVICE =", DEVICE)


Config OK. DEVICE = cuda


In [None]:
#  Layer 2: model (GPT-2)

def load_gpt2(model_path: str = GPT2_PATH):
    """Load local GPT-2 model and tokenizer"""
    print(f" Loading GPT-2 from {model_path}...", flush=True)

    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # GPT-2 has no pad token, use eos as a substitute

    model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)

    print(f"  GPT-2 loaded successfully, running on device: {DEVICE}")
    return tokenizer, model

tokenizer, model = load_gpt2()


def gpt2_rewrite_answer(summary: str) -> str:
    """
    Ask GPT-2 to polish the English of the summary while "trying not to change facts".
    ⚠ Risk: GPT-2 might still alter content, so use strictly as an auxiliary tool.
    """
    prompt = (
        "You are a rewriting assistant.\n"
        "You will be given an answer that is already factually correct.\n"
        "Rewrite it in fluent English, but DO NOT change or add any medical facts, "
        "names, or relationships.\n\n"
        f"Original answer:\n{summary}\n\n"
        "Rewritten answer:\n"
    )

    inputs = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    if inputs.shape[1] > 900:
        inputs = inputs[:, -900:]

    attention_mask = torch.ones_like(inputs)

    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=64,
            do_sample=False,
            no_repeat_ngram_size=3,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=attention_mask
        )

    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "Rewritten answer:" in full_text:
        rewritten = full_text.split("Rewritten answer:", 1)[-1].strip()
    else:
        rewritten = full_text.strip()

    return rewritten

In [None]:
#  Layer 3: Graph Layer (PrimeKG from CSV)

def load_kg(path: str = PRIMEKG_PATH) -> nx.Graph:
    """
    Load PrimeKG subgraph from CSV.
    Assumes existence of at least: x_name, y_name, relation, display_relation, x_type, y_type.
    """
    print(f" Loading knowledge graph from {path}...", flush=True)
    df = pd.read_csv(path, low_memory=False)
    print("Column names:", list(df.columns))

    G = nx.from_pandas_edgelist(
        df,
        source="x_name",
        target="y_name",
        edge_attr=True   # Put all remaining columns into edge_attr for later use
    )
    print(f"  Graph loaded: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges.")
    return G

G = load_kg()

 正在从 kg.csv 加载知识图谱...
列名: ['relation', 'display_relation', 'x_index', 'x_id', 'x_type', 'x_name', 'x_source', 'y_index', 'y_id', 'y_type', 'y_name', 'y_source']


In [None]:
#  Layer 4: Retrieval Layer (Entity Matching + Structured Extraction)

# Build "lower -> original node name" index for case-insensitive matching
node_index = {str(n).lower(): n for n in G.nodes()}

def resolve_node(name: str):
    """Map user-input entity name to the node name in the graph (case-insensitive)."""
    if not name:
        return None
    key = name.strip().lower()
    return node_index.get(key)


def summarize_entity_from_edges(entity: str) -> str | None:
    """
    Extract one-hop relationships for a specific entity from the KG and organize them into a summary by relationship type.
    Uses pure Python logic to ensure GPT-2 does not determine the facts.
    """
    node = resolve_node(entity)
    if node is None:
        return None

    edges = list(G.edges(node, data=True))
    if not edges:
        return None

    carriers = set()
    enzymes = set()
    targets = set()
    contraindications = set()
    others = []

    for u, v, attr in edges:
        x_name = attr.get("x_name", u)
        y_name = attr.get("y_name", v)
        x_type = attr.get("x_type", "")
        y_type = attr.get("y_type", "")
        rel     = str(attr.get("relation", "")).lower()
        disp_rel = attr.get("display_relation", attr.get("relation", "related_to"))

        # Unify neighbor / type
        if node == x_name:
            neighbor, n_type = y_name, y_type
        else:
            neighbor, n_type = x_name, x_type

        # Classify by relation
        if "carrier" in rel:
            carriers.add(neighbor)
        elif "enzyme" in rel:
            enzymes.add(neighbor)
        elif "target" in rel:
            targets.add(neighbor)
        elif "contraindication" in rel:
            contraindications.add(neighbor)
        else:
            others.append(f"{entity} is {disp_rel} {neighbor} ({n_type}).")

    parts = []

    if carriers:
        parts.append(
            f"According to the knowledge graph, {entity} is carried by: "
            + ", ".join(sorted(carriers)) + "."
        )
    if enzymes:
        parts.append(
            f"{entity} is metabolized by the enzymes: "
            + ", ".join(sorted(enzymes)) + "."
        )
    if targets:
        parts.append(
            f"{entity} acts on targets such as: "
            + ", ".join(sorted(targets)) + "."
        )
    if contraindications:
        parts.append(
            f"{entity} has contraindications including: "
            + ", ".join(sorted(contraindications)) + "."
        )

    # If nothing fits into the categories, fall back to the original sentences
    if not parts and others:
        parts.append(" ".join(others[:5]))

    if not parts:
        return None

    return " ".join(parts)

In [None]:
#  Layer 5: QA API Layer

def answer_with_kg_gpt2(entity: str, question: str | None = None):
    """
    Unified external interface (GPT-2 version):
    1) Retrieve 1-hop neighbors from the graph using the entity to construct Facts
    2) Answer the question using GPT-2 + Facts
    """
    context = get_knowledge_context(entity, question=question)
    if not context:
        print(f"  Entity not found in graph or has no neighbors: {entity}")
        return

    print("[Retrieved Graph Facts]:")
    print(context)
    print()

    if not question:
        question = f"What is known about {entity} from these facts?"

    print(" GPT-2 is generating answer...\n")
    answer = generate_answer_gpt2(context, question)
    print("[Answer]:")
    print(answer)

In [None]:
#  Layer 6: Usage Examples (Ready to use in Notebook)

# Example: Ask about Warfarin's metabolism and targets
answer_with_kg_gpt2(
    "Warfarin",
    "Based on these facts, what enzymes metabolize Warfarin and what targets does it act on?"
)

# You can also switch entities and questions freely:
# answer_with_kg_gpt2("UBC", "Summarize what types of entities are related to UBC in this graph.")

GPT2

In [None]:
import pandas as pd
import networkx as nx
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# 1. Configuration
MODEL_PATH = '/LLM/gpt2'
PRIMEKG_PATH = 'kg.csv'
MAX_KNOWLEDGE_EDGES = 10

# 2. Load Model (Notebook specific)
def load_local_llm():
    tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
    tokenizer.pad_token = tokenizer.eos_token

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = GPT2LMHeadModel.from_pretrained(MODEL_PATH).to(device)

    print(f"Model loaded successfully! Running device: {device}")
    return tokenizer, model, device


# 3. Load Knowledge Graph (Notebook specific)
def load_kg(path):
    df = pd.read_csv(path, low_memory=False)
    G = nx.from_pandas_edgelist(
        df, 'x_name', 'y_name',
        edge_attr=['relation','display_relation','x_type','y_type']
    )
    print(f"PrimeKG loaded! Total {G.number_of_nodes()} nodes, {G.number_of_edges()} edges.")
    return G

# Actual Loading
tokenizer, model, device = load_local_llm()
G = load_kg(PRIMEKG_PATH)

模型加载成功！运行设备: cuda
PrimeKG 加载完毕！共 129262 个节点，4049405 条边。


In [None]:
# Entity name case-insensitive mapping
node_index = {str(n).lower(): n for n in G.nodes()}

def resolve_node(name):
    return node_index.get(name.lower().strip())

def get_knowledge_context(entity_name):
    node = resolve_node(entity_name)
    if node is None:
        return None
    
    edges = list(G.edges(node, data=True))
    lines = []
    for u, v, attr in edges[:MAX_KNOWLEDGE_EDGES]:
        # Determine neighbor and type based on edge direction
        if node == attr.get("x_name", u):
            neighbor = attr.get("y_name", v)
            n_type = attr.get("y_type", "")
        else:
            neighbor = attr.get("x_name", u)
            n_type = attr.get("x_type", "")
        
        relation = attr.get("display_relation", attr.get("relation", "related_to"))
        lines.append(f"{node} is {relation} {neighbor} ({n_type}).")
    
    return " ".join(lines)

def generate_answer(context, question):
    text = (
        "You are a medical assistant. Use ONLY the following facts to answer.\n"
        "If the question cannot be answered from the facts, reply: 'I don't know based on the given facts.'\n"
        f"Facts: {context}\n\n"
        f"Question: {question}\n\n"
        "Answer in one short paragraph. Do not invent new drugs or diseases.\n"
        "Answer:"
    )
    
    inputs = tokenizer.encode(text, return_tensors='pt').to(device)
    attention_mask = torch.ones(inputs.shape, device=device)

    outputs = model.generate(
        inputs,
        max_new_tokens=60,
        temperature=0.2,   # Reduce randomness
        do_sample=False,   # Change to greedy/deterministic
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.eos_token_id,
        attention_mask=attention_mask
    )

    ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return ans.split("Answer:")[-1].strip()

def ask_once(entity, question=None):
    ctx = get_knowledge_context(entity)
    
    if not ctx:
        print(f"  Entity not found in graph: {entity}")
        return
    
    print(f"Retrieved Knowledge:\n{ctx}\n")
    
    if not question:
        question = f"What is {entity}?"
    
    print(" Generating answer...\n")
    ans = generate_answer(ctx, question)
    print("Answer:", ans)

In [None]:
ask_once("Warfarin", "What is Warfarin used for?")

LLAMA2

In [None]:
#  Layer 1: Basic Imports & Configuration
import pandas as pd
import networkx as nx
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

LLAMA2_PATH   = "/LLM/llama2"  # Local LLaMA2 model directory
PRIMEKG_PATH  = "kg.csv"   # Your graph edge list file
MAX_K_EDGES   = 12         # Number of relations to feed the model (too many will exceed context limit)
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

print("Config OK. DEVICE =", DEVICE)

Config OK. DEVICE = cuda


In [None]:
#  Layer 2: Model Layer (LLaMA2)
def load_llama2(model_path: str = LLAMA2_PATH):
    """Load local LLaMA2 model and tokenizer"""
    print(f" Loading LLaMA2 from {model_path}...", flush=True)

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if DEVICE == "cuda":
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(model_path)

    print(f"  LLaMA2 loaded successfully, running device: {DEVICE}")
    return tokenizer, model

tokenizer, model = load_llama2()


def generate_answer_llama2(context: str, question: str) -> str:
    """
    Generate answer using LLaMA2 based on graph facts.
    Uses LLaMA2 chat style instruction format by default.
    """
    prompt = f"""[INST]<<SYS>>
You are a helpful and careful medical assistant.
Use ONLY the following facts from a biomedical knowledge graph to answer the question.
If the facts are not enough, reply exactly: "I don't know based on the given facts."
Do NOT invent new drugs, diseases, or genes.
<</SYS>>

Facts:
{context}

Question: {question}
Answer in English in 2-3 sentences.
[/INST]"""

    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(DEVICE)
    attention_mask = inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=160,
            temperature=0.2,
            do_sample=False,                # Use deterministic generation first to avoid hallucination
            pad_token_id=tokenizer.eos_token_id
        )

    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Simple truncation after [/INST]
    if "[/INST]" in full_text:
        answer = full_text.split("[/INST]", 1)[-1].strip()
    else:
        answer = full_text.strip()

    return answer

 正在从 /LLM/llama2 加载 LLaMA2...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device device because they were offloaded to the cpu.


  LLaMA2 加载成功，运行设备: cuda


In [None]:
#  Layer 3: Knowledge Graph Layer (PrimeKG from CSV)
def load_kg(path: str = PRIMEKG_PATH) -> nx.Graph:
    """
    Load PrimeKG subgraph from CSV (assuming columns: x_name, y_name, relation, display_relation, x_type, y_type).
    """
    print(f" Loading knowledge graph from {path}...", flush=True)
    df = pd.read_csv(path, low_memory=False)
    print("Column names:", list(df.columns))

    # Conservative approach: Load all columns as edge_attr to ensure relation/type information is preserved
    G = nx.from_pandas_edgelist(
        df,
        source="x_name",
        target="y_name",
        edge_attr=True
    )
    print(f"  Graph loaded: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges.")
    return G

G = load_kg()

In [None]:
#  Layer 4: Retrieval Layer (Entity Matching + Context Construction)

# Build a simple "lower -> original node name" index for case-insensitive matching
node_index = {str(n).lower(): n for n in G.nodes()}

def resolve_node(name: str):
    """Map user input entity name to node name in the graph (case-insensitive)"""
    if not name:
        return None
    key = name.strip().lower()
    return node_index.get(key)


def get_knowledge_context(entity_name: str, max_edges: int = MAX_K_EDGES) -> str | None:
    """
    Retrieve several edges directly connected to the entity from the graph and assemble them into English Facts text.
    Will try to include relation name + neighbor type.
    """
    node = resolve_node(entity_name)
    if node is None:
        return None

    edges = list(G.edges(node, data=True))
    if not edges:
        return None

    lines = []

    for u, v, attr in edges[:max_edges]:
        x_name = attr.get("x_name", u)
        y_name = attr.get("y_name", v)
        x_type = attr.get("x_type", "")
        y_type = attr.get("y_type", "")
        relation = attr.get("display_relation", attr.get("relation", "related_to"))

        if node == x_name:
            neighbor = y_name
            n_type = y_type
        elif node == y_name:
            neighbor = x_name
            n_type = x_type
        else:
            neighbor = v if node == u else u
            n_type = x_type or y_type

        if not n_type:
            n_type = "Entity"

        line = f"{node} is {relation} {neighbor} ({n_type})."
        lines.append(line)

    return " ".join(lines)

In [None]:
#  Layer 5: QA API Layer

def answer_with_kg(entity: str, question: str | None = None):
    """
    Unified external interface:
    1) Retrieve 1-hop neighbors from the graph using the entity to construct Facts
    2) Answer the question using LLaMA2 + Facts
    """
    context = get_knowledge_context(entity)
    if not context:
        print(f"  Entity not found in graph or has no neighbors: {entity}")
        return

    print("[Retrieved Graph Facts]:")
    print(context)
    print()

    if not question:
        question = f"What is the biomedical role of {entity} according to these facts?"

    print(" LLaMA2 is generating answer...\n")
    answer = generate_answer_llama2(context, question)
    print("[Answer]:")
    print(answer)

In [None]:
#  Layer 6: Usage Examples (You can run this block repeatedly in the Notebook)
# Example 1: Ask about Warfarin's metabolism / targets
answer_with_kg("Warfarin", "Based on these facts, what enzymes metabolize Warfarin and what targets does it act on?")

# Example 2: Switch to another entity you know exists in kg.csv
# answer_with_kg("UBC", "Summarize what types of entities are related to UBC in this graph.")