<a href="https://colab.research.google.com/github/CarlWang96/CS50-IDE/blob/main/diffusion_foldcond_symmetry.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**RFdiffusion** - conditional fold generation with symmetry
---

**<font color="red">NOTE</font>** This notebook is in development, we are still working on adding all the options from the [manuscript](https://www.biorxiv.org/content/10.1101/2022.12.09.519842v2)

**instructions**:
1. select mode
2. enter info, hit the ▶️ button
 - **RFdiffusion** takes ~1min to setup, next time you run this cell it will take seconds!

3. modify the blueprint
 - use diagonal to define the SSEs (`H:helix E:sheet C:coil ?:undefined`)
 - use off-diagonal to define interactions (`0:no_contact 1:contact ?:undefined`)
 - use the textbox in the last column to define the length of each SSE
 - define the buffer length (`buff_length`) between SSEs

In [None]:
!pip install py3Dmol pyrsistent

Collecting py3Dmol
  Downloading py3Dmol-2.0.4-py2.py3-none-any.whl (12 kB)
Collecting pyrsistent
  Downloading pyrsistent-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (117 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.7/117.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: py3Dmol, pyrsistent
Successfully installed py3Dmol-2.0.4 pyrsistent-0.20.0


In [None]:
count = 1

In [None]:
#@title Generate blueprint for **RFdiffusion**
name = "test"
blueprint_mode = "manual" #@param ["manual", "automated"]
run_mode = "unconditional"

#@markdown ---
#@markdown **Manual** blueprint (define number of secondary structure `elements` (SSE))
elements = 10 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "24", "25", "28", "30", "32"] {type:"raw"}
#@markdown ---
#@markdown **Automated** blueprint (from input PDB)
pdb = "6MRR" #@param {type:"string"}
chain = "A" #@param {type:"string"}
trim_loops = True #@param {type:"boolean"}
if chain == "": chain = None

import os, time, sys

######################################################################
# SETUP RFDIFFUSION
######################################################################
if not os.path.isdir("RFdiffusion"):
  print("installing RFdiffusion...")
  # send param download into background
  os.system("apt-get install aria2")
  os.system("aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt &")

  # install RFdiffusion
  os.system("git clone https://github.com/sokrypton/RFdiffusion.git")
  os.system("pip -q install jedi omegaconf hydra-core icecream")
  os.system("pip -q install dgl -f https://data.dgl.ai/wheels/cu117/repo.html")
  os.system("cd RFdiffusion/env/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install .")

  # extras
  os.system("pip -q install py3Dmol pydssp")

if not os.path.isdir("RFdiffusion/models"):
  print("downloading RFdiffusion params...")
  os.system("mkdir RFdiffusion/models")
  models = ["Complex_Fold_base_ckpt.pt"]
  for m in models:
    while os.path.isfile(f"{m}.aria2"):
      time.sleep(5)
  os.system(f"mv {' '.join(models)} RFdiffusion/models")
  print("----------------------------------")

if 'RFdiffusion' not in sys.path:
  os.environ["DGLBACKEND"] = "pytorch"
  sys.path.append('RFdiffusion')
######################################################################

from IPython.display import display
import ipywidgets as widgets
import torch
import random, string, re
import numpy as np
import subprocess
import matplotlib.pyplot as plt
import py3Dmol
from google.colab import files, output

def get_pdb(pdb_code=None):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

def pdb_to_string(pdb_file, chains=None, models=[1]):
  '''read pdb file and return as string'''

  MODRES = {'MSE':'MET','MLY':'LYS','FME':'MET','HYP':'PRO',
            'TPO':'THR','CSO':'CYS','SEP':'SER','M3L':'LYS',
            'HSK':'HIS','SAC':'SER','PCA':'GLU','DAL':'ALA',
            'CME':'CYS','CSD':'CYS','OCS':'CYS','DPR':'PRO',
            'B3K':'LYS','ALY':'LYS','YCM':'CYS','MLZ':'LYS',
            '4BF':'TYR','KCX':'LYS','B3E':'GLU','B3D':'ASP',
            'HZP':'PRO','CSX':'CYS','BAL':'ALA','HIC':'HIS',
            'DBZ':'ALA','DCY':'CYS','DVA':'VAL','NLE':'LEU',
            'SMC':'CYS','AGM':'ARG','B3A':'ALA','DAS':'ASP',
            'DLY':'LYS','DSN':'SER','DTH':'THR','GL3':'GLY',
            'HY3':'PRO','LLP':'LYS','MGN':'GLN','MHS':'HIS',
            'TRQ':'TRP','B3Y':'TYR','PHI':'PHE','PTR':'TYR',
            'TYS':'TYR','IAS':'ASP','GPL':'LYS','KYN':'TRP',
            'CSD':'CYS','SEC':'CYS'}
  restype_1to3 = {'A': 'ALA','R': 'ARG','N': 'ASN',
                  'D': 'ASP','C': 'CYS','Q': 'GLN',
                  'E': 'GLU','G': 'GLY','H': 'HIS',
                  'I': 'ILE','L': 'LEU','K': 'LYS',
                  'M': 'MET','F': 'PHE','P': 'PRO',
                  'S': 'SER','T': 'THR','W': 'TRP',
                  'Y': 'TYR','V': 'VAL'}

  restype_3to1 = {v: k for k, v in restype_1to3.items()}

  if chains is not None:
    if "," in chains: chains = chains.split(",")
    if not isinstance(chains,list): chains = [chains]
  if models is not None:
    if not isinstance(models,list): models = [models]

  modres = {**MODRES}
  lines = []
  seen = []
  model = 1
  for line in open(pdb_file,"rb"):
    line = line.decode("utf-8","ignore").rstrip()
    if line[:5] == "MODEL":
      model = int(line[5:])
    if models is None or model in models:
      if line[:6] == "MODRES":
        k = line[12:15]
        v = line[24:27]
        if k not in modres and v in restype_3to1:
          modres[k] = v
      if line[:6] == "HETATM":
        k = line[17:20]
        if k in modres:
          line = "ATOM  "+line[6:17]+modres[k]+line[20:]
      if line[:4] == "ATOM":
        chain = line[21:22]
        if chains is None or chain in chains:
          atom = line[12:12+4].strip()
          resi = line[17:17+3]
          resn = line[22:22+5].strip()
          if resn[-1].isalpha(): # alternative atom
            resn = resn[:-1]
            line = line[:26]+" "+line[27:]
          key = f"{model}_{chain}_{resn}_{resi}_{atom}"
          if key not in seen: # skip alternative placements
            lines.append(line)
            seen.append(key)
      if line[:5] == "MODEL" or line[:3] == "TER" or line[:6] == "ENDMDL":
        lines.append(line)
  return "\n".join(lines)

def from_pdb(pdb_code=None, chains=None, trim_loops=False, return_pdb_str=False):

  import pydssp
  def process(secondary_structure, contact_map):
    secondary_structure = np.array(secondary_structure)
    # Find the start and end indices of the continuous secondary structure elements
    sse_start,sse_end = [],[]
    for i, current_element in enumerate(secondary_structure):
      if current_element in ["H", "E", "C"]:
        if i == 0 or secondary_structure[i-1] != current_element:
          sse_start.append(i)
        if i == len(secondary_structure) - 1 or secondary_structure[i+1] != current_element:
          sse_end.append(i)

    sse_types = secondary_structure[sse_start]
    sse_lengths = np.array(sse_end) - np.array(sse_start) + 1
    num_sse = len(sse_lengths)
    reduced_contact_map = np.full((num_sse, num_sse), '0', dtype=object)
    np.fill_diagonal(reduced_contact_map, sse_types)

    for i in range(num_sse):
      for j in range(num_sse):
        if i != j and sse_types[i] != "C" and sse_types[j] != "C":
          interaction_mask = np.any(contact_map[sse_start[i]:sse_end[i]+1, sse_start[j]:sse_end[j]+1])
          reduced_contact_map[i, j] = str(interaction_mask.astype(int))

    return {"txt":sse_lengths, "adj":reduced_contact_map}

  def coord_2_cb(coord):
    N,Ca,C = coord[:,0],coord[:,1],coord[:,2]
    # recreate Cb given N,Ca,C
    b = Ca - N
    c = C - Ca
    a = np.cross(b, c)
    Cb = -0.57910144*a + 0.5689693*b - 0.5441217*c + Ca
    return Cb
  pdb_filename = get_pdb(pdb_code)
  pdb_str = pdb_to_string(pdb_filename, chains=chains)
  coord = pydssp.read_pdbtext(pdb_str)

  ss = pydssp.assign(coord)
  if not trim_loops:
    ss = [("C" if s == "-" else s) for s in ss]
  cb = coord_2_cb(coord)
  con = np.sqrt(np.square(cb[:,None] - cb[None,:]).sum(-1)) < 6.0
  out = process(ss, con)
  if return_pdb_str:
    out["pdb_str"] = pdb_str
  return out

def get_adj_ss(adj, txt, buff=0):
  # select non-zero elements
  idx = []
  for i in range(len(adj)):
    if txt[i] > 0:
      idx.append(i)

  L = (len(idx)) * buff + sum(txt)
  full_adj = np.full((L,L),2)
  full_sse = np.full((L,),3)
  n = buff
  for i in idx:
    ss = {"H":0, "E":1, "C":2, "?":3}[adj[i][i]]
    full_sse[n:n+txt[i]] = ss
    m = buff
    for j in idx:
      k = str(adj[i][j])
      if i == j:
        val = {"H":0,"E":0,"C":0,"?":2}[k]
      else:
        val = {"0":0,"1":1,"?":2}[k]
      full_adj[n:n+txt[i],m:m+txt[j]] = val
      m += txt[j] + buff
    n += txt[i] + buff
  return {"adj":full_adj,"sse":full_sse}

class RFdiff_js:
  def reset_callback(self):
    self.adj = [["H" if row == col else "0" for col in range(self.elements)] for row in range(self.elements)]
    self.txt = [19 for _ in range(self.elements)]
    self.buttons["buff_length"].value = self.buff_length

  def grid_callback(self, row, col, new_value):
    if row == col:
      self.txt[row] = {"H": 19, "E": 11, "C": 3, "?": 0}[new_value]
      self.adj[row][col] = new_value
      for i in range(self.elements):
        if i != row:
          if new_value == "?":
            self.adj[row][i] = "?"
            self.adj[i][row] = "?"
          elif self.adj[i][i] != "?":
            if new_value in ["C","H"]:
              self.adj[row][i] = '0'
              self.adj[i][row] = '0'
    else:
      self.adj[row][col] = new_value
      self.adj[col][row] = new_value

  def text_callback(self, row, new_value):
    self.txt[row] = int(new_value)

  def create_html_code(self):
    def style(row, col):
      state = self.adj[row][col]
      if row == col:
        color = {"H":"red","E":"yellow","C":"lime","?":"lightgray"}[state]
        disabled = ""
      else:
        color = {"0":"white","1":"lightblue","?":"lightgray"}[state]
        if self.adj[row][row] in ["?","C"] or self.adj[col][col] in ["?","C"]:
          disabled = "disabled"
        else:
          disabled = ""
      return {"color":color,
              "text":state,
              "id":f"button-{row}-{col}",
              "disabled":disabled,
              "opacity":1 if disabled == "" else 0.2}
    html_grid = ""
    for row in range(self.elements):
        for col in range(self.elements):
            button = style(row, col)
            html_grid += f"""
            <button id="{button['id']}", style="opacity:{button["opacity"]};width:30px;height:30px;background-color:{button['color']};border: 2px solid #000;color:#000;padding:0;font-weight:bold;" onclick="buttonClick('{button['id']}', {row}, {col})" {button["disabled"]}>{button['text']}</button>
            """
        html_grid += f"""
        <input id="text-{row}" type="text" value="{self.txt[row]}" style="width:50px;height:24px; background-color:#ffffff; text-align:center; border:2px solid lightgray;" onchange="textFieldChanged({row}, this)">
        """

    self.html_code = f"""
    <div style="display: grid; grid-template-columns: repeat({self.elements + 1}, 30px); grid-gap: 2px;">{html_grid}</div>
    <button id="reset_button" style="width:62px;height:30px;background-color:#ffffff;border: 2px solid #000;color:#000;padding:0;font-weight:bold;margin-top: 2px;" onclick="reset()">reset</button>
    <script>
    function buttonClick(button_id, row, col) {{
        var button = document.getElementById(button_id);
        if (row === col) {{
            var state_mapping = {{
                "H": {{ "text": "E", "color": "yellow",    "length": "11" }},
                "E": {{ "text": "C", "color": "lime",      "length": "3" }},
                "C": {{ "text": "?", "color": "lightgray", "length": "0" }},
                "?": {{ "text": "H", "color": "red",       "length": "19"}},
            }};
            var current_state = button.textContent;
            update = state_mapping[current_state]

            button.textContent = update.text;
            button.style.backgroundColor = update.color;
            google.colab.kernel.invokeFunction("grid_callback", [row, col, update.text], {{}});

            // Update the corresponding text field value
            var textField = document.getElementById("text-" + row);
            textField.value = update.length;

            // Enable/disable off-diagonal buttons
            for (var i = 0; i < {self.elements}; i++) {{
                if (i !== row) {{
                    var row_button = document.getElementById("button".concat("-", row, "-", i));
                    var col_button = document.getElementById("button".concat("-", i, "-", row));
                    var diag_button = document.getElementById("button".concat("-", i, "-", i));

                    if (button.textContent === "C" || button.textContent === "?") {{
                        row_button.disabled = col_button.disabled = true;
                        row_button.style.opacity = col_button.style.opacity = 0.2;

                        if (button.textContent === "?") {{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'lightgray';
                            row_button.textContent = col_button.textContent = '?';
                        }} else if (button.textContent === "C" && diag_button.textContent !== "?") {{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'white';
                            row_button.textContent = col_button.textContent = '0';
                        }}
                    }} else if (button.textContent === "H") {{
                        if (diag_button.textContent == "C"){{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'white';
                            row_button.textContent = col_button.textContent = '0';
                        }} else if (diag_button.textContent !== "?"){{
                            row_button.style.backgroundColor = col_button.style.backgroundColor = 'white';
                            row_button.textContent = col_button.textContent = '0';
                            row_button.disabled = col_button.disabled = false;
                            row_button.style.opacity = col_button.style.opacity = 1;
                        }}
                    }}
                }}
            }}

        }} else {{
            var off_diag_state_mapping = {{
                "0": {{ "text": "1", "color": "lightblue" }},
                "1": {{ "text": "?", "color": "lightgray" }},
                "?": {{ "text": "0", "color": "white" }},
            }};
            var current_state = button.textContent;
            update = off_diag_state_mapping[current_state]
            var sym_button = document.getElementById("button".concat("-", col, "-", row));
            button.textContent = sym_button.textContent = update.text;
            button.style.backgroundColor = sym_button.style.backgroundColor = update.color;
            google.colab.kernel.invokeFunction("grid_callback", [row, col, update.text], {{}});
        }}
    }}
    function textFieldChanged(row, textField) {{
        var newValue = textField.value;
        google.colab.kernel.invokeFunction("text_callback", [row, newValue], {{}});
    }}
    function reset() {{
        for (var row = 0; row < {self.elements}; row++) {{
            for (var col = 0; col < {self.elements}; col++) {{
                var button = document.getElementById("button".concat("-", row, "-", col));
                if (row === col) {{
                    button.textContent = "H";
                    button.style.backgroundColor = "red";
                }} else {{
                    button.textContent = "0";
                    button.style.backgroundColor = "white";
                }}
                button.disabled = false;
                button.style.opacity = 1;
            }}
            var textField = document.getElementById("text-" + row);
            textField.value = "19";
        }}
        google.colab.kernel.invokeFunction('reset_callback', [], {{}});
    }}
    </script>
    """
class RFdiff_gui(RFdiff_js):

  def __init__(self, elements=5, adj=None, txt=None, buff_length=5, name="test"):
    self.path = self.name = name
    self.input = widgets.Output()
    self.output = widgets.Output()
    self.buff_length = buff_length

    output.register_callback("reset_callback", self.reset_callback)
    output.register_callback("grid_callback", self.grid_callback)
    output.register_callback("text_callback", self.text_callback)

    small_button_style = widgets.Layout(width='30px', height='30px', border='2px solid black')
    button_style = widgets.Layout(width='84px', height='35px', border='2px solid black')
    self.buttons = {
        "buff_length": widgets.BoundedIntText(description='buff_length', value=self.buff_length, min=0, max=20),
        "reset":       widgets.Button(description='reset',     layout=button_style),
        "animate":     widgets.Button(description='animate',   layout=button_style),
        "freeze":      widgets.Button(description='freeze',    layout=button_style),
        "download":    widgets.Button(description='download',  layout=button_style),
    }
    self.buttons["animate"].on_click(self._plot_pdb)
    self.buttons["freeze"].on_click(self._plot_pdb)
    self.buttons["download"].on_click(self._download)

    # prep inputs
    if adj is not None and txt is not None:
      self.elements = len(adj)
      self.adj, self.txt = adj,txt
    else:
      self.elements = elements
      self.reset_callback()

  def redraw(self):
    self.create_html_code()
    with self.input:
      self.input.clear_output(wait=True)
      display(
          widgets.VBox([
          widgets.HTML(self.html_code),
          widgets.Label("Options"),
          self.buttons["buff_length"],
        ])
      )

  def display_input(self):
    self.redraw()
    display(self.input)

  def display_output(self):
    display(self.output)

  def _download(self, button):
    os.system(f"zip -r {self.path}.result.zip outputs/{self.path}* outputs/traj/{self.path}*")
    files.download(f"{self.path}.result.zip")

  def _plot_pdb(self, button):
    with self.output:
      self.output.clear_output(wait=True)
      view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
      if button.description == "animate":
        pdb = f"outputs/traj/{self.path}_0_pX0_traj.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModelsAsFrames(pdb_str,'pdb')
      else:
        pdb = f"outputs/{self.path}_0.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModel(pdb_str,'pdb')
      view.setStyle({"ss":"h","chain":"A"},{'cartoon': {'color':'red'}})
      view.setStyle({"ss":"c","chain":"A"},{'cartoon': {'color':'lime'}})
      view.setStyle({"ss":"s","chain":"A"},{'cartoon': {'color':'yellow'}})
      if self.use_target:
        view.setStyle({"chain":"B"},{'cartoon': {'color':'white'}})
      view.zoomTo()
      if button.description == "animate":
        view.animate({'loop': 'backAndForth'})
      out = widgets.Output()
      with out: view.show()
      toggle = self.buttons["freeze"] if button.description == "animate" else self.buttons["animate"]
      display(widgets.VBox([out, widgets.HBox([toggle, self.buttons["download"]])]))

  def _make_path(self):
    os.makedirs(f"outputs/{self.path}", exist_ok=True)
    while os.path.exists(f"outputs/{self.path}_0.pdb"):
      self.path = self.name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))
      os.makedirs(f"outputs/{self.path}", exist_ok=True)

  def _get_adj_ss(self):
    # get unique path
    full = get_adj_ss(adj=self.adj,
                      txt=self.txt,
                      buff=self.buttons["buff_length"].value)
    self._sse = full["sse"]
    self._adj = full["adj"]

    # save results
    loc = [f"outputs/{self.path}/tmp_ss.pt",
           f"outputs/{self.path}/tmp_adj.pt"]
    torch.save(torch.from_numpy(self._sse).float(),loc[0])
    torch.save(torch.from_numpy(self._adj).float(),loc[1])


  def _get_len(self):
    #get total length of protein
    idx = []
    for i in range(len(self.adj)):
      if self.txt[i] > 0:
        idx.append(i)

    L = (len(idx)) * self.buttons["buff_length"].value + sum(self.txt)

    return L

  def diffuse(self, iterations=50, mask_loops=True, extra_cmd=None):
    self.use_target = use_target
    self.redraw()
    self._make_path()
    self._get_adj_ss()
    # run
    with self.output:
      self.output.clear_output()
      cmd = ["./RFdiffusion/run_inference.py",
             "inference.num_designs=1",
             f"inference.output_prefix=outputs/{self.path}",
             "scaffoldguided.scaffoldguided=True",
             f"scaffoldguided.scaffold_dir=outputs/{self.path}",
             f"diffuser.T={iterations}",
             f"scaffoldguided.mask_loops={mask_loops}"]

      if extra_cmd is not None:
        if "--config-name=symmetry" in extra_cmd:
          n = extra_cmd.index("--config-name=symmetry")
          cmd = [cmd[0], "--config-name=symmetry", extra_cmd[n+1]] + cmd[1:] + extra_cmd[:n]
          if len(extra_cmd) > n+2:
            cmd += extra_cmd[n+2:]
        else:
          cmd += extra_cmd

      self.cmd_str = " ".join(cmd)
      steps = iterations - 1
      self._run(self.cmd_str, "Timestep", steps)

    self._plot_pdb(self.buttons["freeze"])

  def _run(self, command, trigger, total_timesteps):
    progress = widgets.FloatProgress(min=0, max=1, description='running', bar_style='info')
    display(progress)
    pattern = re.compile(f'.*{trigger}.*')
    progress_counter = 0
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True)
    while True:
      line = process.stdout.readline()
      if not line: break
      if pattern.match(line):
        progress_counter += 1
        progress.value = progress_counter / total_timesteps
    return_code = process.wait()
    progress.description = "done"

  def _download_adj_ss(self, L):
    self.use_target = use_target
    self.redraw()
    self._make_path()
    self._get_adj_ss()

    path_n = f"outputs_{self.path}_{cyclic_sym}_{count}/"
    os.makedirs(path_n, exist_ok=True)
    loc = path_n + f"lengthAndsym_{self.path}_{cyclic_sym}_{count}.txt"
    with open(loc, 'w') as fout:
      fout.write(f"name:{self.path}_{cyclic_sym}_{count},length:{L}")

    os.system(f"cp outputs/{self.path}/*.pt {path_n}")
    os.system(f"zip -r {self.path}_{count}_{cyclic_sym}.adj.zip {path_n}*")
    files.download(f"{self.path}_{count}_{cyclic_sym}.adj.zip")

