In [1]:
### biosnap -- random-only-exist

import json
import os
import random
import shutil
from data_utils import *
from tqdm import tqdm
import torch
from sentence_transformers import util,SentenceTransformer
from sentence_transformers.util import semantic_search
from sklearn.metrics import confusion_matrix
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
    
def cls_test(root_dir, output_dir, model_dir, model_name):
    drugbank_test_path = os.path.join(root_dir, "test_add.jsonl")
    drugbank_test = readJSONL(drugbank_test_path)
    ## 改1
    ou_drugbank_path = os.path.join(output_dir, "dti_retri_protein_biosnap_cls2_ml8192_1202.csv")
    model_path = os.path.join(model_dir, model_name)
    
        
    drugbank = []
    # The reference chunk is {example["top1_contents"]}, {example["top2_contents"]}, {example["top3_contents"]}.
    ## 改2
    for example in drugbank_test:
        if example["drug1_id"]:
            text = f"""
            Try to figure out drug-target interaction between the drug and the target. 
            The drug name is {example["drug1_name"]}, the drug smiles is {example["drug1_smile"]} and the drug description is {example["drug1_desc"]}. 
            The target protein description is {example["protein_desc"]}
            Please think step by step!
            {example["top1_contents"]}
            """
            drugbank.append(text)


    model = SentenceTransformer(model_path, trust_remote_code=True)
    pool = model.start_multi_process_pool()
    drugbank_embeddings = model.encode_multi_process(drugbank, pool, normalize_embeddings=True, batch_size=100)
    print(drugbank_embeddings.shape)
    print("embedding完成")
    
    corpus = ['negative', 'advise']
    corpus_embeddings = model.encode_multi_process(corpus, pool, normalize_embeddings=True, batch_size=50)
    top_selection = 2
    res_drugbank = semantic_search(drugbank_embeddings, corpus_embeddings, query_chunk_size=100, top_k=top_selection, score_function=util.dot_score)

    y_pred = np.array([int(res_drugbank[idx][0]["corpus_id"]) for idx in range(len(drugbank_test))])

    y_true = []
    y_pred_proba = []
    
    # Function to compute softmax
    def softmax(scores):
        exp_scores = np.exp(scores - np.max(scores))  # Subtract max for numerical stability
        return exp_scores / exp_scores.sum(axis=0)
    
    # Iterate through the data and compute softmax, replacing scores
    softmax_data = []
    for sublist in res_drugbank:
        scores = [item['score'] for item in sublist]
        softmax_scores = softmax(np.array(scores))
        
        # Replace scores in the original structure
        for i, item in enumerate(sublist):
            item['score'] = softmax_scores[i]
        softmax_data.append(sublist)
    # print("softmax_data", softmax_data)
    res_drugbank2 = softmax_data

    
    for idx, example in enumerate(drugbank_test):
        if int(res_drugbank2[idx][0]["corpus_id"]) == 0:
               y_pred_proba.append(res_drugbank2[idx][1]["score"])
        else:
               y_pred_proba.append(res_drugbank2[idx][0]["score"])

            
    for idx, example in enumerate(drugbank_test):
        if example["pos"] == 'negative':
            y_true.append(0)
        else:
            y_true.append(1)
           
    y_true = np.array(y_true)
    y_pred_proba = np.array(y_pred_proba)
    ## drugbank
    cm = confusion_matrix(y_true, y_pred)
    print('cm', cm)
    # # 保存为文本文件
    # with open('/root/autodl-tmp/piccolo-embedding/test/test_results/cm.txt', 'w') as f:
    #     for row in cm:
    #         f.write('\t'.join(map(str, row)) + '\n')

    
    drugbank_acc = accuracy_score(y_true, y_pred)
    drugbank_pre = precision_score(y_true, y_pred, average='macro')
    drugbank_recall = recall_score(y_true, y_pred, average='macro')
    drugbank_f1 = f1_score(y_true, y_pred, average='macro')

    # Calculate AUC and AUPR
    drugbank_auc = roc_auc_score(y_true, y_pred_proba)
    drugbank_aupr = average_precision_score(y_true, y_pred_proba)

    ### 改3
    new_item_drugbank = {
        "model": 'dti_ml8192_retri_protein_biosnap_random_only_exist_prompt3_n0_fold2_bs64_add_' + 'stella_' + 'ep10',
        "chem_acc": drugbank_acc,
        "chem_pre": drugbank_pre,
        "chem_recall": drugbank_recall,
        "chem_f1": drugbank_f1,
        "chem_auc": drugbank_auc,
        "chem_aupr": drugbank_aupr
    }
    if os.path.exists(ou_drugbank_path):
        writeCSV_xu([new_item_drugbank], ou_drugbank_path)
    else:
        writeCSV([new_item_drugbank], ou_drugbank_path)
            
        

