## CB with Frame2seq

In [None]:
from __future__ import annotations

import sys

sys.path.append("../utilities/")

from cbutils import (
    aa_code,
    get_chain_seq,
    get_chain_seq_for_scoring,
    make_consensus_sequence,
    setup_aligner,
    alignment_to_mapping,
    mapping_to_sequence,
    add_scaled_outputs,
)

import pandas as pd
from tqdm.notebook import tqdm
import torch
import matplotlib.pyplot as plt

from frame2seq import Frame2seqRunner
from frame2seq.utils import residue_constants
from frame2seq.utils.util import get_neg_pll
from frame2seq.utils.pdb2input import get_inference_inputs

### Select structures and chains

In [None]:
pdbs = {
    "open": "../pdbs/lpla/3a7r.pdb",
    "closed": "../pdbs/lpla/1x2g.pdb",
}

chains = {
    'open':"A",
    'closed':"A",
}

### Align sequences, make mutants, and score

In [None]:
# set up frame2seq scoring function
def frame2seq_score(runner, pdb_file: str, chain_id: str, input_seqs: list[str]):
    """
    Calculates the pseudo-log-likelihood (PLL) scores for a list of input sequences
    given a structure using a Frame2seq model ensemble.

    Args:
        runner: Frame2seqRunner.
            An initialized Frame2seqRunner object containing the ensemble models.
        pdb_file: str
            Path to a PDB file containing the desired protein structure.
        chain_id: str
            Chain identifier (e.g., 'A') corresponding to the chain of interest in the PDB file.
        input_seqs: list of str
            List of amino acid sequences to be evaluated against the structure. Must be length matched.

    Returns:
        scores: list of float
            List of negative PLL scores, one for each input sequence. Higher (less negative)
            values indicate sequences more compatible with the structure.
    """
    # Get structure-based input tensors for inference
    seq_mask, backbone_seq_tokenized, X = get_inference_inputs(pdb_file, chain_id)

    # Decode backbone sequence from tokenized integer representation
    backbone_seq = [
        residue_constants.ID_TO_AA[int(i)] for i in backbone_seq_tokenized[0]
    ]

    # Convert backbone sequence to one-hot encoding using standard AA to ID mapping
    backbone_seq_onehot = residue_constants.sequence_to_onehot(
        sequence=backbone_seq,
        mapping=residue_constants.AA_TO_ID,
    )

    # Convert one-hot numpy array to torch tensor and move to runner device
    backbone_seq_onehot = (
        torch.from_numpy(backbone_seq_onehot).float().unsqueeze(0).to(runner.device)
    )
    # Mask all positions in sequence by setting them to 'X' (unknown amino acid)
    backbone_seq_onehot = torch.zeros_like(backbone_seq_onehot)
    backbone_seq_onehot[:, :, 20] = 1  # 20 = 'X', mask all positions

    scores = []  # list to collect scores for each input sequence

    with torch.no_grad():
        # Run all three ensemble models to get amino acid probabilities
        aaprobs1 = runner.models[0].forward(X, seq_mask, backbone_seq_onehot)
        aaprobs2 = runner.models[1].forward(X, seq_mask, backbone_seq_onehot)
        aaprobs3 = runner.models[2].forward(X, seq_mask, backbone_seq_onehot)

        # Average logits from ensemble models
        aaprobs = (aaprobs1 + aaprobs2 + aaprobs3) / 3  # ensemble model predictions

        # Apply softmax to obtain amino acid probability distributions
        aaprobs = torch.nn.functional.softmax(aaprobs, dim=-1)

        # Only keep probabilities at valid sequence mask positions
        aaprobs = aaprobs[seq_mask]

        # Convert each input sequence to tensor of residue IDs on the runner device
        input_seqs = [
            torch.tensor([residue_constants.AA_TO_ID[aa] for aa in seq])
            .long()
            .to(runner.device)
            for seq in input_seqs
        ]

        # For each input sequence, calculate and collect the negative PLL score (log-likelihood under model)
        for sample in tqdm(range(len(input_seqs))):
            input_seq_i = input_seqs[sample]
            _neg_pll, avg_neg_pll = get_neg_pll(aaprobs, input_seq_i)
            scores.append(-1 * avg_neg_pll)  # multiply by -1 to return PLL

    return scores  # return list of scores, one per input sequence

# get sequences and align
seqs = {pdb: get_chain_seq(pdbs[pdb], chains[pdb]) for pdb in pdbs}
scoring_seqs = {pdb: get_chain_seq_for_scoring(pdbs[pdb], chains[pdb]) for pdb in pdbs}
con_seq = make_consensus_sequence(list(seqs.values()))

aligner = setup_aligner()
alignments = {pdb: aligner.align(con_seq, seq)[0] for pdb, seq in scoring_seqs.items()}

mappings = {
    pdb: alignment_to_mapping(alignment) for pdb, alignment in alignments.items()
}

# make mutants
muts = []
mut_seqs = []
for i, aa in enumerate(con_seq):
    for aa_new in aa_code:
        if aa_new != aa:
            mut_seqs.append(con_seq[:i] + aa_new + con_seq[i + 1 :])
            muts.append(f"{aa}{i+1}{aa_new}")

