#**ESMFold_advanced batch version**

This notebook is a batch implementation version of a ColabFold notebook, and is intended to be run on Google Colaboratory. - Xiaozhe Ding (dingxiaozhe@gmail.com)

The original notebook by Sergey Ovchinnikov can be found [here](https://github.com/sokrypton/ColabFold/blob/main/ESMFold.ipynb)

for more details regarding ESMFold see: [Github](https://github.com/facebookresearch/esm/tree/main/esm), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1)

#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.

#### **Colab Limitations**
- On Tesla T4 (typical free colab GPU), max total length ~ 900



In [None]:
#@title ##Mount google drive (run once only)
#Mount google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%time
#@title ##Installation
#@markdown install ESMFold, OpenFold and download Params (~2min 30s)





#Installation
import os, time
if not os.path.isfile("esmfold.model"):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &")

  # install libs
  os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol")
  os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

  # install openfold
  commit = "6908936b68ae89f67755240e2f588c09ec31d4c8"
  os.system(f"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}")

  # install esmfold
  os.system(f"pip install -q git+https://github.com/sokrypton/esm.git@beta")

  # wait for Params to finish downloading...
  if not os.path.isfile("esmfold.model"):
    # backup source!
    os.system("aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model")
  else:
    while os.path.isfile("esmfold.model.aria2"):
      time.sleep(5)

CPU times: user 546 ms, sys: 187 ms, total: 733 ms
Wall time: 2min 39s


In [None]:
#@title ##Load utilitie functions from ColabFold

#@markdown utility functions from colabfold

import json
import logging
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

from absl import logging as absl_logging
from importlib_metadata import distribution
from tqdm import TqdmExperimentalWarning



# parse_fasta from colabfold.batch
def parse_fasta(fasta_string: str) -> Tuple[List[str], List[str]]:
    """Parses FASTA string and returns list of strings with amino-acid sequences.
    Arguments:
      fasta_string: The string contents of a FASTA file.
    Returns:
      A tuple of two lists:
      * A list of sequences.
      * A list of sequence descriptions taken from the comment lines. In the
        same order as the sequences.
    """
    sequences = []
    descriptions = []
    index = -1
    for line in fasta_string.splitlines():
        line = line.strip()
        if line.startswith("#"):
            continue
        if line.startswith(">"):
            index += 1
            descriptions.append(line[1:])  # Remove the '>' at the beginning.
            sequences.append("")
            continue
        elif not line:
            continue  # Skip blank lines.
        sequences[index] += line

    return sequences, descriptions

# get_queries from colabfold.batch
def get_queries(
    input_path: Union[str, Path], sort_queries_by: str = "length"
) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]:
    """Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple
    of job name, sequence and the optional a3m lines"""

    input_path = Path(input_path)
    if not input_path.exists():
        raise OSError(f"{input_path} could not be found")

    if input_path.is_file():
        if input_path.suffix == ".csv" or input_path.suffix == ".tsv":
            sep = "\t" if input_path.suffix == ".tsv" else ","
            df = pandas.read_csv(input_path, sep=sep)
            assert "id" in df.columns and "sequence" in df.columns
            queries = [
                (seq_id, sequence.upper().split(":"), None)
                for seq_id, sequence in df[["id", "sequence"]].itertuples(index=False)
            ]
            for i in range(len(queries)):
                if len(queries[i][1]) == 1:
                    queries[i] = (queries[i][0], queries[i][1][0], None)
        elif input_path.suffix == ".a3m":
            (seqs, header) = parse_fasta(input_path.read_text())
            if len(seqs) == 0:
                raise ValueError(f"{input_path} is empty")
            query_sequence = seqs[0]
            # Use a list so we can easily extend this to multiple msas later
            a3m_lines = [input_path.read_text()]
            queries = [(input_path.stem, query_sequence, a3m_lines)]
        elif input_path.suffix in [".fasta", ".faa", ".fa"]:
            (sequences, headers) = parse_fasta(input_path.read_text())
            queries = []
            for sequence, header in zip(sequences, headers):
                sequence = sequence.upper()
                if sequence.count(":") == 0:
                    # Single sequence
                    queries.append((header, sequence, None))
                else:
                    # Complex mode
                    queries.append((header, sequence.upper().split(":"), None))
        else:
            raise ValueError(f"Unknown file format {input_path.suffix}")
    else:
        assert input_path.is_dir(), "Expected either an input file or a input directory"
        queries = []
        for file in sorted(input_path.iterdir()):
            #troubleshooting
            print("Parsing fasta file {}".format(file))

            if not file.is_file():
                continue
            if file.suffix.lower() not in [".a3m", ".fasta", ".faa"]:
                logger.warning(f"non-fasta/a3m file in input directory: {file}")
                continue
            (seqs, header) = parse_fasta(file.read_text())
            if len(seqs) == 0:
                logger.error(f"{file} is empty")
                continue
            query_sequence = seqs[0]
            if len(seqs) > 1 and file.suffix in [".fasta", ".faa", ".fa"]:
                logger.warning(
                    f"More than one sequence in {file}, ignoring all but the first sequence"
                )

            if file.suffix.lower() == ".a3m":
                a3m_lines = [file.read_text()]
                queries.append((file.stem, query_sequence.upper(), a3m_lines))
            else:
                if query_sequence.count(":") == 0:
                    # Single sequence
                    queries.append((file.stem, query_sequence, None))
                else:
                    # Complex mode
                    queries.append((file.stem, query_sequence.upper().split(":"), None))

    # sort by seq. len
    if sort_queries_by == "length":
        queries.sort(key=lambda t: len(t[1]))
    elif sort_queries_by == "random":
        random.shuffle(queries)
    is_complex = False
    for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries):
        if isinstance(query_sequence, list):
            is_complex = True
            break
        if a3m_lines is not None and a3m_lines[0].startswith("#"):
            a3m_line = a3m_lines[0].splitlines()[0]
            tab_sep_entries = a3m_line[1:].split("\t")
            if len(tab_sep_entries) == 2:
                query_seq_len = tab_sep_entries[0].split(",")
                query_seq_len = list(map(int, query_seq_len))
                query_seqs_cardinality = tab_sep_entries[1].split(",")
                query_seqs_cardinality = list(map(int, query_seqs_cardinality))
                is_single_protein = (
                    True
                    if len(query_seq_len) == 1 and query_seqs_cardinality[0] == 1
                    else False
                )
                if not is_single_protein:
                    is_complex = True
                    break
    return queries, is_complex

