In [None]:
# ANSI escape code for colored text
yellow = "\033[93m"
green = "\033[92m"
reset = "\033[0m"
red = "\033[91m"


import sys
import os

# Check if a virtual environment is active
if not hasattr(sys, 'base_prefix') or sys.base_prefix == sys.prefix:
    raise EnvironmentError(f"{red}No virtual environment is activated. Please activate the right venv_2 to run this code. See ReadMe for more details.{reset}")

# Get the name of the activated virtual environment
venv_path = os.environ.get('VIRTUAL_ENV')
if venv_path is None:
    raise EnvironmentError(f"{red}Error, venv path is none. Please activate the venv_2. See ReadMe for more details.{reset}")

venv_name = os.path.basename(venv_path)
if venv_name != "venv_2":
    raise EnvironmentError(f"{red}The activated virtual environment is '{venv_name}', not 'venv_2'. However venv_2 must be activated to run this code. See ReadMe for more details.{reset}")

import time
import torch
import numpy as np
import pandas as pd
from transformers import T5EncoderModel, T5Tokenizer
from tm_vec.embed_structure_model import trans_basic_block, trans_basic_block_Config
from tqdm import tqdm
import re
import gc
os.chdir('path_to_project_root')

if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print("Using device: {}".format(device))

def load_T5_model():
    print("loading model")
    tokeniser = T5Tokenizer.from_pretrained("./data/Dataset/weights/ProtT5/prot_t5_xl_uniref50", do_lower_case=False )
    model = T5EncoderModel.from_pretrained("./data/Dataset/weights/ProtT5/prot_t5_xl_uniref50")
    gc.collect()
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model = model.eval()
    
    return model, tokeniser

def read_csv(seq_path):
    '''
        Reads in CSV file containing sequences.
        Returns a dictionary of sequences with IDs as keys.
    '''
    sequences = {}
    df = pd.read_csv(seq_path)
    for _ , row in df.iterrows():
        sequences[str(row['Unnamed: 0'])] = row['Sequence']  # Ensure keys are strings
    return sequences

# Function to extract ProtTrans embedding for a sequence
def featurize_prottrans(sequences, model, tokenizer, device):
    sequences = [(" ".join(seq)) for seq in sequences]
    sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences]
    ids = tokenizer.batch_encode_plus(sequences, add_special_tokens=True, padding="longest",)
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    try:
        with torch.no_grad():
            embedding = model(input_ids=input_ids, attention_mask=attention_mask)
    
    except RuntimeError:
                print("RuntimeError during ProtT5 embedding  (nb sequences in batch={} /n (length of sequences in the batch ={}))".format(len(sequences), [len(seq) for seq in sequences]))
                sys.exit("Stopping execution due to RuntimeError.")
    
    embedding = embedding.last_hidden_state.cpu().numpy()

    features = []
    for seq_num in range(len(sequences)):
        seq_len = (attention_mask[seq_num] == 1).sum()
        seq_emd = embedding[seq_num][:seq_len - 1]
        features.append(seq_emd)

    prottrans_embedding = torch.tensor(features[0])
    prottrans_embedding = torch.unsqueeze(prottrans_embedding, 0).to(device)

    return prottrans_embedding

# Embed a protein using tm_vec (takes as input a prottrans embedding)
def embed_tm_vec(prottrans_embedding, model_deep, device, seq):
    padding = torch.zeros(prottrans_embedding.shape[0:2]).type(torch.BoolTensor).to(device)

    try:
        tm_vec_embedding = model_deep(prottrans_embedding, src_mask=None, src_key_padding_mask=padding)
    
    except RuntimeError:
        print("RuntimeError during TM_Vec embedding sequence {}".format(seq))
        return None

    return tm_vec_embedding.cpu().detach().numpy()

def encode(sequences, model_deep, model, tokenizer, device):
    embed_all_sequences = []
    for seq in sequences:
        protrans_sequence = featurize_prottrans([seq], model, tokenizer, device)
        if protrans_sequence is None:
            return None
        embedded_sequence = embed_tm_vec(protrans_sequence, model_deep, device, seq)
        embed_all_sequences.append(embedded_sequence)
    return np.concatenate(embed_all_sequences, axis=0)

