In [1]:
# In the following you can define your desired output. Current options:
# per_residue embeddings
# per_protein embeddings
# secondary structure predictions

# Replace this file with your own (multi-)FASTA
# Headers are expected to start with ">";
seq_path = "./protT5/example_seqs.fasta"

# whether to retrieve embeddings for each residue in a protein 
# --> Lx1024 matrix per protein with L being the protein's length
# as a rule of thumb: 1k proteins require around 1GB RAM/disk
per_residue = True 
per_residue_path = "./protT5/output/per_residue_embeddings.h5" # where to store the embeddings

# whether to retrieve per-protein embeddings 
# --> only one 1024-d vector per protein, irrespective of its length
per_protein = False
per_protein_path = "./protT5/output/per_protein_embeddings.h5" # where to store the embeddings

# whether to retrieve secondary structure predictions
# This can be replaced by your method after being trained on ProtT5 embeddings
sec_struct = False
sec_struct_path = "./protT5/output/ss3_preds.fasta" # file for storing predictions

# make sure that either per-residue or per-protein embeddings are stored
assert per_protein is True or per_residue is True or sec_struct is True, print(
    "Minimally, you need to active per_residue, per_protein or sec_struct. (or any combination)")


In [2]:
!nvidia-smi

Mon Dec 16 02:34:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 2070 ...    Off | 00000000:01:00.0  On |                  N/A |
| 37%   41C    P5              30W / 215W |    578MiB /  8192MiB |      7%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
#@title Import dependencies and check whether GPU is available. { display-mode: "form" }
from transformers import T5EncoderModel, T5Tokenizer
import torch
import h5py
import time
import gc
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

Using cuda:0


In [4]:
#@title Load ProtT5 in half-precision. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
def get_T5_model():
    model = T5EncoderModel.from_pretrained("../protT5/protT5_checkpoint/", torch_dtype=torch.float16)
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False ) 

    return model, tokenizer

In [5]:
def get_embeddings( model, tokenizer, seqs, per_residue, per_protein, sec_struct, 
                   max_residues=15000, max_seq_len=1200, max_batch=100 ):

    # if sec_struct:
    #   sec_struct_model = load_sec_struct_model()

    results = {"residue_embs" : dict(), 
               "protein_embs" : dict(),
               "sec_structs" : dict() 
               }

    # sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
    seq_dict   = sorted( seqs.items(), key=lambda kv: len( seqs[kv[0]] ), reverse=True )
    start = time.time()
    batch = list()
    for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1):
        seq = seq
        seq_len = len(seq)
        seq = ' '.join(list(seq))
        batch.append((pdb_id,seq,seq_len))

        # count residues in current batch and add the last sequence length to
        # avoid that batches with (n_res_batch > max_residues) get processed 
        n_res_batch = sum([ s_len for  _, _, s_len in batch ]) + seq_len 
        if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len:
            pdb_ids, seqs, seq_lens = zip(*batch)
#             print(len(seqs))
#             print(seq_lens)
            batch = list()

            # print(n_res_batch)
            # print(len(seqs))

            # add_special_tokens adds extra token at the end of each sequence
            token_encoding = tokenizer.batch_encode_plus(seqs,
                                                        add_special_tokens = True,
                                                        max_length = max_seq_len, 
                                                        padding = 'max_length',
                                                        truncation = True,
                                                        return_tensors = 'pt')
            input_ids      = token_encoding['input_ids'].to(device)
            # print(f'Shape of input ids is {input_ids.shape}')
            attention_mask = token_encoding['attention_mask'].to(device)
            # print(f'Shape of input ids is {input_ids.shape}')
            
            try:
                with torch.no_grad():
                    # returns: ( batch-size x max_seq_len_in_minibatch x embedding_dim )
                    embedding_repr = model(input_ids, attention_mask=attention_mask)
            except RuntimeError as e:
                print("RuntimeError during embedding for {} (L={})".format(pdb_id, seq_len))
                continue

            # if sec_struct: # in case you want to predict secondary structure from embeddings
            #   d3_Yhat, d8_Yhat, diso_Yhat = sec_struct_model(embedding_repr.last_hidden_state)


            for batch_idx, identifier in enumerate(pdb_ids): # for each protein in the current mini-batch
                s_len = seq_lens[batch_idx]
                # slice off padding --> batch-size x seq_len x embedding_dim  
                emb = embedding_repr.last_hidden_state[batch_idx,:s_len]
                if sec_struct: # get classification results
                    results["sec_structs"][identifier] = torch.max( d3_Yhat[batch_idx,:s_len], dim=1 )[1].detach().cpu().numpy().squeeze()
                if per_residue: # store per-residue embeddings (Lx1024)
                    results["residue_embs"][ identifier ] = emb.detach().cpu().numpy().squeeze()
                if per_protein: # apply average-pooling to derive per-protein embeddings (1024-d)
                    protein_emb = emb.mean(dim=0)
                    results["protein_embs"][identifier] = protein_emb.detach().cpu().numpy().squeeze()


    passed_time=time.time()-start
    avg_time = passed_time/len(results["residue_embs"]) if per_residue else passed_time/len(results["protein_embs"])
    # print('\n############# EMBEDDING STATS #############')
    # print('Total number of per-residue embeddings: {}'.format(len(results["residue_embs"])))
    # print('Total number of per-protein embeddings: {}'.format(len(results["protein_embs"])))
    # print("Time for generating embeddings: {:.1f}[m] ({:.3f}[s/protein])".format(
    #     passed_time/60, avg_time ))
    # print('\n############# END #############')

    return results

