In [1]:
import os
os.chdir('/home/ku76797/Documents/internship/Work/CATHe')

import time
import torch
import numpy as np
import pandas as pd
from transformers import T5EncoderModel, T5Tokenizer
from tqdm import tqdm

if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print("Using device: {}".format(device))


def get_T5_model(model_dir):
    print("Loading ProsT5 from: {}".format(model_dir))
    model = T5EncoderModel.from_pretrained(model_dir).to(device)
    model = model.eval()
    tokenizer = T5Tokenizer.from_pretrained(model_dir, do_lower_case=False)
    return model, tokenizer


def read_csv(seq_path):
    '''
        Reads in CSV file containing sequences.
        Returns a dictionary of sequences with IDs as keys.
    '''
    sequences = {}
    df = pd.read_csv(seq_path)
    for _, row in df.iterrows():
        sequences[int(row['Unnamed: 0'])] = row['Sequence']
    return sequences


def get_embeddings(seq_path, model_dir, half_precision, is_3Di,
                   max_residues, max_seq_len, max_batch):
    
    emb_dict = dict()

    # Read in CSV
    seq_dict = read_csv(seq_path)
    prefix = "<fold2AA>" if is_3Di else "<AA2fold>"
    
    model, tokenizer = get_T5_model(model_dir)
    if half_precision:
        model = model.half()

    # sort sequences by length to trigger OOM at the beginning
    seq_dict = sorted(seq_dict.items(), key=lambda kv: len(kv[1]), reverse=True)
    
    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in enumerate(tqdm(seq_dict, desc="Embedding sequences"), 1):
        # replace non-standard AAs
        seq = seq.replace('U', 'X').replace('Z', 'X').replace('O', 'X').replace('B', 'X')
        seq_len = len(seq)
        seq = prefix + ' ' + ' '.join(list(seq))
        batch.append((pdb_id, seq, seq_len))

        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed 
        n_res_batch = sum([s_len for _, _, s_len in batch]) + seq_len 
        if len(batch) >= max_batch or n_res_batch >= max_residues or seq_idx == len(seq_dict) or seq_len > max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
            batch = list()

            token_encoding = tokenizer.batch_encode_plus(seqs, 
                                                     add_special_tokens=True, 
                                                     padding="longest", 
                                                     return_tensors='pt'
                                                     ).to(device)
            try:
                with torch.no_grad():
                    embedding_repr = model(token_encoding.input_ids, 
                                           attention_mask=token_encoding.attention_mask)
            except RuntimeError:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                return None

            # batch-size x seq_len x embedding_dim
            # extra token is added at the end of the seq
            for batch_idx, identifier in enumerate(pdb_ids):
                s_len = seq_lens[batch_idx]
                # account for prefix in offset
                emb = embedding_repr.last_hidden_state[batch_idx, 1:s_len+1]
                
                
                emb = emb.mean(dim=0)
                emb_dict[identifier] = emb.detach().cpu().numpy().squeeze()

    end = time.time()

    if len(emb_dict) != len(seq_dict):
        return None

    total_time = end - start
    return total_time

def find_best_params(seq_path, model_dir, half_precision=True, is_3Di=False, max_seq_len=3263):
    max_residues_values = [2**i for i in range(6, 14, 2)]  
    max_batch_values = [2**i for i in range(6, 14, 2)] 
    
    results = []
    
    for max_residues in max_residues_values:
        for max_batch in max_batch_values:
            try:
                total_time = get_embeddings(seq_path, model_dir, half_precision, is_3Di, max_residues, max_seq_len, max_batch)
                if total_time is None:
                    results.append((max_residues, max_batch, "Runtime Error"))
                    print(f"Runtime Error for max_residues={max_residues}, max_batch={max_batch}")
                    continue  # Skip to the next iteration if total_time is None
                results.append((max_residues, max_batch, total_time))
                print(f"Tested max_residues={max_residues}, max_batch={max_batch}, time={total_time:.2f} seconds")
            except MemoryError:
                results.append((max_residues, max_batch, "Memory Error"))
                print(f"Memory Error for max_residues={max_residues}, max_batch={max_batch}")
            except Exception as e:
                results.append((max_residues, max_batch, f"Error: {e}"))
                print(f"Failed max_residues={max_residues}, max_batch={max_batch} with error: {e}")
    
    # Find the best parameters
    valid_results = [result for result in results if isinstance(result[2], (int, float))]
    if valid_results:
        best_params = min(valid_results, key=lambda x: x[2])
        print(f"Best parameters: max_residues={best_params[0]}, max_batch={best_params[1]} with time={best_params[2]:.2f} seconds")
    else:
        print("No valid parameter combinations found.")
    
    return results, best_params if valid_results else None

if __name__ == '__main__':
    seq_path = "./data/Dataset/csv/Val.csv"
    
    results, best_params = find_best_params(seq_path, model_dir="Rostlab/ProstT5", half_precision=True, is_3Di=False)
    
    # Optionally, save the results to a file for later analysis
    results_df = pd.DataFrame(results, columns=["max_residues", "max_batch", "time"])
    results_df.to_csv("./src/all/models/ProstT5/embedding_time_results.csv", index=False)


Using device: cuda:0
Loading ProsT5 from: Rostlab/ProstT5


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 65.09it/s] 


Tested max_residues=1, max_batch=1, time=103.11 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:44<00:00, 64.42it/s] 


Tested max_residues=1, max_batch=64, time=104.18 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:44<00:00, 64.36it/s] 


Tested max_residues=1, max_batch=4096, time=104.27 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:45<00:00, 63.76it/s] 


Tested max_residues=1, max_batch=262144, time=105.26 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:44<00:00, 63.92it/s] 


