In [6]:
import time
import random
import requests
import pandas as pd

SESSION = requests.Session()
SESSION.headers.update({"User-Agent": "pdb-chain-to-afdb-notebook/0.1"})

def http_get(url: str, timeout: int = 60, retries: int = 3, backoff: float = 1.0) -> requests.Response:
    last = None
    for i in range(retries):
        try:
            r = SESSION.get(url, timeout=timeout)
            if r.status_code in (429, 500, 502, 503, 504):
                time.sleep(backoff * (i + 1))
                last = r
                continue
            return r
        except requests.RequestException as e:
            last = e
            time.sleep(backoff * (i + 1))
    if isinstance(last, requests.Response):
        return last
    raise last

def extract_bfactors_from_pdb_text(pdb_text: str):
    bfs = []
    for line in pdb_text.splitlines():
        if line.startswith("ATOM") or line.startswith("HETATM"):
            try:
                bf = float(line[60:66].strip())
                bfs.append(bf)
            except Exception:
                pass
    return bfs

def guess_plddt_scale(bfactors):
    # AFDB typically stores pLDDT 0..100 in B-factor.
    # Some other resources store 0..1 confidence in B-factor.
    return 100.0 if (bfactors and max(bfactors) <= 1.5) else 1.0

def plddt_stats_from_bfactors(bfactors):
    if not bfactors:
        return None
    scale = guess_plddt_scale(bfactors)
    vals = sorted(v * scale for v in bfactors)
    n = len(vals)

    def q(p):
        idx = int(round((n - 1) * p))
        idx = max(0, min(n - 1, idx))
        return vals[idx]

    mean = sum(vals) / n
    return {
        "scale_applied": scale,
        "mean": mean,
        "median": q(0.5),
        "min": vals[0],
        "max": vals[-1],
        "frac_ge_90": sum(v >= 90.0 for v in vals) / n,
        "frac_ge_70": sum(v >= 70.0 for v in vals) / n,
        "n_atoms": n,
    }


In [8]:
import re
import pandas as pd

CSV_PATH = "../evaluation/chains_filtered.csv"
df = pd.read_csv(CSV_PATH)

pdb_col = "pdb_id" if "pdb_id" in df.columns else df.columns[0]
chain_col = "chain_id" if "chain_id" in df.columns else (df.columns[1] if len(df.columns) > 1 else None)
if chain_col is None:
    raise ValueError("No chain_id column found (and no 2nd column to use).")

def norm_pdb_id(x):
    s = str(x).strip().lower()
    # keep first 4 alnum chars as pdb id (handles things like '8P0E:A' too)
    m = re.search(r'([0-9][a-z0-9]{3})', s)
    return m.group(1) if m else s[:4]

def norm_chain_id(pdb_id_4, x):
    s = str(x).strip()
    # common formats: "8P0E:A", "8p0e_A", "A"
    if ":" in s:
        s = s.split(":")[-1]
    if "_" in s:
        s = s.split("_")[-1]
    # if it still contains the pdb id, remove it
    s = re.sub(re.escape(pdb_id_4), "", s, flags=re.IGNORECASE)
    s = s.strip()
    # chains are usually 1 char; keep as-is otherwise
    return s

df["pdb_id"] = df[pdb_col].apply(norm_pdb_id)
df["chain_id"] = [norm_chain_id(pid, ch) for pid, ch in zip(df["pdb_id"], df[chain_col])]

pairs = df[["pdb_id", "chain_id"]].drop_duplicates().to_records(index=False).tolist()

print("Unique PDB-chain pairs:", len(pairs))
print("Examples:", pairs[:10])


Unique PDB-chain pairs: 1337
Examples: [('8p0e', 'A'), ('8px8', 'A'), ('8b2e', 'A'), ('8hoe', 'A'), ('8tce', 'A'), ('8ssx', 'A'), ('8pjs', 'A'), ('8j3n', 'A'), ('8shy', 'C'), ('8oe5', 'A')]


In [9]:
# PDBe mapping: (PDB, chain) to UniProt(s)

PDBe_UNIPROT_MAP = "https://www.ebi.ac.uk/pdbe/api/mappings/uniprot/{}"

_pdbe_cache = {}  # pdb_id -> UniProt mapping block

def pdbe_get_uniprot_map_for_pdb(pdb_id: str):
    pdb_id = pdb_id.strip().lower()
    if pdb_id in _pdbe_cache:
        return _pdbe_cache[pdb_id]

    r = http_get(PDBe_UNIPROT_MAP.format(pdb_id), timeout=60, retries=3, backoff=1.0)
    if r.status_code != 200:
        _pdbe_cache[pdb_id] = {}
        return {}

    try:
        j = r.json()
    except Exception:
        _pdbe_cache[pdb_id] = {}
        return {}

    block = (j or {}).get(pdb_id, {})
    uni = block.get("UniProt", {}) or {}
    _pdbe_cache[pdb_id] = uni
    return uni

