# Transform MSAs the corresponding embeddings of the MSA Transformer

In [5]:
import esm
import torch
import os
from Bio import SeqIO
import itertools
from typing import List, Tuple
import string
import time
from tqdm.notebook import tqdm

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f69b043b820>

In [7]:
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def read_sequence(filename: str) -> Tuple[str, str]:
    """ Reads the first (reference) sequences from a fasta or MSA file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)

def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]:
    """ Reads the first nseq sequences from an MSA file, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq)))
            for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)]

## Import the MSA Transformer

In [8]:
# This can that a while to download (1.3Gb)= but once done, it's kept in memory
msa_transformer, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_transformer = msa_transformer.eval()
msa_batch_converter = msa_alphabet.get_batch_converter()

## Read  MSA file

In [9]:
NB_seqs_per_msa = 100
NB_msa= 4431
#mb_layers = [i for i in range(1,13)] #we get all the intermediate layers at once
Emb_layers = [2,6,11]

In [10]:
PATH_TO_MSAs = "C:/Users/caspe/PycharmProjects/NLP/MSA_a3m/" #where to read the MSAs
EMBEDDINGS_PATH = "C:/Users/caspe/PycharmProjects/NLP/MSA_embeddings/" #where to save the embeddings
if not(os.path.isdir(EMBEDDINGS_PATH+'MSA_transformer_embeddings')):
    os.mkdir(EMBEDDINGS_PATH+'MSA_transformer_embeddings')
EMBEDDINGS_PATH += 'MSA_transformer_embeddings/' #create a new folder to save the embeddings


MSAs = os.listdir(PATH_TO_MSAs)

In [11]:
print(len(MSAs))

3032


In [12]:
# format MSA string as needed
#for f in MSAs:
#    with open(PATH_TO_MSAs+f, "r") as msa_file:
#        content = msa_file.read()

#    content = content.replace(">", ">\n")
#    with open(PATH_TO_MSAs+f, "w") as msa_file:
#        msa_file.write(content)

In [13]:
msa_data = [
    read_msa(PATH_TO_MSAs+f, NB_seqs_per_msa) for f in MSAs
]
msa_batch_labels, msa_batch_strs, msa_batch_tokens = msa_batch_converter(msa_data)
msa_batch_tokens = msa_batch_tokens
print(msa_batch_tokens.size(), msa_batch_tokens.dtype)
# Should be a 3D tensor with dtype torch.int64. of shape NB_SEQ, SIZE_MSA, MAX_LEN_SEQ

torch.Size([5, 10, 184]) torch.int64


## Run the MSA transformer

In [16]:
# Set up GPU
torch.cuda.empty_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


#msa_batch_tokens = msa_batch_tokens.to(device)
#print(msa_batch_tokens.is_cuda)

msa_transformer = msa_transformer.to(device)

10.317943334579468


In [11]:
#torch.cuda.empty_cache()
#torch.cuda.memory_stats(device)

torch.Size([5, 10, 184, 768])

In [None]:
t1 = time.time()
for i in tqdm(range(len(MSAs))):
    msa_name = str(MSAs[i:i+1][0][:-4])
    os.mkdir(EMBEDDINGS_PATH+msa_name)
    msa_data = [read_msa(PATH_TO_MSAs+f, NB_seqs_per_msa) for f in MSAs[i:i+1]]
    msa_batch_labels, msa_batch_strs, msa_batch_tokens = msa_batch_converter(msa_data)
    torch.cuda.empty_cache()
    msa_batch_tokens = msa_batch_tokens.to(device)
    results = msa_transformer(msa_batch_tokens, repr_layers=Emb_layers)
    embeddings = [results["representations"][emb_layer][:,0,:,:].clone() for emb_layer in Emb_layers]
    for emb_layer in Emb_layers:
        torch.save(embeddings, EMBEDDINGS_PATH+msa_name+'/embeddings_layer_'+str(emb_layer)+'_MSA_Transformer.pt')

t2 = time.time()
print(t2-t1)

# this is the long part ...
# It's possible that we should break msa_batch_tokens in smaller part to fit in the GPU

In [12]:
results["representations"][1].shape
#should be of size [NB_SEQ, SIZE_MSA, MAX_lEN, 768] (768=dimension of the embeddings)

In [None]:
#we keep only the first dimension along the second axis

embeddings = [results["representations"][emb_layer][:,0,:,:].clone() for emb_layer in Emb_layers]

In [None]:
for emb_layer in Emb_layers:
    torch.save(embeddings, EMBEDDINGS_PATH+'embeddings_layer_'+str(emb_layer)+'_MSA_Transformer.pt')