In [None]:
# FORCE MODULE RELOAD - FIX JUPYTER CACHING ISSUE (MUST BE FIRST!)
import sys 

# Remove defensive_metrics from cache and reload
if 'modules.defensive_metrics' in sys.modules:
    print("Removing cached defensive_metrics module...")
    del sys.modules['modules.defensive_metrics']

# Remove cleanedDataParser from cache and reload  
if 'cleanedDataParser' in sys.modules:
    print("Removing cached cleanedDataParser module...")
    del sys.modules['cleanedDataParser']

# Remove any other related modules
modules_to_remove = [key for key in sys.modules.keys() if 'modules.' in key]
for module in modules_to_remove:
    print(f"Removing cached {module} module...")
    del sys.modules[module]

print("All cached modules removed! Fresh imports will now load the latest fixes.")

In [None]:
# ALL-IN-ONE FULLY OPTIMIZED COMPLETE PIPELINE

# ===== IMPORTS =====
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import importlib
from modularized_data_parser import *
from modules.two_way_players import get_cleaned_two_way_data
from modules.modeling import (
    ModelResults, create_keras_model, print_metrics,
    run_basic_regressions, run_advanced_models, 
    run_nonlinear_models, run_neural_network,
    select_best_models_by_category, apply_proper_war_adjustments
)
from modules.park_factors import (
    calculate_park_factors, 
    apply_enhanced_hitter_park_adjustments, 
    apply_enhanced_pitcher_park_adjustments
)
from modules.name_mapping_caching import create_name_mapping

In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from scipy import stats

def plot_consolidated_model_comparison(model_results, model_names=None, show_residuals=True, show_metrics=True):
    """
    Consolidated model comparison system that replaces individual print_metrics graphs.
    Creates unified visualizations with selectable traces for easy comparison.

    Args:
        model_results: Your ModelResults object
        model_names: List of models to compare (None = auto-select best)
        show_residuals: Whether to include residual analysis plots
        show_metrics: Whether to include scatter plot comparisons

    Returns:
        Dictionary with analysis results for each model
    """
    if model_names is None:
        model_names = select_best_models_by_category(model_results)
        print(f"🎯 Auto-selected models for comparison: {[m.upper() for m in model_names]}")

    print("\n📊 CONSOLIDATED MODEL COMPARISON SYSTEM")
    print("="*70)
    print("🔍 Replacing individual graphs with unified selectable trace visualizations...")

    # Collect all data for consolidated visualizations
    all_data = []
    model_stats = {}

    for model_name in model_names:
        model_stats[model_name] = {}

        for player_type in ['hitter', 'pitcher']:
            for metric in ['war', 'warp']:
                key = f"{model_name}_{player_type}_{metric}"
                if key in model_results.results:
                    data = model_results.results[key]

                    if len(data['y_true']) > 0:
                        y_true = np.array(data['y_true'])
                        y_pred = np.array(data['y_pred'])
                        residuals = y_true - y_pred

                        # Calculate comprehensive statistics
                        rmse = np.sqrt(np.mean(residuals**2))
                        mae = np.mean(np.abs(residuals))
                        r2 = 1 - (np.sum(residuals**2) / np.sum((y_true - np.mean(y_true))**2))

                        # Store statistics
                        model_stats[model_name][f"{player_type}_{metric}"] = {
                            'rmse': rmse,
                            'mae': mae,
                            'r2': r2,
                            'count': len(y_true)
                        }

                        # Add to plotting data
                        for i in range(len(residuals)):
                            all_data.append({
                                'Model': model_name.title(),
                                'PlayerType': player_type.title(),
                                'Metric': metric.upper(),
                                'Category': f"{player_type.title()} {metric.upper()}",
                                'Actual': y_true[i],
                                'Predicted': y_pred[i],
                                'Residual': residuals[i],
                                'Player': data['player_names'][i] if 'player_names' in data else f"Player_{i}"
                            })

    if not all_data:
        print("❌ No data available for consolidated comparison")
        return {}

    df = pd.DataFrame(all_data)

    # 1. CONSOLIDATED SCATTER PLOTS WITH SELECTABLE TRACES
    if show_metrics:
        print("\nCreating consolidated prediction accuracy plots...")

        fig_scatter = make_subplots(
            rows=2, cols=2,
            subplot_titles=['Hitter WAR', 'Hitter WARP', 'Pitcher WAR', 'Pitcher WARP'],
            vertical_spacing=0.1,
            horizontal_spacing=0.1
        )

        colors = px.colors.qualitative.Set1

        for i, category in enumerate(['Hitter WAR', 'Hitter WARP', 'Pitcher WAR', 'Pitcher WARP']):
            cat_data = df[df['Category'] == category]

            row = (i // 2) + 1
            col = (i % 2) + 1

            for j, model in enumerate(cat_data['Model'].unique()):
                model_data = cat_data[cat_data['Model'] == model]

                fig_scatter.add_trace(
                    go.Scatter(
                        x=model_data['Actual'],
                        y=model_data['Predicted'],
                        mode='markers',
                        name=f"{model} {category}",
                        marker=dict(color=colors[j % len(colors)], size=6, opacity=0.7),
                        text=model_data['Player'],
                        hovertemplate="<b>%{text}</b><br>" +
                                      "Actual: %{x:.3f}<br>" +
                                      "Predicted: %{y:.3f}<br>" +
                                      f"Model: {model}<br>" +
                                      "<extra></extra>",
                        showlegend=(i == 0)  # Only show legend for first subplot
                    ),
                    row=row, col=col
                )

                # Add perfect prediction line
                if j == 0:  # Only add once per subplot
                    min_val = min(model_data['Actual'].min(), model_data['Predicted'].min())
                    max_val = max(model_data['Actual'].max(), model_data['Predicted'].max())

                    fig_scatter.add_trace(
                        go.Scatter(
                            x=[min_val, max_val],
                            y=[min_val, max_val],
                            mode='lines',
                            line=dict(dash='dash', color='red', width=2),
                            name='Perfect Prediction',
                            showlegend=(i == 0)
                        ),
                        row=row, col=col
                    )

        fig_scatter.update_layout(
            title="Consolidated Model Comparison: Prediction Accuracy (Click Legend to Toggle)",
            height=800,
            width=1200,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5)
        )

        fig_scatter.show()

    # 2. CONSOLIDATED RESIDUAL ANALYSIS
    if show_residuals:
        print("\n🔍 Creating consolidated residual analysis...")

        fig_residuals = make_subplots(
            rows=2, cols=2,
            subplot_titles=['Residuals vs Fitted', 'Residual Distributions', 'Q-Q Plot', 'Model Performance'],
            specs=[[{"secondary_y": False}, {"secondary_y": False}],
                   [{"secondary_y": False}, {"secondary_y": False}]]
        )

        # Residuals vs Fitted
        for j, model in enumerate(df['Model'].unique()):
            model_data = df[df['Model'] == model]

            fig_residuals.add_trace(
                go.Scatter(
                    x=model_data['Predicted'],
                    y=model_data['Residual'],
                    mode='markers',
                    name=f"{model} Residuals",
                    marker=dict(color=colors[j % len(colors)], size=4, opacity=0.6),
                    showlegend=True
                ),
                row=1, col=1
            )

        # Add horizontal line at y=0
        fig_residuals.add_hline(y=0, line_dash="dash", line_color="gray", row=1, col=1)

        # Residual Distributions
        for j, model in enumerate(df['Model'].unique()):
            model_data = df[df['Model'] == model]

            fig_residuals.add_trace(
                go.Histogram(
                    x=model_data['Residual'],
                    name=f"{model} Distribution",
                    opacity=0.7,
                    nbinsx=30,
                    showlegend=False
                ),
                row=1, col=2
            )

        # Q-Q Plot (simplified - one model for clarity)
        if len(df['Model'].unique()) > 0:
            best_model = df['Model'].unique()[0]
            best_data = df[df['Model'] == best_model]
            residuals = best_data['Residual'].values

            sorted_residuals = np.sort(residuals)
            n = len(sorted_residuals)
            theoretical_quantiles = stats.norm.ppf(np.linspace(0.01, 0.99, n))

            fig_residuals.add_trace(
                go.Scatter(
                    x=theoretical_quantiles,
                    y=sorted_residuals,
                    mode='markers',
                    name=f"{best_model} Q-Q",
                    marker=dict(size=4),
                    showlegend=False
                ),
                row=2, col=1
            )

            # Theoretical line
            slope = np.std(residuals)
            intercept = np.mean(residuals)
            line_min, line_max = min(theoretical_quantiles), max(theoretical_quantiles)

            fig_residuals.add_trace(
                go.Scatter(
                    x=[line_min, line_max],
                    y=[intercept + slope * line_min, intercept + slope * line_max],
                    mode='lines',
                    line=dict(color='red', dash='dash'),
                    name='Normal Line',
                    showlegend=False
                ),
                row=2, col=1
            )

        # Model Performance Comparison
        models = list(model_stats.keys())
        metrics = ['R²', 'RMSE', 'MAE']

        for metric_name in metrics:
            metric_values = []
            for model in models:
                # Average across all model variants
                values = []
                for key, model_stat_data in model_stats[model].items():
                    if metric_name == 'R²':
                        values.append(model_stat_data['r2'])
                    elif metric_name == 'RMSE':
                        values.append(model_stat_data['rmse'])
                    elif metric_name == 'MAE':
                        values.append(model_stat_data['mae'])

                metric_values.append(np.mean(values) if values else 0)

            fig_residuals.add_trace(
                go.Bar(
                    x=models,
                    y=metric_values,
                    name=metric_name,
                    showlegend=False
                ),
                row=2, col=2
            )

        fig_residuals.update_layout(
            title="Consolidated Residual Analysis (Click Legend to Toggle Models)",
            height=800,
            width=1200,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5)
        )

        fig_residuals.show()

    # 3. COMPREHENSIVE STATISTICAL SUMMARY
    print("\nCONSOLIDATED MODEL PERFORMANCE SUMMARY")
    print("="*70)

    for model_name in model_names:
        if model_name in model_stats:
            print(f"\n🤖 {model_name.upper()} MODEL:")

            total_predictions = sum(model_stat_data['count'] for model_stat_data in model_stats[model_name].values())
            avg_r2 = np.mean([model_stat_data['r2'] for model_stat_data in model_stats[model_name].values()])
            avg_rmse = np.mean([model_stat_data['rmse'] for model_stat_data in model_stats[model_name].values()])
            avg_mae = np.mean([model_stat_data['mae'] for model_stat_data in model_stats[model_name].values()])

            print(f"   📊 Overall Performance:")
            print(f"      • Total Predictions: {total_predictions}")
            print(f"      • Average R²: {avg_r2:.4f}")
            print(f"      • Average RMSE: {avg_rmse:.4f}")
            print(f"      • Average MAE: {avg_mae:.4f}")

            print(f"   📈 By Category:")
            for key, model_stat_data in model_stats[model_name].items():
                category = key.replace('_', ' ').title()
                print(f"      • {category}: R²={model_stat_data['r2']:.4f}, RMSE={model_stat_data['rmse']:.4f}, Count={model_stat_data['count']}")

    return model_stats

