## Conduct a preliminary analysis of all model predictions
1. Find the problem where all the models are correct
2. Find the problem where all models get it wrong
3. Find questions that get some parts right

In [1]:
import torch
from sklearn.metrics import accuracy_score

In [2]:
hy_c10_train_analysis_path = "../02-train&valid/hyena/result/hyena_c10_690_train.pt"
hy_c10_valid_analysis_path = "../02-train&valid/hyena/result/hyena_c10_690_valid.pt"
hy_c10_test_analysis_path = "../02-train&valid/hyena/result/hyena_c10_690_test.pt"
nt_c10_train_analysis_path = "../02-train&valid/ntv2/result/ntv2_c10_train_results.pt"
nt_c10_valid_analysis_path = "../02-train&valid/ntv2/result/ntv2_c10_valid_results.pt"
nt_c10_test_analysis_path = "../02-train&valid/ntv2/result/ntv2_c10_test_results.pt"
cd_c10_train_analysis_path = "../02-train&valid/cdgpt/result/cdgpt_c10_2_3036_train_results.pt"
cd_c10_valid_analysis_path = "../02-train&valid/cdgpt/result/cdgpt_c10_2_3036_valid_results.pt"
cd_c10_test_analysis_path = "../02-train&valid/cdgpt/result/cdgpt_c10_2_3036_test_results.pt"

hy_train_file_path = "../01-data/C10_hyena_20kbp_train_dataset.pt"
hy_valid_file_path = "../01-data/C10_hyena_20kbp_valid_dataset.pt"
hy_test_file_path = "../01-data/C10_hyena_20kbp_test_dataset.pt"

nt_train_file_path = "../01-data/C10_ntv2_12kbp_train_dataset.pt"
nt_valid_file_path = "../01-data/C10_ntv2_12kbp_valid_dataset.pt"
nt_test_file_path = "../01-data/C10_ntv2_12kbp_test_dataset.pt"

In [6]:

import torch

def compare_predictions(result_files):
    results = [torch.load(f) for f in result_files]
    num_samples = len(results[0]["prediction"])  
    num_models = len(results)  
    
    # Check the predictions for each sample
    all_correct_indices = []
    all_incorrect_indices = []
    one_correct_rest_incorrect = []
    one_incorrect_rest_correct = []
    
    # A single model is correct and other models are wrong
    single_model_correct_rest_incorrect = {f"{chr(97 + i)}_correct_rest_incorrect": [] for i in range(num_models)}
    
    # When a single model is wrong and other models are correct
    single_model_incorrect_rest_correct = {f"{chr(97 + i)}_incorrect_rest_correct": [] for i in range(num_models)}

    for i in range(num_samples):
        label = results[0]["label"][i]
        predictions = [result["prediction"][i] for result in results]
        correct_preds = [pred == label for pred in predictions]
        num_correct = sum(correct_preds)

        if num_correct == num_models:
            all_correct_indices.append(i)
        
        elif num_correct == 0:
            all_incorrect_indices.append(i)
        
        elif num_correct == 1:
            one_correct_rest_incorrect.append(i)
            for model_idx, is_correct in enumerate(correct_preds):
                if is_correct:
                    model_label = chr(97 + model_idx)
                    single_model_correct_rest_incorrect[f"{model_label}_correct_rest_incorrect"].append(i)
        
        elif num_correct == num_models - 1:
            one_incorrect_rest_correct.append(i)
            for model_idx, is_correct in enumerate(correct_preds):
                if not is_correct:
                    model_label = chr(97 + model_idx)
                    single_model_incorrect_rest_correct[f"{model_label}_incorrect_rest_correct"].append(i)

    def format_output(label, indices):
        count = len(indices)
        proportion = count / num_samples * 100
        print(f"{label} Quantity: {count} ({proportion:.2f}%), subscript: {indices}")

    print("\n=== The model predicts the comparison results ===")
    format_output("All models predict the correct example", all_correct_indices)
    format_output("All models predict examples of errors", all_incorrect_indices)
    format_output("An example of one correct prediction and the rest wrong", one_correct_rest_incorrect)
    format_output("An example of one incorrect prediction and the rest correct", one_incorrect_rest_correct)

    print("\n=== Details of how a single model predicted correctly and other models predicted incorrectly ===")
    for model_label, indices in single_model_correct_rest_incorrect.items():
        format_output(f"Example of model {model_label[0].upper()} getting it right and other models getting it wrong", indices)

    print("\n=== Individual models predicted wrong and other models predicted correct details ===")
    for model_label, indices in single_model_incorrect_rest_correct.items():
        format_output(f"Example of model {model_label[0].upper()} getting it wrong and other models getting it right", indices)

