In [13]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set the style for better visualization
sns.set_palette("husl")

# Read the datasets and clean column names
def clean_column_names(df):
    """Remove whitespace from column names"""
    df.columns = df.columns.str.strip()
    return df

# Read and clean the datasets
dekois_df = clean_column_names(pd.read_csv('dekois_complete_metrics2.csv'))
dude_df = clean_column_names(pd.read_csv('dude_complete_metrics.csv'))
litpcba_df = clean_column_names(pd.read_csv('litpcba_complete_metrics.csv'))

def create_performance_plot(datasets, metric, dataset_names, figsize=(20, 15)):
    """
    Create subplots comparing model performance across different thresholds for each dataset
    """
    fig, axes = plt.subplots(3, 1, figsize=figsize)
    fig.suptitle(f'Model {metric} Performance Across Percentile Thresholds', fontsize=16, y=0.95)
    
    for idx, (df, name) in enumerate(zip(datasets, dataset_names)):
        # Clean model names
        df['Model'] = df['Model'].str.strip()
        
        # Create pivot table
        pivot_df = df.pivot_table(
            values=metric,
            index='Model',
            columns='Percentile',
            aggfunc='mean'
        )
        
        # Sort by average performance
        pivot_df['mean'] = pivot_df.mean(axis=1)
        pivot_df_sorted = pivot_df.sort_values('mean', ascending=False)
        pivot_df_sorted = pivot_df_sorted.drop('mean', axis=1)
        
        # Create bar plot
        ax = axes[idx]
        bars = pivot_df_sorted.plot(kind='bar', ax=ax, width=0.8)
        
        # Set titles and labels
        ax.set_title(f'{name} Dataset', pad=20)
        ax.set_ylabel(f'{metric} Value')
        
        # Rotate x-axis labels
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, horizontalalignment='right')
        
        # Add legend
        ax.legend(title='Percentile Threshold (%)', bbox_to_anchor=(1.05, 1))
        
        # Adjust grid
        ax.grid(True, linestyle='--', alpha=0.7)
        
        # Adjust tick parameters
        ax.tick_params(axis='x', labelsize=8)
    
    plt.tight_layout()
    return fig

# Create lists of datasets and their names
datasets = [dekois_df, dude_df, litpcba_df]
dataset_names = ['DEKOIS', 'DUD-E', 'LitPCBA']

# Create and save EF plot
ef_fig = create_performance_plot(datasets, 'EF', dataset_names)
ef_fig.savefig('enrichment_factor_comparison.png', bbox_inches='tight', dpi=300)

# Create and save REF plot
ref_fig = create_performance_plot(datasets, 'REF', dataset_names)
ref_fig.savefig('relative_enrichment_factor_comparison.png', bbox_inches='tight', dpi=300)

# Display summary statistics
def print_dataset_summary(df, dataset_name):
    """Print summary statistics for a dataset"""
    print(f"\n{dataset_name} Dataset Summary:")
    print("-" * 50)
    
    # Calculate average performance across thresholds
    ef_avg = df.groupby('Model')['EF'].mean().sort_values(ascending=False)
    ref_avg = df.groupby('Model')['REF'].mean().sort_values(ascending=False)
    
    print("\nTop 5 Models by Average EF:")
    print(ef_avg.head().round(3))
    
    print("\nTop 5 Models by Average REF:")
    print(ref_avg.head().round(3))

# Print summaries for each dataset
for df, name in zip(datasets, dataset_names):
    print_dataset_summary(df, name)