In [1]:
import torch

from Bio import SeqIO
# from transformers import AutoTokenizer, AutoModelForMaskedLM

In [2]:
import random

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
torch.cuda.empty_cache()

In [5]:
from esm.pretrained import esm_msa1b_t12_100M_UR50S

In [6]:
msa_transformer, msa_alphabet = esm_msa1b_t12_100M_UR50S()

In [7]:
def read_multi_fasta_for_esm_msa(file_path):
    """
    params:
        file_path: path to a fasta file
    return:
        a dictionary of sequences
    """
    sequences = {}
    current_sequence = ''
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith('>'):
                if current_sequence:
                    sequences[header] = current_sequence.upper().replace('.', '-')
                    current_sequence = ''
                header = line
            else:
                current_sequence += line
        if current_sequence:
            sequences[header] = current_sequence
    return sequences


def read_seq(fasta):
    for record in SeqIO.parse(fasta, "fasta"):
        return str(record.seq)

In [8]:
from huggingface_hub import login

In [9]:
# Replace 'your_access_token' with your actual token
login(token='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')

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /home/yining_yang/.cache/huggingface/token
Login successful


In [11]:
aa_seq_aln_file = "./data/proteingym_v1/aa_seq_aln_a2m/A0A1I9GEU1_NEIME_Kennouche_2019.a2m"

In [12]:
alignment_dict_esm_msa_dict = read_multi_fasta_for_esm_msa(aa_seq_aln_file)

# Convert the dictionary items to a list of tuples (label, msa_string)
msa_batch = list(alignment_dict_esm_msa_dict.items())

# Optionally, remove the last entry if needed (as in your original code)
msa_batch = msa_batch[:-1]

# Subsample if the MSA depth exceeds the maximum allowed (e.g., 1024 sequences)
max_depth = 128

# If the number of alignments exceeds max_depth, subsample the list
if len(msa_batch) > max_depth:
    print(f"MSA batch has {len(msa_batch)} alignments; subsampling to {max_depth}")
    msa_batch = random.sample(msa_batch, max_depth)

# Initialize the batch converter using the msa_alphabet from the pretrained model.
msa_batch_converter = msa_alphabet.get_batch_converter()

# Convert the (subsampled) msa_batch to tokens.
msa_labels, msa_strs, msa_tokens = msa_batch_converter(msa_batch)

# Print a summary of the inputs.
print("MSA Labels:", msa_labels)
print("MSA Strings:\n", msa_strs[0])  # Print the MSA for the first entry.

# Move the tokens to the appropriate device (GPU if available).
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
msa_tokens = msa_tokens.to(device)
msa_transformer = msa_transformer.to(device)

with torch.no_grad():
    # Specify the representation layers you want.
    # Here, we extract representations from layer 12.
    results = msa_transformer(msa_tokens, repr_layers=[12], return_contacts=False)
    token_representations = results["representations"][12]
    # token_representations shape: (batch, num_seqs, seq_length, representation_dim)


MSA batch has 5552 alignments; subsampling to 128
MSA Labels: [['>UniRef100_UPI00119DB9B3/3-130', '>UniRef100_UPI0005E9AF04/8-162', '>UniRef100_UPI00145CF9D1/21-150', '>UniRef100_UPI000D31B81F/8-143', '>UniRef100_A0A1P8FNR1/8-149', '>UniRef100_UPI0018CDB834/6-156', '>UniRef100_A0A0W0XRY8/6-135', '>UniRef100_UPI000A191483/8-194', '>UniRef100_UPI000317E613/13-138', '>UniRef100_UPI0005E3FFB9/8-160', '>UniRef100_A0A4Z1BZ50/8-151', '>UniRef100_UPI0003B3575C/12-152', '>UniRef100_UPI001372B1D1/7-143', '>UniRef100_A0A4Q2JAI3/8-139', '>UniRef100_UPI000F5C00F1/8-165', '>UniRef100_UPI0015D5D4E9/13-142', '>UniRef100_A0A4P9UMA4/12-161', '>UniRef100_A0A7X0AE88/10-144', '>UniRef100_Q5WVD6/6-136', '>UniRef100_A0A838YBR4/7-142', '>UniRef100_UPI0007947582/7-157', '>UniRef100_A0A7Z2V3F6/9-164', '>UniRef100_UPI00145F8E08/31-158', '>UniRef100_UPI001909AF05/6-137', '>UniRef100_UPI000E570538/8-165', '>UniRef100_UPI0014032FE8/15-137', '>UniRef100_A0A385BZX7/6-145', '>UniRef100_A0A0W1KTY4/8-142', '>UniRef100_U

In [None]:
token_representations

tensor([[[[-0.0117, -0.2484, -0.0045,  ...,  0.0952,  0.0186,  0.1032],
          [ 0.3641, -0.3213, -0.1205,  ...,  0.4233, -0.1054,  0.9100],
          [-0.2069, -0.1613,  0.6332,  ..., -0.1065,  0.1151,  1.2488],
          ...,
          [-0.4485,  0.0610,  0.5461,  ..., -1.0651,  0.3013, -0.0944],
          [-0.6391,  0.0368,  0.5162,  ..., -0.4670,  0.6478,  0.7312],
          [-0.5557, -0.3431,  0.3554,  ...,  0.2923,  0.6270,  0.7664]],

         [[ 0.0159, -0.3496, -0.0382,  ...,  0.0543,  0.1593,  0.2793],
          [ 0.6366, -0.9087, -0.2999,  ...,  0.9616,  0.2380,  0.1537],
          [-0.5837, -1.5020,  0.1612,  ...,  0.1205,  0.4152,  0.5169],
          ...,
          [-0.8305, -0.2286,  0.2418,  ..., -1.2774,  0.5381,  0.7949],
          [ 0.0947,  0.2974,  0.1551,  ..., -0.2894,  0.9339,  1.0935],
          [-0.1476,  0.0495,  0.1829,  ..., -0.5649,  1.0071,  1.3754]],

         [[ 0.0353, -0.3179,  0.0543,  ...,  0.0405,  0.1117,  0.2224],
          [ 1.0026,  0.3020,  

In [17]:
print("MSA embedding output dimensionalities: (1, N, L, d): ",token_representations.size())

MSA embedding output dimensionalities: (1, N, L, d):  torch.Size([1, 128, 162, 768])


In [14]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 524991 KiB |   1243 MiB |  38850 MiB |  38337 MiB |
|       from large pool | 524023 KiB |   1242 MiB |  38693 MiB |  38181 MiB |
|       from small pool |    968 KiB |      1 MiB |    156 MiB |    155 MiB |
|---------------------------------------------------------------------------|
| Active memory         | 524991 KiB |   1243 MiB |  38850 MiB |  38337 MiB |
|       from large pool | 524023 KiB |   1242 MiB |  38693 MiB |  38181 MiB |
|       from small pool |    968 KiB |      1 MiB |    156 MiB |    155 MiB |
|---------------------------------------------------------------