<a href="https://colab.research.google.com/github/andrewfavor95/GuidedHallucination/blob/main/demos/ss_spec_loss_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AfDesign - hallucination custom loss example: SS-spec

In [2]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git
  # for debugging
  ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign
  # then add the GuidedHallucination repo for loss functions
  git clone https://github.com/andrewfavor95/GuidedHallucination.git
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
  for W in openfold_model_ptm_1 openfold_model_ptm_2 openfold_model_no_templ_ptm_1
  do wget -qnc https://files.ipd.uw.edu/krypton/openfold/${W}.npz -P params; done
fi

In [3]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import colabdesign
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from IPython.display import HTML
from google.colab import files
import numpy as np
import pdb as pydebug
import string

import jax
import jax.numpy as jnp


# Define custom loss function

Here we will write a custom loss function to score the designed protein based on whether it possesses secondary structure features in a specified arrangement.

In [5]:
   
def ss_spec_loss(inputs,outputs,opt):
    
    def ss_probs_for_bins(outputs,
                      alpha_mu=5.1,
                      alpha_sig=0.7,
                      beta_mu=10.2,
                      beta_sig=0.7):

        bin_lowers = jnp.append(0,outputs["distogram"]["bin_edges"])
        bin_uppers = jnp.append(outputs["distogram"]["bin_edges"],1e3)

        P_ss_by_bin = {'helix':jax.scipy.stats.norm.cdf(bin_uppers, alpha_mu, alpha_sig) - jax.scipy.stats.norm.cdf(bin_lowers, alpha_mu, alpha_sig),
                       'strand':jax.scipy.stats.norm.cdf(bin_uppers, beta_mu, beta_sig) - jax.scipy.stats.norm.cdf(bin_lowers, beta_mu, beta_sig)}

        return P_ss_by_bin
  


    def parse_ss_spec(ss_spec,outputs,chain_len,copies,helix_cutoff=6.0,sheet_cutoff=9.0):


        P_ss_by_bin = ss_probs_for_bins(outputs)
        dgram = outputs["distogram"]["logits"]

        bin_ss_probs = jnp.ones((dgram.shape[0],dgram.shape[-1]))


        for ss_spec_str_i in ss_spec:
            ss_i = ss_spec_str_i.split(',')[0]
            range_i = ss_spec_str_i.split(',')[-1]

            for chain_ind in range(copies):
                start_ind = (chain_ind*chain_len)+int(range_i.split(':')[0])-1
                stop_ind = (chain_ind*chain_len)+int(range_i.split(':')[-1])

                if ss_i=='H':
                    bin_ss_probs = bin_ss_probs.at[start_ind:stop_ind,:].set(P_ss_by_bin["helix"])

                elif ss_i=='S':
                    bin_ss_probs = bin_ss_probs.at[start_ind:stop_ind,:].set(P_ss_by_bin["strand"])

        return bin_ss_probs
    
    
    chain_len = af_model._len
    copies= af_model._args["copies"]

    dgram = outputs["distogram"]["logits"]
    dgram_bins = jnp.append(0,outputs["distogram"]["bin_edges"])

    bin_ss_probs = parse_ss_spec(ss_spec,outputs,chain_len,copies)

    dgram_diag = jnp.diagonal(dgram,offset=3,axis1=0,axis2=1).T

    dgram_smooth = jnp.zeros((dgram.shape[0],dgram.shape[-1]))
    
    for i in range(4):
        dgram_smooth = dgram_smooth.at[i:chain_len+i-3,:].set((dgram_smooth[i:chain_len+i-3,:]+dgram_diag))

    px = jax.nn.softmax(dgram_smooth) # probability of a given residue (axis 0) being in a given bin (axis 1)
    
    correct_ss_prob = (px*bin_ss_probs).sum(-1) # gives array where sum of axis -1 elements is probability of being in the CORRECT secondary structure
    correct_ss_ent = -jnp.log(correct_ss_prob + 1e-8)

    ss_spec_loss_val = correct_ss_ent.mean()

    return {"ss_spec_loss":ss_spec_loss_val}

    


# Monomer design demos:

The specified secondary structure in this demo will correspond to that of [Top7](https://www.rcsb.org/structure/1qys), a *de novo* designed alpha-beta fold [[1]](https://www.science.org/doi/10.1126/science.1089427).

In [8]:
chain_len = 92

ss_spec=['S,2:9','S,14:21','H,25:42','S,45:51','H,54:72','S,76:82','S,85:91']

## Design test 1: 
First we'll just see what kind of output we get when adding this loss function, without modifying any other loss function weights.

The only other loss weight currently set is that of the contact-loss, with a weight of 1.0, so we will also set our ss-spec loss weight to 1.0


In [5]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=ss_spec_loss) # add our custom loss

