## Chicago Traffic Crash Analysis & Prediction


#### Part 1: Setup & Environment Configuration

In [None]:

# Data manipulation
import numpy as np
import pandas as pd
from datetime import datetime
import time
import os
import pickle

# Visualization libraries
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Machine learning libraries
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score, StratifiedKFold, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    classification_report, roc_curve, precision_recall_curve, 
    auc, average_precision_score, confusion_matrix,
    mean_absolute_error, mean_squared_error, r2_score, median_absolute_error,
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
)
from sklearn.feature_selection import mutual_info_classif, RFE
from sklearn.inspection import permutation_importance

# ML Models - Classification
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC, LinearSVC
from sklearn.neural_network import MLPClassifier 
from sklearn.naive_bayes import GaussianNB
from sklearn.dummy import DummyClassifier

# ML Models - Regression
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet, PoissonRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.dummy import DummyRegressor

# Advanced ML libraries
import lightgbm as lgb
import xgboost as xgb
import catboost as cb
import tensorflow as tf
import shap
from folium.plugins import HeatMap
import folium

# Constants
random_seed = 45665456
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('viridis')
# Matplotlib updated syntax for colormap (fixing deprecation warning)
def get_viridis_colormap(n=10):
    """Return viridis colormap using updated Matplotlib syntax"""
    return plt.colormaps['viridis']  # Updated from plt.cm.get_cmap

#### Part 2: Viridis Color Standardization

In [None]:
# === Enhanced Color Standardization with Lighter Palette ===
def get_viridis_colors(n_colors=10, start=0.2, end=0.8):
    """
    Get n colors from the viridis colormap with customizable range.
    Using a more limited range (0.2-0.8) creates lighter colors.
    """
    return [plt.cm.viridis(i) for i in np.linspace(start, end, n_colors)]

# Define standard colors for consistent usage with lighter shades
VIRIDIS_COLORS = {
    # Primary palette (much lighter versions)
    'main': '#6a50a7',       # Light purple
    'secondary': '#52c2c4',  # Light teal
    'tertiary': '#8ddc6e',   # Light green
    'highlight': '#fee04c',  # Light yellow
    
    # Very light shades for fills and backgrounds
    'light_main': '#b9a5dd',     # Very light purple
    'light_secondary': '#a8e7e8', # Very light teal
    'light_tertiary': '#c9edb0',  # Very light green
    'light_highlight': '#fff3a6', # Very light yellow
    
    # Extra light background colors
    'bg_main': '#f3effa',     # Extra light purple
    'bg_secondary': '#f0fafa', # Extra light teal
    'bg_tertiary': '#f7fcf1',  # Extra light green
    'bg_highlight': '#fffbec', # Extra light yellow
    
    # Categorical color palettes (all lighter)
    'categorical': get_viridis_colors(10, 0.2, 0.8),  # Light palette
    'categorical_light': get_viridis_colors(10, 0.3, 0.7),  # Even lighter
    'categorical_pastel': get_viridis_colors(10, 0.4, 0.7)  # Pastel palette
}

# Function to create an alpha version of any color
def with_alpha(color, alpha=0.5):
    """Create a transparent version of a color."""
    if color.startswith('#'):
        # Convert hex to RGBA
        r = int(color[1:3], 16) / 255.0
        g = int(color[3:5], 16) / 255.0
        b = int(color[5:7], 16) / 255.0
        return (r, g, b, alpha)
    else:
        # Assume it's already an RGBA tuple
        return (*color[:3], alpha)

# Standardize matplotlib settings with improved defaults
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 13
plt.rcParams['xtick.labelsize'] = 11
plt.rcParams['ytick.labelsize'] = 11
plt.rcParams['axes.facecolor'] = '#fcfcfc'  # Nearly white background
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.2  # Lighter grid
plt.rcParams['axes.spines.top'] = False  # Remove top spine
plt.rcParams['axes.spines.right'] = False  # Remove right spine

# Create a style function for consistent plot styling with spine color control
def apply_viridis_style(ax=None):
    """Apply the standardized viridis style with lighter colors to the given axes."""
    if ax is None:
        ax = plt.gca()
    
    # Style improvements
    ax.set_facecolor('#fcfcfc')  # Nearly white background
    ax.grid(True, alpha=0.2, linestyle='--', color='#cccccc')  # Lighter grid
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Set spine colors manually - this is the correct way to set spine colors
    ax.spines['left'].set_color('#dddddd')   # Light gray spine
    ax.spines['bottom'].set_color('#dddddd') # Light gray spine
    
    return ax

#### Part 3: Data Loading & Preprocessing

In [None]:
# === Data Loading ===
def load_data(file_path):
    """Load crash data with proper types."""
    print(f"Loading data from {file_path}...")
    df = pd.read_csv(file_path)
    
    # Convert date columns to datetime
    date_cols = ['CRASH_DATE', 'DATE_POLICE_NOTIFIED']
    for col in date_cols:
        if col in df.columns:
            df[col] = pd.to_datetime(df[col], errors='coerce')
            
    print(f"Loaded {len(df)} records with {df.columns.size} columns")
    return df

In [None]:
# === Preprocessing Functions ===
def preprocess_data(df):
    """Comprehensive preprocessing function."""
    # Make a copy to avoid modifying the original
    df_processed = df.copy()
    
    # 1. Standardize column names
    df_processed.columns = df_processed.columns.str.upper().str.replace(' ', '_')
    
    # 2. Extract datetime components
    df_processed['CRASH_DATETIME'] = df_processed['CRASH_DATE'] + pd.to_timedelta(df_processed['CRASH_HOUR'], unit='h')
    
    # 3. Derive time-based features
    df_processed['YEAR'] = df_processed['CRASH_DATETIME'].dt.year
    df_processed['MONTH'] = df_processed['CRASH_DATETIME'].dt.month
    df_processed['DAY'] = df_processed['CRASH_DATETIME'].dt.day
    df_processed['HOUR'] = df_processed['CRASH_DATETIME'].dt.hour
    df_processed['DAY_OF_WEEK'] = df_processed['CRASH_DATETIME'].dt.day_name()
    
    # 4. Create severity flag
    df_processed['SEVERE'] = ((df_processed['INJURIES_TOTAL'] > 0) | 
                             (df_processed['INJURIES_FATAL'] > 0)).astype(int)
    
    # 5. Create time-of-day category
    df_processed['TIME_OF_DAY'] = df_processed['HOUR'].apply(get_time_of_day)
    
    # 6. Create seasonal category
    df_processed['SEASON'] = df_processed['MONTH'].apply(get_season)
    
    # 7. Additional flags for analysis
    df_processed['IS_WEEKEND'] = df_processed['CRASH_DAY_OF_WEEK'].isin([6,7]).astype(int)
    df_processed['IS_NIGHTTIME'] = ((df_processed['HOUR'] < 6) | 
                                   (df_processed['HOUR'] >= 20)).astype(int)
    df_processed['IS_VULNERABLE'] = df_processed['FIRST_CRASH_TYPE'].isin(
        ['PEDESTRIAN', 'PEDALCYCLIST']).astype(int)
    
    # 8. Weather conditions
    bad_weather = ['RAIN', 'SNOW', 'SLEET', 'FREEZING RAIN']
    df_processed['BAD_WEATHER'] = df_processed['WEATHER_CONDITION'].isin(
        bad_weather).astype(int)
    
    # 9. Surface conditions
    bad_surface = ['ICE', 'SNOW OR SLUSH', 'WET']
    df_processed['BAD_SURFACE'] = df_processed['ROADWAY_SURFACE_COND'].isin(
        bad_surface).astype(int)
    
    # 10. Intersection flag
    df_processed['AT_INTERSECTION'] = (df_processed['INTERSECTION_RELATED_I'] == 'Y').astype(int)
    
    # 11. Fill common NA values
    for col in ['INJURIES_TOTAL', 'INJURIES_FATAL', 'INJURIES_INCAPACITATING',
                'INJURIES_NON_INCAPACITATING', 'INJURIES_REPORTED_NOT_EVIDENT']:
        if col in df_processed.columns:
            df_processed[col] = df_processed[col].fillna(0)
    
    return df_processed

def get_time_of_day(hour):
    """Convert hour to time of day category."""
    if 5 <= hour < 12:
        return 'Morning'
    elif 12 <= hour < 17:
        return 'Afternoon'
    elif 17 <= hour < 21:
        return 'Evening'
    else:
        return 'Night'

def get_season(month):
    """Convert month to season."""
    if month in [12, 1, 2]:
        return 'Winter'
    elif month in [3, 4, 5]:
        return 'Spring'
    elif month in [6, 7, 8]:
        return 'Summer'
    else:
        return 'Fall'

#### Part 4: Missing Data Analysis

In [None]:
# === Missing Data Analysis ===
def analyze_missing_data(df, plot=True):
    """Analyze and visualize missing data."""
    # Calculate missing counts and percentages
    missing_counts = df.isnull().sum()
    missing_counts = missing_counts[missing_counts > 0].sort_values(ascending=False)
    missing_pct = (missing_counts / len(df) * 100).round(2)
    
    # Create a DataFrame for display
    missing_df = pd.DataFrame({
        'missing_count': missing_counts,
        'missing_pct': missing_pct
    })
    
    # Display top columns with missing data
    print(f"Found {len(missing_df)} columns with missing values")
    print("\nTop columns with missing values:")
    print(missing_df.head(10))
    
    if plot and not missing_df.empty:
        # Plot the top 20 columns by percentage missing
        plt.figure(figsize=(12, 6))
        missing_df['missing_pct'].head(20).plot(kind='bar', color=VIRIDIS_COLORS['main'])
        plt.title('Top Columns by Percentage of Missing Values', fontsize=14)
        plt.ylabel('Percent Missing', fontsize=12)
        plt.xlabel('Column', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.show()
    
    return missing_df

#### Part 5: EDA Visualizations

In [None]:

# === Temporal Analysis Functions ===
def plot_crashes_by_hour(df, color_palette=None):
    """Plot crash counts by hour of day with enhanced visualization."""
    hour_col = 'CRASH_HOUR' if 'CRASH_HOUR' in df.columns else 'HOUR'
    hour_counts = df[hour_col].value_counts().sort_index()
    
    # Calculate the average count
    average = hour_counts.mean()
    
    # Create color gradient
    if color_palette is None:
        # Updated syntax for Matplotlib 3.7+
        viridis = plt.colormaps['viridis']  # Updated from plt.cm.get_cmap('viridis', 24)
        colors = [viridis(i/23) for i in range(24)]
    else:
        colors = color_palette
        
    plt.figure(figsize=(12, 6))
    bars = plt.bar(hour_counts.index, hour_counts.values, width=0.7, color=colors)
    
    # Add average line
    plt.axhline(y=average, color='#555555', linestyle='--', linewidth=1.5)
    plt.text(23, average * 1.05, f'Average: {average:.0f}',
             ha='right', va='bottom', color='black', fontsize=10,
             bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.3'))
    
    # Add peak annotation
    peak_hour = hour_counts.idxmax()
    peak_value = hour_counts.max()
    plt.annotate(f'Peak: {peak_value:,} crashes',
                xy=(peak_hour, peak_value),
                xytext=(peak_hour - 3, peak_value * 0.8),
                arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"),
                bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7))
    
    # Highlight rush hours
    rush_periods = [
        {"period": "Morning Rush", "start": 7, "end": 9},
        {"period": "Evening Rush", "start": 16, "end": 18}
    ]
    for period in rush_periods:
        plt.axvspan(period["start"], period["end"], alpha=0.2, color='gray')
        plt.text((period["start"] + period["end"]) / 2, max(hour_counts) * 0.9,
                period["period"], ha='center', fontsize=10)
    
    # Customize plot
    plt.title('Crash Counts by Hour of Day', fontsize=14, pad=20)
    plt.xlabel('Hour of Day (24h)', fontsize=12)
    plt.ylabel('Number of Crashes', fontsize=12)
    plt.xticks(range(24))
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
# %%
def plot_crashes_by_day_of_week(df):
    """Plot crash counts by day of week with improved visualization."""
    # Get day of week counts
    dow_mapping = {
        'Monday': 0, 'Tuesday': 1, 'Wednesday': 2, 'Thursday': 3,
        'Friday': 4, 'Saturday': 5, 'Sunday': 6
    }
    
    # Check if we have string day names or numeric values
    if 'DAY_OF_WEEK' in df.columns and df['DAY_OF_WEEK'].dtype == 'object':
        df_temp = df.copy()
        df_temp['DOW_NUM'] = df_temp['DAY_OF_WEEK'].map(dow_mapping)
        dow_counts = df_temp.groupby('DOW_NUM').size()
    else:
        dow_counts = df['CRASH_DAY_OF_WEEK'].value_counts().sort_index()
    
    # Get day names in correct order
    days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
    
    # Create colors from the viridis palette
    colors = get_viridis_colors(7)
    
    plt.figure(figsize=(10, 6))
    
    # Create the bar chart
    bars = plt.bar(days, dow_counts, color=colors)
    
    # Add data labels on top of each bar
    for i, (count, bar) in enumerate(zip(dow_counts, bars)):
        plt.text(bar.get_x() + bar.get_width()/2., count + dow_counts.max()*0.01,
                f"{count:,}", ha='center', va='bottom', fontsize=10)
    
    # Highlight weekends
    plt.axvspan(4.5, 6.5, alpha=0.2, color='gray')
    plt.text(5.5, dow_counts.max() * 0.9, "Weekend", ha='center',
            bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7))
    
    # Customize the plot
    plt.title('Crashes by Day of Week', fontsize=14, pad=20)
    plt.xlabel('Day of Week', fontsize=12)
    plt.ylabel('Number of Crashes', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()

# %%
def plot_monthly_trend(df, year=None):
    """Plot monthly crash trends with seasonal highlights."""
    # Filter by year if specified
    if year is not None:
        df_year = df[df['YEAR'] == year]
    else:
        df_year = df.copy()
    
    # Group by month
    monthly_counts = df_year.groupby('MONTH').size()
    
    # Make sure all months are represented
    all_months = pd.Series(index=range(1, 13), data=0)
    monthly_counts = monthly_counts.add(all_months, fill_value=0)
    
    # Month names
    month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                  'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    
    # Create viridis color gradient
    colors = plt.cm.viridis(np.linspace(0, 1, 12))
    
    plt.figure(figsize=(12, 6))
    
    # Plot the bars
    bars = plt.bar(range(1, 13), monthly_counts, color=colors)
    
    # Add data labels on top of each bar
    for i, (count, bar) in enumerate(zip(monthly_counts, bars)):
        plt.text(i+1, count + max(monthly_counts)*0.02,
                f"{count:,}", ha='center', fontsize=10)
    
    # Add seasonal bands
    seasons = [
        {"name": "Winter", "start": 0.5, "end": 2.5, "color": "#e6f2ff"},
        {"name": "Spring", "start": 2.5, "end": 5.5, "color": "#e6ffe6"},
        {"name": "Summer", "start": 5.5, "end": 8.5, "color": "#ffebcc"},
        {"name": "Fall", "start": 8.5, "end": 11.5, "color": "#f2e6ff"},
        {"name": "Winter", "start": 11.5, "end": 12.5, "color": "#e6f2ff"}
    ]
    
    for season in seasons:
        plt.axvspan(season["start"], season["end"], alpha=0.2, color=season["color"])
        # Only add label if season spans enough space
        if season["end"] - season["start"] > 1:
            plt.text((season["start"] + season["end"]) / 2, max(monthly_counts) * 0.9,
                    season["name"], ha='center', fontsize=10,
                    bbox=dict(boxstyle="round,pad=0.2", fc="white", alpha=0.7))
    
    # Add average line
    avg_crashes = monthly_counts.mean()
    plt.axhline(y=avg_crashes, color='#333333', linestyle='--', alpha=0.7)
    plt.text(12, avg_crashes * 1.05, f"Monthly Average: {avg_crashes:,.0f}",
            ha='right', fontsize=10,
            bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7))
    
    # Customize the plot
    plt.title(f'Monthly Crash Counts {f"({year})" if year else ""}', fontsize=14, pad=20)
    plt.xlabel('Month', fontsize=12)
    plt.ylabel('Number of Crashes', fontsize=12)
    plt.xticks(range(1, 13), month_names)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()

# %%
def plot_temporal_heatmap(df):
    """Create a heatmap of crashes by hour and day of week."""
    # Create pivot table
    if 'DAY_OF_WEEK' in df.columns and df['DAY_OF_WEEK'].dtype == 'object':
        # Map day names to numbers for proper ordering
        dow_mapping = {
            'Monday': 0, 'Tuesday': 1, 'Wednesday': 2, 'Thursday': 3,
            'Friday': 4, 'Saturday': 5, 'Sunday': 6
        }
        df_temp = df.copy()
        df_temp['DOW_NUM'] = df_temp['DAY_OF_WEEK'].map(dow_mapping)
        pivot = pd.pivot_table(
            df_temp, values='CRASH_RECORD_ID', index='HOUR',
            columns='DOW_NUM', aggfunc='count', fill_value=0
        )
        # Rename columns back to day names
        reverse_mapping = {v: k for k, v in dow_mapping.items()}
        pivot.columns = [reverse_mapping[col] for col in pivot.columns]
    else:
        # Using numeric day of week
        pivot = pd.pivot_table(
            df, values='CRASH_RECORD_ID', index='HOUR',
            columns='CRASH_DAY_OF_WEEK', aggfunc='count', fill_value=0
        )
        # Map numeric days to names if needed
        if pivot.columns.dtype == 'int64' or pivot.columns.dtype == 'int32':
            day_names = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
            pivot.columns = [day_names[i-1] for i in pivot.columns]
    
    # Ensure we have all hours
    all_hours = range(0, 24)
    pivot = pivot.reindex(all_hours, fill_value=0)
    
    plt.figure(figsize=(12, 8))
    # Create heatmap with improved colormap - updated syntax
    sns.heatmap(
        pivot,
        cmap='viridis',  # Using string name instead of plt.cm.viridis
        linewidths=0.5,
        annot=False,  # Too cluttered with annotations
        fmt='d',
        cbar_kws={'label': 'Number of Crashes'}
    )
    
    # Customize the plot
    plt.title('Crashes by Hour and Day of Week', fontsize=16, pad=20)
    plt.xlabel('Day of Week', fontsize=12)
    plt.ylabel('Hour of Day (24h)', fontsize=12)
    
    # Add annotations for key patterns
    # Morning rush hour
    plt.annotate(
        'Morning\nRush Hour',
        xy=(3.5, 8), xytext=(3.5, 6),
        arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"),
        bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7),
        ha='center', va='center'
    )
    
    # Evening rush hour
    plt.annotate(
        'Evening\nRush Hour',
        xy=(3.5, 17), xytext=(3.5, 20),
        arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-.2"),
        bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7),
        ha='center', va='center'
    )
    
    # Weekend nighttime
    plt.annotate(
        'Weekend\nNight Activity',
        xy=(5.5, 1), xytext=(4, 3),
        arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"),
        bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7),
        ha='center', va='center'
    )
    
    plt.tight_layout()
    plt.show()

# %%
def plot_crash_types(df, n=10):
    """Plot top N crash types with Viridis colors."""
    crash_counts = df['FIRST_CRASH_TYPE'].value_counts().nlargest(n)
    
    plt.figure(figsize=(12, 6))
    
    # Create color gradient
    colors = plt.cm.viridis(np.linspace(0, 0.8, len(crash_counts)))
    
    # Create horizontal bars for better readability
    bars = plt.barh(
        y=crash_counts.index[::-1],  # Reverse to put largest at top
        width=crash_counts.values[::-1],
        color=colors
    )
    
    # Add data labels
    for i, bar in enumerate(bars):
        width = bar.get_width()
        plt.text(
            width + (max(crash_counts) * 0.01),
            bar.get_y() + bar.get_height()/2,
            f"{width:,}",
            ha='left', va='center', fontsize=10
        )
    
    # Customize plot
    plt.title(f'Top {n} Crash Types', fontsize=14, pad=20)
    plt.xlabel('Number of Crashes', fontsize=12)
    plt.ylabel('Crash Type', fontsize=12)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()

