In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [25]:
import os
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [22]:
# import os
import glob
import hashlib
import pickle
import requests
import shutil
import tarfile
import time
import tqdm.notebook

from string import ascii_uppercase

import jax
import numpy as np

from alphafold.data import parsers
from alphafold.data import pipeline
from alphafold.model import config
from alphafold.model import data
from alphafold.model import model

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

In [23]:
print(jax.local_devices()[0].platform)

gpu


In [6]:
def run_mmseqs2(query_sequence, prefix, use_env=True, filter=False):
    def submit(query_sequence, mode):
        res = requests.post('https://a3m.mmseqs.com/ticket/msa', data={'q':f">1\n{query_sequence}", 'mode': mode})
        return res.json()
    def status(ID):
        res = requests.get(f'https://a3m.mmseqs.com/ticket/{ID}')
        return res.json()
    def download(ID, path):
        res = requests.get(f'https://a3m.mmseqs.com/result/download/{ID}')
        with open(path,"wb") as out: out.write(res.content)
      
    if filter:
        mode = "env" if use_env else "all"
    else:
        mode = "env-nofilter" if use_env else "nofilter"
    
    path = f"{prefix}_{mode}"
    if not os.path.isdir(path): os.mkdir(path)

    # call mmseqs2 api
    tar_gz_file = f'{path}/out.tar.gz'
    if not os.path.isfile(tar_gz_file):
        out = submit(query_sequence, mode)
        while out["status"] in ["RUNNING","PENDING"]:
            time.sleep(1)
            out = status(out["id"])    
        download(out["id"], tar_gz_file)
    
    # parse a3m files
    a3m_lines = []
    a3m = f"{prefix}_{mode}.a3m"
    if not os.path.isfile(a3m):
        with tarfile.open(tar_gz_file) as tar_gz: tar_gz.extractall(path)
        a3m_files = [f"{path}/uniref.a3m"]
        if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
        a3m_out = open(a3m,"w")
        for a3m_file in a3m_files:
            for line in open(a3m_file,"r"):
                line = line.replace("\x00","")
                if len(line) > 0:
                    a3m_lines.append(line)
                    a3m_out.write(line)
    else:
        a3m_lines = open(a3m).readlines()
    return "".join(a3m_lines)

In [7]:
def fasta_to_protein(fasta: str): return fasta.split(" ")[0].split("|")[-1].split("_")[0] 

In [12]:
msa_output_dir = "../data/msa_sequences"
os.makedirs(msa_output_dir, exist_ok=True)

In [8]:
input_fasta_path = "../data/AVE_unbiased.fasta"
with open(input_fasta_path) as f: input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)

In [13]:
for sequence, description in zip(input_seqs, input_descs):
    protein_name = fasta_to_protein(description)
    print(protein_name, sequence, description)

    MIN_SEQUENCE_LENGTH = 16
    MAX_SEQUENCE_LENGTH = 2800

    # Remove all whitespaces, tabs and end lines; upper-case
    sequence = sequence.translate(str.maketrans('', '', ' \n\t')).upper()

    aatypes = set('ACDEFGHIKLMNPQRSTVWY')  # 20 standard aatypes
    if not set(sequence).issubset(aatypes):
        raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')
    if len(sequence) < MIN_SEQUENCE_LENGTH:
        raise Exception(f'Input sequence is too short: {len(sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')
    if len(sequence) > MAX_SEQUENCE_LENGTH:
        raise Exception(f'Input sequence is too long: {len(sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')
    
    prefix = hashlib.sha1(sequence.encode()).hexdigest()

    # Run the search against chunks of genetic databases (since the genetic
    # databases don't fit in Colab ramdisk).
    msa, deletion_matrix = parsers.parse_a3m(run_mmseqs2(sequence, prefix, filter=False))
    msas,deletion_matrices = [msa],[deletion_matrix]
    
    pickle.dump({"msas":msas,"deletion_matrices":deletion_matrices},
                open(f"{msa_output_dir}/{protein_name}.pickle","wb"))
    
    full_msa = []
    for msa in msas: full_msa += msa

    # deduplicate
    deduped_full_msa = list(dict.fromkeys(full_msa))
    total_msa_size = len(deduped_full_msa)
    print(f'\n{total_msa_size} Sequences Found in Total\n')

ADRB2 MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL sp|P07550|ADRB2_HUMAN Beta-2 adrenergic receptor OS=Homo sapiens OX=9606 GN=ADRB2 PE=1 SV=3


