<a href="https://colab.research.google.com/github/Chengeng-Yang/ML_projects/blob/main/AF2Rank_sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AF2Rank
[AF2Rank](https://github.com/jproney/AF2Rank) implemented using ColabDesign.

If you find useful, please cite:
- Roney, J.P. and Ovchinnikov, S., 2022. **State-of-the-Art estimation of protein model accuracy using AlphaFold**. [BioRxiv](https://www.biorxiv.org/content/10.1101/2022.03.11.484043v3.full).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title ## setup
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1
  # for debugging
  ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign

  # alphafold params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar | tar x -C params

  wget -qnc https://zhanggroup.org/TM-score/TMscore.cpp
  g++ -static -O3 -ffast-math -lm -o TMscore TMscore.cpp
fi

     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.1/3.1 MB 80.9 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 352.1/352.1 kB 29.2 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 77.9/77.9 kB 7.7 MB/s eta 0:00:00


In [None]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from colabdesign import clear_mem, mk_af_model
from colabdesign.shared.utils import copy_dict

import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
import jax

def tmscore(x,y):
  # save to dumpy pdb files
  for n,z in enumerate([x,y]):
    out = open(f"{n}.pdb","w")
    for k,c in enumerate(z):
      out.write("ATOM  %5d  %-2s  %3s %s%4d    %8.3f%8.3f%8.3f  %4.2f  %4.2f\n"
                  % (k+1,"CA","ALA","A",k+1,c[0],c[1],c[2],1,0))
    out.close()
  # pass to TMscore
  output = os.popen('./TMscore 0.pdb 1.pdb')

  # parse outputs
  parse_float = lambda x: float(x.split("=")[1].split()[0])
  o = {}
  for line in output:
    line = line.rstrip()
    if line.startswith("RMSD"): o["rms"] = parse_float(line)
    if line.startswith("TM-score"): o["tms"] = parse_float(line)
    if line.startswith("GDT-TS-score"): o["gdt"] = parse_float(line)

  return o

def plot_me(scores, x="tm_i", y="composite",
            title=None, diag=False, scale_axis=True, dpi=100, **kwargs):
  def rescale(a,amin=None,amax=None):
    a = np.copy(a)
    if amin is None: amin = a.min()
    if amax is None: amax = a.max()
    a[a < amin] = amin
    a[a > amax] = amax
    return (a - amin)/(amax - amin)

  plt.figure(figsize=(5,5), dpi=dpi)
  if title is not None: plt.title(title)
  x_vals = np.array([k[x] for k in scores])
  y_vals = np.array([k[y] for k in scores])
  c = rescale(np.array([k["plddt"] for k in scores]),0.5,0.9)
  plt.scatter(x_vals, y_vals, c=c*0.75, s=5, vmin=0, vmax=1, cmap="gist_rainbow",
              **kwargs)
  if diag:
    plt.plot([0,1],[0,1],color="black")

  labels = {"tm_i":"TMscore of Input",
            "tm_o":"TMscore of Output",
            "tm_io":"TMscore between Input and Output",
            "ptm":"Predicted TMscore (pTM)",
            "i_ptm":"Predicted interface TMscore (ipTM)",
            "plddt":"Predicted LDDT (pLDDT)",
            "composite":"Composite"}

  plt.xlabel(labels.get(x,x));  plt.ylabel(labels.get(y,y))
  if scale_axis:
    if x in labels: plt.xlim(-0.1, 1.1)
    if y in labels: plt.ylim(-0.1, 1.1)

  print(spearmanr(x_vals,y_vals).correlation)

class af2rank:
  def __init__(self, pdb, chain=None, model_name="model_1_ptm", model_names=None):
    self.args = {"pdb":pdb, "chain":chain,
                 "use_multimer":("multimer" in model_name),
                 "model_name":model_name,
                 "model_names":model_names}
    self.reset()

  def reset(self):
    self.model = mk_af_model(protocol="fixbb",
                             use_templates=True,
                             use_multimer=self.args["use_multimer"],
                             debug=False,
                             model_names=self.args["model_names"])

    self.model.prep_inputs(self.args["pdb"], chain=self.args["chain"])
    self.model.set_seq(mode="wildtype")
    self.wt_batch = copy_dict(self.model._inputs["batch"])
    self.wt = self.model._wt_aatype

  def set_pdb(self, pdb, chain=None):
    if chain is None: chain = self.args["chain"]
    self.model.prep_inputs(pdb, chain=chain)
    self.model.set_seq(mode="wildtype")
    self.wt = self.model._wt_aatype

  def set_seq(self, seq):
    self.model.set_seq(seq=seq)
    self.wt = self.model._params["seq"][0].argmax(-1)

  def _get_score(self):
    score = copy_dict(self.model.aux["log"])

    score["plddt"] = score["plddt"]
    score["pae"] = 31.0 * score["pae"]
    score["rmsd_io"] = score.pop("rmsd",None)

    i_xyz = self.model._inputs["batch"]["all_atom_positions"][:,1]
    o_xyz = np.array(self.model.aux["atom_positions"][:,1])

    # TMscore to input/output
    if hasattr(self,"wt_batch"):
      n_xyz = self.wt_batch["all_atom_positions"][:,1]
      score["tm_i"] = tmscore(n_xyz,i_xyz)["tms"]
      score["tm_o"] = tmscore(n_xyz,o_xyz)["tms"]

    # TMscore between input and output
    score["tm_io"] = tmscore(i_xyz,o_xyz)["tms"]

    # composite score
    score["composite"] = score["ptm"] * score["plddt"] * score["tm_io"]
    return score

  def predict(self, pdb=None, seq=None, chain=None,
              input_template=True, model_name=None,
              rm_seq=True, rm_sc=True, rm_ic=False,
              recycles=1, iterations=1,
              output_pdb=None, extras=None, verbose=True):

    if model_name is not None:
      self.args["model_name"] = model_name
      if "multimer" in model_name:
        if not self.args["use_multimer"]:
          self.args["use_multimer"] = True
          self.reset()
      else:
        if self.args["use_multimer"]:
          self.args["use_multimer"] = False
          self.reset()

    if pdb is not None: self.set_pdb(pdb, chain)
    if seq is not None: self.set_seq(seq)

    # set template sequence
    self.model._inputs["batch"]["aatype"] = self.wt

    # set other options
    self.model.set_opt(
        template=dict(rm_ic=rm_ic),
        num_recycles=recycles)
    self.model._inputs["rm_template"][:] = not input_template
    self.model._inputs["rm_template_sc"][:] = rm_sc
    self.model._inputs["rm_template_seq"][:] = rm_seq

    # "manual" recycles using templates
    ini_atoms = self.model._inputs["batch"]["all_atom_positions"].copy()
    for i in range(iterations):
      self.model.predict(models=self.args["model_name"], verbose=False)
      if i < iterations - 1:
        self.model._inputs["batch"]["all_atom_positions"] = self.model.aux["atom_positions"]
      else:
        self.model._inputs["batch"]["all_atom_positions"] = ini_atoms

    score = self._get_score()
    if extras is not None:
      score.update(extras)

    if output_pdb is not None:
      self.model.save_pdb(output_pdb)

    if verbose:
      print_list = ["tm_i","tm_o","tm_io","composite","ptm","i_ptm","plddt","fitness","id"]
      print_score = lambda k: f"{k} {score[k]:.4f}" if isinstance(score[k],float) else f"{k} {score[k]}"
      print(*[print_score(k) for k in print_list if k in score])

    return score

In [None]:
#@markdown ### **settings**
recycles = 1 #@param ["0", "1", "2", "3", "4"] {type:"raw"}
iterations = 1

# decide what model to use
model_mode = "alphafold-multimer" #@param ["alphafold", "alphafold-multimer"]
model_num = 1 #@param ["1", "2", "3", "4", "5"] {type:"raw"}

if model_mode == "alphafold":
  model_name = f"model_{model_num}_ptm"
if model_mode == "alphafold-multimer":
  model_name = f"model_{model_num}_multimer_v3"

save_output_pdbs = True #@param {type:"boolean"}

#@markdown ### **advanced**
mask_sequence = True #@param {type:"boolean"}
mask_sidechains = True #@param {type:"boolean"}
mask_interchain = False #@param {type:"boolean"}

SETTINGS = {"rm_seq":mask_sequence,
            "rm_sc":mask_sidechains,
            "rm_ic":mask_interchain,
            "recycles":int(recycles),
            "iterations":int(iterations),
            "model_name":model_name}

## test 4m3k


### zdock 2.1, pose created by 3.0.2

#### score all poses

In [None]:
# initialtion
NAME = "4m3k"
CHAIN = "A" # this can be multiple  #change from A to B
NATIVE_PATH = f"{NAME}.pdb"
DECOY_DIR = f"{NAME}"


if save_output_pdbs:
  os.makedirs(f"{NAME}_output", exist_ok=True) #ok_exists=True)


# get data from rcsb; no decoy
%shell wget -qnc https://files.rcsb.org/download/{NAME}.pdb


# setup model
clear_mem()
af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])
# note 230425, reset is problematic

