In [None]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
import time
import joblib
from load_dataset import read_fasta
from tqdm import tqdm
import gc
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device: {}".format(device))

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda:0


In [2]:
gc.collect()

0

In [3]:
#@title Generate embeddings. { display-mode: "form" }
# Generate embeddings via batch-processing
# per_residue indicates that embeddings for each residue in a protein should be returned.
# per_protein indicates that embeddings for a whole protein should be returned (average-pooling)
# max_residues gives the upper limit of residues within one batch
# max_seq_len gives the upper sequences length for applying batch-processing
# max_batch gives the upper number of sequences per batch
def get_embeddings( model, tokenizer, seqs, per_residue, per_protein, 
                   max_residues=4000, max_seq_len=1000, max_batch=100 ):


    results = {"emb": {}, 
                'protein_embs': {}
                }

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    seq_dict   = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in tqdm(enumerate(seq_dict,1)):
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))

        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed 
        if len(batch) >= max_batch or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            batch = list()

            # add_special_tokens adds extra token at the end of each sequence
            token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding='longest')
            input_ids      = torch.tensor(token_encoding['input_ids']).to(device)
            attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
            
            try:
                with torch.no_grad():
                    # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
                    
            except RuntimeError as e:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                # return 
                continue
            for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
                s_len = seq_lens[batch_idx]
                # slice off padding --> batch-size x seq_len x embedding_dim  
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]
                results['emb'][identifier]= emb.detach().cpu().numpy().squeeze()
                # if per_residue: # store per-residue embeddings (Lx1024)
                #     results["residue_embs"][ identifier ] = emb.detach().cpu().numpy().squeeze()
                if per_protein: # apply average-pooling to derive per-protein embeddings (1024-d)
                    protein_emb = emb.mean(dim=0)
                    results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()

    passed_time=time.time()-start
    avg_time = passed_time/len(results["emb"]) if per_residue else passed_time/len(results["emb"])
    print('\n############# EMBEDDING STATS #############')
    # print('Total number of per-residue embeddings: {}'.format(len(results["residue_embs"])))
    # print('Total number of per-protein embeddings: {}'.format(len(results["protein_embs"])))
    print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
        passed_time/60, avg_time ))
    print('\n############# END #############')
    return results

In [4]:
#@title Load encoder-part of ProtT5 in half-precision. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
def get_T5_model():
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    return model, tokenizer

In [6]:
# Load the encoder part of ProtT5-XL-U50 in half-precision (recommended)
model, tokenizer = get_T5_model()

# Load example fasta.
seqs = read_fasta( 'data/AA.fasta' )


Read 43493 sequences.
Example:
1EP9_1
VQLKGRDLLTLKNFTGEEIKYMLWLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRLSTETGFALLGGHPCFLTTQDIHLGVNESLTDTARVLSSMADAVLARVYKQSDLDTLAKEASIPIINGLSDLYHPIQILADYLTLQEHYSSLKGLTLSWIGDGNNILHSIMMSAAKFGMHLQAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAHGGNVLITDTWISMGREEEKKKRLQAFQGYQVTMKTAKVAASDWTFLHCLPRKPEEVDDEVFYSPRSLVFPEAENRKWTIMAVMVSLLTDYSPQLQKPKFX


In [10]:
# Compute embeddings and/or secondary structure predictions
results = get_embeddings( model, tokenizer, seqs, False, True, max_seq_len=800, max_batch=10)

10it [00:02,  3.62it/s]

RuntimeError during embedding for 6ID0_14 (L=1721)


20it [00:03,  7.07it/s]

RuntimeError during embedding for 6H02_1 (L=1334)


30it [00:03,  9.96it/s]

RuntimeError during embedding for 3J2T_1 (L=1144)


1050it [06:31,  3.36it/s]

In [8]:
joblib.dump(results, 'prot_embeddings.joblib')

In [7]:
data = joblib.load('./prot_embeddings.joblib')
# data = ''

In [9]:
# emb
# protein_embs
data['protein_embs']

{'5MQF_8': array([ 0.05660115,  0.11485489,  0.02996082, ...,  0.09422462,
        -0.03572371,  0.03328162], dtype=float32),
 '6S8Q_1': array([ 0.04251173,  0.12166446,  0.05214708, ...,  0.02914876,
        -0.00819473,  0.00396954], dtype=float32),
 '4F91_1': array([ 0.04217843,  0.12154317,  0.05221662, ...,  0.02949451,
        -0.00781159,  0.00408253], dtype=float32),
 '6ICZ_14': array([ 0.05468202,  0.1202598 ,  0.02874776, ...,  0.0982928 ,
        -0.03879813,  0.03072592], dtype=float32),
 '6ICZ_4': array([ 0.05468202,  0.1202598 ,  0.02874776, ...,  0.0982928 ,
        -0.03879813,  0.03072592], dtype=float32),
 '6S8O_1': array([ 0.04234304,  0.12191214,  0.05172106, ...,  0.02839325,
        -0.00760011,  0.00391881], dtype=float32),
 '4F92_1': array([ 0.04186859,  0.12117928,  0.05246502, ...,  0.02958017,
        -0.00793544,  0.0037528 ], dtype=float32),
 '4F93_1': array([ 0.04186859,  0.12117928,  0.05246502, ...,  0.02958017,
        -0.00793544,  0.0037528 ], dtype=f