def plot_comprehensive_residual_analysis(model_results, model_names=None, player_type="both", metric="both"):
    """
    Comprehensive residual plot comparison system for ML model diagnostics.

    This function creates multiple residual visualizations to diagnose model performance:
    1. Residuals vs Fitted Values (heteroscedasticity detection)
    2. Q-Q plots (normality assessment)
    3. Residual distribution histograms
    4. Scale-Location plots (variance homogeneity)
    5. Model comparison summary statistics

    Args:
        model_results: Your ModelResults object
        model_names: List of models to compare (None = auto-select best)
        player_type: 'hitter', 'pitcher', or 'both'
        metric: 'war', 'warp', or 'both'

    Returns:
        Dictionary with residual analysis results for each model
    """
    # Use the new consolidated system
    return plot_consolidated_model_comparison(model_results, model_names, show_residuals=True, show_metrics=False)

In [None]:
def plot_results(title, y_true, y_pred, player_names=None):
    """Enhanced plot with player names in hover tooltips"""
    if player_names is None:
        player_names = [f"Player_{i}" for i in range(len(y_true))]
    
    # Calculate errors for additional hover info
    errors = np.array(y_pred) - np.array(y_true)
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=y_true, 
        y=y_pred,
        mode='markers',
        marker=dict(size=8, opacity=0.7),
        text=player_names,
        customdata=np.column_stack((errors, y_true, y_pred)),
        hovertemplate="<b>%{text}</b><br>" +
                      "Actual WAR: %{customdata[1]:.3f}<br>" +
                      "Predicted WAR: %{customdata[2]:.3f}<br>" +
                      "Error: %{customdata[0]:.3f}<br>" +
                      "<extra></extra>",
        name='Predictions'
    ))
    
    # Add perfect prediction line
    min_val = min(min(y_true), min(y_pred))
    max_val = max(max(y_true), max(y_pred))
    fig.add_trace(go.Scatter(
        x=[min_val, max_val], 
        y=[min_val, max_val],
        mode='lines',
        line=dict(dash='dash', color='red'),
        name='Perfect Prediction'
    ))
    
    fig.update_layout(
        title=title,
        xaxis_title="Actual WAR",
        yaxis_title="Predicted WAR",
        template='plotly_white',
        width=600,
        height=600
    )
    
    fig.show()

def plot_training_history(history):
    """Plot training and validation loss over epochs"""
    if hasattr(history, 'history'):
        # Keras history object
        loss = history.history.get('loss', [])
        val_loss = history.history.get('val_loss', [])
        
        fig = go.Figure()
        
        epochs = list(range(1, len(loss) + 1))
        
        fig.add_trace(go.Scatter(
            x=epochs,
            y=loss,
            mode='lines',
            name='Training Loss',
            line=dict(color='blue')
        ))
        
        if val_loss:
            fig.add_trace(go.Scatter(
                x=epochs,
                y=val_loss,
                mode='lines',
                name='Validation Loss',
                line=dict(color='red')
            ))
        
        fig.update_layout(
            title='Training History',
            xaxis_title='Epoch',
            yaxis_title='Loss',
            template='plotly_white'
        )
        
        fig.show()
    else:
        print("No training history available")

print("✅ Utility functions loaded: plot_results, plot_training_history (print_metrics from module)")

In [None]:
# ===== MODEL RESULTS CLASS =====
# Import ModelResults from the modeling module
model_results = ModelResults()

print("✅ ModelResults class loaded from modules/modeling.py")

In [None]:
# ===== DEMONSTRATE NEW COMPREHENSIVE FANGRAPHS INTEGRATION =====
try:
    print("🚀 TESTING COMPREHENSIVE FANGRAPHS INTEGRATION SYSTEM")
    print("="*80)
    
    # Test the new comprehensive system
    from modularized_data_parser import demonstrate_comprehensive_system
    demonstrate_comprehensive_system()
    
except Exception as e:
    print(f"❌ Error demonstrating comprehensive system: {e}")
    import traceback
    traceback.print_exc()

