## CB with ESM-IF1

In [None]:
from __future__ import annotations

import sys

sys.path.append("../utilities/")

from cbutils import (
    aa_code,
    make_consensus_sequence,
    setup_aligner,
    alignment_to_mapping,
    mapping_to_sequence,
    AA_ALPHABET,
    add_scaled_outputs,
)

from tqdm.notebook import tqdm

import torch
import esm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from esm.inverse_folding.util import CoordBatchConverter
from scipy.special import softmax

### Select structures and chains

In [None]:
#set path to input pdbs
pdbs = {
    "open": "../pdbs/lpla/3a7r.pdb",
    "closed": "../pdbs/lpla/1x2g.pdb",
}

#specify chains for each structure
chains = {
    "open": "A",
    "closed": "A",
}

### Align sequences, make mutants, and score

In [None]:
#load model and alphabet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()

model = model.cuda()
model = model.eval()

#set up scoring functions
def get_esmif_aaprobs(model, alphabet, coords, seq):
    """
    Compute amino acid probabilities for a given sequence and structure using ESM-IF1.

    Args:
        model: The ESM-IF1 model.
        alphabet: The ESM alphabet object.
        coords: Numpy array of backbone coordinates for the structure.
        seq: Amino acid sequence (string).

    Returns:
        aa_probs: DataFrame of amino acid probabilities (rows: AAs, columns: positions).
    """
    # Get device from model
    device = next(model.parameters()).device

    # Prepare batch for the model
    batch_converter = CoordBatchConverter(alphabet)
    batch = [(coords, None, seq)]
    coords, confidence, strs, tokens, padding_mask = batch_converter(
        batch, device=device
    )

    # Prepare input tokens for the model
    prev_output_tokens = tokens[:, :-1].to(device)
    target = tokens[:, 1:]
    target_padding_mask = target == alphabet.padding_idx

    # Forward pass through the model
    logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)
    logits = logits.cpu().detach().numpy().squeeze()

    # Compute softmax probabilities and format as DataFrame
    aa_probs = pd.DataFrame(
        softmax(logits), 
        index=list(alphabet.to_dict().keys())
    )
    # Restrict to standard amino acids
    aa_probs = aa_probs.loc[[aa for aa in AA_ALPHABET], :]
    # Normalize probabilities at each position
    aa_probs = aa_probs / aa_probs.sum(axis=0)

    return aa_probs


def get_esmif_score(seq, model, alphabet, coords):
    """
    Compute the log-probability score of a sequence given structure using ESM-IF1.

    Args:
        seq: Amino acid sequence (string).
        model: The ESM-IF1 model.
        alphabet: The ESM alphabet object.
        coords: Numpy array of backbone coordinates for the structure.

    Returns:
        score: Log-probability score (float).
    """
    # Get amino acid probabilities for the sequence
    aa_probs = get_esmif_aaprobs(model, alphabet, coords, seq)
    score = 0.0
    # Sum log-probabilities for each residue in the sequence
    for idx, aa in enumerate(seq):
        score += np.log(aa_probs.loc[aa, idx])
    return score

seqs = {}
coords = {}
for pdb in pdbs:
    structure = esm.inverse_folding.util.load_structure(pdbs[pdb], chains[pdb])
    coord, native_seq = esm.inverse_folding.util.extract_coords_from_structure(
        structure
    )
    seqs[pdb] = native_seq
    coords[pdb] = coord
con_seq = make_consensus_sequence(list(seqs.values()))

aligner = setup_aligner()
alignments = {pdb: aligner.align(con_seq, seq)[0] for pdb, seq in seqs.items()}

mappings = {
    pdb: alignment_to_mapping(alignment) for pdb, alignment in alignments.items()
}

muts = []
mut_seqs = []
for i, aa in enumerate(con_seq):
    for aa_new in aa_code:
        if aa_new != aa:
            mut_seqs.append(con_seq[:i] + aa_new + con_seq[i + 1 :])
            muts.append(f"{aa}{i+1}{aa_new}")

output_data = pd.DataFrame({"mut": muts, "seq": mut_seqs})

for structure in pdbs:
    output_scores = []

    wt_seq = mapping_to_sequence(con_seq, seqs[structure], mappings[structure])
    wt_score = get_esmif_score(wt_seq, model, alphabet, coords[structure])

    for mut_seq in tqdm(mut_seqs):
        mapped_seq = mapping_to_sequence(mut_seq, seqs[structure], mappings[structure])
        score = get_esmif_score(mapped_seq, model, alphabet, coords[structure])
        output_scores.append(score - wt_score)

    output_data["esmif1_" + structure] = output_scores

