In [38]:
import csv, sys, time, requests
import obonet

import pandas as pd
import networkx as nx

from pathlib import Path

UNICHEM_BASE = "https://www.ebi.ac.uk/unichem/api/v1"
CHEMBL_ATC = "https://www.ebi.ac.uk/chembl/api/data/atc_class.json"
WD_SPARQL = "https://query.wikidata.org/sparql"

HEADERS_JSON = {"Accept": "application/json"}
WD_HEADERS = {"Accept": "application/sparql-results+json", "User-Agent": "open-atc-mapper/1.0"}



# Load Hetionet nodes

In [87]:
hetionet_nodes = pd.read_csv("../data/hetionet-v1.0-nodes.tsv", sep="\t")
drugbank_ids = [row["id"].split("::")[1] for _, row in hetionet_nodes.iterrows() if row["kind"] == "Compound"]
drugbank_ids[:3]

['DB00014', 'DB00035', 'DB00050']

# Link DrugBank ID with ATC codes

In [19]:
def get_unichem_sources():
    r = requests.get(f"{UNICHEM_BASE}/sources", headers=HEADERS_JSON, timeout=30)
    r.raise_for_status()
    srcs = r.json()["sources"]
    s_by_name = {s["nameLong"].lower(): s["sourceID"] for s in srcs}
    # find DrugBank & ChEMBL ids (names can vary slightly)
    db_src = next(v for k,v in s_by_name.items() if "drugbank" in k)
    chembl_src = next(v for k,v in s_by_name.items() if "chembl" in k and "target" not in k)
    return db_src, chembl_src

def drugbank_to_chembl_via_unichem(dbid, db_src, chembl_src):
    payload = {"type":"sourceID","sourceID":db_src,"compound":dbid}
    r = requests.post(f"{UNICHEM_BASE}/compounds", json=payload, headers=HEADERS_JSON, timeout=60)
    if r.status_code == 404:  # not found
        return None
    r.raise_for_status()
    comps = r.json().get("compounds", [])
    if not comps:
        return None
    # each compound lists all sources; find the chembl id
    for comp in comps:
        for s in comp.get("sources", []):
            if s.get("id") == chembl_src:
                return s.get("compoundId")
    return None

def chembl_atc_for_molecule(chembl_id):
    r = requests.get(CHEMBL_ATC, params={"molecule_chembl_id": chembl_id}, headers=HEADERS_JSON, timeout=60)
    r.raise_for_status()
    data = r.json().get("atc_class", [])
    # prefer level5 where present; fall back to generic atc_code fields
    out = []
    for row in data:
        code = row.get("level5") or row.get("atc_code") or row.get("who_code")
        if code:
            out.append(code)
    return sorted(set(out))

def wikidata_atc_for_drugbanks(batch_dbids):
    # batch SPARQL with VALUES
    vals = " ".join(f'"{x}"' for x in batch_dbids)
    q = f"""
    SELECT ?drugbank ?atc WHERE {{
      VALUES ?drugbank {{ {vals} }}
      ?drug wdt:P715 ?drugbank .
      ?drug wdt:P267 ?atc .
    }}
    """
    r = requests.get(WD_SPARQL, params={"query": q, "format":"json"}, headers=WD_HEADERS, timeout=60)
    r.raise_for_status()
    res = r.json()["results"]["bindings"]
    out = {}
    for b in res:
        db = b["drugbank"]["value"]
        atc = b["atc"]["value"]
        out.setdefault(db, set()).add(atc)
    return {k: sorted(v) for k,v in out.items()}


In [None]:
out_csv = "../data/drugbank_atc.csv"
sleep = 0.3

db_src, chembl_src = get_unichem_sources()

results = {}  # dbid -> set(atc)
# 1) UniChem → ChEMBL → ATC
for dbid in drugbank_ids:
    try:
        chembl = drugbank_to_chembl_via_unichem(dbid, db_src, chembl_src)
        if chembl:
            atcs = chembl_atc_for_molecule(chembl)
            if atcs:
                results[dbid] = set(atcs)
        time.sleep(sleep)  # be gentle
    except Exception as e:
        # continue; will try Wikidata later
        pass