if __name__ == "__main__":
    ### 改4
    root_dir = "/root/autodl-tmp/dataset_dti/datasets/datasets_tmp_all/biosnap_tmp/random_only_exist_retrieval/fold2/"
    output_dir = "/root/autodl-tmp/piccolo-embedding/test-dti/test_results/"
    ## 改5
    model_dir = "/root/autodl-tmp/piccolo-embedding/scripts-dti/formal_biosnap_retri_8192/dti_8192_biosnap_random_only_exist_prompt3_fold2_cls2_add_stella_n0_epoch10_bs64_lr1e5_ml8192_1202/"
    model_names = ["checkpoint-1280"]
    
    for model in model_names:
        print(f"-------------------------------- {model} --------------------------------------")
        cls_test(root_dir, output_dir,model_dir, model)
        
    print(f"----------------------------- All models OK!!!!! ----------------------")
         

  from tqdm.autonotebook import tqdm, trange


-------------------------------- checkpoint-1280 --------------------------------------




(2198, 1792)
embedding完成
cm [[ 765  104]
 [  89 1240]]
CSV file /root/autodl-tmp/piccolo-embedding/test-dti/test_results/dti_retri_protein_biosnap_cls2_ml8192_1202.csv has been append.
----------------------------- All models OK!!!!! ----------------------


In [1]:
##### biosnap -- cluster-only

import json
import os
import random
import shutil
from data_utils import *
from tqdm import tqdm
import torch
from sentence_transformers import util,SentenceTransformer
from sentence_transformers.util import semantic_search
from sklearn.metrics import confusion_matrix
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
    
def cls_test(root_dir, output_dir, model_dir, model_name):
    drugbank_test_path = os.path.join(root_dir, "test_add.jsonl")
    drugbank_test = readJSONL(drugbank_test_path)
    ## 改1
    ou_drugbank_path = os.path.join(output_dir, "dti_biosnap_ml8192_0324.csv")
    model_path = os.path.join(model_dir, model_name)
    
        
    drugbank = []
    # The reference chunk is {example["top1_contents"]}, {example["top2_contents"]}, {example["top3_contents"]}.
    ## 改2
    for example in drugbank_test:
        if example["drug1_id"]:
            text = f"""
            Try to figure out drug-target interaction between the drug and the target. 
            The drug name is {example["drug1_name"]}, the drug smiles is {example["drug1_smile"]} and the drug description is {example["drug1_desc"]}. 
            The target protein description is {example["protein_desc"]}
            Please think step by step!
            {example["top1_contents"]}
            """
            drugbank.append(text)


    model = SentenceTransformer(model_path, trust_remote_code=True)
    pool = model.start_multi_process_pool()
    drugbank_embeddings = model.encode_multi_process(drugbank, pool, normalize_embeddings=True, batch_size=100)
    print(drugbank_embeddings.shape)
    print("embedding完成")
    
    corpus = ['negative', 'advise']
    corpus_embeddings = model.encode_multi_process(corpus, pool, normalize_embeddings=True, batch_size=50)
    top_selection = 2
    res_drugbank = semantic_search(drugbank_embeddings, corpus_embeddings, query_chunk_size=100, top_k=top_selection, score_function=util.dot_score)

    y_pred = np.array([int(res_drugbank[idx][0]["corpus_id"]) for idx in range(len(drugbank_test))])

    y_true = []
    y_pred_proba = []
    
    # Function to compute softmax
    def softmax(scores):
        exp_scores = np.exp(scores - np.max(scores))  # Subtract max for numerical stability
        return exp_scores / exp_scores.sum(axis=0)
    
    # Iterate through the data and compute softmax, replacing scores
    softmax_data = []
    for sublist in res_drugbank:
        scores = [item['score'] for item in sublist]
        softmax_scores = softmax(np.array(scores))
        
        # Replace scores in the original structure
        for i, item in enumerate(sublist):
            item['score'] = softmax_scores[i]
        softmax_data.append(sublist)
    # print("softmax_data", softmax_data)
    res_drugbank2 = softmax_data

    
    for idx, example in enumerate(drugbank_test):
        if int(res_drugbank2[idx][0]["corpus_id"]) == 0:
               y_pred_proba.append(res_drugbank2[idx][1]["score"])
        else:
               y_pred_proba.append(res_drugbank2[idx][0]["score"])

            
    for idx, example in enumerate(drugbank_test):
        if example["pos"] == 'negative':
            y_true.append(0)
        else:
            y_true.append(1)
           
    y_true = np.array(y_true)
    y_pred_proba = np.array(y_pred_proba)
    ## drugbank
    cm = confusion_matrix(y_true, y_pred)
    print('cm', cm)
    # # 保存为文本文件
    # with open('/root/autodl-tmp/piccolo-embedding/test/test_results/cm.txt', 'w') as f:
    #     for row in cm:
    #         f.write('\t'.join(map(str, row)) + '\n')

    
    drugbank_acc = accuracy_score(y_true, y_pred)
    drugbank_pre = precision_score(y_true, y_pred, average='macro')
    drugbank_recall = recall_score(y_true, y_pred, average='macro')
    drugbank_f1 = f1_score(y_true, y_pred, average='macro')

    # Calculate AUC and AUPR
    drugbank_auc = roc_auc_score(y_true, y_pred_proba)
    drugbank_aupr = average_precision_score(y_true, y_pred_proba)

    ### 改3
    new_item_drugbank = {
        "model": 'dti_ml8192_biosnap_cluster_only_n0_fold0_add_' + 'stella_' + 'ep?',
        "chem_acc": drugbank_acc,
        "chem_pre": drugbank_pre,
        "chem_recall": drugbank_recall,
        "chem_f1": drugbank_f1,
        "chem_auc": drugbank_auc,
        "chem_aupr": drugbank_aupr
    }
    if os.path.exists(ou_drugbank_path):
        writeCSV_xu([new_item_drugbank], ou_drugbank_path)
    else:
        writeCSV([new_item_drugbank], ou_drugbank_path)
            
        

