In [1]:
from typing import List, Tuple
import os
import string
import sys
import numpy as np
import torch
from scipy.spatial.distance import squareform, pdist, cdist
from Bio import SeqIO
import pickle
from time import sleep
import esm
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f4674456310>

In [2]:
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

In [3]:
def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)

def read_msa(filename: str) -> List[Tuple[str, str]]:
    """ Reads the sequences from an MSA file, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]

In [4]:
# Select sequences from the MSA to maximize the hamming distance
# Alternatively, can use hhfilter
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
    assert mode in ("max", "min")
    if len(msa) <= num_seqs:
        return msa

    array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)

    optfunc = np.argmax if mode == "max" else np.argmin
    all_indices = np.arange(len(msa))
    indices = [0]
    pairwise_distances = np.zeros((0, len(msa)))
    for _ in range(num_seqs - 1):
        dist = cdist(array[indices[-1:]], array, "hamming")
        pairwise_distances = np.concatenate([pairwise_distances, dist])
        shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
        shifted_index = optfunc(shifted_distance)
        index = np.delete(all_indices, indices)[shifted_index]
        indices.append(index)
    indices = sorted(indices)
    return [msa[idx] for idx in indices]

In [5]:
def main(pdb_ids, msas_pickle=False):
    """Creates a pt file containing the embeddings for each pdb ID in the input."""

    torch.cuda.set_device(1)

    if msas_pickle == False:
        msas = {}
        for name in pdb_ids:
            msa_path = "../data/alignments/aligned_" + name.lower() + ".a3m"
            if os.path.isfile(msa_path):
                msas[name] = read_msa(msa_path)
    else:
        with open(msas_pickle, 'rb') as unpickled_msas:
            msas = pickle.load(unpickled_msas)

    msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
    msa_transformer = msa_transformer.eval().cuda()
    msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()

    for name in pdb_ids:
        try:
            inputs = msas[name]
            inputs = greedy_select(inputs, num_seqs=512)
            msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
            msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
            with torch.no_grad():
                result = msa_transformer(msa_transformer_batch_tokens, repr_layers=[12])
            embeddings = result["representations"][12]
            torch.save(embeddings, '../data/embeddings/embeddings_' + name + '.pt')
        except BaseException as e:
            print("Couldn't create embeddings for " + name)
            print(e)
            torch.cuda.empty_cache()
            sleep(5)

In [10]:
if __name__ == "__main__":
    arguments = sys.argv[1:]
    if type(arguments[0]) == list:
        pdbs = [pdb for pdb in arguments[0] if len(pdb) == 6]
    else:
        pdbs = [arg for arg in arguments if len(arg) == 6]
    if len(pdbs) == 0:
        print("Please provide some pdb IDs in the format '1abc_a'")

    msa_pkl = [arg for arg in arguments if arg.endswith('.pkl')]
    if len(msa_pkl) == 0:
        msa_pkl = [False]

    main(pdbs, msas_pickle=msa_pkl[0])

Please provide some pdb IDs in the format '1abc_a'


ValueError: IPython won't let you open fd=False by default as it is likely to crash IPython. If you know what you are doing, you can use builtins' open.

In [11]:
main(['1um_b'])

1um_b
'1um_b'
