In [67]:
#@title Set up working directories and download files/checkpoints. { display-mode: "form" }
# Create directory for storing model weights (2.3GB) and example sequences.
# Here we use the encoder-part of ProtT5-XL-U50 in half-precision (fp16) as 
# it performed best in our benchmarks (also outperforming ProtBERT-BFD).
# Also download secondary structure prediction checkpoint to show annotation extraction from embeddings
!mkdir protT5 # root directory for storing checkpoints, results etc
!mkdir protT5/protT5_checkpoint # directory holding the ProtT5 checkpoint
!mkdir protT5/subcell_checkpoint # directory storing the supervised classifier's checkpoint
!mkdir protT5/output # directory for storing your embeddings & predictions
!wget -nc -P protT5/ https://rostlab.org/~deepppi/example_seqs.fasta
!wget -nc -P protT5/protT5_checkpoint https://rostlab.org/~deepppi/protT5_xl_u50_encOnly_fp16_checkpoint/pytorch_model.bin
!wget -nc -P protT5/protT5_checkpoint https://rostlab.org/~deepppi/protT5_xl_u50_encOnly_fp16_checkpoint/config.json
# Huge kudos to the bio_embeddings team here! We will integrate the new encoder, half-prec ProtT5 checkpoint soon
!wget -nc -P protT5/subcell_checkpoint http://data.bioembeddings.com/public/embeddings/feature_models/t5/subcell_checkpoint.pt

mkdir: protT5: File exists


python(74191) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74192) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


mkdir: protT5/protT5_checkpoint: File exists
mkdir: protT5/subcell_checkpoint: File exists
mkdir: protT5/output: File exists


python(74193) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74194) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


File ‘protT5/example_seqs.fasta’ already there; not retrieving.



python(74195) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74196) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


File ‘protT5/protT5_checkpoint/pytorch_model.bin’ already there; not retrieving.

File ‘protT5/protT5_checkpoint/config.json’ already there; not retrieving.



python(74197) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74198) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


File ‘protT5/subcell_checkpoint/subcell_checkpoint.pt’ already there; not retrieving.



In [68]:
# In the following you can define your desired output. Current options:
# per_residue embeddings
# per_protein embeddings
# secondary structure predictions

# Replace this file with your own (multi-)FASTA
# Headers are expected to start with ">";
seq_path = "./protT5/example_seqs.fasta"

# whether to retrieve embeddings for each residue in a protein 
# --> Lx1024 matrix per protein with L being the protein's length
# as a rule of thumb: 1k proteins require around 1GB RAM/disk
per_residue = False 
per_residue_path = "./protT5/output/per_residue_embeddings.h5" # where to store the embeddings

# whether to retrieve per-protein embeddings 
# --> only one 1024-d vector per protein, irrespective of its length
per_protein = False
per_protein_path = "./protT5/output/per_protein_embeddings.h5" # where to store the embeddings

# whether to retrieve secondary structure predictions
# This can be replaced by your method after being trained on ProtT5 embeddings
subcell_mem = True
subcell_path = "./protT5/output/subcell.csv" # file for storing predictions
mem_path = "./protT5/output/membrane.csv" # file for storing predictions
# make sure that either per-residue or per-protein embeddings are stored
assert per_protein is True or per_residue is True or subcell_mem is True, print(
    "Minimally, you need to active per_residue, per_protein or sec_struct. (or any combination)")


In [69]:
from transformers import T5EncoderModel, T5Tokenizer
import torch
from torch import nn
import h5py
import time
import csv

In [70]:
# In the following you can define your desired output. Current options:
# per_residue embeddings
# per_protein embeddings
# secondary structure predictions

# Replace this file with your own (multi-)FASTA
# Headers are expected to start with ">";
seq_path = "./protT5/example_seqs.fasta"

# whether to retrieve embeddings for each residue in a protein 
# --> Lx1024 matrix per protein with L being the protein's length
# as a rule of thumb: 1k proteins require around 1GB RAM/disk
per_residue = True 
per_residue_path = "./protT5/output/per_residue_embeddings.h5" # where to store the embeddings

# whether to retrieve per-protein embeddings 
# --> only one 1024-d vector per protein, irrespective of its length
per_protein = True
per_protein_path = "./protT5/output/per_protein_embeddings.h5" # where to store the embeddings

# whether to retrieve secondary structure predictions
# This can be replaced by your method after being trained on ProtT5 embeddings
sec_struct = True
sec_struct_path = "./protT5/output/ss3_preds.fasta" # file for storing predictions

# make sure that either per-residue or per-protein embeddings are stored
assert per_protein is True or per_residue is True or sec_struct is True, print(
    "Minimally, you need to active per_residue, per_protein or sec_struct. (or any combination)")

In [71]:

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print("Using {}".format(device))

Using mps


