<img src="https://github.com/Multiomics-Analytics-Group/course_protein_language_modeling/blob/main/img/nb_logo.png?raw=1" width="600">

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Multiomics-Analytics-Group/course_protein_language_modeling/blob/main/notebooks/embeddings.ipynb)


This is a version of the notebook from [SETH](https://github.com/DagmarIlz/SETH) --- [here](https://colab.research.google.com/drive/1vDWh5YI_BPxQg0ku6CxKtSXEJ25u2wSq?usp=sharing).

In [1]:
#!pip install "transformers[torch]" sentencepiece h5py biopython > /dev/null

<h3><span style="color:red">Important</span></h3> 
If you are running in Google Colab, change the Notebook settings to use `GPU`.

Just follow **Edit** > **Notebook settings** or **Runtime** > **Change runtime type** and select **GPU** as Hardware accelerator.

![gpu.png](../img/gpu.png)


# Embedding Protein Sequences

In this notebook, we will use a pre-trained language model, [ProtT5-XL-UniRef50](https://huggingface.co/Rostlab/prot_t5_xl_uniref50), to encode the protein sequences of 5000+ $\beta$-$lactamase$ TEM-type varients from FASTA file [P62593.fasta](https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/examples/data/P62593.fasta). This data was subsetted from a deep mutational scan released by [Gray et al. (2018)](https://www.cell.com/cell-systems/pdfExtended/S2405-4712(17)30492-1). 

The goal of this notebook is to obtain an embedding (fixed-dimensional vector representation) for each mutated sequence.

Although the embedding won't capture all the information from the original data, good embedding representations allow us to analyze, cluster, or use them as features to train machine learning models. 

The embeddings generated in this notebook will then be used in the next exercise (prediction.ipynb) to train a simple varient predictor (i.e., predict the activity of the protein mutation).

<div class="warning" style='background-color:#E9D8FD; color: #69337A; border-left: solid #805AD5 4px; border-radius: 4px; padding:0.7em;'>
<span>
<p style='margin-top:1em; text-align:center'><b>NOTE</b></p>
<p style='margin-left:1em;'>
    Even when using GPU, embedding the protein sequences takes some time (~25mins) so to begin go ahead and run all cells of this notebook so that the process is started in the background as we review the notebook.
</p>
<p style='margin-top:1em; text-align:center'>
    A shortcut to running all cells is going to the "Runtime" menu and selecting "Run all".

</span>
</div>

----

## The Data: P62593 Sequences

To start we will import and explore the dataset: 

In [2]:
# Set up working directories and download files/checkpoints 
!mkdir protT5 # directory for storing checkpoints, results etc
!mkdir protT5/output # directory for storing your embeddings
!curl -o P62593.fasta https://dl.fbaipublicfiles.com/fair-esm/examples/P62593.fasta

mkdir: cannot create directory ‘protT5’: File exists
mkdir: cannot create directory ‘protT5/output’: File exists
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1719k  100 1719k    0     0  3634k      0 --:--:-- --:--:-- --:--:-- 3628k


In [3]:
# Import dependencies
from transformers import T5EncoderModel, T5Tokenizer
import torch
import h5py
import time
from Bio import SeqIO

# Path variables
per_protein_path = "./protT5/output/per_protein_embeddings.h5" # where to store the embeddings
seq_path = 'P62593.fasta' # where the fasta file is saved

# check whether GPU is available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using {device}")

Using cpu


In [4]:
def read_fasta(fasta_path:str) -> dict:
    '''
    reads in fasta file and returns a dictionary with primary id/sequence key/value pairs 
    '''
    
    # dictionary to append to
    seqs = {}
    
    # read in and parse fasta file
    with open(fasta_path) as handle:
        for record in SeqIO.parse(handle, "fasta"):
            # append each varient to the dict
            seqs[record.id] = record.seq

    # verbose
    example_id=next(iter(seqs))
    print(f"Read {len(seqs)} sequences.")
    print(f"Example:\nKey: {example_id}\nValue: {seqs[example_id]}")

    return seqs

In [5]:
# read in file
fasta_output = read_fasta(seq_path)

Read 5397 sequences.
Example:
Key: 0|beta-lactamase_P20P|1.581033423
Value: MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW


In the FASTA file there are 5,397 sequences. As we can see in the example above from our fasta dictionary, each entry contains:

- key: `{index}|beta-lactamase_{mutation}|{scaled_varient_effect}`
    > in prediction.ipynb we will be predicting the `scaled_varient_effect` value, which describes the scaled effect of the mutation. 
- value: the mutated $\beta$-lactamase sequence, where a single residue is mutated (swapped with another amino acid)

## The Model: ProtT5-XL-UniRef50

ProtT5-XL-UniRef50 is based on the t5-3b model and was pretrained on a large corpus of protein sequences in a self-supervised fashion. This means it was pretrained on the raw protein sequences only, with **no humans-in-the-loop labelling** them in any way (which is why it can use lots of publicly available data) with an automatic process to generate inputs and labels from those protein sequences.

This model only contains the encoder portion of the original ProtT5-XL-UniRef50 model using half precision (float16). As such, this model can efficiently be used to create protein/ amino acid representations. When used for training downstream networks/ feature extraction, these embeddings produced the same performance (established empirically by comparing on several downstream tasks). 

In the following cells we will prepare functions that will later assist us when generating the embeddings. 

### `get_T5_model()` Load encoder-part of ProtT5 in half-precision

To start we create a function that will load the model and associated tokenizer. Recall from the previous notebook (model_training.ipynb) where every model on `transformers` comes with an associated `tokenizer` that handles tokenization for it, where tokenization for protein language models involve coverting each amino acid to a single token.

**Recall: Fine-tuning flow chart from the previous notebook**
![Chart of the pretrained model fine-tuning process](../img/fine-tuning.png)


This function accomplishes the "Model Checkpoint Loading" step. 

In [6]:
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
def get_T5_model():
    '''
    retrieves the model and tokenizer
    '''
    # specify the encorder-part of the model 
    model_checkpoint = 'Rostlab/prot_t5_xl_half_uniref50-enc'
    
    # import the model 
    model = T5EncoderModel.from_pretrained(model_checkpoint)
    
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    
    # import tokenizer
    tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)

    return model, tokenizer

Let's use the function:

In [7]:
# loading model
model, tokenizer = get_T5_model()

config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


### `get_embeddings()` Using the model to generate the embeddings

From the flow chart above, the function `get_embeddings()` includes the 'Tokenization' and 'Dataset Creation' steps. Additionally, the model will encode the sequences. 

In [8]:
def get_embeddings(model, tokenizer, seqs, max_seq_len=1000, max_batch=100):
    '''
    use the encoder to embbed the sequences via batch-processing
    -----
    
    parameters:
        model: from get_T5_model()
        tokenizer: from get_T5_model()
        seqs: the dictionary of sequences generated by read_fasta() 
        max_seq_length: the upper sequences length for applying batch-processing
        max_batch: the upper number of sequences per batch
        
    returns:
        results: a dictionary containing the embedding representations of the sequences
    '''

    # initialize a dictionary, the embeddings will be accessible from results['protein_embs'] 
    results = {"protein_embs" : dict()}

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    seq_dict = sorted(seqs.items(), 
                      # 'key' option is a function that serves as a basis of sort comparison.
                      key=lambda kv: len(seqs[kv[0]]), 
                      # sort by descending order
                      reverse=True
                     )
    
    # for time tracking
    start = time.time()
    
    # initialize empty list
    batch = list()
    
    # for each item in the dictionary
    for seq_idx, (pdb_id, seq) in enumerate(seq_dict, 1):
        
        # add space between residues
        seq = ' '.join(list(seq))
        
        # length of sequence with spaces
        seq_len = len(seq)
        
        # append to batch list as tuple
        batch.append((pdb_id, seq, seq_len))

        # creates n-tuple pairs from each element in batch
        pdb_ids, seqs, seq_lens = zip(*batch)
        
        # empty list
        batch = list()
        
        # Data Preparation and Tokenization:

        # add_special_tokens adds extra token at the end of each sequence
        token_encoding = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding="longest")
        
        # making the tokenized sequence into a tensor
        input_ids = torch.tensor(token_encoding['input_ids']).to(device)
        
        # now making the mask into a tensor
        attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)
        
        # Generate Embedding:
        
        # using the model to encode the sequence, generating an embedding representation
        try:
            with torch.no_grad():
                embedding_repr = model(input_ids, attention_mask=attention_mask)
                # verbosity for progress tracking
                print(f'Currently, embedding {pdb_id}')
        except RuntimeError:
            print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
            continue
            
        # Putting together the dataset:
        
        # slice off padding if any 
        emb = embedding_repr.last_hidden_state[0,:seq_len]

        # average along column
        protein_emb = emb.mean(dim=0)

        # save the embedding into results dictionary where the key = the fasta file entry header
        results["protein_embs"][pdb_id] = protein_emb.detach().cpu().numpy().squeeze()

    # get time elapsed
    passed_time=time.time()-start
    avg_time = passed_time/len(results["protein_embs"])
    
    # final verbose
    print('\n############# EMBEDDING STATS #############')
    print('Total number of per-protein embeddings: {}'.format(len(results["protein_embs"])))
    print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
        passed_time/60, avg_time ))
    print('\n############# END #############')
    
    return results