# TqdmHandler from colabfold.utils (needed for setup_logging)
class TqdmHandler(logging.StreamHandler):
    """https://stackoverflow.com/a/38895482/3549270"""

    def __init__(self):
        logging.StreamHandler.__init__(self)

    def emit(self, record):
        # We need the native tqdm here
        from tqdm import tqdm

        msg = self.format(record)
        tqdm.write(msg)

# setup_logging from colabfold.utils
def setup_logging(log_file: Path):
    log_file.parent.mkdir(exist_ok=True, parents=True)
    root = logging.getLogger()
    if root.handlers:
        for handler in root.handlers:
            root.removeHandler(handler)
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(message)s",
        handlers=[TqdmHandler(), logging.FileHandler(log_file)],
    )
    # otherwise jax will tell us about its search for devices
    absl_logging.set_verbosity("error")
    warnings.simplefilter(action="ignore", category=TqdmExperimentalWarning)

#from colabfold.utils
def safe_filename(file: str) -> str:
    return "".join([c if c.isalnum() or c in ["_", ".", "-"] else "_" for c in file])

# default_data_dir from colabfold.
import appdirs
default_data_dir = Path(appdirs.user_cache_dir(__package__ or "colabfold"))


In [None]:
#@title ##Load input fastas
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]
  
  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  if "contacts" in output["lm_output"]:
    lm_contacts = output["lm_output"]["contacts"].astype(float)[0]
    o["lm_contacts"] = lm_contacts[mask,:][:,mask]
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase+ascii_lowercase)


#################
### Load data ###
#################

#@title Input protein sequence, then hit `Runtime` -> `Run all`

#input_dir = '/content/drive/MyDrive/AF2/AAV100_stage_1/stage_1_grouping_2_input_fasta' #@param {type:"string"}
input_dir = '/content/drive/MyDrive/input_fasta' #@param {type:"string"}

#result_dir = '/content/drive/MyDrive/AF2/ESMFold_AAV100_stage1_results' #@param {type:"string"}
result_dir = '/content/drive/MyDrive/AF2/results' #@param {type:"string"}

