<a href="https://colab.research.google.com/github/andrewfavor95/GuidedHallucination/blob/main/demos/cyclic_oligo_loss_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AfDesign - hallucination custom loss example: cyclic oligomers

In [None]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git
  # for debugging
  ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign
  # then add the GuidedHallucination repo for loss functions
  git clone https://github.com/andrewfavor95/GuidedHallucination.git
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
  for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1
  do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done
fi

In [None]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import colabdesign
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.shared.protein import _np_rmsd, _np_norm,_np_ang,_np_kabsch
from IPython.display import HTML
from google.colab import files
import numpy as np
import pdb as pydebug
import string

import jax
import jax.numpy as jnp


# Define custom loss function

Here we will write a custom loss function to bias the hallucinated oligomer towards cyclic symmetry.
Often, an alphafold-hallucinated protein with multiple chain copies will form a repeat motif (like a DHR), but if we want to make an oligomer, it is useful to have a way to bias the design towards complete cyclic symmetry.

In [None]:

def cyclic_sym_loss(inputs,outputs):

    target_theta = (2*jnp.pi/copies) # target theta for homo-oligomer

    pred = outputs["structure_module"]["final_atom_positions"][:,residue_constants.atom_order["CA"]]
    pred = pred - pred.mean(axis=0)

    chain_pairs = [(i%copies,j%copies) for i,j in zip(range(0,copies),range(1,copies+1))]

    theta_ij_list = []
    for i,j in chain_pairs:

        pred_i = pred[chain_len*i:chain_len*(i+1),:]-pred[chain_len*i:chain_len*(i+1),:].mean(axis=0) # get centered coords for chain i
        pred_j = pred[chain_len*j:chain_len*(j+1),:]-pred[chain_len*j:chain_len*(j+1),:].mean(axis=0) # get centered coords for chain j

        R_ij = _np_kabsch(pred_j,pred_i) # get rotation matrix from chain i to chain j via Kabsch algorithm

        theta_ij = jnp.arccos((jnp.trace(R_ij) + 1e-8 - 1.0)/2.0) # get theta of rotation matrix from trace
        theta_ij_list.append(theta_ij)

    
    cyclic_sym_loss_val = jnp.square(jnp.array(theta_ij_list).mean() - target_theta)

    return {"cyclic_sym_loss": cyclic_sym_loss_val  }



#Example 1: 
Hallucinating a trimer

In [None]:
copies = 3
chain_len = 60 # maybe change variable names later, or even add this as a cell to enter info into

In [None]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=cyclic_sym_loss)

af_model.prep_inputs(length=chain_len, copies=copies, homooligomer=True)

af_model.opt["weights"]["cyclic_sym_loss"] = 0.01

print("weights", af_model.opt["weights"])



weights {'con': 1.0, 'exp_res': 0.0, 'i_con': 1.0, 'i_pae': 0.0, 'pae': 0.0, 'plddt': 0.0, 'seq_ent': 0.0, 'cyclic_sym_loss': 0.01}


In [None]:
# Start with 50 iters of soft design
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)

1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 9.26 con 4.92 i_con 4.22 plddt 0.51 ptm 0.22 i_ptm 0.11 cyclic_sym_loss 0.12
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 7.93 con 3.85 i_con 4.08 plddt 0.47 ptm 0.25 i_ptm 0.12 cyclic_sym_loss 0.00
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.93 con 3.81 i_con 4.10 plddt 0.33 ptm 0.17 i_ptm 0.09 cyclic_sym_loss 0.02
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.65 con 3.55 i_con 4.10 plddt 0.42 ptm 0.21 i_ptm 0.10 cyclic_sym_loss 0.00
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 7.66 con 2.96 i_con 4.70 plddt 0.61 ptm 0.27 i_ptm 0.09 cyclic_sym_loss 0.00
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.23 con 3.22 i_con 4.01 plddt 0.47 ptm 0.27 i_ptm 0.13 cyclic_sym_loss 0.00
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 7.24 con 2.82 i_con 4.42 plddt 0.63 ptm 0.29 i_ptm 0.11 cyclic_sym_loss 0.00
8 models [1] recycles 0 hard 0 soft 1 temp 1 loss 6.95 con 2.79 i_con 4.16 plddt 0.58 ptm 0.33 i_ptm 0.14 cycli

In [None]:
af_model.plot_pdb()

This looks pretty nice and compact.  

Let's see if we can change some weights for the default loss functions in order to tune the types of contacts made between chains.








# Example 2:

What happend if we increase the weights of inter-chain loss metrics?

In [None]:
copies = 3
chain_len = 60

In [None]:
clear_mem()

af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=cyclic_sym_loss)

af_model.prep_inputs(length=chain_len, copies=copies, homooligomer=True)

af_model.opt["weights"]["cyclic_sym_loss"] = 0.1

af_model.opt["weights"]["i_con"] = 1.0
af_model.opt["weights"]["con"] = 0.0
af_model.opt["weights"]["i_pae"] = 0.1
af_model.opt["weights"]["i_ptm"] = 0.1
af_model.opt["weights"]["pae"] = 0.0



print("weights", af_model.opt["weights"])

weights {'con': 0.0, 'exp_res': 0.0, 'i_con': 1.0, 'i_pae': 0.1, 'pae': 0.0, 'plddt': 0.0, 'seq_ent': 0.0, 'cyclic_sym_loss': 0.1, 'i_ptm': 0.1}


In [None]:
# Start with 50 iters of soft design
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)

1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 9.26 con 4.92 i_con 4.22 plddt 0.51 ptm 0.22 i_ptm 0.11 cyclic_sym_loss 0.12
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 7.93 con 3.85 i_con 4.08 plddt 0.47 ptm 0.25 i_ptm 0.12 cyclic_sym_loss 0.00
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.93 con 3.81 i_con 4.10 plddt 0.33 ptm 0.17 i_ptm 0.09 cyclic_sym_loss 0.02
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.65 con 3.55 i_con 4.10 plddt 0.42 ptm 0.21 i_ptm 0.10 cyclic_sym_loss 0.00
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 7.66 con 2.96 i_con 4.70 plddt 0.61 ptm 0.27 i_ptm 0.09 cyclic_sym_loss 0.00
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.23 con 3.22 i_con 4.01 plddt 0.47 ptm 0.27 i_ptm 0.13 cyclic_sym_loss 0.00
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 7.24 con 2.82 i_con 4.42 plddt 0.63 ptm 0.29 i_ptm 0.11 cyclic_sym_loss 0.00
8 models [1] recycles 0 hard 0 soft 1 temp 1 loss 6.95 con 2.79 i_con 4.16 plddt 0.58 ptm 0.33 i_ptm 0.14 cycli

In [None]:
af_model.plot_pdb()

Now this gives a more interesting trimer.  While the individual monomers don't look too stable on their own, the overall complex has a more unique fold, which would be a good starting point for further design.