In [None]:
# score no structure
_ = af.predict(pdb=NATIVE_PATH, input_template=False, **SETTINGS)

tm_i 1.0000 tm_o 0.4811 tm_io 0.4811 composite 0.0607 ptm 0.3343 i_ptm 0.0000 plddt 0.3774


In [None]:
SCORES = []

# score native structure
SCORES.append(af.predict(pdb=NATIVE_PATH, **SETTINGS, extras={"id":NATIVE_PATH}))


In [None]:
# upload poses_4m3k_zdock2.1_mda_split_2.1_mark_create-by-3.0.2.zip from local instead; but the file will be deleted when the runtime is ended
%shell tar xvzf drive/MyDrive/poses_{NAME}_zdock2.1_mda_split_2.1_mark_create-by-3.0.2.zip

In [None]:
# score poses
import csv

with open('demo_4m3k_2.1_B.csv', 'w') as myFile:
  writer = csv.writer(myFile)
  writer.writerow(['con', 'dgram_cce', 'exp_res', 'fape', 'hard', 'helix', 'i_ptm','loss', 'models', 'pae', 'plddt', 'ptm', 'recycles', 'seq_ent', 'seqid', 'soft', 'temp', 'rmsd_io', 'tm_i', 'tm_o', 'tm_io', 'composite', 'id'
])

