In [1]:
import os
import json
import re
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm, trange
from huggingface_hub import login, hf_hub_download, HfApi
from google.colab import userdata
from transformers import AutoTokenizer, AutoModel
import shutil

tqdm.pandas()

In [2]:
hf_token = userdata.get("HF_TOKEN")

if hf_token is None:
    raise ValueError("HF_TOKEN not found.")
login(token=hf_token)

HF_DATASET_REPO = "aekn/dr-dataset"
def dl_hf(filepath: str) -> str:
    """download file from hf dataset"""
    return hf_hub_download(
        repo_id=HF_DATASET_REPO,
        repo_type="dataset",
        filename=filepath,
    )
api = HfApi()

DATA_DIR = "data"
PROCESSED_DIR = "processed"
MAP_DIR = os.path.join(PROCESSED_DIR, "mappings")
FEATURE_DIR = os.path.join(PROCESSED_DIR, "features")
SPLIT_DIR = os.path.join(PROCESSED_DIR, "splits")

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(PROCESSED_DIR, exist_ok=True)
os.makedirs(MAP_DIR, exist_ok=True)
os.makedirs(FEATURE_DIR, exist_ok=True)
os.makedirs(SPLIT_DIR, exist_ok=True)

kg_path = dl_hf("data/kg.csv")
disease_path = dl_hf("data/disease_features.tab")
drug_path = dl_hf("data/drug_features.tab")

shutil.copy(kg_path, os.path.join(DATA_DIR, "kg.csv"))
shutil.copy(disease_path, os.path.join(DATA_DIR, "disease_features.tab"))
shutil.copy(drug_path, os.path.join(DATA_DIR, "drug_features.tab"))

kg_df = pd.read_csv(os.path.join(DATA_DIR, "kg.csv"), low_memory=False)
disease_features_raw = pd.read_csv(os.path.join(DATA_DIR, "disease_features.tab"), sep="\t")
drug_features_raw = pd.read_csv(os.path.join(DATA_DIR, "drug_features.tab"), sep="\t")

print("KG shape:", kg_df.shape)
print("Disease features shape:", disease_features_raw.shape)
print("Drug features shape:", drug_features_raw.shape)

data/kg.csv:   0%|          | 0.00/982M [00:00<?, ?B/s]

data/disease_features.tab:   0%|          | 0.00/114M [00:00<?, ?B/s]

drug_features.tab:   0%|          | 0.00/10.0M [00:00<?, ?B/s]

KG shape: (8100498, 12)
Disease features shape: (44133, 18)
Drug features shape: (7957, 18)


In [3]:
#@title Build Directed KG

# symmetric relations to dedupe
SYMM_RELS = {
    "protein_protein",
    "drug_drug",
    "phenotype_phenotype",
    "disease_disease",
    "bioprocess_bioprocess",
    "molfunc_molfunc",
    "cellcomp_cellcomp",
    "exposure_exposure",
    "pathway_pathway",
    "anatomy_anatomy",
}

# drug–disease rels to orient
DD_RELS = {"indication", "off-label use", "contraindication"}


def convert2str(x):
    if pd.isna(x):
        return ""
    if isinstance(x, str):
        s = x.strip().strip('"')
        try:
            f = float(s)
            return str(int(f)) if f.is_integer() else s
        except Exception:
            return s
    try:
        f = float(x)
        return str(int(f)) if f.is_integer() else str(x)
    except Exception:
        return str(x)


def dedupe_symmetric_relations(
    kg_directed: pd.DataFrame,
    symmetric_relations: set,
) -> pd.DataFrame:
    """
    Collapse symmetric homogeneous relations to one edge per unordered pair.
    Only affects rows where:
      relation E symmetric_relations AND x_type == y_type
    """
    df = kg_directed.copy()

    symm_mask = df["relation"].isin(symmetric_relations) & (df["x_type"] == df["y_type"])
    symm_df = df[symm_mask].copy()
    nonsymm_df = df[~symm_mask].copy()

    if symm_df.empty:
        return df

    pair = np.stack(
        [symm_df["x_global"].values, symm_df["y_global"].values],
        axis=1,
    )
    pair_sorted = np.sort(pair, axis=1)

    symm_df["u"] = pair_sorted[:, 0]
    symm_df["v"] = pair_sorted[:, 1]

    symm_df = symm_df.sort_index()
    symm_df = symm_df.drop_duplicates(
        subset=["relation", "u", "v"],
        keep="first",
    )

    symm_df = symm_df.drop(columns=["u", "v"])

    df_dedup = pd.concat([symm_df, nonsymm_df], ignore_index=True)

    df_dedup = df_dedup.drop_duplicates(
        subset=[
            "relation",
            "x_type", "x_id", "x_global",
            "y_type", "y_id", "y_global",
        ]
    ).reset_index(drop=True)

    return df_dedup


def dedupe_reverse_hetero(
    kg_directed: pd.DataFrame,
    exclude_relations: set | None = None,
) -> pd.DataFrame:
    """
    Collapse reverse-duplicate heterogeneous edges:
      - only edges with x_type != y_type
      - skip relations in exclude_relations
      - for each (relation, {u, v}) unordered pair, keep one edge
        this essentially removes duped (undirected) edges
    """
    if exclude_relations is None:
        exclude_relations = set()

    df = kg_directed.copy()

    mask = (df["x_type"] != df["y_type"]) & (~df["relation"].isin(exclude_relations))
    het_df = df[mask].copy()
    other_df = df[~mask].copy()

    if het_df.empty:
        return df

    pair = np.stack(
        [het_df["x_global"].values, het_df["y_global"].values],
        axis=1,
    )
    pair_sorted = np.sort(pair, axis=1)
    het_df["u"] = pair_sorted[:, 0]
    het_df["v"] = pair_sorted[:, 1]

    het_df = het_df.sort_index()
    het_df = het_df.drop_duplicates(
        subset=["relation", "u", "v"],
        keep="first",
    )
    het_df = het_df.drop(columns=["u", "v"])

    df_dedup = pd.concat([het_df, other_df], ignore_index=True)

    df_dedup = df_dedup.drop_duplicates(
        subset=[
            "relation",
            "x_type", "x_id", "x_global",
            "y_type", "y_id", "y_global",
        ]
    ).reset_index(drop=True)

    return df_dedup


def build_kg_directed(primekg_df: pd.DataFrame) -> pd.DataFrame:
    """
    Canonical directed KG for the framework:
      - Normalize ids & indices
      - Orient DD_RELS as drug -> disease
      - Drop exact duplicate triples
      - Dedupe symmetric same-type relations (SYMM_RELS)
      - Dedupe reverse-duplicate heterogeneous edges (non-DD)
    """
    usecols = [
        "x_type", "x_id", "x_index", "x_name",
        "relation", "display_relation",
        "y_type", "y_id", "y_index", "y_name",
    ]
    df = primekg_df[usecols].copy()

    df = df.rename(columns={
        "x_index": "x_global",
        "y_index": "y_global",
    })

    df["x_id"] = df["x_id"].map(convert2str)
    df["y_id"] = df["y_id"].map(convert2str)

    df["x_global"] = df["x_global"].astype(np.int64)
    df["y_global"] = df["y_global"].astype(np.int64)

    # reorient dd rels to drug -> disease
    dd_mask = df["relation"].isin(DD_RELS)
    need_swap = dd_mask & (df["x_type"] == "disease") & (df["y_type"] == "drug")

    if need_swap.any():
        df.loc[need_swap, ["x_type", "y_type"]] = (
            df.loc[need_swap, ["y_type", "x_type"]].to_numpy()
        )
        df.loc[need_swap, ["x_id", "y_id"]] = (
            df.loc[need_swap, ["y_id", "x_id"]].to_numpy()
        )
        df.loc[need_swap, ["x_global", "y_global"]] = (
            df.loc[need_swap, ["y_global", "x_global"]].to_numpy()
        )
        df.loc[need_swap, ["x_name", "y_name"]] = (
            df.loc[need_swap, ["y_name", "x_name"]].to_numpy()
        )

    df = df.drop_duplicates(
        subset=[
            "relation",
            "x_type", "x_id", "x_global",
            "y_type", "y_id", "y_global",
        ]
    ).reset_index(drop=True)

    df = dedupe_symmetric_relations(df, SYMM_RELS)

    df = dedupe_reverse_hetero(df, exclude_relations=DD_RELS)

    return df


