#**ESMFold**
for more details see: [Github](https://github.com/facebookresearch/esm/tree/main/esm), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1)

#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.
- use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")
- for homo-oligomeric predictions, set copies > 1
- See [experimental notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/ESMFold_advanced.ipynb) for more advanced options (like sampling).

#### **Colab Limitations**
- For short monomeric proteins under the length 400, consider using [ESMFold API](https://esmatlas.com/resources?action=fold) (no need for GPU, super fast!)
- On Tesla T4 (typical free colab GPU), max total length ~ 900

In [1]:
%%time
#@title install
#@markdown install ESMFold, OpenFold and download Params (~2min 30s)

import os, time
if not os.path.isfile("esmfold.model"):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &")

  # install libs
  print("installing libs...")
  os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
  os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

  print("installing openfold...")
  # install openfold
  os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")

  print("installing esmfold...")
  # install esmfold
  os.system(f"pip install -q git+https://github.com/sokrypton/esm.git@beta")

  # wait for Params to finish downloading...
  if not os.path.isfile("esmfold.model"):
    # backup source!
    os.system("aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model")
  else:
    while os.path.isfile("esmfold.model.aria2"):
      time.sleep(5)

installing libs...
installing openfold...
installing esmfold...
CPU times: user 759 ms, sys: 119 ms, total: 878 ms
Wall time: 3min 28s


In [4]:
#@title ##run **ESMFold**
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax
import time

def calculate_min_distance(xyz_a, xyz_b):
    distances = np.linalg.norm(xyz_a[:, np.newaxis] - xyz_b[np.newaxis, :], axis=-1)
    return np.min(distances)

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  if "contacts" in output["lm_output"]:
    lm_contacts = output["lm_output"]["contacts"].astype(float)[0]
    o["lm_contacts"] = lm_contacts[mask,:][:,mask]
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase+ascii_lowercase)

jobname = ['6iqj'] #@param {type:"raw"}

sequence = ['MSWQSYVDDHLMCEVEGNHLTHAAIFGQDGSVWAQSSAFPQLKPAEIAGINKDFEEAGHLAPTGLFLGGEKYMVVQGEAGAVIRGKKGPGGVTIKKTTQALVFGIYDEPMTGGQCNLVVERLGDYLIESGL/VPPPPPPPPPLP'] #@param {type:"raw"}

for this_jobname, this_sequence in zip(jobname, sequence):
  this_sequence = re.sub("[^A-Z:]", "", this_sequence.replace("/",":").upper())
  this_sequence = re.sub(":+",":",this_sequence)
  this_sequence = re.sub("^[:]+","",this_sequence)
  this_sequence = re.sub("[:]+$","",this_sequence)

  #@markdown ---
  #@markdown ###**Advanced Options**
  initial_recycles = 3 #@param ["0", "1", "2", "3", "6", "12"] {type:"raw"}
  min_distance_threshold = 8.0
  min_distance = float('inf')
  get_LM_contacts = False

  #@markdown **multimer options (experimental)**
  #@markdown - use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")
  #@markdown - for homo-oligomeric predictions, set copies > 1

  copies = 1 #@param {type:"integer"}
  chain_linker = 30 #@param {type:"number"}
  if copies == "" or copies <= 0: copies = 1
  this_sequence = ":".join([this_sequence] * copies)

  ID = this_jobname+"_"+get_hash(this_sequence)[:5]
  seqs = this_sequence.split(":")
  lengths = [len(s) for s in seqs]
  length = sum(lengths)
  print("length",length)

  u_seqs = list(set(seqs))
  if len(seqs) == 1: mode = "mono"
  elif len(u_seqs) == 1: mode = "homo"
  else: mode = "hetero"

  if "model" not in dir():
    import torch
    model = torch.load("esmfold.model")
    model.cuda().requires_grad_(False)

  # optimized for Tesla T4
  if length > 700:
    model.trunk.set_chunk_size(64)
  else:
    model.trunk.set_chunk_size(128)

  best_pdb_str = None
  best_ptm = 0
  best_output = None
  traj = []

  initial_seed = int(time.time()) % (2**32 - 1)
  torch.manual_seed(initial_seed)

  torch.cuda.empty_cache()

  seed = "default"
  mask_rate = 0.0
  model.train(False)

  num_recycles = initial_recycles
  while True:
    output = model.infer(this_sequence,
                        num_recycles=num_recycles,
                        chain_linker="X"*chain_linker,
                        residue_index_offset=512,
                        mask_rate=mask_rate,
                        return_contacts=get_LM_contacts)

    pdb_str = model.output_to_pdb(output)[0]
    output = tree_map(lambda x: x.cpu().numpy(), output)
    ptm = output["ptm"][0]
    plddt = output["plddt"][0,:,1].mean()
    traj.append(parse_output(output))
    print(f'{seed} ptm: {ptm:.3f} plddt: {plddt:.1f}')
    if ptm > best_ptm:
      best_pdb_str = pdb_str
      best_ptm = ptm
      best_output = output
    os.system(f"mkdir -p {ID}")

    pdb_filename = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_{seed}.pdb"

    with open(pdb_filename,"w") as out:
      out.write(pdb_str)

    chaina_len=len(output['chain_index'][0])-30-output['chain_index'][0].sum()
    chainb_len=output['chain_index'][0].sum()
    chaina_xyz=output["positions"][-1,0,:,1][:chaina_len]
    chainb_xyz=output["positions"][-1,0,:,1][-chainb_len:]
    min_distance = calculate_min_distance(chaina_xyz, chainb_xyz)
    print(f'{seed} Minimum distance between chains A and B: {min_distance:.2f} Å')

    if min_distance < min_distance_threshold or num_recycles >= 12:
      break

    num_recycles += 1

length 143
default ptm: 0.662 plddt: 76.7
default Minimum distance between chains A and B: 19.65 Å
default ptm: 0.659 plddt: 75.6
default Minimum distance between chains A and B: 8.99 Å
default ptm: 0.670 plddt: 75.8
default Minimum distance between chains A and B: 5.33 Å
CPU times: user 25.7 s, sys: 178 ms, total: 25.9 s
Wall time: 26 s
Parser   : 239 ms
