In [1]:
# Import necessary libraries
import torch
import os
import pandas as pd
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModelForMaskedLM
from scipy.stats import spearmanr
from tqdm import tqdm

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"


In [104]:
amino_acid_properties = {
    'A': {'hydrophobicity': 1.8,  'charge':  0, 'polarity':  0, 'molecular_weight':  89.09, 'volume':  88.6},
    'R': {'hydrophobicity': -4.5, 'charge': +1, 'polarity':  1, 'molecular_weight': 174.20, 'volume': 173.4},
    'N': {'hydrophobicity': -3.5, 'charge':  0, 'polarity':  1, 'molecular_weight': 132.12, 'volume': 114.1},
    'D': {'hydrophobicity': -3.5, 'charge': -1, 'polarity':  1, 'molecular_weight': 133.10, 'volume': 111.1},
    'C': {'hydrophobicity': 2.5,  'charge':  0, 'polarity':  0, 'molecular_weight': 121.15, 'volume': 108.5},
    'Q': {'hydrophobicity': -3.5, 'charge':  0, 'polarity':  1, 'molecular_weight': 146.15, 'volume': 143.8},
    'E': {'hydrophobicity': -3.5, 'charge': -1, 'polarity':  1, 'molecular_weight': 147.13, 'volume': 138.4},
    'G': {'hydrophobicity': -0.4, 'charge':  0, 'polarity':  0, 'molecular_weight':  75.07, 'volume':  60.1},
    'H': {'hydrophobicity': -3.2, 'charge':  0, 'polarity':  1, 'molecular_weight': 155.16, 'volume': 153.2},
    'I': {'hydrophobicity': 4.5,  'charge':  0, 'polarity':  0, 'molecular_weight': 131.17, 'volume': 166.7},
    'L': {'hydrophobicity': 3.8,  'charge':  0, 'polarity':  0, 'molecular_weight': 131.17, 'volume': 166.7},
    'K': {'hydrophobicity': -3.9, 'charge': +1, 'polarity':  1, 'molecular_weight': 146.19, 'volume': 168.6},
    'M': {'hydrophobicity': 1.9,  'charge':  0, 'polarity':  0, 'molecular_weight': 149.21, 'volume': 162.9},
    'F': {'hydrophobicity': 2.8,  'charge':  0, 'polarity':  0, 'molecular_weight': 165.19, 'volume': 189.9},
    'P': {'hydrophobicity': -1.6, 'charge':  0, 'polarity':  0, 'molecular_weight': 115.13, 'volume': 112.7},
    'S': {'hydrophobicity': -0.8, 'charge':  0, 'polarity':  1, 'molecular_weight': 105.09, 'volume':  89.0},
    'T': {'hydrophobicity': -0.7, 'charge':  0, 'polarity':  1, 'molecular_weight': 119.12, 'volume': 116.1},
    'W': {'hydrophobicity': -0.9, 'charge':  0, 'polarity':  0, 'molecular_weight': 204.23, 'volume': 227.8},
    'Y': {'hydrophobicity': -1.3, 'charge':  0, 'polarity':  1, 'molecular_weight': 181.19, 'volume': 193.6},
    'V': {'hydrophobicity': 4.2,  'charge':  0, 'polarity':  0, 'molecular_weight': 117.15, 'volume': 140.0},
}
device = "cuda" if torch.cuda.is_available() else "cpu"


def read_multi_fasta(file_path):
    """
    params:
        file_path: path to a fasta file
    return:
        a dictionary of sequences
    """
    sequences = {}
    current_sequence = ''
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith('>'):
                if current_sequence:
                    sequences[header] = current_sequence.upper().replace('-', '<pad>').replace('.', '<pad>')
                    current_sequence = ''
                header = line
            else:
                current_sequence += line
        if current_sequence:
            sequences[header] = current_sequence
    return sequences