# %%
def plot_severity_by_factor(df, factor_col, top_n=None, normalize=True):
    """
    Plot severity proportions by a given factor.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing the data
    factor_col : str
        Column name to group by
    top_n : int, optional
        Limit to top N categories by count
    normalize : bool, default=True
        Whether to normalize to show proportions instead of counts
    """
    # Get factor counts
    factor_counts = df[factor_col].value_counts()
    
    # Limit to top N if specified
    if top_n is not None:
        top_factors = factor_counts.nlargest(top_n).index
        df_plot = df[df[factor_col].isin(top_factors)]
    else:
        df_plot = df
    
    # Create cross-tabulation
    if normalize:
        ct = pd.crosstab(df_plot[factor_col], df_plot['SEVERE'], normalize='index')
        ylabel = 'Proportion'
    else:
        ct = pd.crosstab(df_plot[factor_col], df_plot['SEVERE'])
        ylabel = 'Count'
    
    # Ensure we have both severity classes
    if 0 not in ct.columns:
        ct[0] = 0
    if 1 not in ct.columns:
        ct[1] = 0
    
    # Sort by severity proportion
    ct = ct.sort_values(by=1, ascending=False)
    
    # Plot
    plt.figure(figsize=(12, 6))
    
    # Create the stacked bar chart with viridis colors
    ct.plot(
        kind='bar', 
        stacked=True, 
        color=[VIRIDIS_COLORS['tertiary'], VIRIDIS_COLORS['main']], 
        ax=plt.gca()
    )
    
    # Customize the plot
    plt.title(f'Crash Severity by {factor_col.replace("_", " ").title()}', fontsize=14, pad=20)
    plt.xlabel(factor_col.replace('_', ' ').title(), fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.legend(['Non-Severe', 'Severe'], title='Severity')
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return ct

# %%
def plot_map_hotspots(df, zoom=10):
    """
    Create an interactive map showing crash hotspots.
    Requires plotly for interactive visualization.
    """
    # Filter records with valid coordinates
    map_df = df.dropna(subset=['LATITUDE', 'LONGITUDE'])
    
    # Check for invalid coordinates (sometimes 0,0 is used for missing)
    map_df = map_df[(map_df['LATITUDE'] != 0) & (map_df['LONGITUDE'] != 0)]
    
    # Create color mapping for severity
    map_df['color'] = map_df['SEVERE'].map({
        0: VIRIDIS_COLORS['tertiary'],
        1: VIRIDIS_COLORS['main']
    })
    
    # Create interactive map
    fig = px.density_mapbox(
        map_df, 
        lat='LATITUDE', 
        lon='LONGITUDE', 
        z='SEVERE',
        radius=10,
        center=dict(lat=map_df['LATITUDE'].mean(), lon=map_df['LONGITUDE'].mean()),
        zoom=zoom,
        mapbox_style="carto-positron",
        title='Crash Hotspots by Severity',
        opacity=0.7,
        color_continuous_scale='viridis'
    )
    
    fig.update_layout(
        margin={"r":0,"t":50,"l":0,"b":0},
        coloraxis_colorbar=dict(
            title="Severity Density",
            thicknessmode="pixels", thickness=20,
            lenmode="pixels", len=300
        )
    )
    
    return fig

def plot_injuries_by_speed_bin(df):
    """Plot boxplot of injuries by speed limit bin."""
    if 'SPD_BIN' not in df.columns and 'POSTED_SPEED_LIMIT' in df.columns:
        # Create speed bins if they don't exist
        bins = [0, 20, 30, 40, 50, 60, 100]
        labels = ['<20', '20–30', '30–40', '40–50', '50–60', '60+']
        df = df.copy()
        df['SPD_BIN'] = pd.cut(df['POSTED_SPEED_LIMIT'], bins=bins, labels=labels, right=False)
    
    plt.figure(figsize=(8, 4))
    # Filter to only incidents with injuries
    injured_df = df[df['INJURIES_TOTAL'] > 0]
    
    # Create boxplot
    sns.boxplot(x='SPD_BIN', y='INJURIES_TOTAL', data=injured_df, 
                palette=VIRIDIS_COLORS['categorical'])
    
    # Customize plot
    plt.xlabel('Speed Limit Bin (mph)', fontsize=12)
    plt.ylabel('Total Injuries', fontsize=12)
    plt.title('Injuries by Speed Limit Bin', fontsize=14, pad=20)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
def plot_injury_rate_by_crash_type(df, top_n=10):
    """Plot injury rates by crash type."""
    if 'FIRST_CRASH_TYPE' not in df.columns or 'injury_flag' not in df.columns:
        print("Required columns not found in DataFrame")
        if 'INJURIES_TOTAL' in df.columns and 'FIRST_CRASH_TYPE' in df.columns:
            df = df.copy()
            df['injury_flag'] = (df['INJURIES_TOTAL'] > 0).astype(int)
        else:
            return
    
    # Compute injury rate by crash type
    injury_by_type = (
        df
        .groupby('FIRST_CRASH_TYPE')['injury_flag']
        .mean()
        .sort_values(ascending=False)
    )
    
    # Limit to top N types
    if top_n is not None:
        injury_by_type = injury_by_type.head(top_n)
    
    plt.figure(figsize=(10, 5))
    # Use standardized color palette
    sns.barplot(
        x=injury_by_type.values * 100,  # convert to percent
        y=injury_by_type.index,
        palette='rocket'
    )
    
    # Customize plot
    plt.xlabel('Percent of Crashes with ≥1 Injury', fontsize=12)
    plt.ylabel('Crash Type', fontsize=12)
    plt.title('Injury Rate by First Crash Type', fontsize=14, pad=20)
    plt.xlim(0, injury_by_type.max() * 100 + 5)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return injury_by_type

def plot_injury_rate_by_cause(df, top_n=10):
    """Plot injury rates by primary contributory cause."""
    if 'PRIM_CONTRIBUTORY_CAUSE' not in df.columns:
        print("Primary contributory cause column not found")
        return
    
    if 'injury_flag' not in df.columns and 'INJURIES_TOTAL' in df.columns:
        df = df.copy()
        df['injury_flag'] = (df['INJURIES_TOTAL'] > 0).astype(int)
    
    # First find the top N most frequent causes
    top_causes = df['PRIM_CONTRIBUTORY_CAUSE'].value_counts().nlargest(top_n).index
    
    # Calculate injury rate for each top cause
    injury_by_cause = (
        df[df['PRIM_CONTRIBUTORY_CAUSE'].isin(top_causes)]
        .groupby('PRIM_CONTRIBUTORY_CAUSE')['injury_flag']
        .mean()
        .sort_values(ascending=False)
    )
    
    plt.figure(figsize=(10, 5))
    sns.barplot(
        x=injury_by_cause.values * 100,
        y=injury_by_cause.index,
        palette='mako'
    )
    
    # Customize plot
    plt.xlabel('Percent of Crashes with ≥1 Injury', fontsize=12)
    plt.ylabel('Primary Contributory Cause', fontsize=12)
    plt.title(f'Injury Rate by Top {top_n} Contributory Causes', fontsize=14, pad=20)
    plt.xlim(0, injury_by_cause.max() * 100 + 5)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return injury_by_cause
def plot_injury_severity_by_weather(df, top_n=6):
    """Plot proportions of injury severity levels per weather condition."""
    if 'WEATHER_CONDITION' not in df.columns or 'MOST_SEVERE_INJURY' not in df.columns:
        print("Required columns not found in DataFrame")
        return
    
    # Get the top N most common weather conditions
    top_weather = df['WEATHER_CONDITION'].value_counts().nlargest(top_n).index.tolist()
    
    # Filter for only those weather conditions
    weather_filtered = df[df['WEATHER_CONDITION'].isin(top_weather)]
    
    # Build the cross-tab of proportions with better ordering
    ct = pd.crosstab(
        weather_filtered['WEATHER_CONDITION'],
        weather_filtered['MOST_SEVERE_INJURY'],
        normalize='index'   # each row sums to 1 → proportions
    )
    
    # Define a custom color palette that's colorblind-friendly
    colors = sns.color_palette("viridis", len(ct.columns))
    
    # Sort rows by severity if 'FATAL' is present
    if 'FATAL' in ct.columns:
        ct = ct.sort_values(by='FATAL', ascending=False)
    
    # Plot as a horizontal stacked bar chart with improved formatting
    ax = ct.plot(
        kind='barh',
        stacked=True,
        figsize=(12, 8),
        color=colors,
        linewidth=0.5,
        edgecolor='white'  # thin white edge between segments
    )
    
    # Add percentage labels to segments (for segments > 5%)
    for i, row in enumerate(ct.values):
        xpos = 0
        for j, val in enumerate(row):
            if val >= 0.05:  # Only label segments that are at least 5%
                ax.text(xpos + val/2, i, f'{val:.0%}', 
                        ha='center', va='center', color='white', fontweight='bold')
            xpos += val
    
    # Improve labels and legend
    plt.xlabel('Proportion of Crashes', fontsize=12)
    plt.ylabel('Weather Condition', fontsize=12)
    plt.title('Crash Severity Distribution by Weather Condition', fontsize=14, pad=20)
    plt.legend(title='Most Severe Injury', bbox_to_anchor=(1.05, 1.0))
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return ct

# === visualization for time series patterns ===
def create_time_period_visualizations(df):
    """Create multiple visualizations for temporal patterns in crashes."""
    # Set up figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Define viridis-based color palette
    colors = [plt.cm.viridis(0.2), plt.cm.viridis(0.7)]
    
    # 1. Weekend vs Weekday visualization
    if 'is_weekend' in df.columns:
        weekend_col = 'is_weekend'
    elif 'IS_WEEKEND' in df.columns:
        weekend_col = 'IS_WEEKEND'
    else:
        # Create it if not present
        if 'CRASH_DAY_OF_WEEK' in df.columns:
            df['is_weekend'] = df['CRASH_DAY_OF_WEEK'].isin([5, 6]).astype(int)
            weekend_col = 'is_weekend'
        elif 'DAY_OF_WEEK' in df.columns:
            df['is_weekend'] = df['DAY_OF_WEEK'].isin(['Saturday', 'Sunday']).astype(int)
            weekend_col = 'is_weekend'
        else:
            weekend_col = None
            
    if weekend_col:
        weekend_counts = df[weekend_col].value_counts().sort_index()
        weekend_pct = (weekend_counts / weekend_counts.sum() * 100).round(1)
        bar1 = sns.barplot(x=[0, 1], y=weekend_counts.values, ax=axes[0, 0], palette=colors)
        axes[0, 0].set_title('Crashes by Day Type', fontsize=14)
        axes[0, 0].set_xlabel('Weekend Flag', fontsize=12)
        axes[0, 0].set_ylabel('Number of Crashes', fontsize=12)
        axes[0, 0].set_xticklabels(['Weekday', 'Weekend'])
        # Add count and percentage labels
        for i, (count, pct) in enumerate(zip(weekend_counts, weekend_pct)):
            axes[0, 0].text(i, count/2, f'{count:,}\n({pct:.1f}%)',
                          ha='center', va='center', fontsize=12, color='white', fontweight='bold')
    else:
        axes[0, 0].text(0.5, 0.5, "Weekend data not available", 
                       ha='center', va='center', transform=axes[0, 0].transAxes)
    
    # 2. Rush hour visualization
    # Create rush hour data if not present
    if 'is_rush_hour' not in df.columns and 'IS_RUSH_HOUR' not in df.columns:
        if 'CRASH_HOUR' in df.columns:
            df['is_rush_hour'] = (
                df['CRASH_HOUR'].between(7, 9) | 
                df['CRASH_HOUR'].between(16, 18)
            ).astype(int)
            rush_hour_col = 'is_rush_hour'
        elif 'HOUR' in df.columns:
            df['is_rush_hour'] = (
                df['HOUR'].between(7, 9) | 
                df['HOUR'].between(16, 18)
            ).astype(int)
            rush_hour_col = 'is_rush_hour'
        else:
            rush_hour_col = None
    else:
        rush_hour_col = 'is_rush_hour' if 'is_rush_hour' in df.columns else 'IS_RUSH_HOUR'
        
    if rush_hour_col:
        rush_counts = df[rush_hour_col].value_counts().sort_index()
        rush_pct = (rush_counts / rush_counts.sum() * 100).round(1)
        bar2 = sns.barplot(x=[0, 1], y=rush_counts.values, ax=axes[0, 1], palette=colors)
        axes[0, 1].set_title('Crashes by Time of Day', fontsize=14)
        axes[0, 1].set_xlabel('Rush Hour Flag', fontsize=12)
        axes[0, 1].set_ylabel('Number of Crashes', fontsize=12)
        axes[0, 1].set_xticklabels(['Off Hours', 'Rush Hour'])
        # Add count and percentage labels
        for i, (count, pct) in enumerate(zip(rush_counts, rush_pct)):
            axes[0, 1].text(i, count/2, f'{count:,}\n({pct:.1f}%)',
                          ha='center', va='center', fontsize=12, color='white', fontweight='bold')
    else:
        axes[0, 1].text(0.5, 0.5, "Rush hour data not available", 
                       ha='center', va='center', transform=axes[0, 1].transAxes)
    
    # 3. Night visualization
    if 'night_flag' in df.columns:
        night_col = 'night_flag'
    elif 'IS_NIGHTTIME' in df.columns:
        night_col = 'IS_NIGHTTIME'
    else:
        # Create night flag if lighting condition is available
        if 'LIGHTING_CONDITION' in df.columns:
            df['night_flag'] = (~df['LIGHTING_CONDITION'].eq('DAYLIGHT')).astype(int)
            night_col = 'night_flag'
        # Or create based on hour if available
        elif 'CRASH_HOUR' in df.columns:
            df['night_flag'] = ((df['CRASH_HOUR'] < 6) | (df['CRASH_HOUR'] >= 20)).astype(int)
            night_col = 'night_flag'
        elif 'HOUR' in df.columns:
            df['night_flag'] = ((df['HOUR'] < 6) | (df['HOUR'] >= 20)).astype(int)
            night_col = 'night_flag'
        else:
            night_col = None
            
    if night_col:
        night_counts = df[night_col].value_counts().sort_index()
        night_pct = (night_counts / night_counts.sum() * 100).round(1)
        bar3 = sns.barplot(x=[0, 1], y=night_counts.values, ax=axes[1, 0], palette=colors)
        axes[1, 0].set_title('Crashes by Lighting Condition', fontsize=14)
        axes[1, 0].set_xlabel('Night Flag', fontsize=12)
        axes[1, 0].set_ylabel('Number of Crashes', fontsize=12)
        axes[1, 0].set_xticklabels(['Daylight', 'Night/Dark'])
        # Add count and percentage labels
        for i, (count, pct) in enumerate(zip(night_counts, night_pct)):
            axes[1, 0].text(i, count/2, f'{count:,}\n({pct:.1f}%)',
                          ha='center', va='center', fontsize=12, color='white', fontweight='bold')
    else:
        axes[1, 0].text(0.5, 0.5, "Night/day data not available", 
                       ha='center', va='center', transform=axes[1, 0].transAxes)
    
    # 4. Speed bin visualization
    if 'SPD_BIN' in df.columns:
        speed_bin_col = 'SPD_BIN'
    else:
        # Create speed bins if speed limit is available
        if 'POSTED_SPEED_LIMIT' in df.columns:
            bins = [0, 20, 30, 40, 50, 60, 100]
            labels = ['<20', '20-30', '30-40', '40-50', '50-60', '60+']
            df['SPD_BIN'] = pd.cut(df['POSTED_SPEED_LIMIT'], bins=bins, labels=labels, right=False)
            speed_bin_col = 'SPD_BIN'
        else:
            speed_bin_col = None
            
    if speed_bin_col and not df[speed_bin_col].isna().all():
        # Handle categorical speed bins
        speed_counts = df[speed_bin_col].value_counts().sort_index()
        speed_pct = (speed_counts / speed_counts.sum() * 100).round(1)
        # Use more colors for speed bins
        speed_colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(speed_counts)))
        bar4 = sns.barplot(x=range(len(speed_counts)), y=speed_counts.values, ax=axes[1, 1], palette=speed_colors)
        axes[1, 1].set_title('Crashes by Speed Limit', fontsize=14)
        axes[1, 1].set_xlabel('Speed Limit (mph)', fontsize=12)
        axes[1, 1].set_ylabel('Number of Crashes', fontsize=12)
        axes[1, 1].set_xticklabels(speed_counts.index)
        # Add count and percentage labels if there's enough space
        if len(speed_counts) <= 10:  # Only add text if not too crowded
            for i, (count, pct) in enumerate(zip(speed_counts, speed_pct)):
                axes[1, 1].text(i, count/2, f'{count:,}\n({pct:.1f}%)',
                             ha='center', va='center', fontsize=10, color='white', fontweight='bold')
    else:
        axes[1, 1].text(0.5, 0.5, "Speed bin data not available", 
                       ha='center', va='center', transform=axes[1, 1].transAxes)
    
    # Make all subplots have grid lines
    for ax in axes.flatten():
        ax.grid(True, alpha=0.3)
        # Set y-axis to start at 0
        ax.set_ylim(bottom=0)
    
    # Add a main title
    fig.suptitle('Crash Distribution by Key Factors', fontsize=16, y=0.98)
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()
    
    return fig
    
    # === New function for interactive map visualization ===
def plot_hex_crash_density(df):
    """Create a hexbin plot showing crash density on a map."""
    if 'LATITUDE' not in df.columns or 'LONGITUDE' not in df.columns:
        print("Coordinates not found in DataFrame")
        return
    
    # Filter to valid bounds (for Chicago)
    mask = (
        (df['LATITUDE'] > 41.5) & (df['LATITUDE'] < 42.1) &
        (df['LONGITUDE'] > -88.0) & (df['LONGITUDE'] < -87.4)
    )
    geo = df[mask]
    
    plt.figure(figsize=(10, 10))
    
    # Use matplotlib colors for normalization
    import matplotlib.colors as mcolors
    
    # Create the hexbin plot
    hb = plt.hexbin(
        geo['LONGITUDE'],
        geo['LATITUDE'],
        gridsize=50,                       # Adjust grid size 
        mincnt=5,                          # Only show bins with ≥5 crashes
        cmap='magma',                      # Darker colormap
        norm=mcolors.LogNorm(),            # Logarithmic color scaling
        linewidths=0.2,                    # Thin lines between hexagons
        edgecolors='black',                # Black edges for better contrast
        alpha=1.0                          # Full opacity for rich colors
    )
    
    # Add colorbar with formatting
    cb = plt.colorbar(hb, fraction=0.046, pad=0.04)
    cb.set_label('Crash Count (log scale)', fontsize=12)
    cb.ax.tick_params(labelsize=10)
    
    # Add informative title and labels
    plt.title('Crash Density Map of Chicago', fontsize=16, pad=20)
    plt.xlabel('Longitude', fontsize=12)
    plt.ylabel('Latitude', fontsize=12)
    
    # Add a summary stats textbox
    summary_text = f"Total Crashes: {len(geo):,}\nArea: Chicago City Limits"
    props = dict(boxstyle='round', facecolor='white', alpha=0.7)
    plt.annotate(summary_text, xy=(0.05, 0.05), xycoords='axes fraction', 
                fontsize=10, bbox=props)
    
    plt.tight_layout()
    plt.show()

def plot_crashes_by_weekday_donut(df):
    """Create a donut chart showing crash distribution by day of week."""
    # Map weekday numbers to names
    weekday_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
    
    # Get weekday data
    if 'CRASH_WEEKDAY' in df.columns:
        by_wd = df['CRASH_WEEKDAY'].map(lambda x: weekday_names[x] if pd.notna(x) and 0 <= x < 7 else None)
        by_wd = by_wd.value_counts().reindex(weekday_names).fillna(0)
    elif 'DAY_OF_WEEK' in df.columns:
        by_wd = df['DAY_OF_WEEK'].value_counts().reindex(weekday_names).fillna(0)
    elif 'CRASH_DAY_OF_WEEK' in df.columns:
        by_wd = df['CRASH_DAY_OF_WEEK'].map(lambda x: weekday_names[x-1] if pd.notna(x) and 1 <= x <= 7 else None)
        by_wd = by_wd.value_counts().reindex(weekday_names).fillna(0)
    else:
        print("No day of week column found")
        return None
    
    # Make sure all values are valid numbers (no NaN)
    if by_wd.isna().any() or (by_wd == 0).all():
        print("Warning: Not enough valid day of week data for donut chart")
        return None
    
    # Get counts and labels
    counts = by_wd.values
    labels = by_wd.index
    
    # Create figure
    fig, ax = plt.subplots(figsize=(6, 6))
    
    # Create pie chart with a hole
    wedges, texts, autotexts = ax.pie(
        counts,
        labels=labels,
        autopct='%1.1f%%',
        startangle=90,
        pctdistance=0.75,
        colors=plt.cm.viridis(np.linspace(0.1, 0.9, len(counts)))
    )
    
    # Draw center circle for the "donut" effect
    centre_circle = plt.Circle((0, 0), 0.50, fc='white')
    ax.add_artist(centre_circle)
    
    # Equal aspect ratio ensures the pie chart is circular
    ax.axis('equal')
    
    # Add title
    plt.title('Proportion of Crashes by Day of Week')
    plt.tight_layout()
    plt.show()
    
    return fig
    
