In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
import os
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [4]:
import hashlib
import pickle
import requests
import tarfile
import tqdm.notebook

from string import ascii_uppercase

import jax
import numpy as np

from alphafold.data import parsers
from alphafold.data import pipeline
from alphafold.model import config
from alphafold.model import data
from alphafold.model import model

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

In [5]:
print(jax.local_devices()[0].platform)

gpu


In [6]:
def run_mmseqs2(query_sequence, prefix, use_env=True, filter=False):
    def submit(query_sequence, mode):
        res = requests.post('https://a3m.mmseqs.com/ticket/msa', data={'q':f">1\n{query_sequence}", 'mode': mode})
        return res.json()
    def status(ID):
        res = requests.get(f'https://a3m.mmseqs.com/ticket/{ID}')
        return res.json()
    def download(ID, path):
        res = requests.get(f'https://a3m.mmseqs.com/result/download/{ID}')
        with open(path,"wb") as out: out.write(res.content)
      
    if filter:
        mode = "env" if use_env else "all"
    else:
        mode = "env-nofilter" if use_env else "nofilter"
    
    path = f"{prefix}_{mode}"
    if not os.path.isdir(path): os.mkdir(path)

    # call mmseqs2 api
    tar_gz_file = f'{path}/out.tar.gz'
    if not os.path.isfile(tar_gz_file):
        out = submit(query_sequence, mode)
        while out["status"] in ["RUNNING","PENDING"]:
            time.sleep(1)
            out = status(out["id"])    
        download(out["id"], tar_gz_file)
    
    # parse a3m files
    a3m_lines = []
    a3m = f"{prefix}_{mode}.a3m"
    if not os.path.isfile(a3m):
        with tarfile.open(tar_gz_file) as tar_gz: tar_gz.extractall(path)
        a3m_files = [f"{path}/uniref.a3m"]
        if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
        a3m_out = open(a3m,"w")
        for a3m_file in a3m_files:
            for line in open(a3m_file,"r"):
                line = line.replace("\x00","")
                if len(line) > 0:
                    a3m_lines.append(line)
                    a3m_out.write(line)
    else:
        a3m_lines = open(a3m).readlines()
    return "".join(a3m_lines)

In [7]:
sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK'

MIN_SEQUENCE_LENGTH = 16
MAX_SEQUENCE_LENGTH = 2500

# Remove all whitespaces, tabs and end lines; upper-case
sequence = sequence.translate(str.maketrans('', '', ' \n\t')).upper()
save_all_models = True 
save_msa = True 

aatypes = set('ACDEFGHIKLMNPQRSTVWY')  # 20 standard aatypes
if not set(sequence).issubset(aatypes):
    raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')
if len(sequence) < MIN_SEQUENCE_LENGTH:
    raise Exception(f'Input sequence is too short: {len(sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')
if len(sequence) > MAX_SEQUENCE_LENGTH:
    raise Exception(f'Input sequence is too long: {len(sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')

In [8]:
msa_method = "mmseqs2"
prefix = hashlib.sha1(sequence.encode()).hexdigest()

output_dir = 'prediction'
os.makedirs(output_dir, exist_ok=True)

# --- Search against genetic databases ---
# with open('target.fasta', 'wt') as f:
#     f.write(f'>query\n{sequence}')

# Run the search against chunks of genetic databases (since the genetic
# databases don't fit in Colab ramdisk).

if msa_method == "mmseqs2":
    msa, deletion_matrix = parsers.parse_a3m(run_mmseqs2(sequence, prefix, filter=False))
    msas,deletion_matrices = [msa],[deletion_matrix]
else:
  # run jackhmmer
  msas, deletion_matrices = run_jackhmmer(sequence, prefix)

full_msa = []
for msa in msas: full_msa += msa

# save MSA as pickle
# pickle.dump({"msas":msas,"deletion_matrices":deletion_matrices},
#             open("prediction/msa.pickle","wb"))

# deduplicate
deduped_full_msa = list(dict.fromkeys(full_msa))
total_msa_size = len(deduped_full_msa)
print(f'\n{total_msa_size} Sequences Found in Total\n')

msa_arr = np.array([list(seq) for seq in deduped_full_msa])
num_alignments, num_res = msa_arr.shape


