<a href="https://colab.research.google.com/github/andrewfavor95/GuidedHallucination/blob/main/ss_spec_loss_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AfDesign - hallucination custom loss example: SS-spec

In [1]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  # pip -q install git+https://github.com/sokrypton/ColabDesign.git
  pip -q install git+https://github.com/andrewfavor95/ColabDesign.git
  # for debugging
  ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign
  # then add the GuidedHallucination repo for loss functions
  git clone https://github.com/andrewfavor95/GuidedHallucination.git
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
  for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1
  do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done
fi

Cloning into 'GuidedHallucination'...


In [2]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import colabdesign
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from IPython.display import HTML
from google.colab import files
import numpy as np
import pdb as pydebug
import string

import jax
import jax.numpy as jnp


# Here's a custom loss function that biases the hallucination towards adopting specific secondary structure motifs at specified regions.

In [3]:

def ss_spec_loss(inputs,outputs,opt):
  
    def parse_ss_spec(ss_spec,outputs,chain_len,copies,helix_cutoff=6.0,sheet_cutoff=9.0):
        dgram = outputs["distogram"]["logits"]
        dgram_bins = jnp.append(0,outputs["distogram"]["bin_edges"])

        resi_list = []
        ss_list = []
        bins = jnp.full((dgram.shape[0],dgram.shape[-1]),True)
        H_ranges = []
        S_ranges = []

        ss_range_list = []
        ss_type_list = []


        for ss_spec_str_i in ss_spec:
            ss_i = ss_spec_str_i.split(',')[0]
            range_i = ss_spec_str_i.split(',')[-1]
            
            for chain_ind in range(copies):
                start_ind = (chain_ind*chain_len)+int(range_i.split(':')[0])-1
                stop_ind = (chain_ind*chain_len)+int(range_i.split(':')[-1])


                if ss_i=='H':
                    bins = bins.at[start_ind:stop_ind,:].set(dgram_bins < helix_cutoff )
                    ss_type_list.append('H')
                    ss_range_list.append((start_ind,stop_ind))
                elif ss_i=='S':
                    bins = bins.at[start_ind:stop_ind,:].set(dgram_bins > sheet_cutoff )
                    S_ranges.append((start_ind,stop_ind))
                    ss_type_list.append('S')
                    ss_range_list.append((start_ind,stop_ind))

        return bins,ss_type_list,ss_range_list

    def get_ss_pae(outputs,ss_type_list,ss_range_list,SS_weight,SH_weight,HH_weight):
        pae_prob = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"])
        pae = (pae_prob * jnp.arange(pae_prob.shape[-1])).mean(-1)

        num_regions = len(ss_type_list)

        pae_spec_error = 0.0

        for (start_resi_i,stop_resi_i), ss_type_i in zip(ss_range_list,ss_type_list):
            for (start_resi_j,stop_resi_j), ss_type_j in zip(ss_range_list,ss_type_list):
                if ss_type_i=='S' and ss_type_j=='S':
                    pae_spec_error += pae[start_resi_i:stop_resi_i,start_resi_j:stop_resi_j].mean()*(SS_weight/num_regions)
                elif ss_type_i=='H' and ss_type_j=='H':
                    pae_spec_error += pae[start_resi_i:stop_resi_i,start_resi_j:stop_resi_j].mean()*(HH_weight/num_regions)
                else:
                    pae_spec_error += pae[start_resi_i:stop_resi_i,start_resi_j:stop_resi_j].mean()*(SH_weight/num_regions)

        return pae_spec_error





    chain_len = af_model._len

    copies= af_model._args["copies"]

    dgram = outputs["distogram"]["logits"]
    dgram_bins = jnp.append(0,outputs["distogram"]["bin_edges"])

    bins,ss_type_list,ss_range_list = parse_ss_spec(ss_spec,outputs,chain_len,copies)


    dgram_diag = jnp.diagonal(dgram,offset=3,axis1=0,axis2=1).T
    buffer = jnp.zeros((3,dgram_diag.shape[-1]))

    dgram_n3 = jnp.append(dgram_diag,buffer,axis=0)+jnp.append(buffer,dgram_diag,axis=0)

    px = jax.nn.softmax(dgram_n3)

    px_ = jax.nn.softmax(dgram_n3 - 1e7 * (1-bins)) 

    con_loss_cat_ent = -(px_ * jax.nn.log_softmax(dgram_n3)).sum(-1)
    con_loss_bin_ent = -jnp.log((bins * px + 1e-8).sum(-1))

    loss_full = jnp.where(True, con_loss_bin_ent, con_loss_cat_ent)

    ss_spec_loss_val = loss_full.sum()/jnp.sqrt(len(loss_full))

    # pydebug.set_trace()

    if "ss_pae_calc" in opt.keys() : 
        if "SS_weight" in opt.keys():
            SS_weight = opt["SS_weight"] 
        else:
            SS_weight = 0.5
        if "HH_weight" in opt.keys():
            HH_weight = opt["HH_weight"] 
        else:
            HH_weight = 0.15

        if "SH_weight" in opt.keys():
            SH_weight = opt["SH_weight"] 
        else:
            SH_weight = 0.35

        ss_pae_loss_val = get_ss_pae(outputs,ss_type_list,ss_range_list,SS_weight,SH_weight,HH_weight)

        return {"ss_spec_loss":ss_spec_loss_val, "ss_pae_loss":ss_pae_loss_val}
    else:
        # pae_spec_error = 0.0
        return {"ss_spec_loss":ss_spec_loss_val}

    




