In [5]:
import pandas as pd
import os
import numpy as np
from collections import defaultdict
from glob import glob
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import wilcoxon
import warnings

# Suppress warnings
warnings.filterwarnings('ignore')

# find all the files with "predictions" in the path from the parent path "results"
pred_files = glob("results/**/**/*predictions*.csv", recursive=True)
# Check for potential duplicates by examining file paths
unique_files = set(pred_files)
if len(unique_files) < len(pred_files):
    print(f"Warning: Found {len(pred_files) - len(unique_files)} duplicate file paths")
    # Count occurrences of each file path
    path_counts = {}
    for path in pred_files:
        path_counts[path] = path_counts.get(path, 0) + 1
    # Show duplicated paths
    duplicates = {path: count for path, count in path_counts.items() if count > 1}
    print(f"Duplicated paths: {duplicates}")
    # Use only unique paths
    pred_files = list(unique_files)
print(f"Found {len(pred_files)} prediction files")

# Load predictions and organize by model and data type
model_predictions = {}
data_types = set()
models = set()

for file_path in pred_files:
    # Extract model and data type from path
    path_parts = file_path.split(os.sep)
    model_name = path_parts[1]  # e.g., "BERT-base"
    
    # Handle special case for trained models with complex structure
    if 'trained_models' in file_path:
        filename = os.path.basename(file_path)
        
        if 'se_domainHF' in file_path:
            # Extract model type from filename (e.g., SE-COPD, SE-Autoimmune)
            # Extract the disease from the model name
            model_disease = filename.split('_')[2].replace('HF', '')
            # Extract the target dataset from the filename
            target_dataset = filename.split('_')[-2]
            model_name = f"SE-{model_disease.lower()}"
            data_type = target_dataset
        elif 'moe_tokensHF' in file_path:
            # Handle MOE models
            data_type = filename.split('_')[-2]
            model_name = "MOE-ALL"
        elif 'se_all_tokensHF' in file_path:
            # Handle SE-ALL models
            data_type = filename.split('_')[-2]
            model_name = "SE-ALL"
        else:
            # For other trained models
            data_type = filename.split('_')[-2]
    else:
        # For standard models, extract data type from filename
        filename = os.path.basename(file_path)
        data_type = filename.split('_')[1] if '_' in filename else 'unknown'
    
    # Load predictions
    df = pd.read_csv(file_path)
    if 'prediction' in df.columns and 'label' in df.columns:
        key = (model_name, data_type)
        model_predictions[key] = {
            'labels': df['label'].values,
            'predictions': df['prediction'].values
        }
        models.add(model_name)
        data_types.add(data_type)

# Convert sets to sorted lists for consistent ordering
models = sorted(list(models))
data_types = sorted(list(data_types))

print(f"Found {len(models)} models and {len(data_types)} data types")

# Create pairwise statistical comparison for each data type
for data_type in data_types:
    # Filter models that have predictions for this data type
    available_models = [model for model in models if (model, data_type) in model_predictions]
    
    if len(available_models) < 2:
        print(f"Skipping {data_type} - not enough models for comparison")
        continue
    
    n_models = len(available_models)
    p_values = np.ones((n_models, n_models))
    
    # Compute pairwise p-values using Wilcoxon test
    for i, model1 in enumerate(available_models):
        for j, model2 in enumerate(available_models):
            if i == j:
                p_values[i, j] = 1.0  # Same model, p-value = 1
                continue
                
            # Get predictions for both models
            y_true = model_predictions[(model1, data_type)]['labels']
            y1_score = model_predictions[(model1, data_type)]['predictions']
            
            # Check if the second model has predictions for this data type
            if (model2, data_type) not in model_predictions:
                continue
            
            y2_score = model_predictions[(model2, data_type)]['predictions']
            
            # Ensure all arrays have the same length
            if len(y_true) != len(y1_score) or len(y_true) != len(y2_score):
                print(f"Warning: Mismatched lengths for {model1} vs {model2} on {data_type}")
                continue
            
            try:
                # Perform Wilcoxon signed-rank test on the predictions
                stat, p_value = wilcoxon(y1_score, y2_score)
                p_values[i, j] = p_value
            except Exception as e:
                print(f"Error computing p-value for {model1} vs {model2} on {data_type}: {e}")
    
    # Create heatmap
    plt.figure(figsize=(12, 10))
    mask = np.triu(np.ones_like(p_values, dtype=bool))  # Mask for upper triangle
    
    # Create a custom colormap for p-values
    cmap = sns.diverging_palette(220, 10, as_cmap=True)
    
    # Plot the heatmap
    ax = sns.heatmap(
        p_values, 
        mask=mask,
        cmap=cmap,
        vmax=0.05,
        vmin=0,
        center=0.025,
        square=True,
        linewidths=.5,
        cbar_kws={"shrink": .5, "label": "p-value"},
        annot=True,
        fmt=".2f",
        xticklabels=available_models,
        yticklabels=available_models
    )
    
    # Highlight significant p-values (p < 0.05)
    for i in range(n_models):
        for j in range(n_models):
            if i > j and p_values[i, j] < 0.05:
                ax.add_patch(plt.Rectangle((j, i), 1, 1, fill=False, edgecolor='black', lw=2))
    
    plt.title(f'Pairwise Wilcoxon Test p-values for {data_type}')
    plt.tight_layout()
    plt.savefig(f"results/wilcoxon_heatmap_{data_type}.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Created heatmap for {data_type}")