kg_directed = build_kg_directed(kg_df)
kg_directed_path = os.path.join(PROCESSED_DIR, "kg_directed.parquet")
kg_directed.to_parquet(kg_directed_path, index=False)

print("kg_directed shape:", kg_directed.shape)
kg_directed.head()

kg_directed shape: (4050064, 10)


Unnamed: 0,x_type,x_id,x_global,x_name,relation,display_relation,y_type,y_id,y_global,y_name
0,drug,DB09130,14012,Copper,drug_protein,carrier,gene/protein,2157,7183,F8
1,drug,DB09130,14012,Copper,drug_protein,carrier,gene/protein,2153,8256,F5
2,drug,DB09140,14013,Oxygen,drug_protein,carrier,gene/protein,3040,4107,HBA2
3,drug,DB00180,14014,Flunisolide,drug_protein,carrier,gene/protein,866,1424,SERPINA6
4,drug,DB00240,14015,Alclometasone,drug_protein,carrier,gene/protein,866,1424,SERPINA6


In [4]:
#@title Checking node counts
# most relations should drop by 1/2
def node_counts_by_type(
    df: pd.DataFrame,
    x_idx_col: str,
    y_idx_col: str,
    x_type_col: str = "x_type",
    y_type_col: str = "y_type",
) -> pd.DataFrame:
    types = sorted(set(df[x_type_col].unique()) | set(df[y_type_col].unique()))
    rows = []

    for t in types:
        xs = set(df.loc[df[x_type_col] == t, x_idx_col])
        ys = set(df.loc[df[y_type_col] == t, y_idx_col])
        nodes = xs | ys
        rows.append({"type": t, "n_nodes": len(nodes)})

    return pd.DataFrame(rows)

orig_counts = node_counts_by_type(
    kg_df,
    x_idx_col="x_index",
    y_idx_col="y_index",
    x_type_col="x_type",
    y_type_col="y_type",
).rename(columns={"n_nodes": "n_nodes_orig"})

dir_counts = node_counts_by_type(
    kg_directed,
    x_idx_col="x_global",
    y_idx_col="y_global",
    x_type_col="x_type",
    y_type_col="y_type",
).rename(columns={"n_nodes": "n_nodes_directed"})

counts_compare = (
    orig_counts
    .merge(dir_counts, on="type", how="outer")
    .fillna(0)
)

counts_compare["n_nodes_orig"] = counts_compare["n_nodes_orig"].astype(int)
counts_compare["n_nodes_directed"] = counts_compare["n_nodes_directed"].astype(int)
counts_compare["delta"] = counts_compare["n_nodes_directed"] - counts_compare["n_nodes_orig"]

print("node counts kg vs kg_directed")
print(counts_compare.sort_values("type").to_string(index=False))

def relation_edge_counts(df, label):
    return (
        df.groupby("relation")
          .size()
          .rename(label)
    )

rel_orig = relation_edge_counts(kg_df, "orig_edges")
rel_dir  = relation_edge_counts(kg_directed, "directed_edges")

rel_compare = (
    pd.concat([rel_orig, rel_dir], axis=1)
      .fillna(0)
      .astype(int)
)

rel_compare["delta"] = rel_compare["directed_edges"] - rel_compare["orig_edges"]
rel_compare["frac_kept"] = rel_compare["directed_edges"] / rel_compare["orig_edges"]

print("\nfraction of relations kept")
print(rel_compare.sort_values("frac_kept").to_string())

node counts kg vs kg_directed
              type  n_nodes_orig  n_nodes_directed  delta
           anatomy         14035             14035      0
biological_process         28642             28642      0
cellular_component          4176              4176      0
           disease         17080             17080      0
              drug          7957              7957      0
  effect/phenotype         15311             15311      0
          exposure           818               818      0
      gene/protein         27671             27671      0
molecular_function         11169             11169      0
           pathway          2516              2516      0

fraction of relations kept
                            orig_edges  directed_edges    delta  frac_kept
relation                                                                  
drug_protein                     51306           25468   -25838   0.496394
anatomy_anatomy                  28064           14032   -14032   0.500000
anat

In [5]:
kgd_df = kg_directed.copy()

kgd_df["pair_key"] = kgd_df[["x_global", "y_global"]].apply(
    lambda r: tuple(sorted(r)), axis=1
)

reversed_dupes = kgd_df[
    kgd_df.duplicated(subset=["relation", "pair_key"], keep=False)
]

print("[num reversed duplicates]", len(reversed_dupes))

[num reversed duplicates] 0


In [6]:
#@title Build Node Table

kg_directed = pd.read_parquet("processed/kg_directed.parquet")

nodes_x = kg_directed[["x_type", "x_id", "x_global", "x_name"]].rename(
    columns={
        "x_type": "node_type",
        "x_id": "node_id",
        "x_global": "global_id",
        "x_name": "node_name",
    }
)

nodes_y = kg_directed[["y_type", "y_id", "y_global", "y_name"]].rename(
    columns={
        "y_type": "node_type",
        "y_id": "node_id",
        "y_global": "global_id",
        "y_name": "node_name",
    }
)

nodes = pd.concat([nodes_x, nodes_y], ignore_index=True)

nodes["node_type"] = nodes["node_type"].astype(str)
nodes["global_id"] = nodes["global_id"].astype(np.int64)
nodes["node_id"] = nodes["node_id"].fillna("").astype(str)
nodes["node_name"] = nodes["node_name"].fillna("").astype(str)

print(nodes.shape)

nodes["name_len"] = nodes["node_name"].str.len().fillna(0)
nodes = nodes.sort_values(
    ["node_type", "global_id", "name_len"],
    ascending=[True, True, False]
).reset_index(drop=True)

node_table = (
    nodes
    .drop_duplicates(subset=["node_type", "global_id"], keep="first")
    .drop(columns=["name_len"])
    .reset_index(drop=True)
)

node_table = node_table.sort_values(["node_type", "global_id"]).reset_index(drop=True)
node_table["local_id"] = (
    node_table.groupby("node_type").cumcount().astype(np.int64)
)

node_table = node_table[
    ["node_type", "node_id", "node_name", "global_id", "local_id"]
]

print(node_table.shape)

node_table_path = os.path.join(MAP_DIR, "node_table.parquet")
node_table.to_parquet(node_table_path, index=False)

(8100128, 4)
(129375, 5)


In [7]:
#@title just checking node table count's are the same

print("node_table shape:", node_table.shape)
print(node_table.head())

print("\nnode counts by type:")
print(node_table["node_type"].value_counts().sort_index())

