### 0. Imports

In [1]:
import sys 
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, T5EncoderModel
import torch
import re

  from .autonotebook import tqdm as notebook_tqdm


### 1. Teacher selection
We only have one teacher, which is ProstT5.


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False) #.to(device) - the tokenizer is not a pytorch object and cannot be loaded to the device

# Load the model
model = T5EncoderModel.from_pretrained("Rostlab/ProstT5").to(device)

# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
model.full() if device=='cpu' else model.half()
print()

cuda:0


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565





2. Prepare a FASTA file for distillation

In [4]:
save_path = "../../data/uniprotkb"
fasta_path = f"{save_path}/uniprot_sprot.fasta"


In [5]:
uniprot_kb = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
!wget $uniprot_kb -O $save_path/uniprot_sprot.fasta.gz
!gunzip -f $save_path/uniprot_sprot.fasta.gz

--2024-11-08 14:55:27--  https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz
Resolving ftp.uniprot.org (ftp.uniprot.org)... 128.175.240.195
Connecting to ftp.uniprot.org (ftp.uniprot.org)|128.175.240.195|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 92678508 (88M) [application/x-gzip]
Saving to: ‘../../data/uniprotkb/uniprot_sprot.fasta.gz’


2024-11-08 14:55:55 (3.24 MB/s) - ‘../../data/uniprotkb/uniprot_sprot.fasta.gz’ saved [92678508/92678508]



In [6]:
def read_fasta(in_path, is_3Di):
    '''
        Reads in fasta file containing a single or multiple sequences.
        Returns dictionary.
    '''

    sequences = dict()
    with open( in_path, 'r' ) as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                # starts with P and is 6 characters long
                # get index of first P
                uniprot_id = line[line.find('P'):line.find('P')+6]
                sequences[ uniprot_id ] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines
                if is_3Di:
                    sequences[ uniprot_id ] += ''.join( line.split() ).replace("-","").lower() # drop gaps and cast to lower-case
                else:
                    sequences[ uniprot_id ] += ''.join( line.split() ).replace("-","")
                    

    example = sequences[uniprot_id]

    print("##########################")
    print(f"Input is 3Di: {is_3Di}")
    print(f"Example sequence: >{uniprot_id}\n{example}")
    print("##########################")

    return sequences

sequences = read_fasta(fasta_path, is_3Di=False)
sequences = list(sequences.values())
# replace ambigous amino acisds with X
sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequences]
sequence_lengths = [len("".join(sequence.split())) for sequence in sequences]
print(f"Number of sequences: {len(sequences)}")

# DEBUG ---------------------------------------
sequences = sequences[:100]
sequence_lengths = sequence_lengths[:100]

##########################
Input is 3Di: False
Example sequence: >PE=3 S
MGLRYSKDVKDRYGDREPEGRIPITLNMPQSLYGRYNCKSCWFANKGLLKCSNHYLCLKCLTLMLRRSDYCGICGEVLPKKLVFENSPSAPPYEA
##########################
Number of sequences: 166692


3. Generate embeddings using your teacher model

In [7]:
# prepare sequences for model input
sequence_input = [ "<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s # this expects 3Di sequences to be already lower-case
                        for s in sequences
                    ]
ids = tokenizer.batch_encode_plus(sequence_input,
                                    add_special_tokens=True,
                                    padding="longest",
                                    return_tensors='pt').to(device)

In [8]:
with torch.no_grad():
    embedding_rpr = model(
              ids.input_ids, 
              attention_mask=ids.attention_mask
              )

OutOfMemoryError: CUDA out of memory. Tried to allocate 10.22 GiB. GPU 0 has a total capacity of 6.00 GiB of which 0 bytes is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 14.96 GiB is allocated by PyTorch, and 227.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [34]:
print(sequences)
print("".join(sequences[0].split()))
print(len("".join(sequences[0].split())))

['M S F T I P T N L Y K P L A T K P K H L S S S S F A P R S K I V C Q Q E N D Q Q Q P K K L E L A K V G A N A A A A L A L S S V L L S S W S V A P D A A M A D I A G L T P C K E S K Q F A K R E K Q A L K K L Q A S L K L Y A D D S A P A L A I K A T M E K T K K R F D N Y G K Y G L L C G S D G L P H L I V S G D Q R H W G E F I T P G I L F L Y I A G W I G W V G R S Y L I A I R D E K K P T Q K E I I I D V P L A S S L L F R G F S W P V A A Y R E L L N G E L V D N N F']
MSFTIPTNLYKPLATKPKHLSSSSFAPRSKIVCQQENDQQQPKKLELAKVGANAAAALALSSVLLSSWSVAPDAAMADIAGLTPCKESKQFAKREKQALKKLQASLKLYADDSAPALAIKATMEKTKKRFDNYGKYGLLCGSDGLPHLIVSGDQRHWGEFITPGILFLYIAGWIGWVGRSYLIAIRDEKKPTQKEIIIDVPLASSLLFRGFSWPVAAYRELLNGELVDNNF
231


In [36]:
per_residue = embedding_rpr.last_hidden_state[0,1:len("".join(sequences[0].split()))+1]
per_protein = embedding_rpr.last_hidden_state[0,1:len("".join(sequences[0].split()))+1].mean(dim=0)

print(per_residue.shape)
print(per_protein.shape)

torch.Size([231, 1024])
torch.Size([1024])


## Question: Per protein or per residue embeddings?
Most likely per protein, because step 5 says:
```Then you can run the following command to distill your model. Please make sure that your "*.feat" file is matrix with shape (number of sequences, dimension of embedding)```