<a href="https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/RoseTTAFold2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**RoseTTAFold2**
RoseTTAFold2 is a method for protein structure prediction.


#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.
- use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")


**<font color="red">NOTE:</font>** This notebook is in active development, we are still working on adding all the options from the [manuscript](https://www.biorxiv.org/content/10.1101/2023.05.24.542179v1) (such as templates).


In [None]:
#@title setup **RoseTTAFold2** (~1m)
%%time
import os, time, sys
if not os.path.isfile("RF2_apr23.pt"):
  # send param download into background
  os.system("(apt-get install aria2; aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/RF2_apr23.pt) &")

if not os.path.isdir("RoseTTAFold2"):
  print("installing RoseTTAFold2...")
  os.system("git clone https://github.com/sokrypton/RoseTTAFold2.git")
  os.system("pip -q install py3Dmol")
  os.system("pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html")
  os.system("cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install .")
  os.system("wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py")

  # install hhsuite
  print("installing hhsuite...")
  os.makedirs("hhsuite", exist_ok=True)
  os.system(f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/")


if os.path.isfile(f"RF2_apr23.pt.aria2"):
  print("downloading RoseTTAFold2 params...")
  while os.path.isfile(f"RF2_apr23.pt.aria2"):
    time.sleep(5)

if not "IMPORTED" in dir():
  if 'RoseTTAFold2/network' not in sys.path:
    os.environ["DGLBACKEND"] = "pytorch"
    sys.path.append('RoseTTAFold2/network')
  if "hhsuite" not in os.environ['PATH']:
    os.environ['PATH'] += ":hhsuite/bin:hhsuite/scripts"

  import matplotlib.pyplot as plt
  from google.colab import files
  import numpy as np
  from parsers import parse_a3m
  from api import run_mmseqs2
  import py3Dmol
  import torch
  from string import ascii_uppercase, ascii_lowercase
  import hashlib, re, os
  import random

  def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
  alphabet_list = list(ascii_uppercase+ascii_lowercase)
  from collections import Counter

  IMPORTED = True

def get_msa(seq, jobname, cov=50, id=90, mode="unpaired_paired"):

  seqs = [seq] if isinstance(seq,str) else seq
  counts = Counter(seqs)
  u_seqs = list(counts.keys())
  u_nums = list(counts.values())
  first_seq = "/".join(sum([[x]*n for x,n in zip(u_seqs,u_nums)],[]))
  msa = [first_seq]
  path = os.path.join(jobname,"msa")
  os.makedirs(path, exist_ok=True)
  if mode in ["paired","unpaired_paired"] and len(u_seqs) > 1:
    print("getting paired MSA")
    out_paired = run_mmseqs2(u_seqs, f"{path}/", use_pairing=True)
    headers, sequences = [],[]
    for a3m_lines in out_paired:
      n = -1
      for line in a3m_lines.split("\n"):
        if len(line) > 0:
          if line.startswith(">"):
            n += 1
            if len(headers) < (n + 1):
              headers.append([])
              sequences.append([])
            headers[n].append(line)
          else:
            sequences[n].append(line)
    # filter MSA
    with open(f"{path}/paired_in.a3m","w") as handle:
      for n,sequence in enumerate(sequences):
        handle.write(f">n{n}\n{''.join(sequence)}\n")
    os.system(f"hhfilter -i {path}/paired_in.a3m -id {id} -cov {cov} -o {path}/paired_out.a3m")
    with open(f"{path}/paired_out.a3m","r") as handle:
      for line in handle:
        if line.startswith(">"):
          n = int(line[2:])
          xs = sequences[n]
          xs = ['/'.join([x]*num) for x,num in zip(xs,u_nums)]
          msa.append('/'.join(xs))
  
  if mode in ["unpaired","unpaired_paired"] or len(u_seqs) == 1:
    print("getting unpaired MSA")
    out = run_mmseqs2(u_seqs,f"{path}/")
    Ls = [len(seq) for seq in u_seqs]
    for n,a3m_lines in enumerate(out):
      with open(f"{path}/in_{n}.a3m","w") as handle:
        handle.write(a3m_lines)
      # filter
      os.system(f"hhfilter -i {path}/in_{n}.a3m -id {id} -cov {cov} -o {path}/out_{n}.a3m")
      with open(f"{path}/out_{n}.a3m","r") as handle:
        for line in handle:
          if not line.startswith(">"):
            xs = ['-'*l for l in Ls]
            xs[n] = line.rstrip()
            xs = ['/'.join([x]*num) for x,num in zip(xs,u_nums)]
            msa.append('/'.join(xs))
            #xs_0 = ['/'.join([x]+["-"*l]*(num-1)) for x,l,num in zip(xs,Ls,u_nums)]
            #xs_1 = ['/'.join(["-"*l]+[x]*(num-1)) for x,l,num in zip(xs,Ls,u_nums)]
            #msa.append('/'.join(xs_0))
            #msa.append('/'.join(xs_1))

  with open(f"{jobname}/msa.a3m","w") as handle:
    for n,sequence in enumerate(msa):
      handle.write(f">n{n}\n{sequence}\n")

In [None]:
#@title ###run **RoseTTAFold2**
sequence = "PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK" #@param {type:"string"}
jobname = "test" #@param {type:"string"}

#@markdown symmetry settings
sym = "C" #@param ["C", "D", "T", "I", "O"]
order = 1 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"] {type:"raw"}
use_diag = False #@param {type:"boolean"}

#@markdown msa settings
msa_method = "mmseqs2" #@param ["mmseqs2","single_sequence","custom_a3m"]
pair_mode = "unpaired_paired" #@param ["unpaired_paired","paired","unpaired"] {type:"string"}

#@markdown RoseTTAFold2 settings
num_recycles = 3 #@param [0, 1, 3, 6, 12, 24] {type:"raw"}
use_mlm = True #@param {type:"boolean"}
random_seed = 0 #@param {type:"integer"}
num_models = 1 #@param ["1", "2", "4", "8", "16", "32"] {type:"raw"}

if sym in ["T","I","O"]:
  order = ""

# process
sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
sequence = re.sub(":+",":",sequence)
sequence = re.sub("^[:]+","",sequence)
sequence = re.sub("[:]+$","",sequence)

sequences = sequence.replace(":","/").split("/")
lengths = [len(s) for s in sequences]
sequence = re.sub("[^A-Z]","",sequence)

symm = sym + str(order)
jobname = jobname+"_"+symm+"_"+get_hash(sequence)[:5]

print(f"jobname: {jobname}")
print(f"lengths: {lengths}")

os.makedirs(jobname, exist_ok=True)
if msa_method == "mmseqs2":
  get_msa(sequences, jobname, mode=pair_mode)

elif msa_method == "single_sequence":
  with open(f"{jobname}/msa.a3m","w") as a3m:
    a3m.write(f">{jobname}\n{sequence}\n")

elif msa_method == "custom_a3m":
  print("upload custom a3m")
  msa_dict = files.upload()
  lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()
  a3m_lines = []
  for line in lines:
    line = line.replace("\x00","")
    if len(line) > 0 and not line.startswith('#'):
      a3m_lines.append(line)

  with open(f"{jobname}/msa.a3m","w") as a3m:
    a3m.write("\n".join(a3m_lines))

if not "pred" in dir():
  from predict import Predictor
  print("initializing RoseTTAFold2...")
  model_params = "RF2_apr23.pt"
  if (torch.cuda.is_available()):
    pred = Predictor(model_params, torch.device("cuda:0"))
  else:
    print ("WARNING: using CPU")
    pred = Predictor(model_params, torch.device("cpu"))

best_pae = None
best_seed = None
for seed in range(random_seed,random_seed+num_models):
  torch.manual_seed(seed)
  random.seed(seed)
  np.random.seed(seed)
  npz = f"{jobname}/rf2_seed{seed}_00.npz"
  pred.predict(inputs=[f"{jobname}/msa.a3m"],
              out_prefix=f"{jobname}/rf2_seed{seed}",
              symm=symm,
              ffdb=None, #TODO (templates),
              n_recycles=num_recycles,
              msa_mask=0.15 if use_mlm else 0.0,
              symm_diag=use_diag)
  pae = np.load(npz)["pae"].mean()
  if best_pae is None or pae < best_pae:
    best_pae = pae
    best_seed = seed

In [None]:
#@title Display 3D structure {run: "auto"}

color = "plddt" #@param ["plddt","rainbow"]
import py3Dmol
from string import ascii_uppercase,ascii_lowercase
alphabet_list = list(ascii_uppercase+ascii_lowercase)

def plot_pdb(pdb):
  hbondCutoff = 4.0
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
  pdb_str = open(pdb,'r').read()
  view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  if color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  else:
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  view.zoomTo()
  view.show()
plot_pdb(f"{jobname}/rf2_seed{best_seed}_00_pred.pdb")
output = dict(np.load(f"{jobname}/rf2_seed{best_seed}_00.npz"))
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("predicted LDDT")
plt.plot(output["lddt"])
plt.ylim(0,1.0)
plt.subplot(1,2,2)
plt.title("predicted alignment error")
plt.imshow(output["pae"],vmin=0,vmax=30,cmap="bwr")
plt.show()

In [35]:
#@title Download prediction

#@markdown Once this cell has been executed, a zip-archive with 
#@markdown the obtained prediction will be automatically downloaded 
#@markdown to your computer.

# add settings file
settings_path = f"{jobname}/settings.txt"
with open(settings_path, "w") as text_file:
  text_file.write(f"method=RoseTTAFold2\n")
  text_file.write(f"sequence={sequence}\n")
  text_file.write(f"sym={sym}\n")
  text_file.write(f"order={order}\n")
  text_file.write(f"use_diag={use_diag}\n")
  text_file.write(f"random_seed={random_seed}\n")
  text_file.write(f"msa_method={msa_method}\n")
  text_file.write(f"num_recycles={num_recycles}\n")
  text_file.write(f"use_mlm={use_mlm}\n")
  text_file.write(f"num_models={num_models}\n")

# --- Download the predictions ---
os.system(f"zip -r {jobname}.zip {jobname}")
files.download(f'{jobname}.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>