## CB with ThermoMPNN

In [None]:
from __future__ import annotations

import os
import sys
import requests

sys.path.append("../utilities/")

from cbutils import (
    aa_code,
    make_consensus_sequence,
    setup_aligner,
    alignment_to_mapping,
    add_scaled_outputs,
)

import matplotlib.pyplot as plt
import pandas as pd

### Select structures and chains

In [None]:
pdbs = {
    "open": "../pdbs/lpla/3a7r.pdb",
    "closed": "../pdbs/lpla/1x2g.pdb",
}

chains = {
    "open": "A",
    "closed": "A",
}

### Align sequences, generate mutants, and score

In [None]:
from thermompnn.protein_mpnn_utils import alt_parse_PDB
from thermompnn import Mutation, ALPHABET
from thermompnn.analysis.thermompnn_benchmarking import load_model

import torch

#extract sequences from structures
thermompnn_seqs = {structure: alt_parse_PDB(pdbs[structure], chains[structure])[0]['seq'] for structure in pdbs}
con_seq = make_consensus_sequence(list(thermompnn_seqs.values()))

muts = []
mut_seqs = []
for i, aa in enumerate(con_seq):
    for aa_new in aa_code:
        if aa_new != aa:
            mut_seqs.append(con_seq[:i] + aa_new + con_seq[i + 1 :])
            muts.append(f"{aa}{i+1}{aa_new}")

output_data = pd.DataFrame({"mut": muts, "seq": mut_seqs})

aligner = setup_aligner()
thermompnn_alignments = {pdb: aligner.align(con_seq, seq)[0] for pdb, seq in thermompnn_seqs.items()}

thermompnn_mappings = {
    pdb: alignment_to_mapping(alignment) for pdb, alignment in thermompnn_alignments.items()
}

#load weights and download weights if not found
try:
    model = load_model("../weights/v_48_020.pt", "../weights/thermoMPNN_default.pt")
except FileNotFoundError:
    os.makedirs("../weights", exist_ok=True)
    tm_url = "https://raw.githubusercontent.com/andrewxue98/ThermoMPNN/main/weights/thermoMPNN_default.pt"
    pm_url = (
        "https://raw.githubusercontent.com/andrewxue98/ThermoMPNN/main/weights/v_48_020.pt"
    )

    response_tm = requests.get(tm_url)
    if response_tm.status_code == 200:
        with open("../weights/thermoMPNN_default.pt", "wb") as f:
            f.write(response_tm.content)
        print("ThermoMPNN weights downloaded successfully!")
    else:
        print(f"Failed to download ThermoMPNN weights. Status code: {response_tm.status_code}")

    response_pm = requests.get(pm_url)
    if response_pm.status_code == 200:
        with open("../weights/v_48_020.pt", "wb") as f:
            f.write(response_pm.content)
        print("ThermoMPNN weights downloaded successfully!")
    else:
        print(f"Failed to download ThermoMPNN weights. Status code: {response_pm.status_code}")

    model = load_model("../weights/v_48_020.pt", "../weights/thermoMPNN_default.pt")

model = model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#run thermompnn scoring
for structure in pdbs:
    thermompnn_scores = []
    print(f'Running ThermoMPNN on {structure}...')

    pdb_fp = pdbs[structure]
    chain = chains[structure]

    mut_pdb = alt_parse_PDB(pdb_fp, chain)
    mutation_list = []

    #try to map mutants that we previously generated to thermompnn mutants (should be identical, but might slightly mismatch due to structure handling)
    for mut in output_data['mut']:
        if mut is None:
            mutation_list.append(None)
        else: #only accepts single mutants, custom uploads are skipped
            wtAA, position, mutAA = str(mut[0]), int(str(mut[1:-1])), str(mut[-1])
            position = position - 1 #convert from 1-indexed to zero indexed

            if position in thermompnn_mappings[structure]:
                position = thermompnn_mappings[structure][position]

                assert (
                    wtAA in ALPHABET
                ), f"Wild type residue {wtAA} invalid, please try again with one of the following options: {ALPHABET}"
                assert (
                    mutAA in ALPHABET
                ), f"Wild type residue {mutAA} invalid, please try again with one of the following options: {ALPHABET}"

                #create thermompnn mutation objects
                mutation_obj = Mutation(
                    position=position,
                    wildtype=wtAA,
                    mutation=mutAA,
                    ddG=None,
                    pdb=mut_pdb[0]["name"],
                )
                mutation_list.append(mutation_obj)
            else:
                mutation_list.append(None)

    pred, _ = model(mut_pdb, mutation_list)

    for mut, out in zip(mutation_list, pred):
        if mut is not None:
            thermompnn_scores.append(-1 * out["ddG"].cpu().item())
        else:
            thermompnn_scores.append(None)

    output_data[f"thermompnn_{structure}"] = thermompnn_scores