def data_preparation():
    """
    FULLY ENHANCED data preparation with ALL missing improvements integrated:
    - SUPERIOR name mapping with index-based duplicate handling (from fixed_enhanced_mapping.py)
    - Performance optimizations (from optimized_name_matching.py)
    - Enhanced conflict resolution (from multiple_matches_handling.py)
    - Neural network-safe data cleaning (from complete_fix_integration.py)
    - ENHANCED BASERUNNING with run expectancy matrix and situational values
    - Stronger park factor effects
    - True 2-way player fix with team verification
    - Enhanced defensive system with OAA integration and framing
    - FIXED: Load park factors ONCE to prevent repetitive loading
    - EXPANDED: Yearly BP data (2016-2024 hitters, 2016-2024 pitchers) 
    - NEW: Comprehensive FanGraphs integration with 50+ features per player
    - NEW: Season information capture for multi-year modeling
    """
    print("=== FULLY ENHANCED DATA PREPARATION WITH COMPREHENSIVE FANGRAPHS INTEGRATION ===")
    hitter_data = clean_sorted_hitter()
    hitter_pred_data = clean_yearly_warp_hitter()  # EXPANDED: 6,410 player-seasons (2016-2024) vs 463 (2021 only)
    pitcher_data = clean_sorted_pitcher()
    pitcher_pred_data = clean_yearly_warp_pitcher()  # EXPANDED: ~5,000+ player-seasons (2016-2024) vs 472 (2021 only)
    war_values = clean_war()
    
    # Load ENHANCED baserunning system with run expectancy
    print("Loading ENHANCED baserunning system with run expectancy...")
    enhanced_baserunning_values = calculate_enhanced_baserunning_values()
    print(f"Enhanced baserunning values: {len(enhanced_baserunning_values)} players")
    
    # Load enhanced defensive system with OAA integration and framing
    print("Loading enhanced defensive system...")
    enhanced_defensive_values = clean_enhanced_defensive_players()
    print(f"Enhanced defensive values: {len(enhanced_defensive_values)} player-seasons")
    
    # CRITICAL FIX: Load park factors ONCE instead of recalculating for each player
    print("Loading park factors ONCE (fixes repetitive loading)...")
    park_factors = calculate_park_factors()
    print(f"Loaded park factors for {len(park_factors)} stadiums")

    # Get comprehensive two-way player analysis (replaces 15+ lines of manual logic)
    print("Identifying true 2-way players using MLB criteria...")
    two_way_analysis = get_cleaned_two_way_data()

    # Extract for backward compatibility with existing notebook logic
    official_two_way_players = two_way_analysis['two_way_players']
    two_way_players = set()

    # Convert module data to simple name set (maintains compatibility)
    for player_key, data in official_two_way_players.items():
        player_name = player_key.rsplit('_', 1)[0]  # Remove year suffix
        two_way_players.add(player_name)

    # Enhanced reporting using module data
    print(f"True 2-way players found:")
    for player_key, data in official_two_way_players.items():
        player_name, year = player_key.rsplit('_', 1)
        print(f"  {player_name} ({year}): Hitter WARP={data['hitting_warp']:.2f}, Pitcher WARP={data['pitching_warp']:.2f}")

    print(f"Loaded data - Hitters: {len(hitter_data)}, WARP hitters: {len(hitter_pred_data)}, WAR: {len(war_values)}")
    print(f"Enhanced baserunning values: {len(enhanced_baserunning_values)}")

    print("Creating SUPERIOR name mappings with index-based duplicate handling...")
    # CRITICAL IMPROVEMENT: Use optimized index-based mapping that handles duplicates correctly
    warp_to_war_map = create_optimized_name_mapping_with_indices(hitter_pred_data, war_values)
    
    # For hitter stats, use traditional mapping as it works well
    warp_to_hitter_map = create_name_mapping(hitter_pred_data['Name'].tolist(), hitter_data['Hitters'].tolist())

    hitter_stats = hitter_data
    x_warp, y_warp, x_war, y_war = [], [], [], []
    hitter_names_warp, hitter_names_war = [], []
    # NEW: Season tracking for multi-year modeling
    hitter_seasons_warp, hitter_seasons_war = [], []

    for index, row in hitter_pred_data.iterrows():
        warp_name = row['Name']
        team = row['Team']
        # NEW: Extract season information
        season = row.get('Year', row.get('Season', 2021))  # Try Year first, then Season, default 2021
        
        hitter_match = warp_to_hitter_map.get(warp_name)
        if hitter_match:
            player_stats = hitter_stats[hitter_stats['Hitters'] == hitter_match]
            if not player_stats.empty:
                stats = player_stats[['K','BB','AVG','OBP','SLG']].values.flatten().tolist()
                
                # Use ENHANCED baserunning values with run expectancy
                enhanced_baserunning_val = enhanced_baserunning_values.get(warp_name, 0.0)
                stats.append(enhanced_baserunning_val)
                
                # Apply ENHANCED park factor adjustments with stronger effects (FIXED: Pass park_factors)
                park_adjusted_stats = apply_enhanced_hitter_park_adjustments(
                    {'AVG': stats[2], 'OBP': stats[3], 'SLG': stats[4]}, warp_name, team, park_factors)
                
                # Replace original stats with park-adjusted ones if available
                if 'AVG_park_adj' in park_adjusted_stats:
                    stats[2] = park_adjusted_stats['AVG_park_adj']
                    stats[3] = park_adjusted_stats['OBP_park_adj'] 
                    stats[4] = park_adjusted_stats['SLG_park_adj']
                
                # Use enhanced defensive system - try multiple possible keys for the player
                defensive_val = 0  # Default value
                player_name_clean = hitter_match.replace(' ', '').replace('.', '')
                
                # Try to find defensive value using different key formats
                possible_keys = []
                for year in [2016, 2017, 2018, 2019, 2020, 2021]:
                    for team_abbr in ['BOS', 'NYY', 'TB', 'TOR', 'BAL', 'CLE', 'DET', 'KC', 'MIN', 'CWS', 
                                'HOU', 'LAA', 'OAK', 'SEA', 'TEX', 'ATL', 'MIA', 'NYM', 'PHI', 'WSN',
                                'CHC', 'CIN', 'MIL', 'PIT', 'STL', 'ARI', 'COL', 'LAD', 'SD', 'SF']:
                        possible_keys.extend([
                            f"{hitter_match}_{team_abbr}_{year}",
                            f"{player_name_clean}_{team_abbr}_{year}",
                            f"{hitter_match.split()[0]}_{team_abbr}_{year}",  # First name only
                        ])
                
                # Find best match for defensive value
                for key in possible_keys:
                    if key in enhanced_defensive_values:
                        defensive_val = enhanced_defensive_values[key].get('enhanced_def_value', 0)
                        break
                
                stats.append(defensive_val)  # Enhanced defensive value with OAA integration
                
                x_warp.append(stats)
                y_warp.append(row['WARP'])
                hitter_names_warp.append(warp_name)
                hitter_seasons_warp.append(season)  # NEW: Store season
                
                # CRITICAL: Use INDEX-based mapping for WAR targets (handles duplicates correctly)
                if warp_name in warp_to_war_map:
                    target_idx = warp_to_war_map[warp_name]  # Get INDEX not name
                    war_row = war_values.iloc[target_idx]    # Use index to get correct row
                    total_war = war_row['Total WAR']
                    
                    # 2-WAY PLAYER FIX: Only apply to TRUE 2-way players (same team)
                    if warp_name in two_way_players:
                        # For 2-way players, use hitting component only (Total - Primary)
                        primary_war = war_row.get('Primary WAR', 0)
                        if primary_war is not None and primary_war != 0:
                            hitting_war = total_war - primary_war  # Hitting + fielding + baserunning
                            print(f"  TRUE 2-way player {warp_name}: Total WAR {total_war:.2f} -> Hitting WAR {hitting_war:.2f}")
                            target_war = hitting_war
                        else:
                            target_war = total_war  # Fallback if no Primary WAR
                    else:
                        # Single-role hitters use Total WAR (which should be hitting-only)
                        target_war = total_war
                    
                    x_war.append(stats)
                    y_war.append(target_war)
                    hitter_names_war.append(warp_name)
                    hitter_seasons_war.append(season)  # NEW: Store season

    # CRITICAL: Use enhanced data cleaning for neural networks
    print("Cleaning data with enhanced neural network-safe algorithms...")
    x_warp, y_warp = validate_and_clean_data_enhanced(x_warp, y_warp)
    x_war, y_war = validate_and_clean_data_enhanced(x_war, y_war)

    print(f"Successfully matched {len(x_warp)} hitters with 7 features:")
    print(f"  - 5 hitting stats (with park adjustments)")
    print(f"  - Enhanced baserunning (run expectancy + situational)")
    print(f"  - Enhanced defense (OAA integration + framing)")
    print(f"WAR target range after enhanced cleaning: {min(y_war):.2f} to {max(y_war):.2f}")

    # Pitcher processing with enhanced mapping and park adjustments (FIXED: Pass park_factors)
    pitcher_warp_to_main = create_name_mapping(pitcher_pred_data['Name'].tolist(), pitcher_data['Pitchers'].tolist())
    pitcher_warp_to_war = create_optimized_name_mapping_with_indices(pitcher_pred_data, war_values)
    pitcher_stats = pitcher_data

    a_warp, b_warp, a_war, b_war = [], [], [], []
    pitcher_names_warp, pitcher_names_war = [], []
    # NEW: Season tracking for pitchers too
    pitcher_seasons_warp, pitcher_seasons_war = [], []

    for index, row in pitcher_pred_data.iterrows():
        warp_name = row['Name']
        team = row['Team']
        # NEW: Extract season information  
        season = row.get('Year', row.get('Season', 2021))
        
        pitcher_match = pitcher_warp_to_main.get(warp_name)
        if pitcher_match:
            player_stats = pitcher_stats[pitcher_stats['Pitchers'] == pitcher_match]
            if not player_stats.empty:
                stats = player_stats[['IP','BB','K','HR','ERA']].values.flatten().tolist()
                
                # Apply enhanced park adjustments for pitchers (FIXED: Pass park_factors)
                park_adjusted_stats = apply_enhanced_pitcher_park_adjustments(
                    {'ERA': stats[4]}, warp_name, team, park_factors)
                if 'ERA_park_adj' in park_adjusted_stats:
                    stats[4] = park_adjusted_stats['ERA_park_adj']
                
                a_warp.append(stats)
                b_warp.append(row['WARP'])
                pitcher_names_warp.append(warp_name)
                pitcher_seasons_warp.append(season)  # NEW: Store season
                
                # Use index-based mapping for pitchers too
                if warp_name in pitcher_warp_to_war:
                    target_idx = pitcher_warp_to_war[warp_name]
                    war_row = war_values.iloc[target_idx]
                    if 'Primary WAR' in war_row:
                        # Primary WAR is already the pitching component - no fix needed
                        a_war.append(stats)
                        b_war.append(war_row['Primary WAR'])
                        pitcher_names_war.append(warp_name)
                        pitcher_seasons_war.append(season)  # NEW: Store season

    # Enhanced data cleaning for pitchers too
    a_warp, b_warp = validate_and_clean_data_enhanced(a_warp, b_warp)
    a_war, b_war = validate_and_clean_data_enhanced(a_war, b_war)

    print(f"Successfully matched {len(a_warp)} pitchers with enhanced park factors")
    print(f"2-way player fix applied to {len(two_way_players)} TRUE 2-way players")
    print(f"Enhanced park factors applied to all players")
    print(f"Index-based mapping FIXES duplicate name issues")
    print(f"Enhanced baserunning with run expectancy REPLACES simple counting")
    print(f"Neural network-safe data cleaning applied")
    print(f"FIXED: Park factors loaded once instead of {len(hitter_pred_data) + len(pitcher_pred_data)} times")
    print(f"NEW: Season information captured for {len(hitter_seasons_warp)} hitter + {len(pitcher_seasons_warp)} pitcher observations")
    print(f"🚀 READY FOR COMPREHENSIVE FANGRAPHS-ENHANCED MODELING!")
    
    return (x_warp, y_warp, x_war, y_war, a_warp, b_warp, a_war, b_war,
            hitter_names_warp, hitter_names_war, pitcher_names_warp, pitcher_names_war,
            hitter_seasons_warp, hitter_seasons_war, pitcher_seasons_warp, pitcher_seasons_war)