def read_multi_fasta_for_esm_msa(file_path):
    """
    params:
        file_path: path to a fasta file
    return:
        a dictionary of sequences
    """
    sequences = {}
    current_sequence = ''
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith('>'):
                if current_sequence:
                    sequences[header] = current_sequence.upper().replace('.', '-')
                    current_sequence = ''
                header = line
            else:
                current_sequence += line
        if current_sequence:
            sequences[header] = current_sequence
    return sequences


def read_seq(fasta):
    for record in SeqIO.parse(fasta, "fasta"):
        return str(record.seq)

def count_matrix_from_residue_alignment(tokenizer, alignment_dict):
    alignment_seqs = list(alignment_dict.values())
    try:
        aln_start, aln_end = list(alignment_dict.keys())[0].split('/')[-1].split('-')
    except:
        aln_start, aln_end = 1, len(alignment_seqs[0])
    print(f">>> Alignment start: {aln_start}, end: {aln_end}")
    print(f">>> Start tokenizing {len(alignment_seqs)} residue alignment sequences")
    tokenized_results = tokenizer(alignment_seqs, return_tensors="pt", padding=True)
    alignment_ids = tokenized_results["input_ids"][:,1:-1]
    return alignment_ids, int(aln_start)-1, int(aln_end)
    # count distribution of each column, [seq_len, vocab_size]
    count_matrix = torch.zeros(alignment_ids.size(1), tokenizer.vocab_size)
    for i in tqdm(range(alignment_ids.size(1))):
        count_matrix[i] = torch.bincount(alignment_ids[:,i], minlength=tokenizer.vocab_size)
    # calculate coverage of each column and normalize count matrix
    # coverage = (1.0 - (count_matrix == tokenizer.pad_token_id).float().mean(dim=-1)).unsqueeze(-1).to(device)
    
    count_matrix = (count_matrix / count_matrix.sum(dim=1, keepdim=True)).to(device)
    # count_matrix = count_matrix * coverage
    return count_matrix, int(aln_start)-1, int(aln_end)


def count_matrix_from_structure_alignment(tokenizer, alignment_dict):
    alignment_seqs = list(alignment_dict.values())
    print(f">>> Start tokenizing {len(alignment_seqs)} structure alignment sequences")
    if len(alignment_seqs) == 0:
        return None
    tokenized_results = tokenizer(alignment_seqs, return_tensors="pt", padding=True)
    alignment_ids = tokenized_results["input_ids"][:,1:-1]
    return alignment_ids
    # count distribution of each column, [seq_len, vocab_size]
    count_matrix = torch.zeros(alignment_ids.size(1), tokenizer.vocab_size)
    for i in tqdm(range(alignment_ids.size(1))):
        count_matrix[i] = torch.bincount(alignment_ids[:,i], minlength=tokenizer.vocab_size)
    return count_matrix
    count_matrix = (count_matrix / count_matrix.sum(dim=1, keepdim=True)).to(device)
    return count_matrix
    


def calculate_property_difference(wild_aa, mutant_aa, weights=None):
    properties = amino_acid_properties[wild_aa].keys()
    if weights is None:
        weights = {prop: 1 for prop in properties}
    differences = []
    for prop in properties:
        wild_value = amino_acid_properties[wild_aa][prop]
        mutant_value = amino_acid_properties[mutant_aa][prop]
        difference = abs(mutant_value - wild_value)
        weighted_diff = weights.get(prop, 1) * difference
        differences.append(weighted_diff)
    return differences

def tokenize_structure_sequence(structure_sequence):
    shift_structure_sequence = [i + 3 for i in structure_sequence]
    shift_structure_sequence = [1, *shift_structure_sequence, 2]
    return torch.tensor([shift_structure_sequence,], dtype=torch.long)



In [5]:
residue_fasta = "./data/proteingym_v1/aa_seq/A0A1I9GEU1_NEIME_Kennouche_2019.fasta"

