## CB with ProteinMPNN

In [1]:
from __future__ import annotations

import sys
import os
import pandas as pd
import matplotlib.pyplot as plt

sys.path.append("../utilities/")

from tqdm.notebook import tqdm
from colabdesign.mpnn import mk_mpnn_model
from cbutils import aa_code, get_chain_seq, get_chain_seq_for_scoring, make_consensus_sequence, setup_aligner, alignment_to_mapping, mapping_to_sequence, mpnn_score, add_scaled_outputs

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

### Select structures and chains

In [2]:
#setup input pdb files
pdbs = {
    "open": "../pdbs/lpla/3a7r.pdb",
    "closed": "../pdbs/lpla/1x2g.pdb",
}

#identify which chains to use for scoring
chains = {
    'open':"A",
    'closed':"A",
}

### Align sequences, generate mutants, and score

In [None]:
# get protein sequences from pdb files and form consensus sequence
seqs = {pdb: get_chain_seq(pdbs[pdb], chains[pdb]) for pdb in pdbs}
scoring_seqs = {pdb: get_chain_seq_for_scoring(pdbs[pdb], chains[pdb]) for pdb in pdbs} #skip any gaps

con_seq = make_consensus_sequence(list(seqs.values()))

# align all sequences to the consensus sequence
aligner = setup_aligner()
alignments = {pdb: aligner.align(con_seq, seq)[0] for pdb, seq in scoring_seqs.items()}

# create mapping of positions from consensus sequence to each pdb sequence
mappings = {
    pdb: alignment_to_mapping(alignment) for pdb, alignment in alignments.items()
}

# create list of all possible single mutations of input sequence
muts = []
mut_seqs = []
for i, aa in enumerate(con_seq):
    for aa_new in aa_code:
        if aa_new != aa:
            mut_seqs.append(con_seq[:i] + aa_new + con_seq[i + 1 :])
            muts.append(f"{aa}{i+1}{aa_new}")

homooligomer = False  # if structure is a homooligomer
fix_pos = None #don't fix any positions
inverse = True  # whether to invert the fix pos selection
model_name = "v_48_020"  #use default model checkpoint

# initialize proteinMPNN model
if "mpnn_model" not in dir():
    mpnn_model = mk_mpnn_model(model_name)

output_data = pd.DataFrame({'mut': muts, 'seq': mut_seqs})

# for each pdb file, score all mutations and save scores relative to WT score
for structure in pdbs:
    output_scores = []

    #load structure model
    mpnn_model.prep_inputs(
        pdb_filename=pdbs[structure],
        chain=chains[structure],
        homooligomer=homooligomer,
        fix_pos=fix_pos,
        inverse=inverse,
        verbose=True,
    )

    #score wild type sequence
    wt_seq = mapping_to_sequence(con_seq, scoring_seqs[structure], mappings[structure])
    wt_score = mpnn_score(wt_seq, mpnn_model)

    #score mutants
    for mut_seq in tqdm(mut_seqs):
        mapped_seq = mapping_to_sequence(
            mut_seq, scoring_seqs[structure], mappings[structure]
        )
        score = mpnn_score(mapped_seq, mpnn_model)
        output_scores.append(score - wt_score)

    #save scores
    output_data["pmpnn_" + structure] = output_scores



lengths [337]


  0%|          | 0/6403 [00:00<?, ?it/s]

### Analysis

In [None]:
model = 'pmpnn'
frac_mutants = 0.05

#scale columns and calculate bias
add_scaled_outputs(output_data, model, state1_col = 'open', state2_col = 'closed')

#filter mutants by low scores
output_data = output_data.dropna(subset = [f'{model}_state1_bias']).sort_values(by = f'{model}_state1_bias', ascending = False)
passing_mutants = output_data[(output_data[f'{model}_state1_scaled'] > 0) | (output_data[f'{model}_state2_scaled'] > 0)]
nonpassing = output_data[~((output_data[f'{model}_state1_scaled'] > 0) | (output_data[f'{model}_state2_scaled'] > 0))]

# take top n biased mutants in each direction
n_mutants_passing_filter = len(output_data[(output_data[f'{model}_state1_scaled'] > 0) | (output_data[f'{model}_state2_scaled'] > 0)])
n_biased = round((frac_mutants/2) * n_mutants_passing_filter)

state1_biased, neutral, state2_biased = passing_mutants[:n_biased], passing_mutants[n_biased:-n_biased], passing_mutants[-n_biased:]

s1_set, s2_set, neutral_set, nonpassing_set = set(state1_biased['mut']), set(state2_biased['mut']), set(neutral['mut']), set(nonpassing['mut'])

assignments = []
for m in output_data['mut']:
    if m in set(state1_biased['mut']):
        assignment = 'state1'
    elif m in set(state2_biased['mut']):
        assignment = 'state2'
    elif m in neutral_set:
        assignment = 'neutral'
    elif m in set(nonpassing['mut']):
        assignment = 'low'
    else:
        assignment = None

    assignments.append(assignment)

#label mutants 
output_data[f'{model}_assignment'] = assignments

cmap = {'state1': 'red', 'state2': 'blue', 'neutral': 'grey', 'low': 'lightgrey'}

passing = output_data[output_data[f'{model}_assignment'] != 'low']
nonpassing = output_data[output_data[f'{model}_assignment'] == 'low']

state1_cutoff = output_data[output_data[f'{model}_assignment'] == 'state1'][f'{model}_state1_bias'].min()
state2_cutoff = output_data[output_data[f'{model}_assignment'] == 'state2'][f'{model}_state2_bias'].min()

plt.figure(figsize = (10,10))
plt.title('Conformational Design Mutants (Top 5% mutants)')

plt.scatter(passing[f'{model}_state1_scaled'], passing[f'{model}_state2_scaled'], marker = 'o', alpha = 0.7, edgecolor = 'black', c=[cmap[x] for x in passing[f'{model}_assignment']])
plt.scatter(nonpassing[f'{model}_state1_scaled'], nonpassing[f'{model}_state2_scaled'], marker = 'o', alpha = 0.25, edgecolor = 'black', c=[cmap[x] for x in nonpassing[f'{model}_assignment']])

# set limits to be equal on both axes
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()

umin, umax = min(xmin, ymin), max(xmax, ymax)
plt.xlim(umin, umax)
plt.ylim(umin, umax)

#show cutoffs
plt.plot([umin, 0], [0,0], color = 'black')
plt.plot([0, 0], [umin,0], color = 'black')

plt.plot([-state2_cutoff, umax-state2_cutoff], [0, umax], color = 'black')
plt.plot([0, umax], [-state1_cutoff, umax -state1_cutoff], color = 'black')

plt.xlabel(f'State 1 {model} Score')
plt.ylabel(f"State 2 {model} Score")

#label each section
text_offset = 0.1
plt.text(umax - text_offset, umax - text_offset, 'Neutral Mutants', horizontalalignment = 'right', verticalalignment = 'top')
plt.text(umax - text_offset, umin + text_offset, 'State 1 Bias Predicted Mutants', horizontalalignment = 'right', verticalalignment = 'bottom')
plt.text(umin + text_offset, umax - text_offset, 'State 2 Bias Predicted Mutants', horizontalalignment = 'left', verticalalignment = 'top')
plt.text(umin + text_offset, umin + text_offset, 'Low Scoring Mutants', horizontalalignment = 'left', verticalalignment = 'bottom')

#