def get_embeddings(seq_path, emb_path, max_residues, max_seq_len, max_batch):

    emb_dict = dict()

    # Read in CSV
    sequences_dict = read_csv(seq_path)
    sequences = list(sequences_dict.values())
    sequence_keys = list(sequences_dict.keys())
    
    model, tokeniser = load_T5_model()

    # TM-Vec model paths
    tm_vec_model_cpnt = "./data/Dataset/weights/TM_Vec/tm_vec_cath_model.ckpt"
    tm_vec_model_config = "./data/Dataset/weights/TM_Vec/tm_vec_cath_model_params.json"

    # Load the TM-Vec model
    tm_vec_model_config = trans_basic_block_Config.from_json(tm_vec_model_config)
    model_deep = trans_basic_block.load_from_checkpoint(tm_vec_model_cpnt, config=tm_vec_model_config)
    model_deep = model_deep.to(device)
    model_deep = model_deep.eval()

    sorted_sequences_tuple = sorted(zip(sequence_keys, sequences), key=lambda x: len(x[1]), reverse=True)
    
    start = time.time()

    batch = []
    batch_keys = []
    for seq_idx, (seq_key, seq) in enumerate(tqdm(sorted_sequences_tuple, desc="Embedding sequences"), 1):
        seq_len = len(seq)
        batch.append(seq)
        batch_keys.append(seq_key)

        n_res_batch = sum([len(s) for s in batch]) + seq_len
        if len(batch) >= max_batch or n_res_batch >= max_residues or seq_idx == len(sorted_sequences_tuple) or seq_len > max_seq_len:
            embedded_batch = encode(batch, model_deep, model, tokeniser, device)
            for i, seq_key in enumerate(batch_keys):
                emb_dict[seq_key] = embedded_batch[i]
            batch = []
            batch_keys = []

    end = time.time()

    total_time = end - start
    return total_time

def find_best_params(seq_path, emb_path, max_seq_len=3263):
    max_residues_values = [2**i for i in range(0, 42, 2)]  
    max_batch_values = [2**i for i in range(0, 42, 2)]  
    
    results = []
    
    for max_residues in max_residues_values:
        for max_batch in max_batch_values:
            try:
                total_time = get_embeddings(seq_path, emb_path, max_residues, max_seq_len, max_batch)
                results.append((max_residues, max_batch, total_time))
                print(f"Tested max_residues={max_residues}, max_batch={max_batch}, time={total_time:.2f} seconds")
            except MemoryError:
                results.append((max_residues, max_batch, "Memory Error"))
                print(f"Memory Error for max_residues={max_residues}, max_batch={max_batch}")
            except Exception as e:
                results.append((max_residues, max_batch, f"Error: {e}"))
                print(f"Failed max_residues={max_residues}, max_batch={max_batch} with error: {e}")
    
    # Find the best parameters
    valid_results = [result for result in results if isinstance(result[2], (int, float))]
    if valid_results:
        best_params = min(valid_results, key=lambda x: x[2])
        print(f"Best parameters: max_residues={best_params[0]}, max_batch={best_params[1]} with time={best_params[2]:.2f} seconds")
    else:
        print("No valid parameter combinations found.")
    
    return results, best_params if valid_results else None

if __name__ == '__main__':
    seq_path = "./data/Dataset/csv/Val.csv"
    emb_path = "./data/Dataset/embeddings/Val_TM_Vec_test.npz"
    
    results, best_params = find_best_params(seq_path, emb_path)
    
    # Optionally, save the results to a file for later analysis
    results_df = pd.DataFrame(results, columns=["max_residues", "max_batch", "time"])
    results_df.to_csv("./src/all/models/TM_Vec/embedding_time_results.csv", index=False)


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Using device: cuda:0
loading model


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.3.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint data/Dataset/weights/TM_Vec/tm_vec_cath_model.ckpt`
Embedding sequences:   8%|▊         | 531/6711 [00:40<07:50, 13.13it/s]


KeyboardInterrupt: 