In [6]:
structure_fasta = "./data/proteingym_v1/struc_seq/2048/A0A1I9GEU1_NEIME_Kennouche_2019.fasta"

In [93]:
alignment_dict

{'>A0A1I9GEU1_NEIME/1-161': 'FTLIELMIVIAIVGILAAVALPAYQDYTARAQVSEAILLAEGQKSAVTEYYLNHGEWPGDNSSAGVATSADIKGKYVQSVTVANGVITAQMASSNVNNEIKSKKLSLWAKRQNGSVKWFCGQPVTRTTATATDVAAANGKTDDKINTKHLPSTCRDDSSAS',
 '>UniRef100_UPI0018A25760/3-135': '<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>ARAQVSEAILLAEGQKSAVTEYYLNHGKWPENNDSAGVASASKIIGKYVKSVTVTNGVVTAQMKPSGVNNEIKDKRLSLWAKREDGSVKWFCGQPVKRDNVAAADDDVTDDKNNNGIDTKHLPSTCRDKSSA<pad>',
 '>UniRef100_UPI00145ECBA7/3-130': '<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>ARAQVSEAILLAEGQKSAVTEYYLNHGKWPENNGDAGVASASKIIGKYVKQVEVKNGVVTATMNSSNVNKEIQGKRLSLWAKRQDGSVKWFCGQPVK<pad><pad>RNANAKDDAVTADKDKEIETKHLPSTCRDES<pad><pad><pad>',
 '>UniRef100_Q00045/8-162': 'FTLIELMIVIAIVGILAAVALPAYQDYTARAQVSEAILLAEGQKSAVTEYYLNHGIWPENNPA<pad>GVASASDIKGKYVQSVTVANGVVTAQMKSDGVNKEIKNKKLSLWARREAGSVKWFCGQPVTR

In [7]:
sequence = read_seq(residue_fasta)
structure_sequence = read_seq(structure_fasta)



In [8]:
len(sequence)

161

In [9]:
len(structure_sequence)

724

In [10]:
deprot = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)

In [11]:
structure_sequence = [int(i) for i in structure_sequence.split(",")]
ss_input_ids = tokenize_structure_sequence(structure_sequence).to(device)
tokenized_results = tokenizer([sequence], return_tensors="pt")
input_ids = tokenized_results["input_ids"].to(device)
attention_mask = tokenized_results["attention_mask"].to(device)


In [12]:
output_attention = True

In [13]:
deprot = deprot.to(device)  # Move the model to the same device as the tensors

outputs = deprot(
    input_ids=input_ids,
    attention_mask=attention_mask,
    ss_input_ids=ss_input_ids,
    labels=input_ids,
    output_attentions=output_attention,  # No need to call .to(device) on a boolean
    return_dict=output_attention
)


In [14]:
logits = outputs.logits[0]
logits = torch.log_softmax(logits[1:-1, :], dim=-1)

In [17]:
logits

tensor([[-16.9099, -12.5894, -12.2322,  ...,  -8.7536, -11.9130, -16.8715],
        [-22.4552, -17.1666, -18.0220,  ..., -14.3106, -19.9979, -22.4330],
        [-21.9546, -16.4669, -17.9962,  ..., -11.6968, -19.0950, -21.9827],
        ...,
        [-22.1040, -16.7349, -17.2137,  ...,  -3.3649, -18.5772, -22.1109],
        [-22.3519, -17.0825, -17.8749,  ...,  -5.2220, -19.3547, -22.3801],
        [-22.8444, -15.9082, -17.2419,  ...,  -4.3891, -19.0890, -22.8425]],
       device='cuda:0', grad_fn=<LogSoftmaxBackward0>)

In [18]:
tokenizer

