<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/omegafold_hacks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**OmegaFold_hacks**
**WARNING** This is an experimental notebook that implements various hacks (unintended uses) of OmegaFold. It has not been optimized or benchmarked. Use at your own risk!

For detail about the original model see: [Github](https://github.com/HeliXonProtein/OmegaFold), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.21.500999v1)


#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.

#### **Experimental Options**
- use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")
- if confidence of prediction is low, try `use_real_msa` (currently only supported for monomeric and homo-oligomeric predictions)


In [1]:
#@markdown ##Install
import os,sys,re
from IPython.utils import io
if "SETUP_DONE" not in dir():
  import torch
  device = "cuda" if torch.cuda.is_available() else "cpu"
  with io.capture_output() as captured:
    if not os.path.isdir("OmegaFold"):
      %shell git clone --quiet https://github.com/sokrypton/OmegaFold.git
      %shell cd OmegaFold; pip -q install -r requirements.txt
      %shell pip -q install py3Dmol
      %shell apt-get install aria2 -qq > /dev/null
      %shell aria2c -q -x 16 https://helixon.s3.amazonaws.com/release1.pt
      %shell mkdir -p ~/.cache/omegafold_ckpt
      %shell mv release1.pt ~/.cache/omegafold_ckpt/model.pt
      %shell wget -qnc https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz
      %shell tar xfz hhsuite-3.3.0-SSE2-Linux.tar.gz
      %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py
  os.environ['PATH'] += ":/content/bin:/content/scripts"
  SETUP_DONE = True

In [2]:
#@markdown ##Prep inputs
from string import ascii_uppercase, ascii_lowercase
import colabfold as cf
alphabet_list = list(ascii_uppercase+ascii_lowercase)

sequence = "PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK" #@param {type:"string"}
jobname = "test" #@param {type:"string"}

# filter inputs
sequence = re.sub("[^A-Z:/]", "", sequence.upper())
sequence = re.sub(":+",":",sequence)
sequence = re.sub("/+","/",sequence)
sequence = re.sub("^[:/]+","",sequence)
sequence = re.sub("[:/]+$","",sequence)
jobname = re.sub(r'\W+', '', jobname)[:50]

num_msa = 16
use_real_msa = False #@param {type:"boolean"}

sequence = re.sub("[^A-Z\/\:]", "", sequence.upper()).replace("/",":")
ID = jobname+"_"+cf.get_hash(sequence)[:5]
seqs = sequence.split(":")
lengths = [len(s) for s in seqs]
u_seqs = list(set(seqs))

if len(seqs) == 1: mode = "mono"
elif len(u_seqs) == 1: mode = "homo"
else: mode = "hetero"

if use_real_msa and mode == "hetero":
  print("ERROR: the msa mode is currently not supported for hetero-multimers")
  use_real_msa = False

if use_real_msa:
  ID_ = cf.get_hash(seqs[0])[:5]
  a3m_lines = cf.run_mmseqs2(seqs[0], ID_, filter=True)[0]
  with open(f"{ID}.a3m","w") as a3m:
    a3m.write(a3m_lines)
  
  %shell hhfilter -v 0 -qid 30 -cov 75 -diff {num_msa} -i {ID}.a3m -o {ID}.fasta
  
  a3m_lines = open(f"{ID}.fasta","r").readlines()
  with open(f"{ID}.fasta","w") as a3m:
    n = 0
    for a3m_line in a3m_lines:
      if a3m_line.startswith(">"):
        if n == 0:
          a3m.write(f">{ID}\n")
        else:
          a3m.write(a3m_line)
      else:
        a3m_line = ":".join([a3m_line.rstrip()] * len(lengths))
        a3m.write(f"{a3m_line}\n")
        n += 1
      if n == num_msa: break
    # add pseudo_msa
    while n < num_msa:
      a3m.write(f">tmp\n{sequence}\n")
      n += 1
else:
  with open(f"{ID}.fasta","w") as out:
    out.write(f">{ID}\n{sequence}\n")

In [3]:
%%time 
#@markdown ## Run OmegaFold
num_cycle = 4 #@param ["1", "2", "4", "8", "16", "32"] {type:"raw"}

if use_real_msa:
  %shell python ./OmegaFold/main.py --device={device} --num_cycle={num_cycle} --real_msa=1 {ID}.fasta .
else:
  %shell python ./OmegaFold/main.py --device={device} --num_cycle={num_cycle} --num_pseudo_msa={num_msa-1} {ID}.fasta .

def renum_pdb_str(pdb_str, Ls=None, renum=True, offset=1):
  if Ls is not None:
    L_init = 0
    new_chain = {}
    for L,c in zip(Ls, alphabet_list):
      new_chain.update({i:c for i in range(L_init,L_init+L)})
      L_init += L  

  n,num,pdb_out = 0,offset,[]
  resnum_ = None
  chain_ = None
  new_chain_ = new_chain[0]
  for line in pdb_str.split("\n"):
    if line[:4] == "ATOM":
      chain = line[21:22]
      resnum = int(line[22:22+5])
      if resnum_ is None: resnum_ = resnum
      if chain_ is None: chain_ = chain
      if resnum != resnum_ or chain != chain_:
        num += (resnum - resnum_)  
        n += 1
        resnum_,chain_ = resnum,chain
      if Ls is not None:
        if new_chain[n] != new_chain_:
          num = offset
          new_chain_ = new_chain[n]
      N = num if renum else resnum
      if Ls is None: pdb_out.append("%s%4i%s" % (line[:22],N,line[26:]))
      else: pdb_out.append("%s%s%4i%s" % (line[:21],new_chain[n],N,line[26:]))        
  return "\n".join(pdb_out)

pdb_str = renum_pdb_str(open(f"{ID}.pdb",'r').read(), Ls=lengths)
with open(f"{ID}.pdb","w") as out:
  out.write(pdb_str)

INFO:root:Loading weights from /root/.cache/omegafold_ckpt/model.pt
INFO:root:Constructing OmegaFold
INFO:root:Reading test_a5e17.fasta
INFO:root:Predicting 1th chain in test_a5e17.fasta
INFO:root:59 residues in this chain.
INFO:root:Finished prediction in 10.11 seconds.
INFO:root:Saving prediction to ./test_a5e17.pdb
INFO:root:Saved
INFO:root:Done!
CPU times: user 168 ms, sys: 37 ms, total: 205 ms
Wall time: 29.8 s


In [4]:
#@markdown ##Display
import py3Dmol


pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,
             color="pLDDT", chains=None, vmin=50, vmax=90,
             size=(800,480), hbondCutoff=4.0,
             Ls=None,
             animate=False):
  
  if chains is None:
    chains = 1 if Ls is None else len(Ls)
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])
  if animate:
    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  else:
    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  if color == "pLDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    HP = ["ALA","GLY","VAL","ILE","LEU","PHE","MET","PRO","TRP","CYS","TYR"]
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})  
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  view.zoomTo()
  if animate: view.animate()
  return view

color = "confidence" #@param ["confidence", "rainbow", "chain"]
if color == "confidence": color = "pLDDT"
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
show_pdb(pdb_str, color=color, show_sidechains=show_sidechains, show_mainchains=show_mainchains,
         Ls=lengths).show()

In [5]:
#@title Download prediction
from google.colab import files
files.download(f'{ID}.pdb')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>