3214 Sequences Found in Total



In [16]:
msa_arr.shape

(3214, 59)

## Run the model

In [9]:
def _placeholder_template_feats(num_templates_, num_res_):
    return {
        'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32),
        'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37, 3], np.float32),
        'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37], np.float32),
        'template_domain_names': np.zeros([num_templates_], np.float32),
        'template_sum_probs': np.zeros([num_templates_], np.float32),
    }

In [10]:
!ls ../../alphafold/alphafold/data/params/

LICENSE			params_model_2_ptm.npz	params_model_4_ptm.npz
params_model_1.npz	params_model_3.npz	params_model_5.npz
params_model_1_ptm.npz	params_model_3_ptm.npz	params_model_5_ptm.npz
params_model_2.npz	params_model_4.npz


In [11]:
model_names = ['model_1']#, 'model_2', 'model_3', 'model_4', 'model_5']

plddts = {}
pae_outputs = {}
unrelaxed_proteins = {}

with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:
    for model_name in model_names:
        pbar.set_description(f'Running {model_name}')
        num_templates = 0
        num_res = len(sequence)

        feature_dict = {}
        feature_dict.update(pipeline.make_sequence_features(sequence, 'test', num_res))
        feature_dict.update(pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))
        feature_dict.update(_placeholder_template_feats(num_templates, num_res))

        cfg = config.model_config(model_name)
        params = data.get_model_haiku_params(model_name, '../../alphafold/alphafold/data')
        model_runner = model.RunModel(cfg, params)
        processed_feature_dict = model_runner.process_features(feature_dict,
                                                               random_seed=0)
        prediction_result = model_runner.predict(processed_feature_dict)

#         mean_plddt = prediction_result['plddt'].mean()

#         if 'predicted_aligned_error' in prediction_result:
#             pae_outputs[model_name] = (
#               prediction_result['predicted_aligned_error'],
#               prediction_result['max_predicted_aligned_error']
#             )
#         else:
#             # Get the pLDDT confidence metrics. Do not put pTM models here as they
#             # should never get selected.
#             plddts[model_name] = prediction_result['plddt']

#         # Set the b-factors to the per-residue plddt.
#         final_atom_mask = prediction_result['structure_module']['final_atom_mask']
#         b_factors = prediction_result['plddt'][:, None] * final_atom_mask
#         unrelaxed_protein = protein.from_prediction(processed_feature_dict,
#                                                     prediction_result,
#                                                     b_factors=b_factors)
#         unrelaxed_proteins[model_name] = unrelaxed_protein

        # Delete unused outputs to save memory.
        del model_runner
        del params
#         del prediction_result
        pbar.update(n=1)

  0%|          | 0/2 [elapsed: 00:00 remaining: ?]

In [21]:
prediction_result["representations"]["single"]

DeviceArray([[ 1.58250170e+01,  4.19535637e+01,  1.16711136e+02, ...,
              -4.44708447e+03, -1.49557039e-01,  4.53268852e+01],
             [ 7.79034376e+00,  2.44731407e+01,  9.70779419e+01, ...,
              -3.94070557e+03, -1.59975433e+01,  5.48255730e+01],
             [ 2.90133324e+01,  8.11811733e+00,  6.85647354e+01, ...,
              -4.27221973e+03, -6.05358124e+00, -5.11846542e-01],
             ...,
             [ 3.52386436e+01,  2.27563324e+01, -2.04485016e+01, ...,
              -3.81309277e+03, -2.90400922e-01,  5.91051598e+01],
             [-4.56238031e+00,  1.97234285e+00, -2.46919346e+00, ...,
              -2.43228076e+03,  7.93027306e+00,  1.86131725e+01],
             [-1.44213533e+01,  7.59067869e+00, -2.59044952e+01, ...,
              -2.40460889e+03,  4.86359901e+01,  6.34384766e+01]],            dtype=float32)

In [17]:
prediction_result["representations"]["msa"].shape

(508, 59, 256)

In [18]:
prediction_result["representations"]["msa_first_row"].shape

(59, 256)

In [19]:
prediction_result["representations"]["pair"].shape

(59, 59, 128)

In [20]:
prediction_result["representations"]["structure_module"].shape

(59, 384)