In [1]:
import torch

from Bio import SeqIO
# from transformers import AutoTokenizer, AutoModelForMaskedLM

In [2]:
import random

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
torch.cuda.empty_cache()

In [5]:
from esm.pretrained import esm_msa1b_t12_100M_UR50S

In [6]:
msa_transformer, msa_alphabet = esm_msa1b_t12_100M_UR50S()

In [7]:
def read_multi_fasta_for_esm_msa(file_path):
    """
    params:
        file_path: path to a fasta file
    return:
        a dictionary of sequences
    """
    sequences = {}
    current_sequence = ''
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith('>'):
                if current_sequence:
                    sequences[header] = current_sequence.upper().replace('.', '-')
                    current_sequence = ''
                header = line
            else:
                current_sequence += line
        if current_sequence:
            sequences[header] = current_sequence
    return sequences


def read_seq(fasta):
    for record in SeqIO.parse(fasta, "fasta"):
        return str(record.seq)

In [8]:
from huggingface_hub import login

In [9]:
# Replace 'your_access_token' with your actual token
login(token='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')

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /home/yining_yang/.cache/huggingface/token
Login successful


In [10]:
aa_seq_aln_file = "./data/proteingym_v1/aa_seq_aln_a2m/A0A1I9GEU1_NEIME_Kennouche_2019.a2m"


In [12]:
alignment_dict_esm_msa_dict = read_multi_fasta_for_esm_msa(aa_seq_aln_file)

# Convert the dictionary items to a list of tuples (label, msa_string)
msa_batch = list(alignment_dict_esm_msa_dict.items())

# Optionally, remove the last entry if needed (as in your original code)
msa_batch = msa_batch[:-1]

# Subsample if the MSA depth exceeds the maximum allowed (e.g., 1024 sequences)
max_depth = 50

# If the number of alignments exceeds max_depth, subsample the list
if len(msa_batch) > max_depth:
    print(f"MSA batch has {len(msa_batch)} alignments; subsampling to {max_depth}")
    msa_batch = random.sample(msa_batch, max_depth)

# Initialize the batch converter using the msa_alphabet from the pretrained model.
msa_batch_converter = msa_alphabet.get_batch_converter()

# Convert the (subsampled) msa_batch to tokens.
msa_labels, msa_strs, msa_tokens = msa_batch_converter(msa_batch)

# Print a summary of the inputs.
print("MSA Labels:", msa_labels)
print("MSA Strings:\n", msa_strs[0])  # Print the MSA for the first entry.

# Move the tokens to the appropriate device (GPU if available).
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
msa_tokens = msa_tokens.to(device)
msa_transformer = msa_transformer.to(device)

with torch.no_grad():
    # Specify the representation layers you want.
    # Here, we extract representations from layer 12.
    results = msa_transformer(msa_tokens, repr_layers=[12], return_contacts=False)
    token_representations = results["representations"][12]
    # token_representations shape: (batch, num_seqs, seq_length, representation_dim)


MSA batch has 5552 alignments; subsampling to 50
MSA Labels: [['>UniRef100_UPI00155E8EA2/3-133', '>UniRef100_UPI001896076C/9-134', '>UniRef100_UPI00082DC2DF/23-166', '>UniRef100_UPI001783E099/10-149', '>UniRef100_UPI00117D17EC/32-157', '>UniRef100_A0A7U6KL99/8-150', '>UniRef100_A0A3D1RW60/8-127', '>UniRef100_Q8RLT9/8-161', '>UniRef100_UPI0015D46767/1-145', '>UniRef100_A0A7C2E7M6/8-143', '>UniRef100_UPI000766641C/8-166', '>UniRef100_A0A495VA47/16-144', '>UniRef100_UPI00053BC81F/8-130', '>UniRef100_A0A1Y1SHR7/8-145', '>UniRef100_UPI00054116BB/8-147', '>UniRef100_UPI00186924AF/12-136', '>UniRef100_UPI001583061F/9-166', '>UniRef100_UPI001603077F/7-140', '>UniRef100_A0A3S0U622/8-166', '>UniRef100_UPI000F4F87FE/30-156', '>UniRef100_UPI001248B4FD/8-145', '>UniRef100_A0A0H1ATG7/7-152', '>UniRef100_UPI0003FED8BB/8-130', '>UniRef100_A0A1Y1QZK2/9-146', '>UniRef100_A0A377R156/8-182', '>UniRef100_A0A7Y1TCQ9/8-131', '>UniRef100_UPI0013722FC1/6-180', '>UniRef100_A0A3S8VGN9/8-182', '>UniRef100_UPI000F

In [13]:
token_representations

tensor([[[[-0.0203, -0.2191,  0.0057,  ...,  0.0508,  0.0891, -0.0280],
          [ 0.6360, -0.2429, -0.1378,  ...,  0.6515,  0.0030,  1.5509],
          [-0.0981, -0.9487,  0.8002,  ..., -0.6415,  0.2205,  1.1969],
          ...,
          [-0.0922, -0.5531,  0.0412,  ..., -0.8715,  0.3688, -0.1279],
          [-0.9696, -0.5027, -0.1401,  ..., -1.1013,  0.7913,  0.1694],
          [-0.9449, -0.7645,  0.1084,  ..., -0.0254,  0.9243, -0.1133]],

         [[-0.0076, -0.2741,  0.0303,  ..., -0.0188,  0.1474,  0.3271],
          [ 0.7523, -0.1193, -0.3029,  ...,  0.3149,  0.2522,  1.4373],
          [ 0.3001, -0.4573,  0.3895,  ..., -0.7306,  0.0835,  1.1819],
          ...,
          [-0.4645, -0.3951, -0.0735,  ...,  0.1369, -0.1148, -0.2066],
          [-0.6714, -0.8012,  0.0962,  ..., -1.5943,  0.0934,  1.5758],
          [-0.3111, -0.4528,  0.0788,  ..., -0.7532,  0.6446,  0.8619]],

         [[-0.1610, -0.1352, -0.0941,  ..., -0.0636, -0.0791, -0.0395],
          [ 0.9082, -1.0078, -