<a href="https://colab.research.google.com/github/SkywalkerLuke/DiffuSeq/blob/main/ESMFold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**ESMFold**
for more details see: [Github](https://github.com/facebookresearch/esm/tree/main/esm), [Preprint](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1)

#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.
- use "/" to specify chainbreaks, (eg. sequence="AAA/AAA")
- for homo-oligomeric predictions, set copies > 1
- See [experimental notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/ESMFold_advanced.ipynb) for more advanced options (like sampling).

#### **Colab Limitations**
- For short monomeric proteins under the length 400, consider using [ESMFold API](https://esmatlas.com/resources?action=fold) (no need for GPU, super fast!)
- On Tesla T4 (typical free colab GPU), max total length ~ 900

In [1]:
%%time
#@title install
#@markdown install ESMFold, OpenFold and download Params (~2min 30s)
version = "1" # @param ["0", "1"]
model_name = "esmfold_v0.model" if version == "0" else "esmfold.model"
import os, time
if not os.path.isfile(model_name):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system(f"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &")

  if not os.path.isfile("finished_install"):
    # install libs
    print("installing libs...")
    os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
    os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

    print("installing openfold...")
    # install openfold
    os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")

    print("installing esmfold...")
    # install esmfold
    os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")
    os.system("touch finished_install")

  # wait for Params to finish downloading...
  while not os.path.isfile(model_name):
    time.sleep(5)
  if os.path.isfile(f"{model_name}.aria2"):
    print("downloading params...")
  while os.path.isfile(f"{model_name}.aria2"):
    time.sleep(5)

installing libs...
installing openfold...
installing esmfold...
CPU times: user 8.74 ms, sys: 6.39 ms, total: 15.1 ms
Wall time: 3min 20s


In [None]:
#@title ##run **ESMFold**
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
import torch
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax
import gc

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase+ascii_lowercase)

jobname = "test" #@param {type:"string"}
jobname = re.sub(r'\W+', '', jobname)[:50]

sequence = "GWSTELEKHREELKEFLKKEGITNVEIRIDNGRLEVRVEGGTERLKRFLEELRQKLEKKGYTVDIKIE" #@param {type:"string"}
sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
sequence = re.sub(":+",":",sequence)
sequence = re.sub("^[:]+","",sequence)
sequence = re.sub("[:]+$","",sequence)
copies = 1 #@param {type:"integer"}
if copies == "" or copies <= 0: copies = 1
sequence = ":".join([sequence] * copies)
num_recycles = 3 #@param ["0", "1", "2", "3", "6", "12", "24"] {type:"raw"}
chain_linker = 25

ID = jobname+"_"+get_hash(sequence)[:5]
seqs = sequence.split(":")
lengths = [len(s) for s in seqs]
length = sum(lengths)
print("length",length)

u_seqs = list(set(seqs))
if len(seqs) == 1: mode = "mono"
elif len(u_seqs) == 1: mode = "homo"
else: mode = "hetero"

if "model" not in dir() or model_name != model_name_:
  if "model" in dir():
    # delete old model from memory
    del model
    gc.collect()
    if torch.cuda.is_available():
      torch.cuda.empty_cache()

  model = torch.load(model_name, weights_only=False)
  model.eval().cuda().requires_grad_(False)
  model_name_ = model_name

# optimized for Tesla T4
if length > 700:
  model.set_chunk_size(64)
else:
  model.set_chunk_size(128)

torch.cuda.empty_cache()
output = model.infer(sequence,
                     num_recycles=num_recycles,
                     chain_linker="X"*chain_linker,
                     residue_index_offset=512)

pdb_str = model.output_to_pdb(output)[0]
output = tree_map(lambda x: x.cpu().numpy(), output)
ptm = output["ptm"][0]
plddt = output["plddt"][0,...,1].mean()
O = parse_output(output)
print(f'ptm: {ptm:.3f} plddt: {plddt:.3f}')
os.system(f"mkdir -p {ID}")
prefix = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_default"
np.savetxt(f"{prefix}.pae.txt",O["pae"],"%.3f")
with open(f"{prefix}.pdb","w") as out:
  out.write(pdb_str)

length 68
ptm: 0.817 plddt: 90.519
CPU times: user 12.2 s, sys: 8.07 s, total: 20.2 s
Wall time: 51.8 s


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:


import os, re, glob, hashlib
import numpy as np
import pandas as pd
from tqdm import tqdm
from jax.tree_util import tree_map
from scipy.special import softmax
import torch

INPUT_DIR  = "/content/drive/MyDrive/B-cell/data"
OUTPUT_DIR = "/content/drive/MyDrive/B-cell/esmfold_outputs_min"
SEQ_COL = "Epitope.1"

os.makedirs(OUTPUT_DIR, exist_ok=True)

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]
  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask]}
  return o

