In [1]:
# Prepare fasta
# Generate batches for dataloader
import pickle5 as pickle
import torch
import pandas as pd
import numpy as np
import collections

df = pd.read_csv('data/LigID_pdbchain_partitions.csv')

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
from Bio import SeqIO
from tqdm import tqdm
fasta_sequences = SeqIO.parse(open('data/datasetB_all_sequences.fasta'),'fasta')
output_handle = open('data/datasetB_multiclass.fasta', "w")
count = 0
for record in tqdm(fasta_sequences):
    if record.id in df['pdbchain'].values:
        SeqIO.write(record, output_handle, "fasta")
        count+=1
output_handle.close()
print(count)

fasta_sequences_ = SeqIO.parse(open('data/datasetB_multiclass.fasta'),'fasta')
len(list(fasta_sequences_))

99298it [00:40, 2458.63it/s]

29409





In [30]:
#@title Import dependencies and check whether GPU is available. { display-mode: "form" }
from transformers import T5EncoderModel, T5Tokenizer
import torch
import h5py
import time
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

Using cuda:0


In [31]:
# In the following you can define your desired output. Current options:
# per_residue embeddings
# per_protein embeddings
# secondary structure predictions

# Replace this file with your own (multi-)FASTA
# Headers are expected to start with ">";
seq_path = "./protT5/example_seqs.fasta"

# whether to retrieve embeddings for each residue in a protein 
# --> Lx1024 matrix per protein with L being the protein's length
# as a rule of thumb: 1k proteins require around 1GB RAM/disk
per_residue = True 
per_residue_path = "./protT5/output/per_residue_embeddings.h5" # where to store the embeddings

# whether to retrieve per-protein embeddings 
# --> only one 1024-d vector per protein, irrespective of its length
per_protein = False
per_protein_path = "./protT5/output/per_protein_embeddings.h5" # where to store the embeddings

# whether to retrieve secondary structure predictions
# This can be replaced by your method after being trained on ProtT5 embeddings
sec_struct = False
sec_struct_path = "./protT5/output/ss3_preds.fasta" # file for storing predictions

# make sure that either per-residue or per-protein embeddings are stored
assert per_protein is True or per_residue is True or sec_struct is True, print(
    "Minimally, you need to active per_residue, per_protein or sec_struct. (or any combination)")

In [32]:
#@title Network architecture for secondary structure prediction. { display-mode: "form" }
# Convolutional neural network (two convolutional layers) to predict secondary structure
class ConvNet( torch.nn.Module ):
    def __init__( self ):
        super(ConvNet, self).__init__()
        # This is only called "elmo_feature_extractor" for historic reason
        # CNN weights are trained on ProtT5 embeddings
        self.elmo_feature_extractor = torch.nn.Sequential(
                        torch.nn.Conv2d( 1024, 32, kernel_size=(7,1), padding=(3,0) ), # 7x32
                        torch.nn.ReLU(),
                        torch.nn.Dropout( 0.25 ),
                        )
        n_final_in = 32
        self.dssp3_classifier = torch.nn.Sequential(
                        torch.nn.Conv2d( n_final_in, 3, kernel_size=(7,1), padding=(3,0)) # 7
                        )
        
        self.dssp8_classifier = torch.nn.Sequential(
                        torch.nn.Conv2d( n_final_in, 8, kernel_size=(7,1), padding=(3,0))
                        )
        self.diso_classifier = torch.nn.Sequential(
                        torch.nn.Conv2d( n_final_in, 2, kernel_size=(7,1), padding=(3,0))
                        )
        

    def forward( self, x):
        # IN: X = (B x L x F); OUT: (B x F x L, 1)
        x = x.permute(0,2,1).unsqueeze(dim=-1) 
        x         = self.elmo_feature_extractor(x) # OUT: (B x 32 x L x 1)
        d3_Yhat   = self.dssp3_classifier( x ).squeeze(dim=-1).permute(0,2,1) # OUT: (B x L x 3)
        d8_Yhat   = self.dssp8_classifier( x ).squeeze(dim=-1).permute(0,2,1) # OUT: (B x L x 8)
        diso_Yhat = self.diso_classifier(  x ).squeeze(dim=-1).permute(0,2,1) # OUT: (B x L x 2)
        return d3_Yhat, d8_Yhat, diso_Yhat

