# Create comprehensive comparison charts for all baseline models

In [None]:
import os, json, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd
import logging
import sys
import gc
import builtins
from sklearn.model_selection import GroupKFold, StratifiedKFold, train_test_split
from sklearn.model_selection import cross_validate
from sklearn.model_selection import RandomizedSearchCV
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
from sklearn.impute import SimpleImputer
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.metrics import make_scorer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import glob
from joblib import dump
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

## do not display DtypeWarning from Pandas
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# pd.set_option('display.max_columns', None)
# pd.set_option('display.max_rows', None)
# pd.set_option('display.width', 1000)
# pd.set_option('display.max_colwidth', None)

def create_comparison_visualizations(summary_df, outdir):
    """
    Create comprehensive comparison charts for all baseline models
    """
    print(f"\nGenerating comparison visualizations...")
    
    # 1. Bar chart comparing all metrics across models
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    metrics = ['Accuracy', 'F1_macro', 'Precision_macro', 'Recall_macro']
    
    for idx, metric in enumerate(metrics):
        ax = axes[idx // 2, idx % 2]
        summary_df.plot(x='model', y=metric, kind='bar', ax=ax, legend=False, color='steelblue')
        ax.set_title(f'{metric} Comparison', fontsize=14, fontweight='bold')
        ax.set_xlabel('Model', fontsize=12)
        ax.set_ylabel(metric, fontsize=12)
        ax.set_xticklabels(summary_df['model'], rotation=45, ha='right')
        ax.grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for container in ax.containers:
            ax.bar_label(container, fmt='%.3f', padding=3)
    
    plt.tight_layout()
    plt.savefig(outdir / "metrics_comparison_bars.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Radar/Spider chart for multi-metric comparison
    from math import pi
    
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
    
    angles = [n / len(metrics) * 2 * pi for n in range(len(metrics))]
    angles += angles[:1]
    
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(metrics)
    
    for idx, row in summary_df.iterrows():
        values = row[metrics].tolist()
        values += values[:1]
        ax.plot(angles, values, 'o-', linewidth=2, label=row['model'])
        ax.fill(angles, values, alpha=0.15)
    
    ax.set_ylim(0, 1)
    ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
    ax.set_title('Multi-Metric Performance Comparison', size=16, fontweight='bold', pad=20)
    ax.grid(True)
    
    plt.tight_layout()
    plt.savefig(outdir / "metrics_radar_chart.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. Heatmap of all metrics
    plt.figure(figsize=(10, 6))
    metrics_matrix = summary_df.set_index('model')[metrics]
    sns.heatmap(metrics_matrix, annot=True, fmt='.3f', cmap='YlGnBu', 
                cbar_kws={'label': 'Score'}, linewidths=0.5)
    plt.title('Performance Metrics Heatmap', fontsize=16, fontweight='bold', pad=15)
    plt.xlabel('Metrics', fontsize=12)
    plt.ylabel('Models', fontsize=12)
    plt.tight_layout()
    plt.savefig(outdir / "metrics_heatmap.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Model ranking visualization
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Calculate average rank for each model
    ranks = summary_df[metrics].rank(ascending=False)
    ranks['model'] = summary_df['model']
    ranks['avg_rank'] = ranks[metrics].mean(axis=1)
    ranks = ranks.sort_values('avg_rank')
    
    x_pos = np.arange(len(ranks))
    bars = ax.barh(x_pos, ranks['avg_rank'], color='coral')
    ax.set_yticks(x_pos)
    ax.set_yticklabels(ranks['model'])
    ax.invert_yaxis()
    ax.set_xlabel('Average Rank (lower is better)', fontsize=12)
    ax.set_title('Model Ranking Based on Average Performance', fontsize=16, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, val) in enumerate(zip(bars, ranks['avg_rank'])):
        ax.text(val + 0.05, i, f'{val:.2f}', va='center')
    
    plt.tight_layout()
    plt.savefig(outdir / "model_ranking.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 5. Box plot showing metric distribution
    fig, ax = plt.subplots(figsize=(12, 6))
    melted = summary_df.melt(id_vars='model', value_vars=metrics, 
                             var_name='Metric', value_name='Score')
    sns.boxplot(data=melted, x='Metric', y='Score', ax=ax, palette='Set2')
    sns.swarmplot(data=melted, x='Metric', y='Score', color='black', alpha=0.5, ax=ax)
    ax.set_title('Distribution of Metrics Across All Models', fontsize=16, fontweight='bold')
    ax.set_ylabel('Score', fontsize=12)
    ax.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(outdir / "metrics_distribution.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"âœ“ Generated 5 comparison visualizations in {outdir}")
    print("  - metrics_comparison_bars.png")
    print("  - metrics_radar_chart.png")
    print("  - metrics_heatmap.png")
    print("  - model_ranking.png")
    print("  - metrics_distribution.png")