In [7]:
result_files = [hy_c10_test_analysis_path, nt_c10_test_analysis_path, cd_c10_test_analysis_path]
compare_predictions(result_files)


=== The model predicts the comparison results ===
All models predict the correct example Quantity: 474 (47.98%), subscript: [2, 3, 4, 6, 8, 10, 11, 13, 14, 15, 16, 19, 20, 22, 23, 24, 27, 28, 31, 35, 36, 39, 44, 49, 50, 52, 54, 59, 60, 63, 64, 65, 66, 67, 69, 71, 73, 76, 77, 80, 81, 83, 86, 89, 91, 93, 95, 97, 101, 103, 104, 107, 108, 112, 113, 114, 115, 123, 124, 125, 126, 129, 133, 134, 135, 139, 143, 146, 147, 148, 149, 151, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 164, 167, 168, 169, 170, 176, 178, 179, 180, 183, 185, 187, 188, 190, 191, 197, 200, 201, 206, 207, 208, 209, 210, 212, 213, 215, 221, 222, 224, 226, 228, 236, 237, 239, 242, 243, 244, 245, 246, 247, 250, 251, 252, 253, 256, 260, 263, 266, 268, 269, 273, 274, 275, 276, 277, 278, 280, 282, 284, 285, 286, 287, 291, 293, 294, 295, 296, 297, 300, 303, 306, 309, 310, 314, 315, 320, 321, 322, 324, 325, 329, 330, 331, 332, 334, 336, 337, 338, 342, 343, 344, 346, 347, 348, 351, 357, 358, 361, 363, 365, 366, 370, 374, 37

  results = [torch.load(f) for f in result_files]


In [8]:
'''
Pass in two models for more detailed analysis
'''
def compare_predictions_two_models(result_files):
    results = [torch.load(f) for f in result_files]
    num_samples = len(results[0]["prediction"])

    all_correct_indices = []
    all_incorrect_indices = []
    one_correct_rest_incorrect = []
    one_incorrect_rest_correct = []
    
    a_correct_b_incorrect = []  # Model a is right, model b is wrong
    b_correct_a_incorrect = []  # Model b is right, model a is wrong

    for i in range(num_samples):
        label = results[0]["label"][i]
        predictions = [result["prediction"][i] for result in results]
        correct_preds = [pred == label for pred in predictions]
        num_correct = sum(correct_preds)

        if num_correct == 2:
            all_correct_indices.append(i)
        
        elif num_correct == 0:
            all_incorrect_indices.append(i)
        
        elif num_correct == 1:
            one_correct_rest_incorrect.append(i)
            if correct_preds[0] and not correct_preds[1]:  
                a_correct_b_incorrect.append(i)
            elif correct_preds[1] and not correct_preds[0]:
                b_correct_a_incorrect.append(i)

    def format_output(label, indices):
        count = len(indices)
        proportion = count / num_samples * 100
        print(f"{label} Quantity: {count} ({proportion:.2f}%), subscript: {indices}")

    print("\n=== Comparison of the predictions of the two models ===")
    format_output("All models predict the correct example", all_correct_indices)
    format_output("All models predict examples of errors", all_incorrect_indices)
    format_output("An example of one correct prediction and the rest wrong", one_correct_rest_incorrect)
    format_output("Example of model a predicting correctly and Model b predicting incorrectly", a_correct_b_incorrect)
    format_output("Example of model b predicting correctly and Model a predicting incorrectly", b_correct_a_incorrect)

In [12]:
result_files_1 = [nt_c10_test_analysis_path, cd_c10_test_analysis_path]
compare_predictions_two_models(result_files_1)


=== Comparison of the predictions of the two models ===
All models predict the correct example Quantity: 559 (56.58%), subscript: [2, 3, 4, 6, 8, 9, 10, 11, 13, 14, 15, 16, 19, 20, 21, 22, 23, 24, 27, 28, 30, 31, 33, 35, 36, 39, 42, 44, 46, 49, 50, 52, 53, 54, 55, 59, 60, 63, 64, 65, 66, 67, 69, 71, 73, 76, 77, 79, 80, 81, 83, 86, 89, 90, 91, 93, 94, 95, 97, 101, 103, 104, 105, 107, 108, 112, 113, 114, 115, 116, 119, 123, 124, 125, 126, 129, 130, 133, 134, 135, 138, 139, 142, 143, 144, 146, 147, 148, 149, 151, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 164, 167, 168, 169, 170, 176, 178, 179, 180, 182, 183, 185, 186, 187, 188, 190, 191, 194, 197, 200, 201, 202, 203, 206, 207, 208, 209, 210, 212, 213, 215, 216, 220, 221, 222, 223, 224, 226, 228, 229, 232, 236, 237, 239, 242, 243, 244, 245, 246, 247, 250, 251, 252, 253, 254, 256, 260, 262, 263, 266, 268, 269, 272, 273, 274, 275, 276, 277, 278, 280, 282, 284, 285, 286, 287, 290, 291, 293, 294, 295, 296, 297, 300, 303, 306, 309, 310

  results = [torch.load(f) for f in result_files]


In [11]:
result_files_2 = [hy_c10_test_analysis_path, cd_c10_test_analysis_path]
compare_predictions_two_models(result_files_2)


=== Comparison of the predictions of the two models ===
All models predict the correct example Quantity: 518 (52.43%), subscript: [2, 3, 4, 6, 7, 8, 10, 11, 13, 14, 15, 16, 19, 20, 22, 23, 24, 27, 28, 31, 35, 36, 39, 44, 49, 50, 52, 54, 56, 59, 60, 63, 64, 65, 66, 67, 69, 71, 73, 76, 77, 80, 81, 82, 83, 86, 89, 91, 93, 95, 97, 101, 103, 104, 107, 108, 110, 112, 113, 114, 115, 123, 124, 125, 126, 128, 129, 131, 133, 134, 135, 139, 143, 146, 147, 148, 149, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 164, 166, 167, 168, 169, 170, 176, 178, 179, 180, 183, 185, 187, 188, 190, 191, 197, 200, 201, 206, 207, 208, 209, 210, 212, 213, 215, 217, 221, 222, 224, 226, 227, 228, 235, 236, 237, 239, 242, 243, 244, 245, 246, 247, 250, 251, 252, 253, 255, 256, 260, 263, 266, 268, 269, 270, 273, 274, 275, 276, 277, 278, 280, 282, 284, 285, 286, 287, 291, 293, 294, 295, 296, 297, 300, 303, 304, 306, 309, 310, 314, 315, 319, 320, 321, 322, 324, 325, 329, 330, 331, 332, 334, 336, 337, 338, 

  results = [torch.load(f) for f in result_files]


In [14]:
result_files_2 = [nt_c10_test_analysis_path, cd_c10_test_analysis_path]
compare_predictions_two_models(result_files_2)


=== Comparison of the predictions of the two models ===
All models predict the correct example Quantity: 559 (56.58%), subscript: [2, 3, 4, 6, 8, 9, 10, 11, 13, 14, 15, 16, 19, 20, 21, 22, 23, 24, 27, 28, 30, 31, 33, 35, 36, 39, 42, 44, 46, 49, 50, 52, 53, 54, 55, 59, 60, 63, 64, 65, 66, 67, 69, 71, 73, 76, 77, 79, 80, 81, 83, 86, 89, 90, 91, 93, 94, 95, 97, 101, 103, 104, 105, 107, 108, 112, 113, 114, 115, 116, 119, 123, 124, 125, 126, 129, 130, 133, 134, 135, 138, 139, 142, 143, 144, 146, 147, 148, 149, 151, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 164, 167, 168, 169, 170, 176, 178, 179, 180, 182, 183, 185, 186, 187, 188, 190, 191, 194, 197, 200, 201, 202, 203, 206, 207, 208, 209, 210, 212, 213, 215, 216, 220, 221, 222, 223, 224, 226, 228, 229, 232, 236, 237, 239, 242, 243, 244, 245, 246, 247, 250, 251, 252, 253, 254, 256, 260, 262, 263, 266, 268, 269, 272, 273, 274, 275, 276, 277, 278, 280, 282, 284, 285, 286, 287, 290, 291, 293, 294, 295, 296, 297, 300, 303, 306, 309, 310

  results = [torch.load(f) for f in result_files]


## Save the models that are correct and the models that are wrong
##### Take out the sequences in which all the 3 models are correct and the sequences in which all the 3 models are wrong, and carry out the sequence comparison

In [15]:
import torch
import os

def compare_predictions(result_files, test_file, output_dir="3model_result"):
    results = [torch.load(f) for f in result_files]
    test_data = torch.load(test_file)
    
    os.makedirs(output_dir, exist_ok=True)
    
    num_samples = len(results[0]["prediction"])
    num_models = len(results) 
    
    all_correct_indices = []
    all_incorrect_indices = []
    
    for i in range(num_samples):
        label = results[0]["label"][i]
        predictions = [result["prediction"][i] for result in results]
        correct_preds = [pred == label for pred in predictions]
        num_correct = sum(correct_preds)

        if num_correct == num_models:
            all_correct_indices.append(i)
        
        elif num_correct == 0:
            all_incorrect_indices.append(i)
    
    all_correct_data = {key: [] for key in test_data.keys()}
    for idx in all_correct_indices:
        label_in_hyena = results[0]["label"][idx]
        label_in_test = test_data["labels"][idx]
        if label_in_hyena == label_in_test:
            for key in test_data.keys():
                all_correct_data[key].append(test_data[key][idx])

    # Save the samples that all models predict correctly
    correct_save_path = os.path.join(output_dir, "3model_all_correct.pt")
    torch.save(all_correct_data, correct_save_path)
    print(f"Saved all correct data to {correct_save_path}")
    
    # Process samples where all models predict errors
    all_wrong_data = {key: [] for key in test_data.keys()}
    for idx in all_incorrect_indices:
        label_in_hyena = results[0]["label"][idx]
        label_in_test = test_data["labels"][idx]
        # pt and data/Hyena_602.pt are consistent
        if label_in_hyena == label_in_test:
            # Adds the data for each field of the corresponding subscript in test.pt to the corresponding list
            for key in test_data.keys():
                all_wrong_data[key].append(test_data[key][idx])

    # Save samples where all models predict errors
    wrong_save_path = os.path.join(output_dir, "3model_all_wrong.pt")
    torch.save(all_wrong_data, wrong_save_path)
    print(f"Saved all wrong data to {wrong_save_path}")


In [33]:
# 文件路径配置
result_files = [hy_c10_test_analysis_path, nt_c10_test_analysis_path, cd_c10_test_analysis_path]
test_file = nt_test_file_path
compare_predictions(result_files, test_file)

Saved all correct data to 3model_result/3model_all_correct.pt
Saved all wrong data to 3model_result/3model_all_wrong.pt


  results = [torch.load(f) for f in result_files]
  test_data = torch.load(test_file)


In [34]:
import torch

correct_file_path = '3model_result/3model_all_correct.pt'
correct_data = torch.load(correct_file_path)

print(f"len is: {len(correct_data['labels'])}")
print(f"table head is :{correct_data.keys()}")
first_row = {key: value[0] for key, value in correct_data.items()}
print("First row of data: ", first_row)

len is: 474
table head is :dict_keys(['gene_id', 'sequences', 'labels'])
First row of data:  {'gene_id': 'ENSG00000254901', 'sequences': 'ATAAATAAATAAATAAGTGGGCCAGGTGCGGTGGCTCACGCCTATAATCCCAGAACTTTGGGATGCCAAGGTGGGCTGAGTGCTTGAGTACAGGAATTCACGACCAGCCTGGGCAACATGACAAGACCCCATATTTATAATTTTTTTTTTTAATTAGCTGGTCACAGGCTGGCCACAGTGGCTCACGCCTGTAATCCCAGGACTTTGGGAGGCCAAGGCAGGTGGATCACCTGTGATCAGGAGTTTGAGACCAGCTTGGCCAACATGGTGAAACTCTGTCTCTACTAAAAATACAAACATTAGCTGGGAATGGTGGCACGCACCTGTAATTCCAGCTACTCAGGAGGCTAAGGCAGAAGAATCGCTTGAACCTGGGAGGTGGAGGTTGCAGTGAGCCGAGATTGTGCCACTGCACTCCAGCCTGGGCAACAGAGTGAGACTCTGTCTCAAAAAAAAAAAAAAAAAAAAAATTAGCTGGGTGTGCTGGTGTGAGCGCGTATTCCTAGCTCCTCAGGAGGCTGAGGCAGGAGGATCACTTGAGCCCAGGAGGCAGAGGTTGCAGTGAGCTGAGATCACACCACTTTACTCTAGCCTGGGCAACAGAGCAAGATGCTGTCTCAAAAACAAAGAAAGAAAGAAAGAAAGAAAGAAACCTCTTCCAGAAGGCCAAACACCCAACATGTCTACACTCACTGCACCCAAGTTGGGGTGAGCAAATGTTTTAAATTCCCCTTCTCTTCTTAATTTGCATTTTCCAGATGTCCACCTGGTTGGGTCATAGTTTAACCAAATAAATCATTCGTTGGGATGGGAAAGCCAAGAGTGGGTTCAGCTTGCTCCGCTCACAGGAGCTGCCACAAAG

  correct_data = torch.load(correct_file_path)


In [35]:
wrong_file_path = '3model_result/3model_all_wrong.pt'
wrong_data = torch.load(wrong_file_path)

print(f"len is: {len(wrong_data['labels'])}")
print(f"table head is :{wrong_data.keys()}")
first_row = {key: value[0] for key, value in wrong_data.items()}
print("First row of data: ", first_row)

len is: 93
table head is :dict_keys(['gene_id', 'sequences', 'labels'])
First row of data:  {'gene_id': 'ENSG00000138380', 'sequences': 'TCCTCAACAAAATACTTGCAAACCAAATTCAACAACACATTAAGAAGATCATCATGACCAAGTAAGATTTATCCCAGGGATGCAAGAATAGTTCAACATACACAAATCAATGTAATACACTGTGATGGTTAACACTGAGTGTCAACTTGATTGGATTGAAGGATGCAAAGTATTGTTCCTGGGTGTGTCTGTGAAGGTGATGCCAAACGAGATGAACATTTGAGACAGTGGAATGGAAGAGGCAGACCTACCCTCTATCTGGGTAGGTACCATCTAATCAGCTGCCAGCATGGCTAGGATAAAAGGAGGCAGAGGAATGTGGAAGGACTAGACTGGATCAGTCTTCAGGCCTTTATCTTACTCCTGTGCTTCCTGCCCTCAAACATTGAACTCCAAGTTCTACAGCTTTTGGATTCTTGGACCAGTGGTTTGCCAGGGATTCTCAGGTGTTTGGCCATAGACTGAAGGCTGTACTGCCGGCTTCCCTACTTTTGAGGTTTTGGGACTCGGACTGGCCTCCTTGCTCCTCATCTTGCAGACAGCCTATCGTAGGACTTCATCTTGTGATCGTGTGAGTCAATACTCCTTAATAAACTCCCTCTCGGCTGTGCGCAGTGGCTCACACCTGTAATCCCAGTACTTTGGGAGGCCAAGAGTTCAAGACCAGCCTGACCAGCATGGTGAAACCCCGTCTCTACTAAAACTACAAAAATTAGCCAGCCGGGTGTGGTGGCACGCGCCTGTAATCCCAGCTACTCGGGAGGCTGAAGCAGGAGAATTGCTTGAACCCGGGAGGTGGAGGTTGCAGTGAGCCGAGATCGTGCCACTGCACTCCAGCCTGGGCGACAGAGCAAGACTCCGTC

  wrong_data = torch.load(wrong_file_path)


## Save the questions that some models do right and some do wrong
#### Of the three models, the sequence in which only two models get it right is compared with the sequence in which only one model gets it right

In [None]:
import torch
import os

def compare_predictions(result_files, test_file, output_dir="3model_result"):
    results = [torch.load(f) for f in result_files]
    test_data = torch.load(test_file)
    
    os.makedirs(output_dir, exist_ok=True)
    
    num_samples = len(results[0]["predictions"])  
    num_models = len(results)  
    
    all_correct_indices = []
    all_incorrect_indices = []
    one_correct_indices = []
    two_correct_indices = []
    
    for i in range(num_samples):
        label = results[0]["labels"][i]
        predictions = [result["predictions"][i] for result in results]
        correct_preds = [pred == label for pred in predictions]
        num_correct = sum(correct_preds)

        # All the models predicted correctly
        if num_correct == num_models:
            all_correct_indices.append(i)
        
        # All the models predicted wrong
        elif num_correct == 0:
            all_incorrect_indices.append(i)

        # Only one model predicted correctly
        elif num_correct == 1:
            one_correct_indices.append(i)
        
        # Only two models predicted correctly
        elif num_correct == 2:
            two_correct_indices.append(i)

    # Process samples for which only one model predicts correctly
    one_correct_data = {key: [] for key in test_data.keys()}
    for idx in one_correct_indices:
        for key in test_data.keys():
            one_correct_data[key].append(test_data[key][idx])

    one_correct_save_path = os.path.join(output_dir, "3model_have_1_correct.pt")
    torch.save(one_correct_data, one_correct_save_path)
    print(f"Saved one correct data to {one_correct_save_path}")
    
    # Process samples where only two models predict correctly
    two_correct_data = {key: [] for key in test_data.keys()}
    for idx in two_correct_indices:
        for key in test_data.keys():
            two_correct_data[key].append(test_data[key][idx])

    two_correct_save_path = os.path.join(output_dir, "3model_have_2_correct.pt")
    torch.save(two_correct_data, two_correct_save_path)
    print(f"Saved two correct data to {two_correct_save_path}")

result_files = ["data/Hyena_602.pt", "data/NTv2_240.pt", "data/CDgpt_2_8754.pt"]
test_file = "../test.pt"
compare_predictions(result_files, test_file)


Saved one correct data to 3model_result/3model_have_1_correct.pt
Saved two correct data to 3model_result/3model_have_2_correct.pt


  results = [torch.load(f) for f in result_files]
  test_data = torch.load(test_file)
