In [4]:
import arcadia_pycolor as apc
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ------------------------------------------------
# Load data
# ------------------------------------------------
results = pd.read_csv("../outputs/human/selected_mimics/gmmviro3d_benchmarking042925_detailed.csv")

apc.mpl.setup()

# ------------------------------------------------
# Pre‑processing
# ------------------------------------------------
results["source_key"] = results["source_key"].str.replace(".tsv$", "", regex=True)
split_cols = results["source_key"].str.split("_", expand=True)
results[
    [
        "control",
        "tool",
        "alignment_type_raw",
        "tmalign_fast_raw",
        "exact_tmscore_raw",
        "tmscore_threshold_raw",
    ]
] = split_cols

results["alignment_type"] = results["alignment_type_raw"].str.replace("alignmenttype", "")
results["tmalign_fast"] = results["tmalign_fast_raw"].str.replace("tmalignfast", "")
results["exact_tmscore"] = results["exact_tmscore_raw"].str.replace("exacttmscore", "")

results["alignment_type"] = results["alignment_type"].replace({"1": "TM-align", "2": "3Di+AA"})
results["tmalign_fast"] = results.apply(
    lambda r: "none" if r["alignment_type"] == "3Di+AA" else r["tmalign_fast"], axis=1
)

results["model_type"] = results.apply(
    lambda r: "3Di+AA"
    if r["alignment_type"] == "3Di+AA"
    else ("hybrid" if "evalue" in str(r.get("feature_set", "")) else "TM-align"),
    axis=1,
)

# ------------------------------------------------
# Join with metadata
# ------------------------------------------------
metadata = pd.read_csv("../benchmarking_data/controls/control_metadata.tsv", sep="\t")
metadata["structure_file"] = metadata["structure_file"].str.replace(".pdb", "", regex=False)
metadata = metadata[~metadata["structure_file"].str.startswith("AF-")].drop_duplicates()

results = results.merge(
    metadata[["structure_file", "target_uniprot"]],
    left_on="query",
    right_on="structure_file",
    how="left",
)

# ------------------------------------------------
# Label correctness
# ------------------------------------------------
results["correct"] = "off-target hit"
results.loc[results["target"] == results["target_uniprot"], "correct"] = "correct hit"
results.loc[results["target_uniprot"].isna(), "correct"] = "unknown correct hit"
results.loc[
    (results["control"] == "c1l") & (results["target"].isin(["Q8WXC3", "P10415"])), "correct"
] = "correct hit"

# ------------------------------------------------
# Pad missing control × model_type combos
# ------------------------------------------------
all_controls = results["control"].unique()
all_model_types = ["3Di+AA", "TM-align", "hybrid"]

pad_rows = []
for ctl in all_controls:
    have = results.loc[results["control"] == ctl, "model_type"].unique()
    for mt in all_model_types:
        if mt not in have:
            pad_rows.append(
                dict(
                    control=ctl,
                    model_type=mt,
                    qtmscore=None,
                    correct="padding",
                    host_gene_names_primary="",
                    query="",
                )
            )
if pad_rows:
    results = pd.concat([results, pd.DataFrame(pad_rows)], ignore_index=True)

# ------------------------------------------------
# Plot definitions
# ------------------------------------------------
symbol_map = {
    "correct hit": "circle",
    "off-target hit": "square-open",
    "unknown correct hit": "circle-open",
}

model_color_map = {
    "3Di+AA": apc.vital,
    "TM-align": apc.canary,
    "hybrid": apc.seaweed,
}

model_to_num = {"3Di+AA": 0, "TM-align": 2, "hybrid": 1}
tickvals = [0, 1, 2]
ticktext = ["3Di+AA", "Hybrid", "TM-align"]

controls = [
    "bcl2",
    "c1l",
    "c1lpt1",
    "c1lpt2",
    "c4bp",
    "ccr1",
    "cxcr2",
    "cd47",
    "eif2a",
    "ifngr",
    "il10",
    "il18bp",
    "lfg4",
    "chemokine",
    "helicase",
    "kinase",
    "nsp5",
    "nsp16",
]

n_cols, n_rows = 6, 3
total_subplots = n_cols * n_rows
# titles = [
#   str(c) if i < len(controls) else ""
#   for i, c in enumerate(list(controls) + [""] * (total_subplots - len(controls)))
# ]
titles = [
    "Bcl-2",
    "C1L",
    "C1L pt1",
    "C1L pt2",
    "C4BP",
    "CCR1",
    "CXCR2",
    "CD47",
    "eIF2a",
    "IFNgR1",
    "IL-10",
    "IL-18BP",
    "TMBIM4",
    "Chemokine",
    "Helicase",
    "Kinase",
    "Protease",
    "RNA methylase",
]

fig = make_subplots(
    rows=n_rows,
    cols=n_cols,
    subplot_titles=titles,
    shared_xaxes=True,
    shared_yaxes=True,
    horizontal_spacing=0.02,
    vertical_spacing=0.07,
)