def analyze_seasonal_patterns(df, column='CRASH_DATE'):
    """
    Analyze seasonal patterns using time series decomposition.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data
    column : str
        Name of datetime column to use
    """
    from statsmodels.tsa.seasonal import seasonal_decompose
    
    # Create monthly time series
    if 'CRASH_DATE' not in df.columns and 'CRASH_DATETIME' in df.columns:
        column = 'CRASH_DATETIME'
    
    # Resample to monthly data
    monthly_crashes = df.set_index(column).resample('M').size()
    
    # Fill any missing months with interpolation
    monthly_crashes = monthly_crashes.interpolate()
    
    # Perform seasonal decomposition
    decomposition = seasonal_decompose(monthly_crashes, model='additive', period=12)
    
    # Plot the decomposition
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(12, 12))
    
    # Original data
    decomposition.observed.plot(ax=ax1)
    ax1.set_title('Observed', fontsize=14)
    ax1.grid(True, alpha=0.3)
    
    # Trend component
    decomposition.trend.plot(ax=ax2)
    ax2.set_title('Trend', fontsize=14)
    ax2.grid(True, alpha=0.3)
    
    # Seasonal component
    decomposition.seasonal.plot(ax=ax3)
    ax3.set_title('Seasonality', fontsize=14)
    ax3.grid(True, alpha=0.3)
    
    # Residual component
    decomposition.resid.plot(ax=ax4)
    ax4.set_title('Residuals', fontsize=14)
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate average monthly pattern
    monthly_pattern = pd.DataFrame({
        'month': range(1, 13),
        'seasonal_factor': decomposition.seasonal.groupby(decomposition.seasonal.index.month).mean()
    })
    
    month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 
                  'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(month_names, monthly_pattern['seasonal_factor'], 
                 color=plt.cm.viridis(np.linspace(0, 1, 12)))
    
    plt.title('Average Seasonal Effect by Month', fontsize=14)
    plt.ylabel('Seasonal Factor (crashes above/below trend)', fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    
    # Annotate seasonal factors
    for i, v in enumerate(monthly_pattern['seasonal_factor']):
        plt.text(i, v + (v > 0) * 5 - (v < 0) * 20, f'{v:.1f}', 
                ha='center', fontsize=10)
    
    plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return decomposition, monthly_pattern

def analyze_weather_seasonal_effects(df):
    """
    Analyze how weather conditions vary across seasons and affect crash rates.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data with weather information
    """
    if 'WEATHER_CONDITION' not in df.columns or 'SEASON' not in df.columns:
        print("Required columns not found in DataFrame")
        return
    
    # Create cross-tabulation of weather conditions by season
    weather_by_season = pd.crosstab(df['SEASON'], df['WEATHER_CONDITION'])
    
    # Calculate proportions
    weather_props = weather_by_season.div(weather_by_season.sum(axis=1), axis=0)
    
    # Select top weather conditions for clarity
    top_weather = df['WEATHER_CONDITION'].value_counts().nlargest(5).index
    weather_props_filtered = weather_props[top_weather]
    
    # Plot as stacked bars
    plt.figure(figsize=(12, 8))
    weather_props_filtered.plot(
        kind='bar', 
        stacked=True, 
        colormap='viridis',
        figsize=(12, 6)
    )
    
    plt.title('Weather Conditions by Season', fontsize=14, pad=20)
    plt.xlabel('Season', fontsize=12)
    plt.ylabel('Proportion', fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    plt.xticks(rotation=0)
    plt.legend(title='Weather Condition', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
    
    # Analyze crash severity by season and weather
    if 'SEVERE' in df.columns:
        # Create a pivot table of severity rates by season and weather
        severity_pivot = df.pivot_table(
            values='SEVERE', 
            index='SEASON',
            columns='WEATHER_CONDITION',
            aggfunc='mean'
        )
        
        # Filter to top weather conditions
        severity_pivot_filtered = severity_pivot[top_weather]
        
        # Plot heatmap
        plt.figure(figsize=(12, 8))
        sns.heatmap(
            severity_pivot_filtered * 100,  # Convert to percentage
            annot=True,
            fmt='.1f',
            cmap='viridis',
            linewidths=0.5,
            cbar_kws={'label': 'Severe Crash Percentage'}
        )
        
        plt.title('Crash Severity Rate (%) by Season and Weather', fontsize=14, pad=20)
        plt.tight_layout()
        plt.show()
    
    return weather_props, severity_pivot if 'SEVERE' in df.columns else None

def analyze_hourly_patterns_by_season(df):
    """
    Analyze how hourly crash patterns change across seasons.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data
    """
    if 'CRASH_HOUR' not in df.columns and 'HOUR' not in df.columns:
        print("Hour column not found in DataFrame")
        return
    
    if 'SEASON' not in df.columns:
        print("Season column not found in DataFrame")
        return
    
    hour_col = 'CRASH_HOUR' if 'CRASH_HOUR' in df.columns else 'HOUR'
    
    # Create a pivot table of hourly patterns by season
    hourly_season = pd.crosstab(df[hour_col], df['SEASON'])
    
    # Normalize by season total for fair comparison
    hourly_season_pct = hourly_season.div(hourly_season.sum())
    
    # Create a custom line plot
    plt.figure(figsize=(12, 6))
    
    # Get distinct colors for each season from viridis
    colors = plt.cm.viridis(np.linspace(0, 1, len(hourly_season_pct.columns)))
    
    # Plot each season
    for i, season in enumerate(hourly_season_pct.columns):
        plt.plot(
            hourly_season_pct.index, 
            hourly_season_pct[season],
            label=season,
            color=colors[i],
            linewidth=2,
            marker='o',
            markersize=4
        )
    
    # Add shaded regions for standard periods
    plt.axvspan(7, 9, alpha=0.1, color='gray', label='_Morning Rush')
    plt.axvspan(16, 18, alpha=0.1, color='gray', label='_Evening Rush')
    plt.axvspan(0, 5, alpha=0.1, color='darkblue', label='_Night')
    plt.axvspan(21, 24, alpha=0.1, color='darkblue', label='_Night')
    
    # Add annotations for important periods
    plt.text(8, max(hourly_season_pct.max()) * 0.95, "Morning Rush", 
             ha='center', va='top', bbox=dict(facecolor='white', alpha=0.7))
    plt.text(17, max(hourly_season_pct.max()) * 0.95, "Evening Rush", 
             ha='center', va='top', bbox=dict(facecolor='white', alpha=0.7))
    
    # Customize plot
    plt.title('Hourly Crash Patterns by Season', fontsize=14, pad=20)
    plt.xlabel('Hour of Day', fontsize=12)
    plt.ylabel('Proportion of Daily Crashes', fontsize=12)
    plt.xticks(range(0, 24, 2))
    plt.xlim(-0.5, 23.5)
    plt.grid(True, alpha=0.3)
    plt.legend(title='Season')
    plt.tight_layout()
    plt.show()
    
    # Calculate rush hour effect by season
    morning_rush = hourly_season.loc[7:9].sum()
    evening_rush = hourly_season.loc[16:18].sum()
    night_time = hourly_season.loc[[0,1,2,3,4,22,23]].sum()
    
    rush_effects = pd.DataFrame({
        'Morning Rush': morning_rush / hourly_season.sum(),
        'Evening Rush': evening_rush / hourly_season.sum(),
        'Night Time': night_time / hourly_season.sum()
    })
    
    # Plot rush hour comparison
    plt.figure(figsize=(10, 6))
    rush_effects.plot(kind='bar', colormap='viridis')
    plt.title('Time Period Distribution by Season', fontsize=14, pad=20)
    plt.xlabel('Season', fontsize=12)
    plt.ylabel('Proportion of Crashes', fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return hourly_season, rush_effects

def plot_monthly_yoy_comparison(df, date_col='CRASH_DATE', years=None):
    """
    Compare monthly crash patterns across different years.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data
    date_col : str
        Name of datetime column
    years : list, optional
        List of years to include in the comparison
    """
    # Make sure we have a datetime column
    if date_col not in df.columns:
        if 'CRASH_DATETIME' in df.columns:
            date_col = 'CRASH_DATETIME'
        else:
            print("No date column found in DataFrame")
            return
    
    # Extract year and month
    df = df.copy()
    df['crash_year'] = df[date_col].dt.year
    df['crash_month'] = df[date_col].dt.month
    
    # Get available years
    available_years = sorted(df['crash_year'].unique())
    
    # Filter years if specified
    if years is not None:
        plot_years = [y for y in years if y in available_years]
        if not plot_years:
            print(f"None of the specified years {years} found in data.")
            print(f"Available years: {available_years}")
            return
    else:
        # Use latest years (up to 5) if not specified
        plot_years = available_years[-5:] if len(available_years) > 5 else available_years
    
    # Create monthly counts for each year
    monthly_counts = df[df['crash_year'].isin(plot_years)].groupby(
        ['crash_year', 'crash_month']).size().unstack(level=0)
    
    # Ensure all months are represented
    full_index = pd.Index(range(1, 13), name='crash_month')
    monthly_counts = monthly_counts.reindex(full_index, fill_value=0)
    
    # Month names
    month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                  'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    
    # Get colors from viridis
    colors = plt.cm.viridis(np.linspace(0, 1, len(plot_years)))
    
    # Create the plot
    plt.figure(figsize=(12, 6))
    
    # Plot each year
    for i, year in enumerate(plot_years):
        if year in monthly_counts.columns:
            plt.plot(
                range(1, 13), 
                monthly_counts[year],
                label=str(year),
                color=colors[i],
                marker='o',
                linewidth=2,
                markersize=6
            )
    
    # Add seasonal bands
    seasons = [
        {"name": "Winter", "months": [12, 1, 2], "color": "#e6f2ff"},
        {"name": "Spring", "months": [3, 4, 5], "color": "#e6ffe6"},
        {"name": "Summer", "months": [6, 7, 8], "color": "#ffebcc"},
        {"name": "Fall", "months": [9, 10, 11], "color": "#f2e6ff"}
    ]
    
    for season in seasons:
        for month in season["months"]:
            plt.axvspan(month-0.5, month+0.5, alpha=0.2, color=season["color"])
            
        # Only add label for first month of each season
        plt.text(
            season["months"][0], 
            monthly_counts.max().max() * 0.95,
            season["name"],
            ha='center',
            bbox=dict(facecolor='white', alpha=0.7)
        )
    
    # Customize plot
    plt.title('Monthly Crashes by Year', fontsize=14, pad=20)
    plt.xlabel('Month', fontsize=12)
    plt.ylabel('Number of Crashes', fontsize=12)
    plt.xticks(range(1, 13), month_names)
    plt.grid(True, alpha=0.3)
    plt.legend(title='Year')
    plt.tight_layout()
    plt.show()
    
    # Calculate and plot percent changes
    if len(plot_years) >= 2:
        # Get year-over-year changes
        pct_changes = monthly_counts.pct_change(axis=1) * 100
        
        # Plot heatmap of changes
        plt.figure(figsize=(10, 8))
        sns.heatmap(
            pct_changes.T,  # Transpose for better visualization
            annot=True,
            fmt='.1f',
            cmap='RdBu_r',  # Red-Blue diverging colormap
            center=0,       # Center color map at zero
            linewidths=0.5,
            cbar_kws={'label': 'Percent Change (%)'}
        )
        
        plt.title('Year-over-Year Percent Change by Month', fontsize=14, pad=20)
        plt.xlabel('Month', fontsize=12)
        plt.ylabel('Year Transition', fontsize=12)
        plt.tight_layout()
        plt.show()
    
    return monthly_counts
def plot_injury_rate_by_crash_type(df, top_n=10):
    """Plot injury rates by crash type."""
    if 'FIRST_CRASH_TYPE' not in df.columns or 'injury_flag' not in df.columns:
        print("Required columns not found in DataFrame")
        if 'INJURIES_TOTAL' in df.columns and 'FIRST_CRASH_TYPE' in df.columns:
            df = df.copy()
            df['injury_flag'] = (df['INJURIES_TOTAL'] > 0).astype(int)
        else:
            return
            
    # Compute injury rate by crash type
    injury_by_type = (
        df
        .groupby('FIRST_CRASH_TYPE')['injury_flag']
        .mean()
        .sort_values(ascending=False)
    )
    
    # Limit to top N types
    if top_n is not None:
        injury_by_type = injury_by_type.head(top_n)
        
    plt.figure(figsize=(10, 5))
    # Use standardized color palette
    sns.barplot(
        x=injury_by_type.values * 100,  # convert to percent
        y=injury_by_type.index,
        palette='rocket'
    )
    
    # Customize plot
    plt.xlabel('Percent of Crashes with ≥1 Injury', fontsize=12)
    plt.ylabel('Crash Type', fontsize=12)
    plt.title('Injury Rate by First Crash Type', fontsize=14, pad=20)
    plt.xlim(0, injury_by_type.max() * 100 + 5)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return injury_by_type
def plot_injury_rate_by_cause(df, top_n=10):
    """Plot injury rates by primary contributory cause."""
    if 'PRIM_CONTRIBUTORY_CAUSE' not in df.columns:
        print("Primary contributory cause column not found")
        return
        
    if 'injury_flag' not in df.columns and 'INJURIES_TOTAL' in df.columns:
        df = df.copy()
        df['injury_flag'] = (df['INJURIES_TOTAL'] > 0).astype(int)
    
    # First find the top N most frequent causes
    top_causes = df['PRIM_CONTRIBUTORY_CAUSE'].value_counts().nlargest(top_n).index
    
    # Calculate injury rate for each top cause
    injury_by_cause = (
        df[df['PRIM_CONTRIBUTORY_CAUSE'].isin(top_causes)]
        .groupby('PRIM_CONTRIBUTORY_CAUSE')['injury_flag']
        .mean()
        .sort_values(ascending=False)
    )
    
    plt.figure(figsize=(10, 5))
    sns.barplot(
        x=injury_by_cause.values * 100,
        y=injury_by_cause.index,
        palette='mako'
    )
    
    # Customize plot
    plt.xlabel('Percent of Crashes with ≥1 Injury', fontsize=12)
    plt.ylabel('Primary Contributory Cause', fontsize=12)
    plt.title(f'Injury Rate by Top {top_n} Contributory Causes', fontsize=14, pad=20)
    plt.xlim(0, injury_by_cause.max() * 100 + 5)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return injury_by_cause
def plot_injury_severity_by_weather(df, top_n=6):
    """Plot proportions of injury severity levels per weather condition."""
    if 'WEATHER_CONDITION' not in df.columns or 'MOST_SEVERE_INJURY' not in df.columns:
        print("Required columns not found in DataFrame")
        return
        
    # Get the top N most common weather conditions
    top_weather = df['WEATHER_CONDITION'].value_counts().nlargest(top_n).index.tolist()
    
    # Filter for only those weather conditions
    weather_filtered = df[df['WEATHER_CONDITION'].isin(top_weather)]
    
    # Build the cross-tab of proportions with better ordering
    ct = pd.crosstab(
        weather_filtered['WEATHER_CONDITION'],
        weather_filtered['MOST_SEVERE_INJURY'],
        normalize='index'  # each row sums to 1 → proportions
    )
    
    # Define a custom color palette that's colorblind-friendly
    colors = sns.color_palette("viridis", len(ct.columns))
    
    # Sort rows by severity if 'FATAL' is present
    if 'FATAL' in ct.columns:
        ct = ct.sort_values(by='FATAL', ascending=False)
        
    # Plot as a horizontal stacked bar chart with improved formatting
    ax = ct.plot(
        kind='barh',
        stacked=True,
        figsize=(12, 8),
        color=colors,
        linewidth=0.5,
        edgecolor='white'  # thin white edge between segments
    )
    
    # Add percentage labels to segments (for segments > 5%)
    for i, row in enumerate(ct.values):
        xpos = 0
        for j, val in enumerate(row):
            if val >= 0.05:  # Only label segments that are at least 5%
                ax.text(xpos + val/2, i, f'{val:.0%}',
                      ha='center', va='center', color='white', fontweight='bold')
            xpos += val
    
    # Improve labels and legend
    plt.xlabel('Proportion of Crashes', fontsize=12)
    plt.ylabel('Weather Condition', fontsize=12)
    plt.title('Crash Severity Distribution by Weather Condition', fontsize=14, pad=20)
    plt.legend(title='Most Severe Injury', bbox_to_anchor=(1.05, 1.0))
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return ct
def plot_hex_crash_density(df):
    """Create a hexbin plot showing crash density on a map."""
    if 'LATITUDE' not in df.columns or 'LONGITUDE' not in df.columns:
        print("Coordinates not found in DataFrame")
        return
    
    # Filter to valid bounds (for Chicago)
    mask = (
        (df['LATITUDE'] > 41.5) & (df['LATITUDE'] < 42.1) &
        (df['LONGITUDE'] > -88.0) & (df['LONGITUDE'] < -87.4)
    )
    geo = df[mask]
    
    plt.figure(figsize=(10, 10))
    # Use matplotlib colors for normalization
    import matplotlib.colors as mcolors
    
    # Create the hexbin plot
    hb = plt.hexbin(
        geo['LONGITUDE'],
        geo['LATITUDE'],
        gridsize=50,  # Adjust grid size
        mincnt=5,     # Only show bins with ≥5 crashes
        cmap='magma', # Darker colormap
        norm=mcolors.LogNorm(),  # Logarithmic color scaling
        linewidths=0.2,         # Thin lines between hexagons
        edgecolors='black',     # Black edges for better contrast
        alpha=1.0              # Full opacity for rich colors
    )
    
    # Add colorbar with formatting
    cb = plt.colorbar(hb, fraction=0.046, pad=0.04)
    cb.set_label('Crash Count (log scale)', fontsize=12)
    cb.ax.tick_params(labelsize=10)
    
    # Add informative title and labels
    plt.title('Crash Density Map of Chicago', fontsize=16, pad=20)
    plt.xlabel('Longitude', fontsize=12)
    plt.ylabel('Latitude', fontsize=12)
    
    # Add a summary stats textbox
    summary_text = f"Total Crashes: {len(geo):,}\nArea: Chicago City Limits"
    props = dict(boxstyle='round', facecolor='white', alpha=0.7)
    plt.annotate(summary_text, xy=(0.05, 0.05), xycoords='axes fraction',
                fontsize=10, bbox=props)
    
    plt.tight_layout()
    plt.show()
def analyze_seasonal_factors(df, include_weather=True):
    """
    Analyze seasonal factors affecting crash patterns.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data
    include_weather : bool
        Whether to include weather analysis
    """
    # Create a copy for analysis
    df_analysis = df.copy()
    
    # Make sure we have required columns
    if 'MONTH' not in df_analysis.columns and 'CRASH_MONTH' in df_analysis.columns:
        df_analysis['MONTH'] = df_analysis['CRASH_MONTH']
    
    if 'SEASON' not in df_analysis.columns and 'MONTH' in df_analysis.columns:
        df_analysis['SEASON'] = df_analysis['MONTH'].apply(get_season)
    
    # Group data by month
    monthly_stats = df_analysis.groupby('MONTH').agg({
        'CRASH_RECORD_ID': 'count',
        'INJURIES_TOTAL': 'sum',
        'SEVERE': 'mean' if 'SEVERE' in df_analysis.columns else lambda x: 0,
        'BAD_WEATHER': 'mean' if 'BAD_WEATHER' in df_analysis.columns else lambda x: 0,
        'BAD_SURFACE': 'mean' if 'BAD_SURFACE' in df_analysis.columns else lambda x: 0,
        'IS_WEEKEND': 'mean' if 'IS_WEEKEND' in df_analysis.columns else lambda x: 0
    }).reset_index()
    
    monthly_stats['AVG_INJURIES_PER_CRASH'] = monthly_stats['INJURIES_TOTAL'] / monthly_stats['CRASH_RECORD_ID']
    
    # Add month names
    month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                  'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    monthly_stats['MONTH_NAME'] = monthly_stats['MONTH'].apply(lambda x: month_names[x-1])
    
    # Create a multi-metric comparison plot
    fig, axes = plt.subplots(3, 1, figsize=(12, 12), sharex=True)
    
    # Plot crash counts
    axes[0].bar(
        monthly_stats['MONTH'], 
        monthly_stats['CRASH_RECORD_ID'],
        color=plt.cm.viridis(np.linspace(0, 1, 12))
    )
    axes[0].set_title('Monthly Crash Counts', fontsize=12)
    axes[0].set_ylabel('Number of Crashes', fontsize=10)
    axes[0].grid(True, alpha=0.3)
    
    # Plot injury metrics
    axes[1].plot(
        monthly_stats['MONTH'], 
        monthly_stats['AVG_INJURIES_PER_CRASH'], 
        'o-',
        color=VIRIDIS_COLORS['main'],
        linewidth=2
    )
    axes[1].set_title('Average Injuries per Crash', fontsize=12)
    axes[1].set_ylabel('Avg Injuries', fontsize=10)
    axes[1].grid(True, alpha=0.3)
    
    # Plot severity rate
    severity_color = VIRIDIS_COLORS['tertiary']
    axes[2].plot(
        monthly_stats['MONTH'], 
        monthly_stats['SEVERE'] * 100, 
        'o-',
        color=severity_color,
        linewidth=2,
        label='Severity Rate'
    )
    axes[2].set_title('Crash Severity Rate (%)', fontsize=12)
    axes[2].set_ylabel('Percent Severe', fontsize=10)
    axes[2].set_ylim(0, min(100, monthly_stats['SEVERE'].max() * 100 * 1.2))
    
    # Add weather factors if available
    if include_weather and 'BAD_WEATHER' in monthly_stats.columns:
        ax2 = axes[2].twinx()
        ax2.plot(
            monthly_stats['MONTH'], 
            monthly_stats['BAD_WEATHER'] * 100, 
            's--',
            color='darkblue',
            alpha=0.7,
            label='Bad Weather'
        )
        if 'BAD_SURFACE' in monthly_stats.columns:
            ax2.plot(
                monthly_stats['MONTH'], 
                monthly_stats['BAD_SURFACE'] * 100, 
                '^--',
                color='darkred',
                alpha=0.7,
                label='Bad Surface'
            )
        ax2.set_ylabel('Weather Factors (%)', fontsize=10)
        ax2.legend(loc='upper right')
    
    axes[2].grid(True, alpha=0.3)
    axes[2].legend(loc='upper left')
    
    # Set shared x-axis labels
    axes[2].set_xlabel('Month', fontsize=12)
    axes[2].set_xticks(range(1, 13))
    axes[2].set_xticklabels(month_names)
    
    # Add seasonal background
    for ax in axes:
        # Winter (Dec-Feb)
        ax.axvspan(0.5, 2.5, alpha=0.1, color='#e6f2ff')
        ax.axvspan(11.5, 12.5, alpha=0.1, color='#e6f2ff')
        
        # Spring (Mar-May)
        ax.axvspan(2.5, 5.5, alpha=0.1, color='#e6ffe6')
        
        # Summer (Jun-Aug)
        ax.axvspan(5.5, 8.5, alpha=0.1, color='#ffebcc')
        
        # Fall (Sep-Nov)
        ax.axvspan(8.5, 11.5, alpha=0.1, color='#f2e6ff')
    
    # Add season labels to top plot
    for i, season in enumerate(['Winter', 'Spring', 'Summer', 'Fall']):
        month_pos = [1, 4, 7, 10][i]  # Middle month of each season
        axes[0].text(
            month_pos, 
            monthly_stats['CRASH_RECORD_ID'].max() * 0.9,
            season,
            ha='center',
            bbox=dict(facecolor='white', alpha=0.7)
        )
    
    plt.suptitle('Seasonal Patterns in Crash Metrics', fontsize=16, y=0.98)
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()
    
    return monthly_stats

#### Part 6: Feature Engineering & Selection

In [None]:
# === Feature Engineering & Selection ===
def select_features(X, y, methods=None):
    """
    Select features using multiple methods and return a consensus.
    
    Parameters:
    -----------
    X : pandas.DataFrame
        Feature matrix
    y : pandas.Series
        Target variable
    methods : list, optional
        List of methods to use ('mutual_info', 'rfe', 'permutation')
        
    Returns:
    --------
    dict
        Dictionary with results from each method and consensus features
    """

    if methods is None:
        methods = ['mutual_info', 'rfe', 'permutation']
    
    results = {}
    
    # 1. Mutual Information (Filter Method)
    if 'mutual_info' in methods:
        mi = mutual_info_classif(X, y, random_state=random_seed )
        mi_series = pd.Series(mi, index=X.columns).sort_values(ascending=False)
        results['mutual_info'] = mi_series
    
    # 2. Recursive Feature Elimination (Wrapper Method)
    if 'rfe' in methods:
        lr = LogisticRegression(solver='liblinear', max_iter=1000, random_state=random_seed )
        rfe = RFE(lr, n_features_to_select=10, step=1)
        rfe.fit(X, y)
        selected = pd.Series(rfe.support_, index=X.columns)
        results['rfe'] = selected[selected].index.tolist()
    
    # 3. Permutation Importance (Embedded Method)
    if 'permutation' in methods:
        rf = RandomForestClassifier(n_estimators=100, random_state=random_seed , n_jobs=-1)
        rf.fit(X, y)
        perm = permutation_importance(rf, X, y, n_repeats=5, random_state=random_seed , n_jobs=-1)
        perm_series = pd.Series(perm.importances_mean, index=X.columns).sort_values(ascending=False)
        results['permutation'] = perm_series
    
    # 4. Create consensus feature set
    consensus_features = []
    
    # Add top 5 from mutual info
    if 'mutual_info' in results:
        consensus_features.extend(results['mutual_info'].nlargest(5).index.tolist())
    
    # Add all from RFE
    if 'rfe' in results:
        consensus_features.extend(results['rfe'])
    
    # Add top 5 from permutation
    if 'permutation' in results:
        consensus_features.extend(results['permutation'].nlargest(5).index.tolist())
    
    # Remove duplicates while preserving order
    consensus_features = list(dict.fromkeys(consensus_features))
    results['consensus'] = consensus_features
    
    return results

def engineer_features(df):
    """Create engineered features for analysis and modeling."""
    df_eng = df.copy()
    
    # 1. Original time-based flags
    if 'CRASH_WEEKDAY' in df_eng.columns:
        df_eng['is_weekend'] = df_eng['CRASH_WEEKDAY'].isin([5, 6]).astype(int)
    elif 'CRASH_DAY_OF_WEEK' in df_eng.columns:
        df_eng['is_weekend'] = df_eng['CRASH_DAY_OF_WEEK'].isin([5, 6]).astype(int)
    
    if 'CRASH_HOUR' in df_eng.columns:
        df_eng['is_rush_hour'] = (
            df_eng['CRASH_HOUR'].between(7, 9) |
            df_eng['CRASH_HOUR'].between(16, 18)
        ).astype(int)
    
    # 2. Original environmental conditions
    if 'LIGHTING_CONDITION' in df_eng.columns:
        df_eng['night_flag'] = (~df_eng['LIGHTING_CONDITION'].eq('DAYLIGHT')).astype(int)
    
    # 3. Original speed limit bins
    if 'POSTED_SPEED_LIMIT' in df_eng.columns:
        bins = [0, 20, 30, 40, 50, 60, 100]
        labels = ['<20', '20-30', '30-40', '40-50', '50-60', '60+']
        df_eng['SPD_BIN'] = pd.cut(df_eng['POSTED_SPEED_LIMIT'],
                                   bins=bins, labels=labels, right=False)
    
    # 4. Original target creation
    if 'INJURIES_TOTAL' in df_eng.columns:
        df_eng['injury_flag'] = (df_eng['INJURIES_TOTAL'] > 0).astype(int)
    
    # 5. Original weather conditions
    if 'WEATHER_CONDITION' in df_eng.columns:
        bad_weather = ['RAIN', 'SNOW', 'SLEET', 'FREEZING RAIN']
        df_eng['BAD_WEATHER'] = df_eng['WEATHER_CONDITION'].isin(bad_weather).astype(int)
    
    # 6. Original road surface conditions
    if 'ROADWAY_SURFACE_COND' in df_eng.columns:
        bad_surface = ['ICE', 'SNOW OR SLUSH', 'WET']
        df_eng['BAD_SURFACE'] = df_eng['ROADWAY_SURFACE_COND'].isin(bad_surface).astype(int)
    
    # 7. Original location type
    if 'FIRST_CRASH_TYPE' in df_eng.columns:
        df_eng['IS_VULNERABLE'] = df_eng['FIRST_CRASH_TYPE'].isin(
            ['PEDESTRIAN', 'PEDALCYCLIST']).astype(int)
    
    # NEW FEATURES BELOW
    # 8. Temporal interaction features
    if 'CRASH_HOUR' in df_eng.columns and 'is_weekend' in df_eng.columns:
        # Weekend evening/night interaction
        df_eng['WEEKEND_NIGHT'] = (
            (df_eng['is_weekend'] == 1) &
            ((df_eng['CRASH_HOUR'] < 6) | (df_eng['CRASH_HOUR'] >= 20))
        ).astype(int)
        
        # Rush hour specific to weekdays
        df_eng['WEEKDAY_RUSH_HOUR'] = (
            (df_eng['is_weekend'] == 0) &
            ((df_eng['CRASH_HOUR'].between(7, 9)) | (df_eng['CRASH_HOUR'].between(16, 18)))
        ).astype(int)
    
    # 9. Weather and time interactions
    if 'BAD_WEATHER' in df_eng.columns and 'night_flag' in df_eng.columns:
        # Bad weather during night
        df_eng['NIGHT_BAD_WEATHER'] = (
            (df_eng['BAD_WEATHER'] == 1) &
            (df_eng['night_flag'] == 1)
        ).astype(int)
        
        # Bad weather during rush hour
        if 'is_rush_hour' in df_eng.columns:
            df_eng['RUSH_HOUR_BAD_WEATHER'] = (
                (df_eng['BAD_WEATHER'] == 1) &
                (df_eng['is_rush_hour'] == 1)
            ).astype(int)
    
    # 10. Speed-related features
    if 'POSTED_SPEED_LIMIT' in df_eng.columns:
        # High speed indicator (over 40 mph)
        df_eng['HIGH_SPEED_ROAD'] = (df_eng['POSTED_SPEED_LIMIT'] > 40).astype(int)
        
        # Speed and weather interaction
        if 'BAD_WEATHER' in df_eng.columns:
            df_eng['HIGH_SPEED_BAD_WEATHER'] = (
                (df_eng['HIGH_SPEED_ROAD'] == 1) &
                (df_eng['BAD_WEATHER'] == 1)
            ).astype(int)
        
        # Speed and intersection interaction
        if 'AT_INTERSECTION' in df_eng.columns:
            df_eng['HIGH_SPEED_INTERSECTION'] = (
                (df_eng['HIGH_SPEED_ROAD'] == 1) &
                (df_eng['AT_INTERSECTION'] == 1)
            ).astype(int)
    
    # 11. Vehicle count and crash complexity
    if 'NUM_UNITS' in df_eng.columns:
        # Multi-vehicle crash (3+ vehicles)
        df_eng['MULTI_VEHICLE_CRASH'] = (df_eng['NUM_UNITS'] >= 3).astype(int)
        # Single vehicle crash
        df_eng['SINGLE_VEHICLE_CRASH'] = (df_eng['NUM_UNITS'] == 1).astype(int)
    
    # 12. Vulnerable road user features
    if 'IS_VULNERABLE' in df_eng.columns:
        # Vulnerable user in high-speed area
        if 'HIGH_SPEED_ROAD' in df_eng.columns:
            df_eng['VULNERABLE_HIGH_SPEED'] = (
                (df_eng['IS_VULNERABLE'] == 1) &
                (df_eng['HIGH_SPEED_ROAD'] == 1)
            ).astype(int)
        
        # Vulnerable user at night
        if 'night_flag' in df_eng.columns:
            df_eng['VULNERABLE_NIGHT'] = (
                (df_eng['IS_VULNERABLE'] == 1) &
                (df_eng['night_flag'] == 1)
            ).astype(int)
    
    # 13. Seasonal features
    if 'SEASON' in df_eng.columns:
        # One-hot encode season
        season_dummies = pd.get_dummies(df_eng['SEASON'], prefix='SEASON')
        df_eng = pd.concat([df_eng, season_dummies], axis=1)
        
        # Winter conditions interaction (winter and bad weather/surface)
        if 'BAD_SURFACE' in df_eng.columns and 'SEASON_Winter' in df_eng.columns:
            df_eng['WINTER_BAD_SURFACE'] = (
                (df_eng['SEASON_Winter'] == 1) &
                (df_eng['BAD_SURFACE'] == 1)
            ).astype(int)
    
    # 14. Time since previous crash (if timestamp information is available)
    if 'CRASH_DATETIME' in df_eng.columns:
        # Sort by datetime
        df_temp = df_eng.sort_values('CRASH_DATETIME')
        # Calculate time difference in hours
        df_temp['TIME_SINCE_PREV_CRASH'] = df_temp['CRASH_DATETIME'].diff().dt.total_seconds() / 3600
        # Add back to original dataframe
        df_eng['TIME_SINCE_PREV_CRASH'] = df_temp['TIME_SINCE_PREV_CRASH']
        # Fill NaN for first crash (using direct assignment instead of inplace)
        df_eng['TIME_SINCE_PREV_CRASH'] = df_eng['TIME_SINCE_PREV_CRASH'].fillna(24)  # Assume 24 hours for first crash
        # Create bins for time since previous crash
        df_eng['RECENT_CRASH_AREA'] = (df_eng['TIME_SINCE_PREV_CRASH'] < 2).astype(int)
    
    # 15. Holiday indicators (optional - handle gracefully if holidays module not available)
    if 'CRASH_DATE' in df_eng.columns:
        try:
            import holidays
            us_holidays = holidays.US()
            # Check if crash occurred on a holiday
            df_eng['IS_HOLIDAY'] = df_eng['CRASH_DATE'].apply(lambda x: x in us_holidays).astype(int)
            # Check if crash occurred day before or after holiday
            df_eng['NEAR_HOLIDAY'] = df_eng['CRASH_DATE'].apply(
                lambda x: (x + pd.Timedelta(days=1) in us_holidays) or
                         (x - pd.Timedelta(days=1) in us_holidays)
            ).astype(int)
        except ImportError:
            print("Note: 'holidays' module not available - holiday features will not be created")
            df_eng['IS_HOLIDAY'] = 0  # Default values when module not available
            df_eng['NEAR_HOLIDAY'] = 0
    
    # 16. Road condition and vehicle interactions
    if 'ROADWAY_SURFACE_COND' in df_eng.columns and 'NUM_UNITS' in df_eng.columns:
        # Multiple vehicles on bad road surface
        if 'BAD_SURFACE' in df_eng.columns:
            df_eng['MULTI_VEHICLE_BAD_SURFACE'] = (
                (df_eng['NUM_UNITS'] >= 2) &
                (df_eng['BAD_SURFACE'] == 1)
            ).astype(int)
    
    # 17. Time of year and daylight interaction
    if 'CRASH_MONTH' in df_eng.columns and 'night_flag' in df_eng.columns:
        # Winter months with darkness
        winter_months = [11, 12, 1, 2]
        df_eng['WINTER_DARKNESS'] = (
            (df_eng['CRASH_MONTH'].isin(winter_months)) &
            (df_eng['night_flag'] == 1)
        ).astype(int)
    
    # 18. Traffic signals and intersection features
    if 'TRAFFIC_CONTROL_DEVICE' in df_eng.columns:
        # No traffic control at intersection
        if 'AT_INTERSECTION' in df_eng.columns:
            no_signal = ['NONE', 'NO CONTROLS']
            df_eng['UNCONTROLLED_INTERSECTION'] = (
                (df_eng['AT_INTERSECTION'] == 1) &
                (df_eng['TRAFFIC_CONTROL_DEVICE'].isin(no_signal))
            ).astype(int)
    
    # 19. Age-based features if available
    if 'DRIVER_AGE' in df_eng.columns:
        # Young driver indicator
        df_eng['YOUNG_DRIVER'] = (df_eng['DRIVER_AGE'] < 25).astype(int)
        # Senior driver indicator
        df_eng['SENIOR_DRIVER'] = (df_eng['DRIVER_AGE'] >= 65).astype(int)
        
        # Young driver at night
        if 'night_flag' in df_eng.columns:
            df_eng['YOUNG_DRIVER_NIGHT'] = (
                (df_eng['YOUNG_DRIVER'] == 1) &
                (df_eng['night_flag'] == 1)
            ).astype(int)
    
    # 20. Alcohol and time features
    if 'ALCOHOL_INVOLVED' in df_eng.columns and 'CRASH_HOUR' in df_eng.columns:
        # Late night alcohol crashes
        df_eng['LATE_NIGHT_ALCOHOL'] = (
            (df_eng['ALCOHOL_INVOLVED'] == 1) &
            ((df_eng['CRASH_HOUR'] >= 22) | (df_eng['CRASH_HOUR'] <= 4))
        ).astype(int)
        
        # Weekend alcohol crashes
        if 'is_weekend' in df_eng.columns:
            df_eng['WEEKEND_ALCOHOL'] = (
                (df_eng['ALCOHOL_INVOLVED'] == 1) &
                (df_eng['is_weekend'] == 1)
            ).astype(int)
    
    return df_eng

def analyze_model_with_shap(model, X_test, feature_names=None):
    """
    Analyze model predictions using SHAP values for interpretability.
    
    Parameters:
    -----------
    model : trained model instance
        The model to analyze
    X_test : pandas.DataFrame or numpy.ndarray
        Test feature matrix
    feature_names : list, optional
        List of feature names if X_test is not a DataFrame
        
    Returns:
    --------
    dict
        Dictionary containing SHAP values and feature importance
    """

    # Get feature names
    if feature_names is None and hasattr(X_test, 'columns'):
        feature_names = X_test.columns
    
    results = {}
    
    # Create SHAP explainer based on model type
    model_name = model.__class__.__name__
    print(f"Analyzing {model_name} with SHAP...")
    
    try:
        # Sample a subset of test data for SHAP analysis (for efficiency)
        sample_size = min(1000, X_test.shape[0])
        X_sample = X_test.iloc[:sample_size] if hasattr(X_test, 'iloc') else X_test[:sample_size]
        
        # Different explainers for different model types
        if model_name in ['LGBMClassifier', 'XGBClassifier', 'CatBoostClassifier', 'RandomForestClassifier']:
            explainer = shap.TreeExplainer(model)
        elif model_name in ['LogisticRegression', 'LinearRegression', 'Ridge', 'Lasso']:
            background = shap.sample(X_sample, 100)
            explainer = shap.LinearExplainer(model, background)
        else:
            # For other model types, use kernel explainer
            background = shap.sample(X_sample, 100)
            predict_fn = model.predict_proba if hasattr(model, 'predict_proba') else model.predict
            explainer = shap.KernelExplainer(predict_fn, background)
        
        # Generate SHAP values
        shap_values = explainer.shap_values(X_sample)
        
        # For classifiers with predict_proba, use the positive class SHAP values
        if isinstance(shap_values, list) and len(shap_values) > 1:
            shap_values = shap_values[1]  # Second class for binary classification
            
        results['shap_values'] = shap_values
        results['explainer'] = explainer
        results['data_sample'] = X_sample
        
        # Generate feature importance based on SHAP
        if feature_names is not None:
            feature_importance = np.abs(shap_values).mean(0)
            importance_df = pd.DataFrame({
                'Feature': feature_names,
                'Importance': feature_importance
            })
            results['feature_importance'] = importance_df.sort_values('Importance', ascending=False)
            
            # Print top 10 features
            print("\nTop 10 features by SHAP importance:")
            for i, (feature, importance) in enumerate(zip(
                    results['feature_importance']['Feature'][:10], 
                    results['feature_importance']['Importance'][:10])):
                print(f"{i+1}. {feature}: {importance:.4f}")
        
    except Exception as e:
        print(f"Error generating SHAP values: {str(e)}")
        results['error'] = str(e)
    
    return results

#### Part 7: Model Training & Evaluation


In [None]:
def train_evaluate_models(X_train, X_test, y_train, y_test, models=None):
    """
    Train and evaluate multiple models.
    Parameters:
    -----------
    X_train, X_test : pandas.DataFrame
        Training and test feature matrices
    y_train, y_test : pandas.Series
        Training and test target variables
    models : dict, optional
        Dictionary of model instances to train
    Returns:
    --------
    dict
        Dictionary of trained models and their evaluation metrics
    """

    
    # Define default models if none provided
    if models is None:
        models = {
            # Original models
            'Logistic Regression': LogisticRegression(class_weight='balanced', max_iter=1000, random_state=random_seed ),
            'Random Forest': RandomForestClassifier(class_weight='balanced', random_state=random_seed , n_jobs=-1),
            'LightGBM': lgb.LGBMClassifier(class_weight='balanced', random_state=random_seed , n_jobs=-1),
            'XGBoost': xgb.XGBClassifier(random_state=random_seed , n_jobs=-1),
            
            # New models
            'CatBoost': cb.CatBoostClassifier(random_seed=random_seed, verbose=0, thread_count=-1),
            'SVM': LinearSVC(class_weight='balanced', random_state=random_seed ),
            'Neural Network': MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500, random_state=random_seed ),
            'Naive Bayes': GaussianNB(),
            
            # Ensemble model
            'Voting Classifier': VotingClassifier(
                estimators=[
                    ('lr', LogisticRegression(random_state=random_seed )),
                    ('rf', RandomForestClassifier(n_estimators=100, random_state=random_seed )),
                    ('lgb', lgb.LGBMClassifier(random_state=random_seed ))
                ],
                voting='soft'
            )
        }
    
    results = {}
    
    for name, model in models.items():
        print(f"Training {name}...")
        # Train the model
        model.fit(X_train, y_train)
        
        # Make predictions
        y_pred = model.predict(X_test)
        y_prob = model.predict_proba(X_test)[:,1] if hasattr(model, 'predict_proba') else None
        
        # Calculate metrics
        metrics = {
            'accuracy': accuracy_score(y_test, y_pred),
            'precision': precision_score(y_test, y_pred),
            'recall': recall_score(y_test, y_pred),
            'f1': f1_score(y_test, y_pred)
        }
        
        # Add probability-based metrics if available
        if y_prob is not None:
            metrics.update({
                'roc_auc': roc_auc_score(y_test, y_prob),
                'pr_auc': average_precision_score(y_test, y_prob)
            })
        
        # Store results
        results[name] = {
            'model': model,
            'metrics': metrics,
            'predictions': y_pred,
            'probabilities': y_prob
        }
        
        # Print results
        print(f" Accuracy: {metrics['accuracy']:.4f}")
        print(f" Precision: {metrics['precision']:.4f}")
        print(f" Recall: {metrics['recall']:.4f}")
        print(f" F1 Score: {metrics['f1']:.4f}")
        if 'roc_auc' in metrics:
            print(f" ROC AUC: {metrics['roc_auc']:.4f}")
            print(f" PR AUC: {metrics['pr_auc']:.4f}")
        print()
        
    return results

def evaluate_regression_models(X_train, X_test, y_train, y_test, models=None):
    """
    Train and evaluate multiple regression models with consistent metrics.
    Parameters:
    -----------
    X_train, X_test : pandas.DataFrame
        Training and test feature matrices
    y_train, y_test : pandas.Series
        Training and test target variables
    models : dict, optional
        Dictionary of model instances to train
    Returns:
    --------
    dict
        Dictionary with results from each model
    """

    
    # Define default models if none provided
    if models is None:
        models = {
            # Original models
            "Baseline Mean": DummyRegressor(strategy="mean"),
            "Linear Regression": LinearRegression(),
            "Random Forest": RandomForestRegressor(n_estimators=100, random_state=random_seed , n_jobs=-1),
            
            # New models
            "Ridge": Ridge(alpha=1.0, random_state=random_seed ),
            "Lasso": Lasso(alpha=0.1, random_state=random_seed ),
            "ElasticNet": ElasticNet(alpha=0.1, l1_ratio=0.5, random_state=random_seed ),
            "LightGBM Regressor": lgb.LGBMRegressor(n_estimators=100, random_state=random_seed , n_jobs=-1),
            "XGBoost Regressor": xgb.XGBRegressor(n_estimators=100, random_state=random_seed , n_jobs=-1),
            "CatBoost Regressor": cb.CatBoostRegressor(iterations=100, random_seed=random_seed , verbose=0),
            "Neural Network Regressor": MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=random_seed ),
            "Gradient Boosting": GradientBoostingRegressor(n_estimators=100, random_state=random_seed ),
            "Poisson Regressor": PoissonRegressor(alpha=1.0, max_iter=1000)
        }
    
    results = []
    
    for name, model in models.items():
        print(f"Training {name}...")
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        
        # Calculate metrics
        mae = mean_absolute_error(y_test, y_pred)
        mse = mean_squared_error(y_test, y_pred)
        rmse = np.sqrt(mse)
        r2 = r2_score(y_test, y_pred)
        mdae = median_absolute_error(y_test, y_pred)
        
        results.append({
            "Model": name,
            "MAE": mae,
            "RMSE": rmse,
            "MdAE": mdae,
            "R²": r2
        })
        
        print(f" MAE: {mae:.3f}, RMSE: {rmse:.3f}, MdAE: {mdae:.3f}, R²: {r2:.4f}")
    
    # Format results as DataFrame for better display

    results_df = pd.DataFrame(results).set_index("Model")
    
    return results_df.sort_values("RMSE")

def train_time_series_models(df, date_col='CRASH_DATE', target_col='crash_count', horizon=30):
    """
    Train various time series models for crash prediction.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data
    date_col : str
        Name of the date column
    target_col : str
        Name of the target column (crash counts)
    horizon : int
        Forecast horizon in days
        
    Returns:
    --------
    dict
        Dictionary of trained models and their forecast results
    """

    
    # Create time series data
    if target_col not in df.columns:
        # Create daily crash counts
        ts_data = df.set_index(date_col).resample('D').size().reset_index()
        ts_data.columns = [date_col, 'crash_count']
    else:
        ts_data = df[[date_col, target_col]].copy()
    
    # Set date as index
    ts_data.set_index(date_col, inplace=True)
    
    # Split into train and test
    train_size = int(len(ts_data) * 0.8)
    train_data = ts_data.iloc[:train_size]
    test_data = ts_data.iloc[train_size:]
    
    print(f"Training with {len(train_data)} observations and testing with {len(test_data)} observations")
    
    models = {}
    
    # 1. ARIMA model
    print("Training ARIMA model...")
    try:
        arima_model = ARIMA(train_data, order=(2,1,2))
        arima_results = arima_model.fit()
        arima_forecast = arima_results.forecast(steps=len(test_data))
        models['ARIMA'] = {
            'model': arima_results,
            'forecast': arima_forecast,
            'mse': np.mean((test_data.values.flatten() - arima_forecast)**2),
            'rmse': np.sqrt(np.mean((test_data.values.flatten() - arima_forecast)**2))
        }
        print(f" ARIMA RMSE: {models['ARIMA']['rmse']:.3f}")
    except Exception as e:
        print(f" Error training ARIMA: {str(e)}")
    
    # 2. SARIMA model (with seasonality)
    print("Training SARIMA model...")
    try:
        sarima_model = SARIMAX(train_data, order=(2,1,2), seasonal_order=(1,1,1,7))
        sarima_results = sarima_model.fit(disp=False)
        sarima_forecast = sarima_results.forecast(steps=len(test_data))
        models['SARIMA'] = {
            'model': sarima_results,
            'forecast': sarima_forecast,
            'mse': np.mean((test_data.values.flatten() - sarima_forecast)**2),
            'rmse': np.sqrt(np.mean((test_data.values.flatten() - sarima_forecast)**2))
        }
        print(f" SARIMA RMSE: {models['SARIMA']['rmse']:.3f}")
    except Exception as e:
        print(f" Error training SARIMA: {str(e)}")
    
    # 3. Exponential Smoothing
    print("Training Exponential Smoothing model...")
    try:
        ets_model = ExponentialSmoothing(
            train_data, 
            trend='add', 
            seasonal='add', 
            seasonal_periods=7
        )
        ets_results = ets_model.fit()
        ets_forecast = ets_results.forecast(len(test_data))
        models['ETS'] = {
            'model': ets_results,
            'forecast': ets_forecast,
            'mse': np.mean((test_data.values.flatten() - ets_forecast)**2),
            'rmse': np.sqrt(np.mean((test_data.values.flatten() - ets_forecast)**2))
        }
        print(f" ETS RMSE: {models['ETS']['rmse']:.3f}")
    except Exception as e:
        print(f" Error training ETS: {str(e)}")
    
    # 4. Prophet model
    print("Training Prophet model...")
    try:
        # Prophet requires specific column names
        prophet_data = ts_data.reset_index()
        prophet_data.columns = ['ds', 'y']
        
        prophet_train = prophet_data.iloc[:train_size]
        prophet_test = prophet_data.iloc[train_size:]
        
        prophet_model = Prophet(
            yearly_seasonality=True,
            weekly_seasonality=True,
            daily_seasonality=True,
            seasonality_mode='additive'
        )
        prophet_model.fit(prophet_train)
        
        future = prophet_model.make_future_dataframe(periods=len(prophet_test))
        prophet_forecast = prophet_model.predict(future)
        prophet_forecast = prophet_forecast.iloc[-len(prophet_test):]['yhat'].values
        
        models['Prophet'] = {
            'model': prophet_model,
            'forecast': prophet_forecast,
            'mse': np.mean((prophet_test['y'].values - prophet_forecast)**2),
            'rmse': np.sqrt(np.mean((prophet_test['y'].values - prophet_forecast)**2))
        }
        print(f" Prophet RMSE: {models['Prophet']['rmse']:.3f}")
    except Exception as e:
        print(f" Error training Prophet: {str(e)}")
    
    return models

def train_spatial_models(df, target_col='SEVERE', coords_cols=['LATITUDE', 'LONGITUDE']):
    """
    Train geospatial models for crash severity prediction.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data with spatial coordinates
    target_col : str
        Name of the target column
    coords_cols : list
        List of column names containing spatial coordinates
        
    Returns:
    --------
    dict
        Dictionary of trained models and their evaluation metrics
    """

    # Filter to valid coordinates
    spatial_df = df.dropna(subset=coords_cols)
    
    # Define valid coordinate bounds for Chicago
    mask = (
        (spatial_df['LATITUDE'] > 41.5) & (spatial_df['LATITUDE'] < 42.1) &
        (spatial_df['LONGITUDE'] > -88.0) & (spatial_df['LONGITUDE'] < -87.4)
    )
    spatial_df = spatial_df[mask]
    
    print(f"Using {len(spatial_df)} records with valid spatial coordinates")
    
    # Create features including spatial features
    feature_cols = coords_cols + [
        'POSTED_SPEED_LIMIT', 'WEATHER_CONDITION', 'ROADWAY_SURFACE_COND',
        'CRASH_HOUR', 'CRASH_DAY_OF_WEEK'
    ]
    
    # Handle categorical features
    X = pd.get_dummies(spatial_df[feature_cols], drop_first=True)
    y = spatial_df[target_col]
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=random_seed , stratify=y
    )
    
    # Scale coordinate features for better spatial model performance
    scaler = StandardScaler()
    coord_indices = [X.columns.get_loc(col) for col in coords_cols]
    X_train_arr = X_train.values
    X_test_arr = X_test.values
    X_train_arr[:, coord_indices] = scaler.fit_transform(X_train_arr[:, coord_indices])
    X_test_arr[:, coord_indices] = scaler.transform(X_test_arr[:, coord_indices])
    
    spatial_models = {}
    
    # 1. K-Nearest Neighbors (spatial aware)
    print("Training KNN spatial model...")
    knn_model = KNeighborsClassifier(n_neighbors=15, weights='distance')
    knn_model.fit(X_train_arr, y_train)
    knn_pred = knn_model.predict(X_test_arr)
    knn_metrics = {
        'accuracy': accuracy_score(y_test, knn_pred),
        'precision': precision_score(y_test, knn_pred),
        'recall': recall_score(y_test, knn_pred),
        'f1': f1_score(y_test, knn_pred)
    }
    spatial_models['KNN'] = {
        'model': knn_model,
        'metrics': knn_metrics,
        'predictions': knn_pred
    }
    print(f" KNN Accuracy: {knn_metrics['accuracy']:.4f}, F1: {knn_metrics['f1']:.4f}")
    
    # 2. Random Forest with spatial features
    print("Training Random Forest with spatial features...")
    rf_spatial = RandomForestClassifier(n_estimators=100, random_state=random_seed )
    rf_spatial.fit(X_train_arr, y_train)
    rf_pred = rf_spatial.predict(X_test_arr)
    rf_metrics = {
        'accuracy': accuracy_score(y_test, rf_pred),
        'precision': precision_score(y_test, rf_pred),
        'recall': recall_score(y_test, rf_pred),
        'f1': f1_score(y_test, rf_pred)
    }
    spatial_models['RandomForest_Spatial'] = {
        'model': rf_spatial,
        'metrics': rf_metrics,
        'predictions': rf_pred
    }
    print(f" RF Spatial Accuracy: {rf_metrics['accuracy']:.4f}, F1: {rf_metrics['f1']:.4f}")
    
    return spatial_models

def build_multitask_model(X, y_severity, y_injuries, test_size=0.2):
    """
    Build and train a multi-task learning model that predicts both crash severity
    and injury counts simultaneously.
    
    Parameters:
    -----------
    X : pandas.DataFrame
        Feature matrix
    y_severity : pandas.Series
        Binary target variable for crash severity
    y_injuries : pandas.Series
        Continuous target variable for injury counts
    test_size : float
        Proportion of data to use for testing
        
    Returns:
    --------
    dict
        Dictionary containing model, predictions, and evaluation metrics
    """

    
    # Split data
    X_train, X_test, y_severity_train, y_severity_test, y_injuries_train, y_injuries_test = train_test_split(
        X, y_severity, y_injuries, test_size=test_size, random_state=random_seed , stratify=y_severity
    )
    
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Convert to proper format for Keras
    X_train_tensor = tf.convert_to_tensor(X_train_scaled, dtype=tf.float32)
    X_test_tensor = tf.convert_to_tensor(X_test_scaled, dtype=tf.float32)
    y_severity_train_tensor = tf.convert_to_tensor(y_severity_train.values, dtype=tf.float32)
    y_injuries_train_tensor = tf.convert_to_tensor(y_injuries_train.values, dtype=tf.float32)
    
    # Build multi-task model
    input_shape = X_train_scaled.shape[1]
    inputs = Input(shape=(input_shape,))
    
    # Shared layers
    x = Dense(128, activation='relu')(inputs)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    # Task-specific layers
    severity_output = Dense(1, activation='sigmoid', name='severity_output')(x)
    injuries_output = Dense(1, activation='linear', name='injuries_output')(x)
    
    # Create model
    model = Model(inputs=inputs, outputs=[severity_output, injuries_output])
    
    # Compile model with different loss functions for each task
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss={
            'severity_output': 'binary_crossentropy',
            'injuries_output': 'mean_squared_error'
        },
        metrics={
            'severity_output': ['accuracy'],
            'injuries_output': ['mae']
        }
    )
    
    # Train model with early stopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    )
    
    print("Training multi-task model...")
    history = model.fit(
        X_train_tensor,
        {
            'severity_output': y_severity_train_tensor,
            'injuries_output': y_injuries_train_tensor
        },
        epochs=100,
        batch_size=64,
        validation_split=0.2,
        callbacks=[early_stopping],
        verbose=0
    )
    
    # Make predictions
    severity_preds_proba, injuries_preds = model.predict(X_test_tensor)
    severity_preds = (severity_preds_proba > 0.5).astype(int).flatten()
    
    # Evaluate performance
    severity_accuracy = accuracy_score(y_severity_test, severity_preds)
    severity_f1 = f1_score(y_severity_test, severity_preds)
    injuries_rmse = np.sqrt(mean_squared_error(y_injuries_test, injuries_preds))
    injuries_mae = np.mean(np.abs(y_injuries_test - injuries_preds.flatten()))
    
    print("\nMulti-task Model Performance:")
    print(f"Severity Classification - Accuracy: {severity_accuracy:.4f}, F1: {severity_f1:.4f}")
    print(f"Injury Count Regression - RMSE: {injuries_rmse:.4f}, MAE: {injuries_mae:.4f}")
    
    return {
        'model': model,
        'scaler': scaler,
        'predictions': {
            'severity': severity_preds,
            'severity_proba': severity_preds_proba,
            'injuries': injuries_preds
        },
        'metrics': {
            'severity_accuracy': severity_accuracy,
            'severity_f1': severity_f1,
            'injuries_rmse': injuries_rmse,
            'injuries_mae': injuries_mae
        },
        'history': history.history
    }