def prepare_train_test_splits():
    """
    Prepare train/test splits using the enhanced data preparation WITH season information
    """
    (x_warp, y_warp, x_war, y_war, a_warp, b_warp, a_war, b_war,
     hitter_names_warp, hitter_names_war, pitcher_names_warp, pitcher_names_war,
     hitter_seasons_warp, hitter_seasons_war, pitcher_seasons_warp, pitcher_seasons_war) = data_preparation()
    
    # Include season data in train/test splits
    x_warp_train, x_warp_test, y_warp_train, y_warp_test, h_names_warp_train, h_names_warp_test, h_seasons_warp_train, h_seasons_warp_test = train_test_split(
        x_warp, y_warp, hitter_names_warp, hitter_seasons_warp, test_size=0.25, train_size=0.75, random_state=1
    )
    x_war_train, x_war_test, y_war_train, y_war_test, h_names_war_train, h_names_war_test, h_seasons_war_train, h_seasons_war_test = train_test_split(
        x_war, y_war, hitter_names_war, hitter_seasons_war, test_size=0.25, train_size=0.75, random_state=1
    )
    a_warp_train, a_warp_test, b_warp_train, b_warp_test, p_names_warp_train, p_names_warp_test, p_seasons_warp_train, p_seasons_warp_test = train_test_split(
        a_warp, b_warp, pitcher_names_warp, pitcher_seasons_warp, test_size=0.25, train_size=0.75, random_state=1
    )
    
    if len(a_war) > 0:
        a_war_train, a_war_test, b_war_train, b_war_test, p_names_war_train, p_names_war_test, p_seasons_war_train, p_seasons_war_test = train_test_split(
            a_war, b_war, pitcher_names_war, pitcher_seasons_war, test_size=0.25, train_size=0.75, random_state=1
        )
    else:
        a_war_train, a_war_test, b_war_train, b_war_test = a_warp_train, a_warp_test, b_warp_train, b_warp_test
        p_names_war_train, p_names_war_test = p_names_warp_train, p_names_warp_test
        p_seasons_war_train, p_seasons_war_test = p_seasons_warp_train, p_seasons_warp_test

    return (x_warp_train, x_warp_test, y_warp_train, y_warp_test,
            x_war_train, x_war_test, y_war_train, y_war_test,
            a_warp_train, a_warp_test, b_warp_train, b_warp_test,
            a_war_train, a_war_test, b_war_train, b_war_test,
            h_names_warp_test, h_names_war_test, p_names_warp_test, p_names_war_test,
            h_seasons_warp_test, h_seasons_war_test, p_seasons_warp_test, p_seasons_war_test)

print("✅ Enhanced data preparation and train/test split functions loaded with COMPREHENSIVE FANGRAPHS INTEGRATION")

In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