if blueprint_mode == "automated":
  pdb_feats = from_pdb(pdb, chains=chain, trim_loops=trim_loops)
  rfdiff = RFdiff_gui(**pdb_feats, name=name, buff_length=(5 if trim_loops else 0))
else:
  rfdiff = RFdiff_gui(elements, name=name)
rfdiff.display_input()

installing RFdiffusion...
downloading RFdiffusion params...
----------------------------------


Output()

In [None]:
%%time
#@title run **RFdiffusion**
iterations = 50 #@param ["25", "50", "100", "200"] {type:"raw"}
mask_loops = True #@param {type:"boolean"}
#@markdown **Optional**: specify target info (for binder design)
use_target = False #@param {type:"boolean"}
symmetry = False #@param {type:"boolean"}
target_pdb = "" #@param {type:"string"}
target_chain = "A" #@param {type:"string"}
target_hotspot = "" #@param {type:"string"}
cyclic_sym = "C1" #@param {type:"string"}
download_adj_ss = True #@param {type:"boolean"}
extra_cmd = None


if "rfdiff" in dir():
  L = rfdiff._get_len()

  if use_target:
    # prep target features
    rfdiff._make_path()
    path = f"outputs/{rfdiff.path}/target"
    os.makedirs(path, exist_ok=True)
    target = from_pdb(target_pdb, target_chain, return_pdb_str=True)
    target_pdb = f"{path}/input.pdb"
    with open(target_pdb,"w") as handle:
      handle.write(target["pdb_str"])
      full = get_adj_ss(adj=target["adj"], txt=target["txt"])
      torch.save(torch.from_numpy(full["sse"]).float(),f"{path}/ss.pt")
      torch.save(torch.from_numpy(full["adj"]).float(),f"{path}/adj.pt")

    extra_cmd = ["scaffoldguided.target_pdb=True",
                f"scaffoldguided.target_path={path}/input.pdb",
                f"scaffoldguided.target_ss={path}/ss.pt",
                f"scaffoldguided.target_adj={path}/adj.pt",
                "denoiser.noise_scale_ca=0",
                "denoiser.noise_scale_frame=0"]
    if target_hotspot != "":
      extra_cmd += [f"'ppi.hotspot_res=[{target_hotspot}]'"]
  if symmetry:
    if len(cyclic_sym.split("C")) != 2:
      print("Error, can only support cyclic symmetry for now.... Input string should be in the format 'C' followed by a number.\nDefaulting to monomeric diffusion...")
    elif L%int(cyclic_sym.split("C")[-1]) != 0:
      print(f"Make sure to have a total length {L} which is a multiple of the symmetry number {cyclic_sym.split('C')[-1]}.\nThis is not the case here, so will default to monomeric diffusion....")
    else:
      sym_cmd = [f"contigmap.contigs=[\\'{L}\\']", "--config-name=symmetry", f"inference.symmetry={cyclic_sym}"]
      if extra_cmd:
        extra_cmd += sym_cmd
      else:
        extra_cmd = sym_cmd

  rfdiff.display_output()
  if download_adj_ss:
    rfdiff._download_adj_ss(L)
  else:
    rfdiff.diffuse(iterations, mask_loops=mask_loops, extra_cmd=extra_cmd)

  count += 1

else:
  print("Error, looks like you didn't run the cell above")

Output()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

CPU times: user 38.8 ms, sys: 6.09 ms, total: 44.9 ms
Wall time: 128 ms