SCORES = []
POSE_DIR = f"poses_{NAME}_zdock2.1_mda_split_2.1_mark_create-by-3.0.2/"
for pose_pdb in os.listdir(POSE_DIR):
  if 'complex' in pose_pdb:
    input_pdb = os.path.join(POSE_DIR, pose_pdb)
    if save_output_pdbs:
      output_pdb = os.path.join(f"{NAME}_output",pose_pdb)
    else:
      output_pdb = None
    SCORES.append(af.predict(pdb=input_pdb, output_pdb=output_pdb,
                            **SETTINGS, extras={"id":pose_pdb}))
    with open('demo_4m3k_2.1_B.csv', 'a') as myFile:
      writer = csv.writer(myFile)
      writer.writerow(SCORES[-1].values())
    %shell cp demo_4m3k_2.1_B.csv drive/MyDrive/demo_4m3k_2.1_B.csv

tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.1109.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.1444.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.5.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.549.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.371.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.787.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.1961.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.1396.pdb
tm_i 1.0000 tm_o 0.9686 tm_io 0.9686 composite 0.8064 ptm 0.8703 i_ptm 0.0000 plddt 0.9566 id complex.986.pdb
tm_i 1.0

In [None]:
# save SCORES to csv # test
import csv

myFile = open('demo_4m3k.csv', 'w')
writer = csv.writer(myFile)
writer.writerow(['tm_i','tm_o', 'tm_io', 'composite', 'ptm', 'i_ptm', 'plddt', 'id'])
for dictionary in SCORES:
    writer.writerow(dictionary.values())
myFile.close()


#### score complexes with the top 5, medium 5 and lowest 5 zdock scores; also compare with the pose generated by abdockgen

###### single chain

In [None]:
NAME = "4m3k"
CHAIN = "A" # this can be multiple  #change from A to B
NATIVE_PATH = f"{NAME}.pdb"
DECOY_DIR = f"{NAME}"


if save_output_pdbs:
  os.makedirs(f"{NAME}_output", exist_ok=True) #ok_exists=True)


# get data from rcsb; no decoy
%shell wget -qnc https://files.rcsb.org/download/{NAME}.pdb


# setup model
clear_mem()
af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])
# note 230425, reset is problematic