EsmTokenizer(name_or_path='AI4Protein/ProSST-2048', vocab_size=25, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	23: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	24: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [19]:
input_ids.size()

torch.Size([1, 163])

In [20]:
outputs.attentions[-1].size()

torch.Size([1, 12, 163, 163])

In [21]:
len(outputs.attentions)

12

In [22]:
aa_seq_aln_file = "./data/proteingym_v1/aa_seq_aln_a2m/A0A1I9GEU1_NEIME_Kennouche_2019.a2m"

alignment_dict = read_multi_fasta(aa_seq_aln_file)

In [23]:
import random

In [24]:
sample_times = 1
sample_ratio=1.0

alignment_matrix, aln_start, aln_end = count_matrix_from_residue_alignment(tokenizer, alignment_dict)


>>> Alignment start: 1, end: 161
>>> Start tokenizing 5553 residue alignment sequences


In [62]:
tokenizer.decode("135")

'<cls> A D'

In [25]:
alignment_matrix.size()

torch.Size([5553, 161])

In [26]:
alignment_matrix

tensor([[ 7, 19, 12,  ..., 18,  3, 18],
        [ 0,  0,  0,  ..., 18,  3,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [ 7, 19, 12,  ...,  0,  0,  0],
        [ 7, 19, 12,  ...,  0,  0,  0],
        [23,  7, 12,  ...,  0,  0,  0]])

In [27]:
for sample in range(sample_times):
    if sample_ratio < 1.0:
        print(f">>> Sample {sample+1}/{sample_times} with ratio {sample_ratio}")
        sample_size = int(len(alignment_matrix) * sample_ratio)
        sample_indices = random.sample(range(len(alignment_matrix)), sample_size)
        alignment_matrix_sample = alignment_matrix[sample_indices]
    else:
        alignment_matrix_sample = alignment_matrix


In [28]:
alignment_matrix_sample.size()

torch.Size([5553, 161])

In [29]:
outputs.attentions[-1].mean(dim=1)[0][1:-1,1:-1].size()

torch.Size([161, 161])

In [30]:
help(torch.bincount)

Help on built-in function bincount in module torch:

bincount(...)
    bincount(input, weights=None, minlength=0) -> Tensor
    
    Count the frequency of each value in an array of non-negative ints.
    
    The number of bins (size 1) is one larger than the largest value in
    :attr:`input` unless :attr:`input` is empty, in which case the result is a
    tensor of size 0. If :attr:`minlength` is specified, the number of bins is at least
    :attr:`minlength` and if :attr:`input` is empty, then the result is tensor of size
    :attr:`minlength` filled with zeros. If ``n`` is the value at position ``i``,
    ``out[n] += weights[i]`` if :attr:`weights` is specified else
    ``out[n] += 1``.
    
    Note:
        This operation may produce nondeterministic gradients when given tensors on a CUDA device. See :doc:`/notes/randomness` for more information.
    
    Arguments:
        input (Tensor): 1-d int tensor
        weights (Tensor): optional, weight for each value in the input tens

In [31]:
[outputs.attentions[-1][0,i][aln_start:aln_end,aln_start:aln_end].size() for i in range(12)]

[torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161]),
 torch.Size([161, 161])]

In [32]:
outputs.attentions[-1].var(dim=1)[0][1:-1,1:-1]

tensor([[2.5114e-05, 4.6668e-03, 1.4003e-03,  ..., 4.5770e-07, 3.0292e-06,
         8.4082e-06],
        [1.4714e-03, 4.4108e-07, 1.1292e-03,  ..., 2.6650e-07, 8.8435e-07,
         5.0666e-07],
        [3.5260e-03, 4.6967e-04, 1.0013e-07,  ..., 1.7462e-07, 1.7166e-06,
         8.5394e-07],
        ...,
        [2.7683e-07, 4.5892e-07, 2.8372e-07,  ..., 9.6783e-08, 2.2760e-03,
         1.6854e-03],
        [3.0547e-06, 5.0321e-07, 7.0277e-07,  ..., 3.1275e-03, 3.8788e-08,
         3.2378e-03],
        [9.1764e-07, 1.0339e-07, 3.7095e-07,  ..., 1.0505e-03, 9.9248e-04,
         3.1883e-06]], device='cuda:0', grad_fn=<SliceBackward0>)