if __name__ == "__main__":
    ### 改4
    root_dir = "/root/autodl-tmp/dataset_dti/datasets/datasets_tmp_all/biosnap_tmp/cluster_only_retrieval/fold0/"
    output_dir = "/root/autodl-tmp/piccolo-embedding/test-dti/test_results/"
    ## 改5
    model_dir = "/root/autodl-tmp/piccolo-embedding/scripts-dti/formal_biosnap_retri_8192/dti_8192_biosnap_cluster_only_prompt3_fold0_cls2_add_stella_n0_epoch15_bs64_lr1e5_ml8192_1202/"
    model_names = ["checkpoint-445"]
    
    for model in model_names:
        print(f"-------------------------------- {model} --------------------------------------")
        cls_test(root_dir, output_dir,model_dir, model)
        
    print(f"----------------------------- All models OK!!!!! ----------------------")
         

  from tqdm.autonotebook import tqdm, trange


-------------------------------- checkpoint-445 --------------------------------------




(419, 1792)
embedding完成
cm [[153  30]
 [  9 227]]
CSV file /root/autodl-tmp/piccolo-embedding/test-dti/test_results/dti_biosnap_ml8192_0324.csv has been created.
----------------------------- All models OK!!!!! ----------------------


In [1]:
### biosnap -- scarce

import json
import os
import random
import shutil
from data_utils import *
from tqdm import tqdm
import torch
from sentence_transformers import util,SentenceTransformer
from sentence_transformers.util import semantic_search
from sklearn.metrics import confusion_matrix
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
    
