# ProstT5 

The model can be found in [HuggingFace](https://huggingface.co/Rostlab/ProstT5) with some initial code.

#### Test Feature Extraction code from HuggingFace

In [1]:
# add .. to path
import sys
sys.path.append("../..")
from src.uniprot import download_fasta_parallel

In [2]:
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False) #.to(device) - the tokenizer is not a pytorch object and cannot be loaded to the device

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)

# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
model.full() if device=='cpu' else model.half()

print("Model loaded")

  from .autonotebook import tqdm as notebook_tqdm


cuda:0


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Model loaded


In [3]:
# stolen from libs/ProstT5/notebooks/ProstT5_inverseFolding.ipynb
#@title Read in file in FASTA format. { display-mode: "form" }
def read_fasta( in_path, is_3Di ):
    '''
        Reads in fasta file containing a single or multiple sequences.
        Returns dictionary.
    '''

    sequences = dict()
    with open( in_path, 'r' ) as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                # starts with P and is 6 characters long
                # get index of first P
                uniprot_id = line[line.find('P'):line.find('P')+6]
                sequences[ uniprot_id ] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines
                if is_3Di:
                    sequences[ uniprot_id ] += ''.join( line.split() ).replace("-","").lower() # drop gaps and cast to lower-case
                else:
                    sequences[ uniprot_id ] += ''.join( line.split() ).replace("-","")
                    

    example = sequences[uniprot_id]

    print("##########################")
    print(f"Input is 3Di: {is_3Di}")
    print(f"Example sequence: >{uniprot_id}\n{example}")
    print("##########################")

    return sequences

In [4]:
# download random fastas from uniprot to get the sequences
uniprot_ids = [f"P{str(i).zfill(5)}" for i in range(12345, 12366)]
fasta_list, failed_ids = download_fasta_parallel(uniprot_ids, num_proc=8, save=True, save_dir="../../data/")

In [5]:
import glob
import re
fasta_file = glob.glob("../../data/2024*.fasta")[0]
sequences = read_fasta(fasta_file, is_3Di=False)
sequences = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for _,sequence in sequences.items()]
sequences = sorted(sequences, key=len)
print(sequences[0])
sequences = [ "<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s # this expects 3Di sequences to be already lower-case
                      for s in sequences
                    ]
# only use the shortest sequence for now
sequences = sequences[:1] 

##########################
Input is 3Di: False
Example sequence: >P12365
MDPYKHRPSSAFNAPYWTTNSGAPVWNNDSSLTVGARGPILLEDYHCEKLANFDRERIPERVVHARGASAKGFFEVTHDITHLTCADFLRAPGVQTPVIVRFSTVIHERGSPETLRDPRGFAVKFYTREGNWDLVGNNFPVFFIRDGIKFPDMVHALKPNPRTHIQDNWRILDFFSHHPESLHMFSFLFDDVGIPADYRHMDGSGVHTYTLVSRAGTVTYVKFHWRPTCGVRSLMDDEAVRCGANHSHATKDLTDAIAAGNFPEWTLYIQTMDPEMEDRLDDLDPLDVTKTWPEDTFPLQPVGRLVLNRNIDNFFAENEQLAFCPGLIVPGIYYSDDKLLQTRIFSYSDTQRHRLGPNYLLLPANAPKCAHHNNHYDGSMNFMHRHEEVDYFPSRYDAVRNAPRYPIPTAHIAGRREKTVISKENNFKQPGERYRAMDPARQERFITRWVDALSDPRLTHEIRTIWLSNWSQADRSLGQKLASRLSAKPSM
##########################
S G G K K I K V D K P L G L G G G L T V D I D A


In [6]:
min_len = min([ len("".join(s.removeprefix("<AA2fold> ").split())) for s in sequences])
max_len = max([ len("".join(s.removeprefix("<AA2fold>").split())) for s in sequences])
print(min_len, max_len)


24 24


In [7]:
ids = tokenizer.batch_encode_plus(sequences,
                                  add_special_tokens=True,
                                  padding="longest",
                                  return_tensors='pt').to(device)

# Generation configuration for "folding" (AA-->3Di)
gen_kwargs_aa2fold = {
                  "do_sample": True,
                  "num_beams": 3, 
                  "top_p" : 0.95, 
                  "temperature" : 1.2, 
                  "top_k" : 6,
                  "repetition_penalty" : 1.2,
}

# translate from AA to 3Di (AA-->3Di)
translations=[]
with torch.no_grad():
  translations = model.generate( 
              ids.input_ids, 
              attention_mask=ids.attention_mask, 
              max_length=max_len, # max length of generated text
              min_length=min_len, # minimum length of the generated text
              early_stopping=True, # stop early if end-of-text token is generated
              num_return_sequences=1, # return only a single sequence
              **gen_kwargs_aa2fold
  )
# Decode and remove white-spaces between tokens
decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )
structure_sequences = [ "".join(ts.split(" ")) for ts in decoded_translations ] # predicted 3Di strings

In [8]:
sequence_examples_backtranslation = [ "<fold2AA>" + " " + s for s in decoded_translations]

# tokenize sequences and pad up to the longest sequence in the batch
ids_backtranslation = tokenizer.batch_encode_plus(sequence_examples_backtranslation,
                                  add_special_tokens=True,
                                  padding="longest",
                                  return_tensors='pt').to(device)

# Example generation configuration for "inverse folding" (3Di-->AA)
gen_kwargs_fold2AA = {
            "do_sample": True,
            "top_p" : 0.85,
            "temperature" : 1.0,
            "top_k" : 3,
            "repetition_penalty" : 1.2,
}

# translate from 3Di to AA (3Di-->AA)
with torch.no_grad():
  backtranslations = model.generate( 
              ids_backtranslation.input_ids, 
              attention_mask=ids_backtranslation.attention_mask, 
              max_length=max_len, # max length of generated text
              min_length=min_len, # minimum length of the generated text
              #early_stopping=True, # stop early if end-of-text token is generated; only needed for beam-search
              num_return_sequences=1, # return only a single sequence
              **gen_kwargs_fold2AA
  )
# Decode and remove white-spaces between tokens
decoded_backtranslations = tokenizer.batch_decode( backtranslations, skip_special_tokens=True )
aminoAcid_sequences = [ "".join(ts.split(" ")) for ts in decoded_backtranslations ] # predicted amino acid strings


In [9]:
sequence = sequences[0]
sequence = sequence.removeprefix("<AA2fold> ")
# remove whitespace 
sequence = "".join(sequence.split())
print("Before translation:")
print(sequence)
print(len(sequence))

Before translation:
SGGKKIKVDKPLGLGGGLTVDIDA
24


In [10]:
print("After translation:")
print(structure_sequences[0])
print(len(structure_sequences[0]))

After translation:
ddwdwdwdwddppppdidididt
23


In [11]:
print("Back translation:")
print(aminoAcid_sequences[0])
print(len(aminoAcid_sequences[0]))

Back translation:
MATSIKLVFRNDGNNQWHYEIIP
23


In [12]:
# TODO use backtranslation to generate structure -> compare structure to original structure