In [33]:
outputs.attentions[-1].mean(dim=1)[0][1:-1,1:-1]

tensor([[0.0023, 0.0311, 0.0163,  ..., 0.0004, 0.0009, 0.0013],
        [0.0241, 0.0004, 0.0180,  ..., 0.0003, 0.0005, 0.0006],
        [0.0272, 0.0132, 0.0002,  ..., 0.0003, 0.0008, 0.0007],
        ...,
        [0.0003, 0.0003, 0.0003,  ..., 0.0002, 0.0299, 0.0278],
        [0.0008, 0.0004, 0.0004,  ..., 0.0340, 0.0002, 0.0435],
        [0.0006, 0.0003, 0.0004,  ..., 0.0239, 0.0257, 0.0009]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [34]:
from huggingface_hub import login

In [35]:
# Replace 'your_access_token' with your actual token
login(token='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')

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /home/yining_yang/.cache/huggingface/token
Login successful


In [36]:
import esm

In [None]:
# from transformers import AutoModel
# esm = AutoModel.from_pretrained("EvolutionaryScale/esm3-sm-open-v1",trust_remote_code=True)

OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like EvolutionaryScale/esm3-sm-open-v1 is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

In [42]:
msa_transformer, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()


In [67]:
msa_alphabet.encode("AC-F")

[5, 23, 30, 18]

In [79]:
help(msa_alphabet)

Help on Alphabet in module esm.data object:

class Alphabet(builtins.object)
 |  Alphabet(standard_toks: Sequence[str], prepend_toks: Sequence[str] = ('<null_0>', '<pad>', '<eos>', '<unk>'), append_toks: Sequence[str] = ('<cls>', '<mask>', '<sep>'), prepend_bos: bool = True, append_eos: bool = False, use_msa: bool = False)
 |  
 |  Methods defined here:
 |  
 |  __init__(self, standard_toks: Sequence[str], prepend_toks: Sequence[str] = ('<null_0>', '<pad>', '<eos>', '<unk>'), append_toks: Sequence[str] = ('<cls>', '<mask>', '<sep>'), prepend_bos: bool = True, append_eos: bool = False, use_msa: bool = False)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  __len__(self)
 |  
 |  encode(self, text)
 |  
 |  get_batch_converter(self, truncation_seq_length: int = None)
 |  
 |  get_idx(self, tok)
 |  
 |  get_tok(self, ind)
 |  
 |  to_dict(self)
 |  
 |  tokenize(self, text, **kwargs) -> List[str]
 |      Inspired by https://github.com/huggingface/transforme

In [57]:
msa_alphabet.all_special_tokens

['<eos>', '<unk>', '<pad>', '<cls>', '<mask>']

In [86]:
msa_alphabet.encode("ACEF")

[5, 23, 9, 18]

In [108]:
list(alignment_dict.values())[:5]

['FTLIELMIVIAIVGILAAVALPAYQDYTARAQVSEAILLAEGQKSAVTEYYLNHGEWPGDNSSAGVATSADIKGKYVQSVTVANGVITAQMASSNVNNEIKSKKLSLWAKRQNGSVKWFCGQPVTRTTATATDVAAANGKTDDKINTKHLPSTCRDDSSAS',
 '<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>ARAQVSEAILLAEGQKSAVTEYYLNHGKWPENNDSAGVASASKIIGKYVKSVTVTNGVVTAQMKPSGVNNEIKDKRLSLWAKREDGSVKWFCGQPVKRDNVAAADDDVTDDKNNNGIDTKHLPSTCRDKSSA<pad>',
 '<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>ARAQVSEAILLAEGQKSAVTEYYLNHGKWPENNGDAGVASASKIIGKYVKQVEVKNGVVTATMNSSNVNKEIQGKRLSLWAKRQDGSVKWFCGQPVK<pad><pad>RNANAKDDAVTADKDKEIETKHLPSTCRDES<pad><pad><pad>',
 'FTLIELMIVIAIVGILAAVALPAYQDYTARAQVSEAILLAEGQKSAVTEYYLNHGIWPENNPA<pad>GVASASDIKGKYVQSVTVANGVVTAQMKSDGVNKEIKNKKLSLWARREAGSVKWFCGQPVTRDNAGTDAVTADTTGKDKEIDTKHLPSTCR<pad><pad><pad><pad><pad><pad>',
 '<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

In [163]:
alignment_dict_esm_msa_dict = read_multi_fasta_for_esm_msa(aa_seq_aln_file)

In [None]:

msa_batch_converter = msa_alphabet.get_batch_converter()
# prot_ind=0
# msa = "\n".join(list(alignment_dict_esm_msa_dict.values())[:-1])
# # msa_labels = ["A0A1I9GEU1"]  # Label for the MSA; adjust as necessary
# msa_batch = [("PROT1", msa)]



In [173]:
msa_batch = list(alignment_dict_esm_msa_dict.items())[:-1]

In [179]:
msa_labels, msa_strs, msa_tokens = msa_batch_converter(msa_batch)

In [180]:
# Print a summary of the inputs.
print("MSA Labels:", msa_labels)
print("MSA Strings:\n", msa_strs[0])  # Print the msa for the first entry.

# Move the tokens to the appropriate device (GPU if available).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
msa_tokens = msa_tokens.to(device)
msa_transformer = msa_transformer.to(device)


MSA Labels: [['>A0A1I9GEU1_NEIME/1-161', '>UniRef100_UPI0018A25760/3-135', '>UniRef100_UPI00145ECBA7/3-130', '>UniRef100_Q00045/8-162', '>UniRef100_A0A7U1LX12/3-136', '>UniRef100_UPI000BA48856/3-130', '>UniRef100_UPI0005E9AABB/13-136', '>UniRef100_UPI0013E0C905/9-133', '>UniRef100_UPI0018C26155/1-133', '>UniRef100_D6HB30/9-137', '>UniRef100_B0FXJ6/2-160', '>UniRef100_UPI0005E0CA5C/8-166', '>UniRef100_UPI000C3233B7/8-166', '>UniRef100_UPI000C332612/8-166', '>UniRef100_UPI000766DABE/8-165', '>UniRef100_UPI000C323150/8-166', '>UniRef100_UPI000C34031D/8-165', '>UniRef100_UPI000C33F616/8-166', '>UniRef100_P57039/8-168', '>UniRef100_UPI000C326753/8-166', '>UniRef100_UPI000C33E39E/8-166', '>UniRef100_UPI000C330815/8-165', '>UniRef100_UPI000C321D0A/8-166', '>UniRef100_A0A0G4BXR5/8-166', '>UniRef100_UPI000766626E/8-166', '>UniRef100_UPI0018A26FAC/12-140', '>UniRef100_UPI000E57AB74/8-165', '>UniRef100_UPI000FC9B0E8/8-165', '>UniRef100_A0A807D5K8/8-165', '>UniRef100_UPI000E58367C/8-166', '>UniRef

In [181]:
with torch.no_grad():
    # Specify the representation layers you want.
    # Here, we extract representations from layer 12.
    results = msa_transformer(msa_tokens, repr_layers=[12], return_contacts=False)
    token_representations = results["representations"][12]
    # token_representations shape: (batch, num_seqs, seq_length, representation_dim)


RuntimeError: Using model with MSA position embedding trained on maximum MSA depth of 1024, but received 5552 alignments.

In [177]:
msa_alphabet.all_toks

['<cls>',
 '<pad>',
 '<eos>',
 '<unk>',
 'L',
 'A',
 'G',
 'V',
 'S',
 'E',
 'R',
 'T',
 'I',
 'D',
 'P',
 'K',
 'Q',
 'N',
 'F',
 'Y',
 'M',
 'H',
 'W',
 'C',
 'X',
 'B',
 'U',
 'Z',
 'O',
 '.',
 '-',
 '<null_1>',
 '<mask>']

In [183]:
msa_batch

[('>A0A1I9GEU1_NEIME/1-161',
  'FTLIELMIVIAIVGILAAVALPAYQDYTARAQVSEAILLAEGQKSAVTEYYLNHGEWPGDNSSAGVATSADIKGKYVQSVTVANGVITAQMASSNVNNEIKSKKLSLWAKRQNGSVKWFCGQPVTRTTATATDVAAANGKTDDKINTKHLPSTCRDDSSAS'),
 ('>UniRef100_UPI0018A25760/3-135',
  '----------------------------ARAQVSEAILLAEGQKSAVTEYYLNHGKWPENNDSAGVASASKIIGKYVKSVTVTNGVVTAQMKPSGVNNEIKDKRLSLWAKREDGSVKWFCGQPVKRDNVAAADDDVTDDKNNNGIDTKHLPSTCRDKSSA-'),
 ('>UniRef100_UPI00145ECBA7/3-130',
  '----------------------------ARAQVSEAILLAEGQKSAVTEYYLNHGKWPENNGDAGVASASKIIGKYVKQVEVKNGVVTATMNSSNVNKEIQGKRLSLWAKRQDGSVKWFCGQPVK--RNANAKDDAVTADKDKEIETKHLPSTCRDES---'),
 ('>UniRef100_Q00045/8-162',
  'FTLIELMIVIAIVGILAAVALPAYQDYTARAQVSEAILLAEGQKSAVTEYYLNHGIWPENNPA-GVASASDIKGKYVQSVTVANGVVTAQMKSDGVNKEIKNKKLSLWARREAGSVKWFCGQPVTRDNAGTDAVTADTTGKDKEIDTKHLPSTCR------'),
 ('>UniRef100_A0A7U1LX12/3-136',
  '----------------------------ARAQVSEAILLAEGQKSAVTEYYLNHGEWPENNTSAGVASADKIKGKYVQKVEVAKGVVTAQMASSNVNKEIKDKKLSLWARREDGSVKWFCGQPVTRGAGNKADDVTKAGNDNEKINTKHLPSTCRDKST--'

In [184]:
alignment_dict_esm_msa_dict = read_multi_fasta_for_esm_msa(aa_seq_aln_file)

# Convert the dictionary items to a list of tuples (label, msa_string)
msa_batch = list(alignment_dict_esm_msa_dict.items())

# Optionally, remove the last entry if needed (as in your original code)
msa_batch = msa_batch[:-1]

# Subsample if the MSA depth exceeds the maximum allowed (e.g., 1024 sequences)
max_depth = 1024

# If the number of alignments exceeds max_depth, subsample the list
if len(msa_batch) > max_depth:
    print(f"MSA batch has {len(msa_batch)} alignments; subsampling to {max_depth}")
    msa_batch = random.sample(msa_batch, max_depth)

# Initialize the batch converter using the msa_alphabet from the pretrained model.
msa_batch_converter = msa_alphabet.get_batch_converter()

# Convert the (subsampled) msa_batch to tokens.
msa_labels, msa_strs, msa_tokens = msa_batch_converter(msa_batch)

# Print a summary of the inputs.
print("MSA Labels:", msa_labels)
print("MSA Strings:\n", msa_strs[0])  # Print the MSA for the first entry.

# Move the tokens to the appropriate device (GPU if available).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
msa_tokens = msa_tokens.to(device)
msa_transformer = msa_transformer.to(device)

with torch.no_grad():
    # Specify the representation layers you want.
    # Here, we extract representations from layer 12.
    results = msa_transformer(msa_tokens, repr_layers=[12], return_contacts=False)
    token_representations = results["representations"][12]
    # token_representations shape: (batch, num_seqs, seq_length, representation_dim)


MSA batch has 5552 alignments; subsampling to 1024
MSA Labels: [['>UniRef100_UPI000646B84C/8-158', '>UniRef100_UPI00036ABE6C/7-142', '>UniRef100_UPI00061B9051/13-140', '>UniRef100_A0A6G8S729/7-178', '>UniRef100_UPI0012B36174/20-150', '>UniRef100_A0A154QKC9/10-146', '>UniRef100_UPI00145CBC7F/8-140', '>UniRef100_UPI00076655CB/8-165', '>UniRef100_D4P890/8-149', '>UniRef100_A0A495I1W9/9-175', '>UniRef100_A0A7T9UUD7/8-179', '>UniRef100_UPI0012B36828/3-126', '>UniRef100_UPI0015DE0C70/9-131', '>UniRef100_UPI0018A27A2E/12-135', '>UniRef100_A0A4P7C1Z3/8-164', '>UniRef100_N8XK23/6-140', '>UniRef100_N8R3C7/8-182', '>UniRef100_A0A849RDX7/8-180', '>UniRef100_UPI0008A5E0B2/11-149', '>UniRef100_A0A7L6A5K2/8-143', '>UniRef100_A0A496JU22/9-144', '>UniRef100_UPI00076681EC/8-167', '>UniRef100_UPI000FF661E4/7-178', '>UniRef100_UPI0013B4560C/36-165', '>UniRef100_UPI001893919B/9-133', '>UniRef100_A0A4R2MXF6/9-161', '>UniRef100_UPI00155AECAF/30-157', '>UniRef100_UPI0009E894F9/8-133', '>UniRef100_UPI0012412C1

OutOfMemoryError: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 0 has a total capacty of 10.90 GiB of which 631.06 MiB is free. Including non-PyTorch memory, this process has 10.05 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 331.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [89]:
sample_times = 1
sample_ratio=1.0

alignment_matrix_msa1b, aln_start, aln_end = count_matrix_from_residue_alignment(msa_alphabet.encode, alignment_dict)

>>> Alignment start: 1, end: 161
>>> Start tokenizing 5553 residue alignment sequences


TypeError: encode() got an unexpected keyword argument 'return_tensors'

In [43]:
msa_transformer

MSATransformer(
  (embed_tokens): Embedding(33, 768, padding_idx=1)
  (dropout_module): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x AxialTransformerLayer(
      (row_self_attention): NormalizedResidualBlock(
        (layer): RowSelfAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (dropout_module): Dropout(p=0.1, inplace=False)
        )
        (dropout_module): Dropout(p=0.1, inplace=False)
        (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
      (column_self_attention): NormalizedResidualBlock(
        (layer): ColumnSelfAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768,

In [45]:
alignment_dict.items

<function dict.items>

In [90]:
len(alignment_dict.items())

5553

In [88]:
msa_alphabet.all_toks

['<cls>',
 '<pad>',
 '<eos>',
 '<unk>',
 'L',
 'A',
 'G',
 'V',
 'S',
 'E',
 'R',
 'T',
 'I',
 'D',
 'P',
 'K',
 'Q',
 'N',
 'F',
 'Y',
 'M',
 'H',
 'W',
 'C',
 'X',
 'B',
 'U',
 'Z',
 'O',
 '.',
 '-',
 '<null_1>',
 '<mask>']

In [87]:
# Load the ESM-MSA-1b model and tokenizer from Hugging Face Hub.
msa_model_name = "EvolutionaryScale/esm3-sm-open-v1" # or another MSA-capable model
msa_tokenizer = AutoTokenizer.from_pretrained(msa_model_name)
msa_model = AutoModelForMaskedLM.from_pretrained(msa_model_name).to(device)
msa_model.eval()  # set to evaluation mode


OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like EvolutionaryScale/esm3-sm-open-v1 is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.