# extra controls

In [1]:
%%time
#@title setup
unified_memory = True #@param {type:"boolean"}
import os, time, gc
if unified_memory:
  ENV = {"TF_FORCE_UNIFIED_MEMORY":"1", "XLA_PYTHON_CLIENT_MEM_FRACTION":"4.0"}
  for k,v in ENV.items(): os.environ[k] = v

if not os.path.isdir("params"):
  # get code
  print("installing ColabDesign")
  os.system("(mkdir params; apt-get install aria2 -qq; \
  aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \
  tar -xf alphafold_params_2022-12-06.tar -C params; touch params/done.txt )&")

  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@gamma_1")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  os.system("wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O colabfold_utils.py")

  # install hhsuite
  print("installing HHsuite")
  os.makedirs("hhsuite", exist_ok=True)
  os.system(f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/")

  # download params
  if not os.path.isfile("params/done.txt"):
    print("downloading AlphaFold params")
    while not os.path.isfile("params/done.txt"):
      time.sleep(5)
if "hhsuite" not in os.environ['PATH']:
  os.environ['PATH'] += ":hhsuite/bin:hhsuite/scripts"

import re, tempfile
from IPython.display import HTML
from google.colab import files
import numpy as np
from colabdesign import mk_af_model, clear_mem
from colabdesign.af.contrib import predict
from colabdesign.af.contrib.cyclic import add_cyclic_offset
from colabdesign.shared.protein import _np_rmsd, _np_kabsch
from colabdesign.shared.plot import plot_pseudo_3D, pymol_cmap

import jax
import jax.numpy as jnp
from colabfold_utils import run_mmseqs2
import matplotlib.pyplot as plt
import string
import numpy as np

def clear_mem():
  backend = jax.lib.xla_bridge.get_backend()
  for buf in backend.live_buffers(): buf.delete()

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.makedirs("tmp",exist_ok=True)
    os.system(f"wget -qnc https://files.rcsb.org/download/{pdb_code}.cif -P tmp/")
    return f"tmp/{pdb_code}.cif"
  else:
    os.makedirs("tmp",exist_ok=True)
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v4.pdb -P tmp/")
    return f"tmp/AF-{pdb_code}-F1-model_v4.pdb"

def run_hhalign(query_sequence, target_sequence, query_a3m=None, target_a3m=None):
  with tempfile.NamedTemporaryFile() as tmp_query, \
  tempfile.NamedTemporaryFile() as tmp_target, \
  tempfile.NamedTemporaryFile() as tmp_alignment:
    if query_a3m is None:
      tmp_query.write(f">Q\n{query_sequence}\n".encode())
      tmp_query.flush()
      query_a3m = tmp_query.name
    if target_a3m is None:
      tmp_target.write(f">T\n{target_sequence}\n".encode())
      tmp_target.flush()
      target_a3m = tmp_target.name
    os.system(f"hhalign -hide_cons -i {query_a3m} -t {target_a3m} -o {tmp_alignment.name}")
    X, start_indices = predict.parse_hhalign_output(tmp_alignment.name)
  return X, start_indices

def run_do_not_align(query_sequence, target_sequence, **arg):
  return [query_sequence,target_sequence],[0,0]

def run_hhfilter(input, output, id=90, qid=10):
  os.system(f"hhfilter -id {id} -qid {qid} -i {input} -o {output}")

@jax.jit
def get_coevolution(X):
  '''given one-hot encoded MSA, return contacts'''
  Y = jax.nn.one_hot(X,22)
  N,L,A = Y.shape
  Y_flat = Y.reshape(N,-1)
  # covariance
  c = jnp.cov(Y_flat.T)

  # inverse covariance
  shrink = 4.5/jnp.sqrt(N) * jnp.eye(c.shape[0])
  ic = jnp.linalg.inv(c + shrink)

  # partial correlation coefficient
  ic_diag = jnp.diag(ic)
  pcc = ic / jnp.sqrt(ic_diag[:,None] * ic_diag[None,:])

  raw = jnp.sqrt(jnp.square(pcc.reshape(L,A,L,A)[:,:20,:,:20]).sum((1,3)))
  i = jnp.arange(L)
  raw = raw.at[i,i].set(0)
  # do apc
  ap = raw.sum(0,keepdims=True) * raw.sum(1,keepdims=True) / raw.sum()
  return (raw - ap).at[i,i].set(0)

def plot_3D(aux, Ls, file_name, show=False):
  plt.figure(figsize=(10,5))
  xyz = aux["atom_positions"][:,1]
  xyz = xyz @ _np_kabsch(xyz, xyz, return_v=True, use_jax=False)
  ax = plt.subplot(1,2,1)
  if len(Ls) > 1:
    plt.title("chain")
    c = np.concatenate([[n]*L for n,L in enumerate(Ls)])
    plot_pseudo_3D(xyz=xyz, c=c, cmap=pymol_cmap, cmin=0, cmax=39, Ls=Ls, ax=ax)
  else:
    plt.title("length")
    plot_pseudo_3D(xyz=xyz, Ls=Ls, ax=ax)
  plt.axis(False)
  ax = plt.subplot(1,2,2)
  plt.title("plddt")
  plot_pseudo_3D(xyz=xyz, c=aux["plddt"], cmin=0.5, cmax=0.9, Ls=Ls, ax=ax)
  plt.axis(False)
  plt.savefig(file_name, dpi=200, bbox_inches='tight')
  plt.show() if show else plt.close()

installing ColabDesign
installing HHsuite
CPU times: user 1.41 s, sys: 859 ms, total: 2.27 s
Wall time: 50.5 s


In [2]:
!wget -qnc https://codeload.github.com/HWaymentSteele/AF_Cluster/zip/refs/heads/main -O AF_Cluster.zip
!unzip -qqo AF_Cluster.zip

In [3]:
data = {}
seq = "MQSWYLLYCKRGQLQRAQEHLERQAVNCLAPMITLEKIVRGKRTAVSEPLFPNYLFVEFDPEVIHTTTINATRGVSHFVRFGASPAIVPSAVIHQLSVYKPKDIVDPATPYPGDKVIITEGAFEGFQAIFTEPDGEARSMLLLNLINKEIKHSVKNTEFRKA"
N_SEEDS=50

In [None]:
for a in ["ptm","noptm"]:
  for b in ["mask","nomask"]:
    clear_mem()
    af = mk_af_model("hallucination",
                    model_type = "alphafold2_ptm" if a == "ptm" else "alphafold2",
                    use_bfloat16=True,
                    optimize_seq=False,
                    num_msa=512,
                    num_extra_msa=5120,
                    use_cluster_profile=True,
                    use_mlm = b=="mask")
    af.prep_inputs(len(seq))
    af.set_msa(a3m_filename="/content/AF_Cluster-main/data_sep2022/04_OtherFoldswitchers/00_RfaH/RfaH_from_colabfold.a3m")
    mode = f"{a}_{b}"
    data[mode] = []
    for m in range(5):
      data[mode].append([])
      for seed in range(N_SEEDS):
        af.predict(num_recycles=3, models=[m], verbose=False, seed=seed)
        data[mode][-1].append(af.aux["plddt"].mean())
      print(mode,m,np.mean(data[mode][-1]))

In [None]:
import pickle
with open('data.pickle', 'wb') as handle:
  pickle.dump(data, handle)

In [None]:
data_clusters = {}
CLUSTER_SET = [49] # add more cluster numbers here
N_SEEDS = 50

In [None]:
!mkdir output

In [None]:
for a in ["ptm"]:
  for b in ["mask"]:
    clear_mem()
    af = mk_af_model("hallucination",
                      model_type = "alphafold2_ptm" if a == "ptm" else "alphafold2",
                      use_bfloat16=True,
                      optimize_seq=False,
                      num_msa=512,
                      num_extra_msa=5120,
                      use_cluster_profile=True,
                      use_mlm = b=="mask")
    for k in CLUSTER_SET:
      mode = f"{k}_{a}_{b}"
      data_clusters[mode] = []
      af.prep_inputs(len(seq))
      af.set_msa(a3m_filename=f"/content/AF_Cluster-main/data_sep2022/04_OtherFoldswitchers/00_RfaH/msas/RFAH_{k:03}.a3m")
      for m in range(5):
        data_clusters[mode].append([])
        for seed in range(N_SEEDS):
          af.predict(num_recycles=3, models=[m], verbose=False, seed=seed)
          af.save_pdb(f"output/{mode}.model_{m+1}_ptm.seed{seed:02}.pdb")
          data_clusters[mode][-1].append(af.aux["plddt"].mean())
        print(mode,m,np.mean(data_clusters[mode][-1]))

In [None]:
import pickle
with open('data_clusters.pickle', 'wb') as handle:
  pickle.dump(data_clusters, handle)