In [None]:
# upload poses_4m3k_zdock2.1_mda_split_2.1_mark_create-by-3.0.2.zip from local instead; but the file will be deleted when the runtime is ended
%shell tar xvzf drive/MyDrive/poses_{NAME}_zdock2.1_mda_split_2.1_mark_create-by-3.0.2.zip

In [None]:
## zdock score
POSE_DIR = f"poses_{NAME}_zdock2.1_mda_split_2.1_mark_create-by-3.0.2/"
top = range(1,6)
medium = range(998,1003)
bottom = range(1996,2001)
tierlist = [top,medium,bottom]
tierlist_str = ['top','medium','bottom']

for tier in range(len(tierlist)):
  print('score of ',tierlist_str[tier], '5 poses:')
  for num in tierlist[tier]:
    pose_pdb = f'complex.{num}.pdb'
    input_pdb = f'{POSE_DIR}/complex.{num}.pdb'
    if save_output_pdbs:
      output_pdb = os.path.join(f"{NAME}_output",pose_pdb)
    else:
      output_pdb = None
    af.predict(pdb=input_pdb, output_pdb=output_pdb,
                            **SETTINGS, extras={"id":pose_pdb})



score of  top 5 poses:
tm_i 1.0000 tm_o 0.9897 tm_io 0.9897 composite 0.8539 ptm 0.9099 i_ptm 0.0000 plddt 0.9482 id complex.1.pdb
tm_i 1.0000 tm_o 0.9898 tm_io 0.9898 composite 0.8542 ptm 0.9100 i_ptm 0.0000 plddt 0.9483 id complex.2.pdb
tm_i 1.0000 tm_o 0.9898 tm_io 0.9898 composite 0.8538 ptm 0.9098 i_ptm 0.0000 plddt 0.9481 id complex.3.pdb
tm_i 1.0000 tm_o 0.9898 tm_io 0.9898 composite 0.8541 ptm 0.9100 i_ptm 0.0000 plddt 0.9483 id complex.4.pdb
tm_i 1.0000 tm_o 0.9897 tm_io 0.9897 composite 0.8540 ptm 0.9100 i_ptm 0.0000 plddt 0.9483 id complex.5.pdb
score of  medium 5 poses:
tm_i 1.0000 tm_o 0.9897 tm_io 0.9897 composite 0.8540 ptm 0.9100 i_ptm 0.0000 plddt 0.9482 id complex.998.pdb
tm_i 1.0000 tm_o 0.9898 tm_io 0.9898 composite 0.8539 ptm 0.9099 i_ptm 0.0000 plddt 0.9482 id complex.999.pdb
tm_i 1.0000 tm_o 0.9897 tm_io 0.9897 composite 0.8537 ptm 0.9098 i_ptm 0.0000 plddt 0.9481 id complex.1000.pdb
tm_i 1.0000 tm_o 0.9896 tm_io 0.9896 composite 0.8539 ptm 0.9099 i_ptm 0.0000 pl

In [None]:
# abdockgen
# pose_pdb = f'4m3k_B_A_bindingseq_chain-match-sample.pdb'
# input_pdb = f'4m3k_B_A_bindingseq_chain-match-sample.pdb'
pose_pdb = f'4m3k_B_A_pred_sample-ckpt.pdb'
input_pdb = f'4m3k_B_A_pred_sample-ckpt.pdb'
if save_output_pdbs:
  output_pdb = os.path.join(f"{NAME}_output",pose_pdb)
else:
  output_pdb = None
af.predict(pdb=input_pdb, output_pdb=output_pdb,
                        **SETTINGS, extras={"id":pose_pdb})

tm_i 0.7982 tm_o 0.8076 tm_io 0.9861 composite 0.8587 ptm 0.9134 i_ptm 0.0000 plddt 0.9534 id 4m3k_B_A_pred_sample-ckpt.pdb