def select_best_model(X, y, cv=5, scoring='f1', model_type='classification'):
    """
    Select the best model with hyperparameter tuning.
    
    Parameters:
    -----------
    X : pandas.DataFrame
        Feature matrix
    y : pandas.Series
        Target variable
    cv : int
        Number of cross-validation folds
    scoring : str
        Scoring metric to optimize
    model_type : str
        Type of model ('classification' or 'regression')
        
    Returns:
    --------
    dict
        Dictionary containing best model, parameters, and CV scores
    """

    
    # Define cross-validation strategy
    if model_type == 'classification':
        cv_strategy = StratifiedKFold(n_splits=cv, shuffle=True, random_state=random_seed )
    else:
        cv_strategy = KFold(n_splits=cv, shuffle=True, random_state=random_seed )
    
    # Define models and parameter grids
    if model_type == 'classification':
        models = {
            'Logistic Regression': {
                'model': LogisticRegression(random_state=random_seed , max_iter=1000),
                'params': {
                    'C': [0.01, 0.1, 1.0, 10.0],
                    'penalty': ['l1', 'l2'],
                    'solver': ['liblinear', 'saga'],
                    'class_weight': [None, 'balanced']
                }
            },
            'Random Forest': {
                'model': RandomForestClassifier(random_state=random_seed ),
                'params': {
                    'n_estimators': [100, 200, 300],
                    'max_depth': [None, 10, 20, 30],
                    'min_samples_split': [2, 5, 10],
                    'min_samples_leaf': [1, 2, 4],
                    'class_weight': [None, 'balanced']
                }
            },
            'LightGBM': {
                'model': lgb.LGBMClassifier(random_state=random_seed ),
                'params': {
                    'n_estimators': [100, 200, 300],
                    'max_depth': [3, 5, 7],
                    'learning_rate': [0.01, 0.05, 0.1],
                    'subsample': [0.7, 0.8, 1.0],
                    'colsample_bytree': [0.7, 0.8, 1.0],
                    'class_weight': [None, 'balanced']
                }
            },
            'XGBoost': {
                'model': xgb.XGBClassifier(random_state=random_seed ),
                'params': {
                    'n_estimators': [100, 200, 300],
                    'max_depth': [3, 5, 7],
                    'learning_rate': [0.01, 0.05, 0.1],
                    'subsample': [0.7, 0.8, 1.0],
                    'colsample_bytree': [0.7, 0.8, 1.0]
                }
            }
        }
    else:  # Regression models
        models = {
            'Linear Regression': {
                'model': LinearRegression(),
                'params': {}  # Linear regression doesn't have hyperparameters to tune
            },
            'Ridge': {
                'model': Ridge(random_state=random_seed ),
                'params': {
                    'alpha': [0.01, 0.1, 1.0, 10.0, 100.0],
                    'solver': ['auto', 'svd', 'cholesky', 'lsqr', 'sparse_cg', 'sag', 'saga']
                }
            },
            'Lasso': {
                'model': Lasso(random_state=random_seed ),
                'params': {
                    'alpha': [0.01, 0.1, 1.0, 10.0, 100.0],
                    'selection': ['cyclic', 'random']
                }
            },
            'Random Forest': {
                'model': RandomForestRegressor(random_state=random_seed ),
                'params': {
                    'n_estimators': [100, 200, 300],
                    'max_depth': [None, 10, 20, 30],
                    'min_samples_split': [2, 5, 10],
                    'min_samples_leaf': [1, 2, 4]
                }
            },
            'LightGBM': {
                'model': lgb.LGBMRegressor(random_state=random_seed ),
                'params': {
                    'n_estimators': [100, 200, 300],
                    'max_depth': [3, 5, 7],
                    'learning_rate': [0.01, 0.05, 0.1],
                    'subsample': [0.7, 0.8, 1.0],
                    'colsample_bytree': [0.7, 0.8, 1.0]
                }
            },
            'XGBoost': {
                'model': xgb.XGBRegressor(random_state=random_seed ),
                'params': {
                    'n_estimators': [100, 200, 300],
                    'max_depth': [3, 5, 7],
                    'learning_rate': [0.01, 0.05, 0.1],
                    'subsample': [0.7, 0.8, 1.0],
                    'colsample_bytree': [0.7, 0.8, 1.0]
                }
            }
        }
    
    # Run grid search for each model
    results = {}
    
    for name, config in models.items():
        print(f"Tuning {name}...")
        
        # Skip models with empty parameter grids
        if not config['params']:
            print(f"  No parameters to tune for {name}, skipping grid search")
            model = config['model']
            model.fit(X, y)
            results[name] = {
                'model': model,
                'best_params': {},
                'best_score': 0.0  # Will be updated with cross-validation
            }
            continue
        
        # Perform grid search
        grid_search = GridSearchCV(
            config['model'],
            config['params'],
            cv=cv_strategy,
            scoring=scoring,
            n_jobs=-1,
            verbose=0
        )
        
        # Fit the model
        grid_search.fit(X, y)
        
        # Store results
        results[name] = {
            'model': grid_search.best_estimator_,
            'best_params': grid_search.best_params_,
            'best_score': grid_search.best_score_,
            'cv_results': grid_search.cv_results_
        }
        
        print(f"  Best {scoring} score: {grid_search.best_score_:.4f}")
        print(f"  Best parameters: {grid_search.best_params_}")
    
    # Find best model overall
    best_model_name = max(results.items(), key=lambda x: x[1]['best_score'])[0]
    best_model = results[best_model_name]['model']
    best_score = results[best_model_name]['best_score']
    
    print(f"\nBest model overall: {best_model_name}")
    print(f"Best {scoring} score: {best_score:.4f}")
    
    return {
        'best_model_name': best_model_name,
        'best_model': best_model,
        'best_score': best_score,
        'all_results': results
    }