def cls_test(root_dir, output_dir, model_dir, model_name):
    drugbank_test_path = os.path.join(root_dir, "test_add.jsonl")
    drugbank_test = readJSONL(drugbank_test_path)
    ## 改1
    ou_drugbank_path = os.path.join(output_dir, "dti_biosnap_ml8192_0324-2.csv")
    model_path = os.path.join(model_dir, model_name)
    
        
    drugbank = []
    # The reference chunk is {example["top1_contents"]}, {example["top2_contents"]}, {example["top3_contents"]}.
    ## 改2
    for example in drugbank_test:
        if example["drug1_id"]:
            text = f"""
            Try to figure out drug-target interaction between the drug and the target. 
            The drug name is {example["drug1_name"]}, the drug smiles is {example["drug1_smile"]} and the drug description is {example["drug1_desc"]}. 
            The target protein description is {example["protein_desc"]}
            Please think step by step!
            {example["top1_contents"]}
            """
            drugbank.append(text)


    model = SentenceTransformer(model_path, trust_remote_code=True)
    pool = model.start_multi_process_pool()
    drugbank_embeddings = model.encode_multi_process(drugbank, pool, normalize_embeddings=True, batch_size=50)
    print(drugbank_embeddings.shape)
    print("embedding完成")
    
    corpus = ['negative', 'advise']
    corpus_embeddings = model.encode_multi_process(corpus, pool, normalize_embeddings=True, batch_size=50)
    top_selection = 2
    res_drugbank = semantic_search(drugbank_embeddings, corpus_embeddings, query_chunk_size=100, top_k=top_selection, score_function=util.dot_score)

    y_pred = np.array([int(res_drugbank[idx][0]["corpus_id"]) for idx in range(len(drugbank_test))])

    y_true = []
    y_pred_proba = []
    
    # Function to compute softmax
    def softmax(scores):
        exp_scores = np.exp(scores - np.max(scores))  # Subtract max for numerical stability
        return exp_scores / exp_scores.sum(axis=0)
    
    # Iterate through the data and compute softmax, replacing scores
    softmax_data = []
    for sublist in res_drugbank:
        scores = [item['score'] for item in sublist]
        softmax_scores = softmax(np.array(scores))
        
        # Replace scores in the original structure
        for i, item in enumerate(sublist):
            item['score'] = softmax_scores[i]
        softmax_data.append(sublist)
    # print("softmax_data", softmax_data)
    res_drugbank2 = softmax_data

    
    for idx, example in enumerate(drugbank_test):
        if int(res_drugbank2[idx][0]["corpus_id"]) == 0:
               y_pred_proba.append(res_drugbank2[idx][1]["score"])
        else:
               y_pred_proba.append(res_drugbank2[idx][0]["score"])

            
    for idx, example in enumerate(drugbank_test):
        if example["pos"] == 'negative':
            y_true.append(0)
        else:
            y_true.append(1)
           
    y_true = np.array(y_true)
    y_pred_proba = np.array(y_pred_proba)
    ## drugbank
    cm = confusion_matrix(y_true, y_pred)
    print('cm', cm)
    # # 保存为文本文件
    # with open('/root/autodl-tmp/piccolo-embedding/test/test_results/cm.txt', 'w') as f:
    #     for row in cm:
    #         f.write('\t'.join(map(str, row)) + '\n')

    
    drugbank_acc = accuracy_score(y_true, y_pred)
    drugbank_pre = precision_score(y_true, y_pred, average='macro')
    drugbank_recall = recall_score(y_true, y_pred, average='macro')
    drugbank_f1 = f1_score(y_true, y_pred, average='macro')

    # Calculate AUC and AUPR
    drugbank_auc = roc_auc_score(y_true, y_pred_proba)
    drugbank_aupr = average_precision_score(y_true, y_pred_proba)

    ### 改3
    new_item_drugbank = {
        "model": 'dti_ml8192_biosnap_scarce_5_n0_fold2_add_' + 'stella_' + 'ep30',
        "chem_acc": drugbank_acc,
        "chem_pre": drugbank_pre,
        "chem_recall": drugbank_recall,
        "chem_f1": drugbank_f1,
        "chem_auc": drugbank_auc,
        "chem_aupr": drugbank_aupr
    }
    if os.path.exists(ou_drugbank_path):
        writeCSV_xu([new_item_drugbank], ou_drugbank_path)
    else:
        writeCSV([new_item_drugbank], ou_drugbank_path)
            
        

if __name__ == "__main__":
    ### 改4
    root_dir = "/root/autodl-tmp/dataset_dti/datasets/datasets_tmp_all/biosnap_tmp/scarce_exist_retrieval/5/fold2/"
    output_dir = "/root/autodl-tmp/piccolo-embedding/test-dti/test_results/"
    ## 改5
    model_dir = "/root/autodl-tmp/piccolo-embedding/scripts-dti/formal_biosnap_retri_8192/dti_8192_biosnap_scarce_5_fold2_n0_epoch40_bs64_lr1e5_0324/"
    model_names = ["checkpoint-240"]
    
    for model in model_names:
        print(f"-------------------------------- {model} --------------------------------------")
        cls_test(root_dir, output_dir,model_dir, model)
        
    print(f"----------------------------- All models OK!!!!! ----------------------")

    

  from tqdm.autonotebook import tqdm, trange


