In [1]:
import torch
import os
import numpy as np
from tqdm import tqdm

In [None]:
proteins_to_load = np.load('proteins_to_load.npy', allow_pickle=True)

In [2]:
pooling_method = "cls_pooled"
embedding_dict = {}
parent_dir_poolings = "/oak/stanford/groups/rbaltman/esm_embeddings/esm2_t33_650M_uniprot/post_pooling_seq_vectors"
embedding_dir = os.path.join(parent_dir_poolings, pooling_method)
path_dict = f"dict_embeddings/{pooling_method}.pt"


In [3]:
for file_index in tqdm(range(len(proteins_to_load))):
    acc = proteins_to_load[file_index]
    if acc not in proteins_to_load:
        continue
    file_name  = f"{acc}.pt"
    if os.path.exists(f"{embedding_dir}/{file_name}"):
        accession = file_name[:-3]  # Remove file extension
        embedding_path = os.path.join(embedding_dir, file_name)
        embedding = torch.load(embedding_path)
        if not isinstance(embedding, torch.Tensor):
            embedding = torch.from_numpy(embedding).to(torch.float32).squeeze()
        
        #print("embedding", embedding, flush=True)
        embedding_dict[accession] = embedding

100%|██████████| 15879/15879 [03:01<00:00, 87.51it/s] 


In [8]:
torch.save(embedding_dict, path_dict)

## max pooled

In [9]:
pooling_method = "max_pooled"
embedding_dict = {}
parent_dir_poolings = "/oak/stanford/groups/rbaltman/esm_embeddings/esm2_t33_650M_uniprot/post_pooling_seq_vectors"
embedding_dir = os.path.join(parent_dir_poolings, pooling_method)
path_dict = f"dict_embeddings/{pooling_method}.pt"

for file_index in tqdm(range(len(proteins_to_load))):
    acc = proteins_to_load[file_index]
    if acc not in proteins_to_load:
        continue
    file_name  = f"{acc}.pt"
    if os.path.exists(f"{embedding_dir}/{file_name}"):
        accession = file_name[:-3]  # Remove file extension
        embedding_path = os.path.join(embedding_dir, file_name)
        embedding = torch.load(embedding_path)
        if not isinstance(embedding, torch.Tensor):
            embedding = torch.from_numpy(embedding).to(torch.float32).squeeze()
        
        #print("embedding", embedding, flush=True)
        embedding_dict[accession] = embedding


torch.save(embedding_dict, path_dict)

100%|██████████| 15879/15879 [10:23<00:00, 25.46it/s]


# mean pooled

In [10]:
pooling_method = "mean_pooled"
embedding_dict = {}
parent_dir_poolings = "/oak/stanford/groups/rbaltman/esm_embeddings/esm2_t33_650M_uniprot/post_pooling_seq_vectors"
embedding_dir = os.path.join(parent_dir_poolings, pooling_method)
path_dict = f"dict_embeddings/{pooling_method}.pt"

for file_index in tqdm(range(len(proteins_to_load))):
    acc = proteins_to_load[file_index]
    if acc not in proteins_to_load:
        continue
    file_name  = f"{acc}.pt"
    if os.path.exists(f"{embedding_dir}/{file_name}"):
        accession = file_name[:-3]  # Remove file extension
        embedding_path = os.path.join(embedding_dir, file_name)
        embedding = torch.load(embedding_path)
        if not isinstance(embedding, torch.Tensor):
            embedding = torch.from_numpy(embedding).to(torch.float32).squeeze()
        
        #print("embedding", embedding, flush=True)
        embedding_dict[accession] = embedding


torch.save(embedding_dict, path_dict)

100%|██████████| 15879/15879 [09:37<00:00, 27.48it/s]


# Pool PaRTI

In [11]:
pooling_method = "PR_contact_prune_topk_no_enh"
embedding_dict = {}
parent_dir_poolings = "/oak/stanford/groups/rbaltman/esm_embeddings/esm2_t33_650M_uniprot/post_pooling_seq_vectors"
embedding_dir = os.path.join(parent_dir_poolings, pooling_method)
path_dict = f"dict_embeddings/{pooling_method}.pt"

for file_index in tqdm(range(len(proteins_to_load))):
    acc = proteins_to_load[file_index]
    if acc not in proteins_to_load:
        continue
    file_name  = f"{acc}.pt"
    if os.path.exists(f"{embedding_dir}/{file_name}"):
        accession = file_name[:-3]  # Remove file extension
        embedding_path = os.path.join(embedding_dir, file_name)
        embedding = torch.load(embedding_path)
        if not isinstance(embedding, torch.Tensor):
            embedding = torch.from_numpy(embedding).to(torch.float32).squeeze()
        
        #print("embedding", embedding, flush=True)
        embedding_dict[accession] = embedding


torch.save(embedding_dict, path_dict)

100%|██████████| 15879/15879 [09:18<00:00, 28.42it/s]


In [13]:
torch.load("dict_embeddings/cls_pooled.pt")['Q7L5N1']

tensor([ 0.0710, -0.0341,  0.0738,  ..., -0.3235,  0.2001,  0.0547])

In [14]:
torch.load("dict_embeddings/max_pooled.pt")['Q7L5N1']

tensor([0.5824, 0.4923, 0.4563,  ..., 0.2624, 0.5301, 0.7738])

In [15]:
torch.load("dict_embeddings/mean_pooled.pt")['Q7L5N1']

tensor([ 0.0026, -0.0982, -0.0214,  ..., -0.2482, -0.0141,  0.1950])

In [16]:
torch.load("dict_embeddings/PR_contact_prune_topk_no_enh.pt")['Q7L5N1']

tensor([ 0.0016, -0.0949, -0.0191,  ..., -0.2570, -0.0132,  0.1928],
       dtype=torch.float64)