def clean(seq):
  seq = re.sub("[^A-Z:]", "", str(seq).replace("/",":").upper())
  seq = re.sub(":+",":",seq)
  seq = re.sub("^[:]+","",seq)
  seq = re.sub("[:]+$","",seq)
  return seq

def short_hash(x):
  return hashlib.sha1(x.encode()).hexdigest()[:8]

# 需要你自己保证：model 已经按你原方法 load 好了
# model = torch.load(model_name, weights_only=False)
# model.eval().cuda().requires_grad_(False)

def run_one_sequence(sequence, out_prefix, num_recycles=3, chain_linker=25):
  seqs = sequence.split(":")
  length = sum(len(s) for s in seqs)
  if length > 700:
    model.set_chunk_size(64)
  else:
    model.set_chunk_size(128)

  torch.cuda.empty_cache()
  output = model.infer(sequence,
                       num_recycles=num_recycles,
                       chain_linker="X"*chain_linker,
                       residue_index_offset=512)
  pdb_str = model.output_to_pdb(output)[0]
  output = tree_map(lambda x: x.cpu().numpy(), output)
  ptm = float(output["ptm"][0])
  plddt = float(output["plddt"][0,...,1].mean())
  O = parse_output(output)

  with open(out_prefix + ".pdb","w") as f:
    f.write(pdb_str)
  np.savetxt(out_prefix + ".pae.txt", O["pae"], "%.3f")

  return ptm, plddt

# 批量跑 fold/test
files = sorted(glob.glob(os.path.join(INPUT_DIR, "dataset_fold*.tsv")))
testf = os.path.join(INPUT_DIR, "dataset_test.tsv")
if os.path.exists(testf):
  files.append(testf)

for fp in files:
  name = os.path.splitext(os.path.basename(fp))[0]
  df = pd.read_csv(fp, sep="\t")
  seqs = [clean(s) for s in df[SEQ_COL].tolist()]
  uniq = sorted(set(seqs))

  out_dir = os.path.join(OUTPUT_DIR, name)
  os.makedirs(out_dir, exist_ok=True)

  for s in tqdm(uniq, desc=name):
    out_prefix = os.path.join(out_dir, f"{s}_{short_hash(s)}")
    if os.path.exists(out_prefix + ".pdb"):
      continue
    run_one_sequence(s, out_prefix, num_recycles=3, chain_linker=25)

dataset_fold1:  12%|█▏        | 2942/24888 [1:16:28<9:24:14,  1.54s/it]

In [None]:
import os, glob, time
import numpy as np
from tqdm.auto import tqdm

# ======================
# 配置
# ======================
PDB_ROOT   = "/content/drive/MyDrive/B-cell/esmfold_outputs_min"
GRAPH_ROOT = "/content/drive/MyDrive/B-cell/esmfold_contacts_ca8"

THR_A = 8.0
CHAIN_ID = "A"          # 不确定链名就设 None（读所有链的CA）
MIN_SEQ_SEP = None      # 例如 3 可去掉 |i-j|<3 的边；不需要就 None
OVERWRITE = False
SAVE_EDGE_INDEX = True