af_model.opt["weights"]["ss_spec_loss"] = 1.0

af_model.prep_inputs(length=chain_len)
print("weights", af_model.opt["weights"])

weights {'con': 1.0, 'exp_res': 0.0, 'pae': 0.0, 'plddt': 0.0, 'seq_ent': 0.0, 'ss_spec_loss': 1.0}


In [6]:
# Start with 50 iters of soft design
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)


1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 10.15 con 4.88 plddt 0.49 ptm 0.16 ss_spec_loss 5.27
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 8.82 con 4.21 plddt 0.44 ptm 0.17 ss_spec_loss 4.61
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 8.89 con 3.81 plddt 0.46 ptm 0.24 ss_spec_loss 5.08
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 8.24 con 4.06 plddt 0.42 ptm 0.17 ss_spec_loss 4.17
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 8.24 con 4.05 plddt 0.38 ptm 0.17 ss_spec_loss 4.19
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 9.07 con 3.95 plddt 0.45 ptm 0.21 ss_spec_loss 5.12
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 7.88 con 4.07 plddt 0.45 ptm 0.22 ss_spec_loss 3.81
8 models [1] recycles 0 hard 0 soft 1 temp 1 loss 7.23 con 3.94 plddt 0.45 ptm 0.24 ss_spec_loss 3.29
9 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.19 con 4.02 plddt 0.40 ptm 0.17 ss_spec_loss 3.17
10 models [2] recycles 0 hard 0 soft 1 temp 1 loss 6.92 con 3.86 plddt 0.42 ptm 0

In [7]:
af_model.plot_pdb()

This looks alright.  The desired secondary structure features are mostly in the specified regions, but I don't have much confidence in the integrity of this design as a well-behaved monomer.

The beta strands aren't paired as well as they could be, and it would be nice if we had a more compact fold overall.

Let's see if we can improve the structure overall by upweighting pae loss for regions between secondary structure motifs.






## Design test 2:

Can we improve the interactions between secondary structure by adding additional weights for the standard loss functions?

We'll add small weight terms for pae, and see if that helps bias our design towards a better fold.

In [15]:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination",
                             debug=False,
                             loss_callback=ss_spec_loss) # add our custom loss
                             

af_model.opt["weights"]["ss_spec_loss"] = 1.0

af_model.opt["weights"]["pae"] = 0.1

af_model.prep_inputs(length=chain_len)

print("weights", af_model.opt["weights"])


weights {'con': 1.0, 'exp_res': 0.0, 'pae': 0.1, 'plddt': 0.0, 'seq_ent': 0.0, 'ss_spec_loss': 1.0}


In [16]:
# Start with 50 iters of soft design
af_model.restart(mode="gumbel", seed=0)
af_model.design_soft(50)

# three stage design  
af_model.restart(seq=af_model.aux["seq"]["pseudo"], keep_history=True)
af_model.design_3stage(50,50,5)

1 models [2] recycles 0 hard 0 soft 1 temp 1 loss 10.21 pae 0.63 con 4.88 plddt 0.49 ptm 0.16 ss_spec_loss 5.27
2 models [3] recycles 0 hard 0 soft 1 temp 1 loss 8.89 pae 0.65 con 4.21 plddt 0.44 ptm 0.17 ss_spec_loss 4.61
3 models [4] recycles 0 hard 0 soft 1 temp 1 loss 8.95 pae 0.57 con 3.82 plddt 0.46 ptm 0.24 ss_spec_loss 5.07
4 models [4] recycles 0 hard 0 soft 1 temp 1 loss 8.37 pae 0.61 con 4.01 plddt 0.43 ptm 0.18 ss_spec_loss 4.30
5 models [3] recycles 0 hard 0 soft 1 temp 1 loss 8.12 pae 0.64 con 4.19 plddt 0.40 ptm 0.15 ss_spec_loss 3.87
6 models [4] recycles 0 hard 0 soft 1 temp 1 loss 8.94 pae 0.56 con 3.76 plddt 0.50 ptm 0.24 ss_spec_loss 5.13
7 models [1] recycles 0 hard 0 soft 1 temp 1 loss 8.98 pae 0.47 con 3.84 plddt 0.58 ptm 0.29 ss_spec_loss 5.09
8 models [1] recycles 0 hard 0 soft 1 temp 1 loss 8.17 pae 0.57 con 3.86 plddt 0.48 ptm 0.25 ss_spec_loss 4.26
9 models [4] recycles 0 hard 0 soft 1 temp 1 loss 7.26 pae 0.62 con 4.01 plddt 0.41 ptm 0.18 ss_spec_loss 3.19


In [17]:
af_model.plot_pdb()

This looks a bit better! The beta sheets have more satisfied H-bonds, and the helices are packed better against adjacent motifs.

Interestingly, the fold is very similar to Top7!