In [1]:
import torch
import sys
import importlib

In [2]:
import os

In [3]:
import numpy as np

In [4]:
from tqdm import tqdm

In [5]:
sys.path.append('../pooling_generator/')

In [6]:
import token_to_sequence_pooler

In [7]:
importlib.reload(token_to_sequence_pooler)

<module 'token_to_sequence_pooler' from '/oak/stanford/groups/rbaltman/alptartici/pooling_work/protein_token_importance_analysis/../pooling_generator/token_to_sequence_pooler.py'>

In [8]:
from scipy.stats import pearsonr, spearmanr

# ESM-2

In [17]:
accession = "P33261" ## CHANGE THIS LINE
pooler = token_to_sequence_pooler.TokenToSequencePooler(uniprot_accession=accession)
matrix_to_use = "max_of_max_attn_per_layer"
importances = pooler.pageRank_pooling(which_matrix=matrix_to_use,
                                        top_k_attns=100, 
                                        verbose=True, 
                                        return_importance=True, 
                                        apply_to_wt=False, 
                                        diagonal_mask=0,
                                         include_cls = False
                                       )
#importances
np.save(f'/oak/stanford/groups/rbaltman/residue_importances/ESM2-650M/{accession}_importances.npy', importances)

In [9]:
attn_files = os.listdir('/oak/stanford/groups/rbaltman/esm_embeddings/esm2_t33_650M_uniprot/attention_matrices_mean_max_perLayer/')
rep_files =  os.listdir('/oak/stanford/groups/rbaltman/esm_embeddings/esm2_t33_650M_uniprot/representation_matrices/')

In [10]:
accessions_to_work_on = []
for file in tqdm(attn_files):
    if 'mut' in file:
        continue
    if file in rep_files:
        acc = file.split('.pt')[0]
        accessions_to_work_on.append(acc)

100%|██████████| 40353/40353 [00:27<00:00, 1455.79it/s]


In [None]:
problematic_accessions = []
for accession in tqdm(accessions_to_work_on):
    importances = None
    path_to_save = f'/oak/stanford/groups/rbaltman/residue_importances/ESM2-650M/{accession}_res_importances.npy'
    if os.path.exists(path_to_save):
        continue
    try:
        pooler = token_to_sequence_pooler.TokenToSequencePooler(uniprot_accession=accession)
        matrix_to_use = "max_of_max_attn_per_layer"
        importances = pooler.pageRank_pooling(which_matrix=matrix_to_use,
                                                top_k_attns=100, 
                                                verbose=True, 
                                                return_importance=True, 
                                                apply_to_wt=False, 
                                                diagonal_mask=0,
                                                 include_cls = False
                                               )
        if importances is None:
            raise Exception('none in importances')
        #importances
        np.save(path_to_save, importances)
    except:
        problematic_accessions.append(accession)
        num_problematic_acc = len(problematic_accessions)
        if num_problematic_acc > 99 and num_problematic_acc  % 100 == 0:
            print(f'number of problematic accessions is {num_problematic_acc}')

 53%|█████▎    | 15599/29425 [8:34:21<2:33:24,  1.50it/s]  

# protBERT

In [13]:
import token_to_sequence_pooler_nonESM

In [14]:
importlib.reload(token_to_sequence_pooler_nonESM)

<module 'token_to_sequence_pooler_nonESM' from '/oak/stanford/groups/rbaltman/alptartici/pooling_work/protein_token_importance_analysis/../pooling_generator/token_to_sequence_pooler_nonESM.py'>

In [15]:
attn_files = os.listdir('/oak/stanford/groups/rbaltman/protbert_embeddings/attention_matrices_mean_max_perLayer/')
rep_files =  os.listdir('/oak/stanford/groups/rbaltman/protbert_embeddings/representation_matrices/')

In [16]:
accessions_to_work_on_protbert = []
for file in tqdm(attn_files):
    if 'mut' in file:
        continue
    if file in rep_files:
        acc = file.split('.pt')[0]
        accessions_to_work_on_protbert.append(acc)

100%|██████████| 17736/17736 [00:01<00:00, 9279.73it/s]