# Analysis

In [None]:
model = "esmif1"
frac_mutants = 0.05

# scale columns and calculate bias
add_scaled_outputs(output_data, model, state1_col="open", state2_col="closed")

# filter mutants by low scores
output_data = output_data.dropna(subset=[f"{model}_state1_bias"]).sort_values(
    by=f"{model}_state1_bias", ascending=False
)
passing_mutants = output_data[
    (output_data[f"{model}_state1_scaled"] > 0)
    | (output_data[f"{model}_state2_scaled"] > 0)
]
nonpassing = output_data[
    ~(
        (output_data[f"{model}_state1_scaled"] > 0)
        | (output_data[f"{model}_state2_scaled"] > 0)
    )
]

# take top n biased mutants in each direction
n_mutants_passing_filter = len(
    output_data[
        (output_data[f"{model}_state1_scaled"] > 0)
        | (output_data[f"{model}_state2_scaled"] > 0)
    ]
)
n_biased = round((frac_mutants / 2) * n_mutants_passing_filter)

state1_biased, neutral, state2_biased = (
    passing_mutants[:n_biased],
    passing_mutants[n_biased:-n_biased],
    passing_mutants[-n_biased:],
)

s1_set, s2_set, neutral_set, nonpassing_set = (
    set(state1_biased["mut"]),
    set(state2_biased["mut"]),
    set(neutral["mut"]),
    set(nonpassing["mut"]),
)

assignments = []
for m in output_data["mut"]:
    if m in set(state1_biased["mut"]):
        assignment = "state1"
    elif m in set(state2_biased["mut"]):
        assignment = "state2"
    elif m in neutral_set:
        assignment = "neutral"
    elif m in set(nonpassing["mut"]):
        assignment = "low"
    else:
        assignment = None

    assignments.append(assignment)

# label mutants
output_data[f"{model}_assignment"] = assignments

cmap = {"state1": "red", "state2": "blue", "neutral": "grey", "low": "lightgrey"}

passing = output_data[output_data[f"{model}_assignment"] != "low"]
nonpassing = output_data[output_data[f"{model}_assignment"] == "low"]

state1_cutoff = output_data[output_data[f"{model}_assignment"] == "state1"][
    f"{model}_state1_bias"
].min()
state2_cutoff = output_data[output_data[f"{model}_assignment"] == "state2"][
    f"{model}_state2_bias"
].min()

plt.figure(figsize=(10, 10))
plt.title("Conformational Design Mutants (Top 5% mutants)")

plt.scatter(
    passing[f"{model}_state1_scaled"],
    passing[f"{model}_state2_scaled"],
    marker="o",
    alpha=0.7,
    edgecolor="black",
    c=[cmap[x] for x in passing[f"{model}_assignment"]],
)
plt.scatter(
    nonpassing[f"{model}_state1_scaled"],
    nonpassing[f"{model}_state2_scaled"],
    marker="o",
    alpha=0.25,
    edgecolor="black",
    c=[cmap[x] for x in nonpassing[f"{model}_assignment"]],
)

# set limits to be equal on both axes
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()

umin, umax = min(xmin, ymin), max(xmax, ymax)
plt.xlim(umin, umax)
plt.ylim(umin, umax)

# show cutoffs
plt.plot([umin, 0], [0, 0], color="black")
plt.plot([0, 0], [umin, 0], color="black")

plt.plot([-state2_cutoff, umax - state2_cutoff], [0, umax], color="black")
plt.plot([0, umax], [-state1_cutoff, umax - state1_cutoff], color="black")

plt.xlabel(f"State 1 {model} Score")
plt.ylabel(f"State 2 {model} Score")

# label each section
text_offset = 0.1
plt.text(
    umax - text_offset,
    umax - text_offset,
    "Neutral Mutants",
    horizontalalignment="right",
    verticalalignment="top",
)
plt.text(
    umax - text_offset,
    umin + text_offset,
    "State 1 Bias Predicted Mutants",
    horizontalalignment="right",
    verticalalignment="bottom",
)
plt.text(
    umin + text_offset,
    umax - text_offset,
    "State 2 Bias Predicted Mutants",
    horizontalalignment="left",
    verticalalignment="top",
)
plt.text(
    umin + text_offset,
    umin + text_offset,
    "Low Scoring Mutants",
    horizontalalignment="left",
    verticalalignment="bottom",
)