# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from tqdm import tqdm
from os.path import exists
import matplotlib.pyplot as plt
import pandas as pd
import pickle

In [2]:
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.5.1


# Load custom scripts

In [3]:
from strokeDTI.predict_dti.encoder import *
from strokeDTI.predict_dti.model import *
from strokeDTI.predict_dti.data_processing import *
from strokeDTI.predict_dti.train_test_utility import *
from strokeDTI.predict_dti.samples_for_testing import *
from strokeDTI.predict_dti.params import *

# Test model

In [4]:
testModel = get_model_from_name('ResGatedGraphConv')

In [5]:
testModel

ResGatedModel(
  (drug_model): ResGatedGraphConv(33, 256)
  (target_model): CNN(
    (extraction): Sequential(
      (0): Conv1d(26, 64, kernel_size=(3,), stride=(1,))
      (1): ReLU(inplace=True)
      (2): Conv1d(64, 256, kernel_size=(3,), stride=(1,))
      (3): ReLU(inplace=True)
      (4): Conv1d(256, 1024, kernel_size=(5,), stride=(1,))
      (5): ReLU(inplace=True)
      (6): Dropout(p=0.25, inplace=False)
      (7): AdaptiveMaxPool1d(output_size=1)
    )
    (output): Sequential(
      (0): Linear(in_features=1024, out_features=768, bias=True)
      (1): ReLU()
    )
  )
  (mlp): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=1024, out_features=512, bias=True)
    (7): ReLU()
    (8): Linear(in_features=512, out_features=1, bias=True)
    (9): ReLU()

In [6]:
sequence_dic