In [None]:
problematic_accessions_protbert = []
for accession in tqdm(accessions_to_work_on):
    importances = None
    path_to_save = f'/oak/stanford/groups/rbaltman/residue_importances/protbert/{accession}_res_importances.npy'
    if os.path.exists(path_to_save):
        continue
    try:
        pooler = token_to_sequence_pooler_nonESM.TokenToSequencePooler(uniprot_accession=accession, PLM='protbert')
        matrix_to_use = "max_of_max_attn_per_layer"
        importances = pooler.pageRank_pooling(which_matrix=matrix_to_use,
                                                top_k_attns=100, 
                                                verbose=True, 
                                                return_importance=True, 
                                                apply_to_wt=False, 
                                                diagonal_mask=0,
                                                 include_cls = False
                                               )
        if importances is None:
            raise Exception('none in importances')
        #importances
        np.save(path_to_save, importances)
    except:
        problematic_accessions.append(accession)
        num_problematic_acc = len(problematic_accessions)
        if num_problematic_acc > 99 and num_problematic_acc  % 100 == 0:
            print(f'number of problematic accessions is {num_problematic_acc}')

  4%|▍         | 1257/29425 [39:51<9:39:35,  1.23s/it] 

number of problematic accessions is 100


  9%|▉         | 2602/29425 [1:16:35<15:16:36,  2.05s/it] 

number of problematic accessions is 200


 13%|█▎        | 3715/29425 [1:48:49<6:06:47,  1.17it/s]  

number of problematic accessions is 300


 16%|█▋        | 4836/29425 [2:22:24<4:53:22,  1.40it/s]  

number of problematic accessions is 400


 21%|██        | 6082/29425 [2:57:01<5:05:19,  1.27it/s]  

number of problematic accessions is 500


 25%|██▍       | 7314/29425 [3:33:49<33:02:13,  5.38s/it] 

number of problematic accessions is 600


 29%|██▉       | 8613/29425 [4:12:21<7:54:01,  1.37s/it]  

number of problematic accessions is 700


 33%|███▎      | 9784/29425 [4:44:46<9:46:05,  1.79s/it] 

number of problematic accessions is 800


 37%|███▋      | 10979/29425 [5:21:22<6:23:09,  1.25s/it] 

number of problematic accessions is 900


 39%|███▉      | 11440/29425 [5:33:41<2:59:33,  1.67it/s] 

# convert to NPZ

## ESM

In [10]:
# Define the paths
input_directory = "/oak/stanford/groups/rbaltman/residue_importances/ESM2-650M"
output_path = "/oak/stanford/groups/rbaltman/residue_importances/esm2_650m_residue_importances.npz"

# Initialize a dictionary to hold the data
data_dict = {}

# Loop through each file in the directory
for filename in tqdm(os.listdir(input_directory)):
    if filename.endswith(".npy"):
        # Construct the full path to the .npy file
        file_path = os.path.join(input_directory, filename)
        
        # Load the .npy file
        data = np.load(file_path)
        
        # Use the filename without the extension as the key
        key = os.path.splitext(filename)[0].split('_')[0]
        #print(f'key is {key}')
        
        # Store the array in the dictionary
        data_dict[key] = data

# Save all data as a single .npz file
np.savez(output_path, **data_dict)

print(f"Data has been saved to {output_path}")

100%|██████████| 29424/29424 [14:50<00:00, 33.03it/s]


Data has been saved to /oak/stanford/groups/rbaltman/residue_importances/esm2_650m_residue_importances.npz


In [11]:
# Define the paths
input_directory = "/oak/stanford/groups/rbaltman/residue_importances/protbert"
output_path = "/oak/stanford/groups/rbaltman/residue_importances/protbert_residue_importances.npz"

# Initialize a dictionary to hold the data
data_dict = {}

# Loop through each file in the directory
for filename in tqdm(os.listdir(input_directory)):
    if filename.endswith(".npy"):
        # Construct the full path to the .npy file
        file_path = os.path.join(input_directory, filename)
        
        # Load the .npy file
        data = np.load(file_path)
        
        # Use the filename without the extension as the key
        key = os.path.splitext(filename)[0].split('_')[0]
        #print(f'key is {key}')
        
        # Store the array in the dictionary
        data_dict[key] = data

# Save all data as a single .npz file
np.savez(output_path, **data_dict)

print(f"Data has been saved to {output_path}")

100%|██████████| 27112/27112 [10:29<00:00, 43.08it/s] 


Data has been saved to /oak/stanford/groups/rbaltman/residue_importances/protbert_residue_importances.npz


#### testing

In [17]:
file_path = "/oak/stanford/groups/rbaltman/residue_importances/protbert_residue_importances.npz"
data_esm = np.load(file_path)

In [19]:
np.sum(data_esm['Q9NZV8'])

1.0000000000000009