-------------------------------- checkpoint-240 --------------------------------------




(9899, 1792)
embedding完成
cm [[1814 2208]
 [ 364 5513]]
CSV file /root/autodl-tmp/piccolo-embedding/test-dti/test_results/dti_biosnap_ml8192_0324-2.csv has been append.
----------------------------- All models OK!!!!! ----------------------


In [1]:
### human -- random-only-exist

import json
import os
import random
import shutil
from data_utils import *
from tqdm import tqdm
import torch
from sentence_transformers import util,SentenceTransformer
from sentence_transformers.util import semantic_search
from sklearn.metrics import confusion_matrix
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
    
def cls_test(root_dir, output_dir, model_dir, model_name):
    drugbank_test_path = os.path.join(root_dir, "test_add.jsonl")
    drugbank_test = readJSONL(drugbank_test_path)
    ## 改1
    ou_drugbank_path = os.path.join(output_dir, "dti_retri_protein_human_cls2_ml8192_1127.csv")
    model_path = os.path.join(model_dir, model_name)
    
        
    drugbank = []
    # The reference chunk is {example["top1_contents"]}, {example["top2_contents"]}, {example["top3_contents"]}.
    ## 改2
    for example in drugbank_test:
            text = f"""
            Try to figure out drug-target interaction between the drug and the target. 
            The drug smiles is {example["drug1_smile"]}
            The target protein description is {example["protein_desc"]}
            Please think step by step!
            {example["top1_contents"]}
            """
            drugbank.append(text)


    model = SentenceTransformer(model_path, trust_remote_code=True)
    pool = model.start_multi_process_pool()
    drugbank_embeddings = model.encode_multi_process(drugbank, pool, normalize_embeddings=True, batch_size=100)
    print(drugbank_embeddings.shape)
    print("embedding完成")
    
    corpus = ['negative', 'advise']
    corpus_embeddings = model.encode_multi_process(corpus, pool, normalize_embeddings=True, batch_size=50)
    top_selection = 2
    res_drugbank = semantic_search(drugbank_embeddings, corpus_embeddings, query_chunk_size=100, top_k=top_selection, score_function=util.dot_score)

    y_pred = np.array([int(res_drugbank[idx][0]["corpus_id"]) for idx in range(len(drugbank_test))])

    y_true = []
    y_pred_proba = []
    
    # Function to compute softmax
    def softmax(scores):
        exp_scores = np.exp(scores - np.max(scores))  # Subtract max for numerical stability
        return exp_scores / exp_scores.sum(axis=0)
    
    # Iterate through the data and compute softmax, replacing scores
    softmax_data = []
    for sublist in res_drugbank:
        scores = [item['score'] for item in sublist]
        softmax_scores = softmax(np.array(scores))
        
        # Replace scores in the original structure
        for i, item in enumerate(sublist):
            item['score'] = softmax_scores[i]
        softmax_data.append(sublist)
    # print("softmax_data", softmax_data)
    res_drugbank2 = softmax_data

    
    for idx, example in enumerate(drugbank_test):
        if int(res_drugbank2[idx][0]["corpus_id"]) == 0:
               y_pred_proba.append(res_drugbank2[idx][1]["score"])
        else:
               y_pred_proba.append(res_drugbank2[idx][0]["score"])

            
    for idx, example in enumerate(drugbank_test):
        if example["pos"] == 'negative':
            y_true.append(0)
        else:
            y_true.append(1)
           
    y_true = np.array(y_true)
    y_pred_proba = np.array(y_pred_proba)
    ## drugbank
    cm = confusion_matrix(y_true, y_pred)
    print('cm', cm)
    # # 保存为文本文件
    # with open('/root/autodl-tmp/piccolo-embedding/test/test_results/cm.txt', 'w') as f:
    #     for row in cm:
    #         f.write('\t'.join(map(str, row)) + '\n')

    
    drugbank_acc = accuracy_score(y_true, y_pred)
    drugbank_pre = precision_score(y_true, y_pred, average='macro')
    drugbank_recall = recall_score(y_true, y_pred, average='macro')
    drugbank_f1 = f1_score(y_true, y_pred, average='macro')

    # Calculate AUC and AUPR
    drugbank_auc = roc_auc_score(y_true, y_pred_proba)
    drugbank_aupr = average_precision_score(y_true, y_pred_proba)

    ### 改3
    new_item_drugbank = {
        "model": 'dti_human_random_only_exist_prompt3_n0_fold2_bs64_retrain_add_' + 'stella_' + 'ep10',
        "chem_acc": drugbank_acc,
        "chem_pre": drugbank_pre,
        "chem_recall": drugbank_recall,
        "chem_f1": drugbank_f1,
        "chem_auc": drugbank_auc,
        "chem_aupr": drugbank_aupr
    }
    if os.path.exists(ou_drugbank_path):
        writeCSV_xu([new_item_drugbank], ou_drugbank_path)
    else:
        writeCSV([new_item_drugbank], ou_drugbank_path)

