<a href="https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/fixbb_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

install

In [11]:
%%bash
if [ ! -d af_backprop ]; then
  git clone https://github.com/sokrypton/af_backprop.git
  pip -q install biopython dm-haiku ml-collections py3Dmol
fi
if [ ! -d params ]; then
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi
wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py

import libraries

In [12]:
import sys
sys.path.append('/content/af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline, templates
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from utils import *
import colabfold as cf

setup model

In [3]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
model_config.model.num_recycle = 1
model_config.data.common.num_recycle = 1

# backprop through recycles
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# number of sequences
model_config.data.common.max_extra_msa = 1
model_config.data.eval.max_msa_clusters = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout
model_config = set_dropout(model_config, 0.0)

# setup model
model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
model_runner = model.RunModel(model_config, model_params[0], is_training=True)

# load the other model_params (during optimization, we randomly pick which model to use)
for model_name in ["model_1_ptm","model_2_ptm","model_4_ptm","model_5_ptm"]:
  params = data.get_model_haiku_params(model_name, '.')
  model_params.append({k: params[k] for k in model_runner.params.keys()})

example

In [4]:
# setup inputs
example = "1QYS"
!wget -qnc https://files.rcsb.org/view/{example}.pdb
protein_obj = protein.from_pdb_string(pdb_to_string(f"{example}.pdb"))

batch = {'aatype': protein_obj.aatype,
          'all_atom_positions': protein_obj.atom_positions,
          'all_atom_mask': protein_obj.atom_mask}
batch.update(all_atom.atom37_to_frames(**batch))

query_sequence = "".join([order_restype[a] for a in protein_obj.aatype])
starting_sequence = query_sequence

# one_hot_encode
feature_dict = {
    **pipeline.make_sequence_features(sequence=starting_sequence,description="none",num_res=len(starting_sequence)),
    **pipeline.make_msa_features(msas=[[starting_sequence]],deletion_matrices=[[[0]*len(starting_sequence)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)
wt_seq = jax.nn.one_hot(inputs["aatype"][0],20)

setup gradient

In [97]:
def get_grad_fn(model_runner, inputs):

  def mod(params, key, model_params, opt):
    ############################
    # set amino acid sequence
    ############################
    seq_logits = params["seq"]
    seq = soft_seq(seq_logits)
    
    mask = opt["mask"][:,None]
    pseudo_seq = mask * seq + (1-mask) * seq_logits

    inputs_mod = inputs.copy()
    update_seq(pseudo_seq, inputs_mod)
    
    ####################
    # set sidechains identity
    ####################
    N,L = inputs_mod["aatype"].shape[:2]
    ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)

    aatype = jnp.zeros((N,L,21)).at[...,:20].set(seq)
    aatype_ala = jnp.zeros((N,L,21)).at[:].set(ALA)
    aatype_pseudo = mask * aatype + (1-mask) * aatype_ala
    update_aatype(aatype_pseudo, inputs_mod)
    
    # get output
    outputs = model_runner.apply(model_params, key, inputs_mod)
            
    # losses
    dgram_loss = get_dgram_loss(batch, outputs, model_config=model_runner.config)
    fape_loss = get_fape_loss(batch, outputs, model_config=model_runner.config)

    # we are just monitoring rmsd, but it's not used in loss
    rmsd_loss = jnp_rmsd(protein_obj.atom_positions[:,1,:],
                         outputs["structure_module"]["final_atom_positions"][:,1,:])

    loss = dgram_loss # + fape_loss
    outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
            "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

    return loss, ({"losses":{"rmsd":rmsd_loss,
                             "dgram":dgram_loss,
                             "fape":fape_loss},
                   "outputs":outs,"seq":seq,"pseudo_seq":pseudo_seq})
  
  return jax.value_and_grad(mod, has_aux=True, argnums=0)

where the magic happens

In [99]:
# gradient function
grad_fn = jax.jit(get_grad_fn(model_runner, inputs))

In [100]:
init_fun, update_fun, get_params = adam(step_size=1e-2)
def step(i, state, key, model_params, opt):
  (loss, outs), grad = grad_fn(get_params(state), key, model_params=model_params, opt=opt)
  grad["seq"] = grad["seq"] / jnp.sqrt(jnp.square(grad["seq"]).sum())
  state = update_fun(i, grad, state)
  return state, outs

For complex targets, we find directly optimizing a single one_hot encoded sequence (even with gumbel-st) to be very challenging. Instead, we start with a continious representation and then one residue at a time switch to one_hot in second round of optimization.

**design adversarial vector**

In [105]:
L,A = wt_seq.shape
key = jax.random.PRNGKey(0)

seq = jnp.zeros((L,A))
mask = jnp.zeros((L,))
state = init_fun({"seq":seq})

In [106]:
XYZ = []
SEQ = []
for i in range(500): # number of iterations (might be overkill)
  key,subkey = jax.random.split(key)
  n = jax.random.randint(subkey,[],0,5) # select which model to use
  state, outs = step(i, state, subkey, model_params[n], {"mask":mask})
  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()
  if i < 200:
    # save for animation
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))
  if (i+1) % 10 == 0:
    losses = outs["losses"]
    print(f'{i+1} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

10 0 4.665 4.283 34.885 0.043
20 0 4.212 3.462 27.203 0.065
30 0 3.617 3.096 24.548 0.076
40 0 3.673 2.354 17.051 0.087
50 0 3.266 1.734 7.716 0.087
60 0 3.123 2.059 15.548 0.098
70 0 2.899 1.487 5.342 0.120
80 0 2.696 1.184 3.483 0.098
90 0 2.530 1.028 3.037 0.098
100 0 2.104 0.417 1.358 0.130
110 0 1.884 0.411 1.284 0.120
120 0 1.808 0.380 0.937 0.174
130 0 1.807 0.391 0.927 0.163
140 0 1.743 0.321 0.817 0.174
150 0 1.735 0.409 0.935 0.185
160 0 1.581 0.283 0.677 0.185
170 0 1.538 0.278 0.634 0.185
180 0 1.639 0.250 0.633 0.207
190 0 1.607 0.258 0.679 0.228
200 0 1.495 0.305 0.659 0.207
210 0 1.564 0.273 0.737 0.217
220 0 1.455 0.421 0.586 0.217
230 0 1.401 0.240 0.568 0.185
240 0 1.428 0.244 0.611 0.185
250 0 1.428 0.258 0.634 0.196
260 0 1.274 0.230 0.476 0.217
270 0 1.335 0.248 0.498 0.228
280 0 1.917 0.320 0.848 0.207
290 0 1.297 0.201 0.520 0.217
300 0 1.234 0.201 0.462 0.217
310 0 1.219 0.227 0.447 0.228
320 0 1.287 0.246 0.592 0.228
330 0 1.288 0.294 0.526 0.250
340 0 1.255 0.

In [107]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from IPython.display import HTML

def make_animation(positions, seq, pos_ref=None, line_w=2.0, dpi=100):

  def align(P, Q):
    p = P - P.mean(0,keepdims=True)
    q = Q - Q.mean(0,keepdims=True)
    return p @ cf.kabsch(p,q)

  if pos_ref is None: pos_ref = positions[-1]
  pos_ref = pos_ref - pos_ref.mean(0,keepdims=True)
  pos_ref = pos_ref @ cf.kabsch(pos_ref,pos_ref,return_v=True)

  new_positions = []
  for i in range(len(positions)):
    new_positions.append(align(positions[i],pos_ref))
  pos = np.asarray(new_positions)

  fig = plt.figure()
  gs = GridSpec(4,1, figure=fig)
  ax1,ax2 = fig.add_subplot(gs[:3,:]),fig.add_subplot(gs[3:,:])
  fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 0.9, left = 0.1, hspace = 0, wspace = 0)
  fig.set_figwidth(5);fig.set_figheight(6)
  fig.set_dpi(dpi)

  ax2.set_xlabel("positions")
  ax2.set_yticks([])

  z_min,z_max = pos_ref[...,2].min(),pos_ref[...,2].max()
  xy_min,xy_max = pos_ref[...,:2].min() - 5, pos_ref[...,:2].max() + 5
  for ax in [ax1]:
    ax.set_xlim(xy_min, xy_max); ax.set_ylim(xy_min, xy_max)
    ax.axis(False)

  ims = []
  for k,(x,s) in enumerate(zip(pos,seq)):
    ims.append([cf.add_text("colored by N→C\n(overlaid on target)", ax1),
                cf.plot_pseudo_3D(x, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max),
                ax2.imshow(s.T, animated=True, cmap="bwr_r",vmin=-1, vmax=1)])
    
  ani = animation.ArtistAnimation(fig, ims, blit=True, interval=60)
  plt.close()
  return ani.to_html5_video()

In [108]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

In [109]:
save_pdb(outs)
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

design one_hot

In [110]:
XYZ = []
SEQ = []
RMSD = np.inf
BEST_outs = None
for n in range(500,1500):
  key,subkey = jax.random.split(key)
  n = 0 #jax.random.randint(subkey,[],0,5) # select which model to use
  state, outs = step(i, state, subkey, model_params[0], {"mask":mask})
  seq = outs["seq"].argmax(-1)
  seq_id = (seq == wt_seq.argmax(-1)).mean()
  i += 1
  losses = outs["losses"]
  if mask.sum() == len(seq) and losses["rmsd"] < RMSD:
    RMSD = losses["rmsd"]
    BEST_outs = outs
  if i % 10 == 0:
    # pick random position to flip to one_hot
    if mask.mean() < 1:
      mask = mask.at[np.random.choice(np.where(mask == 0)[0])].set(1)    
    print(f'{i} {int(mask.sum())} {losses["dgram"]:.3f} {losses["fape"]:.3f} {losses["rmsd"]:.3f} {seq_id:.3f}')

    # save for animation
    XYZ.append(np.asarray(outs["outputs"]["final_atom_positions"][:,1,:]))
    SEQ.append(np.asarray(outs["pseudo_seq"]))

500 1 1.120 0.309 0.428 0.228
510 2 1.030 0.190 0.362 0.217
520 3 0.977 0.171 0.339 0.228
530 4 1.026 0.275 0.362 0.228
540 5 1.040 0.182 0.368 0.250
550 6 0.970 0.169 0.337 0.261
560 7 1.015 0.253 0.337 0.272
570 8 0.961 0.178 0.349 0.250
580 9 0.940 0.167 0.340 0.261
590 10 0.953 0.170 0.350 0.239
600 11 0.977 0.171 0.427 0.239
610 12 0.927 0.168 0.324 0.261
620 13 0.936 0.170 0.349 0.239
630 14 0.890 0.156 0.335 0.250
640 15 0.882 0.164 0.341 0.239
650 16 0.931 0.212 0.356 0.228
660 17 0.878 0.169 0.321 0.228
670 18 0.902 0.163 0.337 0.239
680 19 0.889 0.171 0.314 0.261
690 20 0.948 0.202 0.384 0.272
700 21 0.897 0.234 0.324 0.261
710 22 0.941 0.184 0.341 0.272
720 23 0.915 0.176 0.328 0.272
730 24 0.949 0.168 0.352 0.239
740 25 0.923 0.191 0.333 0.217
750 26 0.903 0.162 0.309 0.228
760 27 0.940 0.160 0.332 0.217
770 28 0.949 0.160 0.328 0.217
780 29 0.928 0.163 0.336 0.217
790 30 0.919 0.164 0.316 0.217
800 31 0.921 0.164 0.343 0.217
810 32 0.930 0.166 0.335 0.207
820 33 0.945 0.17

In [111]:
HTML(make_animation(XYZ,SEQ,pos_ref=batch["all_atom_positions"][:,1,:]))

In [112]:
for model_param in model_params:
  params = {"seq":BEST_outs["seq"]}
  (loss, outs), grad = grad_fn(params, key, model_params=model_param, opt={"mask":mask})
  print(outs["losses"]["rmsd"])

0.89586204
0.9311516
2.9526575
1.0286101
0.9469117


In [113]:
save_pdb(BEST_outs)
view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
view.addModel(pdb_to_string("tmp.pdb"),'pdb')
view.setStyle({'cartoon': {}})
BB = ['C','O','N']
view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                    {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
              {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
              {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
view.zoomTo()
view.show()

In [114]:
"".join([order_restype[a] for a in seq])

'AYLIHCHMVKPGRHIHEAFIVDDMDAISEHISFLYTAIAAIGPSHCHVCVHTPSRSMSGSMAQMLTRMFTALPYTEVKVVFHGNSTCVHCTR'