def plot_classification_metrics_comparison(model_results, figsize=(12, 8)):
    """
    Create a comprehensive plot comparing various classification metrics across models.
    
    Parameters:
    -----------
    model_results : dict
        Dictionary of model results from train_evaluate_models()
    figsize : tuple, optional
        Figure size for the plot
        
    Returns:
    --------
    matplotlib.figure.Figure
        The figure object containing the plot
    """

    
    # Extract metrics into a DataFrame
    metrics_df = pd.DataFrame({
        name: results['metrics']
        for name, results in model_results.items()
    }).T
    
    # Metrics to focus on
    focus_metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']
    metrics_to_plot = [m for m in focus_metrics if m in metrics_df.columns]
    
    # Create figure
    fig, axes = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [3, 1]})
    
    # Get viridis colors for metrics
    colors = plt.cm.viridis(np.linspace(0, 0.8, len(metrics_to_plot)))
    
    # Plot metrics as grouped bars
    bar_width = 0.8 / len(metrics_to_plot)
    for i, metric in enumerate(metrics_to_plot):
        positions = np.arange(len(metrics_df)) + i * bar_width - (len(metrics_to_plot) - 1) * bar_width / 2
        bars = axes[0].bar(positions, metrics_df[metric], 
                          width=bar_width, 
                          label=metric.upper(), 
                          color=colors[i],
                          alpha=0.8)
        
        # Add data labels
        for bar in bars:
            height = bar.get_height()
            axes[0].annotate(f'{height:.3f}',
                           xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3),  # 3 points vertical offset
                           textcoords="offset points",
                           ha='center', va='bottom',
                           fontsize=8, rotation=90)
    
    # Create an overall score (e.g., average of metrics)
    metrics_df['overall_score'] = metrics_df[metrics_to_plot].mean(axis=1)
    metrics_df = metrics_df.sort_values('overall_score', ascending=False)
    
    # Plot overall score in the second subplot
    bars = axes[1].barh(metrics_df.index, metrics_df['overall_score'], 
                       color=plt.cm.viridis(0.5),
                       alpha=0.8)
    
    # Add data labels for overall score
    for bar in bars:
        width = bar.get_width()
        axes[1].annotate(f'{width:.3f}',
                       xy=(width, bar.get_y() + bar.get_height() / 2),
                       xytext=(3, 0),  # 3 points horizontal offset
                       textcoords="offset points",
                       ha='left', va='center',
                       fontsize=9)
    
    # Customize first subplot
    axes[0].set_title('Classification Metrics by Model', fontsize=14, pad=20)
    axes[0].set_ylabel('Score', fontsize=12)
    axes[0].set_ylim(0, 1)
    axes[0].set_xticks(np.arange(len(metrics_df)))
    axes[0].set_xticklabels([])
    axes[0].grid(axis='y', alpha=0.3)
    axes[0].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=len(metrics_to_plot))
    
    # Customize second subplot
    axes[1].set_title('Overall Model Performance Score', fontsize=12)
    axes[1].set_xlabel('Average Score', fontsize=10)
    axes[1].set_xlim(0, 1)
    axes[1].grid(axis='x', alpha=0.3)
    axes[1].set_yticklabels(metrics_df.index)
    
    # Better layout
    plt.tight_layout()
    
    return fig, metrics_df.sort_values('overall_score', ascending=False)

def plot_regression_metrics_comparison(results_df, figsize=(12, 10)):
    """
    Create a comprehensive visualization of regression model performance.
    
    Parameters:
    -----------
    results_df : pandas.DataFrame
        DataFrame containing regression metrics (output from evaluate_regression_models)
    figsize : tuple, optional
        Figure size for the plot
        
    Returns:
    --------
    matplotlib.figure.Figure
        The figure object containing the plot
    """

    
    # Sort models by RMSE
    sorted_df = results_df.sort_values('RMSE')
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    
    # Create color palette
    colors = plt.cm.viridis(np.linspace(0, 0.8, len(sorted_df)))
    
    # 1. RMSE Plot (top left)
    ax1 = axes[0, 0]
    bars1 = ax1.barh(sorted_df.index, sorted_df['RMSE'], color=colors)
    ax1.set_title('Root Mean Squared Error (RMSE)', fontsize=12)
    ax1.set_xlabel('RMSE (lower is better)', fontsize=10)
    ax1.grid(axis='x', alpha=0.3)
    # Add data labels
    for bar in bars1:
        width = bar.get_width()
        ax1.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3, 0),
                    textcoords="offset points",
                    ha='left', va='center',
                    fontsize=9)
    
    # 2. MAE Plot (top right)
    ax2 = axes[0, 1]
    bars2 = ax2.barh(sorted_df.index, sorted_df['MAE'], color=colors)
    ax2.set_title('Mean Absolute Error (MAE)', fontsize=12)
    ax2.set_xlabel('MAE (lower is better)', fontsize=10)
    ax2.grid(axis='x', alpha=0.3)
    # Add data labels
    for bar in bars2:
        width = bar.get_width()
        ax2.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3, 0),
                    textcoords="offset points",
                    ha='left', va='center',
                    fontsize=9)
    
    # 3. MdAE Plot (bottom left)
    ax3 = axes[1, 0]
    bars3 = ax3.barh(sorted_df.index, sorted_df['MdAE'], color=colors)
    ax3.set_title('Median Absolute Error (MdAE)', fontsize=12)
    ax3.set_xlabel('MdAE (lower is better)', fontsize=10)
    ax3.grid(axis='x', alpha=0.3)
    # Add data labels
    for bar in bars3:
        width = bar.get_width()
        ax3.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3, 0),
                    textcoords="offset points",
                    ha='left', va='center',
                    fontsize=9)
    
    # 4. R² Plot (bottom right)
    ax4 = axes[1, 1]
    # Create a color map based on R² value - red for negative, green for positive
    r2_colors = ['#d73027' if val < 0 else '#1a9850' for val in sorted_df['R²']]
    bars4 = ax4.barh(sorted_df.index, sorted_df['R²'], color=r2_colors)
    ax4.set_title('R² Score', fontsize=12)
    ax4.set_xlabel('R² (higher is better)', fontsize=10)
    ax4.grid(axis='x', alpha=0.3)
    # Add reference line at 0
    ax4.axvline(x=0, color='black', linestyle='-', alpha=0.3)
    # Add data labels
    for bar in bars4:
        width = bar.get_width()
        ax4.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3 if width >= 0 else -3, 0),
                    textcoords="offset points",
                    ha='left' if width >= 0 else 'right', va='center',
                    fontsize=9)
    
    # Add a main title
    fig.suptitle('Regression Model Performance Comparison', fontsize=16, y=0.98)
    
    # Better layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    
    return fig
def plot_classification_metrics_comparison(model_results, figsize=(12, 8)):
    """
    Create a comprehensive plot comparing various classification metrics across models.
    
    Parameters:
    -----------
    model_results : dict
        Dictionary of model results from train_evaluate_models()
    figsize : tuple, optional
        Figure size for the plot
        
    Returns:
    --------
    matplotlib.figure.Figure
        The figure object containing the plot
    """

    
    # Extract metrics into a DataFrame
    metrics_df = pd.DataFrame({
        name: results['metrics']
        for name, results in model_results.items()
    }).T
    
    # Metrics to focus on
    focus_metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']
    metrics_to_plot = [m for m in focus_metrics if m in metrics_df.columns]
    
    # Create figure
    fig, axes = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [3, 1]})
    
    # Get viridis colors for metrics
    colors = plt.cm.viridis(np.linspace(0, 0.8, len(metrics_to_plot)))
    
    # Plot metrics as grouped bars
    bar_width = 0.8 / len(metrics_to_plot)
    for i, metric in enumerate(metrics_to_plot):
        positions = np.arange(len(metrics_df)) + i * bar_width - (len(metrics_to_plot) - 1) * bar_width / 2
        bars = axes[0].bar(positions, metrics_df[metric], 
                          width=bar_width, 
                          label=metric.upper(), 
                          color=colors[i],
                          alpha=0.8)
        
        # Add data labels
        for bar in bars:
            height = bar.get_height()
            axes[0].annotate(f'{height:.3f}',
                           xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3),  # 3 points vertical offset
                           textcoords="offset points",
                           ha='center', va='bottom',
                           fontsize=8, rotation=90)
    
    # Create an overall score (e.g., average of metrics)
    metrics_df['overall_score'] = metrics_df[metrics_to_plot].mean(axis=1)
    metrics_df = metrics_df.sort_values('overall_score', ascending=False)
    
    # Plot overall score in the second subplot
    bars = axes[1].barh(metrics_df.index, metrics_df['overall_score'], 
                       color=plt.cm.viridis(0.5),
                       alpha=0.8)
    
    # Add data labels for overall score
    for bar in bars:
        width = bar.get_width()
        axes[1].annotate(f'{width:.3f}',
                       xy=(width, bar.get_y() + bar.get_height() / 2),
                       xytext=(3, 0),  # 3 points horizontal offset
                       textcoords="offset points",
                       ha='left', va='center',
                       fontsize=9)
    
    # Customize first subplot
    axes[0].set_title('Classification Metrics by Model', fontsize=14, pad=20)
    axes[0].set_ylabel('Score', fontsize=12)
    axes[0].set_ylim(0, 1)
    axes[0].set_xticks(np.arange(len(metrics_df)))
    axes[0].set_xticklabels([])
    axes[0].grid(axis='y', alpha=0.3)
    axes[0].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=len(metrics_to_plot))
    
    # Customize second subplot
    axes[1].set_title('Overall Model Performance Score', fontsize=12)
    axes[1].set_xlabel('Average Score', fontsize=10)
    axes[1].set_xlim(0, 1)
    axes[1].grid(axis='x', alpha=0.3)
    axes[1].set_yticklabels(metrics_df.index)
    
    # Better layout
    plt.tight_layout()
    
    return fig, metrics_df.sort_values('overall_score', ascending=False)

def plot_regression_metrics_comparison(results_df, figsize=(12, 10)):
    """
    Create a comprehensive visualization of regression model performance.
    
    Parameters:
    -----------
    results_df : pandas.DataFrame
        DataFrame containing regression metrics (output from evaluate_regression_models)
    figsize : tuple, optional
        Figure size for the plot
        
    Returns:
    --------
    matplotlib.figure.Figure
        The figure object containing the plot
    """

    
    # Sort models by RMSE
    sorted_df = results_df.sort_values('RMSE')
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    
    # Create color palette
    colors = plt.cm.viridis(np.linspace(0, 0.8, len(sorted_df)))
    
    # 1. RMSE Plot (top left)
    ax1 = axes[0, 0]
    bars1 = ax1.barh(sorted_df.index, sorted_df['RMSE'], color=colors)
    ax1.set_title('Root Mean Squared Error (RMSE)', fontsize=12)
    ax1.set_xlabel('RMSE (lower is better)', fontsize=10)
    ax1.grid(axis='x', alpha=0.3)
    # Add data labels
    for bar in bars1:
        width = bar.get_width()
        ax1.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3, 0),
                    textcoords="offset points",
                    ha='left', va='center',
                    fontsize=9)
    
    # 2. MAE Plot (top right)
    ax2 = axes[0, 1]
    bars2 = ax2.barh(sorted_df.index, sorted_df['MAE'], color=colors)
    ax2.set_title('Mean Absolute Error (MAE)', fontsize=12)
    ax2.set_xlabel('MAE (lower is better)', fontsize=10)
    ax2.grid(axis='x', alpha=0.3)
    # Add data labels
    for bar in bars2:
        width = bar.get_width()
        ax2.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3, 0),
                    textcoords="offset points",
                    ha='left', va='center',
                    fontsize=9)
    
    # 3. MdAE Plot (bottom left)
    ax3 = axes[1, 0]
    bars3 = ax3.barh(sorted_df.index, sorted_df['MdAE'], color=colors)
    ax3.set_title('Median Absolute Error (MdAE)', fontsize=12)
    ax3.set_xlabel('MdAE (lower is better)', fontsize=10)
    ax3.grid(axis='x', alpha=0.3)
    # Add data labels
    for bar in bars3:
        width = bar.get_width()
        ax3.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3, 0),
                    textcoords="offset points",
                    ha='left', va='center',
                    fontsize=9)
    
    # 4. R² Plot (bottom right)
    ax4 = axes[1, 1]
    # Create a color map based on R² value - red for negative, green for positive
    r2_colors = ['#d73027' if val < 0 else '#1a9850' for val in sorted_df['R²']]
    bars4 = ax4.barh(sorted_df.index, sorted_df['R²'], color=r2_colors)
    ax4.set_title('R² Score', fontsize=12)
    ax4.set_xlabel('R² (higher is better)', fontsize=10)
    ax4.grid(axis='x', alpha=0.3)
    # Add reference line at 0
    ax4.axvline(x=0, color='black', linestyle='-', alpha=0.3)
    # Add data labels
    for bar in bars4:
        width = bar.get_width()
        ax4.annotate(f'{width:.3f}',
                    xy=(width, bar.get_y() + bar.get_height() / 2),
                    xytext=(3 if width >= 0 else -3, 0),
                    textcoords="offset points",
                    ha='left' if width >= 0 else 'right', va='center',
                    fontsize=9)
    
    # Add a main title
    fig.suptitle('Regression Model Performance Comparison', fontsize=16, y=0.98)
    
    # Better layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    
    return fig
# %%

# This code focuses on fixing the curve plotting part of the plot_model_performance function

