<a href="https://colab.research.google.com/github/artemg97/af2bind_prod/blob/main/AF2BIND_prod.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### AF2BIND (v.1.1) -  Prediction of binding sites in protein-ligand and protein-peptide complexes

AF2BIND is a simple and fast notebook that runs inference on the output obtained from [Alphafold](https://github.com/deepmind/alphafold).


The method utilizes [ColabDesign](https://github.com/sokrypton/ColabDesign) binder protocol framework which facilitates the identification of binding sites for protein-peptide and protein-ligand complexes.

Authors/Collaborators :

*   Artem Gazizov (agazizov@fas.harvard.edu)
*    Sergey Ovchinnikov (so@fas.harvard.edu)
*    Nicholas Polizzi (nicholasf_polizzi@dfci.harvard.edu)


<figure>
<center>
<img src='https://drive.google.com/uc?export=view&id=1fHB9irpruKRUQBIEd45pp9go4QKsFigg'  width="300" height="150"  align="right" />

</figure>






In [None]:
#@title Install Colabdesign
#@markdown Please execute this cell by pressing the *Play* button on
#@markdown the left.


#@markdown **Note**: This installs the Colabdesign framework
%%time
%shell pip install py3dmol
import os

os.system("pip install atomium")
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar")
  os.system("tar -xf alphafold_params_2022-03-02.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from IPython.display import HTML
from google.colab import files
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting py3dmol
  Downloading py3Dmol-2.0.3-py2.py3-none-any.whl (12 kB)
Installing collected packages: py3dmol
Successfully installed py3dmol-2.0.3
CPU times: user 1.42 s, sys: 286 ms, total: 1.7 s
Wall time: 2min 19s


In [None]:
#@title **Upload PDB**
from google.colab import files
upload_dict = files.upload()

#@markdown - Please indicate the target chain
target_chain = "A" #@param {type:"string"}
model_type = 'ligand_model' #@param ["ligand_model", "peptide_model"]
pdb_string = upload_dict[list(upload_dict.keys())[0]]

with open("tmp.pdb","wb") as out:
  out.write(pdb_string)


In [None]:
#@title **Run ColabFold** 🔬
clear_mem()

af_model = mk_afdesign_model(protocol="binder", debug=True)
af_model.prep_inputs(pdb_filename="tmp.pdb", chain=target_chain, binder_len=20)


print("target_length",af_model._target_len)
print("binder_length",af_model._binder_len)

af_model.predict(seq="ACDEFGHIKLMNPQRSTVWY", num_recycles=0)

residues_repr=[]

for i in range(af_model._target_len):

  # Empty pairwise repr.

  pw_repr=[]

  for ii in range(20):
      pw_repr=np.concatenate([pw_repr,af_model.aux["debug"]["outputs"]["representations"]["pair"][i][af_model._target_len+ii]])

  #---Info about neighbours----------------------

  #-----------Process the first and the last AA separately


  residues_repr.append(pw_repr)

target_length 232
binder_length 20
predict models [0] recycles 0 hard 1 soft 0 temp 1 loss 4.77 i_con 4.70 plddt 0.35 ptm 0.84 i_ptm 0.21


In [None]:
#@title  **Load model weights** 🤖
from tensorflow.keras.models import Sequential, model_from_json
import sklearn
from sklearn.preprocessing import StandardScaler
from pickle import load
import os

if(model_type=="ligand_model"):

  if not os.path.isdir("model_ligand_weights"):
    os.system("mkdir model_ligand_weights")
    os.system("pip install gdown")
    os.system("gdown --id 1bfD3N5jBXFPr_DyrASfqsalPeX3rgiWS")
    os.system("gdown --id 1NYiiXvmv-WuAHLoTnakLlHwdy5FtS6BG")
    os.system("gdown --id 1U8X09G1TNG6jpbWOP6sLkBMVsasHvqYx")
    os.system("gdown --id 1nOfZoOLmEOucDRxCmL4D1itEcc22cah5")

    #move to folder
    os.system("mv model_ligand.h5 model_ligand_weights")
    os.system("mv model_ligand.json model_ligand_weights")
    os.system("mv scaler_model_ligand.pkl model_ligand_weights")
    os.system("mv trainset.csv model_ligand_weights")



  json_file = open('model_ligand_weights/model_ligand.json', 'r')
  loaded_model_json = json_file.read()
  json_file.close()

  model = model_from_json(loaded_model_json)
  # load weights into new model
  model.load_weights("model_ligand_weights/model_ligand.h5")
  # load the scaler
  scaler = load(open('model_ligand_weights/scaler_model_ligand.pkl', 'rb'))
  print("Model and scaler loaded model from disk")

else:

  if not os.path.isdir("model_peptide_weights"):
    os.system("mkdir model_peptide_weights")
    os.system("pip install gdown")

    #peptide model
    os.system("gdown --id 1TS8q6lulqtL0xHa66HDVxKvNj7yyrP0k")
    os.system("gdown --id 1TJ0frY6hr3kD11gdC21_R9VtUfi_ZKls")
    os.system("gdown --id 1Lwz-lSNn2pr_cJLD_z2DtHnXRHj0MlCG")

    #move to folder
    os.system("mv model_peptide.h5 model_peptide_weights")
    os.system("mv model_peptide.json model_peptide_weights")
    os.system("mv scaler_peptide.pkl model_peptide_weights")
    #os.system("mv trainset.csv model_ligand_weights")


  json_file = open('model_peptide_weights/model_peptide.json', 'r')
  loaded_model_json = json_file.read()
  json_file.close()

  model = model_from_json(loaded_model_json)
  # load weights into new model
  model.load_weights("model_peptide_weights/model_peptide.h5")
  # load the scaler
  scaler = load(open('model_peptide_weights/scaler_peptide.pkl', 'rb'))

  print("Model and scaler loaded model from disk")


In [None]:
#@title **Scale inputs and get the prediction**
import atomium
#@markdown - Please specify the range :)
top_n = "15" #@param {type:"string"}
top_n=int(top_n)

pw_scaled=scaler.transform(residues_repr)
Y_submit=model.predict(pw_scaled)


pdb = atomium.open("tmp.pdb")

residues_dict={}
for i in range(len(pdb.model.chain(target_chain)[:])):

  residues_dict[str(pdb.model.chain(target_chain)[i])]=Y_submit[i][0]


residues_confidence = sorted(residues_dict.items(), key=lambda x:x[1],reverse=True)

pymol_cmd="select ch"+str(target_chain)+", "


print("\n 🧪 Top",top_n, "binding residues sorted by confidence: ")
for i in range(top_n):

  print(residues_confidence[i])

  pymol_cmd=pymol_cmd + " resi " + residues_confidence[i][0].split(".")[1][0:-2]
  if(i!=top_n-1):
    pymol_cmd=pymol_cmd + " +"

  #print(residues_confidence[i][0].split(".")[0][-1])
  #print(residues_confidence[i][0].split(".")[1][0:-2])


print("\n🧪Pymol Selection Cmd:")

print(pymol_cmd)

In [None]:
import py3Dmol
import matplotlib.pyplot as plt
#@title **Color the structure by confidence**
#partly inspired by OpeFold - https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb#scrollTo=rowN0bVYLe9n
#color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
confidence_type = 'relative_of_top_value' #@param ["relative_of_top_value", "absolute"]
view_sidechains = False #@param {type:"boolean"}
# Color bands for visualizing binding sites
CONFIDENCE_BANDS = [
  (0, 50, '#FF7D45'),
  (50, 70, '#FFDB13'),
  (70, 90, '#65CBF3'),
  (90, 100, '#0053D6')
]

def plot_confidence_legend():

  thresh = [
            'Very low (confidence < 50)',
            'Low (70 > confidence > 50)',
            'Confident (90 > confidence > 70)',
            'Very high (confidence > 90)']

  colors = ['#FF7D45', '#FFDB13', '#65CBF3', '#0053D6']

  plt.figure(figsize=(1, 1))
  for c in colors:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False, loc='center', fontsize=20)
  plt.xticks([])
  plt.yticks([])
  ax = plt.gca()
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  ax.spines['left'].set_visible(False)
  ax.spines['bottom'].set_visible(False)
  plt.title('Model Confidence', fontsize=20, pad=50)
  return plt

color_map={0: '#FF7D45', 1: '#FFDB13', 2: '#65CBF3', 3: '#0053D6'}

with open("tmp.pdb") as ifile:
    system = "".join([x for x in ifile])

view = py3Dmol.view(width=800, height=400)
view.addModelsAsFrames(system)

view.setStyle({'model': -1}, {"cartoon": {'color': 'black'}})

if(confidence_type=="absolute"):

  confidence_value_90=0.9
  confidence_value_70=0.7
  confidence_value_50=0.5


else:

  max=residues_confidence[0][1]
  confidence_value_90=max-max*0.1
  confidence_value_70=max-max*0.3
  confidence_value_50=max-max*0.5



for i in residues_confidence:

  if(i[1]>=confidence_value_90):
    resi=i[0].split(".")[1].split(")")[0]
    if(view_sidechains):
      view.setStyle({'resi': resi},{'cartoon': {'color': '#0053D6'}, 'model': -1, 'stick':{} }  )
    else:
      view.setStyle({'resi': resi},{'cartoon': {'color': '#0053D6'}})


  elif( (i[1]<confidence_value_90) and (i[1]>=confidence_value_70)):

    resi=i[0].split(".")[1].split(")")[0]

    if(view_sidechains):
      view.setStyle({'resi': resi},{'cartoon': {'color': '#65CBF3'}, 'model': -1, 'stick':{} }  )
    else:
      view.setStyle({'resi': resi},{'cartoon': {'color': '#65CBF3'}})


  elif( (i[1]<confidence_value_70) and (i[1]>=confidence_value_50)):

    resi=i[0].split(".")[1].split(")")[0]

    if(view_sidechains):
      view.setStyle({'resi': resi},{'cartoon': {'color': '#FFDB13'}, 'model': -1, 'stick':{} }  )
    else:
      view.setStyle({'resi': resi},{'cartoon': {'color': '#FFDB13'}})

  elif( (i[1]<confidence_value_50)):

    resi=i[0].split(".")[1].split(")")[0]

    if(view_sidechains):
      view.setStyle({'resi': resi},{'cartoon': {'color': '#FF7D45'}, 'model': -1, 'stick':{} }  )
    else:
      view.setStyle({'resi': resi},{'cartoon': {'color': '#FF7D45'}})

view.zoomTo()
view.show()

plot_confidence_legend().show()

In [None]:
#@title **Perform salience (ligand)**
import tensorflow as tf
from keras import activations

#---------------Compute the gradients---------------------------------
if(model_type=="ligand_model"):

  pw_scaled_tensor = tf.convert_to_tensor(pw_scaled, dtype=tf.float32)
  with tf.GradientTape() as t:
      t.watch(pw_scaled_tensor)
      output = model(pw_scaled_tensor)

  result = output
  gradients = t.gradient(output, pw_scaled_tensor)

#---------------Perform mapping---------------------------------

  # "ACDEFGHIKLMNPQRSTVWY" + right + left neighbour (ligand only)
  # A index - 0
  # C index - 1
  # D index - 2
  # E index - 3
  # F index - 4
  # G index - 5
  # H index - 6
  # I index - 7
  # K index - 8
  # L index - 9
  # M index - 10
  # N index - 11
  # P index - 12
  # Q index - 13
  # R index - 14
  # S index - 15
  # T index - 16
  # V index - 17
  # W index - 18
  # Y index - 19
  # right aa index - 20
  # left aa - 21


  per_residue_salience=[]
  per_residue_salience_index={}

  pw_map_aa={0:"A",1:"C",2:"D",
            3:"E",4:"F",5:"G",
            6:"H",7:"I",8:"K",
            9:"L",10:"M",11:"N",
            12:"P",13:"Q",14:"R",
            15:"S",16:"T",17:"V",
            18:"W",19:"Y"}#,20:"right aa",
            #21:"left aa"}

  prot_map_saliency=[]
  for i in range(pw_scaled.shape[0]):

    salience=[]
    for k in range(0,2560,128): # range: 0 - 2816
      salience.append(np.mean(abs(gradients[i][k:k+128]))) # np.max because np.mean has a bias

    prot_map_saliency.append(salience)


  max_values=np.array(prot_map_saliency).max(-1)

  for i in range(len(prot_map_saliency)):
    per_residue_salience.append(prot_map_saliency[i].index(max_values[i]))

  for i in range(len(pdb.model.chain(target_chain)[:])):
    per_residue_salience_index[str(pdb.model.chain(target_chain)[i])]=[Y_submit[i][0],pw_map_aa[per_residue_salience[i]]]

per_residue_salience_index = sorted(per_residue_salience_index.items(), key=lambda x:x[1],reverse=True)

In [None]:
#@title **Output binding sites and saliency (ligand)**
print("#########")
print("Top",top_n, "binding residues sorted by confidence [conf, AA w/ max (salience)] ")
for i in range(top_n):
  print(per_residue_salience_index[i])


In [None]:
#@title 🔥 **Heatmap  with the normalized AA contribution (ligand)**
import matplotlib.pyplot as plt

#@markdown - Please specify the range :)
top_n = "5" #@param {type:"string"}

top_n=int(top_n)
residue_count=0

#[C],[S,T,A,G,P], [D,E,Q,N], [H,R,K],  [M,I,L,V], [W,Y,F]

pw_map_aa={ 0:"A",1:"C",2:"D",
          3:"E",4:"F",5:"G",
          6:"H",7:"I",8:"K",
          9:"L",10:"M",11:"N",
          12:"P",13:"Q",14:"R",
          15:"S",16:"T",17:"V",
          18:"W",19:"Y"}#,20:"right aa",
          #21:"left aa"}

prot_map_saliency_blosum=[]

per_residue_salience_blosum=[]
per_residue_salience_blosum_index={}

for i in range(pw_scaled.shape[0]):

  salience=[]
  for k in range(0,2560,128): # range: 0 - 2816
    salience.append(np.max(abs(gradients[i][k:k+128]))) # np.max because np.mean has a bias
  prot_map_saliency_blosum.append(salience)

blosum_map_template=[[1],[15,16,0,5,12],[2,3,13,11],[6,14,8],[10,7,9,17],[18,19,4]]
blosum_map_template_string=[["C"],["S","T","A","G","P"], ["D","E","Q","N"], ["H","R","K"],  ["M","I","L","V"], ["W","Y","F"]]
blosum_map_filled=[]

for sample in range(pw_scaled.shape[0]):

  q=[]

  for i in blosum_map_template:
    q.append([[pw_map_aa[l],prot_map_saliency_blosum[sample][l]] for l in i] )

  blosum_map_filled.append(q)

max_values=np.array(prot_map_saliency_blosum).max(-1)

for i in range(len(pdb.model.chain(target_chain)[:])):
  per_residue_salience_blosum_index[str(pdb.model.chain(target_chain)[i])]=[Y_submit[i][0],blosum_map_filled[i]]

per_residue_salience_blosum_index = sorted(per_residue_salience_blosum_index.items(), key=lambda x:x[1],reverse=True)

#https://upload.wikimedia.org/wikipedia/commons/f/f5/Blosum62-dayhoff-ordering.svg


heatmap_aminoacids=[]
heatmap_res_names=[]

for residue in per_residue_salience_blosum_index:


  residue[1][1][0]=np.array(residue[1][1][0])
  residue[1][1][1]=np.array(residue[1][1][1])
  residue[1][1][2]=np.array(residue[1][1][2])
  residue[1][1][3]=np.array(residue[1][1][3])
  residue[1][1][4]=np.array(residue[1][1][4])
  residue[1][1][5]=np.array(residue[1][1][5])



  s_cl_1=np.sum(np.array(residue[1][1][0])[:,1].astype(float))
  s_cl_2=np.sum(np.array(residue[1][1][1])[:,1].astype(float) )
  s_cl_3=np.sum(np.array(residue[1][1][2])[:,1].astype(float) )
  s_cl_4=np.sum(np.array(residue[1][1][3])[:,1].astype(float) )
  s_cl_5=np.sum(np.array(residue[1][1][4])[:,1].astype(float) )
  s_cl_6=np.sum(np.array(residue[1][1][5])[:,1].astype(float) )

  s_t=s_cl_1 + s_cl_2 + s_cl_3 + s_cl_4 + s_cl_5 + s_cl_6

  residue[1][1][0][:,1]=np.array(residue[1][1][0])[:,1].astype(float)/s_t
  residue[1][1][1][:,1]=np.array(residue[1][1][1])[:,1].astype(float)/s_t
  residue[1][1][2][:,1]=np.array(residue[1][1][2])[:,1].astype(float)/s_t
  residue[1][1][3][:,1]=np.array(residue[1][1][3])[:,1].astype(float)/s_t
  residue[1][1][4][:,1]=np.array(residue[1][1][4])[:,1].astype(float)/s_t
  residue[1][1][5][:,1]=np.array(residue[1][1][5])[:,1].astype(float)/s_t



  #print(residue[0])
  #print("Confidence: ",residue[1][0])

  '''
  print("CLUSTER 1")
  print(residue[1][1][0])
  #print(np.array(residue[1][1][0])[:,1].astype(float)/s_cl_1)
  print("CLUSTER 2")
  print(residue[1][1][1])
  #print(np.array(residue[1][1][1])[:,1].astype(float)/s_cl_2)
  print("CLUSTER 3")
  print(residue[1][1][2])
  #print(np.array(residue[1][1][2])[:,1].astype(float)/s_cl_3)
  print("CLUSTER 4")
  print(residue[1][1][3])
  #print(np.array(residue[1][1][3])[:,1].astype(float)/s_cl_4)
  print("CLUSTER 5")
  print(residue[1][1][4])
  #print(np.array(residue[1][1][4])[:,1].astype(float)/s_cl_5)
  print("CLUSTER 6")
  print(residue[1][1][5])
  #print(np.array(residue[1][1][5])[:,1].astype(float)/s_cl_6)
  '''

  blosum_vector=np.concatenate([ residue[1][1][0][:,1], residue[1][1][1][:,1]])
  blosum_vector=np.concatenate([blosum_vector,residue[1][1][2][:,1]])
  blosum_vector=np.concatenate([blosum_vector,residue[1][1][3][:,1]])
  blosum_vector=np.concatenate([blosum_vector,residue[1][1][4][:,1]])
  blosum_vector=np.concatenate([blosum_vector,residue[1][1][5][:,1]])
  ###

  #plt.plot(["C","S,T,A,G,P", "D,E,Q,N", "H,R,K","M,I,L,V","W,Y,F"],[s_cl_1, s_cl_2,s_cl_3,s_cl_4,s_cl_5,s_cl_6])
  blosum_vector=np.array(blosum_vector).astype(float)

  heatmap_aminoacids.append(blosum_vector)
  heatmap_res_names.append(str(residue[0])+ "; conf: " + str( round(residue[1][0],2) ))

  #plt.plot(["C","S","T","A","G","P", "D","E","Q","N", "H","R","K","M","I","L","V","W","Y","F"],blosum_vector)
  #plt.show()
  #plt.plot(blosum_vector,["C","S","T","A","G","P", "D","E","Q","N", "H","R","K","M","I","L","V","W","Y","F"])
  #plt.show()

  #print("-------")






#fig, ax = plt.subplots(figsize=(19,59), dpi=500)

fig, ax = plt.subplots(figsize=(14,top_n), dpi=500)

#plt.figure(dpi=300)

x_label_list = ["C","S","T","A","G","P", "D","E","Q","N", "H","R","K","M","I","L","V","W","Y","F"]
y_label_list = heatmap_res_names[0:top_n]

#heatmap_aminoacids=np.swapaxes(heatmap_aminoacids,1,0)
img=plt.imshow(heatmap_aminoacids[0:top_n],cmap='bwr', aspect='auto')
#plt.ylabel('Residues sorted by confidence')
ax.set_xticks([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19])
ax.set_yticks(np.linspace(0,top_n,top_n,endpoint=False))

ax.set_xticklabels(x_label_list)
ax.set_yticklabels(y_label_list)

fig.colorbar(img)