### Analysis

In [None]:
model = "thermompnn"
frac_mutants = 0.05

# scale columns and calculate bias
add_scaled_outputs(output_data, model, state1_col="open", state2_col="closed")

# filter mutants by low scores
output_data = output_data.dropna(subset=[f"{model}_state1_bias"]).sort_values(
    by=f"{model}_state1_bias", ascending=False
)
passing_mutants = output_data[
    (output_data[f"{model}_state1_scaled"] > 0)
    | (output_data[f"{model}_state2_scaled"] > 0)
]
nonpassing = output_data[
    ~(
        (output_data[f"{model}_state1_scaled"] > 0)
        | (output_data[f"{model}_state2_scaled"] > 0)
    )
]

# take top n biased mutants in each direction
n_mutants_passing_filter = len(
    output_data[
        (output_data[f"{model}_state1_scaled"] > 0)
        | (output_data[f"{model}_state2_scaled"] > 0)
    ]
)
n_biased = round((frac_mutants / 2) * n_mutants_passing_filter)

state1_biased, neutral, state2_biased = (
    passing_mutants[:n_biased],
    passing_mutants[n_biased:-n_biased],
    passing_mutants[-n_biased:],
)

s1_set, s2_set, neutral_set, nonpassing_set = (
    set(state1_biased["mut"]),
    set(state2_biased["mut"]),
    set(neutral["mut"]),
    set(nonpassing["mut"]),
)

assignments = []
for m in output_data["mut"]:
    if m in set(state1_biased["mut"]):
        assignment = "state1"
    elif m in set(state2_biased["mut"]):
        assignment = "state2"
    elif m in neutral_set:
        assignment = "neutral"
    elif m in set(nonpassing["mut"]):
        assignment = "low"
    else:
        assignment = None

    assignments.append(assignment)

# label mutants
output_data[f"{model}_assignment"] = assignments

cmap = {"state1": "red", "state2": "blue", "neutral": "grey", "low": "lightgrey"}

passing = output_data[output_data[f"{model}_assignment"] != "low"]
nonpassing = output_data[output_data[f"{model}_assignment"] == "low"]

state1_cutoff = output_data[output_data[f"{model}_assignment"] == "state1"][
    f"{model}_state1_bias"
].min()
state2_cutoff = output_data[output_data[f"{model}_assignment"] == "state2"][
    f"{model}_state2_bias"
].min()

plt.figure(figsize=(10, 10))
plt.title("Conformational Design Mutants (Top 5% mutants)")

plt.scatter(
    passing[f"{model}_state1_scaled"],
    passing[f"{model}_state2_scaled"],
    marker="o",
    alpha=0.7,
    edgecolor="black",
    c=[cmap[x] for x in passing[f"{model}_assignment"]],
)
plt.scatter(
    nonpassing[f"{model}_state1_scaled"],
    nonpassing[f"{model}_state2_scaled"],
    marker="o",
    alpha=0.25,
    edgecolor="black",
    c=[cmap[x] for x in nonpassing[f"{model}_assignment"]],
)

# set limits to be equal on both axes
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()

umin, umax = min(xmin, ymin), max(xmax, ymax)
plt.xlim(umin, umax)
plt.ylim(umin, umax)

# show cutoffs
plt.plot([umin, 0], [0, 0], color="black")
plt.plot([0, 0], [umin, 0], color="black")

plt.plot([-state2_cutoff, umax - state2_cutoff], [0, umax], color="black")
plt.plot([0, umax], [-state1_cutoff, umax - state1_cutoff], color="black")

plt.xlabel(f"State 1 {model} Score")
plt.ylabel(f"State 2 {model} Score")

# label each section
text_offset = 0.1
plt.text(
    umax - text_offset,
    umax - text_offset,
    "Neutral Mutants",
    horizontalalignment="right",
    verticalalignment="top",
)
plt.text(
    umax - text_offset,
    umin + text_offset,
    "State 1 Bias Predicted Mutants",
    horizontalalignment="right",
    verticalalignment="bottom",
)
plt.text(
    umin + text_offset,
    umax - text_offset,
    "State 2 Bias Predicted Mutants",
    horizontalalignment="left",
    verticalalignment="top",
)
plt.text(
    umin + text_offset,
    umin + text_offset,
    "Low Scoring Mutants",
    horizontalalignment="left",
    verticalalignment="bottom",
)