In [1]:
%%capture
!pip install transformers
!pip install torch torchvision torchaudio
!pip install sentencepiece

In [2]:
import subprocess
import multiprocessing as mp
import sys
import time
from transformers import T5EncoderModel, T5Tokenizer
import torch
import h5py
import pickle
import re

In [3]:
def get_T5_model(device):
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

    return model, tokenizer

def read_fasta( fasta_path, split_char="!", id_field=0):
    '''
        Reads in fasta file containing multiple sequences.
        Split_char and id_field allow to control identifier extraction from header.
        E.g.: set split_char="|" and id_field=1 for SwissProt/UniProt Headers.
        Returns dictionary holding multiple sequences or only single 
        sequence, depending on input file.
    '''
    
    seqs = dict()
    with open( fasta_path, 'r' ) as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                uniprot_id = line.replace('>', '').strip().split(split_char)[id_field]
                # replace tokens that are mis-interpreted when loading h5
                uniprot_id = uniprot_id.replace("/","_").replace(".","_")
                seqs[ uniprot_id ] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines, drop gaps and cast to upper-case
                seq= ''.join( line.split() ).upper().replace("-","")
                # repl. all non-standard AAs and map them to unknown/X
                seq = seq.replace('U','X').replace('Z','X').replace('O','X')
                seqs[ uniprot_id ] += seq 
    example_id=next(iter(seqs))
    print("Read {} sequences.".format(len(seqs)))
    print("Example:\n{}\n{}".format(example_id,seqs[example_id]))

    return seqs


def get_embeddings(model, tokenizer, seqs, device, max_residues = 4000, max_seq_len = 1000, max_batch = 100):
    results = {"protein_embs" : dict()}

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    seq_dict   = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    start = time.time()
    batch = list()


    for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))
        
        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed 
        n_res_batch = sum([ s_len for  _, _, s_len in batch ]) + seq_len 
        if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            batch = list()

            # add_special_tokens adds extra token at the end of each sequence
            token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
            input_ids      = torch.tensor(token_encoding['input_ids']).to(device)
            attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
            try:
                with torch.no_grad():
                    # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
                    
            except RuntimeError:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                continue

            for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
                s_len = seq_lens[batch_idx]
                # slice off padding --> batch-size x seq_len x embedding_dim  
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]
                protein_emb = emb.mean(dim=0)
                results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()
    print("Time using batching:\t\t" ,time.time()-start)

    return results

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model, tokenizer = get_T5_model(device)

Downloading (…)lve/main/config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

In [5]:
seq_path = "/content/ORF2p_v3.fasta" 
seqs = read_fasta(seq_path)

Read 721 sequences.
Example:
L1M5_orf2DF0000008
MVDLNPXISIITLNVNGLNTPIKRQRLSDWIKKQDPTICCLQETHFKYKDTXRLKVKGWKKIYHANTXQKKAGVAILISDKVDFKARNITRDKEGHFIMIKGSIHQEDITILNVYAPNNRASKYMKQKLTELKGEIDKSTIIVGDFNTPLSVIDRTSRQKISKDIEDLNNTINQLDLIDIYRTLHPTTAEYTFFSSAHGTFTKIDHXLGHKTSLNKFXRIEIIQSMFSDHNGIKLEINNKKIXRKSPNIWKLNNTLLNNPWVKEEITXEIRKYFELNXNENTTYQNLWDAAKAVLRGKFIALNAYIRKEERLKINDLSFHLKKLEKEEQIKPKXSRRKEIIKIRAEINEIENKXTIEKINEAKSWFFEKINKIDKPLARLIKKKREKTQITNIRNEKGDITTDPTDIKRIIREYYEQLYANKFXNLDEMDKFLEKHXLPKLTQEEIENLNSPISIKEIEFVIKNLPTKKTPGPDGFTGEFYQTFKEEIIPILHKLFQKIEKEGTLPNSFYEASXTLIPKPDKDITRKENYRPISLMNIDAKILNKILANRIQQYIKRIIHHDQVGFIPGMQGWFNIRKSINVIHHINRIKEKNHMIISIDAEKAFDKIQHPFMIKTLSKLGIEGNFLNLIKGIYEKPTANIILNGEXLXAFPLRSGTRQGCPLSPLLFNIVLEVLASAIRQEKEIKGIKIGKEEIKLSLFADDMIVYVENPKESTXKLLELISEFSKVAGYKVNIQKSIVFLYTSNKQLENEIXKTIPFTIASKNXKYLGINLTKXVQDLYTENYKTLLREIKEDLNKWRDIPCSWIGRLNIVKMSILPKLIYRFNAIPIKIPAGFFVEIDKLILKFIWKCKGPRIAKTILKKKNKVGGLTLPDXKTYYKATVIKTVWYWRKDRQIDQWNRIESPEIDPHIYGXLIFDKGAXAIQWGKDSLFNKWCWNNWISIWKKMXLDPY

In [6]:
results = get_embeddings(model, tokenizer, seqs, device,  4*1535.0,  1535, 10)

Time using batching:		 1044.9500122070312


# Export results as pickle and checking array length

In [10]:
# Store data as pickle file
with open('embeddings_v2.pickle', 'wb') as handle:
  pickle.dump(results['protein_embs'], handle) 

In [None]:
# Load data as pickle
with open('/content/embeddings_v2.pickle', 'rb') as handle:
    a = pickle.load(handle)

In [16]:
len(a['TapTer-5_461DF0280958'])

1024