#load frame2seq runner object
runner = Frame2seqRunner()

#score sequences
output_data = pd.DataFrame({"mut": muts, "seq": mut_seqs})

for structure in pdbs:
    output_seqs = []

    wt_seq = mapping_to_sequence(con_seq, scoring_seqs[structure], mappings[structure])

    with open(f"{structure}.fasta", "w") as f:
        for mut_seq, mut in zip(mut_seqs, muts):
            mapped_seq = mapping_to_sequence(
                mut_seq, scoring_seqs[structure], mappings[structure]
            )
            output_seqs.append(mapped_seq)

    wt_score = frame2seq_score(runner, pdbs[structure], chains[structure], [wt_seq])[0]
    outs = frame2seq_score(runner, pdbs[structure], chains[structure], output_seqs)

    output_data[f"frame2seq_{structure}"] = [x - wt_score for x in outs]

## Analysis

In [None]:
model = "frame2seq"
frac_mutants = 0.05

# scale columns and calculate bias
add_scaled_outputs(output_data, model, state1_col="open", state2_col="closed")

# filter mutants by low scores
output_data = output_data.dropna(subset=[f"{model}_state1_bias"]).sort_values(
    by=f"{model}_state1_bias", ascending=False
)
passing_mutants = output_data[
    (output_data[f"{model}_state1_scaled"] > 0)
    | (output_data[f"{model}_state2_scaled"] > 0)
]
nonpassing = output_data[
    ~(
        (output_data[f"{model}_state1_scaled"] > 0)
        | (output_data[f"{model}_state2_scaled"] > 0)
    )
]

# take top n biased mutants in each direction
n_mutants_passing_filter = len(
    output_data[
        (output_data[f"{model}_state1_scaled"] > 0)
        | (output_data[f"{model}_state2_scaled"] > 0)
    ]
)
n_biased = round((frac_mutants / 2) * n_mutants_passing_filter)

state1_biased, neutral, state2_biased = (
    passing_mutants[:n_biased],
    passing_mutants[n_biased:-n_biased],
    passing_mutants[-n_biased:],
)

s1_set, s2_set, neutral_set, nonpassing_set = (
    set(state1_biased["mut"]),
    set(state2_biased["mut"]),
    set(neutral["mut"]),
    set(nonpassing["mut"]),
)

assignments = []
for m in output_data["mut"]:
    if m in set(state1_biased["mut"]):
        assignment = "state1"
    elif m in set(state2_biased["mut"]):
        assignment = "state2"
    elif m in neutral_set:
        assignment = "neutral"
    elif m in set(nonpassing["mut"]):
        assignment = "low"
    else:
        assignment = None

    assignments.append(assignment)

# label mutants
output_data[f"{model}_assignment"] = assignments

cmap = {"state1": "red", "state2": "blue", "neutral": "grey", "low": "lightgrey"}

passing = output_data[output_data[f"{model}_assignment"] != "low"]
nonpassing = output_data[output_data[f"{model}_assignment"] == "low"]

state1_cutoff = output_data[output_data[f"{model}_assignment"] == "state1"][
    f"{model}_state1_bias"
].min()
state2_cutoff = output_data[output_data[f"{model}_assignment"] == "state2"][
    f"{model}_state2_bias"
].min()

plt.figure(figsize=(10, 10))
plt.title("Conformational Design Mutants (Top 5% mutants)")

plt.scatter(
    passing[f"{model}_state1_scaled"],
    passing[f"{model}_state2_scaled"],
    marker="o",
    alpha=0.7,
    edgecolor="black",
    c=[cmap[x] for x in passing[f"{model}_assignment"]],
)
plt.scatter(
    nonpassing[f"{model}_state1_scaled"],
    nonpassing[f"{model}_state2_scaled"],
    marker="o",
    alpha=0.25,
    edgecolor="black",
    c=[cmap[x] for x in nonpassing[f"{model}_assignment"]],
)

# set limits to be equal on both axes
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()

umin, umax = min(xmin, ymin), max(xmax, ymax)
plt.xlim(umin, umax)
plt.ylim(umin, umax)

# show cutoffs
plt.plot([umin, 0], [0, 0], color="black")
plt.plot([0, 0], [umin, 0], color="black")

plt.plot([-state2_cutoff, umax - state2_cutoff], [0, umax], color="black")
plt.plot([0, umax], [-state1_cutoff, umax - state1_cutoff], color="black")

plt.xlabel(f"State 1 {model} Score")
plt.ylabel(f"State 2 {model} Score")

# label each section
text_offset = 0.1
plt.text(
    umax - text_offset,
    umax - text_offset,
    "Neutral Mutants",
    horizontalalignment="right",
    verticalalignment="top",
)
plt.text(
    umax - text_offset,
    umin + text_offset,
    "State 1 Bias Predicted Mutants",
    horizontalalignment="right",
    verticalalignment="bottom",
)
plt.text(
    umin + text_offset,
    umax - text_offset,
    "State 2 Bias Predicted Mutants",
    horizontalalignment="left",
    verticalalignment="top",
)
plt.text(
    umin + text_offset,
    umin + text_offset,
    "Low Scoring Mutants",
    horizontalalignment="left",
    verticalalignment="bottom",
)