**NOTE: What are special tokens?** 

Special tokens aren't present in the input text, but carry important meaning that we want the model to act on. For exmaple (not spevific to our model): 
- [PAD] Padding token — Added to the end of shorter inputs so that all inputs have the same length. This is because inputs to a neural network model are typically batched and the model operates on entire batches. 
- [UNK] Unknown token — Used to limit the number of distinct tokens. For example, if we want a vocabulary of at most 1000 tokens but the input text has 1200, then the remaining 200 will be converted to [UNK].
    
You can read more [here](https://medium.com/@alexkubiesa/special-tokens-in-tensorflow-3c7718dcb0ef).

We will move on and use the function to get the embeddings:

In [9]:
# Compute embeddings
results = get_embeddings(model, tokenizer, fasta_output)

Currently, embedding 0|beta-lactamase_P20P|1.581033423
Currently, embedding 1|beta-lactamase_D207D|1.42563125
Currently, embedding 2|beta-lactamase_A215A|1.422813331
Currently, embedding 3|beta-lactamase_C75C|1.4155315119999998
Currently, embedding 4|beta-lactamase_N134N|1.39696596
Currently, embedding 5|beta-lactamase_L137L|1.355533136
Currently, embedding 6|beta-lactamase_L28L|1.3516090040000002
Currently, embedding 7|beta-lactamase_L199L|1.3516090040000002
Currently, embedding 8|beta-lactamase_F149F|1.32191175
Currently, embedding 9|beta-lactamase_A200A|1.295473865
Currently, embedding 10|beta-lactamase_E210E|1.29406548
Currently, embedding 11|beta-lactamase_H24H|1.282201552
Currently, embedding 12|beta-lactamase_L19L|1.280029666
Currently, embedding 13|beta-lactamase_A183A|1.279505214
Currently, embedding 14|beta-lactamase_T27T|1.248455477
Currently, embedding 15|beta-lactamase_L38L|1.245034082
Currently, embedding 16|beta-lactamase_I229I|1.23953749
Currently, embedding 17|beta-lac

KeyboardInterrupt: 

### `save_embeddings()` Writing the embeddings to a file

For our final function, we will write the embeddings to a file. We will load this file into prediction.ipynb to train machine learning models. This is also a copy of the embedding file in the reposition in _data/_. 

In [10]:
def save_embeddings(emb_dict:dict , out_path:str):
    '''
    takes the resulting embeddings from get_embeddings() and saves to a compressed h5 file
    -----
    
    parameters:
        emb_dict (dict): dictionary that is in results['protein_embs'] 
        out_path (str): path and filename where the embeddings will be saved
    '''
    
    with h5py.File(str(out_path), "w") as hf:
        for sequence_id, embedding in emb_dict.items():
            hf.create_dataset(sequence_id, data=embedding)
            
    return None

In [None]:
# write embeddings to file
save_embeddings(results["protein_embs"], per_protein_path)