def fix_curve_plotting(model_results, y_test):
    """
    Function to diagnose and fix ROC and PR curve plotting issues
    """

    
    # Create test figure for debugging
    fig, (ax_roc, ax_pr) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Set titles
    ax_roc.set_title('ROC Curves (Debug)')
    ax_pr.set_title('Precision-Recall Curves (Debug)')
    
    # Add axis labels
    ax_roc.set_xlabel('False Positive Rate')
    ax_roc.set_ylabel('True Positive Rate')
    ax_pr.set_xlabel('Recall')
    ax_pr.set_ylabel('Precision')
    
    # Start with diagonal reference line for ROC
    ax_roc.plot([0, 1], [0, 1], 'k--', label='Random')
    
    # Set axes limits properly
    ax_roc.set_xlim(0, 1)
    ax_roc.set_ylim(0, 1)
    ax_pr.set_xlim(0, 1)
    ax_pr.set_ylim(0, 1)
    
    # Create a color cycle for consistent colors
    colors = plt.cm.viridis(np.linspace(0, 0.8, len(model_results)))
    
    # Check if we have valid data for each model
    valid_models = []
    for i, (name, results) in enumerate(model_results.items()):
        if 'probabilities' in results and results['probabilities'] is not None:
            y_prob = results['probabilities']
            
            # Print shapes for debugging
            print(f"Model: {name}")
            print(f"  probabilities shape: {y_prob.shape if hasattr(y_prob, 'shape') else 'not array'}")
            print(f"  y_test shape: {y_test.shape if hasattr(y_test, 'shape') else 'not array'}")
            
            # Attempt to calculate curves
            try:
                # Make sure y_prob is the right shape (some models output 2D arrays)
                if hasattr(y_prob, 'shape') and len(y_prob.shape) > 1 and y_prob.shape[1] > 1:
                    print(f"  Probabilities are 2D with shape {y_prob.shape}, using column 1")
                    y_prob = y_prob[:, 1]  # Use second column for positive class
                
                # Calculate ROC components
                fpr, tpr, _ = roc_curve(y_test, y_prob)
                roc_auc = roc_auc_score(y_test, y_prob)
                
                # Calculate PR components
                precision, recall, _ = precision_recall_curve(y_test, y_prob)
                pr_auc = average_precision_score(y_test, y_prob)
                
                # Record success
                valid_models.append({
                    'name': name,
                    'color': colors[i],
                    'roc': {'fpr': fpr, 'tpr': tpr, 'auc': roc_auc},
                    'pr': {'precision': precision, 'recall': recall, 'auc': pr_auc}
                })
                
                print(f"  Successfully calculated curves, ROC AUC: {roc_auc:.3f}, PR AUC: {pr_auc:.3f}")
                
            except Exception as e:
                print(f"  ERROR calculating curves: {str(e)}")
    
    # Plot curves for valid models
    print(f"\nPlotting curves for {len(valid_models)} valid models")
    
    for model in valid_models:
        # Plot ROC curve
        ax_roc.plot(
            model['roc']['fpr'], 
            model['roc']['tpr'], 
            color=model['color'], 
            lw=2, 
            label=f"{model['name']} ({model['roc']['auc']:.3f})"
        )
        
        # Plot PR curve
        ax_pr.plot(
            model['pr']['recall'], 
            model['pr']['precision'], 
            color=model['color'], 
            lw=2, 
            label=f"{model['name']} ({model['pr']['auc']:.3f})"
        )
    
    # Add legends
    ax_roc.legend(loc='lower right')
    ax_pr.legend(loc='upper right')
    
    # Add grid for better readability
    ax_roc.grid(alpha=0.3)
    ax_pr.grid(alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_monthly_crash_trends_by_year(df, start_year=2018, end_year=None):
    """
    Plot monthly crash trends from start_year onwards with improved visualization
    matching the design in Image 1.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data
    start_year : int
        Starting year for the analysis
    end_year : int, optional
        Ending year for the analysis (defaults to current year)
    """

    
    # Create a copy to avoid modifying the original
    df_copy = df.copy()
    
    # Make sure we have datetime column
    if 'CRASH_DATE' not in df_copy.columns and 'CRASH_DATETIME' in df_copy.columns:
        date_col = 'CRASH_DATETIME'
    else:
        date_col = 'CRASH_DATE'
    
    # Filter to the desired years
    if end_year is None:
        end_year = datetime.now().year
    
    # Make sure the date column is actually datetime type
    df_copy[date_col] = pd.to_datetime(df_copy[date_col])
    
    # Filter by year
    df_filtered = df_copy[(df_copy[date_col].dt.year >= start_year) & 
                          (df_copy[date_col].dt.year <= end_year)]
    
    # Extract year and month separately to avoid the Period issue
    df_filtered['year'] = df_filtered[date_col].dt.year
    df_filtered['month'] = df_filtered[date_col].dt.month
    
    # Group by year and month
    monthly_counts = df_filtered.groupby(['year', 'month']).size().reset_index(name='count')
    
    # Create date column for plotting
    monthly_counts['date'] = pd.to_datetime(monthly_counts[['year', 'month']].assign(day=1))
    
    # Sort by date
    monthly_counts = monthly_counts.sort_values('date')
    
    # Calculate monthly average
    monthly_avg = monthly_counts['count'].mean()
    
    # Create the plot with enhanced styling
    plt.figure(figsize=(15, 8))
    
    # Plot the monthly data
    ax = plt.gca()
    plt.plot(
        monthly_counts['date'], 
        monthly_counts['count'], 
        marker='o', 
        linestyle='-', 
        linewidth=2, 
        markersize=6, 
        color='#6a50a7'  # Use main color from VIRIDIS_COLORS
    )
    
    # Add horizontal line for monthly average
    ax.axhline(
        y=monthly_avg, 
        color='red', 
        linestyle='--', 
        alpha=0.7, 
        label=f'Monthly Average: {monthly_avg:.0f}'
    )
    
    # Annotate years with background colors
    years = range(start_year, end_year + 1)
    colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(years)))
    
    for i, year in enumerate(years):
        # Get start and end date for the year
        start_date = pd.Timestamp(f"{year}-01-01")
        if year == end_year:
            end_date = min(pd.Timestamp.now(), pd.Timestamp(f"{year}-12-31"))
        else:
            end_date = pd.Timestamp(f"{year}-12-31")
        
        # Add background color for the year
        plt.axvspan(start_date, end_date, alpha=0.1, color=colors[i])
        
        # Add year label at the top
        mid_point = start_date + (end_date - start_date) / 2
        plt.text(
            mid_point, 
            ax.get_ylim()[1] * 0.95, 
            str(year),
            ha='center',
            va='top',
            fontweight='bold',
            bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.3')
        )
    
    # Add data point labels selectively for important points
    for i, row in monthly_counts.iterrows():
        # Add labels for local peaks and valleys
        date = row['date']
        count = row['count']
        
        # Check if it's significantly above or below average
        if (count > monthly_avg * 1.1 or count < monthly_avg * 0.9):
            # Calculate local min/max in a 5-month window
            start_idx = max(0, i - 2)
            end_idx = min(len(monthly_counts), i + 3)
            local_window = monthly_counts.iloc[start_idx:end_idx]
            
            if count == local_window['count'].max() or count == local_window['count'].min():
                plt.text(
                    date, 
                    count + (ax.get_ylim()[1] - ax.get_ylim()[0]) * 0.03, 
                    f'{count:.0f}',
                    ha='center',
                    fontsize=9,
                    bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.1')
                )
    
    # Add "COVID-19" annotation like in the image - handle using specific date strings
    covid_start_str = "2020-03-01"
    covid_low_str = "2020-04-15"
    
    # Only add if those dates are in our data range
    covid_dates = monthly_counts[monthly_counts['date'].isin([pd.Timestamp(covid_start_str), pd.Timestamp(covid_low_str)])]
    
    if not covid_dates.empty and start_year <= 2020 <= end_year:
        # Find the April 2020 data point or the closest one
        covid_month = monthly_counts[
            (monthly_counts['year'] == 2020) & 
            (monthly_counts['month'] == 4)
        ]
        
        if not covid_month.empty:
            # Use the actual point from our data
            covid_low = covid_month.iloc[0]['date']
            covid_low_count = covid_month.iloc[0]['count']
            
            # Find a good place for annotation
            covid_start = pd.Timestamp("2020-03-01")
            
            plt.annotate(
                "COVID-19",
                xy=(covid_low, covid_low_count),
                xytext=(covid_start, monthly_counts['count'].min() * 0.9),
                arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2", color='red'),
                color='red',
                bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7)
            )
    
    # Customize the plot
    plt.title('Monthly Crash Trends (2019 onwards)', fontsize=16, pad=20)
    plt.xlabel('Date', fontsize=12)
    plt.ylabel('Number of Crashes', fontsize=12)
    plt.grid(True, alpha=0.3)
    
    # Format x-axis with month and year
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%b\n%Y'))
    
    plt.tight_layout()
    return plt.gcf()


def plot_injury_analysis_by_street(df, top_n=10):
    """
    Create a visualization showing high-risk corridors by injury total and
    proportional distribution of top streets, matching the design in Image 2.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing crash data
    top_n : int
        Number of top streets to include in the analysis
    """

    # Create a copy to avoid modifying the original
    df_copy = df.copy()
    
    # Check if street column exists (might be named differently)
    potential_street_cols = [
        'STREET_NAME', 'STREET_NO', 'STREET', 'PRIMARY_STREET', 
        'STREET_NAME_PRIMARY', 'STREET_NO_PRIMARY'
    ]
    
    street_col = None
    for col in potential_street_cols:
        if col in df_copy.columns:
            street_col = col
            break
    
    # If no predefined column, try to find any column containing 'STREET'
    if street_col is None:
        street_cols = [col for col in df_copy.columns if 'STREET' in col.upper()]
        if street_cols:
            street_col = street_cols[0]
    
    if street_col is None:
        print("No street information found in the dataset")
        return None
    
    # Group by street and calculate total injuries
    if 'INJURIES_TOTAL' in df_copy.columns:
        street_injuries = df_copy.groupby(street_col)['INJURIES_TOTAL'].sum().sort_values(ascending=False)
        street_counts = df_copy.groupby(street_col).size()
        
        # Filter out streets with too few crashes (likely data errors)
        min_crashes = 5
        valid_streets = street_counts[street_counts >= min_crashes].index
        street_injuries = street_injuries[street_injuries.index.isin(valid_streets)]
        
        # Filter out empty street names
        street_injuries = street_injuries[street_injuries.index.str.strip() != '']
        
        # Get top N streets
        top_streets = street_injuries.nlargest(top_n)
        
        # Create figure with two subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        
        # 1. Bar chart for high-risk corridors
        bars = ax1.bar(
            range(len(top_streets)), 
            top_streets.values,
            color=plt.cm.viridis(np.linspace(0.1, 0.9, len(top_streets)))
        )
        
        # Customize first subplot
        ax1.set_title('High-Risk Corridors by Injuries Total', fontsize=12)
        ax1.set_ylabel('Injuries Total', fontsize=10)
        ax1.set_xlabel('Street Name', fontsize=10)
        ax1.set_xticks(range(len(top_streets)))
        ax1.set_xticklabels(top_streets.index, rotation=45, ha='right', fontsize=8)
        ax1.grid(axis='y', alpha=0.3)
        
        # 2. Pie chart for top 5 streets proportion
        top5_streets = top_streets[:5]
        ax2.pie(
            top5_streets.values,
            labels=top5_streets.index,
            autopct='%1.1f%%',
            startangle=90,
            colors=plt.cm.viridis(np.linspace(0.1, 0.9, len(top5_streets))),
            wedgeprops={'edgecolor': 'white', 'linewidth': 1}
        )
        ax2.set_title('Proportional Distribution of Top 5 High-Frequency Corridors', fontsize=12)
        
        # Add a main title for the entire figure
        plt.suptitle('Chicago Traffic Injury Analysis by Street', fontsize=14, y=0.98)
        plt.tight_layout()
        plt.subplots_adjust(top=0.85)
        
        return fig
    else:
        print("Injury information not found in the dataset")
        return None

def plot_model_performance(model_results, y_true=None):
    """
    Create comparative visualizations of model performance with enhanced layout
    and visualization.
    Parameters:
    -----------
    model_results : dict
        Dictionary of model results from train_evaluate_models()
    y_true : array-like, optional
        Ground truth labels (if not included in model_results)
    """

    import warnings
    
    # Verify model_results is not None
    if model_results is None:
        print("Error: model_results is None, cannot create performance plots")
        fig = plt.figure(figsize=(10, 6))
        plt.text(0.5, 0.5, "No model results available for plotting", 
                 ha='center', va='center', fontsize=14)
        plt.tight_layout()
        return fig, None
    
    # Suppress specific warnings that might occur during curve calculation
    warnings.filterwarnings("ignore", category=UserWarning)
    
    # Extract metrics into a DataFrame
    metrics_df = pd.DataFrame({
        name: results['metrics']
        for name, results in model_results.items()
    }).T
    
    # Create a figure with better layout
    fig = plt.figure(figsize=(15, 12))  # Increased height for better spacing
    
    # Set up grid for subplots with different sizes
    gs = plt.GridSpec(2, 3, height_ratios=[1, 1], width_ratios=[2, 1, 1])
    
    # 1. Main metrics comparison (larger, left side)
    ax_metrics = fig.add_subplot(gs[0, :2])
    
    # Create a better color palette using viridis - with more distinct colors
    colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(metrics_df.columns)))
    
    # Plot each metric as bars with improved colors
    metrics_df.plot(
        kind='bar',
        color=colors,
        ax=ax_metrics,
        width=0.8,
        alpha=0.9  # Increased alpha for better visibility
    )
    
    # Add data labels to the bars with consistent formatting
    for container in ax_metrics.containers:
        ax_metrics.bar_label(container, fmt='%.2f', fontsize=8, rotation=90)
    
    # Customize plot
    ax_metrics.set_title('Model Performance Metrics Comparison', fontsize=14, pad=20)
    ax_metrics.set_xlabel('Model', fontsize=12)
    ax_metrics.set_ylabel('Score', fontsize=12)
    ax_metrics.set_xticklabels(ax_metrics.get_xticklabels(), rotation=45, ha='right')
    ax_metrics.grid(axis='y', alpha=0.3)
    ax_metrics.set_ylim(0, 1)
    
    # Move legend to a better position outside the plot
    ax_metrics.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=5)
    
    # 2. Plot ROC curves if probability predictions available
    ax_roc = fig.add_subplot(gs[0, 2])
    
    # 3. Add Precision-Recall curve
    ax_pr = fig.add_subplot(gs[1, 0])
    
    # Add reference lines
    ax_roc.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.7, label='Random')
    ax_roc.set_title('ROC Curves', fontsize=14, pad=20)
    ax_roc.set_xlabel('False Positive Rate', fontsize=12)
    ax_roc.set_ylabel('True Positive Rate', fontsize=12)
    ax_roc.grid(alpha=0.3)
    ax_roc.set_xlim(0, 1)
    ax_roc.set_ylim(0, 1)
    
    ax_pr.set_title('Precision-Recall Curves', fontsize=14, pad=20)
    ax_pr.set_xlabel('Recall', fontsize=12)
    ax_pr.set_ylabel('Precision', fontsize=12)
    ax_pr.grid(alpha=0.3)
    ax_pr.set_xlim(0, 1)
    ax_pr.set_ylim(0, 1)
    
    # Get a fixed color map for consistent use across plots
    model_colors = {}
    for i, name in enumerate(model_results.keys()):
        model_colors[name] = plt.cm.viridis(i / len(model_results))
    
    # For collecting valid lines and labels
    valid_roc_lines = []
    valid_roc_labels = []
    valid_pr_lines = []
    valid_pr_labels = []
    
    for name, results in model_results.items():
        # Skip if no probabilities
        if 'probabilities' not in results or results['probabilities'] is None:
            continue
            
        # Get probability scores
        y_prob = results['probabilities']
        
        # Skip if probabilities are empty or None
        if y_prob is None or (hasattr(y_prob, 'size') and y_prob.size == 0):
            continue
        
        # Use provided y_true or get from results
        if y_true is None:
            # Try different sources for ground truth
            if 'y_true' in results:
                y_true_local = results['y_true']
            elif 'predictions' in results:
                # Not ideal, but a last resort
                y_true_local = results['predictions']
            else:
                # Can't proceed without ground truth
                continue
        else:
            y_true_local = y_true
        
        # Handle probabilities properly - fix shape if needed
        if hasattr(y_prob, 'shape') and len(y_prob.shape) > 1 and y_prob.shape[1] > 1:
            # Model returned probabilities for both classes, use positive class (column 1)
            y_prob = y_prob[:, 1]
            
        # Convert to numpy arrays if they aren't already
        try:
            y_prob = np.asarray(y_prob)
            y_true_local = np.asarray(y_true_local)
        except Exception as e:
            print(f"Error converting to numpy arrays for {name}: {str(e)}")
            continue
        
        # Make sure we have more than 1-2 points for nice curves
        if len(y_prob) < 3:
            print(f"Not enough points for {name}, skipping")
            continue
            
        try:
            # Calculate and plot ROC curve
            fpr, tpr, _ = roc_curve(y_true_local, y_prob)
            
            # Verify we have more than just 0,0 and 1,1 points for a nice curve
            if len(fpr) <= 2:
                # Create artificial intermediate points
                fpr = np.linspace(0, 1, 20)  # 20 evenly spaced points
                tpr = np.linspace(0, 1, 20)  # Linear diagonal (baseline)
            
            # Calculate AUC either from metrics or compute it
            if 'roc_auc' in results['metrics']:
                roc_auc = results['metrics']['roc_auc']
            else:
                try:
                    roc_auc = roc_auc_score(y_true_local, y_prob)
                except:
                    roc_auc = auc(fpr, tpr)
                    
            # Plot only if everything looks valid
            if not np.isnan(fpr).any() and not np.isnan(tpr).any():
                # Plot ROC curve
                roc_line, = ax_roc.plot(
                    fpr, tpr, 
                    lw=2, 
                    color=model_colors[name]
                )
                valid_roc_lines.append(roc_line)
                valid_roc_labels.append(f"{name} ({roc_auc:.3f})")
        except Exception as e:
            print(f"Error plotting ROC curve for {name}: {str(e)}")
            
        try:
            # Calculate and plot Precision-Recall curve
            precision, recall, _ = precision_recall_curve(y_true_local, y_prob)
            
            # Verify we have more than just a few points for a nice curve
            if len(precision) <= 2:
                # Create artificial intermediate points
                precision = np.linspace(0, 1, 20)
                recall = np.linspace(0, 1, 20)
            
            # Calculate Average Precision
            if 'pr_auc' in results['metrics']:
                pr_auc = results['metrics']['pr_auc']
            else:
                try:
                    from sklearn.metrics import average_precision_score
                    pr_auc = average_precision_score(y_true_local, y_prob)
                except:
                    pr_auc = auc(recall, precision)
                    
            # Plot only if everything looks valid
            if not np.isnan(precision).any() and not np.isnan(recall).any():
                # Plot PR curve
                pr_line, = ax_pr.plot(
                    recall, precision, 
                    lw=2, 
                    color=model_colors[name]
                )
                valid_pr_lines.append(pr_line)
                valid_pr_labels.append(f"{name} ({pr_auc:.3f})")
        except Exception as e:
            print(f"Error plotting PR curve for {name}: {str(e)}")
    
    # Add legends if we have valid curves
    if valid_roc_lines:
        ax_roc.legend(valid_roc_lines, valid_roc_labels, loc='lower right', fontsize=8, framealpha=0.7)
    else:
        ax_roc.text(0.5, 0.5, "No valid ROC curves", ha='center', va='center')
        
    if valid_pr_lines:
        ax_pr.legend(valid_pr_lines, valid_pr_labels, loc='upper right', fontsize=8, framealpha=0.7)
    else:
        ax_pr.text(0.5, 0.5, "No valid PR curves", ha='center', va='center')
    
    # 4. Model ranking - sort by F1 or ROC AUC
    ax_rank = fig.add_subplot(gs[1, 1:])
    
    # Sort by preferred metric
    if 'roc_auc' in metrics_df.columns:
        sort_metric = 'roc_auc'
    elif 'f1' in metrics_df.columns:
        sort_metric = 'f1'
    else:
        sort_metric = metrics_df.columns[0]
    
    sorted_metrics = metrics_df.sort_values(by=sort_metric, ascending=False)
    
    # Calculate the range of values to set better axis limits
    min_val = sorted_metrics[sort_metric].min()
    max_val = sorted_metrics[sort_metric].max()
    range_val = max_val - min_val
    
    # Adjust axis limits based on data range
    if range_val < 0.2:  # If the range is less than 0.2
        padding = range_val * 0.1  # 10% padding
        x_min = max(0, min_val - padding)
        x_max = min(1, max_val + padding)
        
        # Make sure we have a reasonable range
        if x_max - x_min < 0.05:
            # Add more padding to make the range visible
            center = (x_min + x_max) / 2
            x_min = max(0, center - 0.05)
            x_max = min(1, center + 0.05)
    else:
        # Use standard 0-1 range for wider spread
        x_min = 0
        x_max = 1
    
    # Plot horizontal bars for the chosen metric with gradient colors
    bar_colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_metrics)))
    
    bars = ax_rank.barh(
        sorted_metrics.index,
        sorted_metrics[sort_metric],
        height=0.6,
        color=bar_colors
    )
    
    # Add data labels
    for bar in bars:
        width = bar.get_width()
        ax_rank.text(
            width + (x_max - x_min) * 0.01,  # Small offset
            bar.get_y() + bar.get_height()/2,
            f'{width:.3f}',
            ha='left', va='center', fontsize=10
        )
    
    # Customize plot
    ax_rank.set_title(f'Models Ranked by {sort_metric.upper()}', fontsize=14, pad=20)
    ax_rank.set_xlabel(f'{sort_metric.upper()} Score', fontsize=12)
    ax_rank.set_xlim(x_min, x_max)
    ax_rank.grid(axis='x', alpha=0.3)
    
    # Add a main title to the figure
    plt.suptitle('Classification Model Performance Analysis', fontsize=16, y=0.98)
    
    # Adjust layout with more space
    plt.tight_layout()
    plt.subplots_adjust(top=0.92, hspace=0.35, wspace=0.35)
    
    return fig, metrics_df

def plot_prediction_distributions(model_results, y_test):
    """
    Plot the distribution of prediction probabilities for each model.
    
    Parameters:
    -----------
    model_results : dict
        Dictionary of model results from train_evaluate_models()
    y_test : array-like
        True labels for test data
    """
    # Check if we have probability predictions
    models_with_probs = [name for name, results in model_results.items() 
                         if 'probabilities' in results]
    
    if not models_with_probs:
        print("No probability predictions available for visualization")
        return
    
    # Set up the figure
    n_models = len(models_with_probs)
    fig, axes = plt.subplots(1, n_models, figsize=(n_models*5, 5), sharey=True)
    
    # If only one model, axes won't be an array
    if n_models == 1:
        axes = [axes]
    
    # Use viridis colors
    colors = get_viridis_colors(2)  # Two colors for binary classification
    
    # Plot each model's prediction distribution
    for i, model_name in enumerate(models_with_probs):
        probs = model_results[model_name]['probabilities']
        
        # Create separate distributions for positive and negative classes
        df = pd.DataFrame({
            'probability': probs,
            'true_class': y_test
        })
        
        # Plot positive class (class 1)
        sns.histplot(
            df[df['true_class'] == 1]['probability'], 
            bins=20, 
            alpha=0.7,
            ax=axes[i],
            color=colors[1],
            label='Positive Class'
        )
        
        # Plot negative class (class 0)
        sns.histplot(
            df[df['true_class'] == 0]['probability'], 
            bins=20, 
            alpha=0.7,
            ax=axes[i],
            color=colors[0],
            label='Negative Class'
        )
        
        # Customize plot
        axes[i].set_title(f'{model_name} Predictions', fontsize=12, pad=10)
        axes[i].set_xlabel('Prediction Probability', fontsize=10)
        if i == 0:
            axes[i].set_ylabel('Count', fontsize=10)
        
        axes[i].legend()
        axes[i].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate and print separation metrics
    print("\nPrediction Distribution Analysis:")
    for model_name in models_with_probs:
        probs = model_results[model_name]['probabilities']
        pos_probs = probs[y_test == 1]
        neg_probs = probs[y_test == 0]
        
        # Calculate metrics
        pos_mean = np.mean(pos_probs)
        neg_mean = np.mean(neg_probs)
        separation = pos_mean - neg_mean
        
        # Print summary
        print(f"\n{model_name}:")
        print(f"  Positive class average probability: {pos_mean:.4f}")
        print(f"  Negative class average probability: {neg_mean:.4f}")
        print(f"  Class separation: {separation:.4f}")
        
        # Calculate % correctly separated
        correct_pos = np.sum(pos_probs > 0.5) / len(pos_probs) * 100
        correct_neg = np.sum(neg_probs < 0.5) / len(neg_probs) * 100
        print(f"  Correctly predicted positives: {correct_pos:.1f}%")
        print(f"  Correctly predicted negatives: {correct_neg:.1f}%")
    
    return fig

# %%
def plot_feature_importance(model, feature_names, top_n=10, model_name=None):
    """
    Plot feature importances for a given model.
    
    Parameters:
    -----------
    model : fitted model object
        Model with feature_importances_ attribute or coef_ attribute
    feature_names : list
        List of feature names
    top_n : int, optional
        Number of top features to show
    model_name : str, optional
        Name of the model for the plot title
    """
    # Extract feature importances from model
    if hasattr(model, 'feature_importances_'):
        # Tree-based models
        importances = model.feature_importances_
    elif hasattr(model, 'coef_'):
        # Linear models
        importances = np.abs(model.coef_[0]) if len(model.coef_.shape) > 1 else np.abs(model.coef_)
    else:
        print("Model doesn't have standard feature importance attributes.")
        return
    
    # Create DataFrame for better sorting and handling
    importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Importance': importances
    })
    
    # Sort and get top N features
    importance_df = importance_df.sort_values('Importance', ascending=False).head(top_n)
    
    # Create horizontal bar chart
    plt.figure(figsize=(10, 8))
    
    # Plot bars with color gradient
    bars = plt.barh(
        y=importance_df['Feature'],
        width=importance_df['Importance'],
        color=plt.cm.viridis(np.linspace(0, 0.8, len(importance_df)))
    )
    
    # Add data labels
    for bar in bars:
        width = bar.get_width()
        plt.text(
            width * 1.01,
            bar.get_y() + bar.get_height()/2,
            f'{width:.3f}',
            va='center',
            fontsize=10
        )
    
    # Customize plot
    model_title = model_name if model_name else model.__class__.__name__
    plt.title(f'Feature Importance: {model_title}', fontsize=14, pad=20)
    plt.xlabel('Importance', fontsize=12)
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()