# ------------------------------------------------
# Invisible markers so every legend key shows
# ------------------------------------------------
for lab, sym in symbol_map.items():
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            marker=dict(color=apc.black, symbol=sym, size=10, opacity=1.0),
            legendgroup=lab,
            name=lab,
            showlegend=True,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )

seen_legends = set()

# ------------------------------------------------
# Add data traces
# ------------------------------------------------
for i, ctl in enumerate(controls):
    sub_df = results[results["control"] == ctl]
    row, col = i // n_cols + 1, i % n_cols + 1

    # placeholders for missing x values (numeric already)
    for mt in all_model_types:
        fig.add_trace(
            go.Scatter(
                x=[model_to_num[mt]],
                y=[0],
                mode="markers",
                marker=dict(color="rgba(0,0,0,0)", size=0),
                showlegend=False,
                hoverinfo="skip",
            ),
            row=row,
            col=col,
        )

    for correctness in symbol_map:
        for model_type in model_color_map:
            grp = sub_df[
                (sub_df["correct"] == correctness) & (sub_df["model_type"] == model_type)
            ].copy()
            if grp.empty:
                continue

            if ctl == controls[0]:
                grp["model_type"] = pd.Categorical(
                    grp["model_type"], categories=["3Di+AA", "hybrid", "TM-align"], ordered=True
                )
                grp.sort_values("model_type", inplace=True)

            jittered_x, hover_mts = [], []
            for mt in grp["model_type"]:
                jittered_x.append(model_to_num[mt])
                hover_mts.append(mt)

            for mt_idx in range(3):
                idxs = [j for j, mt in enumerate(hover_mts) if model_to_num[mt] == mt_idx]
                if idxs:
                    offs = [0] if len(idxs) == 1 else np.linspace(-0.25, 0.25, len(idxs))
                    for j, off in zip(idxs, offs):
                        jittered_x[j] += off

            show_leg = correctness not in seen_legends
            seen_legends.add(correctness)

            fig.add_trace(
                go.Scatter(
                    x=jittered_x,
                    y=grp["qtmscore"],
                    mode="markers",
                    marker=dict(
                        color=model_color_map[model_type],
                        symbol=symbol_map[correctness],
                        size=10,
                        opacity=0.95,
                    ),
                    showlegend=False,
                    customdata=np.column_stack(
                        (
                            grp[["host_gene_names_primary", "query", "genbank_name"]].values,
                            hover_mts,
                        )
                    ),
                    hovertemplate=(
                        "Model type: %{customdata[3]}<br>"
                        + "Query TM-score: %{y}<br>"
                        + "Host gene: %{customdata[0]}<br>"
                        + "Viral query GenBank name: %{customdata[2]}<br>"
                        + "Viral query ID: %{customdata[1]}<extra></extra>"
                    ),
                ),
                row=row,
                col=col,
            )

# ------------------------------------------------
# Layout
# ------------------------------------------------
parchment = apc.parchment
fig.update_layout(
    height=800,
    width=1800,
    paper_bgcolor=parchment,
    plot_bgcolor="rgba(0,0,0,0)",
    margin=dict(t=50, b=100, l=150),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.2,
        xanchor="center",
        x=0.5,
        font=dict(size=14),
        itemsizing="constant",
    ),
)


# Fix range, ticks, grid, etc.
for i in range(1, total_subplots + 1):
    xname = "xaxis" if i == 1 else f"xaxis{i}"
    yname = "yaxis" if i == 1 else f"yaxis{i}"

    if xname in fig.layout:
        is_bottom = i > n_cols * (n_rows - 1)
        fig.layout[xname].update(
            range=[-0.5, 2.5],
            tickvals=tickvals,
            ticktext=ticktext,
            tickfont=dict(size=14),
            showticklabels=is_bottom,
            showgrid=False,
            zeroline=False,
            linecolor="black",
            showline=is_bottom,
        )

    if yname in fig.layout:
        is_left = (i - 1) % n_cols == 0
        fig.layout[yname].update(
            range=[0, 1],
            tickvals=[0, 0.25, 0.5, 0.75],
            ticktext=["0", "0.25", "0.5", "0.75"],
            tickmode="array",
            showticklabels=is_left,
            showgrid=False,
            zeroline=False,
            linecolor="black",
            showline=is_left,
        )

for ann in fig.layout.annotations:
    ann.font = dict(size=14)
    ann.y += 0.03


apc.plotly.style_plot(fig, categorical_axes="x", monospaced_axes="y")
apc.plotly.set_xaxis_categorical(fig, row=1, col=1)

fig.add_annotation(
    text="Query TM-score",
    xref="paper",
    yref="paper",
    x=-0.05,
    y=0.5,
    textangle=-90,
    showarrow=False,
    font=dict(size=16, family="SuisseIntl-Medium"),
)
fig.add_annotation(
    text="Model type",
    xref="paper",
    yref="paper",
    x=0.5,
    y=-0.12,
    showarrow=False,
    font=dict(size=16, family="SuisseIntl-Medium"),
)

fig.write_html("figure5.html")


The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.