if __name__ == "__main__":
    ### 改4
    root_dir = "/root/autodl-tmp/dataset_dti/datasets/datasets_tmp_all/human_tmp_only_protein/random_only_exist_retrieval/fold2/"
    output_dir = "/root/autodl-tmp/piccolo-embedding/test-dti/test_results/"
    ## 改5
    model_dir = "/root/autodl-tmp/piccolo-embedding/scripts-dti/formal_human_retri_8192/dti_8192_human_random_only_exist_prompt3_fold2_cls2_add_stella_n0_epoch10_bs64_lr1e5_ml8192_1202/"
    model_names = ["checkpoint-630"]
    
    for model in model_names:
        print(f"-------------------------------- {model} --------------------------------------")
        cls_test(root_dir, output_dir,model_dir, model)
        
    print(f"----------------------------- All models OK!!!!! ----------------------")
         

  from tqdm.autonotebook import tqdm, trange


-------------------------------- checkpoint-630 --------------------------------------




(949, 1792)
embedding完成
cm [[481  25]
 [ 25 418]]
CSV file /root/autodl-tmp/piccolo-embedding/test-dti/test_results/dti_retri_protein_human_cls2_ml8192_1127.csv has been append.
----------------------------- All models OK!!!!! ----------------------


In [1]:
### human cold-balanced

import json
import os
import random
import shutil
from data_utils import *
from tqdm import tqdm
import torch
from sentence_transformers import util,SentenceTransformer
from sentence_transformers.util import semantic_search
from sklearn.metrics import confusion_matrix
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
    