# 2) Wikidata fallback for those still missing
missing = [d for d in drugbank_ids if d not in results]
B = 200  # WDQS likes modest batches
for i in range(0, len(missing), B):
    batch = missing[i:i+B]
    try:
        wd_map = wikidata_atc_for_drugbanks(batch)
        for dbid, atcs in wd_map.items():
            results.setdefault(dbid, set()).update(atcs)
        time.sleep(1.0)
    except Exception:
        pass

# write rows: one (dbid, atc) per line
with open(out_csv, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["drugbank_id","atc_code"])
    for dbid in drugbank_ids:
        if dbid in results and results[dbid]:
            for code in sorted(results[dbid]):
                w.writerow([dbid, code])
        else:
            w.writerow([dbid, ""])  # no ATC found

# Create ATC drug hierarchy

In [78]:
drug_atc_df = pd.read_csv(out_csv, dtype=str).dropna()
drug_atc_df.describe()

Unnamed: 0,drugbank_id,atc_code
count,1742,1742
unique,1324,1741
top,DB01234,J01XB01
freq,11,2


In [29]:
ATC_ROOT = "drug"

def atc_parent(code: str) -> str:
    code = code.strip()
    L = len(code)
    if L <= 1:
        return ATC_ROOT
    if L == 3:
        return code[:1]
    if L == 4:
        return code[:3]
    if L == 5:
        return code[:4]
    if L >= 7:
        return code[:5]
    return code[:-1] if L > 1 else ATC_ROOT

def build_atc_tree(atc_codes: set[str]):
    tree = set()
    stack = atc_codes.copy()
    while stack:
        c = stack.pop()
        p = atc_parent(c)
        if p and p != c:
            tree.add((c, p))
            if p not in atc_codes and p != ATC_ROOT:
                atc_codes.add(p)
                stack.add(p)
    # connect single-letter tops to ATC:ROOT
    for c in list(atc_codes):
        if len(c) == 1:
            tree.add((c, ATC_ROOT))
    return tree

In [89]:
atc_tree = build_atc_tree(set(drug_atc_df["atc_code"]))
atc_hierarchy = pd.DataFrame(atc_tree, columns=["child", "parent"])

hetionet_atc = drug_atc_df.rename(columns={"drugbank_id": "child", "atc_code": "parent"})
hetionet_atc["child"] = "Compound::" + hetionet_atc.child

drug_hierarchy = pd.concat([atc_hierarchy, hetionet_atc])
drug_path = Path('../data/atc_drug_hierarchy.tsv')
drug_hierarchy.to_csv(drug_path, sep="\t", index=False)

print(drug_hierarchy[:-3])


                  child   parent
0               N01BB01    N01BB
1               C01CA11    C01CA
2               C02CA04    C02CA
3                 S01AB     S01A
4                 R03BB     R03B
...                 ...      ...
1961  Compound::DB09009  N01BB08
1962  Compound::DB09014  N05BB02
1964  Compound::DB09017  N05CD09
1965  Compound::DB09018  A03FA04
1966  Compound::DB09019  R05CB02

[4226 rows x 2 columns]


# Create drug hierarchy

Download DOID OBO from https://obofoundry.org/ontology/doid.html

In [70]:
def build_doid_edges(
    doid_obo_path: str, 
    keep_only: set[str] = None,
    root_id: str = "DOID:4",
    root_label: str = "disease"
    ):
    g = obonet.read_obo(doid_obo_path)
    G = nx.DiGraph()
    for u, v, d in g.edges(data=True):
        if d.get('relation', 'is_a') == 'is_a':
            G.add_edge(u, v)  # child -> parent
    
    # Filter and keep only the specified diseases
    if keep_only:
        # keep nodes that are in keep_only or ancestors of those
        keep = set()
        for n in keep_only:
            if n in G:
                keep.add(n)
                keep |= nx.ancestors(G, n)
        G = G.subgraph(keep).copy()

    # Create edges
    edges = {(u, v) for u, v in G.edges()}

    # Add root node
    edges.add((root_id, root_label))

    return edges

In [None]:
disease_hierarchy_path = "../data/doid.obo"

hetionet_doid = hetionet_nodes.query('kind == "Disease"')["id"].to_list()
hetionet_doid = set([d.split("::")[1] for d in hetionet_doid])

doid_edges = build_doid_edges(disease_hierarchy_path, keep_only=hetionet_doid)
doid_edges_df = pd.DataFrame(doid_edges, columns=["child", "parent"])
doid_edges_df.to_csv("../data/hetionet_doid_hierarchy.tsv", index=False)