Tested max_residues=1, max_batch=16777216, time=104.99 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 64.61it/s] 


Tested max_residues=1, max_batch=1073741824, time=103.87 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 64.90it/s] 


Tested max_residues=1, max_batch=68719476736, time=103.40 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:42<00:00, 65.48it/s] 


Tested max_residues=64, max_batch=1, time=102.49 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:42<00:00, 65.60it/s] 


Tested max_residues=64, max_batch=64, time=102.30 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:42<00:00, 65.36it/s] 


Tested max_residues=64, max_batch=4096, time=102.68 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:42<00:00, 65.24it/s] 


Tested max_residues=64, max_batch=262144, time=102.87 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 65.00it/s] 


Tested max_residues=64, max_batch=16777216, time=103.25 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 65.10it/s] 


Tested max_residues=64, max_batch=1073741824, time=103.09 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:42<00:00, 65.32it/s] 


Tested max_residues=64, max_batch=68719476736, time=102.74 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 64.92it/s] 


Tested max_residues=4096, max_batch=1, time=103.38 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [00:56<00:00, 117.93it/s]


Tested max_residues=4096, max_batch=64, time=56.91 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [00:57<00:00, 117.52it/s]


Tested max_residues=4096, max_batch=4096, time=57.11 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [00:57<00:00, 117.52it/s]


Tested max_residues=4096, max_batch=262144, time=57.11 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [00:57<00:00, 117.56it/s]


Tested max_residues=4096, max_batch=16777216, time=57.09 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [00:57<00:00, 117.72it/s]


Tested max_residues=4096, max_batch=1073741824, time=57.01 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [00:57<00:00, 117.59it/s]


Tested max_residues=4096, max_batch=68719476736, time=57.07 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 65.03it/s] 


Tested max_residues=262144, max_batch=1, time=103.19 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:   1%|          | 63/6711 [00:00<00:20, 325.81it/s]


RuntimeError during embedding for 1047893 (L=457)
Runtime Error for max_residues=262144, max_batch=64
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  12%|█▏        | 784/6711 [00:00<00:03, 1579.39it/s]

RuntimeError during embedding for 1046347 (L=253)
Runtime Error for max_residues=262144, max_batch=4096
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  12%|█▏        | 784/6711 [00:00<00:03, 1614.36it/s]

RuntimeError during embedding for 1046347 (L=253)
Runtime Error for max_residues=262144, max_batch=262144
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  12%|█▏        | 784/6711 [00:00<00:03, 1584.30it/s]

RuntimeError during embedding for 1046347 (L=253)
Runtime Error for max_residues=262144, max_batch=16777216
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  12%|█▏        | 784/6711 [00:00<00:03, 1746.54it/s]

RuntimeError during embedding for 1046347 (L=253)
Runtime Error for max_residues=262144, max_batch=1073741824
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  12%|█▏        | 784/6711 [00:00<00:03, 1769.38it/s]

RuntimeError during embedding for 1046347 (L=253)
Runtime Error for max_residues=262144, max_batch=68719476736
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:43<00:00, 65.09it/s] 


Tested max_residues=16777216, max_batch=1, time=103.10 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:   1%|          | 63/6711 [00:00<00:11, 570.20it/s]


RuntimeError during embedding for 1047893 (L=457)
Runtime Error for max_residues=16777216, max_batch=64
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  61%|██████    | 4095/6711 [00:01<00:01, 2190.37it/s] 

RuntimeError during embedding for 1048794 (L=108)
Runtime Error for max_residues=16777216, max_batch=4096
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2203.44it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=16777216, max_batch=262144
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2214.89it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=16777216, max_batch=16777216
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2179.79it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=16777216, max_batch=1073741824
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2218.03it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=16777216, max_batch=68719476736
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:42<00:00, 65.67it/s] 


Tested max_residues=1073741824, max_batch=1, time=102.19 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:   1%|          | 63/6711 [00:00<00:11, 570.53it/s]


RuntimeError during embedding for 1047893 (L=457)
Runtime Error for max_residues=1073741824, max_batch=64
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  61%|██████    | 4095/6711 [00:01<00:01, 2191.39it/s] 

RuntimeError during embedding for 1048794 (L=108)
Runtime Error for max_residues=1073741824, max_batch=4096
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2171.77it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=1073741824, max_batch=262144
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2216.83it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=1073741824, max_batch=16777216
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2180.93it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=1073741824, max_batch=1073741824
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2219.03it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=1073741824, max_batch=68719476736
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|██████████| 6711/6711 [01:42<00:00, 65.31it/s] 


Tested max_residues=68719476736, max_batch=1, time=102.76 seconds
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:   1%|          | 63/6711 [00:00<00:11, 577.24it/s]


RuntimeError during embedding for 1047893 (L=457)
Runtime Error for max_residues=68719476736, max_batch=64
Loading ProsT5 from: Rostlab/ProstT5


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences:  61%|██████    | 4095/6711 [00:01<00:01, 2186.22it/s] 

RuntimeError during embedding for 1048794 (L=108)
Runtime Error for max_residues=68719476736, max_batch=4096
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2233.78it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=68719476736, max_batch=262144
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2174.92it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=68719476736, max_batch=16777216
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2199.71it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=68719476736, max_batch=1073741824
Loading ProsT5 from: Rostlab/ProstT5



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Embedding sequences: 100%|█████████▉| 6710/6711 [00:03<00:00, 2221.93it/s] 

RuntimeError during embedding for 1046789 (L=14)
Runtime Error for max_residues=68719476736, max_batch=68719476736
Best parameters: max_residues=4096, max_batch=64 with time=56.91 seconds