def plot_quadrant_analysis_px_toggle(model_results, season_col="Season", model_names=None, show_hitters=True, show_pitchers=True):
    """
    Enhanced quadrant analysis using Plotly Express facets with year-over-year animation.
    Dual accuracy zone visualization (orange cross + green intersection) with clickable legend toggles.
    ENHANCED: Comprehensive accuracy analysis with error percentage calculations
    FIXED: Data display issues, chronological year sorting, improved legend positioning
    """
    if model_names is None:
        model_names = select_best_models_by_category(model_results)
        print(f"🎯 Auto-selected models: {[m.upper() for m in model_names]}")

    # FIXED: More robust data collection with debugging
    data = []
    data_found = False
    
    print("🔍 Collecting data from model results...")
    for model in model_names:
        for player_type in ["hitter", "pitcher"]:
            for metric in ["war", "warp"]:
                key = f"{model}_{player_type}_{metric}"
                if key in model_results.results:
                    res = model_results.results[key]
                    if len(res["player_names"]) > 0:
                        data_found = True
                        print(f"   Found {len(res['player_names'])} entries for {key}")
                        
                        for i, player in enumerate(res["player_names"]):
                            # FIXED: Better season handling with debugging
                            if season_col in res and len(res[season_col]) > i and res[season_col][i] is not None:
                                season_value = res[season_col][i]
                                # Convert to int first for proper sorting, then back to string for consistency
                                try:
                                    season_int = int(season_value)
                                    season_value = str(season_int)  # Normalized string format
                                except (ValueError, TypeError):
                                    season_value = str(season_value)
                            else:
                                season_value = "2021"  # Default

                            data.append({
                                "Player": player,
                                "Model": model.title(),
                                "PlayerType": player_type.title(),
                                season_col: season_value,
                                f"Actual {metric.upper()}": res["y_true"][i],
                                f"Predicted {metric.upper()}": res["y_pred"][i]
                            })
                    else:
                        print(f"   No data for {key}")
                else:
                    print(f"   Key {key} not found in results")

    if not data_found or not data:
        print("❌ No data available for quadrant analysis.")
        print("Available keys in model_results:", list(model_results.results.keys())[:10])
        return

    df = pd.DataFrame(data)
    print(f"✅ Collected {len(df)} data points for analysis")

    # Enhanced delta and error calculations
    df["WAR_Delta"] = df["Actual WAR"] - df["Predicted WAR"]
    df["WARP_Delta"] = df["Actual WARP"] - df["Predicted WARP"]

    # Error percentage calculations for 10% accuracy zone
    df["WAR_Error_%"] = abs(df["WAR_Delta"]) / df["Actual WAR"].replace(0, float("nan")).abs() * 100
    df["WARP_Error_%"] = abs(df["WARP_Delta"]) / df["Actual WARP"].replace(0, float("nan")).abs() * 100

    # Multiple accuracy zone definitions
    df["In_Accuracy_Zone"] = (df["WAR_Error_%"] <= 10) & (df["WARP_Error_%"] <= 10)
    df["WAR_Delta_1"] = abs(df["WAR_Delta"]) <= 1.0
    df["WARP_Delta_1"] = abs(df["WARP_Delta"]) <= 1.0
    df["Both_Delta_1"] = df["WAR_Delta_1"] & df["WARP_Delta_1"]  # Green intersection
    df["Either_Delta_1"] = df["WAR_Delta_1"] | df["WARP_Delta_1"]  # Orange cross

    df["AccuracyZone"] = df["In_Accuracy_Zone"].map({True: "≤10% Error Both", False: "Outside Zone"})
    df["Delta1Zone"] = df["Both_Delta_1"].map({True: "Both ≤1", False: "Outside ±1"})

    # FIXED: Proper chronological sorting for animation frames
    unique_seasons = df[season_col].unique()
    try:
        # Convert to int for proper chronological sorting
        sorted_seasons = sorted([int(s) for s in unique_seasons if s is not None])
        # Convert back to strings for consistency
        sorted_season_strings = [str(s) for s in sorted_seasons]
        # Create categorical with proper order
        df[season_col] = pd.Categorical(df[season_col], categories=sorted_season_strings, ordered=True)
        print(f"📅 Sorted seasons chronologically: {sorted_season_strings}")
    except (ValueError, TypeError):
        # Fallback to string sorting
        sorted_season_strings = sorted([str(s) for s in unique_seasons if s is not None])
        df[season_col] = pd.Categorical(df[season_col], categories=sorted_season_strings, ordered=True)
        print(f"📅 Sorted seasons as strings: {sorted_season_strings}")

    min_val = min(df["WAR_Delta"].min(), df["WARP_Delta"].min())
    max_val = max(df["WAR_Delta"].max(), df["WARP_Delta"].max())
    buffer = (max_val - min_val) * 0.05

    # Create the enhanced faceted figure
    fig = px.scatter(
        df,
        x="WAR_Delta",
        y="WARP_Delta",
        color="PlayerType",
        symbol="AccuracyZone",
        hover_name="Player",
        facet_col="Model",
        facet_row="PlayerType",
        animation_frame=season_col,
        animation_group="Player",
        title="Enhanced Quadrant Analysis: WAR vs WARP Deltas (Chronological Animation)",
        range_x=[min_val - buffer, max_val + buffer],
        range_y=[min_val - buffer, max_val + buffer],
        width=1200,
        height=800,
        template="seaborn"
    )

    # Convert to go.Figure for advanced customization
    fig = go.Figure(fig)

    # Add quadrant reference lines
    fig.add_hline(y=0, line_dash="dash", line_color="gray", line_width=1)
    fig.add_vline(x=0, line_dash="dash", line_color="gray", line_width=1)

    # Add dual accuracy zone visualization
    accuracy_shapes = []
    
    # Orange cross lines (±1 margins)
    accuracy_shapes.extend([
        dict(type="line", x0=-1, y0=min_val-buffer, x1=-1, y1=max_val+buffer, 
             line=dict(color="orange", width=2, dash="dot")),
        dict(type="line", x0=1, y0=min_val-buffer, x1=1, y1=max_val+buffer, 
             line=dict(color="orange", width=2, dash="dot")),
        dict(type="line", x0=min_val-buffer, y0=-1, x1=max_val+buffer, y1=-1, 
             line=dict(color="orange", width=2, dash="dot")),
        dict(type="line", x0=min_val-buffer, y0=1, x1=max_val+buffer, y1=1, 
             line=dict(color="orange", width=2, dash="dot"))
    ])

    # Green intersection rectangle
    accuracy_shapes.append(
        dict(type="rect", x0=-1, y0=-1, x1=1, y1=1,
             line=dict(color="green", width=2, dash="dash"),
             fillcolor="green", opacity=0.1)
    )

    fig.update_layout(shapes=accuracy_shapes)

    # FIXED: Improved legend positioning (above traces, not covering y-axis)
    fig.update_layout(
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,  # Above the plot
            xanchor="center",
            x=0.5,   # Centered horizontally
            bgcolor="rgba(255,255,255,0.9)",
            bordercolor="gray",
            borderwidth=1
        )
    )

    # FIXED: Reposition animation controls to be usable
    fig.update_layout(
        updatemenus=[
            dict(
                type="buttons",
                direction="left",
                x=0.15,  # Moved right so not covered by legend
                y=1.15,  # Above legend
                showactive=True,
                buttons=[
                    dict(label="All Zones", method="relayout",
                         args=[{"shapes": [{**s, "visible": True} for s in accuracy_shapes]}]),
                    dict(label="Cross Only", method="relayout",
                         args=[{"shapes": [{**s, "visible": True if "line" in s.get("type", "") else False} 
                                          for s in accuracy_shapes]}]),
                    dict(label="Intersection", method="relayout",
                         args=[{"shapes": [{**s, "visible": True if s.get("type") == "rect" else False} 
                                          for s in accuracy_shapes]}]),
                    dict(label="No Zones", method="relayout",
                         args=[{"shapes": [{**s, "visible": False} for s in accuracy_shapes]}])
                ]
            )
        ],
        # FIXED: Position animation controls to the right
        sliders=[dict(
            currentvalue={"prefix": "Year: "},
            x=0.15,  # Moved right
            len=0.7   # Adjusted length
        )]
    )

    fig.show()

    # Enhanced statistical summary with better formatting
    print("\n" + "="*60)
    print("📊 INTERACTIVE QUADRANT ANALYSIS SUMMARY")
    print("="*60)
    
    for model in df["Model"].unique():
        mdf = df[df["Model"] == model]
        total = len(mdf)

        acc_10pct = mdf["In_Accuracy_Zone"].sum()
        both_delta1 = mdf["Both_Delta_1"].sum()
        either_delta1 = mdf["Either_Delta_1"].sum()

        print(f"\n🔍 {model.upper()} MODEL ({total} predictions):")
        print(f"   📈 10% Accuracy Zone (both WAR & WARP): {acc_10pct}/{total} ({acc_10pct/total*100:.1f}%)")
        print(f"   🎯 Delta 1 Cross (either ≤1): {either_delta1}/{total} ({either_delta1/total*100:.1f}%)")
        print(f"   ✅ Delta 1 Intersection (both ≤1): {both_delta1}/{total} ({both_delta1/total*100:.1f}%)")

        # Sample accurate players
        accurate_players = mdf[mdf["In_Accuracy_Zone"]]["Player"].unique()
        if len(accurate_players) > 0:
            sample = ", ".join(list(accurate_players[:3]))
            print(f"   🌟 Sample accurate: {sample}{'...' if len(accurate_players) > 3 else ''}")

    print(f"\n💡 INTERACTIVE FEATURES:")
    print(f"   🖱️  Legend: Click PlayerType/AccuracyZone to show/hide")
    print(f"   🎬 Animation: Chronologically ordered year progression")
    print(f"   🔘 Accuracy Zones: Toggle orange cross vs green intersection")
    print(f"   🎯 Hover: Detailed player performance information")

