<a href="https://colab.research.google.com/github/andrewfavor95/GuidedHallucination/blob/main/demos/cyclic_oligo_loss_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AfDesign - hallucination custom loss example: cyclic oligomers

In [1]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git
  # for debugging
  ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign
  # then add the GuidedHallucination repo for loss functions
  git clone https://github.com/andrewfavor95/GuidedHallucination.git
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
  for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1
  do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done
fi

Cloning into 'GuidedHallucination'...


In [2]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import colabdesign
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from IPython.display import HTML
from google.colab import files
import numpy as np
import pdb as pydebug
import string

import jax
import jax.numpy as jnp


# Define custom loss function

Here we will write a custom loss function to bias the hallucinated oligomer towards cyclic symmetry.
Often, an alphafold-hallucinated protein with multiple chain copies will form a repeat motif (like a DHR), but if we want to make an oligomer, it is useful to have a way to bias the design towards complete cyclic symmetry.

In [3]:

def xform_diff_loss(inputs,outputs,opt,center_coords=False):

    def quat2rot(quat, dtype='f8', shape=(3, 3)):
        quat = jnp.asarray(quat)
        assert quat.shape[-1] == 4
        qr = quat[..., 0]
        qi = quat[..., 1]
        qj = quat[..., 2]
        qk = quat[..., 3]
        outshape = quat.shape[:-1]
        rot = jnp.zeros(outshape + shape)

        rot = rot.at[..., 0, 0].set(1 - 2 * (qj**2 + qk**2))
        rot = rot.at[..., 0, 1].set(2 * (qi * qj - qk * qr))
        rot = rot.at[..., 0, 2].set(2 * (qi * qk + qj * qr))
        rot = rot.at[..., 1, 0].set(2 * (qi * qj + qk * qr))
        rot = rot.at[..., 1, 1].set(1 - 2 * (qi**2 + qk**2))
        rot = rot.at[..., 1, 2].set(2 * (qj * qk - qi * qr))
        rot = rot.at[..., 2, 0].set(2 * (qi * qk - qj * qr))
        rot = rot.at[..., 2, 1].set(2 * (qj * qk + qi * qr))
        rot = rot.at[..., 2, 2].set(1 - 2 * (qi**2 + qj**2))

        return rot

    def rot2quat(matrix, isprecise=False):
        M = jnp.array(matrix, copy=False)[:4, :4]
        if isprecise:
            q = jnp.empty((4, ))
            t = jnp.trace(M)
            if t > M[3, 3]:
                q[0] = t
                q[3] = M[1, 0] - M[0, 1]
                q[2] = M[0, 2] - M[2, 0]
                q[1] = M[2, 1] - M[1, 2]
            else:
                i, j, k = 0, 1, 2
                if M[1, 1] > M[0, 0]:
                    i, j, k = 1, 2, 0
                if M[2, 2] > M[i, i]:
                    i, j, k = 2, 0, 1
                t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
                q[i] = t
                q[j] = M[i, j] + M[j, i]
                q[k] = M[k, i] + M[i, k]
                q[3] = M[k, j] - M[j, k]
                q = q[[3, 0, 1, 2]]
            q *= 0.5 / math.sqrt(t * M[3, 3])
        else:
            m00 = M[0, 0]
            m01 = M[0, 1]
            m02 = M[0, 2]
            m10 = M[1, 0]
            m11 = M[1, 1]
            m12 = M[1, 2]
            m20 = M[2, 0]
            m21 = M[2, 1]
            m22 = M[2, 2]
            # symmetric matrix K
            K =jnp.array([[m00-m11-m22, 0.0,         0.0,         0.0],
                             [m01+m10,     m11-m00-m22, 0.0,         0.0],
                             [m02+m20,     m12+m21,     m22-m00-m11, 0.0],
                             [m21-m12,     m02-m20,     m10-m01,     m00+m11+m22]])
            K /= 3.0
            # quaternion is eigenvector of K that corresponds to largest eigenvalue
            w, V = jnp.linalg.eigh(K)
            q = V[[3, 0, 1, 2], jnp.argmax(w)]

        flip =  q[0] < 0
        q = jnp.where(flip, -q, q)

        return q

    def average_quats_mean(quats):
        
        flip = quats @ quats[0,:] < 0 # see if quats have correct handedness

        quats = jnp.where(flip, -quats.T, quats.T).T # flip quats to have correct handedness

        q_avg = quats.mean(axis=0)/jnp.linalg.norm(quats.mean(axis=0))

        return q_avg


    pred = outputs["structure_module"]["final_atom_positions"][:,residue_constants.atom_order["CA"]]
    pred = pred - pred.mean(axis=0)

  
    sub_pred_dict = {}
    
    for i in range(copies):
        start_i = chain_len*i
        stop_i = chain_len*(i+1)
        sub_pred_dict[i] = pred[start_i:stop_i,:]
            
    chain_order_sorted = [i for i in range(copies)]

    # Define different chain pairs to compare transforms between
    xform_pairs = [(chain_order_sorted[i%copies],chain_order_sorted[j%copies]) for i,j in zip(range(0,copies),range(1,copies+1))]

    # Start gathering the info for all the chain transformations
    rmsd_list = []
    quat_list = []

    R_ij_list = []
    v_ij_list = []
    
    for i,j in xform_pairs:
        pred_i = sub_pred_dict[i]-sub_pred_dict[i].mean(-2) 
        pred_j = sub_pred_dict[j]-sub_pred_dict[j].mean(-2)

        v_ij_list.append(sub_pred_dict[j].mean(-2)-sub_pred_dict[i].mean(-2))
        R_ij = colabdesign.shared.protein._np_kabsch(pred_j,pred_i)
        q_ij = rot2quat(R_ij)
        
        rmsd_ij = colabdesign.shared.protein._np_rmsd(sub_pred_dict[i],sub_pred_dict[j])
        rmsd_list.append(rmsd_ij)

        R_ij_list.append(R_ij)
        quat_list.append(q_ij)


    q_avg = average_quats_mean(jnp.array(quat_list))
    R_avg = quat2rot(q_avg)

    B_avg = (R_avg - jnp.identity(3)) @ jnp.linalg.inv((R_avg + jnp.identity(3)))

    rod_vector = jnp.array([B_avg[2, 1], B_avg[0, 2], B_avg[1, 0]]).T
    rod_mag = jnp.linalg.norm(rod_vector)

    theta = 2 * jnp.arctan(rod_mag)

    chain_rmsd_loss_val = jnp.square(jnp.array(rmsd_list)).mean()

    theta_error = jnp.square(theta - (jnp.pi * 2 / copies) )

    xform_diff_loss_val = theta_error
    
    return {"chain_rmsd_loss": chain_rmsd_loss_val, "xform_diff_loss": xform_diff_loss_val,  }