for side, col_type, col_global in [
    ("x", "x_type", "x_global"),
    ("y", "y_type", "y_global"),
]:
    side_nodes = (
        kg_directed[[col_type, col_global]]
        .drop_duplicates()
        .rename(columns={col_type: "node_type", col_global: "global_id"})
        .merge(
            node_table[["node_type", "global_id", "local_id"]],
            on=["node_type", "global_id"],
            how="left",
        )
    )
    missing = side_nodes["local_id"].isna().sum()
    print(f"{side}-side unique nodes: {len(side_nodes)}, missing in node_table: {missing}")

node_table shape: (129375, 5)
  node_type node_id            node_name  global_id  local_id
0   anatomy       2       uterine cervix      63112         0
1   anatomy       3                naris      63113         1
2   anatomy       4                 nose      63114         2
3   anatomy       5   chemosensory organ      63115         3
4   anatomy       6  islet of Langerhans      63116         4

node counts by type:
node_type
anatomy               14035
biological_process    28642
cellular_component     4176
disease               17080
drug                   7957
effect/phenotype      15311
exposure                818
gene/protein          27671
molecular_function    11169
pathway                2516
Name: count, dtype: int64
x-side unique nodes: 83605, missing in node_table: 0
y-side unique nodes: 103813, missing in node_table: 0


In [8]:
# @title utils for building disease and drug labels and descriptions

STOPWORDS = {
    "the", "a", "an", "and", "or", "of", "in", "on", "for", "to",
    "with", "without", "by", "from", "as", "at", "that", "this",
    "is", "are", "was", "were", "be", "been", "it", "its", "their",
    "may", "can", "could", "also", "often", "usually",
    "patient", "patients", "condition", "disease", "disorder"
}

GENERIC_PATTERNS = [
    "see your doctor",
    "see a doctor",
    "talk to your doctor",
    "contact your doctor",
    "call your doctor",
]

def is_generic_advice(sent: str) -> bool:
    if not isinstance(sent, str):
        return False
    s = sent.lower()
    return any(p in s for p in GENERIC_PATTERNS)


def tokenize(text: str):
    """lowercase alphanum tokenization, excluding stopwords and short tokens"""
    if not isinstance(text, str):
        return []
    text = text.lower()
    toks = re.findall(r"[a-z0-9]+", text)
    return [t for t in toks if t not in STOPWORDS and len(t) > 2]


def split_into_sentences(text: str):
    if not isinstance(text, str):
        return []
    text = text.replace("\n", " ")
    chunks = re.split(r'(?<=[.!?;])\s+', text.strip())
    out = []
    for c in chunks:
        c = c.strip(" \t-•*")
        if c:
            out.append(c)
    return out


def concat_unique(series: pd.Series) -> str:
    "concatenate unique strings with [SEP]"
    seen = set()
    vals = []
    for v in series:
        if isinstance(v, str):
            s = v.strip()
            if s and s not in seen:
                seen.add(s)
                vals.append(s)
    return " [SEP] ".join(vals)


def sentences_from_columns(row: pd.Series, cols) -> list:
    "splits [SEP] blocks into sentences"
    out = []
    for col in cols:
        if col not in row.index:
            continue
        value = row.get(col, "")
        if isinstance(value, str) and value.strip():
            blocks = [b.strip() for b in value.split("[SEP]") if b.strip()]
            for block in blocks:
                out.extend(split_into_sentences(block))
    return out


def strip_leading_name(text: str, name: str) -> str:
    """
    this is mainly for drugs,
    remove '<name> is/are ...' to avoid wasting tokens.
    """
    if not (isinstance(text, str) and isinstance(name, str)):
        return text
    t = text.strip()
    nm = name.strip()
    if not nm:
        return t

    pattern = re.compile(rf"^{re.escape(nm)}\s+(is|are)\b", flags=re.IGNORECASE)
    t2 = pattern.sub("", t).lstrip(" ,.-")
    return t2 if t2 else t


def sentences_from_columns_drug(row: pd.Series, cols, drug_name: str) -> list:
    """
    same thing as 'sentences_from_columns' except initial leading name
    is stripped
    """
    out = []
    for col in cols:
        if col not in row.index:
            continue
        value = row.get(col, "")
        if isinstance(value, str) and value.strip():
            blocks = [b.strip() for b in value.split("[SEP]") if b.strip()]
            for block in blocks:
                cleaned = strip_leading_name(block, drug_name)
                out.extend(split_into_sentences(cleaned))
    return out


def select_diverse_sentences(
    sentences,
    max_sentences: int,
    jac_threshold: float = 0.6,
    max_candidates: int = 80,
) -> str:
    """
    Dedupe sentences by jaccard on content tokens and greedily pick
    sentences that add the most new tokens, with a small length penalty.
    Return "[NO_TEXT]" if unusable.
    """
    cleaned = []
    for s in sentences:
        if isinstance(s, str):
            s = s.strip()
            if s and not is_generic_advice(s):
                cleaned.append(s)

    if not cleaned:
        return "[NO_TEXT]"

    if len(cleaned) > max_candidates:
        cleaned = cleaned[:max_candidates]

    token_sets = [set(tokenize(s)) for s in cleaned]

    # dedupe by jaccard similarity
    dedup_sentences = []
    dedup_tokens = []
    for sent, toks in zip(cleaned, token_sets):
        if not dedup_sentences:
            dedup_sentences.append(sent)
            dedup_tokens.append(toks)
            continue
        too_similar = False
        for existing in dedup_tokens:
            inter = len(toks & existing)
            union = len(toks | existing)
            js = inter / union if union > 0 else 0.0
            if js >= jac_threshold:
                too_similar = True
                break
        if not too_similar:
            dedup_sentences.append(sent)
            dedup_tokens.append(toks)

    if not dedup_sentences:
        return "[NO_TEXT]"

    if len(dedup_sentences) <= max_sentences:
        return " ".join(dedup_sentences)

    # greedily add sentences
    # pick the most "informative" sentence first
    first_idx = max(
        range(len(dedup_sentences)),
        key=lambda i: len(dedup_tokens[i])
    )
    selected_idx = [first_idx]
    covered = set(dedup_tokens[first_idx])

    while len(selected_idx) < max_sentences:
        best_i = None
        best_score = 0.0

        for i in range(len(dedup_sentences)):
            if i in selected_idx:
                continue
            raw_gain = len(dedup_tokens[i] - covered)
            if raw_gain <= 0:
                continue
            length = len(dedup_tokens[i])

            # length penalty, so that the best gain isnt longest sentence
            gain = raw_gain / (1 + 0.5 * max(0, length - 30))
            if gain > best_score:
                best_score = gain
                best_i = i

        if best_i is None:
            break

        selected_idx.append(best_i)
        covered |= dedup_tokens[best_i]

    return " ".join(dedup_sentences[i] for i in selected_idx)


