In [1]:
import torch
from model_protein_moe import trans_basic_block, trans_basic_block_Config
from utils_search import *
from transformers import T5EncoderModel, T5Tokenizer
import re
import gc
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from collections import defaultdict

from huggingface_protein_vec import ProteinVec, ProteinVecConfig
from tqdm.auto import tqdm
from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#Protein-Vec MOE model checkpoint and config
vec_model_cpnt = 'protein_vec_models/protein_vec.ckpt'
vec_model_config = 'protein_vec_models/protein_vec_params.json'

In [None]:
#Load the ProtTrans model and ProtTrans tokenizer
tokenizer = T5Tokenizer.from_pretrained("lhallee/prot_t5_enc", do_lower_case=False)
model = T5EncoderModel.from_pretrained("lhallee/prot_t5_enc").to(device).eval()
gc.collect()

In [None]:
#Load the model
vec_model_config = trans_basic_block_Config.from_json(vec_model_config)
model_deep = trans_basic_block.load_from_checkpoint(vec_model_cpnt, config=vec_model_config).to(device).eval()
gc.collect()

In [None]:
def embed_seqs(model, model_deep, tokenizer, seqs, device):
    sampled_keys = np.array(['TM', 'PFAM', 'GENE3D', 'ENZYME', 'MFO', 'BPO', 'CCO'])
    all_cols = np.array(['TM', 'PFAM', 'GENE3D', 'ENZYME', 'MFO', 'BPO', 'CCO'])
    masks = [all_cols[k] in sampled_keys for k in range(len(all_cols))]
    masks = torch.logical_not(torch.tensor(masks, dtype=torch.bool))[None,:]

    embed_all_sequences = []
    for seq in tqdm(seqs): 
        protrans_sequence = featurize_prottrans([seq], model, tokenizer, device)
        embedded_sequence = embed_vec(protrans_sequence, model_deep, masks, device)
        embed_all_sequences.append(embedded_sequence)
    return np.concatenate(embed_all_sequences)

In [2]:
all_seqs = load_dataset('lhallee/triplets', split='valid')['positives']
seqs = all_seqs[:4]
seqs

['MRALKARSRLASRRQLKKLDEDSLTKQPEEVFDVLEKLGEGSYGSVYKAIHKETGQIVAIKQVPVESDLQEIIKEISIMQQCDSPHVVKYYGSYFKNTDLWIVMEYCGAGSVSDIIRLRNKTLTEDEIATILQSTLKGLEYLHFMRKIHRDIKAGNILLNTEGHAKLADFGVAGQLTDTMAKRNTVIGTPFWMAPEVIQEIGYNCVADIWSLGITAIEMAEGKPPYADIHPMRAIFMIPTNPPPTFRKPEVWSDNFMDFVKQCLVKSPEQRATATQLLQHPFVKSAKGAAILRDLINEAMDVKLKRQEAQQRAVDQDDDENSEEDEMDSGTMVRAAGDDMGTVRVASTMSGGANTMIEHGDTLPSQLGTMVINTEDEEEEGTMKRRDETMQPAKPSFLEYFEQKEKENQINSFGKNVSGSLKNSSDWKIPQDGDYEFLKSWTVEDLQKRLSALDPMMEQEMEEIRQKYRSKRQPILDAIEAKKRRQQNF',
 'MSAPTADIRARAPEAKKVHIADTAINRHNWYKHVNWLNVFLIIGIPLYGCIQAFWVPLQLKTAIWAVIYYFFTGLGITAGYHRLWAHCSYSATLPLRIWLAAVGGGAVEGSIRWWARDHRAHHRYTDTDKDPYSVRKGLLYSHLGWMVMKQNPKRIGRTDISDLNEDPVVVWQHRNYLKVVFTMGLAVPMLVAGLGWGDWLGGFVYAGILRIFFVQQATFCVNSLAHWLGDQPFDDRNSPRDHVITALVTLGEGYHNFHHEFPSDYRNAIEWHQYDPTKWSIWAWKQLGLAYDLKKFRANEIEKGRVQQLQKKLDRKRATLDWGTPLDQLPVMEWDDYVEQAKNGRGLVAIAGVVHDVTDFIKDHPGGKAMISSGIGKDATAMFNGGVYYHSNAAHNLLSTMRVGVIRGGCEVEIWKRAQKENVEYVRDGSGQRVIRAGEQPTKIPEPIPTADAA',
 'LCFQLLPVVIGAAVVAFTVLYLFFRSVSSYHKKRGKKSPVTLQD

In [None]:
emb = embed_seqs(model, model_deep, tokenizer, seqs, device)
emb.shape

In [None]:
np.save('local_emb_test.npy', emb)

In [3]:
model = ProteinVec.from_pretrained('lhallee/ProteinVec', config=ProteinVecConfig())
model.to_eval()
model = model.to(device)
tokenizer = T5Tokenizer.from_pretrained('lhallee/ProteinVec')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
sequences = [(" ".join(seqs[i])) for i in range(len(seqs))]
sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences]
ids = tokenizer.batch_encode_plus(sequences, add_special_tokens=True, padding=True)
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

In [5]:
emb = model.embed_batch(input_ids, attention_mask, aspect=6).detach().cpu().numpy()

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


In [6]:
np.save('hf_emb_test.npy', emb)

In [7]:
local_emb = np.load('local_emb_test.npy')
hf_emb = np.load('hf_emb_test.npy')

In [8]:
local_emb