#Example 1: 
Hallucinating a trimer

In [4]:
copies = 3
chain_len = 30 # maybe change variable names later, or even add this as a cell to enter info into

In [5]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=xform_diff_loss)

af_model.prep_inputs(length=chain_len, copies=copies, homooligomer=True)

af_model.opt["weights"]["xform_diff_loss"] = 0.1
af_model.opt["weights"]["chain_rmsd_loss"] = 0.1

# af_model.prep_inputs(length=chain_len)

print("weights", af_model.opt["weights"])


weights {'con': 1.0, 'exp_res': 0.0, 'i_con': 1.0, 'i_pae': 0.0, 'pae': 0.0, 'plddt': 0.0, 'seq_ent': 0.0, 'xform_diff_loss': 0.1, 'chain_rmsd_loss': 0.1}


In [6]:
# Start with 50 iters of soft design
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)

1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 19.38 con 4.22 i_con 4.17 plddt 0.30 ptm 0.14 i_ptm 0.06 chain_rmsd_loss 6.69 xform_diff_loss 4.30
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 25.45 con 4.22 i_con 3.99 plddt 0.35 ptm 0.15 i_ptm 0.05 chain_rmsd_loss 13.19 xform_diff_loss 4.04
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 22.06 con 4.14 i_con 4.15 plddt 0.40 ptm 0.19 i_ptm 0.06 chain_rmsd_loss 13.76 xform_diff_loss 0.00
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 21.36 con 4.07 i_con 4.13 plddt 0.36 ptm 0.19 i_ptm 0.06 chain_rmsd_loss 13.16 xform_diff_loss 0.00
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 12.09 con 4.06 i_con 4.17 plddt 0.38 ptm 0.18 i_ptm 0.05 chain_rmsd_loss 3.86 xform_diff_loss 0.00
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 8.92 con 3.86 i_con 4.04 plddt 0.36 ptm 0.20 i_ptm 0.08 chain_rmsd_loss 1.02 xform_diff_loss 0.00
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 19.76 con 3.98 i_con 4.07 plddt 0.33 ptm 0.18 i_ptm 0.

In [7]:
af_model.plot_pdb()

This is alright.  I think we can increase other loss weights to make the contacts better.








# Example 2:

Can we increase the i_con and i_pae weights to get better interaction between chains?

In [8]:
copies = 3
chain_len = 60 

In [9]:
clear_mem()

af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=xform_diff_loss)

af_model.prep_inputs(length=chain_len, copies=copies, homooligomer=True)

af_model.opt["weights"]["xform_diff_loss"] = 0.1
af_model.opt["weights"]["chain_rmsd_loss"] = 0.1

print("weights", af_model.opt["weights"])

weights {'con': 1.0, 'exp_res': 0.0, 'i_con': 1.0, 'i_pae': 0.0, 'pae': 0.0, 'plddt': 0.0, 'seq_ent': 0.0, 'xform_diff_loss': 0.1, 'chain_rmsd_loss': 0.1}


In [10]:
# Start with 50 iters of soft design
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)

1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 50.57 con 4.92 i_con 4.22 plddt 0.51 ptm 0.22 i_ptm 0.11 chain_rmsd_loss 41.35 xform_diff_loss 0.08
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 44.51 con 4.12 i_con 4.03 plddt 0.43 ptm 0.19 i_ptm 0.11 chain_rmsd_loss 33.71 xform_diff_loss 2.65
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 42.50 con 4.44 i_con 4.26 plddt 0.34 ptm 0.15 i_ptm 0.08 chain_rmsd_loss 33.32 xform_diff_loss 0.47
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 35.42 con 3.96 i_con 4.16 plddt 0.35 ptm 0.20 i_ptm 0.10 chain_rmsd_loss 27.09 xform_diff_loss 0.21
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 14.65 con 3.89 i_con 4.18 plddt 0.44 ptm 0.23 i_ptm 0.11 chain_rmsd_loss 6.58 xform_diff_loss 0.00
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 12.02 con 4.10 i_con 4.10 plddt 0.39 ptm 0.22 i_ptm 0.12 chain_rmsd_loss 3.79 xform_diff_loss 0.02
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 22.44 con 3.87 i_con 4.02 plddt 0.44 ptm 0.26 i_ptm 

In [11]:
af_model.plot_pdb()