In [72]:
#@title Network architecture for subcell. loc. prediction and Membrane-bound pred. { display-mode: "form" }
# Feed forward neural network to predict a) subcellular localization and 
# b) classifies membrane-bound from water-soluble proteins
class FNN( nn.Module ):
    
    def __init__( self ):
        super(FNN, self).__init__()
        # Linear layer, taking embedding dimension 1024 to make predictions:
        self.layer = nn.Sequential(
                        nn.Linear( 1024, 32),
                        nn.Dropout( 0.25 ),
                        nn.ReLU(),
                        )
        # subcell. classification head
        self.loc_classifier = nn.Linear( 32, 10)
        # membrane classification head
        self.mem_classifier = nn.Linear( 32,  2)

    def forward( self, x):
        # Inference
        out = self.layer( x ) 
        Yhat_loc = self.loc_classifier(out)
        Yhat_mem = self.mem_classifier(out)
        return Yhat_loc, Yhat_mem 

In [73]:
def load_subcell_model():
  checkpoint_dir="./protT5/subcell_checkpoint/subcell_checkpoint.pt"
  state = torch.load(checkpoint_dir, map_location=torch.device('mps'))
  model = FNN()
  model.load_state_dict(state['state_dict'])
  model = model.eval()
  model = model.half()
  model = model.to(device)
  return model

In [74]:
def get_T5_model():
    model = T5EncoderModel.from_pretrained("./protT5/protT5_checkpoint/", torch_dtype=torch.float16)
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False ) 

    return model, tokenizer

In [75]:
csv_file_path = '../merged.csv'  

seqs = {}    

with open(csv_file_path, 'r') as csv_file:
    csv_reader = csv.DictReader(csv_file)
    
    for row in csv_reader:
        uniprot_id = row['Uniprot_ID']
        protein_sequence = row['Protein_sequence']
        
        seqs[uniprot_id] = protein_sequence

In [76]:
def get_embeddings( model, tokenizer, seqs, per_residue, per_protein, subcell_mem, 
                   max_residues=4000, max_seq_len=1000, max_batch=100):

    if subcell_mem:
      subcell_model = load_subcell_model()

    results = {"residue_embs" : dict(), 
               "protein_embs" : dict(),
               "subcell" : dict(),
               "mem" : dict(), 
               }

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    seq_dict   = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))

        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed 
        n_res_batch = sum([ s_len for  _, _, s_len in batch ]) + seq_len 
        if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            batch = list()

            # add_special_tokens adds extra token at the end of each sequence
            token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
            input_ids      = torch.tensor(token_encoding['input_ids']).to(device)
            attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
            
            try:
                with torch.no_grad():
                    # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
            except RuntimeError:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                continue

            for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
                s_len = seq_lens[batch_idx]
                # slice off padding --> batch-size x seq_len x embedding_dim  
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]

                if subcell_mem: # in case you want to predict secondary structure from embeddings
                  subcell_Yhat, mem_Yhat = subcell_model(emb.mean(dim=0,keepdims=True))
                  results["subcell"][identifier] = torch.max( subcell_Yhat, dim=1)[1].detach().cpu().numpy().squeeze()
                  results["mem"][identifier] = torch.max( mem_Yhat, dim=1)[1].detach().cpu().numpy().squeeze()


                if per_residue: # store per-residue embeddings (Lx1024)
                    results["residue_embs"][ identifier ] = emb.detach().cpu().numpy().squeeze()
                if per_protein: # apply average-pooling to derive per-protein embeddings (1024-d)
                    protein_emb = emb.mean(dim=0)
                    results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()


    passed_time=time.time()-start
    avg_time = passed_time/len(seq_dict)
    print('\n############# EMBEDDING STATS #############')
    print('Total number of per-residue embeddings: {}'.format(len(results["residue_embs"])))
    print('Total number of per-protein embeddings: {}'.format(len(results["protein_embs"])))
    print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
        passed_time/60, avg_time ))
    print('\n############# END #############')
    return results

In [77]:
def save_embeddings(emb_dict,out_path):
    with h5py.File(str(out_path), "w") as hf:
        for sequence_id, embedding in emb_dict.items():
            # noinspection PyUnboundLocalVariable
            hf.create_dataset(sequence_id, data=embedding)
    return None

In [78]:
def write_prediction_csv(predictions, out_path, mode):
  # Label mapping for subcellular localization
  subcell_mapping = {
      0: "Cell_Membrane",
      1: "Cytoplasm",
      2: "Endoplasmatic Reticulum",
      3: "Golgi Apparatus",
      4: "Lysosome or vacuole",
      5: "Mitochondrion",
      6: "Nucleus",
      7: "PEROXISOME",
      8: "Plastid",
      9: "Extracellular"
  }
  # Label mapping for membrane-bound
  mem_mapping = {
      0: "Soluble",
      1: "Membrane-bound"
  }

  if mode=="subcell":
    class_mapping=subcell_mapping
  elif mode=="mem":
    class_mapping=mem_mapping
  else:
    raise NotImplemented

  with open(out_path, 'w+') as out_f:
      out_f.write( '\n'.join( 
          [ "{},{}".format( 
              seq_id, class_mapping[int(yhat)]) 
          for seq_id, yhat in predictions.items()
          ] 
            ) )
  return None

