<a href="https://colab.research.google.com/github/ajasja/PROPER/blob/main/ESMFold_fasta.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**ESMFold with FASTA file support**
Upload a fasta file (`*permuted-dimers.fasta`) and predict structures using ESM2.

#### **Tips and Instructions**
- Click the little ▶ play icon to the left of each cell below.
- Upload fasta file (`*permuted-dimers.fasta`) generated by the [PROPER](https://colab.research.google.com/github/ajasja/PROPER/blob/main/PROPER.ipynb) server.
- Result is zip file with all predictions. Best predictions are the ones with the highest pTM, that are permuted.

#### **Notes**
- On Tesla T4 (typical free colab GPU), max total length ~ 900
- Based on notebook by Sergey Ovchinnikov (https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/ESMFold.ipynb)



In [None]:
%%time
#@title Install
#@markdown install ESMFold, OpenFold and download Params (~5min)
version = "1" # @param ["0", "1"]
model_name = "esmfold_v0.model" if version == "0" else "esmfold.model"
import os, time
if not os.path.isfile(model_name):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system(f"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &")

  if not os.path.isfile("finished_install"):
    # install libs
    print("installing libs...")
    os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
    os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

    print("installing openfold...")
    # install openfold
    os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")

    print("installing esmfold...")
    # install esmfold
    os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")
    os.system("touch finished_install")

  # wait for Params to finish downloading...
  while not os.path.isfile(model_name):
    time.sleep(5)
  if os.path.isfile(f"{model_name}.aria2"):
    print("downloading params...")
  while os.path.isfile(f"{model_name}.aria2"):
    time.sleep(5)

#!pip install biopython> /dev/null
#!pip install --upgrade git+https://github.com/ajasja/PROPER.git > /dev/null
#import proper as pr
def read_fasta(filename):
    """Reads a FASTA file and returns a dictionary of sequences."""
    sequences = {}
    with open(filename, "r") as file:
        seq_id = None
        sequence = []
        for line in file:
            line = line.strip()
            if line.startswith(">"):  # Header line
                if seq_id:  # Save previous sequence
                    sequences[seq_id] = "".join(sequence)
                seq_id = line[1:]  # Remove ">"
                sequence = []
            else:
                sequence.append(line)
        if seq_id:  # Save last sequence
            sequences[seq_id] = "".join(sequence)
    return sequences

In [None]:
#@title Upload fasta file (`*-permuted-dimers.fasta`)
#@markdown Upload only a single fasta file at a time.
from google.colab import files

uploaded = files.upload()  # Upload the FASTA file
if len(uploaded.keys())>1:
  print('WARNING: multiple files uploaded, only the first will be processed.')
fasta_filename = list(uploaded.keys())[0]  # Get uploaded filename

from pathlib import Path
JOB_DIR = Path(fasta_filename).stem

sequences_dict = read_fasta(fasta_filename)
sequences_dict

In [None]:
#@title ##Run **ESMFold** structure prediction
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
import torch
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax
import gc

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()

alphabet_list = list(ascii_uppercase+ascii_lowercase)
num=1
total_num = len(sequences_dict.items())
for seq_id, sequence in sequences_dict.items():
    print(f"Processing {seq_id}. Seq {num} of {total_num}...")
    num_recycles = 3
    # Sanitize job name and sequence
    jobname = re.sub(r'\W+', '', seq_id)[:50]
    sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper())

    # Handle multiple copies (optional)
    copies = 1  # Adjust as needed
    sequence = ":".join([sequence] * copies)

    # Generate sequence hash
    ID = jobname
    hash = get_hash(sequence)[:5]

    # Set up model parameters
    seqs = sequence.split(":")
    length = sum(len(s) for s in seqs)

    # Determine model mode
    u_seqs = list(set(seqs))
    mode = "mono" if len(seqs) == 1 else "homo" if len(u_seqs) == 1 else "hetero"

    # Load model if not already loaded
    if "model" not in dir() or model_name != model_name_:
        if "model" in dir():
            del model
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        model = torch.load(model_name, weights_only=False)
        model.eval().cuda().requires_grad_(False)
        model_name_ = model_name

    # Adjust chunk size for large sequences
    chunk_size = 128
    if length > 700:
      chunk_size = 64
    if length > 1400:
      chunk_size = 16

    # Set chunk size
    model.set_chunk_size(chunk_size)

    try:
      # Run inference
      torch.cuda.empty_cache()
      output = model.infer(sequence,
                          num_recycles=num_recycles,
                          chain_linker="X" * 25,
                          residue_index_offset=512)

      # Process output
      pdb_str = model.output_to_pdb(output)[0]
      output = tree_map(lambda x: x.cpu().numpy(), output)
      ptm = output["ptm"][0]
      plddt = output["plddt"][0, ..., 1].mean()
      O = parse_output(output)

      print(f'Finished {seq_id}: ptm={ptm:.3f}, plddt={plddt:.0f}')

      # Save results
      os.makedirs(JOB_DIR, exist_ok=True)
      prefix = f"{JOB_DIR}/{ID}_ptm{ptm:.3f}_plddt{plddt:.3f}_r{num_recycles}_{hash}"
      np.savetxt(f"{prefix}.pae.txt", O["pae"], "%.3f")
      with open(f"{prefix}.pdb", "w") as out:
          out.write(pdb_str)
    except Exception as e:
      print(e)
    num = num + 1
    print(f"Results saved for {seq_id}.\n")

