In [1]:
import torch
import esm
from Bio import SeqIO
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import string
import numpy as np
from scipy.spatial.distance import squareform, pdist, cdist

In [2]:
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def read_sequence(filename: str) -> Tuple[str, str]:
    """ Reads the first (reference) sequences from a fasta or MSA file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)
    
def read_msa(filename: str) -> List[Tuple[str, str]]:
    """ Reads the sequences from an MSA file, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]

In [3]:
msas = {
    "1a3a": read_msa("test_msa.a3m")
}

In [4]:
# Select sequences from the MSA to maximize the hamming distance
# Alternatively, can use hhfilter 
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
    assert mode in ("max", "min")
    if len(msa) <= num_seqs:
        return msa
    
    array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)

    optfunc = np.argmax if mode == "max" else np.argmin
    all_indices = np.arange(len(msa))
    indices = [0]
    pairwise_distances = np.zeros((0, len(msa)))
    for _ in range(num_seqs - 1):
        dist = cdist(array[indices[-1:]], array, "hamming")
        pairwise_distances = np.concatenate([pairwise_distances, dist])
        shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
        shifted_index = optfunc(shifted_distance)
        index = np.delete(all_indices, indices)[shifted_index]
        indices.append(index)
    indices = sorted(indices)
    return [msa[idx] for idx in indices]

In [7]:
msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_transformer = msa_transformer.eval()
msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()

In [17]:
msa_transformer_predictions = {}
msa_transformer_results = []
for name, inputs in msas.items():
    inputs = greedy_select(inputs, num_seqs=128) # can change this to pass more/fewer sequences
    msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
    msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
    print(msa_transformer(msa_transformer_batch_tokens)["logits"].shape)
    # msa_transformer_predictions[name] = msa_transformer.predict_contacts(msa_transformer_batch_tokens)[0].cpu()
    # metrics = {"id": name, "model": "MSA Transformer (Unsupervised)"}
    # metrics.update(evaluate_prediction(msa_transformer_predictions[name], contacts[name]))
    # msa_transformer_results.append(metrics)
# msa_transformer_results = pd.DataFrame(msa_transformer_results)
# display(msa_transformer_results)

torch.Size([1, 128, 146, 33])


In [22]:
emb = msa_transformer(msa_transformer_batch_tokens)["logits"]
emb

tensor([[[[-12.9309, -23.8732, -24.2215,  ...,  -1.9398, -13.9956, -14.7635],
          [-13.9280, -12.3575, -12.6524,  ...,  -1.0194, -13.5932, -13.7057],
          [-15.9826, -20.6273, -20.8905,  ...,  -8.1170, -15.6904, -15.7424],
          ...,
          [-15.7191, -18.0337, -18.1623,  ...,  -6.7042, -15.4590, -15.6125],
          [-16.9020, -19.7343, -19.7679,  ...,  -7.2387, -16.2089, -16.1908],
          [-15.4685, -15.6624, -15.8068,  ...,  -4.8952, -14.3706, -14.6087]],

         [[-12.9736, -21.9237, -21.9615,  ...,   3.6814, -15.1494, -16.3454],
          [-15.6551, -17.5572, -17.5652,  ...,  -0.5021, -15.4875, -15.4964],
          [-15.6690, -20.1051, -20.2912,  ...,  -2.8869, -14.9445, -14.9622],
          ...,
          [-15.1065, -17.7589, -17.7223,  ...,  -2.5858, -15.5110, -15.8153],
          [-15.7245, -16.8178, -16.7517,  ...,   1.2008, -15.7335, -15.7986],
          [-14.2028, -15.3096, -15.2191,  ...,  10.8708, -12.2719, -11.9081]],

         [[-13.3200, -20.9067,

In [23]:
emb

tensor([[[[-12.9309, -23.8732, -24.2215,  ...,  -1.9398, -13.9956, -14.7635],
          [-13.9280, -12.3575, -12.6524,  ...,  -1.0194, -13.5932, -13.7057],
          [-15.9826, -20.6273, -20.8905,  ...,  -8.1170, -15.6904, -15.7424],
          ...,
          [-15.7191, -18.0337, -18.1623,  ...,  -6.7042, -15.4590, -15.6125],
          [-16.9020, -19.7343, -19.7679,  ...,  -7.2387, -16.2089, -16.1908],
          [-15.4685, -15.6624, -15.8068,  ...,  -4.8952, -14.3706, -14.6087]],

         [[-12.9736, -21.9237, -21.9615,  ...,   3.6814, -15.1494, -16.3454],
          [-15.6551, -17.5572, -17.5652,  ...,  -0.5021, -15.4875, -15.4964],
          [-15.6690, -20.1051, -20.2912,  ...,  -2.8869, -14.9445, -14.9622],
          ...,
          [-15.1065, -17.7589, -17.7223,  ...,  -2.5858, -15.5110, -15.8153],
          [-15.7245, -16.8178, -16.7517,  ...,   1.2008, -15.7335, -15.7986],
          [-14.2028, -15.3096, -15.2191,  ...,  10.8708, -12.2719, -11.9081]],

         [[-13.3200, -20.9067,