{'ripk1': 'MQPDMSLNVIKMKSSDFLESAELDSGGFGKVSLCFHRTQGLMIMKTVYKGPNCIEHNEAL\n    LEEAKMMNRLRHSRVVKLLGVIIEEGKYSLVMEYMEKGNLMHVLKAEMSTPLSVKGRIIL\nEIIEGMCYLHGKGVIHKDLKPENILVDNDFHIKIADLGLASFKMWSKLNNEEHNELREVD\nGTAKKNGGTLYYMAPEHLNDVNAKPTEKSDVYSFAVVLWAIFANKEPYENAICEQQLIMC\nIKSGNRPDVDDITEYCPREIISLMKLCWEANPEARPTFPGIEEKFRPFYLSQLEESVEED\nVKSLKKEYSNENAVVKRMQSLQLDCVAVPSSRSNSATEQPGSLHSSQGLGMGPVEESWFA\nPSLEHPQEENEPSLQSKLQDEANYHLYGSRMDRQTKQQPRQNVAYNREEERRRRVSHDPF\nAQQRPYENFQNTEGKGTAYSSAASHGNAVHQPSGLTSQPQVLYQNNGLYSSHGFGTRPLD\nPGTAGPRVWYRPIPSHMPSLHNIPVPETNYLGNTPTMPFSSLPPTDESIKYTIYNSTGIQ\nIGAYNYMEIGGTSSSLLDSTNTNFKEEPAAKYQAIFDNTTSLTDKHLDPIRENLGKHWKN\nCARKLGFTQSQIDEIDHDYERDGLKEKVYQMLQKWVMREGIKGATVGKLAQALHQCSRID\nLLSSLIYVSQN',
 'ripk3': 'MSCVKLWPSGAPAPLVSIEELENQELVGKGGFGTVFRAQHRKWG\n    YDVAVKIVNSKAISREVKAMASLDNEFVLRLEGVIEKVNWDQDPKPALVTKFMEN\n    GSLSGLLQSQCPRPWPLLCRLLKEVVLGMFYLHDQNPVLLHRDLKPSNVLLDPEL\n    HVKLADFGLSTFQGGSQSGTGSGEPGGTLGYLAPELFVNVNRKASTASDVYSFGI\n    LMWAVLAGREVELPTEPSLVYEAVCNRQNRPSLAELPQAGPETPG

In [7]:
# import json

# # Saving the dictionary to a file
# with open("../../data/sequence_dic.json", "w") as file:
#     json.dump(sequence_dic, file)

In [8]:
with open("../../data/sequence_dic.json", "r") as file:
    loaded_sequence_dic = json.load(file)

In [9]:
loaded_sequence_dic

{'ripk1': 'MQPDMSLNVIKMKSSDFLESAELDSGGFGKVSLCFHRTQGLMIMKTVYKGPNCIEHNEAL\n    LEEAKMMNRLRHSRVVKLLGVIIEEGKYSLVMEYMEKGNLMHVLKAEMSTPLSVKGRIIL\nEIIEGMCYLHGKGVIHKDLKPENILVDNDFHIKIADLGLASFKMWSKLNNEEHNELREVD\nGTAKKNGGTLYYMAPEHLNDVNAKPTEKSDVYSFAVVLWAIFANKEPYENAICEQQLIMC\nIKSGNRPDVDDITEYCPREIISLMKLCWEANPEARPTFPGIEEKFRPFYLSQLEESVEED\nVKSLKKEYSNENAVVKRMQSLQLDCVAVPSSRSNSATEQPGSLHSSQGLGMGPVEESWFA\nPSLEHPQEENEPSLQSKLQDEANYHLYGSRMDRQTKQQPRQNVAYNREEERRRRVSHDPF\nAQQRPYENFQNTEGKGTAYSSAASHGNAVHQPSGLTSQPQVLYQNNGLYSSHGFGTRPLD\nPGTAGPRVWYRPIPSHMPSLHNIPVPETNYLGNTPTMPFSSLPPTDESIKYTIYNSTGIQ\nIGAYNYMEIGGTSSSLLDSTNTNFKEEPAAKYQAIFDNTTSLTDKHLDPIRENLGKHWKN\nCARKLGFTQSQIDEIDHDYERDGLKEKVYQMLQKWVMREGIKGATVGKLAQALHQCSRID\nLLSSLIYVSQN',
 'ripk3': 'MSCVKLWPSGAPAPLVSIEELENQELVGKGGFGTVFRAQHRKWG\n    YDVAVKIVNSKAISREVKAMASLDNEFVLRLEGVIEKVNWDQDPKPALVTKFMEN\n    GSLSGLLQSQCPRPWPLLCRLLKEVVLGMFYLHDQNPVLLHRDLKPSNVLLDPEL\n    HVKLADFGLSTFQGGSQSGTGSGEPGGTLGYLAPELFVNVNRKASTASDVYSFGI\n    LMWAVLAGREVELPTEPSLVYEAVCNRQNRPSLAELPQAGPETPG

# Test on different models

In [10]:
model_name_list = ['transformer_cnn','gatv2conv_cnn','gineconv_cnn','mpnn_cnn','ResGatedGraphConv']

In [11]:
drug_df = pd.read_csv('../../data/drug_list_with_smiles_first_20.csv')

In [12]:
# # Save first 10 to save time

# drug_df = drug_df[:20]

# drug_df.to_csv('../../data/drug_list_with_smiles_first_20.csv', index=False)

In [13]:
cleaned_target_dict = {target: remove_line(seq) for target, seq in loaded_sequence_dic.items()}


total_fold = 5

# Initialize a list to store the results
results = []

total_iterations = len(model_name_list) * total_fold * len(drug_df) * len(cleaned_target_dict)
with tqdm(total=total_iterations, desc="Computing DTI_scores",position=0) as pbar:
    for model_name in model_name_list:
        for fold in range(1, total_fold+1):  # Assuming folds 1 to 5
            # Setup the model for the current fold
            test_model = setup_model(model_name, fold, model_output="../../data/trained_model/")
            
            for target, target_sequence in cleaned_target_dict.items():
                for _, drug_row in drug_df.iterrows():
                    drug_name = drug_row['drug_names']
                    drug_smiles = drug_row['drug_smiles']
                    
                    # Compute the dti_score using the test function
                    dti_score = test(test_model, drug_smiles, target_sequence)
                    
                    # Append the result to the list
                    results.append({
                        'drug_names': drug_name,
                        'drug_smiles': drug_smiles,
                        'model': model_name,
                        'fold': fold,
                        'target': target,
                        'DTI_score': dti_score
                    })
                    
                    # Update the progress bar
                    pbar.update(1)


Computing DTI_scores: 100%|██████████| 1500/1500 [02:38<00:00,  9.48it/s]


In [14]:
results_df = pd.DataFrame(results)

In [15]:
results_df.to_csv('../../data/DTI_output.csv', index=False)