In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from scipy import stats

def plot_consolidated_model_comparison(model_results, model_names=None, show_residuals=True, show_metrics=True):
    """
    Consolidated model comparison system that replaces individual print_metrics graphs.
    Creates unified visualizations with selectable traces for easy comparison.

    Args:
        model_results: Your ModelResults object
        model_names: List of models to compare (None = auto-select best)
        show_residuals: Whether to include residual analysis plots
        show_metrics: Whether to include scatter plot comparisons

    Returns:
        Dictionary with analysis results for each model
    """
    if model_names is None:
        model_names = select_best_models_by_category(model_results)
        print(f"🎯 Auto-selected models for comparison: {[m.upper() for m in model_names]}")

    print("\n📊 CONSOLIDATED MODEL COMPARISON SYSTEM")
    print("="*70)
    print("🔍 Replacing individual graphs with unified selectable trace visualizations...")

    # Collect all data for consolidated visualizations
    all_data = []
    model_stats = {}

    for model_name in model_names:
        model_stats[model_name] = {}

        for player_type in ['hitter', 'pitcher']:
            for metric in ['war', 'warp']:
                key = f"{model_name}_{player_type}_{metric}"
                if key in model_results.results:
                    data = model_results.results[key]

                    if len(data['y_true']) > 0:
                        y_true = np.array(data['y_true'])
                        y_pred = np.array(data['y_pred'])
                        residuals = y_true - y_pred

                        # Calculate comprehensive statistics
                        rmse = np.sqrt(np.mean(residuals**2))
                        mae = np.mean(np.abs(residuals))
                        r2 = 1 - (np.sum(residuals**2) / np.sum((y_true - np.mean(y_true))**2))

                        # Store statistics
                        model_stats[model_name][f"{player_type}_{metric}"] = {
                            'rmse': rmse,
                            'mae': mae,
                            'r2': r2,
                            'count': len(y_true)
                        }

                        # Add to plotting data
                        for i in range(len(residuals)):
                            all_data.append({
                                'Model': model_name.title(),
                                'PlayerType': player_type.title(),
                                'Metric': metric.upper(),
                                'Category': f"{player_type.title()} {metric.upper()}",
                                'Actual': y_true[i],
                                'Predicted': y_pred[i],
                                'Residual': residuals[i],
                                'Player': data['player_names'][i] if 'player_names' in data else f"Player_{i}"
                            })

    if not all_data:
        print("❌ No data available for consolidated comparison")
        return {}

    df = pd.DataFrame(all_data)

    # Utility: add group toggle buttons
    def add_group_buttons(fig, group_labels):
        n_traces = len(fig.data)
        buttons = []

        for g in group_labels:
            buttons.append(
                dict(
                    label=f"Toggle {g}",
                    method="restyle",
                    args=[
                        {"visible": "toggle"},
                        [i for i, tr in enumerate(fig.data) if tr.legendgroup == g]
                    ]
                )
            )

        # Show all
        buttons.append(
            dict(
                label="Show All",
                method="restyle",
                args=[{"visible": True}, list(range(n_traces))]
            )
        )

        fig.update_layout(
            updatemenus=[dict(
                type="buttons",
                direction="right",
                x=0.5, xanchor="center",
                y=1.15, yanchor="top",
                buttons=buttons
            )]
        )

    # 1. CONSOLIDATED SCATTER PLOTS WITH SELECTABLE TRACES
    if show_metrics:
        print("\n📈 Creating consolidated prediction accuracy plots...")

        fig_scatter = make_subplots(
            rows=2, cols=2,
            subplot_titles=['Hitter WAR', 'Hitter WARP', 'Pitcher WAR', 'Pitcher WARP'],
            vertical_spacing=0.1,
            horizontal_spacing=0.1
        )

        colors = px.colors.qualitative.Set1

        for i, category in enumerate(['Hitter WAR', 'Hitter WARP', 'Pitcher WAR', 'Pitcher WARP']):
            cat_data = df[df['Category'] == category]

            row = (i // 2) + 1
            col = (i % 2) + 1

            for j, model in enumerate(cat_data['Model'].unique()):
                model_data = cat_data[cat_data['Model'] == model]

                fig_scatter.add_trace(
                    go.Scatter(
                        x=model_data['Actual'],
                        y=model_data['Predicted'],
                        mode='markers',
                        name=f"{model} {category}",
                        legendgroup=model,
                        marker=dict(color=colors[j % len(colors)], size=6, opacity=0.7),
                        text=model_data['Player'],
                        hovertemplate="<b>%{text}</b><br>" +
                                      "Actual: %{x:.3f}<br>" +
                                      "Predicted: %{y:.3f}<br>" +
                                      f"Model: {model}<br>" +
                                      "<extra></extra>",
                        showlegend=(i == 0)
                    ),
                    row=row, col=col
                )

                # Add perfect prediction line
                if j == 0:  # Only add once per subplot
                    min_val = min(model_data['Actual'].min(), model_data['Predicted'].min())
                    max_val = max(model_data['Actual'].max(), model_data['Predicted'].max())

                    fig_scatter.add_trace(
                        go.Scatter(
                            x=[min_val, max_val],
                            y=[min_val, max_val],
                            mode='lines',
                            line=dict(dash='dash', color='red', width=2),
                            name='Perfect Prediction',
                            legendgroup='Perfect Prediction',
                            showlegend=(i == 0)
                        ),
                        row=row, col=col
                    )

        fig_scatter.update_layout(
            title="Consolidated Model Comparison: Prediction Accuracy (Click Legend or Use Buttons to Toggle)",
            height=800,
            width=1200,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5)
        )

        add_group_buttons(fig_scatter, df['Model'].unique())
        fig_scatter.show()

    # 2. CONSOLIDATED RESIDUAL ANALYSIS
    if show_residuals:
        print("\n🔍 Creating consolidated residual analysis...")

        fig_residuals = make_subplots(
            rows=2, cols=2,
            subplot_titles=['Residuals vs Fitted', 'Residual Distributions', 'Q-Q Plot', 'Model Performance'],
            specs=[[{"secondary_y": False}, {"secondary_y": False}],
                   [{"secondary_y": False}, {"secondary_y": False}]]
        )

        # Residuals vs Fitted
        for j, model in enumerate(df['Model'].unique()):
            model_data = df[df['Model'] == model]

            fig_residuals.add_trace(
                go.Scatter(
                    x=model_data['Predicted'],
                    y=model_data['Residual'],
                    mode='markers',
                    name=f"{model} Residuals",
                    legendgroup=model,
                    marker=dict(color=colors[j % len(colors)], size=4, opacity=0.6),
                    showlegend=True
                ),
                row=1, col=1
            )

        # Add horizontal line at y=0
        fig_residuals.add_hline(y=0, line_dash="dash", line_color="gray", row=1, col=1)

        # Residual Distributions
        for j, model in enumerate(df['Model'].unique()):
            model_data = df[df['Model'] == model]

            fig_residuals.add_trace(
                go.Histogram(
                    x=model_data['Residual'],
                    name=f"{model} Distribution",
                    legendgroup=model,
                    opacity=0.7,
                    nbinsx=30,
                    showlegend=False
                ),
                row=1, col=2
            )

        # Q-Q Plot (simplified - one model for clarity)
        if len(df['Model'].unique()) > 0:
            best_model = df['Model'].unique()[0]
            best_data = df[df['Model'] == best_model]
            residuals = best_data['Residual'].values

            sorted_residuals = np.sort(residuals)
            n = len(sorted_residuals)
            theoretical_quantiles = stats.norm.ppf(np.linspace(0.01, 0.99, n))

            fig_residuals.add_trace(
                go.Scatter(
                    x=theoretical_quantiles,
                    y=sorted_residuals,
                    mode='markers',
                    name=f"{best_model} Q-Q",
                    legendgroup=best_model,
                    marker=dict(size=4),
                    showlegend=False
                ),
                row=2, col=1
            )

            # Theoretical line
            slope = np.std(residuals)
            intercept = np.mean(residuals)
            line_min, line_max = min(theoretical_quantiles), max(theoretical_quantiles)

            fig_residuals.add_trace(
                go.Scatter(
                    x=[line_min, line_max],
                    y=[intercept + slope * line_min, intercept + slope * line_max],
                    mode='lines',
                    line=dict(color='red', dash='dash'),
                    name='Normal Line',
                    legendgroup='Normal Line',
                    showlegend=False
                ),
                row=2, col=1
            )

        # Model Performance Comparison
        models = list(model_stats.keys())
        metrics = ['R²', 'RMSE', 'MAE']

        for metric_name in metrics:
            metric_values = []
            for model in models:
                # Average across all model variants
                values = []
                for key, model_stat_data in model_stats[model].items():
                    if metric_name == 'R²':
                        values.append(model_stat_data['r2'])
                    elif metric_name == 'RMSE':
                        values.append(model_stat_data['rmse'])
                    elif metric_name == 'MAE':
                        values.append(model_stat_data['mae'])

                metric_values.append(np.mean(values) if values else 0)

            fig_residuals.add_trace(
                go.Bar(
                    x=models,
                    y=metric_values,
                    name=metric_name,
                    legendgroup=metric_name,
                    showlegend=False
                ),
                row=2, col=2
            )

        fig_residuals.update_layout(
            title="Consolidated Residual Analysis (Click Legend or Use Buttons to Toggle Models)",
            height=800,
            width=1200,
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5)
        )

        add_group_buttons(fig_residuals, df['Model'].unique())
        fig_residuals.show()

    # 3. COMPREHENSIVE STATISTICAL SUMMARY
    print("\n📋 CONSOLIDATED MODEL PERFORMANCE SUMMARY")
    print("="*70)

    for model_name in model_names:
        if model_name in model_stats:
            print(f"\n🤖 {model_name.upper()} MODEL:")

            total_predictions = sum(model_stat_data['count'] for model_stat_data in model_stats[model_name].values())
            avg_r2 = np.mean([model_stat_data['r2'] for model_stat_data in model_stats[model_name].values()])
            avg_rmse = np.mean([model_stat_data['rmse'] for model_stat_data in model_stats[model_name].values()])
            avg_mae = np.mean([model_stat_data['mae'] for model_stat_data in model_stats[model_name].values()])

            print(f"   📊 Overall Performance:")
            print(f"      • Total Predictions: {total_predictions}")
            print(f"      • Average R²: {avg_r2:.4f}")
            print(f"      • Average RMSE: {avg_rmse:.4f}")
            print(f"      • Average MAE: {avg_mae:.4f}")

            print(f"   📈 By Category:")
            for key, model_stat_data in model_stats[model_name].items():
                category = key.replace('_', ' ').title()
                print(f"      • {category}: R²={model_stat_data['r2']:.4f}, RMSE={model_stat_data['rmse']:.4f}, Count={model_stat_data['count']}")

    print(f"\n✅ CONSOLIDATED COMPARISON COMPLETE")
    print(f"   📈 Unified scatter plots: All models on same plots with toggleable traces")
    print(f"   🔍 Integrated residual analysis: Comprehensive diagnostic plots")
    print(f"   📊 Statistical summary: Complete performance metrics")
    print(f"   🖱️  Interactive legends: Click to show/hide individual models, use buttons for group control")

    return model_stats