def cls_test(root_dir, output_dir, model_dir, model_name):
    drugbank_test_path = os.path.join(root_dir, "test_add.jsonl")
    drugbank_test = readJSONL(drugbank_test_path)
    ## 改1
    ou_drugbank_path = os.path.join(output_dir, "dti_retri_protein_human_cls2_ml8192_1127.csv")
    model_path = os.path.join(model_dir, model_name)
    
        
    drugbank = []
    # The reference chunk is {example["top1_contents"]}, {example["top2_contents"]}, {example["top3_contents"]}.
    ## 改2
    for example in drugbank_test:
            text = f"""
            Try to figure out drug-target interaction between the drug and the target. 
            The drug smiles is {example["drug1_smile"]}
            The target protein description is {example["protein_desc"]}
            Please think step by step!
            {example["top1_contents"]}
            """
            drugbank.append(text)


    model = SentenceTransformer(model_path, trust_remote_code=True)
    pool = model.start_multi_process_pool()
    drugbank_embeddings = model.encode_multi_process(drugbank, pool, normalize_embeddings=True, batch_size=100)
    print(drugbank_embeddings.shape)
    print("embedding完成")
    
    corpus = ['negative', 'advise']
    corpus_embeddings = model.encode_multi_process(corpus, pool, normalize_embeddings=True, batch_size=50)
    top_selection = 2
    res_drugbank = semantic_search(drugbank_embeddings, corpus_embeddings, query_chunk_size=100, top_k=top_selection, score_function=util.dot_score)

    y_pred = np.array([int(res_drugbank[idx][0]["corpus_id"]) for idx in range(len(drugbank_test))])

    y_true = []
    y_pred_proba = []
    
    # Function to compute softmax
    def softmax(scores):
        exp_scores = np.exp(scores - np.max(scores))  # Subtract max for numerical stability
        return exp_scores / exp_scores.sum(axis=0)
    
    # Iterate through the data and compute softmax, replacing scores
    softmax_data = []
    for sublist in res_drugbank:
        scores = [item['score'] for item in sublist]
        softmax_scores = softmax(np.array(scores))
        
        # Replace scores in the original structure
        for i, item in enumerate(sublist):
            item['score'] = softmax_scores[i]
        softmax_data.append(sublist)
    # print("softmax_data", softmax_data)
    res_drugbank2 = softmax_data

    
    for idx, example in enumerate(drugbank_test):
        if int(res_drugbank2[idx][0]["corpus_id"]) == 0:
               y_pred_proba.append(res_drugbank2[idx][1]["score"])
        else:
               y_pred_proba.append(res_drugbank2[idx][0]["score"])

            
    for idx, example in enumerate(drugbank_test):
        if example["pos"] == 'negative':
            y_true.append(0)
        else:
            y_true.append(1)
           
    y_true = np.array(y_true)
    y_pred_proba = np.array(y_pred_proba)
    ## drugbank
    cm = confusion_matrix(y_true, y_pred)
    print('cm', cm)
    # # 保存为文本文件
    # with open('/root/autodl-tmp/piccolo-embedding/test/test_results/cm.txt', 'w') as f:
    #     for row in cm:
    #         f.write('\t'.join(map(str, row)) + '\n')

    
    drugbank_acc = accuracy_score(y_true, y_pred)
    drugbank_pre = precision_score(y_true, y_pred, average='macro')
    drugbank_recall = recall_score(y_true, y_pred, average='macro')
    drugbank_f1 = f1_score(y_true, y_pred, average='macro')

    # Calculate AUC and AUPR
    drugbank_auc = roc_auc_score(y_true, y_pred_proba)
    drugbank_aupr = average_precision_score(y_true, y_pred_proba)

    # 保存y_label和y_pred到CSV文件
    ## 自添加
    import pandas as pd
    results_df = pd.DataFrame({
        'true_labels': y_true,
        'predicted_prob': y_pred_proba
    })
    results_df.to_csv('/root/autodl-tmp/piccolo-embedding/test-dti/test_results/cold-csv/cold-predictions.csv', index=False)
    print('csv输出完成！！')

    

    ### 改3
    new_item_drugbank = {
        "model": 'dti_human_balanced_cold_only_prompt3_n0_fold2_bs64_retrain_add_' + 'stella_' + 'ep5',
        "chem_acc": drugbank_acc,
        "chem_pre": drugbank_pre,
        "chem_recall": drugbank_recall,
        "chem_f1": drugbank_f1,
        "chem_auc": drugbank_auc,
        "chem_aupr": drugbank_aupr
    }
    if os.path.exists(ou_drugbank_path):
        writeCSV_xu([new_item_drugbank], ou_drugbank_path)
    else:
        writeCSV([new_item_drugbank], ou_drugbank_path)

if __name__ == "__main__":
    ### 改4
    root_dir = "/root/autodl-tmp/dataset_dti/datasets/datasets_tmp_all/human_tmp_only_protein/cold_only_retrieval_balanced/fold2/"
    output_dir = "/root/autodl-tmp/piccolo-embedding/test-dti/test_results/"
    ## 改5
    model_dir = "/root/autodl-tmp/piccolo-embedding/scripts-dti/formal_human_retri_8192/dti_8192_human_balanced_cold_only_prompt3_fold2_cls2_add_stella_n0_epoch5_bs64_lr1e5_ml8192_1202/"
    model_names = ["checkpoint-245"]
    
    for model in model_names:
        print(f"-------------------------------- {model} --------------------------------------")
        cls_test(root_dir, output_dir,model_dir, model)
        
    print(f"----------------------------- All models OK!!!!! ----------------------")

    

  from tqdm.autonotebook import tqdm, trange


-------------------------------- checkpoint-245 --------------------------------------




(240, 1792)
embedding完成
cm [[108  12]
 [ 60  60]]
csv输出完成！！
CSV file /root/autodl-tmp/piccolo-embedding/test-dti/test_results/dti_retri_protein_human_cls2_ml8192_1127.csv has been append.
----------------------------- All models OK!!!!! ----------------------