array([[-9.0215225 ,  6.265668  ,  4.1960397 , ...,  0.81374305,
         5.500475  , -2.5042906 ],
       [-8.844703  ,  6.1047983 ,  4.215445  , ...,  0.78629345,
         5.840502  , -2.5396354 ],
       [-8.900217  ,  6.1098895 ,  4.2216673 , ...,  1.0021304 ,
         5.853938  , -2.494628  ],
       [-8.937391  ,  6.2477336 ,  4.0384502 , ...,  0.830227  ,
         5.735293  , -2.391416  ]], dtype=float32)

In [9]:
hf_emb

array([[-9.021523 ,  6.2656674,  4.196039 , ...,  0.8137429,  5.500475 ,
        -2.5042913],
       [-8.844703 ,  6.104799 ,  4.215445 , ...,  0.7862937,  5.8405027,
        -2.5396352],
       [-8.900217 ,  6.10989  ,  4.221667 , ...,  1.0021304,  5.853938 ,
        -2.4946282],
       [-8.937391 ,  6.247734 ,  4.038451 , ...,  0.8302269,  5.7352924,
        -2.3914163]], dtype=float32)

In [14]:
np.allclose(local_emb, hf_emb, rtol=1e-3)

True

In [None]:
# Load in uniprot meta data
meta_data_new = pd.read_csv('data/uniprotkb_AND_reviewed_true_2023_07_03.tsv', sep='\t')

In [None]:
#Now filter for the proteins that were newly discovered
new_proteins = meta_data_new[meta_data_new['Date of creation'] > '2022-05-25'].reset_index(drop=True)

In [None]:
print('Number of new proteins deposited after 2022-05-25')
print(len(new_proteins))

# filter those proteins that are greater than 2000 amino acid residues long
new_proteins['length'] = new_proteins['Sequence'].str.len()
new_proteins = new_proteins[new_proteins['length'] <= 2000]
print('Filtered proteins longer than 2000 amino acids')
print(len(new_proteins))

In [None]:
# This is a forward pass of the Protein-Vec model
# Every aspect is turned on (therefore no masks)
sampled_keys = np.array(['TM', 'PFAM', 'GENE3D', 'ENZYME', 'MFO', 'BPO', 'CCO'])
all_cols = np.array(['TM', 'PFAM', 'GENE3D', 'ENZYME', 'MFO', 'BPO', 'CCO'])
masks = [all_cols[k] in sampled_keys for k in range(len(all_cols))]
masks = torch.logical_not(torch.tensor(masks, dtype=torch.bool))[None,:]

#Pull out sequences for the new proteins
flat_seqs = new_proteins['Sequence'].values

#Loop through the sequences and embed them using protein-vec
i = 0
embed_all_sequences = []
while i < len(flat_seqs): 
    protrans_sequence = featurize_prottrans(flat_seqs[i:i+1], model, tokenizer, device)
    embedded_sequence = embed_vec(protrans_sequence, model_deep, masks, device)
    embed_all_sequences.append(embedded_sequence)
    i = i + 1
    if i % 50 == 0:
        print(i)    


In [None]:
#Combine the embedding vectors into an array
query_embeddings = np.concatenate(embed_all_sequences)

Now that we have embeddings for the newly discovered proteins, we can visualize them after performing TSNE, and we can transfer annotations to them as well

In [None]:
#Perform TSNE on the embedding vectors
all_X_embedded = TSNE(n_components=2, perplexity=10, learning_rate='auto', init='random').fit_transform(query_embeddings)
all_X_embedded_df = pd.DataFrame(all_X_embedded)
all_X_embedded_df.columns = ["Dim1", "Dim2"]
all_X_embedded_df['Pfam'] = new_proteins['Pfam'].values[:len(all_X_embedded_df)]
all_X_embedded_df['EC'] = new_proteins['EC number'].values[:len(all_X_embedded_df)]

In [None]:
#For visualization purposes, filter for the top 20 PFam terms
top_ranks = list(all_X_embedded_df['Pfam'].value_counts()[0:20].index)
sns.lmplot(x="Dim1", y="Dim2", data=all_X_embedded_df[all_X_embedded_df['Pfam'].isin(top_ranks)], hue="Pfam", fit_reg=False)


In [None]:
################## Load the lookup database of all embeddings (note that we will pull out only embeddings from proteins that were trained on)
embeddings = np.load('protein_vec_embeddings/lookup_embeddings.npy')
lookup_proteins_meta = pd.read_csv('protein_vec_embeddings/lookup_embeddings_meta_data.tsv', sep="\t")


In [None]:
print("Maximum date of lookup database protein")
np.max(lookup_proteins_meta['Date of creation'])

We can run search and the nearest neighbor pipeline for any of our available aspects
 - 'Gene Ontology (biological process)'
 - 'Gene Ontology (molecular function)' 
 - 'Gene Ontology (cellular component)' 
 - 'Gene3D' 
 - 'Pfam' 
 - 'EC number'

In [None]:
#Switch this for whichever aspect you want to perform search for
############### User parameter
column = 'Pfam'

In [None]:
# Filter for lookup proteins with annotations for the relavant aspect (don't want to transfer null annotations)
col_lookup = lookup_proteins_meta[~lookup_proteins_meta[column].isnull()]
col_lookup_embeddings = embeddings[col_lookup.index]
col_meta_data = col_lookup[column].values

# load database
lookup_database = load_database(col_lookup_embeddings)

# Query for the 1st nearest neighbor
k = 1
D, I = query(lookup_database, query_embeddings, k)

#Get metadata for the 1st nearest neighbor
near_ids = []
for i in range(I.shape[0]):
    meta = col_meta_data[I[i]]
    near_ids.append(list(meta))       

near_ids = np.array(near_ids)

In [None]:
print("Annotations for the nearest neighbors (with aspect annotations) of newly discovered proteins")
print(near_ids)