def plot_comprehensive_residual_analysis(model_results, model_names=None, player_type="both", metric="both"):
    """
    Comprehensive residual plot comparison system for ML model diagnostics.

    This function creates multiple residual visualizations to diagnose model performance:
    1. Residuals vs Fitted Values (heteroscedasticity detection)
    2. Q-Q plots (normality assessment)
    3. Residual distribution histograms
    4. Scale-Location plots (variance homogeneity)
    5. Model comparison summary statistics

    Args:
        model_results: Your ModelResults object
        model_names: List of models to compare (None = auto-select best)
        player_type: 'hitter', 'pitcher', or 'both'
        metric: 'war', 'warp', or 'both'

    Returns:
        Dictionary with residual analysis results for each model
    """
    # Use the new consolidated system
    return plot_consolidated_model_comparison(model_results, model_names, show_residuals=True, show_metrics=False)


In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

def plot_quadrant_analysis_px_toggle(model_results, season_col="Season", model_names=None, show_hitters=True, show_pitchers=True):
    """
    Enhanced quadrant analysis using Plotly Express facets with year-over-year animation.
    Dual accuracy zone visualization (orange cross + green intersection) with clickable legend toggles.
    ENHANCED: Comprehensive accuracy analysis with error percentage calculations
    FIXED: Handles missing season data gracefully
    """
    if model_names is None:
        model_names = select_best_models_by_category(model_results)
        print(f"🎯 Auto-selected models: {[m.upper() for m in model_names]}")

    # Collect long-form data (always include both player types for interactive toggling)
    data = []
    for model in model_names:
        for player_type in ["hitter", "pitcher"]:
            for metric in ["war", "warp"]:
                key = f"{model}_{player_type}_{metric}"
                if key in model_results.results:
                    res = model_results.results[key]
                    for i, player in enumerate(res["player_names"]):
                        # FIXED: Handle missing season data gracefully
                        if season_col in res and len(res[season_col]) > i:
                            season_value = res[season_col][i]
                        else:
                            season_value = "2021"  # Default season placeholder

                        data.append({
                            "Player": player,
                            "Model": model.title(),
                            "PlayerType": player_type.title(),
                            season_col: season_value,
                            f"Actual {metric.upper()}": res["y_true"][i],
                            f"Predicted {metric.upper()}": res["y_pred"][i]
                        })

    if not data:
        print("No data available for quadrant analysis.")
        return

    df = pd.DataFrame(data)

    # Enhanced delta and error calculations
    df["WAR_Delta"] = df["Actual WAR"] - df["Predicted WAR"]
    df["WARP_Delta"] = df["Actual WARP"] - df["Predicted WARP"]

    # Error percentage calculations for 10% accuracy zone
    df["WAR_Error_%"] = abs(df["WAR_Delta"]) / df["Actual WAR"].replace(0, float("nan")).abs() * 100
    df["WARP_Error_%"] = abs(df["WARP_Delta"]) / df["Actual WARP"].replace(0, float("nan")).abs() * 100

    # Multiple accuracy zone definitions
    df["In_Accuracy_Zone"] = (df["WAR_Error_%"] <= 10) & (df["WARP_Error_%"] <= 10)  # 10% both
    df["WAR_Delta_1"] = abs(df["WAR_Delta"]) <= 1.0  # Delta 1 margins
    df["WARP_Delta_1"] = abs(df["WARP_Delta"]) <= 1.0
    df["Both_Delta_1"] = df["WAR_Delta_1"] & df["WARP_Delta_1"]  # Green intersection
    df["Either_Delta_1"] = df["WAR_Delta_1"] | df["WARP_Delta_1"]  # Orange cross

    df["AccuracyZone"] = df["In_Accuracy_Zone"].map({True: "≤10% Error Both", False: "Outside Zone"})
    df["Delta1Zone"] = df["Both_Delta_1"].map({True: "Both ≤1", False: "Outside ±1"})

    min_val = min(df["WAR_Delta"].min(), df["WARP_Delta"].min())
    max_val = max(df["WAR_Delta"].max(), df["WARP_Delta"].max())
    buffer = (max_val - min_val) * 0.05

    # Create the base PX figure
    fig = px.scatter(
        df,
        x="WAR_Delta",
        y="WARP_Delta",
        color="PlayerType",
        symbol="AccuracyZone",
        hover_name="Player",
        facet_col="Model",
        facet_row="PlayerType",
        animation_frame=season_col,
        animation_group="Player",
        title="Enhanced Quadrant Analysis: WAR vs WARP Deltas (Click Legend to Toggle)",
        range_x=[min_val - buffer, max_val + buffer],
        range_y=[min_val - buffer, max_val + buffer],
        width=1200,
        height=800,
        template="seaborn"
    )

    # Convert to go.Figure to add shapes and customize legend
    fig = go.Figure(fig)

    # Add quadrant reference lines
    fig.add_hline(y=0, line_dash="dash", line_color="gray")
    fig.add_vline(x=0, line_dash="dash", line_color="gray")

    # Add DUAL accuracy zone visualization (orange cross + green intersection)
    cross_shapes = []  # Orange cross lines (±1 margins)
    intersection_shapes = []  # Green intersection rectangle

    for row in df["PlayerType"].unique():
        for col in df["Model"].unique():
            # Orange cross lines (WAR≤1 OR WARP≤1)
            cross_shapes.extend([
                dict(type="line", x0=-1, y0=-4, x1=-1, y1=4, line=dict(color="orange", width=3, dash="dot")),  # Vertical left
                dict(type="line", x0=1, y0=-4, x1=1, y1=4, line=dict(color="orange", width=3, dash="dot")),    # Vertical right
                dict(type="line", x0=-4, y0=-1, x1=4, y1=-1, line=dict(color="orange", width=3, dash="dot")),  # Horizontal bottom
                dict(type="line", x0=-4, y0=1, x1=4, y1=1, line=dict(color="orange", width=3, dash="dot"))    # Horizontal top
            ])

            # Green intersection rectangle (WAR≤1 AND WARP≤1)
            intersection_shapes.append(
                dict(
                    type="rect",
                    x0=-1, y0=-1, x1=1, y1=1,
                    line=dict(color="green", width=2, dash="dash"),
                    fillcolor="green",
                    opacity=0.1,
                    visible=True
                )
            )

    # Combine all accuracy zone shapes
    all_accuracy_shapes = cross_shapes + intersection_shapes
    fig.update_layout(shapes=all_accuracy_shapes)

    # Configure interactive legend (horizontal, inside plot, top-left)
    fig.update_layout(
        legend=dict(
            orientation="h",
            yanchor="top",
            y=0.98,
            xanchor="left", 
            x=0.01,
            bgcolor="rgba(255,255,255,0.8)",
            bordercolor="gray",
            borderwidth=1
        )
    )

    # Add accuracy zone toggle buttons (keep minimal for zones only)
    fig.update_layout(
        updatemenus=[
            dict(
                type="buttons",
                direction="left",
                x=0.02,
                y=0.02,
                showactive=True,
                buttons=[
                    dict(
                        label="All Zones",
                        method="relayout",
                        args=[{"shapes": [{**s, "visible": True} for s in all_accuracy_shapes]}]
                    ),
                    dict(
                        label="Cross Only",
                        method="relayout",
                        args=[{"shapes": [{**s, "visible": True if s in cross_shapes else False} for s in all_accuracy_shapes]}]
                    ),
                    dict(
                        label="Intersection",
                        method="relayout",
                        args=[{"shapes": [{**s, "visible": True if s in intersection_shapes else False} for s in all_accuracy_shapes]}]
                    ),
                    dict(
                        label="No Zones",
                        method="relayout",
                        args=[{"shapes": [{**s, "visible": False} for s in all_accuracy_shapes]}]
                    )
                ]
            )
        ],
        annotations=[
            dict(text="Accuracy Zones:", x=0.02, y=0.06, xref="paper", yref="paper", align="left", showarrow=False)
        ]
    )

    fig.show()

    # Enhanced statistical summary
    print("=== INTERACTIVE QUADRANT ANALYSIS SUMMARY ===")
    for model in df["Model"].unique():
        mdf = df[df["Model"] == model]
        total = len(mdf)

        acc_10pct = mdf["In_Accuracy_Zone"].sum()
        both_delta1 = mdf["Both_Delta_1"].sum()
        either_delta1 = mdf["Either_Delta_1"].sum()

        print(f"\n{model} ({total} entries):")
        print(f"  10% Accuracy Zone (both): {acc_10pct} ({acc_10pct/total*100:.1f}%)")
        print(f"  Delta 1 Cross (either): {either_delta1} ({either_delta1/total*100:.1f}%)")
        print(f"  Delta 1 Intersection (both): {both_delta1} ({both_delta1/total*100:.1f}%)")

    print(f"\n💡 INTERACTIVE CONTROLS:")
    print(f"   • Legend: Click PlayerType/AccuracyZone items to show/hide traces")
    print(f"   • Accuracy Zone buttons: Toggle visualization overlays")
    print(f"   • Animation: Play through seasons to see temporal patterns")
    print(f"   • Hover: Detailed player information with error percentages")