KeyboardInterrupt: 

In [None]:
for file_or_dir in glob.glob(f"./*env-nofilter*"): os.remove(file_or_dir) if os.path.isfile(file_or_dir) else shutil.rmtree(file_or_dir)

## Run the model

In [9]:
embedding_output_dir = "../data/model_embeddings"
os.makedirs(embedding_output_dir, exist_ok=True)

In [10]:
def _placeholder_template_feats(num_templates_, num_res_):
    return {
        'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32),
        'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37, 3], np.float32),
        'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37], np.float32),
        'template_domain_names': np.zeros([num_templates_], np.float32),
        'template_sum_probs': np.zeros([num_templates_], np.float32),
    }

In [26]:
for sequence, description in zip(input_seqs, input_descs):
    protein_name = fasta_to_protein(description)
    print(protein_name)
    
    input_protein_path = f"{msa_output_dir}/{protein_name}.pickle"
    output_protein_path = f"{embedding_output_dir}/{protein_name}.npy"
    
    if os.path.isfile(output_protein_path): continue
    
    msas_dict = pickle.load(open(input_protein_path, "rb"))
    msas, deletion_matrices = (msas_dict[k] for k in ["msas", "deletion_matrices"])

#     model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5']
    model_name = "model_1"

    full_msa = []
    for msa in msas: full_msa += msa

    # deduplicate
    deduped_full_msa = list(dict.fromkeys(full_msa))
    total_msa_size = len(deduped_full_msa)
    print(f'\n{total_msa_size} Sequences Found in Total\n')

    msa_arr = np.array([list(seq) for seq in deduped_full_msa])
    num_alignments, num_res = msa_arr.shape

    print(f'Running {model_name}')
    num_templates = 0
    num_res = len(sequence)

    feature_dict = {}
    feature_dict.update(pipeline.make_sequence_features(sequence, 'test', num_res))
    feature_dict.update(pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))
    feature_dict.update(_placeholder_template_feats(num_templates, num_res))

    cfg = config.model_config(model_name)
    params = data.get_model_haiku_params(model_name, '../../alphafold/alphafold/data')
    model_runner = model.RunModel(cfg, params)
    processed_feature_dict = model_runner.process_features(feature_dict,
                                                           random_seed=0)
    try:
        prediction_result = model_runner.predict(processed_feature_dict)
    except RuntimeError:
        print(f"Failed for {protein_name}")
        continue

    np.save(output_protein_path, prediction_result["representations"]["single"])

    # Delete unused outputs to save memory.
    del model_runner
    del params
    del prediction_result
    del feature_dict
    del processed_feature_dict

ADRB2
AL1A1
ESR1
FEN1
GLCM
IDHC
KAT2A
MK01
MTOR

2599 Sequences Found in Total

Running model_1


2021-07-27 13:23:28.349077: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:272] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.11GiB with freed_by_count=0. The caller indicates that this is not a failure, but may mean that there could be performance gains if more memory were available.


Failed for MTOR
OPRK

9107 Sequences Found in Total

Running model_1
KPYM

47035 Sequences Found in Total

Running model_1
PPARG

5122 Sequences Found in Total

Running model_1
P53

1988 Sequences Found in Total

Running model_1
VDR

4301 Sequences Found in Total

Running model_1


In [28]:
embedding = np.load(output_protein_path)

In [34]:
len(embedding.flatten())

163968

In [35]:
embedding

array([[-3.7297489e+01,  1.2727687e+01,  9.4766846e+01, ...,
        -2.1827852e+03, -6.8753311e+01,  8.1704414e+01],
       [ 1.3574860e+01,  8.7939102e+01,  5.9374817e+01, ...,
        -2.2468887e+03, -2.9185944e+01,  9.7156219e+01],
       [ 1.8065815e+01,  2.9657791e+01,  9.8774246e+01, ...,
        -2.6101824e+03, -9.2502422e+00,  4.0248898e+01],
       ...,
       [-1.3179278e+01, -8.5475168e+00,  5.2997832e+00, ...,
        -2.5445649e+03, -8.7797909e+00,  1.0725223e+02],
       [-7.0162826e+00,  5.4774399e+01,  1.9697548e+01, ...,
        -2.3585698e+03, -2.2868738e+00,  2.6496464e+01],
       [-5.9840130e+01, -9.6369982e+00,  3.1638180e+01, ...,
        -2.5015537e+03, -5.1954197e+01, -1.7446728e+00]], dtype=float32)