In [6]:
last_processed = -1
with open('checkpoint.txt') as f:
  last_processed = int(f.read())

In [7]:
last_processed

-1

In [8]:
def update_checkpoint(val: int):
  with open('checkpoint.txt', 'w') as f:
    f.write(str(val))

In [9]:
from collections import OrderedDict
import pandas as pd

df_train = pd.read_csv('./data/davis/raw/data_train.csv', usecols=['target_sequence'])
df_test = pd.read_csv('./data/davis/raw/data_test.csv', usecols=['target_sequence'])
df = pd.concat([df_train, df_test], axis=0)
df.drop_duplicates(inplace=True)
seqs = df['target_sequence'].to_dict(OrderedDict)

  from pandas.core import (
  seqs = df['target_sequence'].to_dict(OrderedDict)


In [10]:
# Load the encoder part of ProtT5-XL-U50 in half-precision (recommended)
model, tokenizer = get_T5_model()

In [10]:
keys = [key for key in seqs.keys()]
seq_num_to_process = 1
seq_to_write = 1000
index = (last_processed + 1)* seq_to_write
key_list = keys[index: ]
index

0

In [11]:
import numpy as np
from tqdm import trange, tqdm

In [12]:
buffer_dict = {}
write_counter = last_processed + 1
max_length = 1200

for i in tqdm(range(0, len(key_list), seq_num_to_process)):
    # print(f"epoch{i}/{len(key_list)//seq_num_to_process}")
    seq_keys = key_list[i : i+seq_num_to_process]

    temp_data_dic = {}

    for seq_key in seq_keys:
        temp_data_dic[seq_key] = seqs[seq_key]


    results = get_embeddings( model, tokenizer, temp_data_dic,
                      per_residue, per_protein, sec_struct, max_seq_len=max_length, max_batch=10)

    embedding = results['residue_embs'][seq_keys[-1]]
    sequence = seqs[seq_keys[-1]]
    buffer_dict.update({sequence:torch.from_numpy(embedding)})

    del results
    gc.collect()

100%|██████████| 379/379 [03:28<00:00,  1.82it/s]


In [None]:
torch.save(buffer_dict, './data/davis/raw/prot5.pth')

In [2]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import subgraph


def extract_individual_graphs(edge_index, batch, num_nodes):
    """
    Extracts individual graphs from batched edge_index and batch tensors.

    Args:
        edge_index (torch.LongTensor): Tensor of shape (2, num_edges) representing edge connections.
        batch (torch.LongTensor): Tensor of shape (num_nodes,) indicating graph membership for each node.
        num_nodes (int): Total number of nodes in the batch.

    Returns:
        List[Data]: A list of PyTorch Geometric Data objects, each representing an individual graph.
    """
    graphs = []
    unique_graphs = batch.unique()

    for graph_id in unique_graphs:
        # Create a mask for nodes belonging to the current graph
        node_mask = (batch == graph_id)
        
        # Get the node indices for the current graph
        node_indices = node_mask.nonzero(as_tuple=False).view(-1)
        # Extract the subgraph for the current graph
        sub_edge_index, edge_mask = subgraph(
            node_indices, 
            edge_index, 
            relabel_nodes=True,
        )
        
        # Optionally, extract node features or other attributes here
        # For example, if you have node features 'x', you can extract them as:
        # sub_x = x[node_indices]

        # Create a Data object for the subgraph
        graph = Data(edge_index=sub_edge_index, num_nodes=node_indices.size(0))
        
        # Append to the list of graphs
        graphs.append(graph)
    
    return graphs


# Example usage
# Example edge_index and batch tensors
edge_index = torch.tensor([
    [0, 1, 0, 3, 4, 6, 7],
    [1, 2, 2, 4, 5, 7, 0]
], dtype=torch.long)  # Shape: (2, 8)

batch = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.long)  # 8 nodes in total
num_nodes = batch.size(0)

# Extract individual graphs
individual_graphs = extract_individual_graphs(edge_index, batch, num_nodes)
print(individual_graphs)

# Display the extracted graphs
for idx, graph in enumerate(individual_graphs):
    print(f"Graph {idx}:")
    print(f"  Number of Nodes: {graph.num_nodes}")
    print(f"  Edge Index:\n{graph.edge_index}\n")


tensor([ True,  True,  True, False, False, False, False, False])
tensor([0, 1, 2])
tensor([[0, 1, 0],
        [1, 2, 2]])
Data(edge_index=[2, 3], num_nodes=3)
tensor([False, False, False,  True,  True,  True, False, False])
tensor([3, 4, 5])
tensor([[0, 1],
        [1, 2]])
Data(edge_index=[2, 2], num_nodes=3)
tensor([False, False, False, False, False, False,  True,  True])
tensor([6, 7])
tensor([[0],
        [1]])
Data(edge_index=[2, 1], num_nodes=2)
[Data(edge_index=[2, 3], num_nodes=3), Data(edge_index=[2, 2], num_nodes=3), Data(edge_index=[2, 1], num_nodes=2)]
Graph 0:
  Number of Nodes: 3
  Edge Index:
tensor([[0, 1, 0],
        [1, 2, 2]])

Graph 1:
  Number of Nodes: 3
  Edge Index:
tensor([[0, 1],
        [1, 2]])

Graph 2:
  Number of Nodes: 2
  Edge Index:
tensor([[0],
        [1]])