In [None]:
# ===== EXECUTE THE COMPLETE PIPELINE WITH ENHANCED MODULARIZED FUNCTIONS =====
try:
    
    # Prepare data
    print("\nPreparing data with fuzzy matching and caching...")
    data_splits = prepare_train_test_splits()
    print("Data preparation complete!")
    
    # Run all models using modularized functions
    print("\n1. Running Basic Regression Models (including NEW Ridge)...")
    run_basic_regressions(data_splits, model_results, print_metrics, plot_results)
    
    print("\n2. Running Advanced Models...")
    run_advanced_models(data_splits, model_results, print_metrics, plot_results)
    
    # print("\n3. Running NEW Ensemble Models (AdaBoost)...")
    # run_ensemble_models(data_splits, model_results, print_metrics, plot_results)
    
    print("\n4. Running NEW Non-linear Models (SVR + Gaussian Process)...")
    run_nonlinear_models(data_splits, model_results, print_metrics, plot_results)

    print("\n5. Running Neural Network with AdamW optimizer...")
    run_neural_network(data_splits, model_results, print_metrics, plot_results, plot_training_history)
    
    # Apply PROPER WAR adjustments with real position data
    print("\n6. Applying PROPER WAR Adjustments with Real Position Data...")
    adjusted_model_results = apply_proper_war_adjustments(model_results)

    # Generate ENHANCED quadrant analysis with all missing features integrated
    print("\n7. Generating ENHANCED Analysis (comprehensive statistics + dual accuracy zones)...")
    
    # Enhanced faceted quadrant analysis with dual zones and player type toggles
    print("\n📊 ENHANCED QUADRANT ANALYSIS (Faceted with Dual Accuracy Zones):")
    plot_quadrant_analysis_px_toggle(
        adjusted_model_results, 
        show_hitters=True, 
        show_pitchers=True  # Use False to reduce data density if needed
    )
    
    # Enhanced animated analysis with comprehensive statistics
    print("\n🎬 ENHANCED ANIMATED ANALYSIS (Comprehensive Statistics + Dual Zones):")
    plot_war_warp_animated(
        adjusted_model_results,
        show_hitters=True,
        show_pitchers=True  # Use False to reduce data density if needed  
    )
    
    print("\n🎉 COMPLETE ENHANCED MODEL SUITE TESTING FINISHED!")
    print("   Total algorithms tested: 10")
    print("   • Linear methods: Ridge, ElasticNet")
    print("   • Tree/Ensemble: KNN, Random Forest, XGBoost") 
    print("   • Non-linear: SVR")
    print("   • Neural: Keras with AdamW")
    print("\n✨ ENHANCED ANALYSIS FEATURES:")
    print("   • Dual accuracy zones: Orange cross (±1 margins) + Green intersection")
    print("   • Comprehensive statistics: All metrics from original quadrant analysis")
    print(r"   • Error percentage calculations: 10% accuracy zone analysis")
    print("   • Toggleable player types: Reduce data density when needed")
    print("   • Sample player examples: Shows which players fall in accuracy zones")
    print("   • Year-over-year consistency: Temporal accuracy patterns")
    print("   • Faceted visualization: Player type × Model matrix view")
    print("   • Animation-ready: Seamless temporal visualization")
    
    print("\n💡 USAGE TIPS:")
    print("   • Use show_hitters=False to focus on pitchers only")
    print("   • Use show_pitchers=False to focus on hitters only") 
    print("   • Toggle accuracy zone buttons in faceted plot for different views")
    print("   • Animation shows prediction evolution over seasons")
    
except Exception as e:
    print(f"\n❌ Error: {e}")
    import traceback
    traceback.print_exc()