# Calculate ROC-AUC for each model and data type
results = []
for (model_name, data_type), data in model_predictions.items():
    try:
        auc = roc_auc_score(data['labels'], data['predictions'])
        results.append({
            'model': model_name,
            'data_type': data_type,
            'roc_auc': round(auc, 4)
        })
    except Exception as e:
        print(f"Error calculating ROC-AUC for {model_name} on {data_type}: {e}")

# Create and save results dataframe
results_df = pd.DataFrame(results)
# Sort by data_type and then by roc_auc
results_df = results_df.sort_values(by=['data_type', 'roc_auc'], ascending=[False, False])
results_df.to_csv("results/model_performance_roc_auc.csv", index=False)
print(f"Saved results for {len(results)} model-dataset combinations")

# Add ROC-AUC values to the base table if it exists
try:
    base_table = pd.read_excel("results/base_table_moe.xlsx")
    
    for i, row in results_df.iterrows():
        for j, row_base in base_table.iterrows():
            if row['model'].lower() == row_base['Model'].lower() and row['data_type'].lower() == row_base['Dataset'].lower():
                base_table.at[j, 'roc_auc'] = row['roc_auc']
    
    base_table.to_excel("results/base_table_final.xlsx", index=False)
    print("Updated base table with ROC-AUC values")
except Exception as e:
    print(f"Could not update base table: {e}")

# Display model counts
model_counts = {model: sum(1 for key in model_predictions if key[0] == model) for model in models}
print("Model counts:", model_counts)

Duplicated paths: {'results\\BERT-base\\BERT-base_all_False_predictions.csv': 2, 'results\\BERT-base\\BERT-base_AUTOIMMUNE_False_predictions.csv': 2, 'results\\BERT-base\\BERT-base_CANCER_False_predictions.csv': 2, 'results\\BERT-base\\BERT-base_COPD_False_predictions.csv': 2, 'results\\BERT-base\\BERT-base_CVD_False_predictions.csv': 2, 'results\\BERT-base\\BERT-base_PARASITIC_False_predictions.csv': 2, 'results\\BERT-large\\BERT-large_all_False_predictions.csv': 2, 'results\\BERT-large\\BERT-large_AUTOIMMUNE_False_predictions.csv': 2, 'results\\BERT-large\\BERT-large_CANCER_False_predictions.csv': 2, 'results\\BERT-large\\BERT-large_COPD_False_predictions.csv': 2, 'results\\BERT-large\\BERT-large_CVD_False_predictions.csv': 2, 'results\\BERT-large\\BERT-large_PARASITIC_False_predictions.csv': 2, 'results\\BioBERT\\BioBERT_all_False_predictions.csv': 2, 'results\\BioBERT\\BioBERT_AUTOIMMUNE_False_predictions.csv': 2, 'results\\BioBERT\\BioBERT_CANCER_False_predictions.csv': 2, 'result