# ======================
# 基础函数：读取/计算/保存
# ======================
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def read_ca_from_pdb_path(pdb_path, chain_id=None):
    """逐行扫描PDB，仅提取CA坐标。"""
    res_ids = []
    xyz = []
    with open(pdb_path, "r") as f:
        for line in f:
            if not (line.startswith("ATOM") or line.startswith("HETATM")):
                continue
            if line[12:16].strip() != "CA":
                continue

            ch = line[21].strip()
            if chain_id is not None and ch != chain_id:
                continue

            resseq = int(line[22:26])
            icode = line[26].strip()
            x = float(line[30:38]); y = float(line[38:46]); z = float(line[46:54])
            res_ids.append((ch, resseq, icode))
            xyz.append((x, y, z))

    return res_ids, np.asarray(xyz, dtype=np.float32)

def pairwise_distances(X):
    diff = X[:, None, :] - X[None, :, :]
    return np.sqrt(np.sum(diff * diff, axis=-1))

def make_contact_adj_from_ca(ca_xyz, thr_A=8.0, remove_self=True, min_seq_sep=None):
    D = pairwise_distances(ca_xyz)
    A = (D < thr_A).astype(np.uint8)

    if remove_self:
        np.fill_diagonal(A, 0)

    if min_seq_sep is not None and min_seq_sep > 0:
        n = A.shape[0]
        for i in range(n):
            lo = max(0, i - (min_seq_sep - 1))
            hi = min(n, i + (min_seq_sep - 1) + 1)
            A[i, lo:hi] = 0

    return A  # 你最终主要用A即可（更快更省）

def adjacency_to_edge_index(A):
    src, dst = np.where(A > 0)
    return np.stack([src, dst], axis=0).astype(np.int64)  # (2,E)

def mirror_out_path(pdb_path, pdb_root, graph_root, suffix):
    """
    pdb_root/**/xxx.pdb  -> graph_root/**/xxx{suffix}
    """
    rel = os.path.relpath(pdb_path, pdb_root)
    rel_noext = os.path.splitext(rel)[0]
    out_path = os.path.join(graph_root, rel_noext + suffix)
    ensure_dir(os.path.dirname(out_path))
    return out_path