#Example 1: 
hallucinating monomer with the same chain length and secondary structure locations as Top7:

In [4]:
chain_len = 92

ss_spec=['S,2:9','S,14:21','H,25:42','S,45:51','H,54:72','S,76:82','S,85:91']

In [5]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=ss_spec_loss) # add our custom loss

af_model.opt["weights"]["ss_spec_loss"] = 0.1

af_model.opt["weights"]["pae"] = 0.1

af_model.prep_inputs(length=chain_len)

print("weights", af_model.opt["weights"])


weights {'con': 1.0, 'exp_res': 0.0, 'pae': 0.1, 'plddt': 0.0, 'seq_ent': 0.0, 'ss_spec_loss': 0.1}


In [6]:
# Start with 50 iters of soft design
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)

1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 6.95 pae 0.63 con 4.88 ptm 0.16 i_ptm 0 ss_spec_loss 20.06
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 5.92 pae 0.65 con 4.19 ptm 0.17 i_ptm 0 ss_spec_loss 16.63
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 5.91 pae 0.58 con 3.89 ptm 0.22 i_ptm 0 ss_spec_loss 19.56
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 5.84 pae 0.54 con 3.67 ptm 0.29 i_ptm 0 ss_spec_loss 21.14
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 5.84 pae 0.63 con 4.19 ptm 0.16 i_ptm 0 ss_spec_loss 15.95
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 5.63 pae 0.63 con 4.12 ptm 0.17 i_ptm 0 ss_spec_loss 14.40
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 5.51 pae 0.60 con 3.99 ptm 0.21 i_ptm 0 ss_spec_loss 14.62
8 models [1] recycles 0 hard 0 soft 1 temp 1 loss 5.69 pae 0.61 con 3.96 ptm 0.20 i_ptm 0 ss_spec_loss 16.62
9 models [4] recycles 0 hard 0 soft 1 temp 1 loss 5.41 pae 0.59 con 3.85 ptm 0.22 i_ptm 0 ss_spec_loss 15.06
10 models [2] recyc

In [7]:
af_model.plot_pdb()

This looks alright.  The specified secondary structure features are in the desired sequence regions, but I don't have much confidence in the integrity of this design as a well-behaved monomer.

Let's see if we can improve the structure overall by upweighting pae loss for regions between secondary structure motifs.









# Example number 2:

Can we improve the interactions between secondary structure motifs by upweighting the pae terms?

In [8]:
chain_len = 92

ss_spec=['S,2:9','S,14:21','H,25:42','S,45:51','H,54:72','S,76:82','S,85:91']


In [9]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=ss_spec_loss) # add our custom loss

af_model.opt["weights"]["ss_spec_loss"] = 0.1
af_model.opt["weights"]["ss_pae_loss"] = 0.1


af_model.opt["ss_pae_calc"] = True # Say that we do want to add an additional loss term for ss-pae

af_model.opt["SS_weight"] = 0.5 # strand-strand interactions are weighted as most important
af_model.opt["SH_weight"] = 0.35 # strand-helix interactions are second most important
af_model.opt["HH_weight"] = 0.15 # helix-helix interactions are third most important

af_model.prep_inputs(length=chain_len)

print("weights", af_model.opt["weights"])

weights {'con': 1.0, 'exp_res': 0.0, 'pae': 0.0, 'plddt': 0.0, 'seq_ent': 0.0, 'ss_pae_loss': 0.1, 'ss_spec_loss': 0.1}


In [10]:
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)

1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 7.07 con 4.88 ptm 0.16 i_ptm 0 ss_pae_loss 1.85 ss_spec_loss 20.06
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 6.04 con 4.19 ptm 0.17 i_ptm 0 ss_pae_loss 1.87 ss_spec_loss 16.63
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 6.01 con 3.90 ptm 0.22 i_ptm 0 ss_pae_loss 1.70 ss_spec_loss 19.43
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 5.80 con 3.95 ptm 0.19 i_ptm 0 ss_pae_loss 1.69 ss_spec_loss 16.80
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 5.75 con 4.08 ptm 0.16 i_ptm 0 ss_pae_loss 1.81 ss_spec_loss 14.83
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 5.49 con 4.03 ptm 0.15 i_ptm 0 ss_pae_loss 1.78 ss_spec_loss 12.84
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 5.92 con 4.19 ptm 0.18 i_ptm 0 ss_pae_loss 1.77 ss_spec_loss 15.47
8 models [1] recycles 0 hard 0 soft 1 temp 1 loss 5.94 con 4.01 ptm 0.19 i_ptm 0 ss_pae_loss 1.65 ss_spec_loss 17.69
9 models [4] recycles 0 hard 0 soft 1 temp 1 loss 5.15 con 4.04 

In [11]:
af_model.plot_pdb()

This looks a bit better! The beta sheets have more satisfied H-bonds, and the helices are packed better against adjacent motifs.

Interestingly, the fold is very similar to Top7!