<a href="https://colab.research.google.com/github/artemg97/af2bind_prod/blob/main/AF2BIND_prod.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### AF2BIND: Prediction of ligand-binding sites using AlphaFold2

AF2BIND is a simple and fast notebook that runs inference on the output obtained from [Alphafold](https://github.com/deepmind/alphafold).


The method utilizes [ColabDesign](https://github.com/sokrypton/ColabDesign) binder protocol framework which facilitates the identification of binding sites for protein-peptide and protein-ligand complexes.

Authors/Collaborators :

*   Artem Gazizov (agazizov@fas.harvard.edu)
*    Sergey Ovchinnikov (so@fas.harvard.edu)
*    Nicholas Polizzi (nicholasf_polizzi@dfci.harvard.edu)

<!--<img src="https://raw.githubusercontent.com/artemg97/af2bind_prod/main/logo.png" width="300">.-->

<figure>
<center>
<img src='https://raw.githubusercontent.com/artemg97/af2bind_prod/main/logo.png'  width="300" height="150"  align=left />
</figure>





In [None]:
%%time
#@title Install AlphaFold2 (~2 mins)
#@markdown Please execute this cell by pressing the *Play* button on
#@markdown the left.

#@markdown **Note**: This installs the Colabdesign framework
import os, time
if not os.path.isdir("params"):
  # get code
  print("installing ColabDesign")
  os.system("(mkdir params; apt-get install aria2 -qq; \
  aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar; \
  aria2c -q -x 16 https://files.ipd.uw.edu/krypton/af2bind_params.zip; \
  tar -xf alphafold_params_2021-07-14.tar -C params; unzip af2bind_params.zip; touch params/done.txt )&")

  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")

  # download params
  if not os.path.isfile("params/done.txt"):
    print("downloading params")
    while not os.path.isfile("params/done.txt"):
      time.sleep(5)

import os
from colabdesign import mk_afdesign_model
from IPython.display import HTML
from google.colab import files
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v4.pdb")
    return f"AF-{pdb_code}-F1-model_v4.pdb"

In [None]:
#@title **Run AF2BIND** 🔬
from colabdesign.af.alphafold.common import residue_constants
import pandas as pd
aa_order = {v:k for k,v in residue_constants.restype_order.items()}

target_pdb = "6w70" #@param {type:"string"}
target_chain = "A" #@param {type:"string"}
#@markdown - Please indicate target pdb and chain (leave pdb blank for custom upload)
pdb_filename = get_pdb(target_pdb)
top_n = 15
import jax, pickle
import jax.numpy as jnp
def af2bind(inputs,outputs,params,aux):
  opt = inputs["opt"]["af2bind"]
  def bypass_relu(x):
    x_relu = jax.nn.relu(x)
    x = jax.nn.leaky_relu(x)
    return jax.lax.stop_gradient(x_relu - x) + x
  xs = []
  for p in params["af2bind"]:
    if "mlp" in p:
      x = outputs["representations"]["pair"][:-20,-20:]
      x = x.reshape(x.shape[0],-1)
      x = (x - p["scale"]["mean"])/p["scale"]["std"]
      p = p["mlp"]
      for k in  range(5):
        x = x @ p["weights"][k] + p["bias"][k]
        if k < 4:
          x = jnp.where(opt["bypass_relu"],
                        bypass_relu(x),
                        jax.nn.relu(x))
      x = x[:,0]
    else:
      d = outputs["distogram"]["logits"][:-20,-20:]
      # 20 bin = 8 angstroms
      d0 = jax.nn.logsumexp(d[...,:20],-1)
      # todo: check if excluding last bin makes sense
      d1 = jax.nn.logsumexp(d[...,20:-1],-1)
      x = (d0 - d1).max(-1)
    xs.append(x)
  x = jnp.stack(xs,-1)
  aux["af2bind"] = jax.nn.sigmoid(x)
  loss = x[:,opt["type"]]
  loss = (loss * opt["site"]).sum() / (opt["site"].sum() + 1e-8)
  return {"af2bind":loss}

if "af_model" not in dir():
  af_model = mk_afdesign_model(protocol="binder",
                               debug=True,
                               loss_callback=af2bind,
                               use_bfloat16=False)
  af_model.opt["weights"]["af2bind"] = 1.0
  af_model.opt["af2bind"] = {"type":0,
                             "site":np.full(1,False),
                             "bypass_relu":False}
  af2bind_params = []
  for m in ["ligand_model","peptide_model"]:
    with open(f"{m}.pkl",'rb') as handle:
      af2bind_params.append(pickle.load(handle))
  af_model._params["af2bind"] = af2bind_params + [{}]

af_model.prep_inputs(pdb_filename=pdb_filename, chain=target_chain, binder_len=20)
af_model.set_seq("ACDEFGHIKLMNPQRSTVWY")
af_model.set_opt(weights=0)
af_model.set_opt("af2bind",site=np.full(af_model._target_len,False))
af_model.set_weights(af2bind=1.0)
af_model.predict(verbose=False)

preds = af_model.aux["af2bind"].copy()

labels = ["chain","resi","resn","ligand","peptide","dgram"]
data = []
for i in range(af_model._target_len):
  c = af_model._pdb["idx"]["chain"][i]
  r = af_model._pdb["idx"]["residue"][i]
  a = aa_order.get(af_model._pdb["batch"]["aatype"][i],"X")
  ps = [round(float(p),3) for p in preds[i]]
  data.append([c,r,a]+ps)

df = pd.DataFrame(data, columns=labels)
df.to_csv('results.csv')

model_m = 0

In [None]:
#@title **Select Model**
from google.colab import data_table
from IPython.display import display, HTML
model_type = 'ligand' #@param ["ligand", "peptide", "dgram"]
data_table.enable_dataframe_formatter()
df_sorted = df.sort_values(model_type,ascending=False, ignore_index=True).rename_axis('rank').reset_index()
display(data_table.DataTable(df_sorted, min_width=100, num_rows_per_page=15, include_index=False))

model_m = {"ligand":0, "peptide":1, "dgram":2}[model_type]
top_n_idx = preds[:,model_m].argsort()[::-1][:15]
pymol_cmd="select ch"+str(target_chain)+","
for n,i in enumerate(top_n_idx):
  p = preds[i]
  c = af_model._pdb["idx"]["chain"][i]
  r = af_model._pdb["idx"]["residue"][i]
  pymol_cmd += f" resi {r}"
  if n < top_n-1:
    pymol_cmd += " +"

print("\n🧪Pymol Selection Cmd:")
print(pymol_cmd)

In [None]:
import matplotlib.pyplot as plt
from scipy.special import softmax
import copy

#@title **Display Structure (Colored by Confidence)**
#partly inspired by OpeFold - https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb#scrollTo=rowN0bVYLe9n
#color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}
rescale_by_max_conf = True #@param {type:"boolean"}
show_ligand = False

if rescale_by_max_conf:
  preds_adj = preds[:,model_m].copy() / preds[:,model_m].max()
else:
  preds_adj = preds[:,model_m].copy()

# replace plddt and coordinates of prediction
L = af_model._target_len
aux = copy.deepcopy(af_model.aux["all"])
aux["plddt"][:,:L] = preds_adj
if show_ligand:
  af_model.save_pdb("output.pdb",aux={"all":aux})
else:
  native_coords = af_model._pdb["batch"]["all_atom_positions"][:L].copy()
  aux["atom_positions"][:,:L] = native_coords
  aux["atom_mask"][:,L:] = 0
  af_model.save_pdb("output.pdb",aux={"all":aux})

af_model.plot_pdb(aux={"all":aux})

def plot_plddt_legend(dpi=100):
  thresh = ['confidence:','<50','60','70','80','>90']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt
plot_plddt_legend().show()

In [None]:
#@title **Download Predictions**
from google.colab import files
os.system(f"zip -r output.zip output.pdb results.csv")
files.download(f'output.zip')

In [None]:
#@title **Optional Analysis**
import matplotlib.pyplot as plt
run_saliency = False #@param {type:"boolean"}
show_distogram = False #@param {type:"boolean"}
#@markdown select position
rank = 0 #@param {type:"raw"}
include_all_top_ranked = False #@param {type:"boolean"}
pos = "" #@param {type:"string"}
#@markdown - select which position to analyze using either `rank` or `pos` (example: `A10`, for chain A, residue 10)

#@markdown advanced settings (for saliency)
hard = True #@param {type:"boolean"}
soft = False #@param {type:"boolean"}
alpha = 2.0 #@param {type:"raw"}
normalize_gradient = True #@param {type:"boolean"}
bypass_relu = True #@param {type:"boolean"}
#@markdown - (experimental option) if saliency is zero, this means this position has vanishing gradient issues. you can set `bypass_relu=True` to avoid this.

if pos == "":
  if include_all_top_ranked:
    i = preds[:,model_m].argsort()[::-1][:rank+1]
    a = None
  else:
    i = preds[:,model_m].argsort()[::-1][rank]
    a = aa_order.get(af_model._pdb["batch"]["aatype"][i],"X")
  c = af_model._pdb["idx"]["chain"][i]
  r = af_model._pdb["idx"]["residue"][i]
else:
  c = ''.join(filter(str.isalpha, pos))
  r = int(''.join(filter(str.isdigit, pos)))
  i = np.argwhere((af_model._pdb["idx"]["chain"] == c) & (af_model._pdb["idx"]["residue"] == r))[0][0]
  a = aa_order.get(af_model._pdb["batch"]["aatype"][i],"X")

if run_saliency:
  sites = np.full(af_model._target_len,False)
  sites[i] = True
  af_model.set_opt(af2bind=dict(site=sites,
                                type=model_m,
                                bypass_relu=bypass_relu),
                  dropout=False,
                  soft=soft,
                  hard=hard,
                  alpha=alpha,
                  sample_models=False)
  af_model.set_seq("ACDEFGHIKLMNPQRSTVWY")
  af_model.run()
  if normalize_gradient:
    af_model._norm_seq_grad()

  saliency_map = af_model.aux["grad"]["seq"][0]
  blosum_map = list("CSTAGPDEQNHRKMILVWYF")
  cs_label_list = list("ACDEFGHIKLMNPQRSTVWY")
  af_label_list = list("ARNDCQEGHILKMFPSTWYV")

  indices_A_Y_mapping = np.array([cs_label_list.index(letter) for letter in blosum_map])
  indices_A_R_mapping = np.array([af_label_list.index(letter) for letter in blosum_map])
  saliency_map = saliency_map[indices_A_Y_mapping,:][:,indices_A_R_mapping]

  max_val = np.abs(saliency_map).max()

  if pos == "" and include_all_top_ranked:
    print(f"including the top {rank+1} ranked positions")
    plt.title(f"avg_conf={preds[i,model_m].mean():.3f}")
  else:
    plt.title(f"chain={c} residue={a}{r} conf={preds[i,model_m]:.3f}")
  plt.imshow(saliency_map.T, cmap="bwr_r", vmin=-max_val, vmax=max_val)
  plt.xticks(np.arange(20),blosum_map)
  plt.yticks(np.arange(20),blosum_map)
  plt.xlabel("inputs"); plt.ylabel("gradient of aminoacids");
  plt.colorbar()
  plt.show()

if show_distogram:

  dgram_logits = af_model.aux["debug"]["outputs"]["distogram"]["logits"][i,-20:]
  dgram = softmax(dgram_logits[...,:-1],-1)
  upper_breaks = np.linspace(2.3125,21.6875,63)
  lower_breaks = np.append(2.0,upper_breaks[:-1])
  mid_point = (upper_breaks + lower_breaks) / 2

  cs_label_list = list("ACDEFGHIKLMNPQRSTVWY")
  plt.figure(figsize=(12,3))
  if pos == "" and include_all_top_ranked:
    plt.title(f"avg_conf={preds[i,model_m].mean():.3f}")
    plt.imshow(dgram.max(0))
  else:
    plt.title(f"chain={c} residue={a}{r} conf={preds[i,model_m]:.3f}")
    plt.imshow(dgram)

  plt.yticks(np.arange(20),cs_label_list)
  plt.xticks(np.arange(63)[::5],np.round(mid_point[::5],1))
  plt.xlabel("distances (angstroms)")
  plt.colorbar()
  plt.show()