def pdbe_uniprots_for_chain(pdb_id: str, chain_id: str):
    """
    Returns a list of UniProt accessions mapped to this pdb_id + chain_id.
    """
    chain_id = str(chain_id).strip()
    uni_map = pdbe_get_uniprot_map_for_pdb(pdb_id)

    hits = []
    for u, info in (uni_map or {}).items():
        for m in info.get("mappings", []) or []:
            # PDBe mapping objects typically contain chain_id; handle a few variants defensively
            ch = m.get("chain_id") or m.get("chainId") or m.get("pdb_chain_id") or m.get("pdbChainId")
            if ch is None:
                continue
            if str(ch).strip() == chain_id:
                hits.append(u)
                break

    # stable order, unique
    return sorted(set(hits))


In [10]:
for pdb_id, chain_id in pairs[:5]:
    print(pdb_id, chain_id, "->", pdbe_uniprots_for_chain(pdb_id, chain_id)[:5])

8p0e A -> ['G3M8F4']
8px8 A -> ['P25440']
8b2e A -> ['A0AA82WPC8']
8hoe A -> ['A0A085GHR3']
8tce A -> ['P08519']


In [11]:
# AFDB lookup + download model PDB + compute pLDDT stats

AFDB_API = "https://alphafold.ebi.ac.uk/api/prediction/{}"

def afdb_get_record(uniprot):
    r = http_get(AFDB_API.format(uniprot), timeout=60, retries=3, backoff=1.0)
    if r.status_code != 200:
        return None
    try:
        j = r.json()
    except Exception:
        return None
    if isinstance(j, list) and j:
        return j[0]
    if isinstance(j, dict):
        return j
    return None

def afdb_pick_pdb_url(rec):
    urls = []
    def scan(obj):
        if isinstance(obj, dict):
            for _, v in obj.items():
                if isinstance(v, str) and v.startswith("http"):
                    urls.append(v)
                else:
                    scan(v)
        elif isinstance(obj, list):
            for v in obj:
                scan(v)
    scan(rec)
    for u in urls:
        if u.lower().endswith(".pdb"):
            return u
    return urls[0] if urls else None

def afdb_plddt_stats(uniprot, sleep_s=0.05):
    rec = afdb_get_record(uniprot)
    if not rec:
        return None
    pdb_url = afdb_pick_pdb_url(rec)
    if not pdb_url:
        return None

    r = http_get(pdb_url, timeout=120, retries=3, backoff=1.0)
    if r.status_code != 200:
        return None

    bf = extract_bfactors_from_pdb_text(r.text)
    st = plddt_stats_from_bfactors(bf)
    if not st:
        return None
    st["model_url"] = pdb_url
    time.sleep(sleep_s)
    return st


In [12]:
# full pipeline

from tqdm.notebook import tqdm

SLEEP_S = 0.05
MAX_PAIRS = None   # set e.g. 200 for a quick test; None = all

af_cache = {}   # uniprot -> stats dict or None
rows = []

run_pairs = pairs if MAX_PAIRS is None else pairs[:MAX_PAIRS]

for pdb_id, chain_id in tqdm(run_pairs):
    uniprots = pdbe_uniprots_for_chain(pdb_id, chain_id)

    if not uniprots:
        rows.append({"pdb_id": pdb_id, "chain_id": chain_id, "uniprot": None, "af_found": False})
        continue

    for u in uniprots:
        if u not in af_cache:
            af_cache[u] = afdb_plddt_stats(u, sleep_s=SLEEP_S)

        st = af_cache[u]
        if not st:
            rows.append({"pdb_id": pdb_id, "chain_id": chain_id, "uniprot": u, "af_found": False})
        else:
            rows.append({
                "pdb_id": pdb_id,
                "chain_id": chain_id,
                "uniprot": u,
                "af_found": True,
                "mean": st["mean"],
                "median": st["median"],
                "min": st["min"],
                "max": st["max"],
                "frac_ge_90": st["frac_ge_90"],
                "frac_ge_70": st["frac_ge_70"],
                "n_atoms": st["n_atoms"],
                "scale_applied": st["scale_applied"],
                "model_url": st["model_url"],
            })

df_out = pd.DataFrame(rows)
print("Output rows:", len(df_out))
print("AF hits:", int(df_out["af_found"].sum()))

OUT_PATH = "../evaluation/pdbchain_to_af_plddt_stats.csv"
df_out.to_csv(OUT_PATH, index=False)
print("Wrote:", OUT_PATH)

df_out.head()


  0%|          | 0/1337 [00:00<?, ?it/s]

