# Scop3P

A comprehensive database of human phosphosites within their full context. Scop3P integrates sequences (UniProtKB/Swiss-Prot), structures (PDB), and uniformly reprocessed phosphoproteomics data (PRIDE) to annotate all known human phosphosites. 

Scop3P, available at https://iomics.ugent.be/scop3p, presents a unique resource for visualization and analysis of phosphosites and for understanding of phosphosite structure–function relationships.

Please cite: https://doi.org/10.1021/acs.jproteome.0c00306

# Scop3P-Biophysical prediction & mutation effects 

This notebook renders analysis in 3 different tabs:

1. **WT prediction** (interactive Bokeh plot + Scop3P PTM circles)  
2. **Mutant prediction** (WT overlay in light shades + mutant in dark shades + PTM circles)  
3. **Inference** (simple summary around mutation positions)

> notes
> - Predictions are done using b2btools (DynaMine, DisOmine, EFoldMine)
 > - This app keeps the predictions in memory and the Inference tab will summarize the effect


In [2]:

import tempfile
import requests
import pandas as pd

import ipywidgets as W
from IPython.display import display, HTML

# b2bTools
from b2bTools import SingleSeq
try:
    from b2bTools import constants
except Exception:
    constants = None

# Bokeh embed (robust in Voila)
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.embed import components
from bokeh.resources import CDN


# -----------------------------
# Fetchers
# -----------------------------
def fetch_uniprot_sequence(accession: str) -> str:
    """Fetch FASTA from UniProt and return the AA sequence as a string."""
    url = f"https://rest.uniprot.org/uniprotkb/{accession}.fasta"
    r = requests.get(url, timeout=30)
    r.raise_for_status()
    lines = r.text.splitlines()
    seq = "".join([ln.strip() for ln in lines if ln and not ln.startswith(">")])
    return seq

def fetch_scop3p_modifications(accession: str) -> pd.DataFrame:
    url = "https://iomics.ugent.be/scop3p/api/modifications"
    r = requests.get(url, params={"accession": accession},
                     headers={"accept": "application/json"}, timeout=30)
    r.raise_for_status()
    payload = r.json()

    # payload might be a dict OR a list with one dict
    if isinstance(payload, list):
        payload = payload[0] if payload else {}

    mods = payload.get("modifications", [])
    df = pd.DataFrame(mods)

    if df.empty:
        return df

    keep = [c for c in ["position", "residue", "name", "source", "evidence", "reference", "functionalScore"]
            if c in df.columns]
    df = df[keep].copy()

    df["position"] = pd.to_numeric(df["position"], errors="coerce")
    df = df.dropna(subset=["position"])
    df["position"] = df["position"].astype(int)

    return df

# -----------------------------
# Prediction
# -----------------------------
def predict_biophysical(accession: str, sequence: str) -> dict:
    """Run b2bTools SingleSeq prediction and return the raw dict."""
    with tempfile.NamedTemporaryFile(prefix="seq_", suffix=".fasta", mode="w") as fp:
        fp.write(f">{accession}\n{sequence}\n")
        fp.flush()
        fp.seek(0)
        s = SingleSeq(fp.name)
        # Prefer explicit tools if available; otherwise fall back to defaults.
        if constants is not None:
            tool_candidates = []
            for name in ["TOOL_BACKBONE_DYNAMICS", "TOOL_DYNAMINE", "TOOL_DISOMINE", "TOOL_EFOLDMINE"]:
                if hasattr(constants, name):
                    tool_candidates.append(getattr(constants, name))
            if tool_candidates:
                pred = s.predict(tools=tool_candidates).get_all_predictions()
            else:
                pred = s.predict().get_all_predictions()
        else:
            pred = s.predict().get_all_predictions()
    return pred