#@title Load the checkpoint for secondary structure prediction. { display-mode: "form" }
def load_sec_struct_model():
  checkpoint_dir="./protT5/sec_struct_checkpoint/secstruct_checkpoint.pt"
  state = torch.load( checkpoint_dir )
  model = ConvNet()
  model.load_state_dict(state['state_dict'])
  model = model.eval()
  model = model.to(device)
  print('Loaded sec. struct. model from epoch: {:.1f}'.format(state['epoch']))

  return model

#@title Load encoder-part of ProtT5 in half-precision. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
def get_T5_model():
    model = T5EncoderModel.from_pretrained("pretrained_models/prot_t5_xl_uniref50")
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained('pretrained_models/prot_t5_xl_uniref50', do_lower_case=False)

    return model, tokenizer

In [33]:
#@title Generate embeddings. { display-mode: "form" }
# Generate embeddings via batch-processing
# per_residue indicates that embeddings for each residue in a protein should be returned.
# per_protein indicates that embeddings for a whole protein should be returned (average-pooling)
# max_residues gives the upper limit of residues within one batch
# max_seq_len gives the upper sequences length for applying batch-processing
# max_batch gives the upper number of sequences per batch
def get_embeddings( model, tokenizer, seqs, per_residue, per_protein, sec_struct, 
                   max_residues=4000, max_seq_len=1000, max_batch=100 ):

    if sec_struct:
      sec_struct_model = load_sec_struct_model()

    results = {"residue_embs" : dict(), 
               "protein_embs" : dict(),
               "sec_structs" : dict() 
               }

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    seq_dict   = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
        print('Running %s'%(seq_idx))
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))

        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed 
        n_res_batch = sum([ s_len for  _, _, s_len in batch ]) + seq_len 
        if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            batch = list()

            # add_special_tokens adds extra token at the end of each sequence
            token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
            input_ids      = torch.tensor(token_encoding['input_ids']).to(device)
            attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
            
            try:
                with torch.no_grad():
                    # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
            except RuntimeError:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                continue

            if sec_struct: # in case you want to predict secondary structure from embeddings
              d3_Yhat, d8_Yhat, diso_Yhat = sec_struct_model(embedding_repr.last_hidden_state)


            for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
                s_len = seq_lens[batch_idx]
                # slice off padding --> batch-size x seq_len x embedding_dim  
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]
                if sec_struct: # get classification results
                    results["sec_structs"][identifier] = torch.max( d3_Yhat[batch_idx,:s_len], dim=1 )[1].detach().cpu().numpy().squeeze()
                if per_residue: # store per-residue embeddings (Lx1024)
                    results["residue_embs"][ identifier ] = emb.detach().cpu().numpy().squeeze()
                if per_protein: # apply average-pooling to derive per-protein embeddings (1024-d)
                    protein_emb = emb.mean(dim=0)
                    results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()


    passed_time=time.time()-start
    avg_time = passed_time/len(results["residue_embs"]) if per_residue else passed_time/len(results["protein_embs"])
    print('\n############# EMBEDDING STATS #############')
    print('Total number of per-residue embeddings: {}'.format(len(results["residue_embs"])))
    print('Total number of per-protein embeddings: {}'.format(len(results["protein_embs"])))
    print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
        passed_time/60, avg_time ))
    print('\n############# END #############')
    return results

In [34]:
# Load the encoder part of ProtT5-XL-U50 in half-precision (recommended)
model, tokenizer = get_T5_model()

Some weights of the model checkpoint at pretrained_models/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.13.layer.0.SelfAttention.q.weight', 'decoder.block.23.layer.2.DenseReluDense.wi.weight', 'decoder.block.17.layer.0.SelfAttention.k.weight', 'decoder.block.8.layer.1.layer_norm.weight', 'decoder.block.22.layer.2.layer_norm.weight', 'decoder.block.18.layer.1.EncDecAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.k.weight', 'decoder.block.23.layer.2.DenseReluDense.wo.weight', 'decoder.block.18.layer.1.EncDecAttention.v.weight', 'lm_head.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.8.layer.1.EncDecAttention.v.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.11.layer.0.SelfAttention.o.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.22.layer.0.layer_norm.weight', 'decoder.block.17.layer.1.EncDecAttention.q.