Output rows: 1359
AF hits: 1033
Wrote: ../evaluation/pdbchain_to_af_plddt_stats.csv


Unnamed: 0,pdb_id,chain_id,uniprot,af_found,mean,median,min,max,frac_ge_90,frac_ge_70,n_atoms,scale_applied,model_url
0,8p0e,A,G3M8F4,False,,,,,,,,,
1,8px8,A,P25440,True,65.700076,54.47,29.78,98.62,0.339528,0.458306,6188.0,1.0,https://alphafold.ebi.ac.uk/files/AF-P25440-F1...
2,8b2e,A,A0AA82WPC8,True,97.513289,98.5,81.56,98.94,0.959472,1.0,1061.0,1.0,https://alphafold.ebi.ac.uk/files/AF-A0AA82WPC...
3,8hoe,A,A0A085GHR3,True,89.278117,96.12,31.56,98.69,0.830069,0.873026,1583.0,1.0,https://alphafold.ebi.ac.uk/files/AF-A0A085GHR...
4,8tce,A,P08519,True,61.116525,70.5,21.16,97.31,0.11711,0.506615,15874.0,1.0,https://alphafold.ebi.ac.uk/files/AF-P08519-F1...


In [17]:
# One-cell: load AFDB stats + evaluation table, normalize keys, merge, compare, plot, save.

import pandas as pd
import matplotlib.pyplot as plt

AFDB_PATH = "../evaluation/pdbchain_to_af_plddt_stats.csv"
EVAL_PATH = "../evaluation/chains_evaluation_filtered_with_gdt.csv"
OUT_PATH  = "./afdb_vs_evaluation_plddt_comparison.csv"

def norm_pdb_id(x):
    return str(x).strip().lower()[:4]

def norm_chain_id(x):
    s = str(x).strip()
    # handles "8P0E:A", "8P0E_A", "A"
    if ":" in s:
        s = s.split(":")[-1]
    if "_" in s:
        s = s.split("_")[-1]
    return s.strip()

# Load
df_afdb = pd.read_csv(AFDB_PATH)
df_eval = pd.read_csv(EVAL_PATH)

# Normalize keys
df_afdb["pdb_id_norm"] = df_afdb["pdb_id"].apply(norm_pdb_id)
df_afdb["chain_id_norm"] = df_afdb["chain_id"].apply(norm_chain_id)

df_eval["pdb_id_norm"] = df_eval["pdb_id"].apply(norm_pdb_id)
df_eval["chain_id_norm"] = df_eval["chain_id"].apply(norm_chain_id)

# Keep only AFDB hits
df_afdb_ok = df_afdb[df_afdb["af_found"]].copy()

# Merge
df_merged = df_afdb_ok.merge(
    df_eval,
    on=["pdb_id_norm", "chain_id_norm"],
    how="inner",
    suffixes=("_afdb", "_eval")
)

# Compare
df_merged["plddt_diff"] = df_merged["mean"] - df_merged["AF_average_pLDDT"]
df_merged["plddt_abs_diff"] = df_merged["plddt_diff"].abs()

# Report
print("AFDB rows:", len(df_afdb), "| AF hits:", int(df_afdb["af_found"].sum()))
print("Eval rows:", len(df_eval))
print("Merged rows:", len(df_merged))

#if len(df_merged) > 1:
#    spearman = df_merged["mean"].corr(df_merged["AF_average_pLDDT"], method="spearman")
#    print("Spearman corr(mean_AFDB, AF_average_pLDDT):", spearman)

print("Abs diff mean:", df_merged["plddt_abs_diff"].mean())
print("Abs diff median:", df_merged["plddt_abs_diff"].median())
#print("Diff (AFDB mean - eval) mean:", df_merged["plddt_diff"].mean())
#print("Diff (AFDB mean - eval) median:", df_merged["plddt_diff"].median())

# Plot
#plt.figure()
#plt.scatter(df_merged["AF_average_pLDDT"], df_merged["mean"], s=10, alpha=0.6)
#plt.xlabel("AF_average_pLDDT (evaluation)")
#plt.ylabel("Mean pLDDT from AFDB PDB (B-factor)")
#plt.title("AF pLDDT comparison")
#plt.show()

# Save
df_merged.to_csv(OUT_PATH, index=False)
print("Saved:", OUT_PATH)

# Preview key columns
#display(df_merged[[
#    "pdb_id_norm","chain_id_norm",
#    "uniprot","mean","AF_average_pLDDT",
#    "plddt_diff","plddt_abs_diff",
#    "model_url"
#]].head(20))


AFDB rows: 1359 | AF hits: 1033
Eval rows: 1337
Merged rows: 1033
Abs diff mean: 7.916355751925823
Abs diff median: 4.128486934188302
Saved: ./afdb_vs_evaluation_plddt_comparison.csv