# ======================
# 可见性强的主流程
# ======================
def run_all_verbose(
    pdb_root,
    graph_root,
    thr_A=8.0,
    chain_id="A",
    min_seq_sep=None,
    overwrite=False,
    save_edge_index=True,
):
    t_start = time.time()

    print("========== 0) 配置 ==========")
    print("PDB_ROOT  :", pdb_root)
    print("GRAPH_ROOT:", graph_root)
    print("thr_A     :", thr_A)
    print("chain_id  :", chain_id)
    print("min_seq_sep:", min_seq_sep)
    print("overwrite :", overwrite)
    print("save_edge_index:", save_edge_index)
    print()

    print("========== 1) 路径检查 ==========")
    print("PDB_ROOT exists?  ", os.path.exists(pdb_root))
    if not os.path.exists(pdb_root):
        raise FileNotFoundError(f"PDB_ROOT not found: {pdb_root}")
    ensure_dir(graph_root)
    print("GRAPH_ROOT exists?", os.path.exists(graph_root))
    print()

    print("========== 2) 枚举 split 文件夹（PDB_ROOT 下的一级子目录） ==========")
    split_dirs = sorted([d for d in glob.glob(os.path.join(pdb_root, "*")) if os.path.isdir(d)])
    print("split_dirs count:", len(split_dirs))
    if split_dirs:
        print("split dir names :", [os.path.basename(d) for d in split_dirs])
    print()

    # 统计每个 split 下 pdb 数量（一级） + 递归总数
    print("========== 3) 统计文件数量 ==========")
    per_split_counts = []
    total_pdb_one_level = 0
    for sd in split_dirs:
        c = len(glob.glob(os.path.join(sd, "*.pdb")))
        per_split_counts.append((os.path.basename(sd), c))
        total_pdb_one_level += c

    print("PDB count (one-level per split):")
    for name, c in per_split_counts:
        print(f"  - {name}: {c}")

    # 递归统计（如果你的PDB不止一层目录，这个才是实际会处理到的）
    pdb_files = sorted(glob.glob(os.path.join(pdb_root, "**", "*.pdb"), recursive=True))
    print("PDB count (recursive under PDB_ROOT):", len(pdb_files))
    if len(pdb_files) == 0:
        print("没有找到任何 .pdb 文件：请检查目录层级/扩展名。")
        return
    print()

    print("========== 4) 开始处理（按 split 分组） ==========")
    # 把pdb按 split 名分组（只按 PDB_ROOT 的第一层目录名归类）
    groups = {}
    for p in pdb_files:
        rel = os.path.relpath(p, pdb_root)
        split = rel.split(os.sep)[0]  # 第一段
        groups.setdefault(split, []).append(p)

    print("Will process splits:")
    for k in sorted(groups.keys()):
        print(f"  - {k}: {len(groups[k])} pdb")
    print()

    # 总进度条
    total = len(pdb_files)
    pbar_total = tqdm(total=total, desc="TOTAL", unit="pdb")

    # 统计
    n_done = 0
    n_skipped = 0
    n_failed = 0
    n_no_ca = 0

    for split in sorted(groups.keys()):
        files = sorted(groups[split])
        pbar_split = tqdm(files, desc=f"SPLIT {split}", unit="pdb", leave=False)

        for pdb_path in pbar_split:
            try:
                out_adj  = mirror_out_path(pdb_path, pdb_root, graph_root, f".ca{thr_A:.0f}.adj.npy")
                out_meta = mirror_out_path(pdb_path, pdb_root, graph_root, f".ca{thr_A:.0f}.meta.npz")
                out_edge = mirror_out_path(pdb_path, pdb_root, graph_root, f".ca{thr_A:.0f}.edge_index.npy")

                already = os.path.exists(out_adj) and os.path.exists(out_meta) and (
                    (not save_edge_index) or os.path.exists(out_edge)
                )
                if (not overwrite) and already:
                    n_skipped += 1
                    continue

                res_ids, ca_xyz = read_ca_from_pdb_path(pdb_path, chain_id=chain_id)
                if ca_xyz.shape[0] == 0:
                    n_no_ca += 1
                    continue

                A = make_contact_adj_from_ca(
                    ca_xyz, thr_A=thr_A, remove_self=True, min_seq_sep=min_seq_sep
                )
                np.save(out_adj, A)

                if save_edge_index:
                    edge_index = adjacency_to_edge_index(A)
                    np.save(out_edge, edge_index)

                # 元信息：不打印文件名，但保存路径便于追溯
                np.savez(
                    out_meta,
                    pdb_path=pdb_path,
                    thr_A=float(thr_A),
                    chain_id="" if chain_id is None else chain_id,
                    min_seq_sep=-1 if min_seq_sep is None else int(min_seq_sep),
                    n_res=int(len(res_ids)),
                    res_ids=np.array(res_ids, dtype=object),
                )

                n_done += 1

            except Exception:
                n_failed += 1
                # 不打印具体文件名：这里仅记一次失败
                # 如果你想排查，可把 traceback 打开
                # import traceback; traceback.print_exc()
            finally:
                pbar_total.update(1)
                # 在进度条上显示统计（不含文件名）
                pbar_total.set_postfix({
                    "done": n_done,
                    "skip": n_skipped,
                    "noCA": n_no_ca,
                    "fail": n_failed,
                })

        pbar_split.close()

    pbar_total.close()

    print()
    print("========== 5) 汇总 ==========")
    print("total pdb:", total)
    print("done     :", n_done)
    print("skipped  :", n_skipped)
    print("no CA    :", n_no_ca)
    print("failed   :", n_failed)
    print("elapsed(s):", round(time.time() - t_start, 2))

