# Test the TheNovel dataset

# 1 import

In [None]:
import pandas as pd
import os
import re
from tqdm import tqdm
import importlib
from matplotlib.pyplot import figure
from ZHMolGraph.import_modules import *
from ZHMolGraph import ZHMolGraph

In [None]:
model_Dataset = 'NPInter2'
unseen_Dataset = 'NPInter5'

# 2 define object

In [None]:
with open('data/Mol2Vec/RPI_'+ unseen_Dataset +'_rnafm_embed_normal.pkl', 'rb') as file: 
    test_rnas = pkl.load(file)
    
with open('data/Mol2Vec/RPI_' + unseen_Dataset + '_proteinprottrans_embed_normal.pkl', 'rb') as file: 
    test_proteins = pkl.load(file)

In [None]:
vecnet_object = ZHMolGraph.ZHMolGraph(interactions_location = 'data/interactions/dataset_RPI_' + model_Dataset + '_RP.csv',
                                  interactions = None,
                                  interaction_y_name = 'Y',
    
                                  absolute_negatives_location = None,
                                  absolute_negatives = None,
    
                                  rnas_location = None,
                                  rnas_dataframe = test_rnas,
                                  rna_seq_name = 'RNA_aa_code',
    
                                  proteins_location = None,
                                  proteins_dataframe = test_proteins, 
                                  protein_seq_name = 'target_aa_code',

    
                                  model_out_dir = f'trained_model/ZHMolGraph_VecNN_model_RPI_{model_Dataset}/',
    
            
                                  debug = False)

# 3 Load dataset

In [None]:
dataset_path = 'data/interactions/' + unseen_Dataset + '_interactions_seqpairs.csv'
interactions = pd.read_csv(dataset_path, sep=',')
# print(interactions)
interactions_seqpairs = pd.concat([interactions['RNA_aa_code'], interactions['target_aa_code'], interactions['Y']], axis=1)
# print(interactions_seqpairs)

# 4 Load dataset embedding

In [None]:
### 导入测试集预训练的嵌入 ###
# Read In rnas and targets dataframes to pass to AIBind after changing column names 
with open('data/Mol2Vec/RPI_'+ unseen_Dataset +'_rnafm_embed_normal.pkl', 'rb') as file: 
    test_rnas = pkl.load(file)
    
with open('data/Mol2Vec/RPI_' + unseen_Dataset + '_proteinprottrans_embed_normal.pkl', 'rb') as file: 
    test_targets = pkl.load(file)



In [None]:
rna_vector_length = 640
protein_vector_length = 1024

In [None]:
test_rna_embeddings = test_rnas['normalized_embeddings']
test_rna_array = np.zeros((len(test_rnas['normalized_embeddings']), rna_vector_length))

test_target_embeddings = test_targets['normalized_embeddings']
test_target_array = np.zeros((len(test_targets['normalized_embeddings']), protein_vector_length))

for i in range(len(test_rnas['normalized_embeddings'])):
    test_rna_array[i, :] = test_rna_embeddings.iloc[i]

for i in range(len(test_targets['normalized_embeddings'])):
    test_target_array[i, :] = test_target_embeddings.iloc[i]


## 导入graphsage的嵌入

In [None]:
rna_padding_length = rna_vector_length

test_rna_array = np.array([np.pad(row, (0, rna_padding_length - len(row)), 'constant') for row in test_rna_array])

# 指定补齐的长度
target_padding_length = protein_vector_length

# 创建一个新数组，将每一行后面补齐零到指定长度
test_target_array = np.array([np.pad(row, (0, target_padding_length - len(row)), 'constant') for row in test_target_array])


In [None]:
vecnet_object.normalized_rna_embeddings = test_rna_array
vecnet_object.normalized_target_embeddings = test_target_array

# 5 准备测试

In [None]:
# print(interactions_seqpairs)
interactions_seqpairs_copy = interactions_seqpairs.copy()
interactions_seqpairs_copy = interactions_seqpairs_copy[interactions_seqpairs_copy['RNA_aa_code'].apply(len) > 100].reset_index(drop=True)
interactions_seqpairs_copy = interactions_seqpairs_copy[interactions_seqpairs_copy['Y'] == 1].reset_index(drop=True)
# print(interactions_seqpairs_copy)

In [None]:
negative_dataframe_file = 'data/interactions/NPInter5.xlsx'
negative_dataframe = pd.read_excel(negative_dataframe_file)
# print(negative_dataframe)


Protein_sequence = pd.DataFrame(columns=['target_aa_code'])
# print(Protein_sequence)
RNA_sequence = pd.DataFrame(columns=['RNA_aa_code'])
# print(RNA_sequence)

for i in range(len(negative_dataframe)):
    protein_name = negative_dataframe['Protein names'].iloc[i]
    protein_sequence = interactions[interactions['Protein names'] == protein_name]['target_aa_code'].iloc[0]
    # print(protein_sequence)
    Protein_sequence.loc[i] = protein_sequence

    rna_name = negative_dataframe['RNA names'].iloc[i]
    # print(rna_name)
    
    rna_sequence = interactions[interactions['RNA names'] == rna_name]['RNA_aa_code'].iloc[0]
    # print(rna_sequence)
    RNA_sequence.loc[i] = rna_sequence
# print(Protein_sequence)
# print(RNA_sequence)
negative_interaction_dataframe = pd.concat([negative_dataframe, Protein_sequence, RNA_sequence], axis=1)
# print(negative_interaction_dataframe)
interactions_seqpairs_balanced = negative_interaction_dataframe[['target_aa_code','RNA_aa_code','Labels']]
interactions_seqpairs_balanced.rename(columns={'Labels': 'Y'}, inplace=True)
# print(interactions_seqpairs_balanced)

In [None]:
graphsage_model_path = vecnet_object.model_out_dir
result_auc_aup_path = 'result/' + 'Mymethod' + '_' + 'TrainNPInter2' + '_' + 'TestNPInter5' + '_fold'
embedding_type = 'Pretrain'
vecnet_object.get_TheNovel_test_results(model_dataset=model_Dataset, 
                                                  graphsage_path=graphsage_model_path, 
                                                  unseen_dataset=unseen_Dataset, 
                                                  test_dataframe=interactions_seqpairs_balanced,
                                                  rna_vector_length=rna_vector_length, 
                                                  protein_vector_length=protein_vector_length,
                                                  rnas=test_rnas,
                                                  proteins=test_proteins,
                                                  result_path=result_auc_aup_path, 
                                                  embedding_type=embedding_type)