<a href="https://colab.research.google.com/github/PeptoneLtd/proteinmpnn_ddg/blob/main/ProteinMPNN_ddG_binding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN-ddG for binding affinity

Scores all possible point mutations of a protein to improving binding at an interface


In [1]:
#@title Install ProteinMPNN-ddG (and colabdesign)
import os
try:
  import proteinmpnn_ddg
except:
  os.system("pip install -q proteinmpnn_ddg[cuda12]@git+https://github.com/PeptoneLtd/proteinmpnn_ddg.git@paper")

from proteinmpnn_ddg import predict_logits_for_all_point_mutations_of_single_pdb

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
pd.options.display.float_format = '{:,.2f}'.format

from google.colab import files
from google.colab import data_table
data_table.disable_dataframe_formatter()

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

In [2]:
import warnings, os, re
warnings.simplefilter(action='ignore', category=FutureWarning)

os.system("mkdir -p output")

# USER OPTIONS
#@markdown # ProteinMPNN options
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
#@markdown (v_48_020 recommended)

#@markdown # Input Options
pdb='7WPH' #@param {type:"string"}
#@markdown (leave `pdb` as  blank to get an upload prompt)
binder_chains = "H,L" #@param {type:"string"}
receptor_chains = "A" #@param {type:"string"}
#@markdown (You can specify several chains, separating by commas e.g. "A,C")

#@markdown Only the chains specified will be loaded from the PDB file for prediction
# chains_to_predict = "" #@param {type:"string"}
# #@markdown (Leave `chains_to_predict` empty to predict all chains)

nrepeats = 1
seed = 42

# cleaning user options
binder_chains, receptor_chains = (re.sub("[^A-Za-z]+",",", chains).split(',') for chains in (
   binder_chains, receptor_chains)
)
chains = list(binder_chains) + list(receptor_chains)

pdb_path = get_pdb(pdb)

In [3]:
#@title Run ProteinMPNN-binding-ddG
#@markdown We compute the logit difference with and without the binding partner
#@markdown then take the difference to identify mutations that should improve
#@markdown binding to the receptor chains.

#@markdown We would suggest taking the mutations with the
#@markdown highest binding ddG after filtering for unbound ddG > 0

#@markdown (Positive values are good mutations, which strengthen binding/stability)

# # @markdown ```df[df.unbound_ddg>0].sort_values('binding_ddg', ascending=False)```
data = {}
for source, chains_to_predict, context_chains in [
    ('unbound', binder_chains, []),
    ('bound', binder_chains, receptor_chains)
]:
  dfs = []
  for chain in chains_to_predict:
    df = predict_logits_for_all_point_mutations_of_single_pdb(
        model_name,
        chains_to_predict+context_chains,
        pdb_path,
        nrepeat=nrepeats,
        seed=seed,
        chain_to_predict=chain,
        pad_inputs=False,
        apply_ddG_correction=True)
    df['chain'] = chain
    dfs.append(df)
  df = pd.concat(dfs).set_index(['chain','pre','pos', 'post']).rename(
      columns={'logit_difference_ddg': f'{source}_ddg'})[f'{source}_ddg']
  data[source] = df

df = pd.concat(data.values(), axis=1)
df['binding_ddg'] = df['bound_ddg'] - df['unbound_ddg']
df.to_csv('predictions.csv')

display(df[df.unbound_ddg>0].sort_values(
    'binding_ddg', ascending=False
    ).head(
      10
      ).style.hide().format(
          precision=2, decimal="."
          )
      )

unbound_ddg,bound_ddg,binding_ddg
0.49,1.37,0.88
0.52,1.12,0.6
0.49,1.09,0.6
1.2,1.77,0.56
0.74,1.29,0.55
0.69,1.17,0.48
1.35,1.81,0.46
1.06,1.47,0.42
0.54,0.95,0.41
0.19,0.56,0.37


In [4]:
#@title download predictions (optional)
from google.colab import files
files.download(f'predictions.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>