def select_definition_blocks(def_texts, max_blocks=2, jac_threshold=0.7):
    """
    from a list of definition texts, pick up to max_blocks that are:
      - not too similar, based on jaccard threshold
      - and collectively cover diverse tokens
    """
    cleaned = []
    for t in def_texts:
        if isinstance(t, str):
            s = t.strip()
            if s:
                cleaned.append(s)
    if not cleaned:
        return []

    token_sets = [set(tokenize(t)) for t in cleaned]

    # dedupe by jaccard sim
    unique_idx = []
    for i, toks in enumerate(token_sets):
        if not unique_idx:
            unique_idx.append(i)
            continue
        too_similar = False
        for j in unique_idx:
            inter = len(toks & token_sets[j])
            union = len(toks | token_sets[j])
            js = inter / union if union > 0 else 0.0
            if js >= jac_threshold:
                too_similar = True
                break
        if not too_similar:
            unique_idx.append(i)

    if not unique_idx:
        return []

    if len(unique_idx) <= max_blocks:
        return [cleaned[i] for i in unique_idx]

    # greedily add text blocks, based on diverse tokens
    # pick the most informative block first
    first_i = max(unique_idx, key=lambda i: len(token_sets[i]))
    selected = [first_i]
    covered = set(token_sets[first_i])
    remaining = [i for i in unique_idx if i != first_i]

    while len(selected) < max_blocks and remaining:
        best_i = None
        best_gain = 0
        for i in remaining:
            gain = len(token_sets[i] - covered)
            if gain > best_gain:
                best_gain = gain
                best_i = i
        if best_i is None or best_gain == 0:
            break
        selected.append(best_i)
        covered |= token_sets[best_i]
        remaining = [i for i in remaining if i != best_i]

    return [cleaned[i] for i in selected]

In [9]:
#@title build labels and descriptions for Diseases

disease_nodes = node_table[node_table["node_type"] == "disease"][
    ["node_type", "node_id", "node_name", "global_id", "local_id"]
].copy()
print("num disease nodes:", len(disease_nodes))

requested_disease_cols = [
    "mondo_id",
    "mondo_name",
    "group_name_bert",
    "mondo_definition",
    "umls_description",
    "orphanet_definition",
    "orphanet_clinical_description",
    "mayo_symptoms",
    "mayo_causes",
    "mayo_risk_factors",
    "mayo_complications",
    "orphanet_prevalence",
    "orphanet_epidemiology",
]

disease_cols = [c for c in requested_disease_cols if c in disease_features_raw.columns]

# aggregate disease info
agg_dict = {}
if "mondo_id" in disease_cols:
    agg_dict["mondo_id"] = "first"
if "mondo_name" in disease_cols:
    agg_dict["mondo_name"] = "first"
if "group_name_bert" in disease_cols:
    agg_dict["group_name_bert"] = "first"

for col in disease_cols:
    if col in agg_dict:
        continue
    agg_dict[col] = concat_unique


print("[aggregating disease_features.tab by node_index]")
disease_agg = (
    disease_features_raw
    .groupby("node_index", as_index=False)
    .agg(agg_dict)
    .rename(columns={"node_index": "global_id"})
)

print("aggregated disease rows:", len(disease_agg))

# anchor only disease nodes within kg
disease_full = disease_nodes.merge(
    disease_agg,
    on="global_id",
    how="left"
)
print("disease_full rows (# disease nodes):", len(disease_full))


def build_disease_structured_text(row: pd.Series, jac_threshold: float = 0.6) -> str:
    """
    Longer text for MedEmbed. This contains important disease information.
    We order the text based on importance, aiming to be under 512 tokens.
      (1) name
      (2) group
      (3) core definitions
      (4) symptoms
      (5) causes
      (6) risk factors
      (7) complications
      (8) prevalence and epidemiology
    """
    parts = []

    # choose first not null from {mondo_name, node_name, node_id} to be dis name
    name = row.get("mondo_name", "")
    if not (isinstance(name, str) and name.strip()):
        name = row.get("node_name", "")
    if not (isinstance(name, str) and name.strip()):
        name = row.get("node_id", "")

    if isinstance(name, str) and name.strip():
        nm = name.strip()
        if not nm.endswith((".", "!", "?")):
            nm += "."
        parts.append(f"Disease name: {nm}")

    # group label if present
    group_name = row.get("group_name_bert", "")
    if isinstance(group_name, str) and group_name.strip():
        gn = group_name.strip()
        if not gn.endswith((".", "!", "?")):
            gn += "."
        parts.append(f"Disease group: {gn}")

    # choose 2 sources for definitions, add up to 4 diverse sentences
    def_cols = [
        "mondo_definition",
        "umls_description",
        "orphanet_definition",
        "orphanet_clinical_description",
    ]
    def_cols = [c for c in def_cols if c in row.index]

    core_defs = []
    for col in def_cols:
        raw = row.get(col, "")
        if isinstance(raw, str) and raw.strip():
            core_defs.extend(b.strip() for b in raw.split("[SEP]") if b.strip())

    selected_defs = select_definition_blocks(core_defs, max_blocks=2, jac_threshold=jac_threshold)
    if selected_defs:
        def_sentences = []
        for d in selected_defs:
            def_sentences.extend(split_into_sentences(d))
        def_text = select_diverse_sentences(def_sentences, max_sentences=4, jac_threshold=jac_threshold)
        if def_text != "[NO_TEXT]":
            parts.append(f"Core definitions: {def_text}")

    # symptoms
    sym_sent = sentences_from_columns(row, ["mayo_symptoms"])
    sym_text = select_diverse_sentences(sym_sent, max_sentences=3, jac_threshold=jac_threshold)
    if sym_text != "[NO_TEXT]":
        parts.append(f"Symptoms: {sym_text}")

    # causes
    cause_sent = sentences_from_columns(row, ["mayo_causes"])
    cause_text = select_diverse_sentences(cause_sent, max_sentences=2, jac_threshold=jac_threshold)
    if cause_text != "[NO_TEXT]":
        parts.append(f"Causes: {cause_text}")

    # risk factors
    risk_sent = sentences_from_columns(row, ["mayo_risk_factors"])
    risk_text = select_diverse_sentences(risk_sent, max_sentences=2, jac_threshold=jac_threshold)
    if risk_text != "[NO_TEXT]":
        parts.append(f"Risk factors: {risk_text}")

    # complications
    comp_sent = sentences_from_columns(row, ["mayo_complications"])
    comp_text = select_diverse_sentences(comp_sent, max_sentences=2, jac_threshold=jac_threshold)
    if comp_text != "[NO_TEXT]":
        parts.append(f"Complications: {comp_text}")

    # prevalence and epidemiology
    prev_sent = sentences_from_columns(row, ["orphanet_prevalence", "orphanet_epidemiology"])
    prev_text = select_diverse_sentences(prev_sent, max_sentences=2, jac_threshold=jac_threshold)
    if prev_text != "[NO_TEXT]":
        parts.append(f"Prevalence and epidemiology: {prev_text}")

    if not parts:
        return "[NO_TEXT]"
    return " ".join(parts)

def build_disease_label_text(row: pd.Series) -> str:
    """
    short label text for disease (name, group)
    used for SAPBert embedding
    """
    labels = []
    for col in ["mondo_name", "node_name", "group_name_bert"]:
        val = row.get(col, "")
        if isinstance(val, str) and val.strip():
            labels.append(val.strip())

    if not labels:
        nid = row.get("node_id", "")
        if isinstance(nid, str) and nid.strip():
            labels.append(nid.strip())

    if not labels:
        return "[NO_TEXT]"
    return " [SEP] ".join(labels)


print("[building structured disease_text for all disease nodes]")
disease_full["disease_text"] = disease_full.progress_apply(
    build_disease_structured_text,
    axis=1,
)

disease_full["label_text"] = disease_full.progress_apply(
    build_disease_label_text,
    axis=1,
)

# sort by local_id so that row idx = local_id
disease_text_df = disease_full.sort_values("local_id").reset_index(drop=True)

disease_cols_to_keep = [
    "node_type", "node_id", "node_name",
    "global_id", "local_id",
    "mondo_id", "mondo_name", "group_name_bert",
    "label_text", "disease_text",
]