def prediction_to_df(pred: dict, accession: str) -> pd.DataFrame:
    """Convert b2bTools output to a dataframe with seqpos/backbone/disoMine/earlyFolding."""
    # Common b2bTools structure
    prot = None
    if isinstance(pred, dict):
        if "proteins" in pred and accession in pred["proteins"]:
            prot = pred["proteins"][accession]
        elif accession in pred:
            prot = pred[accession]
    if prot is None:
        raise ValueError("Could not find protein predictions for the requested accession in b2bTools output.")

    n = len(prot.get("seq", "")) or len(prot.get("backbone", [])) or len(prot.get("disoMine", [])) or len(prot.get("earlyFolding", []))
    df = pd.DataFrame({
        "seqpos": list(range(1, n+1)),
        "backbone": prot.get("backbone", [None]*n),
        "disoMine": prot.get("disoMine", [None]*n),
        "earlyFolding": prot.get("earlyFolding", [None]*n),
    })
    for c in ["backbone", "disoMine", "earlyFolding"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    return df

def apply_mutations(sequence: str, positions_csv: str, aas_csv: str) -> str:
    """Apply 1-indexed mutations: positions '10,25' and aas 'A,V'."""
    pos = [p.strip() for p in positions_csv.split(",") if p.strip()]
    aas = [a.strip().upper() for a in aas_csv.split(",") if a.strip()]
    if len(pos) != len(aas):
        raise ValueError("Positions and amino acids must have the same count (e.g., '10,25' and 'A,V').")

    seq = list(sequence)
    for p_str, aa in zip(pos, aas):
        p = int(p_str)
        if p < 1 or p > len(seq):
            raise ValueError(f"Position {p} is out of range for sequence length {len(seq)}.")
        if len(aa) != 1 or aa not in set("ACDEFGHIKLMNPQRSTVWY"):
            raise ValueError(f"Invalid amino acid '{aa}' at position {p}.")
        seq[p-1] = aa
    return "".join(seq)

# -----------------------------
# Plotting (Bokeh)
# -----------------------------
def _embed_bokeh(fig) -> None:
    script, div = components(fig)
    display(HTML(CDN.render() + div + script))

def make_wt_plot(wt_df: pd.DataFrame, mods_df: pd.DataFrame):
    p = figure(width=1000, height=300, tools="pan,box_zoom,reset,save", toolbar_location="below", toolbar_sticky=False)
    p.title.text = "Biophysical properties (WT)"

    l1 = p.line(wt_df["seqpos"], wt_df["backbone"], line_width=2, color="blue", alpha=0.8,
                muted_color="blue", muted_alpha=0.2, legend_label="backbone_dynamics")
    l2 = p.line(wt_df["seqpos"], wt_df["disoMine"], line_width=2, color="red", alpha=0.8,
                muted_color="red", muted_alpha=0.2, legend_label="disorder")
    l3 = p.line(wt_df["seqpos"], wt_df["earlyFolding"], line_width=2, color="grey", alpha=0.8,
                muted_color="grey", muted_alpha=0.2, legend_label="earlyFolding")

    renderers_for_hover = [l1, l2, l3]
    p.add_tools(HoverTool(tooltips="Seqpos:@x, value:@y", renderers=renderers_for_hover))

    # PTM circles (Scop3P)
    ptm_renderer = None
    if mods_df is not None and not mods_df.empty and "position" in mods_df.columns:
        ptm_src = ColumnDataSource(dict(
            x=mods_df["position"].tolist(),
            y=[0.5]*len(mods_df),
            residue=mods_df.get("residue", pd.Series([""]*len(mods_df))).astype(str).tolist(),
            name=mods_df.get("name", pd.Series([""]*len(mods_df))).astype(str).tolist(),
            source=mods_df.get("source", pd.Series([""]*len(mods_df))).astype(str).tolist(),
        ))
        ptm_renderer = p.scatter(x="x", y="y", source=ptm_src, marker="circle", size=10, fill_alpha=0.6, line_alpha=0.8,
                                color="grey", legend_label="P-sites")
        p.add_tools(HoverTool(tooltips=[("Seqpos", "@x"), ("residue", "@residue"), ("mod", "@name"), ("source", "@source")],
                              renderers=[ptm_renderer]))

    p.legend.click_policy = "mute"
    p.add_layout(p.legend[0], "right")
    return p

def make_mut_plot(wt_df: pd.DataFrame, mut_df: pd.DataFrame, mods_df: pd.DataFrame):
    p = figure(width=1000, height=300, tools="pan,box_zoom,reset,save", toolbar_location="below", toolbar_sticky=False)
    p.title.text = "Biophysical properties (WT vs Mutant)"

    # WT (light shades)
    b1 = p.line(wt_df["seqpos"], wt_df["backbone"], line_width=2, color="skyblue", alpha=0.8,
                muted_color="skyblue", muted_alpha=0.2, legend_label="backbone_dynamics (WT)")
    d1 = p.line(wt_df["seqpos"], wt_df["disoMine"], line_width=2, color="salmon", alpha=0.8,
                muted_color="salmon", muted_alpha=0.2, legend_label="disorder (WT)")
    e1 = p.line(wt_df["seqpos"], wt_df["earlyFolding"], line_width=2, color="grey", alpha=0.8,
                muted_color="grey", muted_alpha=0.2, legend_label="earlyFolding (WT)")

    # Mutant (dark shades)
    b2 = p.line(mut_df["seqpos"], mut_df["backbone"], line_width=2, color="blue", alpha=0.8,
                muted_color="blue", muted_alpha=0.2, legend_label="backbone_mut")
    d2 = p.line(mut_df["seqpos"], mut_df["disoMine"], line_width=2, color="red", alpha=0.8,
                muted_color="red", muted_alpha=0.2, legend_label="disorder_mut")
    e2 = p.line(mut_df["seqpos"], mut_df["earlyFolding"], line_width=2, color="black", alpha=0.8,
                muted_color="black", muted_alpha=0.2, legend_label="earlyFolding_mut")

    p.add_tools(HoverTool(tooltips="Seqpos:@x, value:@y", renderers=[b1, b2, d1, d2, e1, e2]))

    # PTM circles
    if mods_df is not None and not mods_df.empty and "position" in mods_df.columns:
        ptm_src = ColumnDataSource(dict(
            x=mods_df["position"].tolist(),
            y=[0.5]*len(mods_df),
            residue=mods_df.get("residue", pd.Series([""]*len(mods_df))).astype(str).tolist(),
            name=mods_df.get("name", pd.Series([""]*len(mods_df))).astype(str).tolist(),
            source=mods_df.get("source", pd.Series([""]*len(mods_df))).astype(str).tolist(),
        ))
        l4 = p.scatter(x="x", y="y", source=ptm_src, marker="circle", size=10, fill_alpha=0.6, line_alpha=0.8,
                      color="grey", legend_label="P-sites")
        p.add_tools(HoverTool(tooltips=[("Seqpos", "@x"), ("residue", "@residue"), ("mod", "@name"), ("source", "@source")],
                              renderers=[l4]))

    p.legend.click_policy = "mute"
    p.add_layout(p.legend[0], "right")
    return p


# -----------------------------
# Inference helpers (label shifts)
# -----------------------------
def label_backbone(x: float) -> str:
    if pd.isna(x): return "NA"
    if x > 1.0:  return "membrane-spanning"
    if x > 0.8:  return "rigid"
    if x > 0.69: return "context-dependent"
    return "flexible"

def label_disorder(x: float) -> str:
    if pd.isna(x): return "NA"
    return "disordered" if x > 0.50 else "ordered"

def label_earlyfold(x: float) -> str:
    if pd.isna(x): return "NA"
    return "early-folding" if x > 0.169 else "non-early-folding"

LABEL_FUNCS = {
    "backbone": (label_backbone, "Backbone dynamics"),
    "disoMine": (label_disorder, "Disorder"),
    "earlyFolding": (label_earlyfold, "Early folding"),
}

def parse_mutations(pos_str: str, aa_str: str):
    pos = [p.strip() for p in pos_str.split(",") if p.strip()]
    aa  = [a.strip().upper() for a in aa_str.split(",") if a.strip()]
    if len(pos) != len(aa):
        raise ValueError("Number of positions must match number of amino acids.")
    muts = []
    for p, a in zip(pos, aa):
        if not p.isdigit():
            raise ValueError(f"Invalid position '{p}'.")
        p_i = int(p)
        if len(a) != 1 or a not in set("ACDEFGHIKLMNPQRSTVWY"):
            raise ValueError(f"Invalid amino acid '{a}' at position {p_i}.")
        muts.append((p_i, a))
    return muts

def mutation_effect_table_with_label_shift(wt_df: pd.DataFrame,
                                          mut_df: pd.DataFrame,
                                          feature: str,
                                          mutations,
                                          window: int = 5) -> pd.DataFrame:
    """Create a mutation-centric inference table, reporting label changes at the site and in a +/- window."""
    if feature not in LABEL_FUNCS:
        raise ValueError(f"feature must be one of {list(LABEL_FUNCS.keys())}")

    label_fn, pretty = LABEL_FUNCS[feature]

    wt = wt_df.set_index("seqpos", drop=False)
    mu = mut_df.set_index("seqpos", drop=False)
    max_pos = int(wt["seqpos"].max())

    rows = []
    for pos, aa_to in mutations:
        if pos not in wt.index or pos not in mu.index:
            continue

        wt_center = float(wt.loc[pos, feature])
        mu_center = float(mu.loc[pos, feature])
        d_center  = mu_center - wt_center

        lo = max(1, pos - window)
        hi = min(max_pos, pos + window)

        wt_mean = float(wt.loc[lo:hi, feature].astype(float).mean())
        mu_mean = float(mu.loc[lo:hi, feature].astype(float).mean())
        d_mean = mu_mean - wt_mean

        wt_lab_pos = label_fn(wt_center)
        mu_lab_pos = label_fn(mu_center)
        wt_lab_mean = label_fn(wt_mean)
        mu_lab_mean = label_fn(mu_mean)

        shift_pos = f"{wt_lab_pos} → {mu_lab_pos}" if wt_lab_pos != mu_lab_pos else f"{wt_lab_pos} (no class change)"
        shift_mean = f"{wt_lab_mean} → {mu_lab_mean}" if wt_lab_mean != mu_lab_mean else f"{wt_lab_mean} (no class change)"

        wt_aa = str(wt.loc[pos, "seq"])
        note_parts = []
        mut_aa_seq = str(mu.loc[pos, "seq"])
        if mut_aa_seq.upper() != aa_to.upper():
            note_parts.append(f"Mut seq AA={mut_aa_seq} (expected {aa_to})")

        inference = (
            f"{pretty}: {shift_pos} at site (Δ {d_center:+.3f}); "
            f"window mean: {shift_mean} (Δ {d_mean:+.3f})"
        )

        rows.append({
            "pos": pos,
            "WT_AA": wt_aa,
            "Mut_AA": aa_to,
            "mutation": f"{wt_aa}{pos}{aa_to}",
            f"{feature}_WT@pos": wt_center,
            f"{feature}_Mut@pos": mu_center,
            "Δ@pos": d_center,
            f"{feature}_WT_mean": wt_mean,
            f"{feature}_Mut_mean": mu_mean,
            "Δ_mean": d_mean,
            "label_shift@pos": shift_pos,
            "label_shift_mean": shift_mean,
            "inference": inference,
            "note": "; ".join(note_parts)
        })

    df = pd.DataFrame(rows)
    num_cols = [c for c in df.columns if any(k in c for k in ["_WT@pos", "_Mut@pos", "_WT_mean", "_Mut_mean", "Δ@pos", "Δ_mean"])]
    for c in num_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce").round(3)
    return df

# -----------------------------
# Display helpers (tables under plots)
# -----------------------------
def _non_runtime_pred_cols(df: pd.DataFrame):
    """Return prediction columns excluding runtime-like columns and core columns."""
    core = {"seqpos", "seq"}
    cols = []
    for c in df.columns:
        if c in core:
            continue
        if "runtime" in str(c).lower():
            continue
        cols.append(c)
    return cols

def _add_ptm_flag(df: pd.DataFrame, mods_df: pd.DataFrame) -> pd.DataFrame:
    """Add PTMs yes/no based on Scop3P positions."""
    out = df.copy()
    ptm_positions = set()
    if mods_df is not None and not mods_df.empty and "position" in mods_df.columns:
        ptm_positions = set(pd.to_numeric(mods_df["position"], errors="coerce").dropna().astype(int).tolist())
    out["PTMs"] = out["seqpos"].apply(lambda x: "yes" if int(x) in ptm_positions else "no")
    return out

def make_wt_table(wt_df: pd.DataFrame, mods_df: pd.DataFrame) -> pd.DataFrame:
    """WT table: seqpos, AA, predictions (no runtime), PTMs yes/no."""
    cols = _non_runtime_pred_cols(wt_df)
    base = wt_df[["seqpos", "seq"] + cols].copy()
    base = _add_ptm_flag(base, mods_df)
    return base

def make_wt_mut_merged_table(wt_df: pd.DataFrame, mut_df: pd.DataFrame, mods_df: pd.DataFrame) -> pd.DataFrame:
    """
    Merged table: align on seqpos; show WT_AA + Mut_AA and predictions adjacent (WT vs Mut).
    (Not merging on AA to avoid dropping mutated positions.)
    """
    wt_cols = _non_runtime_pred_cols(wt_df)
    mut_cols = _non_runtime_pred_cols(mut_df)

    # Keep only common prediction columns (so the table is robust if one side has extras)
    common = [c for c in wt_cols if c in mut_cols]

    w = wt_df[["seqpos", "seq"] + common].copy().rename(columns={"seq": "WT_AA"})
    m = mut_df[["seqpos", "seq"] + common].copy().rename(columns={"seq": "Mut_AA"})

    merged = pd.merge(w, m, on="seqpos", how="inner", suffixes=("_WT", "_Mut"))

    # Reorder: seqpos, WT_AA, Mut_AA, (feat_WT, feat_Mut) pairs...
    ordered = ["seqpos", "WT_AA", "Mut_AA"]
    for c in common:
        ordered.append(f"{c}_WT")
        ordered.append(f"{c}_Mut")
    merged = merged[ordered]

    merged = _add_ptm_flag(merged, mods_df)
    return merged

def display_track_guide():
    display(HTML("""
<div style="margin:6px 0 10px 0; padding:8px 10px; border:1px solid #ddd; border-radius:8px;">
  <div style="margin-bottom:6px;"><b>How to read the tracks</b></div>

  <div style="margin:2px 0;">
    <span style="color:blue; font-weight:600;">Backbone dynamics</span>:
    &gt;1.0 membrane-spanning, 0.8–1.0 rigid, 0.69–0.80 context-dependent, &lt;0.69 flexible
  </div>

  <div style="margin:2px 0;">
    <span style="color:red; font-weight:600;">Disorder (DisoMine)</span>:
    values &gt;0.50 indicate disordered regions
  </div>

  <div style="margin:2px 0;">
    <span style="color:grey; font-weight:600;">Early folding</span>:
    values &gt;0.169 suggest early-folding propensity
  </div>

  <div style="margin:2px 0;">
    <span style="color:grey; font-weight:600;">P-sites</span>:
    phosphorylation positions (grey dots)
  </div>
</div>
"""))


from uuid import uuid4

def display_scrollable_table(
    df: pd.DataFrame,
    *,
    title: str = "",
    height_px: int = 420,
    width: str = "100%",
    sticky_cols: int = 3,
    col_widths_px=None,
    highlight_seqpos=None,          # NEW: set/list of seqpos to highlight
    seqpos_col: str = "seqpos",     # NEW: which column holds positions
):

    """
    Voila-safe scrollable table with sticky header + sticky left columns.
    - sticky header (row 1)
    - freezes first `sticky_cols` columns
    - horizontal + vertical scrolling
    """
    if df is None or df.empty:
        display(HTML(f"<div><b>{title}</b><br><i>(no rows)</i></div>" if title else "<i>(no rows)</i>"))
        return

    # create a unique table id (prevents CSS collisions across tabs)
    table_id = f"tbl_{uuid4().hex}"

    # reasonable default widths for sticky columns
    if col_widths_px is None:
        # fallback: 90px per sticky col
        col_widths_px = [90] * sticky_cols
    else:
        # pad/truncate to sticky_cols
        col_widths_px = (list(col_widths_px) + [90] * sticky_cols)[:sticky_cols]

    # compute left offsets for each sticky column
    left_offsets = []
    acc = 0
    for w in col_widths_px:
        left_offsets.append(acc)
        acc += int(w)

    # build HTML table
    hl = set(int(x) for x in highlight_seqpos) if highlight_seqpos else set()

    if hl and seqpos_col in df.columns:
        # Manual HTML so we can add a class per row
        cols = list(df.columns)
        thead = "<thead><tr>" + "".join([f"<th>{c}</th>" for c in cols]) + "</tr></thead>"

        rows_html = []
        for _, r in df.iterrows():
            try:
                pos_val = int(r[seqpos_col])
            except Exception:
                pos_val = None
            cls = " class='row-hl'" if (pos_val is not None and pos_val in hl) else ""
            tds = "".join([f"<td>{'' if pd.isna(r[c]) else str(r[c])}</td>" for c in cols])
            rows_html.append(f"<tr{cls}>{tds}</tr>")

        tbody = "<tbody>" + "".join(rows_html) + "</tbody>"
        html_table = f"<table border='1' class='dataframe'>{thead}{tbody}</table>"
    else:
        html_table = df.to_html(index=False, escape=True)


    # CSS: sticky header + sticky left columns
    # Note: nth-child is 1-indexed
    sticky_css_cols = []
    for i in range(1, sticky_cols + 1):
        left = left_offsets[i - 1]
        z = 4 + (sticky_cols - i)  # keep leftmost on top
        sticky_css_cols.append(f"""
#{table_id} table th:nth-child({i}),
#{table_id} table td:nth-child({i}) {{
  position: sticky;
  left: {left}px;
  z-index: {z};
  background: #fff;
}}
""")

    css = f"""
<style>
#{table_id} .tbl-title {{
  font-weight: 700;
  margin: 8px 0 6px 0;
}}

#{table_id} tr.row-hl td {{
  background: #fff6cc;  /* light highlight */
}}


#{table_id} .tbl-wrap {{
  width: {width};
  max-width: {width};
  height: {height_px}px;
  overflow: auto;
  border: 1px solid #ddd;
  border-radius: 8px;
}}

#{table_id} table {{
  border-collapse: collapse;
  width: max-content; /* allow horizontal scrolling */
  min-width: 100%;
  font-size: 12px;
}}

#{table_id} th, #{table_id} td {{
  border: 1px solid #eee;
  padding: 6px 8px;
  white-space: nowrap;
}}

#{table_id} thead th {{
  position: sticky;
  top: 0;
  z-index: 10;
  background: #f7f7f7;
}}

{''.join(sticky_css_cols)}
</style>
"""

    # Optional: force widths for sticky cols via colgroup
    # (helps sticky offsets feel stable)
    colgroup = "<colgroup>"
    for i, w in enumerate(col_widths_px):
        colgroup += f"<col style='width:{int(w)}px;'>"
    colgroup += "</colgroup>"

    # inject colgroup after <table ...>
    if "<table" in html_table:
        html_table = html_table.replace("<table border=\"1\" class=\"dataframe\">",
                                        f"<table border=\"1\" class=\"dataframe\">{colgroup}", 1)

    block = f"""
<div id="{table_id}">
  {f'<div class="tbl-title">{title}</div>' if title else ''}
  <div class="tbl-wrap">
    {html_table}
  </div>
</div>
"""
    display(HTML(css + block))


# -----------------------------
# UI
# -----------------------------
acc_in = W.Text(value="P07949", description="UniProt:", layout=W.Layout(width="260px"))
status = W.HTML(value="<div style='padding:6px;border:1px solid #ddd;border-radius:6px;'>Ready.</div>")

run_wt_btn = W.Button(description="Fetch & Predict WT", button_style="success")
wt_out = W.Output()

pos_in = W.Text(value="606", description="Positions:", placeholder="e.g. 10,25,100", layout=W.Layout(width="320px"))
aa_in  = W.Text(value="A", description="To AA:", placeholder="e.g. A,V,G", layout=W.Layout(width="320px"))
run_mut_btn = W.Button(
    description="Apply & Predict",
    button_style="warning",
    layout=W.Layout(width="320px"))
mut_out = W.Output()

run_inf_btn = W.Button(description="Run inference", button_style="info")
inf_out = W.Output()


APP_STATE = {"accession": None, "sequence": None, "mods_df": None, "wt_df": None, "mut_df": None, "mutations": None}

def _set_status(msg: str):
    status.value = f"<div style='padding:6px;border:1px solid #ddd;border-radius:6px;'>{msg}</div>"

def do_wt(_=None):
    with wt_out:
        wt_out.clear_output()
        try:
            acc = acc_in.value.strip()
            _set_status("Fetching UniProt sequence + Scop3P PTMs…")
            seq = fetch_uniprot_sequence(acc)
            mods = fetch_scop3p_modifications(acc)

            _set_status("Running WT biophysical prediction…")
            pred = predict_biophysical(acc, seq)
            wt_df = prediction_to_df(pred, acc)
            # add per-position AA (required for inference table)
            wt_df["seq"] = list(seq)


            APP_STATE.update({"accession": acc, "sequence": seq, "mods_df": mods, "wt_df": wt_df, "mut_df": None, "mutations": None})

            _set_status(f"WT prediction ready. (PTMs: {0 if mods is None else len(mods)})")
            fig = make_wt_plot(wt_df, mods)
            display_track_guide()
            _embed_bokeh(fig)

            # ---- extra display under plot (WT table) ----
            display(HTML("<hr style='margin:10px 0;'>"))
            wt_tbl = make_wt_table(wt_df, mods)
            
            display_scrollable_table(
                wt_tbl,
                title="WT predicted features (per residue)",
                height_px=420,
                width="100%",
                sticky_cols=0,                 # seqpos + AA + PTMs (adjust if you want)
                col_widths_px=[90, 70, 60],    # seqpos, seq, PTMs
            )

        except Exception as e:
            _set_status(f"<b>Error:</b> {e}")
            raise

def do_mut(_=None):
    with mut_out:
        mut_out.clear_output()
        try:
            if APP_STATE["wt_df"] is None or APP_STATE["sequence"] is None:
                raise ValueError("Run WT prediction first.")

            acc = APP_STATE["accession"]
            seq = APP_STATE["sequence"]
            mods = APP_STATE["mods_df"]

            _set_status("Applying mutations + predicting…")
            mut_seq = apply_mutations(seq, pos_in.value, aa_in.value)
            pred_mut = predict_biophysical(acc, mut_seq)
            mut_df = prediction_to_df(pred_mut, acc)
            # add per-position AA (required for inference table)
            mut_df["seq"] = list(mut_seq)  # ✅ correct


            APP_STATE.update({"mut_df": mut_df, "mutations": (pos_in.value, aa_in.value)})

            _set_status("Mutant prediction ready.")

            # ---- mutant summary message (above plot) ----
            muts = parse_mutations(pos_in.value, aa_in.value)  # [(pos, 'P'), ...]
            wt_seq = APP_STATE["sequence"]

            labels = []
            for p, aa_to in muts:
                if 1 <= p <= len(wt_seq):
                    wt_aa = wt_seq[p-1]
                    labels.append(f"{wt_aa}{p}{aa_to}")
                else:
                    labels.append(f"?{p}{aa_to}")

            msg = ", ".join(labels) if labels else "(none)"
            display(HTML(f"<div style='margin:6px 0 10px 0;'><b>Predicted properties for mutants:</b> {msg}</div>"))

            fig = make_mut_plot(APP_STATE["wt_df"], mut_df, mods)
            display_track_guide()
            _embed_bokeh(fig)

            # ---- extra display under plot (WT vs Mut merged table) ----
            display(HTML("<hr style='margin:10px 0;'>"))
            merged_tbl = make_wt_mut_merged_table(APP_STATE["wt_df"], mut_df, mods)

            highlight_positions = [p for p, _ in parse_mutations(pos_in.value, aa_in.value)]

            
            display_scrollable_table(
                merged_tbl,
                title="WT vs Mutant predicted features (aligned by seqpos)",
                height_px=420,
                width="100%",
                sticky_cols=0,
                highlight_seqpos=highlight_positions,
                seqpos_col="seqpos",
            )



        except Exception as e:
            _set_status(f"<b>Error:</b> {e}")
            raise

from IPython.display import Markdown, display

def do_inf(_=None):
    with inf_out:
        inf_out.clear_output()
        try:
            if APP_STATE["wt_df"] is None or APP_STATE["mut_df"] is None:
                raise ValueError("Run WT prediction and Mutant prediction first.")

            wt_df = APP_STATE["wt_df"]
            mut_df = APP_STATE["mut_df"]

            # parse mutations from the same UI fields
            muts = parse_mutations(pos_in.value, aa_in.value)

            _set_status("Running inference…")

            display(Markdown("## Label-shift summary (±5 AA window)"))
            for feat, (_, pretty) in LABEL_FUNCS.items():
                display(Markdown(f"### {pretty}"))
                df = mutation_effect_table_with_label_shift(
                    wt_df=wt_df,
                    mut_df=mut_df,
                    feature=feat,
                    mutations=muts,
                    window=5
                )
                display(df)

            _set_status("Inference ready.")
        except Exception as e:
            _set_status(f"<b>Error:</b> {e}")
            raise



run_wt_btn.on_click(do_wt)
run_mut_btn.on_click(do_mut)
run_inf_btn.on_click(do_inf)

wt_box = W.VBox([W.HBox([acc_in, run_wt_btn]), status, wt_out])
mut_box = W.VBox([W.HBox([pos_in, aa_in, run_mut_btn]), mut_out])
inf_box = W.VBox([run_inf_btn, inf_out])

tabs = W.Tab(children=[wt_box, mut_box, inf_box])
tabs.set_title(0, "WT prediction")
tabs.set_title(1, "Mutant prediction")
tabs.set_title(2, "Inference")

display(tabs)


Tab(children=(VBox(children=(HBox(children=(Text(value='P07949', description='UniProt:', layout=Layout(width='…