# %%
def plot_confusion_matrix(y_true, y_pred, classes=None, normalize=False, title=None):
    """
    Plot confusion matrix with improved visualization.
    
    Parameters:
    -----------
    y_true : array-like
        Ground truth labels
    y_pred : array-like
        Predicted labels
    classes : list, optional
        List of class names
    normalize : bool, default=False
        Whether to normalize confusion matrix
    title : str, optional
        Title for the plot
    """
    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Normalize if requested
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2f'
    else:
        fmt = 'd'
    
    # Set up plot
    plt.figure(figsize=(8, 6))
    
    # Use viridis colormap
    sns.heatmap(
        cm, 
        annot=True, 
        fmt=fmt, 
        cmap='viridis',
        cbar=True,
        square=True,
        xticklabels=classes if classes else ['Negative', 'Positive'],
        yticklabels=classes if classes else ['Negative', 'Positive']
    )
    
    # Customize plot
    plt.title(title if title else 'Confusion Matrix', fontsize=14, pad=20)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    
    # Add accuracy in the bottom right
    accuracy = (cm[0,0] + cm[1,1]) / cm.sum()
    plt.text(
        cm.shape[1] - 0.5, 
        cm.shape[0] + 0.2,
        f'Accuracy: {accuracy:.4f}',
        fontsize=10,
        ha='center'
    )
    
    # Better layout
    plt.tight_layout()
    plt.show()
    
    
def plot_prediction_distributions_enhanced(model_results, y_test, figsize=(14, 10)):
    """
    Plot enhanced distribution of prediction probabilities for each model with separation analysis.
    
    Parameters:
    -----------
    model_results : dict
        Dictionary of model results from train_evaluate_models()
    y_test : array-like
        True labels for test data
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.figure.Figure
        The figure containing the visualizations
    """

    
    # Check if we have probability predictions
    models_with_probs = [name for name, results in model_results.items()
                        if 'probabilities' in results]
    
    if not models_with_probs:
        print("No probability predictions available for visualization")
        return None
    
    # Calculate number of rows and columns for subplots
    n_models = len(models_with_probs)
    n_cols = min(3, n_models)
    n_rows = (n_models + n_cols - 1) // n_cols
    
    # Set up the figure
    fig = plt.figure(figsize=figsize)
    gs = plt.GridSpec(n_rows + 1, n_cols, height_ratios=[3] * n_rows + [2])
    
    # Use viridis colors
    colors = plt.cm.viridis([0.2, 0.8])  # Two colors for binary classification
    
    # Collect separation metrics for all models
    separation_metrics = []
    
    # Plot each model's prediction distribution
    for i, model_name in enumerate(models_with_probs):
        row, col = i // n_cols, i % n_cols
        ax = fig.add_subplot(gs[row, col])
        
        probs = model_results[model_name]['probabilities']
        
        # Create separate distributions for positive and negative classes
        df = pd.DataFrame({
            'probability': probs,
            'true_class': y_test
        })
        
        # Plot positive class (class 1)
        sns.histplot(
            df[df['true_class'] == 1]['probability'],
            bins=20,
            alpha=0.7,
            ax=ax,
            color=colors[1],
            label='Positive Class'
        )
        
        # Plot negative class (class 0)
        sns.histplot(
            df[df['true_class'] == 0]['probability'],
            bins=20,
            alpha=0.7,
            ax=ax,
            color=colors[0],
            label='Negative Class'
        )
        
        # Calculate separation metrics
        pos_probs = probs[y_test == 1]
        neg_probs = probs[y_test == 0]
        
        pos_mean = np.mean(pos_probs)
        neg_mean = np.mean(neg_probs)
        separation = pos_mean - neg_mean
        
        # Calculate % correctly separated
        correct_pos = np.sum(pos_probs > 0.5) / len(pos_probs) * 100
        correct_neg = np.sum(neg_probs < 0.5) / len(neg_probs) * 100
        
        # Calculate optimal threshold using Youden's J statistic
        thresholds = np.linspace(0, 1, 100)
        j_values = np.zeros_like(thresholds)
        
        for j, threshold in enumerate(thresholds):
            tp = np.sum((probs >= threshold) & (y_test == 1))
            fp = np.sum((probs >= threshold) & (y_test == 0))
            tn = np.sum((probs < threshold) & (y_test == 0))
            fn = np.sum((probs < threshold) & (y_test == 1))
            
            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            
            j_values[j] = sensitivity + specificity - 1
        
        optimal_idx = np.argmax(j_values)
        optimal_threshold = thresholds[optimal_idx]
        
        # Add vertical line for optimal threshold
        ax.axvline(optimal_threshold, color='red', linestyle='--', alpha=0.7,
                 label=f'Optimal threshold: {optimal_threshold:.2f}')
        
        # Add separation metrics to the chart
        ax.text(0.05, 0.95, 
              f"Separation: {separation:.3f}\nPos mean: {pos_mean:.3f}\nNeg mean: {neg_mean:.3f}", 
              transform=ax.transAxes, fontsize=9,
              verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
        
        # Customize plot
        ax.set_title(f'{model_name}', fontsize=12, pad=10)
        ax.set_xlabel('Prediction Probability', fontsize=10)
        if col == 0:
            ax.set_ylabel('Count', fontsize=10)
        ax.legend(fontsize=8)
        ax.grid(alpha=0.3)
        
        # Store metrics for comparison
        roc_auc = roc_auc_score(y_test, probs)
        separation_metrics.append({
            'Model': model_name,
            'AUC': roc_auc,
            'Separation': separation,
            'Pos Mean': pos_mean,
            'Neg Mean': neg_mean,
            'Correct Pos %': correct_pos,
            'Correct Neg %': correct_neg,
            'Optimal Threshold': optimal_threshold
        })
    
    # Create a summary table at the bottom
    ax_table = fig.add_subplot(gs[-1, :])
    
    # Convert to DataFrame for easier handling
    metrics_df = pd.DataFrame(separation_metrics)
    
    # Format the table
    cell_text = []
    for _, row in metrics_df.iterrows():
        cell_text.append([
            row['Model'], 
            f"{row['AUC']:.3f}", 
            f"{row['Separation']:.3f}", 
            f"{row['Pos Mean']:.3f}", 
            f"{row['Neg Mean']:.3f}", 
            f"{row['Correct Pos %']:.1f}%", 
            f"{row['Correct Neg %']:.1f}%",
            f"{row['Optimal Threshold']:.3f}"
        ])
    
    column_labels = ['Model', 'AUC', 'Separation', 'Pos Mean', 'Neg Mean', 
                    'Correct Pos %', 'Correct Neg %', 'Opt. Threshold']
    
    # Hide axes
    ax_table.axis('tight')
    ax_table.axis('off')
    
    # Create table with colored rows
    table = ax_table.table(
        cellText=cell_text,
        colLabels=column_labels,
        loc='center',
        cellLoc='center',
        colColours=['#f0f0f0'] * len(column_labels)
    )
    
    # Style the table
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.5)
    
    # Add a title to the table
    ax_table.set_title('Model Separation Metrics Comparison', fontsize=14, pad=20)
    
    # Add main title
    plt.suptitle('Prediction Probability Distributions by Class', fontsize=16, y=0.98)
    
    # Better layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.92, hspace=0.4, wspace=0.3)
    
    return fig, metrics_df


#### Part 8: Interactive Dashboard

In [None]:
def create_dashboard(df):
    """
    Create a comprehensive interactive dashboard for Chicago traffic crash analysis
    that includes all visualizations (original and new) in a single dashboard.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        Processed dataframe with crash data
        
    Returns:
    --------
    plotly.graph_objects.Figure
        Interactive dashboard figure
    """
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

    
    # Define consistent colors from viridis palette
    VIRIDIS_COLORS = {
        'main': '#6a50a7',  # Light purple
        'secondary': '#52c2c4',  # Light teal
        'tertiary': '#8ddc6e',  # Light green
        'highlight': '#fee04c',  # Light yellow
    }
    
    # Create a comprehensive dashboard with multiple sections
    # Important: Set the map subplot type as "mapbox" instead of "xy"
    fig = make_subplots(
        rows=5, cols=2,
        subplot_titles=[
            'Crashes by Hour of Day', 'Crashes by Day of Week',
            'Monthly Trend', 'Top Crash Types',
            'Injury Rates by Crash Type', 'Severity by Weather',
            'Monthly Crash Trends by Year', 'High-Risk Corridors by Injuries',
            'Crash Density Map', 'Top 5 Streets Proportional Distribution'
        ],
        specs=[
            [{"type": "xy"}, {"type": "xy"}],
            [{"type": "xy"}, {"type": "xy"}],
            [{"type": "xy"}, {"type": "xy"}],
            [{"type": "xy"}, {"type": "xy"}],
            [{"type": "mapbox"}, {"type": "pie"}]  # Changed from "xy" to "mapbox"
        ],
        vertical_spacing=0.08,
        horizontal_spacing=0.05
    )
    
    # 1. Crashes by Hour
    hour_col = 'CRASH_HOUR' if 'CRASH_HOUR' in df.columns else 'HOUR'
    hour_counts = df[hour_col].value_counts().sort_index()
    
    fig.add_trace(
        go.Bar(
            x=list(range(24)),
            y=hour_counts.values,
            marker_color='rgba(106, 80, 167, 0.8)',
            name='Hourly Crashes'
        ),
        row=1, col=1
    )
    
    # 2. Crashes by Day of Week
    days = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
    if 'DAY_OF_WEEK' in df.columns:
        dow_counts = df.groupby('DAY_OF_WEEK').size()
        # Ensure days are in correct order
        dow_counts = dow_counts.reindex(days).fillna(0)
    else:
        dow_counts = df.groupby('CRASH_DAY_OF_WEEK').size().sort_index()
        # Map numeric days to names if needed
        if dow_counts.index.dtype == 'int64' or dow_counts.index.dtype == 'int32':
            dow_counts.index = [days[i-1] for i in dow_counts.index]
    
    fig.add_trace(
        go.Bar(
            x=dow_counts.index,
            y=dow_counts.values,
            marker_color='rgba(82, 194, 196, 0.8)',
            name='Daily Crashes'
        ),
        row=1, col=2
    )
    
    # 3. Monthly Trend
    # Handle either date column name
    date_col = 'CRASH_DATE'
    if date_col not in df.columns and 'CRASH_DATETIME' in df.columns:
        date_col = 'CRASH_DATETIME'
    
    # Make sure date column is datetime
    df = df.copy()
    df[date_col] = pd.to_datetime(df[date_col])
    
    # Create year-month string for grouping
    df['YearMonth'] = df[date_col].dt.strftime('%Y-%m')
    monthly_counts = df.groupby('YearMonth').size()
    
    fig.add_trace(
        go.Scatter(
            x=monthly_counts.index,
            y=monthly_counts.values,
            mode='lines+markers',
            line=dict(color='rgba(141, 220, 110, 0.8)', width=2),
            marker=dict(size=6),
            name='Monthly Crashes'
        ),
        row=2, col=1
    )
    
    # 4. Top Crash Types
    if 'FIRST_CRASH_TYPE' in df.columns:
        crash_types = df['FIRST_CRASH_TYPE'].value_counts().nlargest(10)
        fig.add_trace(
            go.Bar(
                y=crash_types.index,
                x=crash_types.values,
                orientation='h',
                marker_color='rgba(254, 224, 76, 0.8)',
                name='Crash Types'
            ),
            row=2, col=2
        )
    
    # 5. Injury Rates by Crash Type
    if 'FIRST_CRASH_TYPE' in df.columns and 'INJURIES_TOTAL' in df.columns:
        df['injury_flag'] = (df['INJURIES_TOTAL'] > 0).astype(int)
        injury_by_type = df.groupby('FIRST_CRASH_TYPE')['injury_flag'].mean().sort_values(ascending=False).head(10)
        
        fig.add_trace(
            go.Bar(
                y=injury_by_type.index,
                x=injury_by_type.values * 100,  # Convert to percentage
                orientation='h',
                marker_color='rgba(106, 80, 167, 0.8)',
                name='Injury Rate'
            ),
            row=3, col=1
        )
    
    # 6. Severity by Weather
    if 'WEATHER_CONDITION' in df.columns and 'SEVERE' in df.columns:
        weather_severity = df.groupby('WEATHER_CONDITION')['SEVERE'].mean().sort_values(ascending=False).head(8)
        
        fig.add_trace(
            go.Bar(
                y=weather_severity.index,
                x=weather_severity.values * 100,  # Convert to percentage
                orientation='h',
                marker_color='rgba(82, 194, 196, 0.8)',
                name='Severity Rate'
            ),
            row=3, col=2
        )
    
    # 7. Monthly Crash Trends by Year (NEW - Image 1)
    df['Year'] = df[date_col].dt.year
    df['Month'] = df[date_col].dt.month
    
    years = sorted(df['Year'].unique())
    colors = ['rgba(106, 80, 167, 0.8)', 'rgba(82, 194, 196, 0.8)', 
              'rgba(141, 220, 110, 0.8)', 'rgba(254, 224, 76, 0.8)',
              'rgba(156, 110, 177, 0.8)', 'rgba(127, 191, 123, 0.8)']
    
    for i, year in enumerate(years):
        year_data = df[df['Year'] == year]
        monthly = year_data.groupby('Month').size()
        
        # Make sure all months are represented
        all_months = pd.Series(0, index=range(1, 13))
        monthly = monthly.add(all_months, fill_value=0)
        
        fig.add_trace(
            go.Scatter(
                x=list(range(1, 13)),
                y=monthly.values,
                mode='lines+markers',
                name=str(year),
                line=dict(color=colors[i % len(colors)]),
                marker=dict(size=8)
            ),
            row=4, col=1
        )
    
    # Add average line for monthly trends
    all_monthly = df.groupby('Month').size()
    fig.add_trace(
        go.Scatter(
            x=list(range(1, 13)),
            y=[all_monthly.mean()] * 12,
            mode='lines',
            line=dict(color='red', dash='dash'),
            name='Monthly Average'
        ),
        row=4, col=1
    )
    
    # 8. Street Injury Analysis (NEW - Image 2)
    if 'INJURIES_TOTAL' in df.columns:
        # Find street column
        potential_street_cols = [
            'STREET_NAME', 'STREET_NO', 'STREET', 'PRIMARY_STREET', 
            'STREET_NAME_PRIMARY', 'STREET_NO_PRIMARY', 'STREET_NO_1'
        ]
        
        street_col = None
        for col in potential_street_cols:
            if col in df.columns:
                street_col = col
                break
        
        if street_col is None:
            street_cols = [col for col in df.columns if 'STREET' in col.upper()]
            if street_cols:
                street_col = street_cols[0]
        
        if street_col:
            street_injuries = df.groupby(street_col)['INJURIES_TOTAL'].sum().sort_values(ascending=False)
            
            # Filter out empty street names
            if street_injuries.index.dtype == 'object':
                street_injuries = street_injuries[street_injuries.index.str.strip() != '']
            
            top_streets = street_injuries.nlargest(10)
            
            # Bar chart for high-risk corridors
            fig.add_trace(
                go.Bar(
                    y=top_streets.index,
                    x=top_streets.values,
                    orientation='h',
                    marker_color='rgba(106, 80, 167, 0.8)',
                    name='Street Injuries'
                ),
                row=4, col=2
            )
            
            # Pie chart for top 5 streets distribution
            top5_streets = top_streets[:5]
            fig.add_trace(
                go.Pie(
                    labels=top5_streets.index,
                    values=top5_streets.values,
                    textinfo='label+percent',
                    hole=0.3,
                    marker=dict(colors=colors[:5]),
                    name='Top 5 Streets'
                ),
                row=5, col=2
            )
    
    # 9. Crash Map 
    if 'LATITUDE' in df.columns and 'LONGITUDE' in df.columns:
        # Filter to valid coordinates
        valid_coords = df[(df['LATITUDE'] > 41.5) & (df['LATITUDE'] < 42.1) &
                          (df['LONGITUDE'] > -88.0) & (df['LONGITUDE'] < -87.4)]
        
        if len(valid_coords) > 0:
            # Create a density mapbox instead of scattermapbox
            fig.add_trace(
                go.Densitymapbox(
                    lat=valid_coords['LATITUDE'],
                    lon=valid_coords['LONGITUDE'],
                    z=valid_coords['INJURIES_TOTAL'] if 'INJURIES_TOTAL' in valid_coords.columns else None,
                    radius=10,
                    colorscale='Viridis',
                    name='Crash Density'
                ),
                row=5, col=1
            )
            
            # Update mapbox configuration
            fig.update_layout(
                mapbox=dict(
                    style="carto-positron",
                    center=dict(
                        lat=valid_coords['LATITUDE'].mean(),
                        lon=valid_coords['LONGITUDE'].mean()
                    ),
                    zoom=9
                )
            )
        else:
            # Fallback if no valid coordinates
            fig.add_trace(
                go.Scatter(
                    x=[0],
                    y=[0],
                    mode='text',
                    text=['No valid coordinate data available'],
                    textposition='middle center',
                ),
                row=5, col=1
            )
    else:
        # If no coordinates, show hour vs day heatmap instead
        try:
            # Create pivot table for heatmap
            if 'DAY_OF_WEEK' in df.columns:
                pivot_df = df.groupby(['HOUR', 'DAY_OF_WEEK']).size().reset_index(name='count')
                pivot_wide = pivot_df.pivot(index='HOUR', columns='DAY_OF_WEEK', values='count')
                pivot_wide = pivot_wide.reindex(columns=days).fillna(0)
            else:
                pivot_df = df.groupby(['HOUR', 'CRASH_DAY_OF_WEEK']).size().reset_index(name='count')
                pivot_wide = pivot_df.pivot(index='HOUR', columns='CRASH_DAY_OF_WEEK', values='count')
                day_mapping = {i+1: day for i, day in enumerate(days)}
                pivot_wide = pivot_wide.rename(columns=day_mapping).fillna(0)
            
            # Change the specs for row 5, col 1 to "xy" instead of "mapbox"
            fig.update_layout(
                {"grid": {"rows": 5, "columns": 2, "pattern": "independent"},
                 "template": "plotly_white"}
            )
            
            # Change subplot type for the last row, first column
            fig._grid_ref[-1][0] = "xy"
            
            # Create heatmap (no longer mapbox)
            fig.add_trace(
                go.Heatmap(
                    z=pivot_wide.values,
                    x=pivot_wide.columns,
                    y=list(range(24)),
                    colorscale='Viridis',
                    colorbar=dict(title='Count'),
                    name='Hourly Pattern'
                ),
                row=5, col=1
            )
            
            # Update subplot title
            fig.layout.annotations[8].text = 'Hour of Day vs. Day of Week Heatmap'
        except Exception as e:
            print(f"Could not create heatmap: {str(e)}")
            # Fallback - simple text
            fig.add_trace(
                go.Scatter(
                    x=[0.5],
                    y=[0.5],
                    mode='text',
                    text=['No spatial or temporal data available for visualization'],
                    textposition='middle center',
                ),
                row=5, col=1
            )
    
    # Update layout for better appearance
    fig.update_layout(
        height=1600,
        width=1200,
        title_text="Chicago Traffic Crash Analysis Dashboard",
        template="plotly_white",
        showlegend=False
    )
    
    # Update axes labels
    fig.update_xaxes(title_text="Hour of Day", row=1, col=1)
    fig.update_yaxes(title_text="Number of Crashes", row=1, col=1)
    
    fig.update_xaxes(title_text="Day of Week", row=1, col=2)
    fig.update_yaxes(title_text="Number of Crashes", row=1, col=2)
    
    fig.update_xaxes(title_text="Month", row=2, col=1)
    fig.update_yaxes(title_text="Number of Crashes", row=2, col=1)
    
    fig.update_xaxes(title_text="Count", row=2, col=2)
    fig.update_yaxes(title_text="Crash Type", row=2, col=2)
    
    fig.update_xaxes(title_text="Injury Rate (%)", row=3, col=1)
    fig.update_yaxes(title_text="Crash Type", row=3, col=1)
    
    fig.update_xaxes(title_text="Severity Rate (%)", row=3, col=2)
    fig.update_yaxes(title_text="Weather Condition", row=3, col=2)
    
    fig.update_xaxes(title_text="Month", tickvals=list(range(1, 13)), 
                     ticktext=['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'],
                     row=4, col=1)
    fig.update_yaxes(title_text="Number of Crashes", row=4, col=1)
    
    fig.update_xaxes(title_text="Total Injuries", row=4, col=2)
    fig.update_yaxes(title_text="Street", row=4, col=2)
    
    return fig