#Load queries
queries, is_complex = get_queries(input_dir)


# from batch.py
data_dir = default_data_dir
data_dir = Path(data_dir)
result_dir = Path(result_dir)
result_dir.mkdir(exist_ok=True)

# jobname = "test" #@param {type:"string"}
# jobname = re.sub(r'\W+', '', jobname)[:50]

# sequence = "GWSTELEKHREELKEFLKKEGITNVEIRIDNGRLEVRVEGGTERLKRFLEELRQKLEKKGYTVDIKIE" #@param {type:"string"}
# sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
# sequence = re.sub(":+",":",sequence)
# sequence = re.sub("^[:]+","",sequence)
# sequence = re.sub("[:]+$","",sequence)

#@markdown ---
#@markdown ###**Advanced Options**
num_recycles = 3 #@param ["0", "1", "2", "3", "6", "12"] {type:"raw"}
# get_LM_contacts = False #@param {type:"boolean"}

# copies = 1 #@param {type:"integer"}
glycine_linker_length = 30 #@param {type:"number"}
# if copies == "" or copies <= 0: copies = 1
# sequence = ":".join([sequence] * copies)

#@markdown **sampling options (experimental)**
#@markdown - Samples are generated via random masking (defined by `masking_rate`) 
#@markdown of input sequence (stochastic_mode="LM") and/or via dropout within structure module (stochastic_mode="SM").
# samples = None #@param ["None", "1", "4", "8", "16", "32", "64"] {type:"raw"}
# masking_rate = 0.15 #@param {type:"number"}
# stochastic_mode = "LM" #@param ["LM", "LM_SM", "SM"]

# ID = jobname+"_"+get_hash(sequence)[:5]
# seqs = sequence.split(":")
# lengths = [len(s) for s in seqs]
# length = sum(lengths)
# print("length",length)

# u_seqs = list(set(seqs))
# if len(seqs) == 1: mode = "mono"
# elif len(u_seqs) == 1: mode = "homo"
# else: mode = "hetero"

if "model" not in dir():
  import torch
  model = torch.load("esmfold.model")
  model.cuda().requires_grad_(False)

# # optimized for Tesla T4
# if length > 700:
#   model.trunk.set_chunk_size(64)
# else:
#   model.trunk.set_chunk_size(128)

best_pdb_str = None
best_ptm = 0
best_output = None
traj = []

num_samples = 1 if samples is None else samples


In [None]:
#@title ##Run **ESMFold** in batch

###########
## Batch ##
###########


logger = logging.getLogger(__name__)

keep_existing_results = False

if 'logging_setup' not in globals():
    setup_logging(Path(result_dir).joinpath("log.txt"))
    logging_setup = True