{'con': 0.8482070565223694,
 'dgram_cce': 1.1633045673370361,
 'exp_res': 0.0009422175935469568,
 'fape': 0.34512099623680115,
 'hard': 1.0,
 'helix': 5.154178142547607,
 'i_ptm': 0.0,
 'loss': 1.1633045673370361,
 'models': [0],
 'pae': 3.139632999897003,
 'plddt': 0.9533663392066956,
 'ptm': 0.91341233253479,
 'recycles': 1,
 'seq_ent': 705.3644409179688,
 'seqid': 1.0,
 'soft': 0.0,
 'temp': 1.0,
 'rmsd_io': 0.7177258729934692,
 'tm_i': 0.7982,
 'tm_o': 0.8076,
 'tm_io': 0.9861,
 'composite': 0.858712221308938,
 'id': '4m3k_B_A_pred_sample-ckpt.pdb'}

###### chain A,B

In [None]:
## zdock
NAME = "4m3k"
CHAIN = "A,B" # this can be multiple  #change from A to B #must change model_mode to alpha-fold-multimer
NATIVE_PATH = f"{NAME}.pdb"
DECOY_DIR = f"{NAME}"


if save_output_pdbs:
  os.makedirs(f"{NAME}_output", exist_ok=True) #ok_exists=True)


# get data from rcsb; no decoy
%shell wget -qnc https://files.rcsb.org/download/{NAME}.pdb


# setup model
clear_mem()
af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])
# note 230425, reset is problematic

In [None]:
# upload poses_4m3k_zdock2.1_mda_split_2.1_mark_create-by-3.0.2.zip from local instead; but the file will be deleted when the runtime is ended
%shell tar xvzf drive/MyDrive/poses_{NAME}_zdock2.1_mda_split_2.1_mark_create-by-3.0.2.zip

In [None]:
af.predict(pdb=NATIVE_PATH, **SETTINGS, extras={"id":NATIVE_PATH}) #test

tm_i 1.0000 tm_o 0.7436 tm_io 0.7436 composite 0.3985 ptm 0.6551 i_ptm 0.1002 plddt 0.8180 id 4m3k.pdb


{'con': 1.5922664403915405,
 'dgram_cce': 0.7734336853027344,
 'exp_res': 0.0024014548398554325,
 'fape': 0.5252897143363953,
 'hard': 1.0,
 'helix': 4.22415828704834,
 'i_ptm': 0.1001817062497139,
 'loss': 0.7734336853027344,
 'models': [0],
 'pae': 4.545825764536858,
 'plddt': 0.8179619908332825,
 'ptm': 0.6551194190979004,
 'recycles': 1,
 'seq_ent': 1014.6396484375,
 'seqid': 1.0,
 'soft': 0.0,
 'temp': 1.0,
 'rmsd_io': 11.548102378845215,
 'tm_i': 1.0,
 'tm_o': 0.7436,
 'tm_io': 0.7436,
 'composite': 0.3984675663897619,
 'id': '4m3k.pdb'}

In [None]:
## zdock score
POSE_DIR = f"poses_{NAME}_zdock2.1_mda_split_2.1_mark_create-by-3.0.2/"
top = range(1,6)
medium = range(998,1003)
bottom = range(1996,2001)
tierlist = [top,medium,bottom]
tierlist_str = ['top','medium','bottom']

for tier in range(len(tierlist)):
  print('score of ',tierlist_str[tier], '5 poses:')
  for num in tierlist[tier]:
    pose_pdb = f'complex.{num}.pdb'
    input_pdb = f'{POSE_DIR}/complex.{num}.pdb'
    if save_output_pdbs:
      output_pdb = os.path.join(f"{NAME}_output",pose_pdb)
    else:
      output_pdb = None
    af.predict(pdb=input_pdb, output_pdb=output_pdb,
                            **SETTINGS, extras={"id":pose_pdb})



score of  top 5 poses:
tm_i 0.7008 tm_o 0.7432 tm_io 0.6909 composite 0.3702 ptm 0.6544 i_ptm 0.0966 plddt 0.8188 id complex.1.pdb
tm_i 0.7009 tm_o 0.7441 tm_io 0.6984 composite 0.3744 ptm 0.6544 i_ptm 0.0966 plddt 0.8191 id complex.2.pdb


KeyboardInterrupt: ignored

In [None]:
# abdockgen #use a different native; chain name = A,H
NAME = "4m3k"
CHAIN = "A,H" # this can be multiple  #change from A to B
NATIVE_PATH = f'4m3k_B_A_bindingseq_chain-match-sample.pdb'
DECOY_DIR = f"{NAME}"


if save_output_pdbs:
  os.makedirs(f"{NAME}_output", exist_ok=True) #ok_exists=True)