# 执行
run_all_verbose(
    pdb_root=PDB_ROOT,
    graph_root=GRAPH_ROOT,
    thr_A=THR_A,
    chain_id=CHAIN_ID,
    min_seq_sep=MIN_SEQ_SEP,
    overwrite=OVERWRITE,
    save_edge_index=SAVE_EDGE_INDEX,
)

In [None]:
#@title display (optional) {run: "auto"}
import py3Dmol
pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]

def show_pdb(pdb_str, show_sidechains=False, show_mainchains=False,
             color="pLDDT", chains=None, vmin=50, vmax=90,
             size=(800,480), hbondCutoff=4.0,
             Ls=None,
             animate=False):

  if chains is None:
    chains = 1 if Ls is None else len(Ls)
  view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])
  if animate:
    view.addModelsAsFrames(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  else:
    view.addModel(pdb_str,'pdb',{'hbondCutoff':hbondCutoff})
  if color == "pLDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})
  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  view.zoomTo()
  if animate: view.animate()
  return view

color = "confidence" #@param ["confidence", "rainbow", "chain"]
if color == "confidence": color = "pLDDT"
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
show_pdb(pdb_str, color=color,
         show_sidechains=show_sidechains,
         show_mainchains=show_mainchains,
         Ls=lengths).show()

In [None]:
#@title plot confidence (optional)

dpi = 100 #@param {type:"integer"}

def plot_ticks(Ls):
  Ln = sum(Ls)
  L_prev = 0
  for L_i in Ls[:-1]:
    L = L_prev + L_i
    L_prev += L_i
    plt.plot([0,Ln],[L,L],color="black")
    plt.plot([L,L],[0,Ln],color="black")
  ticks = np.cumsum([0]+Ls)
  ticks = (ticks[1:] + ticks[:-1])/2
  plt.yticks(ticks,alphabet_list[:len(ticks)])

def plot_confidence(O, Ls=None, dpi=100):
  if "lm_contacts" in O:
    plt.figure(figsize=(20,4), dpi=dpi)
    plt.subplot(1,4,1)
  else:
    plt.figure(figsize=(15,4), dpi=dpi)
    plt.subplot(1,3,1)

  plt.title('Predicted lDDT')
  plt.plot(O["plddt"])
  if Ls is not None:
    L_prev = 0
    for L_i in Ls[:-1]:
      L = L_prev + L_i
      L_prev += L_i
      plt.plot([L,L],[0,100],color="black")
  plt.xlim(0,O["plddt"].shape[0])
  plt.ylim(0,100)
  plt.ylabel('plDDT')
  plt.xlabel('position')
  plt.subplot(1,4 if "lm_contacts" in O else 3,2)

  plt.title('Predicted Aligned Error')
  Ln = O["pae"].shape[0]
  plt.imshow(O["pae"],cmap="bwr",vmin=0,vmax=30,extent=(0, Ln, Ln, 0))
  if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
  plt.colorbar()
  plt.xlabel('Scored residue')
  plt.ylabel('Aligned residue')

  if "lm_contacts" in O:
    plt.subplot(1,4,3)
    plt.title("contacts from LM")
    plt.imshow(O["lm_contacts"],cmap="Greys",vmin=0,vmax=1,extent=(0, Ln, Ln, 0))
    if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
    plt.subplot(1,4,4)
  else:
    plt.subplot(1,3,3)
  plt.title("contacts from Structure Module")
  plt.imshow(O["sm_contacts"],cmap="Greys",vmin=0,vmax=1,extent=(0, Ln, Ln, 0))
  if Ls is not None and len(Ls) > 1: plot_ticks(Ls)
  return plt

plot_confidence(O, Ls=lengths, dpi=dpi)
plt.savefig(f'{prefix}.png',bbox_inches='tight')
plt.show()

In [None]:
#@title download predictions
from google.colab import files
os.system(f"zip {ID}.zip {ID}/*")
files.download(f'{ID}.zip')