for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries):
    jobname = safe_filename(raw_jobname)
    # In the colab version and with --zip we know we're done when a zip file has been written
    result_zip = result_dir.joinpath(jobname).with_suffix(".result.zip")
    if keep_existing_results and result_zip.is_file():
        logger.info(f"Skipping {jobname} (result.zip)")
        continue
    # In the local version we use a marker file
    is_done_marker = result_dir.joinpath(jobname + ".done.txt")
    if keep_existing_results and is_done_marker.is_file():
        logger.info(f"Skipping {jobname} (already done)")
        continue

    query_sequence_len = (
        len(query_sequence)
        if isinstance(query_sequence, str)
        else sum(len(s) for s in query_sequence)
    )
    logger.info(
        f"Query {job_number + 1}/{len(queries)}: {jobname} (length {query_sequence_len})"
    )


    # #add glycine linker if there isn't one
    glycine_linker_seq = 'G' * glycine_linker_length
    sequence = (
        query_sequence
        if isinstance(query_sequence, str)
        else glycine_linker_seq.join(query_sequence)
    )

    print('> Sequence to model: ')
    print(sequence)

    if len(sequence) > 700:
        model.trunk.set_chunk_size(64)
    else:
        model.trunk.set_chunk_size(128)

    for seed in range(num_samples):
        torch.cuda.empty_cache()
        if samples is None:
            seed = "default"
            mask_rate = 0.0
            model.train(False)
        else:
            torch.manual_seed(seed)
            mask_rate = masking_rate if "LM" in stochastic_mode else 0.0
            model.train("SM" in stochastic_mode)

        output = model.infer(sequence,
                            num_recycles=num_recycles, #deleted argument chain_linker = "X"*chain_linker, from Alphafold-multimer
                            residue_index_offset=512,
                            mask_rate=mask_rate,
                            return_contacts=get_LM_contacts)
        
        pdb_str = model.output_to_pdb(output)[0]
        output = tree_map(lambda x: x.cpu().numpy(), output)
        ptm = output["ptm"][0]
        plddt = output["plddt"][0,:,1].mean()
        traj.append(parse_output(output))
        print(f'{seed} ptm: {ptm:.3f} plddt: {plddt:.1f}')
        if ptm > best_ptm:
            best_pdb_str = pdb_str
            best_ptm = ptm
            best_output = output
        #os.system(f"mkdir -p {ID}")
        if samples is None:
            pdb_filename = result_dir.joinpath(f"{jobname}_unrelaxed_ptm{ptm:.3f}_r{num_recycles}_seed{seed}.pdb")
        else:
            pdb_filename = result_dir.joinpath(f"{jobname}_unrelaxed_ptm{ptm:.3f}_r{num_recycles}_seed{seed}_{stochastic_mode}_m{masking_rate:.2f}.pdb")

        with open(pdb_filename,"w") as out:
            out.write(pdb_str)




    # try:
    #     if a3m_lines is not None:
    #         if use_templates is False:
    #             (
    #                 unpaired_msa,
    #                 paired_msa,
    #                 query_seqs_unique,
    #                 query_seqs_cardinality,
    #                 template_features,
    #             ) = unserialize_msa(a3m_lines, query_sequence)
    #         else:
    #             (
    #                 unpaired_msa,
    #                 paired_msa,
    #                 query_seqs_unique,
    #                 query_seqs_cardinality,
    #             ) = unserialize_msa(a3m_lines, query_sequence)[:4]
    #             template_features = get_msa_and_templates(
    #                 jobname,
    #                 query_sequence,
    #                 result_dir,
    #                 msa_mode,
    #                 use_templates,
    #                 custom_template_path,
    #                 pair_mode,
    #                 host_url,
    #             )[4]
    #     else:
    #         (
    #             unpaired_msa,
    #             paired_msa,
    #             query_seqs_unique,
    #             query_seqs_cardinality,
    #             template_features,
    #         ) = get_msa_and_templates(
    #             jobname,
    #             query_sequence,
    #             result_dir,
    #             msa_mode,
    #             use_templates,
    #             custom_template_path,
    #             pair_mode,
    #             host_url,
    #         )
    #     msa = msa_to_str(
    #         unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality
    #     )
    #     result_dir.joinpath(jobname + ".a3m").write_text(msa)
    # except Exception as e:
    #     logger.exception(f"Could not get MSA/templates for {jobname}: {e}")
    #     continue
    # try:
    #     (input_features, domain_names) = generate_input_feature(
    #         query_seqs_unique,
    #         query_seqs_cardinality,
    #         unpaired_msa,
    #         paired_msa,
    #         template_features,
    #         is_complex,
    #         model_type,
    #     )
    # except Exception as e:
    #     logger.exception(f"Could not generate input features {jobname}: {e}")
    #     continue
    # try:
    #     query_sequence_len_array = [
    #         len(query_seqs_unique[i])
    #         for i, cardinality in enumerate(query_seqs_cardinality)
    #         for _ in range(0, cardinality)
    #     ]

    #     # only use padding if we have more than one sequence
    #     if sum(query_sequence_len_array) > crop_len:
    #         crop_len = math.ceil(sum(query_sequence_len_array) * recompile_padding)

    #     outs, model_rank = predict_structure(
    #         jobname,
    #         result_dir,
    #         input_features,
    #         is_complex,
    #         use_templates,
    #         sequences_lengths=query_sequence_len_array,
    #         crop_len=crop_len,
    #         model_type=model_type,
    #         model_runner_and_params=model_runner_and_params,
    #         do_relax=use_amber,
    #         rank_by=rank_by,
    #         stop_at_score=stop_at_score,
    #         stop_at_score_below=stop_at_score_below,
    #         prediction_callback=prediction_callback,
    #         use_gpu_relax=use_gpu_relax,
    #         random_seed=random_seed,
    #     )
    # except RuntimeError as e:
    #     # This normally happens on OOM. TODO: Filter for the specific OOM error message
    #     logger.error(f"Could not predict {jobname}. Not Enough GPU memory? {e}")
    #     continue

    # Write representations if needed

    # representation_files = []

    # if save_representations:
    #     for i, key in enumerate(model_rank):
    #         out = outs[key]
    #         model_id = i + 1
    #         model_name = out["model_name"]
    #         representations = out["representations"]

    #         if save_single_representations:
    #             single_representation = np.asarray(representations["single"])
    #             single_filename = result_dir.joinpath(
    #                 f"{jobname}_single_repr_{model_id}_{model_name}"
    #             )
    #             np.save(single_filename, single_representation)

    #         if save_pair_representations:
    #             pair_representation = np.asarray(representations["pair"])
    #             pair_filename = result_dir.joinpath(
    #                 f"{jobname}_pair_repr_{model_id}_{model_name}"
    #             )
    #             np.save(pair_filename, pair_representation)

    # Write alphafold-db format (PAE)
    # alphafold_pae_file = result_dir.joinpath(
    #     jobname + "_predicted_aligned_error_v1.json"
    # )
    # alphafold_pae_file.write_text(get_pae_json(outs[0]["pae"], outs[0]["max_pae"]))
    # num_alignment = (
    #     int(input_features["num_alignments"])
    #     if model_type.startswith("AlphaFold2-multimer")
    #     else input_features["num_alignments"][0]
    # )
    # msa_plot = plot_msa(
    #     input_features["msa"][0:num_alignment],
    #     input_features["msa"][0],
    #     query_sequence_len_array,
    #     query_sequence_len,
    #     dpi=dpi,
    # )
    # coverage_png = result_dir.joinpath(jobname + "_coverage.png")
    # msa_plot.savefig(str(coverage_png))
    # msa_plot.close()
    # paes_plot = plot_paes(
    #     [outs[k]["pae"] for k in model_rank], Ls=query_sequence_len_array, dpi=dpi
    # )
    # pae_png = result_dir.joinpath(jobname + "_PAE.png")
    # paes_plot.savefig(str(pae_png))
    # paes_plot.close()
    # plddt_plot = plot_plddts(
    #     [outs[k]["plddt"] for k in model_rank], Ls=query_sequence_len_array, dpi=dpi
    # )
    # plddt_png = result_dir.joinpath(jobname + "_plddt.png")
    # plddt_plot.savefig(str(plddt_png))
    # plddt_plot.close()
    # result_files = [
    #     bibtex_file,
    #     config_out_file,
    #     alphafold_pae_file,
    #     result_dir.joinpath(jobname + ".a3m"),
    #     pae_png,
    #     coverage_png,
    #     plddt_png,
    #     *representation_files,
    # ]
    # if use_templates:
    #     templates_file = result_dir.joinpath(
    #         jobname + "_template_domain_names.json"
    #     )
    #     templates_file.write_text(json.dumps(domain_names))
    #     result_files.append(templates_file)

    # for i, key in enumerate(model_rank):
    #     result_files.append(
    #         result_dir.joinpath(
    #             f"{jobname}_unrelaxed_rank_{i + 1}_{outs[key]['model_name']}.pdb"
    #         )
    #     )
    #     result_files.append(
    #         result_dir.joinpath(
    #             f"{jobname}_unrelaxed_rank_{i + 1}_{outs[key]['model_name']}_scores.json"
    #         )
    #     )
    #     if use_amber:
    #         result_files.append(
    #             result_dir.joinpath(
    #                 f"{jobname}_relaxed_rank_{i + 1}_{outs[key]['model_name']}.pdb"
    #             )
    #         )

    # if zip_results:
    #     with zipfile.ZipFile(result_zip, "w") as result_zip:
    #         for file in result_files:
    #             result_zip.write(file, arcname=file.name)
    #     # Delete only after the zip was successful, and also not the bibtex and config because we need those again
    #     for file in result_files[2:]:
    #         file.unlink()
    # else:
    #     is_done_marker.touch()
    is_done_marker.touch()

logger.info("Done")


In [None]:
#@title ##Terminate runtime
from google.colab import runtime
runtime.unassign()