In [None]:
# Demo
import numpy as np
# Load the .npy file
data = np.load('filename.npy')

print("Load ProtBERT Model...")
# PROT BERT LOADING :
from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertModel.from_pretrained("Rostlab/prot_bert").to(config.device)

def get_bert_embedding(
    sequence : str,
    len_seq_limit : int
):
    '''
    Function to collect last hidden state embedding vector from pre-trained ProtBERT Model

    INPUTS:
    - sequence (str) : protein sequence (ex : AAABBB) from fasta file
    - len_seq_limit (int) : maximum sequence lenght (i.e nb of letters) for truncation

    OUTPUTS:
    - output_hidden : last hidden state embedding vector for input sequence of length 1024
    '''
    sequence_w_spaces = ' '.join(list(sequence))
    encoded_input = tokenizer(
        sequence_w_spaces,
        truncation=True,
        max_length=len_seq_limit,
        padding='max_length',
        return_tensors='pt').to(config.device)
    output = model(**encoded_input)
    output_hidden = output['last_hidden_state'][:,0][0].detach().cpu().numpy()
    assert len(output_hidden)==1024
    return output_hidden

### COLLECTING FOR TRAIN SAMPLES :
print("Loading train set ProtBERT Embeddings...")
fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
print("Total Nb of Elements : ", len(list(fasta_train)))
fasta_train = SeqIO.parse(config.train_sequences_path, "fasta")
ids_list = []
embed_vects_list = []
t0 = time.time()
checkpoint = 0
for item in tqdm(fasta_train):
    ids_list.append(item.id)
    embed_vects_list.append(
        get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
    checkpoint+=1
    if checkpoint>=100:
        df_res = pd.DataFrame(data={"id" : ids_list, "embed_vect" : embed_vects_list})
        np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
        np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
        checkpoint=0

np.save('/kaggle/working/train_ids.npy',np.array(ids_list))
np.save('/kaggle/working/train_embeddings.npy',np.array(embed_vects_list))
print('Total Elapsed Time:',time.time()-t0)

### COLLECTING FOR TEST SAMPLES :
print("Loading test set ProtBERT Embeddings...")
fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
print("Total Nb of Elements : ", len(list(fasta_test)))
fasta_test = SeqIO.parse(config.test_sequences_path, "fasta")
ids_list = []
embed_vects_list = []
t0 = time.time()
checkpoint=0
for item in tqdm(fasta_test):
    ids_list.append(item.id)
    embed_vects_list.append(
        get_bert_embedding(sequence = item.seq, len_seq_limit = 1200))
    checkpoint+=1
    if checkpoint>=100:
        np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
        np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
        checkpoint=0

np.save('/kaggle/working/test_ids.npy',np.array(ids_list))
np.save('/kaggle/working/test_embeddings.npy',np.array(embed_vects_list))
print('Total Elasped Time:',time.time()-t0)