# setup model
clear_mem()
af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])
# note 230425, reset is problematic


%shell cp drive/MyDrive/4m3k_B_A_bindingseq_chain-match-sample_edited.pdb ./
pose_pdb =  '4m3k_B_A_pred_sample-ckpt.pdb'
input_pdb = '4m3k_B_A_pred_sample-ckpt.pdb'
if save_output_pdbs:
  output_pdb = os.path.join(f"{NAME}_output",pose_pdb)
else:
  output_pdb = None
af.predict(pdb=input_pdb,
           output_pdb=output_pdb,
                        **SETTINGS, extras={"id":pose_pdb})

tm_i 0.8007 tm_o 0.7932 tm_io 0.9382 composite 0.7331 ptm 0.8836 i_ptm 0.3133 plddt 0.8843 id 4m3k_B_A_pred_sample-ckpt.pdb


{'con': 1.4765971899032593,
 'dgram_cce': 1.070410966873169,
 'exp_res': 0.0323721207678318,
 'fape': 0.40577933192253113,
 'hard': 1.0,
 'helix': 4.184812068939209,
 'i_ptm': 0.3133479356765747,
 'loss': 1.070410966873169,
 'models': [0],
 'pae': 3.6233732998371124,
 'plddt': 0.8843065276741982,
 'ptm': 0.8835690021514893,
 'recycles': 1,
 'seq_ent': 743.3455810546875,
 'seqid': 1.0,
 'soft': 0.0,
 'temp': 1.0,
 'rmsd_io': 10.394646644592285,
 'tm_i': 0.8007,
 'tm_o': 0.7932,
 'tm_io': 0.9382,
 'composite': 0.7330586635726956,
 'id': '4m3k_B_A_pred_sample-ckpt.pdb'}

## rank sequences (not used in this notebook)
Example: ParD and ParE are an example of a toxin and antitoxin pair of proteins. If the pair of proteins bind, the organism survives, if they do not, organism does not! Mike Laub et al. created a library of mutants that targets this interface and their measured "fitness". Let's see how well AlphaFold can predict this, using the template trick.

In [None]:
# get data
%shell wget -qnc https://files.ipd.uw.edu/krypton/5CEG_AD_trim.pdb
%shell wget -qnc https://files.ipd.uw.edu/krypton/design/Library_fitness_vs_parE3_replicate_A.csv
%shell wget -qnc https://files.ipd.uw.edu/krypton/design/Library_fitness_vs_parE3_replicate_B.csv

# lets parse the data
lib_a = dict([line.rstrip().split(",") for line in open("Library_fitness_vs_parE3_replicate_A.csv")])
lib_b = dict([line.rstrip().split(",") for line in open("Library_fitness_vs_parE3_replicate_B.csv")])
lib_ab = jax.tree_map(lambda a,b:(float(a)+float(b))/2,lib_a,lib_b)

# get sequences
seqs = {}
for mut,sco in lib_ab.items():
  seq = list("RHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLTEVRLVWSPTAKADLIDIYVMIGSENIRAADRYYDQLEARALQLADQPRMGVRRPDIRPSARMLVEAPFVLLYETVPDTDDGPVEWVEIVRVVDGRRDLNRLF")
  # mutate seq
  for i,m in zip([10,11,12,15],list(mut)): seq[i] = m
  seq = "".join(seq)
  seqs[mut] = {"seq":seq, "sco":sco}

NAME = "toxin"
if save_output_pdbs:
  os.makedirs(f"{NAME}_output",ok_exists=True)


In [None]:
# setup model
clear_mem()
af = af2rank("5CEG_AD_trim.pdb", chain="A,B", model_name=SETTINGS["model_name"])
SCORES,LABELS = [],[]

In [None]:
for label,x in seqs.items():
  if label not in LABELS:

    if save_output_pdbs:
      output_pdb = os.path.join(f"{NAME}_output",f"{label}.pdb")
    else:
      output_pdb = None

    score = af.predict(seq=x["seq"], **SETTINGS, output_pdb=output_pdb,
                       extras={"fitness":x["sco"], "id":label})
    SCORES.append(score)
    LABELS.append(label)

In [None]:
plot_me(SCORES, x="fitness", y="composite", scale_axis=False)