disease_out_path = os.path.join(FEATURE_DIR, "disease_text.parquet")
disease_text_df[disease_cols_to_keep].to_parquet(disease_out_path, index=False)
print("saved disease text features to:", disease_out_path)
print("disease_text.parquet shape:", disease_text_df[disease_cols_to_keep].shape)

num disease nodes: 17080
[aggregating disease_features.tab by node_index]
aggregated disease rows: 17080
disease_full rows (# disease nodes): 17080
[building structured disease_text for all disease nodes]


  0%|          | 0/17080 [00:00<?, ?it/s]

  0%|          | 0/17080 [00:00<?, ?it/s]

saved disease text features to: processed/features/disease_text.parquet
disease_text.parquet shape: (17080, 10)


In [10]:
disease_text_df["disease_text"][12442]

"Disease name: schizotypal personality disorder. Core definitions: A disorder characterized by an enduring pattern of inability to establish close relationships coupled with cognitive or perceptual distortions, odd beliefs and speech, and eccentric behavior and appearance. A personality disorder in which there are oddities of thought , perception , speech , and behavior that are not severe enough to characterize schizophrenia. Symptoms: Schizotypal personality disorder typically includes five or more of these signs and symptoms: Being a loner and lacking close friends outside of the immediate family, Flat emotions or limited or inappropriate emotional responses, Persistent and excessive social anxiety, Incorrect interpretation of events, such as a feeling that something that is actually harmless or inoffensive has a direct personal meaning, Peculiar, eccentric or unusual thinking, beliefs or mannerisms, Suspicious or paranoid thoughts and constant doubts about the loyalty of others, Be

In [11]:
#@title build labels and descriptions for Drugs

# gather nodes in kg
drug_nodes = node_table[node_table["node_type"] == "drug"][
    ["node_type", "node_id", "node_name", "global_id", "local_id"]
].copy()
print("num drug nodes:", len(drug_nodes))

requested_drug_cols = [
    "description",
    "half_life",
    "indication",
    "mechanism_of_action",
    "protein_binding",
    "pharmacodynamics",
    "state",
    "atc_1", "atc_2", "atc_3", "atc_4",
    "category",
    "group",
    "pathway",
    "molecular_weight",
    "tpsa",
    "clogp",
]

drug_cols = [c for c in requested_drug_cols if c in drug_features_raw.columns]

# aggregate dupe drugs, although there shouldnt be any
agg_dict = {}
for col in drug_cols:
    if col in {"group", "category"}:
        agg_dict[col] = "first"
    else:
        agg_dict[col] = concat_unique

print("[aggregating drug_features.tab by node_index]")
drug_agg = (
    drug_features_raw
    .groupby("node_index", as_index=False)
    .agg(agg_dict)
    .rename(columns={"node_index": "global_id"})
)

print("aggregated drug rows:", len(drug_agg))

# anchor drugs in kg
drug_full = drug_nodes.merge(
    drug_agg,
    on="global_id",
    how="left"
)
print("drug_full rows (# drug nodes):", len(drug_full))


def build_drug_structured_text(row: pd.Series, jac_threshold: float = 0.6) -> str:
    """
    Longer text for MedEmbed. Contains important drug information.
    Ordered by importance, such that tokens beyond 512 get cutoff

      (1) name
      (2) group / category
      (3) mechanism, pharmacodynamics, description
      (4) indications / uses
      (5) PK & binding
      (6) classification & pathways
      (7) physicochemical properties
    """
    parts = []

    # name
    name = row.get("node_name", "")
    if not (isinstance(name, str) and name.strip()):
        name = row.get("node_id", "")

    if isinstance(name, str) and name.strip():
        nm = name.strip()
        if not nm.endswith((".", "!", "?")):
            nm += "."
        parts.append(f"Drug name: {nm}")

    # group and category
    group_val = row.get("group", "")
    if isinstance(group_val, str) and group_val.strip():
        cleaned = strip_leading_name(group_val, name)
        cleaned = cleaned.strip()
        if cleaned:
            if not cleaned.endswith((".", "!", "?")):
                cleaned += "."
            parts.append(f"Drug group: {cleaned}")

    category_val = row.get("category", "")
    if isinstance(category_val, str) and category_val.strip():
        cleaned = strip_leading_name(category_val, name)
        cleaned = cleaned.strip()
        if cleaned:
            if not cleaned.endswith((".", "!", "?")):
                cleaned += "."
            parts.append(f"Drug category: {cleaned}")

    # moa, pharmacodynamics, and description
    mech_cols = ["mechanism_of_action", "pharmacodynamics", "description"]
    mech_cols = [c for c in mech_cols if c in row.index]

    mech_sent = sentences_from_columns_drug(row, mech_cols, drug_name=name)
    mech_text = select_diverse_sentences(
        mech_sent,
        max_sentences=4,
        jac_threshold=jac_threshold,
    )
    if mech_text != "[NO_TEXT]":
        parts.append(f"Mechanism and pharmacodynamics: {mech_text}")

    # indications, this contains disease names so their might be some leakage
    # although i think this is fine, since we are just using for initializing nodes
    if "indication" in row.index:
        ind_sent = sentences_from_columns_drug(row, ["indication"], drug_name=name)
        ind_text = select_diverse_sentences(
            ind_sent,
            max_sentences=2,
            jac_threshold=jac_threshold,
        )
        if ind_text != "[NO_TEXT]":
            parts.append(f"Indications and uses: {ind_text}")

    # pharmacokinetics
    pk_cols = ["half_life", "protein_binding"]
    pk_cols = [c for c in pk_cols if c in row.index]

    if pk_cols:
        pk_sent = sentences_from_columns_drug(row, pk_cols, drug_name=name)
        pk_text = select_diverse_sentences(
            pk_sent,
            max_sentences=1,
            jac_threshold=jac_threshold,
        )
        if pk_text != "[NO_TEXT]":
            parts.append(f"Pharmacokinetics: {pk_text}")

    # classification and pathways: atc levels and pathway text
    class_cols = ["atc_1", "atc_2", "atc_3", "atc_4", "pathway"]
    class_cols = [c for c in class_cols if c in row.index]

    if class_cols:
        class_sent = sentences_from_columns_drug(row, class_cols, drug_name=name)
        class_text = select_diverse_sentences(
            class_sent,
            max_sentences=1,
            jac_threshold=jac_threshold,
        )
        if class_text != "[NO_TEXT]":
            parts.append(f"Classification and pathways: {class_text}")

    # physiochemical properties
    phys_cols = ["molecular_weight", "tpsa", "clogp", "state"]
    phys_cols = [c for c in phys_cols if c in row.index]

    if phys_cols:
        phys_sent = sentences_from_columns_drug(row, phys_cols, drug_name=name)
        phys_text = select_diverse_sentences(
            phys_sent,
            max_sentences=2,
            jac_threshold=jac_threshold,
        )
        if phys_text != "[NO_TEXT]":
            parts.append(f"Physicochemical properties: {phys_text}")

    if not parts:
        return "[NO_TEXT]"
    return " ".join(parts)


def build_drug_label_text(row: pd.Series) -> str:
    """
    Short label text for drug (name, group, category).
    Used for SapBERT embedding.
    """
    labels = []

    name = row.get("node_name", "")
    if isinstance(name, str) and name.strip():
        labels.append(name.strip())

    # add cleaned group and category
    for col in ["group", "category"]:
        val = row.get(col, "")
        if isinstance(val, str) and val.strip():
            cleaned = strip_leading_name(val, name)
            cleaned = cleaned.strip()
            if cleaned:
                labels.append(cleaned)

    if not labels:
        nid = row.get("node_id", "")
        if isinstance(nid, str) and nid.strip():
            labels.append(nid.strip())

    if not labels:
        return "[NO_TEXT]"
    return " [SEP] ".join(labels)


print("[building structured drug_text for all drug nodes]")
drug_full["drug_text"] = drug_full.progress_apply(
    build_drug_structured_text,
    axis=1,
)

print("[building drug label_text for all drug nodes]")
drug_full["label_text"] = drug_full.progress_apply(
    build_drug_label_text,
    axis=1,
)

# sort by local_id so that row idx = local_id
drug_text_df = drug_full.sort_values("local_id").reset_index(drop=True)

drug_cols_to_keep = [
    "node_type", "node_id", "node_name",
    "global_id", "local_id",
    "group", "category",
    "label_text", "drug_text",
]

drug_out_path = os.path.join(FEATURE_DIR, "drug_text.parquet")
drug_text_df[drug_cols_to_keep].to_parquet(drug_out_path, index=False)
print("saved drug text features to:", drug_out_path)
print("drug_text.parquet shape:", drug_text_df[drug_cols_to_keep].shape)

num drug nodes: 7957
[aggregating drug_features.tab by node_index]
aggregated drug rows: 7957
drug_full rows (# drug nodes): 7957
[building structured drug_text for all drug nodes]


  0%|          | 0/7957 [00:00<?, ?it/s]

[building drug label_text for all drug nodes]


  0%|          | 0/7957 [00:00<?, ?it/s]

saved drug text features to: processed/features/drug_text.parquet
drug_text.parquet shape: (7957, 9)


In [12]:
drug_text = pd.read_parquet("processed/features/drug_text.parquet")

In [13]:
print(drug_text.info())
display(drug_text.head())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7957 entries, 0 to 7956
Data columns (total 9 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   node_type   7957 non-null   object
 1   node_id     7957 non-null   object
 2   node_name   7957 non-null   object
 3   global_id   7957 non-null   int64 
 4   local_id    7957 non-null   int64 
 5   group       7957 non-null   object
 6   category    5431 non-null   object
 7   label_text  7957 non-null   object
 8   drug_text   7957 non-null   object
dtypes: int64(2), object(7)
memory usage: 559.6+ KB
None


Unnamed: 0,node_type,node_id,node_name,global_id,local_id,group,category,label_text,drug_text
0,drug,DB09130,Copper,14012,0,Copper is approved and investigational.,Copper is part of Copper-containing Intrauteri...,Copper [SEP] approved and investigational. [SE...,Drug name: Copper. Drug group: approved and in...
1,drug,DB09140,Oxygen,14013,1,Oxygen is approved and vet_approved.,Oxygen is part of Chalcogens ; Elements ; Gase...,Oxygen [SEP] approved and vet_approved. [SEP] ...,Drug name: Oxygen. Drug group: approved and ve...
2,drug,DB00180,Flunisolide,14014,2,Flunisolide is approved and investigational.,Flunisolide is part of Adrenal Cortex Hormones...,Flunisolide [SEP] approved and investigational...,Drug name: Flunisolide. Drug group: approved a...
3,drug,DB00240,Alclometasone,14015,3,Alclometasone is approved.,Alclometasone is part of Adrenal Cortex Hormon...,Alclometasone [SEP] approved. [SEP] part of Ad...,Drug name: Alclometasone. Drug group: approved...
4,drug,DB00253,Medrysone,14016,4,Medrysone is approved.,Medrysone is part of Adrenal Cortex Hormones ;...,Medrysone [SEP] approved. [SEP] part of Adrena...,Drug name: Medrysone. Drug group: approved. Dr...


In [14]:
MEDEMBED = "abhinand/MedEmbed-base-v0.1"
med_tok = AutoTokenizer.from_pretrained(MEDEMBED)

def count_medembed_tokens(text: str) -> int:
    if not isinstance(text, str):
        text = "" if text is None else str(text)
    return len(med_tok(text, truncation=False)["input_ids"])

print("[computing MedEmbed token lengths for disease_text]")
disease_text_df["medembed_len"] = disease_text_df["disease_text"].apply(count_medembed_tokens)

print(disease_text_df["medembed_len"].describe(percentiles=[0.5, 0.9, 0.95, 0.99]))

total = len(disease_text_df)
over_512 = (disease_text_df["medembed_len"] > 512).sum()
print(f"\nNum > 512 tokens: {over_512}/{total} ({over_512/total:.2%})")

print("[computing MedEmbed token lengths for drug_text]")
drug_text_df["medembed_len"] = drug_text_df["drug_text"].apply(count_medembed_tokens)

print(drug_text_df["medembed_len"].describe(percentiles=[0.5, 0.9, 0.95, 0.99]))

total = len(drug_text_df)
over_512 = (drug_text_df["medembed_len"] > 512).sum()
print(f"\nNum > 512 tokens: {over_512}/{total} ({over_512/total:.2%})")

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (581 > 512). Running this sequence through the model will result in indexing errors


[computing MedEmbed token lengths for disease_text]
count    17080.000000
mean       183.953454
std        176.800480
min          8.000000
50%        128.000000
90%        461.000000
95%        545.000000
99%        702.000000
max       1237.000000
Name: medembed_len, dtype: float64

Num > 512 tokens: 1163/17080 (6.81%)
[computing MedEmbed token lengths for drug_text]
count    7957.000000
mean      232.525449
std       189.526286
min        14.000000
50%       167.000000
90%       500.000000
95%       595.000000
99%       812.440000
max      1338.000000
Name: medembed_len, dtype: float64

Num > 512 tokens: 728/7957 (9.15%)


In [15]:
def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"

device = get_device()
print("Device:", device)

def mean_pool(last_hidden_state: torch.Tensor,
             attention_mask: torch.Tensor) -> torch.Tensor:
    """
    mean pool over non-padding tokens
    """
    # [B, L] --> [B, L, 1]
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)

    # zero out padding tokens [B, L, D]
    masked = last_hidden_state * mask

    # sum over tokens [B, D]
    summed = masked.sum(dim=1)

    # num nonpadding tokens [B, 1]
    counts = mask.sum(dim=1).clamp(min=1.0)

    # mean pool [B, D]
    return summed / counts


def embed_texts_with_model(
    texts: pd.Series,
    model_name: str,
    max_length: int = 512,
    batch_size: int = 32,
    device: str = None,
) -> torch.Tensor:
    """
    Encode a series of texts with an encoder model using mean pooling.
    Returns an L2-norm tensor of shape [N, D].
    """
    if device is None:
        device = get_device()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    model.to(device)
    model.eval()

    all_embs = []
    n = len(texts)

    # just in case
    texts = texts.fillna("[NO_TEXT]")

    with torch.no_grad():
        for start in trange(0, n, batch_size, desc=f"Embedding with {model_name}"):
            end = min(start + batch_size, n)
            batch_texts = texts.iloc[start:end].tolist()

            enc = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            )
            enc = {k: v.to(device) for k, v in enc.items()}

            outputs = model(**enc)
            last_hidden = outputs.last_hidden_state              # [B, L, D]
            emb = mean_pool(last_hidden, enc["attention_mask"])  # [B, D]

            # should always be [B, D]
            assert emb.dim() == 2, f"Unexpected emb shape {emb.shape}"

            # L2 normalize
            emb = torch.nn.functional.normalize(emb, p=2, dim=-1)  # [B, D]

            all_embs.append(emb.cpu())

    all_embs = torch.cat(all_embs, dim=0)
    return all_embs



Device: cuda


In [16]:
#@title make disease and drug embeddings

disease_text_df = pd.read_parquet("processed/features/disease_text.parquet")
disease_text_df = disease_text_df.sort_values("local_id").reset_index(drop=True)
assert (disease_text_df["local_id"].values == np.arange(len(disease_text_df))).all()
print("Disease text df shape:", disease_text_df.shape)

drug_text_df = pd.read_parquet("processed/features/drug_text.parquet")
drug_text_df = drug_text_df.sort_values("local_id").reset_index(drop=True)
assert (drug_text_df["local_id"].values == np.arange(len(drug_text_df))).all()
print("Drug text df shape:", drug_text_df.shape)

SAPBERT = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext"
MEDEMBED = "abhinand/MedEmbed-base-v0.1"

print("\n[embedding disease label_text with SapBERT]")
disease_sapbert_emb = embed_texts_with_model(
    disease_text_df["label_text"],
    model_name=SAPBERT,
    max_length=256,
    batch_size=128,
    device=device,
)
print("disease_sapbert_emb shape:", disease_sapbert_emb.shape)

print("\n[embedding drug label_text with SapBERT]")
drug_sapbert_emb = embed_texts_with_model(
    drug_text_df["label_text"],
    model_name=SAPBERT,
    max_length=256,
    batch_size=128,
    device=device,
)
print("drug_sapbert_emb shape:", drug_sapbert_emb.shape)

print("\n[embedding disease_text with MedEmbed]")
disease_medembed_emb = embed_texts_with_model(
    disease_text_df["disease_text"],
    model_name=MEDEMBED,
    max_length=512,
    batch_size=64,
    device=device,
)
print("disease_medembed_emb shape:", disease_medembed_emb.shape)

print("\n[embedding drug_text with MedEmbed]")
drug_medembed_emb = embed_texts_with_model(
    drug_text_df["drug_text"],
    model_name=MEDEMBED,
    max_length=512,
    batch_size=64,
    device=device,
)
print("drug_medembed_emb shape:", drug_medembed_emb.shape)

torch.save(disease_sapbert_emb, os.path.join(FEATURE_DIR, "disease_sapbert_emb.pt"))
torch.save(drug_sapbert_emb, os.path.join(FEATURE_DIR, "drug_sapbert_emb.pt"))
torch.save(disease_medembed_emb, os.path.join(FEATURE_DIR, "disease_medembed_emb.pt"))
torch.save(drug_medembed_emb, os.path.join(FEATURE_DIR, "drug_medembed_emb.pt"))

Disease text df shape: (17080, 10)
Drug text df shape: (7957, 9)

[embedding disease label_text with SapBERT]


tokenizer_config.json:   0%|          | 0.00/198 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Embedding with cambridgeltl/SapBERT-from-PubMedBERT-fulltext:   0%|          | 0/134 [00:00<?, ?it/s]

disease_sapbert_emb shape: torch.Size([17080, 768])

[embedding drug label_text with SapBERT]


Embedding with cambridgeltl/SapBERT-from-PubMedBERT-fulltext:   0%|          | 0/63 [00:00<?, ?it/s]

drug_sapbert_emb shape: torch.Size([7957, 768])

[embedding disease_text with MedEmbed]


config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Embedding with abhinand/MedEmbed-base-v0.1:   0%|          | 0/267 [00:00<?, ?it/s]

disease_medembed_emb shape: torch.Size([17080, 768])

[embedding drug_text with MedEmbed]


Embedding with abhinand/MedEmbed-base-v0.1:   0%|          | 0/125 [00:00<?, ?it/s]

drug_medembed_emb shape: torch.Size([7957, 768])


In [17]:
disease_sapbert_emb = torch.load(os.path.join(FEATURE_DIR, "disease_sapbert_emb.pt"))
disease_medembed_emb = torch.load(os.path.join(FEATURE_DIR, "disease_medembed_emb.pt"))
drug_sapbert_emb = torch.load(os.path.join(FEATURE_DIR, "drug_sapbert_emb.pt"))
drug_medembed_emb = torch.load(os.path.join(FEATURE_DIR, "drug_medembed_emb.pt"))

print("disease_sapbert_emb:", disease_sapbert_emb.shape)
print("disease_medembed_emb:", disease_medembed_emb.shape)
print("drug_sapbert_emb:", drug_sapbert_emb.shape)
print("drug_medembed_emb:", drug_medembed_emb.shape)

disease_init_text_emb = torch.cat(
    [disease_sapbert_emb, disease_medembed_emb],
    dim=-1,
)

drug_init_text_emb = torch.cat(
    [drug_sapbert_emb, drug_medembed_emb],
    dim=-1,
)

disease_init_text_emb = torch.nn.functional.normalize(disease_init_text_emb, p=2, dim=-1)
drug_init_text_emb = torch.nn.functional.normalize(drug_init_text_emb, p=2, dim=-1)

print("disease_init_text_emb:", disease_init_text_emb.shape)
print("drug_init_text_emb:", drug_init_text_emb.shape)

torch.save(disease_init_text_emb, os.path.join(FEATURE_DIR, "disease_text_init_emb.pt"))
torch.save(drug_init_text_emb, os.path.join(FEATURE_DIR, "drug_text_init_emb.pt"))

np.save(os.path.join(FEATURE_DIR, "disease_text_init_emb.npy"), disease_init_text_emb.numpy())
np.save(os.path.join(FEATURE_DIR, "drug_text_init_emb.npy"), drug_init_text_emb.numpy())

disease_sapbert_emb: torch.Size([17080, 768])
disease_medembed_emb: torch.Size([17080, 768])
drug_sapbert_emb: torch.Size([7957, 768])
drug_medembed_emb: torch.Size([7957, 768])
disease_init_text_emb: torch.Size([17080, 1536])
drug_init_text_emb: torch.Size([7957, 1536])


In [18]:
#@title add local id to kg_directed.parquet

node_table = pd.read_parquet("processed/mappings/node_table.parquet")
kg_directed = pd.read_parquet("processed/kg_directed.parquet")

node_idx = node_table[["node_type", "global_id", "local_id"]].copy()

kg_directed = kg_directed.merge(
    node_idx.rename(columns={
        "node_type": "x_type",
        "global_id": "x_global",
        "local_id": "x_local",
    }),
    on=["x_type", "x_global"],
    how="left",
)

kg_directed = kg_directed.merge(
    node_idx.rename(columns={
        "node_type": "y_type",
        "global_id": "y_global",
        "local_id": "y_local",
    }),
    on=["y_type", "y_global"],
    how="left",
)

assert kg_directed["x_local"].isna().sum() == 0, "Missing x_local for some edges"
assert kg_directed["y_local"].isna().sum() == 0, "Missing y_local for some edges"

kg_directed["x_local"] = kg_directed["x_local"].astype(int)
kg_directed["y_local"] = kg_directed["y_local"].astype(int)

kg_directed_path = os.path.join(PROCESSED_DIR, "kg_directed.parquet")
kg_directed.to_parquet(kg_directed_path, index=False)
print("kg_directed shape:", kg_directed.shape)

kg_directed shape: (4050064, 12)


In [19]:
#@title mappings for relation types

relations = sorted(kg_directed["relation"].unique())
rel_table = pd.DataFrame({
    "rel_id": np.arange(len(relations), dtype=np.int32),
    "relation": relations,
})
rel_table_path = os.path.join(MAP_DIR, "relation_table.parquet")
rel_table.to_parquet(rel_table_path, index=False)
print("[saved relation_table]", rel_table_path)

[saved relation_table] processed/mappings/relation_table.parquet


In [20]:
#@title mappings for node types

node_types = sorted(node_table["node_type"].unique())
node_type_table = pd.DataFrame({
    "type_id": np.arange(len(node_types), dtype=np.int16),
    "node_type": node_types,
    "n_nodes": [int((node_table["node_type"] == nt).sum()) for nt in node_types],
})
node_type_path = os.path.join(MAP_DIR, "node_type_table.parquet")
node_type_table.to_parquet(node_type_path, index=False)
print("[saved node_type_table]", node_type_path)

[saved node_type_table] processed/mappings/node_type_table.parquet


In [21]:
#@title labels for dd relations

DD_RELS = {"indication", "off-label use", "contraindication"}

dd_edges = kg_directed[kg_directed["relation"].isin(DD_RELS)].copy()

assert (dd_edges["x_type"] == "drug").all()
assert (dd_edges["y_type"] == "disease").all()

dd_edges["label_sign"] = 0
dd_edges.loc[dd_edges["relation"].isin(["indication", "off-label use"]), "label_sign"] = 1
dd_edges.loc[dd_edges["relation"] == "contraindication", "label_sign"] = -1

dd_edges_path = os.path.join(PROCESSED_DIR, "drug_disease_edges.parquet")
dd_edges.to_parquet(dd_edges_path, index=False)
print("[saved drug_disease_edges]", dd_edges_path)
print(dd_edges.shape)

[saved drug_disease_edges] processed/drug_disease_edges.parquet
(42631, 13)


In [22]:
#@title choose diseases for train/val/test

disease_nodes = node_table[node_table["node_type"] == "disease"][
    ["global_id", "local_id", "node_id", "node_name"]
].copy()

N = len(disease_nodes)
rng = np.random.default_rng(42)

perm = rng.permutation(N)
n_train = int(0.7 * N)
n_val = int(0.1 * N)
n_test = N - n_train - n_val

train_idx = perm[:n_train]
val_idx = perm[n_train:n_train + n_val]
test_idx = perm[n_train + n_val:]

disease_nodes["split"] = "train"
disease_nodes.iloc[val_idx, disease_nodes.columns.get_loc("split")] = "val"
disease_nodes.iloc[test_idx, disease_nodes.columns.get_loc("split")] = "test"

print(disease_nodes["split"].value_counts())

disease_split_path = os.path.join(SPLIT_DIR, "disease_zero_shot_splits.parquet")
disease_nodes.to_parquet(disease_split_path, index=False)
print("[saved disease_zero_shot_splits]", disease_split_path)

split
train    11956
test      3416
val       1708
Name: count, dtype: int64
[saved disease_zero_shot_splits] processed/splits/disease_zero_shot_splits.parquet


In [23]:
#@title dd edges for train/val/test

disease_split_df = pd.read_parquet(disease_split_path)
dis2split = dict(zip(disease_split_df["global_id"], disease_split_df["split"]))

dd_edges = pd.read_parquet(dd_edges_path).copy()
dd_edges["disease_split"] = dd_edges["y_global"].map(dis2split)

dd_train = dd_edges[dd_edges["disease_split"] == "train"].copy()
dd_val = dd_edges[dd_edges["disease_split"] == "val"].copy()
dd_test = dd_edges[dd_edges["disease_split"] == "test"].copy()

print("dd_train:", dd_train.shape)
print("dd_val:", dd_val.shape)
print("dd_test:", dd_test.shape)

dd_train_path = os.path.join(SPLIT_DIR, "dd_edges_zero_shot_train.parquet")
dd_val_path = os.path.join(SPLIT_DIR, "dd_edges_zero_shot_val.parquet")
dd_test_path = os.path.join(SPLIT_DIR, "dd_edges_zero_shot_test.parquet")

dd_train.to_parquet(dd_train_path, index=False)
dd_val.to_parquet(dd_val_path, index=False)
dd_test.to_parquet(dd_test_path, index=False)
print("[saved dd_edges zero-shot splits]")

dd_train: (29803, 14)
dd_val: (4455, 14)
dd_test: (8373, 14)
[saved dd_edges zero-shot splits]


In [24]:
#@title random train/val/test splits

dd_edges_all = dd_edges.copy()
perm = rng.permutation(len(dd_edges_all))
n_train = int(0.8 * len(dd_edges_all))
n_val = int(0.1 * len(dd_edges_all))
n_test = len(dd_edges_all) - n_train - n_val

train_idx = perm[:n_train]
val_idx = perm[n_train:n_train + n_val]
test_idx = perm[n_train + n_val:]

dd_edges_all["random_split"] = "train"
dd_edges_all.iloc[val_idx, dd_edges_all.columns.get_loc("random_split")] = "val"
dd_edges_all.iloc[test_idx, dd_edges_all.columns.get_loc("random_split")] = "test"

dd_edges_all.to_parquet(
    os.path.join(SPLIT_DIR, "dd_edges_random_splits.parquet"),
    index=False,
)
print("[saved random dd_edges splits]")

[saved random dd_edges splits]


In [25]:
disease_init = torch.load(os.path.join(FEATURE_DIR, "disease_text_init_emb.pt"))
drug_init = torch.load(os.path.join(FEATURE_DIR, "drug_text_init_emb.pt"))

meta = {
    "n_nodes_total": int(node_table.shape[0]),
    "n_edges_total": int(kg_directed.shape[0]),
    "node_types": node_types,
    "relations": relations,
    "n_disease": int((node_table["node_type"] == "disease").sum()),
    "n_drug": int((node_table["node_type"] == "drug").sum()),
    "disease_text_init_dim": int(disease_init.shape[1]),
    "drug_text_init_dim": int(drug_init.shape[1]),
    "dd_zero_shot_counts": {
        "train": int(dd_train.shape[0]),
        "val": int(dd_val.shape[0]),
        "test": int(dd_test.shape[0]),
    },
}

with open(os.path.join(PROCESSED_DIR, "meta.json"), "w") as f:
    json.dump(meta, f, indent=2)
print("[saved meta.json]")

[saved meta.json]


In [26]:
api.upload_folder(
    folder_path=PROCESSED_DIR,
    path_in_repo="processed",
    repo_id=HF_DATASET_REPO,
    repo_type="dataset"
)

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  .../disease_text_init_emb.pt:   1%|          |  552kB /  105MB            

  ...s_zero_shot_train.parquet: 100%|##########|  309kB /  309kB            

  ...s/disease_medembed_emb.pt:  77%|#######6  | 40.3MB / 52.5MB            

  ...es_zero_shot_test.parquet: 100%|##########|  113kB /  113kB            

  ...res/drug_text_init_emb.pt:  30%|###       | 14.8MB / 48.9MB            

  ...rug_disease_edges.parquet: 100%|##########|  414kB /  414kB            

  ...es/drug_text_init_emb.npy:  30%|###       | 14.8MB / 48.9MB            

  ...gs/relation_table.parquet: 100%|##########| 2.16kB / 2.16kB            

  ...ures/disease_text.parquet: 100%|##########| 8.38MB / 8.38MB            

  ..._zero_shot_splits.parquet: 100%|##########|  701kB /  701kB            

CommitInfo(commit_url='https://huggingface.co/datasets/aekn/dr-dataset/commit/0a02ae3b1dbb41dd26d8cf4ace0165c26c61d537', commit_message='Upload folder using huggingface_hub', commit_description='', oid='0a02ae3b1dbb41dd26d8cf4ace0165c26c61d537', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/aekn/dr-dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='aekn/dr-dataset'), pr_revision=None, pr_num=None)