In [None]:
#@title Interactive Structure Viewer
import ipywidgets as widgets
from IPython.display import display, clear_output

# Collect all generated PDBs in the current runtime
import glob
import py3Dmol
pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]



def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,
             color="pLDDT", chains=None, vmin=50, vmax=90,
             size=(800,480), hbondCutoff=4.0,
             Ls=None,
             animate=False):

  if chains is None:
    chains = 1 if Ls is None else len(Ls)
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])
  if animate:
    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  else:
    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  if color == "pLDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  view.zoomTo()
  if animate: view.animate()
  return view
# Find all generated PDB files (assuming naming pattern used earlier)
pdb_files = sorted(glob.glob(f"{JOB_DIR}/*.pdb"))

# Load all PDB contents
pdb_dict = {}
for pdb_file in pdb_files:
    with open(pdb_file) as f:
        pdb_dict[os.path.basename(pdb_file)] = f.read()

# Create dropdown and display button
pdb_selector = widgets.Dropdown(
    options=list(pdb_dict.keys()),
    description='Structure:',
    layout=widgets.Layout(width='600px')
)

button = widgets.Button(description="Show Structure", button_style='success')

# Output widget to hold the viewer
output = widgets.Output()

# Callback to update the viewer
def on_button_click(b):
    with output:
        clear_output(wait=True)
        pdb_str = pdb_dict[pdb_selector.value]
        color = "chain" #@param ["confidence", "rainbow", "chain"]
        if color == "confidence": color = "pLDDT"
        show_sidechains = False #@param {type:"boolean"}
        show_mainchains = False #@param {type:"boolean"}
        show_pdb(pdb_str,
                 color=color,
                 show_sidechains=show_sidechains,
                 show_mainchains=show_mainchains,
                 Ls=None).show()

button.on_click(on_button_click)
pdb_selector.observe(on_button_click)
# Display UI

display(widgets.VBox([pdb_selector, output]))
on_button_click(None)

In [None]:
#@title Download all predictions
from google.colab import files
import os
import zipfile
import glob
import os

# Name of the output zip file
zip_filename = f"{JOB_DIR}.zip"

# Get all .pdb files in the current directory
pdb_files = glob.glob(f"{JOB_DIR}/*.pdb")

# Create a ZIP file and add each .pdb file
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for pdb_file in pdb_files:
        zipf.write(pdb_file, arcname=os.path.basename(pdb_file))
        print(f"Added: {pdb_file}")

print(f"\nCreated ZIP archive: {zip_filename}")

# Trigger download
files.download(zip_filename)