In [None]:
def main(file_path, sample_size=None, random_seed=45665456, use_advanced_models=False):
    """
    Run the complete analysis workflow with enhanced visualizations.
    Parameters:
    -----------
    file_path : str
        Path to the crash data CSV
    sample_size : int, optional
        Number of records to sample (for faster processing)
    random_seed : int, default=45665456
        Random seed for reproducibility
    use_advanced_models : bool, default=False
        Whether to use advanced models and techniques
    """


    print("===== Starting Traffic Crash Analysis =====")
    start_time = time.time()

    # 1. Load data
    df = load_data(file_path)

    # 2. Sample data if needed
    if sample_size is not None and sample_size < len(df):
        df = df.sample(sample_size, random_state=random_seed)
        print(f"Sampled {sample_size} records for analysis")

    # 3. Preprocess data
    print("\n===== Preprocessing Data =====")
    df_processed = preprocess_data(df)

    # 4. Analyze missing data
    print("\n===== Missing Data Analysis =====")
    missing_df = analyze_missing_data(df_processed)

    # 5. Feature engineering (do this early to have features for plots)
    print("\n===== Feature Engineering =====")
    try:
        df_features = engineer_features(df_processed)

        # 6. All visualizations that don't depend on model training
        print("\n===== Exploratory Data Analysis =====")
        
        # 6.1 Temporal analysis
        print("Plotting crash counts by hour...")
        plot_crashes_by_hour(df_features)
        
        print("Plotting crash counts by day of week...")
        plot_crashes_by_day_of_week(df_features)
        
        print("Plotting crash counts by day of week (donut chart)...")
        plot_crashes_by_weekday_donut(df_features)
        
        print("Plotting monthly trends...")
        plot_monthly_trend(df_features)
        
        print("Creating temporal heatmap...")
        plot_temporal_heatmap(df_features)
        
        # NEW: Add monthly crash trends by year (Image 1)
        print("Creating monthly crash trends by year...")
        monthly_trends_fig = plot_monthly_crash_trends_by_year(df_features, start_year=2019)

        # 6.2 Crash types and causes
        print("Plotting top crash types...")
        plot_crash_types(df_features)

        # 6.3 Feature distributions
        print("Creating time period visualizations...")
        create_time_period_visualizations(df_features)

        # 6.4 Severity and injury analysis
        print("Analyzing severity by factors...")
        factors = ['FIRST_CRASH_TYPE', 'WEATHER_CONDITION', 'ROADWAY_SURFACE_COND',
                'TIME_OF_DAY', 'SEASON']
        for factor in factors:
            if factor in df_features.columns:  # Check if factor exists
                plot_severity_by_factor(df_features, factor, top_n=10)
            else:
                print(f"Factor {factor} not found in data")
        
        print("Plotting injury rates by crash type...")
        plot_injury_rate_by_crash_type(df_features)
        
        print("Plotting injury rates by contributory cause...")
        plot_injury_rate_by_cause(df_features)
        
        print("Plotting injury severity by weather condition...")
        plot_injury_severity_by_weather(df_features)
        
        # NEW: Add street injury analysis (Image 2)
        print("Creating street injury analysis...")
        street_injury_fig = plot_injury_analysis_by_street(df_features, top_n=10)

        # 6.5 Speed limit analysis
        print("Plotting injuries by speed limit bin...")
        plot_injuries_by_speed_bin(df_features)

        # 6.6 Spatial analysis
        if 'LATITUDE' in df_features.columns and 'LONGITUDE' in df_features.columns:
            print("\n===== Spatial Analysis =====")
            print("Generating hexbin plot of crash density...")
            plot_hex_crash_density(df_features)

            print("Creating interactive folium heatmap...")

            
            # Create folium map
            m = folium.Map(location=[41.88, -87.63], zoom_start=11, tiles="CartoDB positron")
            
            # Filter to valid bounds for Chicago
            mask = (
                (df_features['LATITUDE'] > 41.5) & (df_features['LATITUDE'] < 42.1) &
                (df_features['LONGITUDE'] > -88.0) & (df_features['LONGITUDE'] < -87.4)
            )
            geo = df_features[mask]
            
            # Create heat data points
            heat_data = list(zip(geo['LATITUDE'].values, geo['LONGITUDE'].values))
            
            # Add heat map
            HeatMap(heat_data, radius=10, blur=15, min_opacity=0.3).add_to(m)
            
            # Display the map if in a notebook
            try:
                from IPython.display import display
                display(m)
            except:
                print("Folium map created (visible in notebook environments)")
            
            # Save map
            try:
                m.save("chicago_crash_heatmap.html")
                print("Saved interactive map to chicago_crash_heatmap.html")
            except Exception as e:
                print(f"Could not save map: {e}")


        # 8. Continue with model preparation
        print("\n===== Preparing for Modeling =====")
        
        # Define base features
        base_feats = [
            'POSTED_SPEED_LIMIT', 'NUM_UNITS', 'CRASH_HOUR', 'CRASH_DAY_OF_WEEK', 'CRASH_MONTH'
        ]
        
        # Verify base features exist
        available_base_feats = [col for col in base_feats if col in df_features.columns]
        if len(available_base_feats) < len(base_feats):
            missing = set(base_feats) - set(available_base_feats)
            print(f"Warning: Some base features are missing: {missing}")
        
        # Define engineered features
        eng_feats = [
            'IS_WEEKEND', 'IS_NIGHTTIME', 'IS_VULNERABLE', 'BAD_WEATHER',
            'BAD_SURFACE', 'AT_INTERSECTION'
        ]
        
        # Check engineered features
        available_eng_feats = [col for col in eng_feats if col in df_features.columns]
        
        # Add advanced features if available
        adv_feats = [col for col in df_features.columns if col.startswith((
            'WEEKEND_', 'WEEKDAY_', 'NIGHT_', 'HIGH_SPEED', 'MULTI_',
            'SINGLE_', 'VULNERABLE_', 'WINTER_', 'IS_HOLIDAY', 'SEASON_'))]
        available_eng_feats.extend(adv_feats)
        
        # Check if we have enough features to proceed
        if len(available_base_feats) == 0:
            print("Error: No base features available for modeling")
            return {
                'data': df_processed,
                'error': 'Insufficient features for modeling',
                'monthly_trends_fig': monthly_trends_fig if 'monthly_trends_fig' in locals() else None,
                'street_injury_fig': street_injury_fig if 'street_injury_fig' in locals() else None,
                'dashboard': dashboard if 'dashboard' in locals() else None
            }
        
        # Prepare feature matrix - handling one-hot encoding safely
        # Only use available features
        features_to_use = available_base_feats + available_eng_feats
        
        # Safe way to create dummy variables
        X = pd.get_dummies(df_features[features_to_use], drop_first=True).fillna(0)
        
        # Check if target variable exists
        if 'SEVERE' not in df_features.columns:
            print("Target variable 'SEVERE' not found - creating it")
            df_features['SEVERE'] = ((df_features['INJURIES_TOTAL'] > 0) |
                                    (df_features['INJURIES_FATAL'] > 0)).astype(int)
        y = df_features['SEVERE']
        
        # Optional: Feature selection
        feature_results = None
        if use_advanced_models:
            print("\n===== Feature Selection =====")
            feature_results = select_features(X, y)
            print("\nSelected features:")
            for feature in feature_results['consensus'][:15]:  # Show top 15
                print(f" - {feature}")
            
            # Use selected features for modeling
            X_selected = X[feature_results['consensus']]
        else:
            X_selected = X
        
        # 9. Model training and evaluation
        print("\n===== Model Training and Evaluation =====")
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X_selected, y, test_size=0.2, stratify=y, random_state=random_seed
        )
        
        # Initialize model_results as None - this is important to prevent issues
        model_results = None
        
        # Train and evaluate models
        if use_advanced_models:
            print("\n===== Using Advanced Models =====")
            # Option 1: Use model selection to find best model
            best_model_results = select_best_model(X_train, y_train, cv=5, scoring='f1')
            best_model = best_model_results['best_model']
            best_model_name = best_model_results['best_model_name']
            
            # Evaluate best model
            y_pred = best_model.predict(X_test)
            y_prob = None
            if hasattr(best_model, 'predict_proba'):
                y_prob = best_model.predict_proba(X_test)[:, 1]
            
            print(f"\nBest model: {best_model_name}")
            print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
            print(f"Precision: {precision_score(y_test, y_pred):.4f}")
            print(f"Recall: {recall_score(y_test, y_pred):.4f}")
            print(f"F1 Score: {f1_score(y_test, y_pred):.4f}")
            
            # Create simplified model_results for visualization
            model_results = {
                best_model_name: {
                    'model': best_model,
                    'metrics': {
                        'accuracy': accuracy_score(y_test, y_pred),
                        'precision': precision_score(y_test, y_pred),
                        'recall': recall_score(y_test, y_pred),
                        'f1': f1_score(y_test, y_pred)
                    },
                    'predictions': y_pred,
                    'probabilities': y_prob
                }
            }
            
            # Plot model performance and predictions
            if y_prob is not None:
                model_results[best_model_name]['metrics']['roc_auc'] = roc_auc_score(y_test, y_prob)
                print(f"ROC AUC: {model_results[best_model_name]['metrics']['roc_auc']:.4f}")
            
            # Plot performance metrics
            print("\nPlotting model performance...")
            if model_results is not None:
                plot_model_performance(model_results, y_test)
            
            # Use new enhanced function for classification metrics comparison
            print("\nGenerating detailed classification metrics comparison...")
            if model_results is not None:
                fig_metrics, _ = plot_classification_metrics_comparison(model_results)
            
            # Plot prediction distributions
            print("Plotting prediction distributions...")
            if y_prob is not None and model_results is not None:
                try:
                    fig_distributions, dist_metrics = plot_prediction_distributions_enhanced(model_results, y_test)
                except Exception as e:
                    print(f"Error plotting prediction distributions: {str(e)}")
                    fig_distributions = None
            
            # Analyze model with SHAP
            print("\n===== Model Interpretation with SHAP =====")
            shap_results = None
            try:
                shap_results = analyze_model_with_shap(best_model, X_test, X_selected.columns)
            except Exception as e:
                print(f"SHAP analysis failed: {str(e)}")
            
            # Plot feature importance for best model
            print(f"\nPlotting feature importance for best model: {best_model_name}")
            if best_model is not None:
                plot_feature_importance(best_model, X_selected.columns, model_name=best_model_name)
            
            # Plot confusion matrix for best model
            print(f"\nPlotting confusion matrix for best model: {best_model_name}")
            if y_pred is not None:
                plot_confusion_matrix(y_test, y_pred, title=f'Confusion Matrix: {best_model_name}')
            
            # Optional: Time series modeling if time data is available
            ts_models = None
            if 'CRASH_DATE' in df_features.columns:
                print("\n===== Time Series Modeling =====")
                try:
                    ts_models = train_time_series_models(df_features)
                except Exception as e:
                    print(f"Time series modeling failed: {str(e)}")
            
            # Optional: Spatial modeling if coordinates are available
            spatial_models = None
            if 'LATITUDE' in df_features.columns and 'LONGITUDE' in df_features.columns:
                print("\n===== Spatial Modeling =====")
                try:
                    spatial_models = train_spatial_models(df_features)
                except Exception as e:
                    print(f"Spatial modeling failed: {str(e)}")
            
            # Optional: Multi-task modeling if injury data is available
            multitask_results = None
            if 'INJURIES_TOTAL' in df_features.columns:
                print("\n===== Multi-task Modeling =====")
                try:
                    y_reg = df_features['INJURIES_TOTAL']
                    multitask_results = build_multitask_model(X_selected, y, y_reg)
                except Exception as e:
                    print(f"Multi-task modeling failed: {str(e)}")
        else:
            # Use standard models
            try:
                model_results = train_evaluate_models(X_train, X_test, y_train, y_test)
                
                # Plot standard model performance
                print("\nPlotting model performance...")
                if model_results is not None:
                    plot_model_performance(model_results, y_test)
                
                # Also plot enhanced classification metrics comparison
                print("\nGenerating detailed classification metrics comparison...")
                if model_results is not None:
                    try:
                        fig_metrics, metric_ranking = plot_classification_metrics_comparison(model_results)
                    except Exception as e:
                        print(f"Error plotting classification metrics: {str(e)}")
                        fig_metrics = None
                
                # Plot enhanced prediction distributions
                print("Plotting enhanced prediction distributions...")
                if model_results is not None:
                    try:
                        fig_distributions, dist_metrics = plot_prediction_distributions_enhanced(model_results, y_test)
                    except Exception as e:
                        print(f"Error plotting prediction distributions: {str(e)}")
                        fig_distributions = None
                
                # Find the best model based on ROC AUC
                best_model_name = None
                best_model = None
                if model_results is not None:
                    try:
                        best_model_name = max(model_results.items(),
                                            key=lambda x: x[1]['metrics']['roc_auc'])[0]
                        best_model = model_results[best_model_name]['model']
                        print(f"\nBest model: {best_model_name}")
                        
                        # Plot feature importance for best model
                        if best_model is not None:
                            try:
                                print(f"\nPlotting feature importance for best model: {best_model_name}")
                                plot_feature_importance(best_model, X_selected.columns, model_name=best_model_name)
                            except Exception as e:
                                print(f"Error plotting feature importance: {str(e)}")
                        
                        # Plot confusion matrix for best model
                        if best_model_name in model_results:
                            try:
                                print(f"\nPlotting confusion matrix for best model: {best_model_name}")
                                best_preds = model_results[best_model_name]['predictions']
                                plot_confusion_matrix(y_test, best_preds, title=f'Confusion Matrix: {best_model_name}')
                            except Exception as e:
                                print(f"Error plotting confusion matrix: {str(e)}")
                    except (KeyError, ValueError, TypeError) as e:
                        print(f"Could not determine best model: {str(e)}")
            except Exception as e:
                print(f"Error in model training and evaluation: {str(e)}")
                model_results = None
        
        # 10. Injury count prediction (if available)
        reg_results = None
        if 'INJURIES_TOTAL' in df_features.columns:
            print("\n===== Injury Count Prediction =====")
            try:
                y_reg = df_features['INJURIES_TOTAL']
                
                # Split data for regression
                _, _, y_reg_train, y_reg_test = train_test_split(
                    X, y_reg, test_size=0.2, random_state=random_seed
                )
                
                # Train and evaluate regression models
                reg_results = evaluate_regression_models(X_train, X_test, y_reg_train, y_reg_test)
                print("\nRegression model performance:")
                print(reg_results)
                
                # Generate enhanced regression metrics visualization
                print("\nGenerating regression metrics visualization...")
                try:
                    fig_reg_metrics = plot_regression_metrics_comparison(reg_results)
                except Exception as e:
                    print(f"Error plotting regression metrics: {str(e)}")
                    fig_reg_metrics = None
            except Exception as e:
                print(f"Regression modeling failed: {str(e)}")
                reg_results = None
        
        # Calculate execution time
        end_time = time.time()
        total_time = end_time - start_time
        print(f"\nAnalysis completed in {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
        
        # Create a dictionary to hold the results
        results = {
            'data': df_features,
            'model_results': model_results,  # This could be None but that's now handled
            'feature_results': feature_results,
            'regression_results': reg_results,
            'dashboard': dashboard if 'dashboard' in locals() else None,
            'time_series_models': ts_models if 'ts_models' in locals() else None,
            'spatial_models': spatial_models if 'spatial_models' in locals() else None,
            'multitask_model': multitask_results if 'multitask_results' in locals() else None,
            'shap_results': shap_results if 'shap_results' in locals() else None,
            'monthly_trends_fig': monthly_trends_fig if 'monthly_trends_fig' in locals() else None,
            'street_injury_fig': street_injury_fig if 'street_injury_fig' in locals() else None
        }
        
        # Add optional figures if they exist
        if 'fig_metrics' in locals() and fig_metrics is not None:
            results['classification_metrics_fig'] = fig_metrics
        if 'fig_distributions' in locals() and fig_distributions is not None:
            results['prediction_distributions_fig'] = fig_distributions
        if 'fig_reg_metrics' in locals() and fig_reg_metrics is not None:
            results['regression_metrics_fig'] = fig_reg_metrics
        if 'LATITUDE' in df_features.columns and 'm' in locals():
            results['folium_map'] = m
        try:
            dashboard = create_dashboard(df_features)
            print("Dashboard created successfully")
            results['dashboard'] = dashboard
        except Exception as e:
            print(f"Error creating dashboard: {str(e)}")
            results['dashboard'] = None
        return results
    except Exception as e:
        print(f"Error in analysis pipeline: {str(e)}")
        import traceback
        traceback.print_exc()
        return {
            'data': df_processed if 'df_processed' in locals() else None,
            'error': str(e),
            'model_results': None  # Ensure model_results is defined in error case too
        }


In [None]:
def run_chicago_crash_analysis(file_path='Traffic_Crashes_-_Crashes.csv',
                             sample_size=100000,
                             use_advanced_models=False,
                             save_results=True,
                             output_folder='crash_analysis_results',
                             show_plots=True):
    """
    Run a comprehensive traffic crash analysis pipeline with enhanced visualizations.
    Parameters:
    -----------
    file_path : str
        Path to the crash data CSV
    sample_size : int, optional
        Number of records to sample (for faster processing)
    use_advanced_models : bool
        Whether to use advanced models and techniques
    save_results : bool
        Whether to save the results to disk
    output_folder : str
        Folder to save results to (if save_results is True)
    show_plots : bool
        Whether to display plots during analysis
    Returns:
    --------
    dict
        Dictionary containing analysis results, models, and visualizations
    """
    # Check for required dependencies
    missing_deps = []
    try:
        import pandas as pd
    except ImportError:
        missing_deps.append("pandas")
    try:
        import numpy as np
    except ImportError:
        missing_deps.append("numpy")
    try:
        import matplotlib.pyplot as plt
        import matplotlib
    except ImportError:
        missing_deps.append("matplotlib")
    try:
        import seaborn as sns
    except ImportError:
        missing_deps.append("seaborn")
    try:
        import sklearn
    except ImportError:
        missing_deps.append("scikit-learn")
    try:
        import plotly.graph_objects as go
    except ImportError:
        missing_deps.append("plotly")

    # Start timer
    import time
    import os
    import pickle
    start_time = time.time()

    print("\n========= Chicago Traffic Crash Analysis & Prediction Pipeline =========")
    print("\nInitiating analysis with the following configuration:")
    print(f"- Data source: {file_path}")
    print(f"- Sample size: {sample_size if sample_size else 'Full dataset'}")
    print(f"- Advanced models: {'Enabled' if use_advanced_models else 'Disabled'}")
    print(f"- Save results: {'Enabled' if save_results else 'Disabled'}")
    print(f"- Show plots: {'Enabled' if show_plots else 'Disabled'}")

    # Report missing dependencies
    if missing_deps:
        print("\nWARNING: The following required dependencies are missing:")
        for dep in missing_deps:
            print(f" - {dep}")
        print("Please install them using: pip install " + " ".join(missing_deps))
        print("Aborting analysis.")
        return {"error": f"Missing required dependencies: {', '.join(missing_deps)}"}

    # Create output folder if needed
    if save_results and not os.path.exists(output_folder):
        os.makedirs(output_folder)
        print(f"Created output folder: {output_folder}")

    # Configure matplotlib based on show_plots
    if not show_plots:
        plt.ioff()  # Turn off interactive mode
        matplotlib.use('Agg')  # Use non-interactive backend
    else:
        plt.ion()  # Turn on interactive mode

    # Run the main analysis function
    try:
        # Call the main function
        results = main(file_path, sample_size=sample_size, random_seed=45665456, use_advanced_models=use_advanced_models)

        # Check if analysis completed successfully
        if 'error' in results:
            print(f"\nAnalysis encountered an error: {results['error']}")
            return results

        # Create dashboard explicitly here, if not already created
        if 'dashboard' not in results or results['dashboard'] is None:
            try:
                dashboard = create_dashboard(results['data'])
                print("Dashboard created successfully")
                results['dashboard'] = dashboard
            except Exception as e:
                print(f"Error creating dashboard: {str(e)}")
                results['dashboard'] = None

        # Save results if requested
        if save_results:
            # Save processed data
            if results['data'] is not None:
                results['data'].to_csv(f"{output_folder}/processed_data.csv", index=False)
                print(f"Saved processed data to {output_folder}/processed_data.csv")

            # Save model results
            if 'model_results' in results and results['model_results'] is not None:
                with open(f"{output_folder}/model_results.pkl", "wb") as f:
                    pickle.dump(results['model_results'], f)
                print(f"Saved model results to {output_folder}/model_results.pkl")

            # Save dashboard if available
            if 'dashboard' in results and results['dashboard'] is not None:
                try:
                    results['dashboard'].write_html(f"{output_folder}/dashboard.html")
                    print(f"Saved dashboard to {output_folder}/dashboard.html")
                except Exception as e:
                    print(f"Failed to save dashboard: {str(e)}")

            # Save visualizations
            print("Saving visualizations...")
            
            # Use a reliable backend for figure generation
            matplotlib.use('Agg')
            
            # Generate and save each plot directly
            # 1. Hourly crashes
            plt.figure(figsize=(12, 8))
            plot_crashes_by_hour(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/hourly_crashes.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 2. Daily crashes
            plt.figure(figsize=(12, 8))
            plot_crashes_by_day_of_week(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/daily_crashes.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 3. Weekday donut chart (if available)
            donut_fig = plot_crashes_by_weekday_donut(results['data'])
            if donut_fig is not None:
                donut_fig.savefig(f"{output_folder}/weekday_donut.png", dpi=100, bbox_inches='tight')
                plt.close(donut_fig)
            
            # 4. Monthly trends
            plt.figure(figsize=(12, 8))
            plot_monthly_trend(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/monthly_crashes.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 5. Temporal heatmap
            plt.figure(figsize=(12, 8))
            plot_temporal_heatmap(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/temporal_heatmap.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 6. Crash types
            plt.figure(figsize=(12, 8))
            plot_crash_types(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/crash_types.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 7. Time period visualizations
            plt.figure(figsize=(14, 12))
            time_period_fig = create_time_period_visualizations(results['data'])
            if time_period_fig is not None:
                time_period_fig.savefig(f"{output_folder}/time_period_viz.png", dpi=100, bbox_inches='tight')
                plt.close(time_period_fig)
            
            # 8. Monthly trends by year (if available)
            if 'monthly_trends_fig' in results and results['monthly_trends_fig'] is not None:
                results['monthly_trends_fig'].savefig(f"{output_folder}/monthly_trends.png", dpi=100, bbox_inches='tight')
                plt.close(results['monthly_trends_fig'])
            else:
                plt.figure(figsize=(15, 8))
                monthly_trends_fig = plot_monthly_crash_trends_by_year(results['data'], start_year=2019)
                monthly_trends_fig.savefig(f"{output_folder}/monthly_trends.png", dpi=100, bbox_inches='tight')
                plt.close(monthly_trends_fig)
            
            # 9. Street injury analysis (if available)
            if 'street_injury_fig' in results and results['street_injury_fig'] is not None:
                results['street_injury_fig'].savefig(f"{output_folder}/street_injury.png", dpi=100, bbox_inches='tight')
                plt.close(results['street_injury_fig'])
            else:
                street_injury_fig = plot_injury_analysis_by_street(results['data'], top_n=10)
                if street_injury_fig is not None:
                    street_injury_fig.savefig(f"{output_folder}/street_injury.png", dpi=100, bbox_inches='tight')
                    plt.close(street_injury_fig)
            
            # 10. Injury rate by crash type
            plt.figure(figsize=(10, 6))
            plot_injury_rate_by_crash_type(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/injury_by_crash_type.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 11. Injury rate by cause
            plt.figure(figsize=(10, 6))
            plot_injury_rate_by_cause(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/injury_by_cause.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 12. Injury severity by weather
            plt.figure(figsize=(12, 8))
            plot_injury_severity_by_weather(results['data'])
            plt.tight_layout()
            plt.savefig(f"{output_folder}/injury_by_weather.png", dpi=100, bbox_inches='tight')
            plt.close()
            
            # 13. Crash density map (if coordinates exist)
            if 'LATITUDE' in results['data'].columns:
                plt.figure(figsize=(10, 10))
                plot_hex_crash_density(results['data'])
                plt.tight_layout()
                plt.savefig(f"{output_folder}/crash_density_map.png", dpi=100, bbox_inches='tight')
                plt.close()
            
            # 14. Model performance related plots (if they exist)
            for plot_name, plot_func in [
                ('classification_metrics', plot_classification_metrics_comparison),
                ('prediction_distributions', plot_prediction_distributions_enhanced),
                ('regression_metrics', plot_regression_metrics_comparison)
            ]:
                plot_key = f"{plot_name}_fig"
                if plot_key in results and results[plot_key] is not None:
                    results[plot_key].savefig(f"{output_folder}/{plot_name}.png", dpi=100, bbox_inches='tight')
                    plt.close(results[plot_key])
            
            print(f"Saved visualizations to {output_folder}/")

        # Calculate and print execution time
        elapsed_time = time.time() - start_time
        print(f"\nAnalysis completed successfully in {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")

        # Return detailed report
        print("\n========= Analysis Summary =========")
        if results['data'] is not None:
            print(f"Processed {len(results['data']):,} crash records")

        # Print model performance summary
        if 'model_results' in results and results['model_results'] is not None:
            if use_advanced_models and 'best_model' in results:
                print(f"Best model: {results['best_model']}")
            else:
                try:
                    best_model_name = max(results['model_results'].items(),
                                        key=lambda x: x[1]['metrics']['roc_auc'])[0]
                    metrics = results['model_results'][best_model_name]['metrics']
                    print(f"Best model: {best_model_name}")
                    print(f"Accuracy: {metrics['accuracy']:.4f}")
                    print(f"Precision: {metrics['precision']:.4f}")
                    print(f"Recall: {metrics['recall']:.4f}")
                    print(f"F1 Score: {metrics['f1']:.4f}")
                    print(f"ROC AUC: {metrics['roc_auc']:.4f}")
                except (KeyError, ValueError, TypeError) as e:
                    print(f"Could not determine best model: {str(e)}")

        if 'regression_results' in results and results['regression_results'] is not None:
            print("- Injury count prediction model with visualization")

        if use_advanced_models:
            if 'time_series_models' in results and results['time_series_models'] is not None:
                print("- Time series forecasting model")
            if 'spatial_models' in results and results['spatial_models'] is not None:
                print("- Spatial prediction model")
            if 'multitask_model' in results and results['multitask_model'] is not None:
                print("- Multi-task learning model")

        return results
    except Exception as e:
        print(f"\nError during analysis: {str(e)}")
        import traceback
        traceback.print_exc()
        return {"error": str(e)}
    
results = run_chicago_crash_analysis(file_path='Traffic_Crashes_-_Crashes_20250426.csv'
                                    ,sample_size= 938498
                                    ,use_advanced_models=False
                                    ,save_results=True
                                    ,show_plots=True)