In [79]:
model, tokenizer = get_T5_model()

In [83]:
# seqs_limited = dict(list(seqs.items())[:100])
seqs_limited = seqs

In [84]:
print(seqs_limited)

{'A0A009G0Z5': 'MSSALLPLAILVEFGGGFLVLIGLQTRLAAFLLFGFSLVAAVLFHSGSDMNSQIMFMKNISMAGGLLALVIFGAGGLSVDKKLK', 'A0A009IC41': 'MKKADIVVLLLDPKDLKAILAPLKKWLADKTIVSMMAGVNIQQLTSITGSKKIIRVISNPPVLTYTGTHVLIGSDYLEPLDKEVIETIYSATGRTYWANSESQSDAIIALSGSGPAYFFYILDSMVKTGVSMGLDKQFALDLILQAASGAVEMVRKSNVQPSELCGKVTLANGITESALRMFELGNLSDDIRLALKAAYHRSKEISLEISAEITRH', 'A0A009Q4U8': 'MSCTEKNNIGPSVLSPKNSEQYEVGVKQQIRNFLVTAAIFDLKQDNQYSKVNAEGTFDFISQGEQHNQGIELGLTGALTDTLDVSSGVTYTKSRLVDIDTDIYKGHQTQNVPKVRATAQLSYKVPSVEGLRLLSGMQYSSSKYANKEGTAKVGGYSVFNIGAAYKTNFAGHDTTFRFNIDNLFNKKYWRDVGAFMGDDYLFLGNPRTAQFSTTFSF', 'A0A009RPZ1': 'MLTNLREQWFSNVRADILSGLVVGLALIPEAIAFSIIAGVDPQIGLYASFCIAVIISFAGGRPAMISAATGAMALVMTTLVKEHGLQYLFAATILTGVIQIIVGYLKLAKLMRFVSKSVVIGFVNALAILIFMAQLPELVNVSWYVYLLVAIGLVIIYLFPYVPKLGKIFPSPLICIVIVTLLALFLGLDVRTVGDMGALPNTLPIFLIPDIPLNLDTLLIILPYSLALAAVGLLESMMTATIVDEMTDSPSNKFKECKGQGIANIASGFMGGMGS', 'A0A009T383': 'MDFNIALILGQDGITSGAIYALLALCIILVFTVTRILLIPLGEFTVFGALTLASIQAGTPSTIVWLVSAFCLVNLCLDAWESLRNKTAFQWKKQLGLVGYCIILVLCMYQLPLADLPTFF

In [85]:
# Load example fasta.

# Compute embeddings and/or secondary structure predictions
results = get_embeddings( model, tokenizer, seqs_limited,
                         per_residue, per_protein, subcell_mem)

# Store per-residue embeddings
if per_residue:
  save_embeddings(results["residue_embs"], per_residue_path)
if per_protein:
  save_embeddings(results["protein_embs"], per_protein_path)
if subcell_mem:
  print("Start writing predictions")
  write_prediction_csv(results["subcell"], subcell_path,mode="subcell")
  write_prediction_csv(results["mem"], mem_path, mode="mem")

RuntimeError during embedding for A0A024GBR2 (L=4888)
RuntimeError during embedding for A0A087SX42 (L=3659)
RuntimeError during embedding for A0A026WS40 (L=3476)
RuntimeError during embedding for A0A085H9D8 (L=3328)
RuntimeError during embedding for A0A087YLY5 (L=3055)
RuntimeError during embedding for A0A084ZKS7 (L=2908)
RuntimeError during embedding for A0A077RZ86 (L=2905)
RuntimeError during embedding for A0A077YYA4 (L=2639)
RuntimeError during embedding for A0A077Z7F2 (L=2601)
RuntimeError during embedding for A0A061IMH7 (L=2356)
RuntimeError during embedding for A0A015LEM1 (L=2298)
RuntimeError during embedding for A0A016U420 (L=2288)
RuntimeError during embedding for A0A084AK91 (L=2258)
RuntimeError during embedding for A0A023F2B5 (L=2158)
RuntimeError during embedding for A0A087SEJ7 (L=2118)
RuntimeError during embedding for A0A089JUR5 (L=2118)
RuntimeError during embedding for A0A064C176 (L=1978)
RuntimeError during embedding for A0A024URW2 (L=1855)
RuntimeError during embeddin

KeyboardInterrupt: 