In [35]:
#@title Read in file in fasta format. { display-mode: "form" }
def read_fasta( fasta_path, split_char="!", id_field=0):
    '''
        Reads in fasta file containing multiple sequences.
        Split_char and id_field allow to control identifier extraction from header.
        E.g.: set split_char="|" and id_field=1 for SwissProt/UniProt Headers.
        Returns dictionary holding multiple sequences or only single 
        sequence, depending on input file.
    '''
    seqs = dict()
    with open( fasta_path, 'r' ) as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                uniprot_id = line.replace('>', '').strip().split(split_char)[id_field]
                # replace tokens that are mis-interpreted when loading h5
                uniprot_id = uniprot_id.replace("/","_").replace(".","_")
                seqs[ uniprot_id ] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines, drop gaps and cast to upper-case
                seq= ''.join( line.split() ).upper().replace("-","")
                # repl. all non-standard AAs and map them to unknown/X
                seq = seq.replace('U','X').replace('Z','X').replace('O','X')
                seqs[ uniprot_id ] += seq 
    example_id=next(iter(seqs))
    print("Read {} sequences.".format(len(seqs)))
    print("Example:\n{}\n{}".format(example_id,seqs[example_id]))

    return seqs

# Load example fasta.
seqs = read_fasta('data/datasetB_multiclass.fasta')

Read 29409 sequences.
Example:
3es3_A
TSQVRQNYDQDSEAAINRQINLELYASYVYLSMSYYFDRDDVALKNFAKYFLHQSHEERCHAEKLMKLQNQRGGRIFLQDIQKPDRDDWESGLNAMEAALQLEKNVNQSLLELHKLATDKNDPHLSDFIETHYLNCQVCAIKCLGDHVTNLRKMGAPESGLAEYLFDKHTLG


In [36]:
# Compute embeddings and/or secondary structure predictions
results = get_embeddings(model, tokenizer, seqs,
                         per_residue, per_protein, sec_struct)

Running 1
Running 2
Running 3
Running 4
Running 5
Running 6
Running 7
RuntimeError during embedding for 6rs0_A (L=998)
Running 8
Running 9
Running 10
Running 11
RuntimeError during embedding for 4anj_A (L=995)
Running 12
Running 13
Running 14
Running 15
RuntimeError during embedding for 3f2b_A (L=994)
Running 16
Running 17
Running 18
Running 19
RuntimeError during embedding for 5dhf_C (L=994)
Running 20
Running 21
Running 22
Running 23
RuntimeError during embedding for 6k7l_A (L=992)
Running 24
Running 25
Running 26
Running 27
RuntimeError during embedding for 3ho8_B (L=989)
Running 28
Running 29
Running 30
Running 31
RuntimeError during embedding for 4rpu_A (L=988)
Running 32
Running 33
Running 34
Running 35
RuntimeError during embedding for 6jxi_A (L=987)
Running 36
Running 37
Running 38
Running 39
RuntimeError during embedding for 4nge_A (L=985)
Running 40
Running 41
Running 42
Running 43
RuntimeError during embedding for 2o9j_A (L=980)
Running 44
Running 45
Running 46
Running 47
Ru

In [38]:
#@title Write embeddings to disk. { display-mode: "form" }
def save_embeddings(emb_dict,out_path):
    with h5py.File(str(out_path), "w") as hf:
        for sequence_id, embedding in emb_dict.items():
            # noinspection PyUnboundLocalVariable
            hf.create_dataset(sequence_id, data=embedding)
    return None


per_residue = True 
per_residue_path = "./protT5/output/per_residue_embeddings.h5" # where to store the embeddings

# Store per-residue embeddings
if per_residue:
  save_embeddings(results["residue_embs"], per_residue_path)
if per_protein:
  save_embeddings(results["protein_embs"], per_protein_path)

In [54]:
results['residue_embs']['6f5l_A'].shape

(693, 1024)