## Setup and Data Loading

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image

%matplotlib inline
plt.rcParams['figure.dpi'] = 100

## Plotting Functions Library

In [None]:
# ============================================================================
# COMPLETE PLOTTING FUNCTIONS LIBRARY
# ============================================================================
# This cell contains ALL plotting functions used in this notebook.
# Functions extracted from: analyzer_config.py, individual_analyzer.py, 
# overall_analyzer.py
# ============================================================================

# =====================
# Imports
# =====================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import io
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import A4
import plotly.io as pio

"""
Configuration and constants for sumobot analyzer
"""

# =====================
# Arena Configuration
# =====================
arena_center = np.array([0.24, 1.97])
arena_radius = 4.73485

# =====================
# Visualization Parameters
# =====================
tile_size = 0.7   # Larger = bigger heatmap tiles (lower resolution)

# =====================
# Bot Marker Configuration
# =====================
# Map bot names to matplotlib marker shapes for easy visual differentiation
BOT_MARKER_MAP = {
    "Bot_NN": "o",           # #1: NN - Circle
    "Bot_ML_Classification": "s",          # #2: MLP - Square
    "Bot_MCTS": "8",         # #3: MCTS - Octagon
    "Bot_FuzzyLogic": "^",        # #4: Fuzzy - Triangle up
    "Bot_Primitive": "p",    # #5: Primitive - Pentagon
    "Bot_GA": "h",           # #6: GA - Hexagon
    "Bot_SLM_ActionGPT": "*",          # #7: SLM - Star
    "Bot_PPO": "8",          # #8: PPO - Octagon
    "Bot_BT": "X",           # #9: BT - X filled
    "Bot_UtilityAI": "P",      # #10: Utility - Plus
    "Bot_LLM_ActionGPT": "D",          # #11: LLM - Diamond
    "Bot_FSM": "v",          # #12: FSM - Triangle down
    "Bot_DQN": "d",          # #13: DQN - Thin diamond
}

# Default marker if bot not in map
DEFAULT_MARKER = "o"

# =====================
# Metric Name Mapping
# =====================
# Map metric/key names to proper display names
METRIC_NAME_MAP = {
    # Time-related metrics
    "MatchDur": "Match Duration",
    "ActInterval": "Action Interval",
    "Timer": "Timer",
    "Duration": "Duration",

    # Win/Performance metrics
    "WinRate": "Win Rate",
    "WinRate_L": "Win Rate (Left)",
    "WinRate_R": "Win Rate (Right)",
    "Rank": "Rank",

    # Action metrics
    "ActionCounts": "Action Counts",
    "ActionCounts_L": "Action Counts (Left)",
    "ActionCounts_R": "Action Counts (Right)",
    "Actions": "Actions",
    "AvgActions_L": "Avg Actions (Left)",
    "AvgActions_R": "Avg Actions (Right)",

    # Collision metrics
    "Collisions": "Collisions",
    "Collisions_L": "Collisions (Left)",
    "Collisions_R": "Collisions (Right)",
    "TotalCollisions": "Total Collisions",
    "Actor_L": "Actor (Left)",
    "Actor_R": "Actor (Right)",
    "Tie": "Tie",

    # Specific action types
    "Accelerate_Act": "Accelerate",
    "Accelerate_Dur": "Accelerate",
    "Accelerate_Act_L": "Accelerate (Left)",
    "Accelerate_Act_R": "Accelerate (Right)",
    "TurnLeft_Act": "Turn Left",
    "TurnLeft_Dur": "Turn Left",
    "TurnLeft_Act_L": "Turn Left (Left)",
    "TurnLeft_Act_R": "Turn Left (Right)",
    "TurnRight_Act": "Turn Right",
    "TurnRight_Dur": "Turn Right",
    "TurnRight_Act_L": "Turn Right (Left)",
    "TurnRight_Act_R": "Turn Right (Right)",
    "Dash_Act": "Dash",
    "Dash_Dur": "Dash",

    # Skill actions
    "SkillBoost_Act": "Skill Boost",
    "SkillBoost_Dur": "Skill Boost",
    "SkillBoost_Act_L": "Skill Boost (Left)",
    "SkillBoost_Act_R": "Skill Boost (Right)",
    "SkillStone_Act": "Skill Stone",
    "SkillStone_Dur": "Skill Stone",
    "SkillStone_Act_L": "Skill Stone (Left)",
    "SkillStone_Act_R": "Skill Stone (Right)",
    "TotalSkillAct": "Total Skill Actions",

    # Round/Game metrics
    "Round": "Round",
    "RoundNumeric": "Round",
    "SkillTypeNumeric": "Skill Type",
    "Games": "Games",

    # Bot identifiers
    "Bot": "Bot",
    "Bot_L": "Bot (Left)",
    "Bot_R": "Bot (Right)",
    "Enemy": "Enemy",
    "Left_Side": "Left Side",
    "Right_Side": "Right Side",

    # Skill types
    "Skill": "Skill",
    "SkillType": "Skill Type",
    "SkillLeft": "Skill (Left)",
    "SkillRight": "Skill (Right)",
    "SkillNumeric": "Skill (Numeric)",

    # Time bins
    "TimeBin": "Time Bin",

    # Other metrics
    "AvgDuration": "Avg Duration",
    "MeanCount": "Mean Count",
    "Count": "Count",
    "Action": "Action",
    "Side": "Side",
    "BotWithRank": "Bot (with Rank)",
    "BotWithRankLeft": "Bot (Left, with Rank)",
    "BotWithRankRight": "Bot (Right, with Rank)",
}


def get_metric_name(metric_key):
    """
    Get proper display name for a metric key.

    Args:
        metric_key: Raw metric/column name

    Returns:
        Proper display name if found in map, otherwise returns the raw metric key
    """
    return METRIC_NAME_MAP.get(metric_key, metric_key)


def get_bot_marker(bot_name):
    """
    Get marker shape for a given bot name.

    Args:
        bot_name: Name of the bot

    Returns:
        Matplotlib marker string
    """
    return BOT_MARKER_MAP.get(bot_name, DEFAULT_MARKER)




def calculate_legend_padding(ax, x_labels=None, rotation=0, base_padding=-0.15):
    """
    Calculate dynamic padding for legend based on x-axis label length and rotation.

    Args:
        ax: Matplotlib axes object
        x_labels: List of x-axis labels (optional, will try to get from ax if not provided)
        rotation: Rotation angle of x-axis labels in degrees
        base_padding: Base padding value (default: -0.15)

    Returns:
        Adjusted padding value for bbox_to_anchor
    """
    if x_labels is None:
        # Try to get labels from the axis
        x_labels = [label.get_text() for label in ax.get_xticklabels()]

    if not x_labels or all(not label for label in x_labels):
        return base_padding

    # Calculate maximum label length
    max_label_len = max(len(str(label)) for label in x_labels)

    # Calculate padding based on rotation and length
    if rotation >= 30:
        # For rotated labels, length affects vertical space more
        # Longer labels need more space
        extra_padding = (max_label_len - 10) * 0.005  # Adjust factor as needed
        extra_padding = max(0, min(extra_padding, 0.15))  # Cap between 0 and 0.15
    else:
        # For horizontal labels, less impact
        extra_padding = (max_label_len - 15) * 0.003
        extra_padding = max(0, min(extra_padding, 0.1))

    return base_padding - extra_padding

def plot_grouped(summary, key="WinRate", group_by="ActInterval", width=10, height=7, chart_type="line", error_type="std"):
    """
    Plot average win rate per bot, grouped by a specific configuration variable.

    Parameters:
        group_by: one of ["ActInterval", "Timer", "Round", "SkillType"]
        chart_type: "line" for line chart with error bands, "bar" for bar chart
        error_type: "se" for standard error (recommended), "std" for standard deviation,
                    "ci" for 95% confidence interval
    """

    # --- Handle SkillType special case ---
    # SkillType combines SkillLeft and SkillRight into a unified grouping
    if group_by == "SkillType":
        left_group_col = "SkillLeft"
        right_group_col = "SkillRight"
    else:
        left_group_col = group_by
        right_group_col = group_by

    # --- Merge both sides ---
    left_cols = ["Bot_L", f"{key}_L", left_group_col]
    right_cols = ["Bot_R", f"{key}_R", right_group_col]

    if "Rank_L" in summary.columns:
        left_cols.append("Rank_L")
    if "Rank_R" in summary.columns:
        right_cols.append("Rank_R")

    left = summary[left_cols].rename(
        columns={"Bot_L": "Bot", f"{key}_L": key, "Rank_L": "Rank", left_group_col: group_by}
    )
    right = summary[right_cols].rename(
        columns={"Bot_R": "Bot", f"{key}_R": key, "Rank_R": "Rank", right_group_col: group_by}
    )

    combined = pd.concat([left, right], ignore_index=True)

    # Fill missing Rank with large number so unranked bots go last
    if "Rank" not in combined.columns:
        combined["Rank"] = np.nan
    combined["Rank"] = combined["Rank"].fillna(9999)

    # --- Aggregate (with std and count) ---
    grouped = (
        combined.groupby(["Bot", group_by], dropna=False)
        .agg({key: ["mean", "std", "count"], "Rank": "first"})
        .reset_index()
    )

    # Flatten column names
    grouped.columns = ["Bot", group_by, f"{key}_mean", f"{key}_std", f"{key}_count", "Rank"]
    grouped[f"{key}_std"] = grouped[f"{key}_std"].fillna(0)  # Handle cases with no std
    grouped[f"{key}_count"] = grouped[f"{key}_count"].fillna(1)  # Avoid division by zero

    # --- Sort bots by Rank ---
    bot_order = grouped.groupby("Bot")["Rank"].first().sort_values().index.tolist()

    fig, ax = plt.subplots(figsize=(width, height))

    if chart_type == "line":
        # --- Line chart with error bands ---
        x_values = sorted(grouped[group_by].unique())
        colors = plt.cm.tab10(np.linspace(0, 1, len(bot_order)))

        for i, bot in enumerate(bot_order):
            bot_data = grouped[grouped["Bot"] == bot].sort_values(group_by)
            rank = int(bot_data["Rank"].iloc[0])

            means = []
            errors = []
            for x_val in x_values:
                row = bot_data[bot_data[group_by] == x_val]
                if not row.empty:
                    mean_val = row[f"{key}_mean"].values[0]
                    std_val = row[f"{key}_std"].values[0]
                    count_val = row[f"{key}_count"].values[0]

                    means.append(mean_val)

                    # Calculate error based on error_type
                    if error_type == "se":
                        # Standard Error
                        error = std_val / np.sqrt(count_val) if count_val > 0 else 0
                    elif error_type == "ci":
                        # 95% Confidence Interval (approximation using 1.96 * SE)
                        error = 1.96 * (std_val / np.sqrt(count_val)) if count_val > 0 else 0
                    else:  # "std"
                        error = std_val

                    errors.append(error)
                else:
                    means.append(np.nan)
                    errors.append(0)

            means = np.array(means)
            errors = np.array(errors)

            # Plot line with thicker style and bot-specific marker
            marker = get_bot_marker(bot)
            ax.plot(x_values, means, marker=marker, linestyle='-', linewidth=2.5, markersize=7,
                   label=f"{bot} (#{rank})", color=colors[i])

            # Plot error band with lighter transparency
            # ax.fill_between(x_values, means - errors, means + errors,
            #               alpha=0.15, color=colors[i])

        ax.set_xlabel(get_metric_name(group_by), fontsize=12, fontweight='bold')
        ax.set_ylabel(get_metric_name(key), fontsize=12, fontweight='bold')
        ax.set_xticks(x_values)
        ax.set_xticklabels([str(x) for x in x_values])

        # Set Y-axis limits for WinRate
        if key == "WinRate":
            ax.set_ylim(-0.05, 1.05)

    else:
        # --- Bar chart (original) ---
        grouped_bar = grouped.rename(columns={f"{key}_mean": key})
        grouped_bar["Bot"] = pd.Categorical(grouped_bar["Bot"], categories=bot_order, ordered=True)
        grouped_bar = grouped_bar.sort_values(["Bot", group_by])

        labels = [
            f"{b} (#{int(grouped_bar[grouped_bar['Bot'] == b]['Rank'].iloc[0])})"
            for b in bot_order
        ]

        groups = sorted(grouped_bar[group_by].unique())
        x = np.arange(len(bot_order))
        bar_width = 0.8 / len(groups)

        for i, g in enumerate(groups):
            subset = grouped_bar[grouped_bar[group_by] == g]
            avg_by_bot = subset.set_index("Bot").reindex(bot_order)[key].fillna(0)
            ax.bar(x + i * bar_width, avg_by_bot, width=bar_width, label=str(g))

        ax.set_xticks(x + bar_width * (len(groups) - 1) / 2)
        ax.set_xticklabels(labels, rotation=30, ha="right")
        ax.set_xlabel("Bots")

    # --- Common styling ---
    ax.set_title(f"{get_metric_name(key)} grouped by {get_metric_name(group_by)}", fontsize=14, fontweight='bold', pad=15)

    # Calculate dynamic padding for legend
    legend_padding = calculate_legend_padding(ax, rotation=0)
    ax.legend(title="Bot (Rank)" if chart_type == "line" else group_by,
             loc='upper center', bbox_to_anchor=(0.5, legend_padding), fontsize=10, framealpha=0.9, ncol=3, markerscale=1.2)
    ax.grid(True, linestyle="--", alpha=0.5, linewidth=0.8)
    fig.tight_layout()

    return fig

def prepare_individual_bot_data(df, bot_name):
    """
    Prepare data for a specific bot combining left and right perspectives.

    Args:
        df: Summary matchup dataframe
        bot_name: Name of the bot to analyze

    Returns:
        DataFrame with bot's data from all configurations
    """
    # Bot_L perspective
    df_left = df[df['Bot_L'] == bot_name].copy()
    df_left['WinRate'] = df_left['WinRate_L']
    df_left['Actions'] = df_left['ActionCounts_L']
    df_left['Collisions'] = df_left['Collisions_L']
    df_left['Collisions_Hit'] = df_left['Collisions_L']  # Hit (Actor_L)
    df_left['Collisions_Struck'] = df_left['Collisions_R']  # Struck (Actor_R)
    df_left['Collisions_Tie'] = df_left['Collisions_Tie']  # Tie
    df_left['Duration'] = df_left['Duration_L']
    df_left['Accelerate_Act'] = df_left['Accelerate_Act_L']
    df_left['TurnLeft_Act'] = df_left['TurnLeft_Act_L']
    df_left['TurnRight_Act'] = df_left['TurnRight_Act_L']
    df_left['Dash_Act'] = df_left['Dash_Act_L']
    df_left['SkillBoost_Act'] = df_left['SkillBoost_Act_L']
    df_left['SkillStone_Act'] = df_left['SkillStone_Act_L']

    df_left['Accelerate_Dur'] = df_left['Accelerate_Dur_L']
    df_left['TurnLeft_Dur'] = df_left['TurnLeft_Dur_L']
    df_left['TurnRight_Dur'] = df_left['TurnRight_Dur_L']
    df_left['Dash_Dur'] = df_left['Dash_Dur_L']
    # df_left['SkillBoost_Dur'] = df_left['SkillBoost_Dur_L']
    # df_left['SkillStone_Dur'] = df_left['SkillStone_Dur_L']

    df_left['SkillType'] = df_left['SkillLeft']

    # Bot_R perspective
    df_right = df[df['Bot_R'] == bot_name].copy()
    df_right['WinRate'] = df_right['WinRate_R']
    df_right['Actions'] = df_right['ActionCounts_R']
    df_right['Collisions'] = df_right['Collisions_R']
    df_right['Collisions_Hit'] = df_right['Collisions_R']  # Hit (Actor_R when on right)
    df_right['Collisions_Struck'] = df_right['Collisions_L']  # Struck (Actor_L when on right)
    df_right['Collisions_Tie'] = df_right['Collisions_Tie']  # Tie
    df_right['Duration'] = df_right['Duration_R']
    df_right['Accelerate_Act'] = df_right['Accelerate_Act_R']
    df_right['TurnLeft_Act'] = df_right['TurnLeft_Act_R']
    df_right['TurnRight_Act'] = df_right['TurnRight_Act_R']
    df_right['Dash_Act'] = df_right['Dash_Act_R']
    df_right['SkillBoost_Act'] = df_right['SkillBoost_Act_R']
    df_right['SkillStone_Act'] = df_right['SkillStone_Act_R']

    df_right['Accelerate_Dur'] = df_right['Accelerate_Dur_R']
    df_right['TurnLeft_Dur'] = df_right['TurnLeft_Dur_R']
    df_right['TurnRight_Dur'] = df_right['TurnRight_Dur_R']
    df_right['Dash_Dur'] = df_right['Dash_Dur_R']
    # df_right['SkillBoost_Dur'] = df_right['SkillBoost_Dur_R']
    # df_right['SkillStone_Dur'] = df_right['SkillStone_Dur_R']

    df_right['SkillType'] = df_right['SkillRight']

    # Combine both perspectives
    bot_data = pd.concat([df_left, df_right], ignore_index=True)

    # Add derived columns
    bot_data['RoundNumeric'] = bot_data['Round'].map({'BestOf1': 1, 'BestOf3': 3})
    bot_data['TotalSkillAct'] = bot_data['SkillBoost_Act'] + bot_data['SkillStone_Act']

    # Encode SkillType as numeric for correlation
    skill_map = {'Stone': 1, 'Boost': 2}
    bot_data['SkillTypeNumeric'] = bot_data['SkillType'].map(skill_map)

    return bot_data


def plot_individual_correlation_scatter(data, x_col, y_col, title, bot_name,
                                       alpha=0.95, figsize=(10, 8), add_jitter=False):
    """
    Create scatter plot with regression line and Pearson correlation for individual bot.

    Args:
        data: DataFrame with bot's data
        x_col: Column name for x-axis
        y_col: Column name for y-axis (should be WinRate)
        title: Plot title
        bot_name: Name of the bot
        alpha: Transparency of scatter points
        figsize: Figure size tuple
        add_jitter: If True, add jitter to x-axis for discrete variables

    Returns:
        matplotlib figure
    """
    # Remove NaN values
    plot_data = data[[x_col, y_col]].dropna().copy()

    if len(plot_data) < 2:
        return None

    # Calculate Pearson correlation (on original data)
    pearson_r, pearson_p = stats.pearsonr(plot_data[x_col], plot_data[y_col])

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Add jitter for discrete variables to spread out points
    x_values = plot_data[x_col].values.copy()
    if add_jitter:
        # Determine jitter amount based on the range of x values
        unique_vals = np.unique(x_values)
        if len(unique_vals) <= 5:  # Discrete variable
            x_range = x_values.max() - x_values.min()
            jitter_amount = x_range * 0.02 if x_range > 0 else 0.01
            x_jittered = x_values + np.random.normal(0, jitter_amount, size=len(x_values))
        else:
            x_jittered = x_values
    else:
        x_jittered = x_values

    # Scatter plot with jittered x values
    ax.scatter(x_jittered, plot_data[y_col],
              alpha=alpha, s=60, color='steelblue', edgecolors='black', linewidth=0.5)

    # Add regression line (using original non-jittered data)
    slope, intercept = np.polyfit(x_values, plot_data[y_col], 1)
    x_line = np.linspace(x_values.min(), x_values.max(), 100)
    y_line = slope * x_line + intercept
    ax.plot(x_line, y_line, 'r-', linewidth=2.5, label=f'Regression Line')

    # Add correlation info to plot
    corr_text = f'Pearson r = {pearson_r:.3f}\np-value = {pearson_p:.3e}\nn = {len(plot_data)}'
    ax.text(0.05, 0.95, corr_text, transform=ax.transAxes,
            verticalalignment='top', bbox=dict(boxstyle='round',
            facecolor='wheat', alpha=0.8), fontsize=11, family='monospace')

    # Labels and title
    ax.set_xlabel(get_metric_name(x_col), fontsize=12, fontweight='bold')
    ax.set_ylabel(get_metric_name(y_col), fontsize=12, fontweight='bold')
    ax.set_title(f'{title}\n{bot_name}', fontsize=14, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(loc='best', fontsize=10, framealpha=0.9, markerscale=1.2)

    plt.tight_layout()
    return fig


def plot_individual_bot_correlations(df, bot_name, width=10, height=8,alpha=0.2):
    """
    Create all correlation plots for a specific bot.
    For config variables, plots win rate directly against the config values.
    For actions/collisions, creates separate plots for each config value.

    Args:
        df: Summary matchup dataframe
        bot_name: Name of the bot to analyze
        width: Figure width
        height: Figure height

    Returns:
        Dictionary of figures with nested structure for config-separated plots
    """
    # Prepare data for this bot
    data = prepare_individual_bot_data(df, bot_name)

    if data.empty:
        return {}

    figs = {}

    # a. Winrate vs ActInterval (direct correlation)
    fig = plot_individual_correlation_scatter(
        data,
        x_col='ActInterval',
        y_col='WinRate',
        title='Win Rate vs Action Interval',
        bot_name=bot_name,
        figsize=(width, height),
        add_jitter=False
    )
    if fig:
        figs['actinterval'] = fig

    # b. Winrate vs Round type (direct correlation)
    # Build dynamic round type mapping for title
    round_mapping = data[['Round', 'RoundNumeric']].drop_duplicates().dropna()
    round_labels = ', '.join([f"{int(row['RoundNumeric'])}={row['Round']}"
                              for _, row in round_mapping.sort_values('RoundNumeric').iterrows()])
    round_title = f'Win Rate vs Round Type ({round_labels})' if round_labels else 'Win Rate vs Round Type'

    fig = plot_individual_correlation_scatter(
        data,
        x_col='RoundNumeric',
        y_col='WinRate',
        title=round_title,
        bot_name=bot_name,
        figsize=(width, height),
        add_jitter=False
    )
    if fig:
        figs['roundtype'] = fig

    # c. Winrate vs Timer (direct correlation)
    fig = plot_individual_correlation_scatter(
        data,
        x_col='Timer',
        y_col='WinRate',
        title='Win Rate vs Timer Duration',
        bot_name=bot_name,
        figsize=(width, height),
        add_jitter=False
    )
    if fig:
        figs['timer'] = fig

    # d. Winrate vs Skill Type (direct correlation)
    fig = plot_individual_correlation_scatter(
        data,
        x_col='SkillTypeNumeric',
        y_col='WinRate',
        title='Win Rate vs Skill Type (1=Stone, 2=Boost)',
        bot_name=bot_name,
        figsize=(width, height),
        add_jitter=False
    )
    if fig:
        figs['skilltype'] = fig

    # e. Winrate vs Individual Actions (combined across all configs)
    action_types = ['Accelerate_Act', 'TurnLeft_Act', 'TurnRight_Act',
                   'Dash_Act', 'SkillBoost_Act', 'SkillStone_Act']

    fig, axes = plt.subplots(2, 3, figsize=(width*1.8, height*1.2))
    axes = axes.flatten()

    for idx, action in enumerate(action_types):
        if action not in data.columns:
            continue

        plot_data = data[[action, 'WinRate']].dropna()

        if len(plot_data) < 2:
            axes[idx].text(0.5, 0.5, f'Insufficient data',
                          ha='center', va='center', transform=axes[idx].transAxes)
            continue

        # Calculate Pearson correlation
        pearson_r, pearson_p = stats.pearsonr(plot_data[action], plot_data['WinRate'])

        # Scatter plot
        axes[idx].scatter(plot_data[action], plot_data['WinRate'],
                        alpha=alpha, s=50, color='steelblue', edgecolors='black', linewidth=0.5)

        # Regression line
        if len(plot_data) >= 2 and plot_data[action].std() > 0:
            slope, intercept = np.polyfit(plot_data[action], plot_data['WinRate'], 1)
            x_line = np.linspace(plot_data[action].min(), plot_data[action].max(), 100)
            y_line = slope * x_line + intercept
            axes[idx].plot(x_line, y_line, 'r-', linewidth=2)

        # Correlation info
        corr_text = f'r={pearson_r:.3f}\np={pearson_p:.2e}'
        axes[idx].text(0.05, 0.95, corr_text, transform=axes[idx].transAxes,
                      verticalalignment='top', bbox=dict(boxstyle='round',
                      facecolor='wheat', alpha=0.8), fontsize=9, family='monospace')

        axes[idx].set_xlabel(get_metric_name(action), fontsize=10, fontweight='bold')
        axes[idx].set_ylabel(get_metric_name('WinRate'), fontsize=10, fontweight='bold')
        axes[idx].set_title(f'{get_metric_name("WinRate")} vs {get_metric_name(action)}', fontsize=11, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, linestyle='--')

    plt.suptitle(f'Win Rate vs Individual Action Types\n{bot_name}',
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    figs['actions'] = fig

    # f. Winrate vs Individual Actions Duration (combined across all configs)
    action_dur_types = ['Accelerate_Dur', 'TurnLeft_Dur', 'TurnRight_Dur', 'Dash_Dur']

    fig, axes = plt.subplots(2, 2, figsize=(width*1.2, height*1.2))
    axes = axes.flatten()

    for idx, action in enumerate(action_dur_types):
        if action not in data.columns:
            continue

        plot_data = data[[action, 'WinRate']].dropna()

        if len(plot_data) < 2:
            axes[idx].text(0.5, 0.5, f'Insufficient data',
                          ha='center', va='center', transform=axes[idx].transAxes)
            continue

        # Calculate Pearson correlation
        pearson_r, pearson_p = stats.pearsonr(plot_data[action], plot_data['WinRate'])

        # Scatter plot
        axes[idx].scatter(plot_data[action], plot_data['WinRate'],
                        alpha=alpha, s=50, color='steelblue', edgecolors='black', linewidth=0.5)

        # Regression line
        if len(plot_data) >= 2 and plot_data[action].std() > 0:
            slope, intercept = np.polyfit(plot_data[action], plot_data['WinRate'], 1)
            x_line = np.linspace(plot_data[action].min(), plot_data[action].max(), 100)
            y_line = slope * x_line + intercept
            axes[idx].plot(x_line, y_line, 'r-', linewidth=2)

        # Correlation info
        corr_text = f'r={pearson_r:.3f}\np={pearson_p:.2e}'
        axes[idx].text(0.05, 0.95, corr_text, transform=axes[idx].transAxes,
                      verticalalignment='top', bbox=dict(boxstyle='round',
                      facecolor='wheat', alpha=0.8), fontsize=9, family='monospace')

        axes[idx].set_xlabel(get_metric_name(action), fontsize=10, fontweight='bold')
        axes[idx].set_ylabel(get_metric_name('WinRate'), fontsize=10, fontweight='bold')
        axes[idx].set_title(f'{get_metric_name("WinRate")} vs {get_metric_name(action)}', fontsize=11, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, linestyle='--')

    plt.suptitle(f'Win Rate vs Individual Action Duration\n{bot_name}',
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    figs['actions_dur'] = fig

    # g. Winrate vs Collision Types (Hit, Struck, Tie) - Combined across all configs
    collision_types = ['Collisions_Hit', 'Collisions_Struck', 'Collisions_Tie']
    collision_labels = {'Collisions_Hit': 'Hit', 'Collisions_Struck': 'Struck', 'Collisions_Tie': 'Tie'}

    fig, axes = plt.subplots(1, 3, figsize=(width*1.8, height))

    for idx, col_type in enumerate(collision_types):
        if col_type not in data.columns:
            continue

        plot_data = data[[col_type, 'WinRate']].dropna()

        if len(plot_data) < 2:
            axes[idx].text(0.5, 0.5, f'Insufficient data',
                          ha='center', va='center', transform=axes[idx].transAxes)
            continue

        # Calculate Pearson correlation
        pearson_r, pearson_p = stats.pearsonr(plot_data[col_type], plot_data['WinRate'])

        # Scatter plot
        axes[idx].scatter(plot_data[col_type], plot_data['WinRate'],
                        alpha=alpha, s=60, color='steelblue', edgecolors='black', linewidth=0.5)

        # Regression line
        if len(plot_data) >= 2 and plot_data[col_type].std() > 0:
            slope, intercept = np.polyfit(plot_data[col_type], plot_data['WinRate'], 1)
            x_line = np.linspace(plot_data[col_type].min(), plot_data[col_type].max(), 100)
            y_line = slope * x_line + intercept
            axes[idx].plot(x_line, y_line, 'r-', linewidth=2.5)

        # Correlation info
        corr_text = f'r={pearson_r:.3f}\np={pearson_p:.2e}'
        axes[idx].text(0.05, 0.95, corr_text, transform=axes[idx].transAxes,
                      verticalalignment='top', bbox=dict(boxstyle='round',
                      facecolor='wheat', alpha=0.8), fontsize=10, family='monospace')

        axes[idx].set_xlabel(collision_labels[col_type], fontsize=11, fontweight='bold')
        axes[idx].set_ylabel(get_metric_name('WinRate'), fontsize=11, fontweight='bold')
        axes[idx].set_title(f'{get_metric_name("WinRate")} vs {collision_labels[col_type]}',
                           fontsize=12, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, linestyle='--')

    plt.suptitle(f'Win Rate vs Collision Types\n{bot_name}',
                 fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    figs['collisions'] = fig

    return figs


def calculate_legend_padding(ax, x_labels=None, rotation=0, base_padding=-0.15):
    """
    Calculate dynamic padding for legend based on x-axis label length and rotation.

    Args:
        ax: Matplotlib axes object
        x_labels: List of x-axis labels (optional, will try to get from ax if not provided)
        rotation: Rotation angle of x-axis labels in degrees
        base_padding: Base padding value (default: -0.15)

    Returns:
        Adjusted padding value for bbox_to_anchor
    """
    if x_labels is None:
        # Try to get labels from the axis
        x_labels = [label.get_text() for label in ax.get_xticklabels()]

    if not x_labels or all(not label for label in x_labels):
        return base_padding

    # Calculate maximum label length
    max_label_len = max(len(str(label)) for label in x_labels)

    # Calculate padding based on rotation and length
    if rotation >= 30:
        # For rotated labels, length affects vertical space more
        # Longer labels need more space
        extra_padding = (max_label_len - 10) * 0.1  # Adjust factor as needed
        extra_padding = max(0, min(extra_padding, 0.1))  # Cap between 0 and 0.15
    else:
        # For horizontal labels, less impact
        extra_padding = (max_label_len - 15) * 0.1
        extra_padding = max(0, min(extra_padding, 0.1))

    return base_padding - extra_padding


def plot_with_bot_markers(ax, data, x, y, hue, hue_order=None, **kwargs):
    """
    Plot line plot with bot-specific markers.

    Args:
        ax: Matplotlib axes object
        data: pandas DataFrame with plot data
        x: Column name for x-axis
        y: Column name for y-axis
        hue: Column name for grouping (bot names or bot names with rank)
        hue_order: List specifying order of hue values (optional)
        **kwargs: Additional plot keywords (linewidth, alpha, etc.)

    Example:
        >>> fig, ax = plt.subplots()
        >>> plot_with_bot_markers(ax, data=df, x="Timer", y="WinRate",
        ...                       hue="BotWithRank", hue_order=bot_order)
    """
    # Default plot settings
    plot_kwargs = {'linewidth': 2, 'markersize': 8}
    plot_kwargs.update(kwargs)

    # Determine which bots to plot
    bots_to_plot = hue_order if hue_order else data[hue].unique()

    for bot_label in bots_to_plot:
        bot_data = data[data[hue] == bot_label]
        if bot_data.empty:
            continue

        # Extract original bot name (before " (#rank)" if present)
        bot_name = bot_label.split(" (")[0] if " (" in str(bot_label) else str(bot_label)
        marker = get_bot_marker(bot_name)

        ax.plot(bot_data[x], bot_data[y], marker=marker, label=bot_label, **plot_kwargs)

def update_bot_marker_map(new_mappings):
    """
    Update the bot marker map with new mappings.

    Args:
        new_mappings: Dictionary of {bot_name: marker_shape}

    Example:
        >>> update_bot_marker_map({"Bot_NewBot": "H"})
    """
    BOT_MARKER_MAP.update(new_mappings)


def get_bot_winrates(summary: pd.DataFrame, bot_name: str):
    """Return aggregated winrates for one bot against all others."""
    left = (
        summary[summary["Bot_L"] == bot_name]
        .groupby("Bot_R")["WinRate_L"]
        .mean()
        .reset_index()
        .rename(columns={"Bot_R": "Enemy", "WinRate_L": "WinRate"})
    )

    right = (
        summary[summary["Bot_R"] == bot_name]
        .groupby("Bot_L")["WinRate_R"]
        .mean()
        .reset_index()
        .rename(columns={"Bot_L": "Enemy", "WinRate_R": "WinRate"})
    )

    combined = pd.concat([left, right])
    final = combined.groupby("Enemy")["WinRate"].mean().reset_index()

    return final.sort_values("WinRate", ascending=False)

def build_winrate_matrix(summary: pd.DataFrame):
    """Return pivot matrix of win rates (row = bot, col = enemy)."""
    # Combine both directions
    left = summary[["Bot_L", "Bot_R", "WinRate_L","Rank_L"]].rename(
        columns={"Bot_L": "Left_Side", "Bot_R": "Right_Side", "WinRate_L": "WinRate","Rank_L":"Rank"}
    )
    right = summary[["Bot_R", "Bot_L", "WinRate_R","Rank_R"]].rename(
        columns={"Bot_R": "Left_Side", "Bot_L": "Right_Side", "WinRate_R": "WinRate","Rank_L":"Rank"}
    )

    combined = pd.concat([left, right], ignore_index=True)

    # --- Get bot rank mapping ---
    rank_map = (
        combined.groupby("Left_Side")["Rank"]
        .mean()
        .sort_values()
        .round(0)
        .astype(int)
        .to_dict()
    )

    # --- Rename bot labels with rank ---
    combined["BotWithRankLeft"] = combined["Left_Side"].map(
        lambda b: f"{b} (#{rank_map.get(b, '?')})"
    )
    combined["BotWithRankRight"] = combined["Right_Side"].map(
        lambda b: f"{b} (#{rank_map.get(b, '?')})"
    )

    # Aggregate mean winrate over all configs
    matrix_df = combined.groupby(["BotWithRankLeft", "BotWithRankRight"])["WinRate"].mean().reset_index()

    # Pivot into matrix
    pivot = matrix_df.pivot(index="BotWithRankLeft", columns="BotWithRankRight", values="WinRate")

    # Fill missing (never faced) with NaN
    return pivot

def plot_winrate_matrix(summary, width=8, height=6):
    fig = plt.figure(figsize=(width, height))
    pivot = build_winrate_matrix(summary)
    sns.heatmap(
        pivot, annot=True, cmap="Blues", center=0.5,
        fmt=".2f", linewidths=0.5, cbar_kws={'label': 'Win Rate'}
    )
    plt.title(f"{get_metric_name('Bot')} vs {get_metric_name('Bot')} {get_metric_name('WinRate')} Matrix")
    plt.ylabel(get_metric_name("Bot"))
    plt.xlabel(f"{get_metric_name('Enemy')} {get_metric_name('Bot')}")
    plt.tight_layout()
    return fig

def plot_time_related(summary, width=8, height=6):
    figs = []
    # group by ActInterval, Timer, and Bot_L to average duration per bot per timer
    grouped = (
        summary.groupby(["ActInterval", "Timer", "Bot_L","Rank_L"], as_index=False)
        .agg({"MatchDur": "mean"})
    )
    grouped["AvgDuration"] = grouped["MatchDur"]

    rank_map = (
        grouped.groupby("Bot_L")["Rank_L"]
        .mean()
        .sort_values()
        .round(0)
        .astype(int)
        .to_dict()
    )

    # --- Rename bot labels with rank ---
    grouped["BotWithRank"] = grouped["Bot_L"].map(
        lambda b: f"{b} (#{rank_map.get(b, '?')})"
    )

    # Sort bots by rank
    bot_order = sorted(rank_map.keys(), key=lambda b: rank_map.get(b, 999))
    bot_order_with_rank = [f"{b} (#{int(rank_map[b])})" for b in bot_order]

    for interval in grouped["ActInterval"].unique():
        fig, ax = plt.subplots(figsize=(width, height))
        subset = grouped[grouped["ActInterval"] == interval]

        # Plot bots in rank order
        for bot in bot_order:
            bot_data = subset[subset["Bot_L"] == bot]
            if bot_data.empty:
                continue
            label = bot_data["BotWithRank"].iloc[0]
            marker = get_bot_marker(bot)
            ax.plot(bot_data["Timer"], bot_data["AvgDuration"], marker=marker, label=label,
                   markersize=8, linewidth=2)

        ax.set_title(f"Avg {get_metric_name('MatchDur')} vs {get_metric_name('Timer')} ({get_metric_name('ActInterval')} = {interval})")
        ax.set_xlabel(f"{get_metric_name('Timer')} (s)")
        ax.set_ylabel(f"Actual {get_metric_name('MatchDur')} (s)")

        # Calculate dynamic padding for legend
        legend_padding = calculate_legend_padding(ax, rotation=0)
        ax.legend(title='Bot (Rank)', loc='upper center', bbox_to_anchor=(0.5, legend_padding), ncol=3,
                 framealpha=0.9, markerscale=1.2)
        ax.grid(True, linestyle="--", alpha=0.5)

        unique_timers = sorted(subset["Timer"].unique())
        ax.set_xticks(unique_timers)
        ax.set_xticklabels([f"{t}s" for t in unique_timers])

        figs.append(fig)
    return figs

def plot_action_win_related(summary, width=8, height=6):
    # Step 1: Compute average actions per game (as before)
    summary["AvgActions_L"] = summary["ActionCounts_L"] / summary["Games"]
    summary["AvgActions_R"] = summary["ActionCounts_R"] / summary["Games"]

    # Step 2: Convert each matchup into per-bot rows
    left = summary[["Bot_L", "Bot_R", "AvgActions_L", "WinRate_L"]].rename(
        columns={"Bot_L": "Bot", "Bot_R": "Enemy", "AvgActions_L": "Actions", "WinRate_L": "WinRate"}
    )
    right = summary[["Bot_R", "Bot_L", "AvgActions_R", "WinRate_R"]].rename(
        columns={"Bot_R": "Bot", "Bot_L": "Enemy", "AvgActions_R": "Actions", "WinRate_R": "WinRate"}
    )
    combined = pd.concat([left, right], ignore_index=True)

    corr = combined["Actions"].corr(combined["WinRate"])
    print(f"Correlation between Actions and Win Rate: {corr:.3f}")

    fig = plt.figure(figsize=(width,height))
    sns.regplot(data=combined, x="Actions", y="WinRate", scatter_kws={"alpha":0.6})
    plt.title(f"Correlation Between {get_metric_name('Actions')} and {get_metric_name('WinRate')}")
    plt.xlabel(f"Average {get_metric_name('Actions')} per Game")
    plt.ylabel(get_metric_name("WinRate"))
    plt.grid(alpha=0.3)

    plt.text(
        0.6, 0.85,
        f"Correlation result: {corr}.\n"
        "> 0.5 → strong positive relationship (more actions → more wins)\n~0.0 → no clear relation\n< -0.5 → inverse relationship (passive bots win more)",
        transform=plt.gca().transAxes,
        fontsize=6,
        bbox=dict(facecolor='lightyellow', alpha=0.5, edgecolor='gold', boxstyle="round,pad=0.4")
    )
    
    return fig

def plot_highest_action(summary, width=8, height=6, n_action = 6):
    action_cols = [col for col in summary.columns if col.endswith("_Act_L")]

    # Get rank mapping if available
    if "Rank_L" in summary.columns:
        rank_map = summary.groupby("Bot_L")["Rank_L"].first().to_dict()
        bot_order = sorted(rank_map.keys(), key=lambda b: rank_map[b])
    else:
        rank_map = {}
        bot_order = sorted(summary["Bot_L"].unique())

    df_actions = summary.melt(
        id_vars=["Bot_L"],
        value_vars=action_cols,
        var_name="Action",
        value_name="Count"
    )
    df_actions["Action"] = df_actions["Action"].str.replace("_Act_L", "")
    df_actions = df_actions.groupby(["Bot_L", "Action"])["Count"].sum().reset_index()

    # Add rank to bot names if available
    if rank_map:
        df_actions["BotWithRank"] = df_actions["Bot_L"].map(lambda b: f"{b} (#{int(rank_map.get(b, 999))})")
        bot_order_with_rank = [f"{b} (#{int(rank_map[b])})" for b in bot_order]
        hue_col = "BotWithRank"
        hue_order = bot_order_with_rank
    else:
        hue_col = "Bot_L"
        hue_order = bot_order

    top_actions = df_actions.groupby("Bot_L").apply(lambda x: x.nlargest(n_action, "Count")).reset_index(drop=True)

    # Re-add BotWithRank for top_actions
    if rank_map:
        top_actions["BotWithRank"] = top_actions["Bot_L"].map(lambda b: f"{b} (#{int(rank_map.get(b, 999))})")

    fig = plt.figure(figsize=(width,height))
    sns.barplot(
        data=top_actions,
        x="Count",
        y="Action",
        hue=hue_col,
        hue_order=hue_order
    )
    plt.title("Top 3 Actions Taken per Bot")
    plt.xlabel("Action Count")
    plt.ylabel("Action")
    legend_title = "Bot (Rank)" if rank_map else "Bot"
    plt.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=3)
    plt.tight_layout()
    return fig

def plot_win_rate_stability_over_timer(summary, width=8, height=6):
    # Melt the WinRate columns so both sides are in one column
    df_melted = summary.melt(
        id_vars=["Bot_L", "Bot_R", "Timer"],
        value_vars=["WinRate_L", "WinRate_R"],
        var_name="Side",
        value_name="WinRate"
    )

    # Extract bot name depending on side
    df_melted["Bot"] = df_melted.apply(
        lambda r: r["Bot_L"] if r["Side"] == "WinRate_L" else r["Bot_R"],
        axis=1
    )
    avg = df_melted.groupby(["Bot", "Timer"])["WinRate"].mean().reset_index()

    # Plot
    heat = avg.pivot(index="Bot", columns="Timer", values="WinRate")

    fig = plt.figure(figsize=(width, height))
    sns.heatmap(heat, annot=True, cmap="RdYlGn", vmin=0, vmax=1)
    plt.title("Win Rate Stability vs Timer (Heatmap)")
    plt.xlabel("Timer")
    plt.ylabel("Bot")
    return fig

def plot_timebins_intensity(
    df,
    group_by="Bot",
    timer=None,
    act_interval=None,
    round=None,
    mode="total",        # "total" | "per_action" | "select"
    action_name=None,    # used when mode == "select"
    width=10,
    height=6,
    summary_df=None,     # Optional: summary dataframe with Rank_L column for bot ranking
):
    """
    Plot action intensity over time (with optional timer cutoff).
    Modes:
      - "total": sum MeanCount across all actions -> one line per bot
      - "per_action": show per-action trends; creates one subplot per action
      - "select": plot a single action_name for all bots
    """

    # --- Filters ---
    if timer is not None:
        df = df[df["Timer"] == timer]
    if act_interval is not None:
        df = df[df["ActInterval"] == act_interval]
    if round is not None:
        df = df[df["Round"] == round]

    if df.empty:
        print("⚠️ No data after filtering.")
        return None

    # --- Preprocess TimeBin ---
    df = df.copy()
    df["TimeBin"] = pd.to_numeric(df["TimeBin"], errors="coerce")
    df = df.dropna(subset=["TimeBin"])
    df = df.sort_values("TimeBin")

    # --- Add rank to bot names if group_by is "Bot" ---
    rank_map = None
    if group_by == "Bot":
        # Try to get rank from summary_df first, then from df itself
        if summary_df is not None and "Rank_L" in summary_df.columns:
            rank_map = summary_df.groupby("Bot_L")["Rank_L"].first().to_dict()
        elif "Rank" in df.columns:
            rank_map = df.groupby("Bot")["Rank"].first().to_dict()
        elif "Rank_L" in df.columns:
            rank_map = df.groupby("Bot")["Rank_L"].first().to_dict()

    if rank_map:
        df["BotWithRank"] = df["Bot"].map(lambda b: f"{b} (#{int(rank_map.get(b, 999))})")
        group_by_plot = "BotWithRank"
        # Sort bots by rank
        bot_order = sorted(rank_map.keys(), key=lambda b: rank_map.get(b, 999))
        bot_order_with_rank = [f"{b} (#{int(rank_map[b])})" for b in bot_order]
    else:
        group_by_plot = group_by
        bot_order_with_rank = None

    # --- Helper: apply x-axis cutoff ---
    def apply_timer_xlim(ax):
        if timer is not None:
            ax.set_xlim(0, timer)
            ax.set_xticks(range(0, int(timer) + 1, max(1, int(timer // 10) or 1)))

    # --- Plot modes ---
    if mode == "select":
        if not action_name:
            raise ValueError("action_name must be provided when mode='select'")
        df_sel = df[df["Action"] == action_name]
        if df_sel.empty:
            print(f"⚠️ No rows for action '{action_name}' after filtering.")
            return None
        grouped = df_sel.groupby([group_by_plot, "TimeBin"], as_index=False)["MeanCount"].mean()

        fig, ax = plt.subplots(figsize=(width, height))
        plot_with_bot_markers(ax, data=grouped, x="TimeBin", y="MeanCount",
                            hue=group_by_plot, hue_order=bot_order_with_rank)
        ax.set_title(f"Mean {get_metric_name(action_name)} over time")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Mean Count")
        legend_title = "Bot (Rank)" if (group_by == "Bot" and rank_map is not None) else group_by

        # Calculate dynamic padding for legend
        legend_padding = calculate_legend_padding(ax, rotation=0)
        ax.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, legend_padding), ncol=3,
                 framealpha=0.9, markerscale=1.2)
        ax.grid(True, alpha=0.3)
        apply_timer_xlim(ax)
        fig.tight_layout()
        return fig

    elif mode == "total":
        grouped = df.groupby([group_by_plot, "TimeBin"], as_index=False)["MeanCount"].mean()
        fig, ax = plt.subplots(figsize=(width, height))
        plot_with_bot_markers(ax, data=grouped, x="TimeBin", y="MeanCount",
                            hue=group_by_plot, hue_order=bot_order_with_rank)
        ax.set_title("Total action intensity over time")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Mean Count (summed over actions)")
        legend_title = "Bot (Rank)" if (group_by == "Bot" and rank_map is not None) else group_by

        # Calculate dynamic padding for legend
        legend_padding = calculate_legend_padding(ax, rotation=0)
        ax.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, legend_padding), ncol=3,
                 framealpha=0.9, markerscale=1.2)
        ax.grid(True, alpha=0.3)
        apply_timer_xlim(ax)
        fig.tight_layout()
        return fig

    elif mode == "per_action":
        actions = sorted(df["Action"].unique())
        n = len(actions)
        ncols = min(2, n)
        nrows = (n + ncols - 1) // ncols
        fig, axes = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            figsize=(width, max(height, 2.5 * nrows)),
            squeeze=False
        )
        axes = axes.flatten()

        handles, labels = None, None
        for i, action in enumerate(actions):
            ax = axes[i]
            sub = df[df["Action"] == action].groupby([group_by_plot, "TimeBin"], as_index=False)["MeanCount"].mean()
            if sub.empty:
                ax.set_visible(False)
                continue
            plot_with_bot_markers(ax, data=sub, x="TimeBin", y="MeanCount",
                                hue=group_by_plot, hue_order=bot_order_with_rank)

            # Capture legend handles from first plot
            if i == 0:
                handles, labels = ax.get_legend_handles_labels()
                legend = ax.get_legend()
                if legend is not None:
                    legend.remove()

            ax.set_title(get_metric_name(action))
            ax.set_xlabel("Time (s)")
            ax.set_ylabel("Mean Count")
            ax.grid(True, alpha=0.3)
            apply_timer_xlim(ax)

        # Hide unused axes
        for j in range(len(actions), len(axes)):
            axes[j].set_visible(False)

        # Add global legend below
        if handles and labels:
            legend_title = "Bot (Rank)" if (group_by == "Bot" and rank_map is not None) else group_by
            fig.legend(
                handles, labels, title=legend_title,
                loc="upper center", bbox_to_anchor=(0.5, -0.02),
                ncol=min(6, len(labels))
            )
        fig.suptitle("Per-action intensity over timer")
        fig.tight_layout()
        return fig

    else:
        raise ValueError("mode must be one of ['total','per_action','select']")


def plot_full_cross_heatmap_half(df, bot_name="Bot_NN", key="WinRate_L", max_labels=40, lower_triangle=True):
    cfg_cols = ["Timer", "ActInterval", "Round", "SkillLeft", "SkillRight"]
    df_bot = df[df["Bot_L"] == bot_name].copy()
    
    # Melt configurations
    melted = df_bot.melt(
        id_vars=[key],
        value_vars=cfg_cols,
        var_name="ConfigType",
        value_name="ConfigValue"
    )

    # Cartesian join (self merge)
    merged = melted.merge(melted, on=key, suffixes=("_X", "_Y"))
    merged = merged[merged["ConfigType_X"] != merged["ConfigType_Y"]]

    # Aggregate mean WinRate
    grouped = (
        merged.groupby(["ConfigType_X", "ConfigValue_X", "ConfigType_Y", "ConfigValue_Y"])[key]
        .mean()
        .reset_index()
    )

    # Label for axes
    grouped["X"] = grouped["ConfigType_X"] + "=" + grouped["ConfigValue_X"].astype(str)
    grouped["Y"] = grouped["ConfigType_Y"] + "=" + grouped["ConfigValue_Y"].astype(str)

    # Pivot into matrix
    pivot = grouped.pivot(index="Y", columns="X", values=key)

    # Drop all-NaN rows and columns

    # Clip to manageable size
    if len(pivot) > max_labels or len(pivot.columns) > max_labels:
        pivot = pivot.iloc[:max_labels, :max_labels]

    # Ensure symmetry (optional, if slightly different values occur)
    pivot = (pivot + pivot.T) / 2

    pivot = pivot.dropna(axis=0, how="all")
    pivot = pivot.dropna(axis=1, how="all")

    # Build triangular mask
    # mask = np.triu(np.ones_like(pivot, dtype=bool)) if lower_triangle else np.tril(np.ones_like(pivot, dtype=bool))

    # Plot
    fig, ax = plt.subplots(figsize=(max(10, len(pivot.columns)*0.4), max(8, len(pivot)*0.3)))
    sns.heatmap(
        pivot,
        cmap="Blues",
        annot=True,
        fmt=".2f",
        # mask=mask,            # ✅ Hide upper (or lower) triangle
        linewidths=0.5,
        cbar_kws={'label': 'Win Rate'},
        ax=ax
    )

    ax.set_title(f"Cross Configuration Win Rate (Half Matrix) for {bot_name}", fontsize=14, pad=12)
    ax.set_xlabel("Config X", fontsize=12)
    ax.set_ylabel("Config Y", fontsize=12)
    ax.tick_params(axis="x", rotation=45, labelsize=9)
    ax.tick_params(axis="y", labelsize=9)
    fig.tight_layout()
    return fig


def plot_grouped_config_winrates(
    df: pd.DataFrame,
    bot_col: str = "Bot_L",
    metric: str = "WinRate_L",
    config_col: str = "Timer",
    width: int = 10,
    height: int = 9,
    title: str = None,
    ylabel: str = None,
):
    """
    Create a grouped bar chart showing win-rates (or other metrics) grouped by a single configuration parameter.

    Parameters
    ----------
    df : pd.DataFrame
        Summary dataframe (e.g., matchup_summary)
    bot_col : str
        Column name for bots (default: "Bot_L")
    metric : str
        Metric to plot (default: "WinRate_L")
    config_col : str
        Configuration column to group by (default: "Timer")
        Special case: "Skill" will use both SkillLeft and SkillRight
    width : int
        Figure width
    height : int
        Figure height
    title : str
        Plot title (optional)
    ylabel : str
        Y-axis label (optional)

    Returns
    -------
    matplotlib.figure.Figure
    """

    # Get unique bots and sort by rank
    rank_col = "Rank_L" if bot_col == "Bot_L" else "Rank_R"
    if rank_col in df.columns:
        rank_map = df.groupby(bot_col)[rank_col].first().to_dict()
        bots = sorted(df[bot_col].unique(), key=lambda b: rank_map.get(b, 9999))
    else:
        bots = sorted(df[bot_col].unique())
        rank_map = {}

    # Determine if we need to calculate per-game averages
    per_game_metrics = ["Collisions", "ActionCounts", "Duration", "MatchDur"]
    needs_per_game = any(m in metric for m in per_game_metrics)

    # Special handling for "Skill" - use both SkillLeft and SkillRight
    if config_col == "Skill":
        # Merge left and right data
        if needs_per_game:
            # Include Games column for per-game calculation
            left_data = df[[bot_col, "SkillLeft", metric, "Games"]].copy()
            left_data = left_data.rename(columns={"SkillLeft": "SkillType"})

            right_data = df[[bot_col.replace("_L", "_R"), "SkillRight", metric.replace("_L", "_R"), "Games"]].copy()
            right_data = right_data.rename(columns={
                bot_col.replace("_L", "_R"): bot_col,
                "SkillRight": "SkillType",
                metric.replace("_L", "_R"): metric
            })

            combined = pd.concat([left_data, right_data], ignore_index=True)
            # Calculate per-game average first
            combined['metric_per_game'] = combined[metric] / combined['Games']
            grouped = combined.groupby([bot_col, "SkillType"])['metric_per_game'].agg(['mean', 'std']).reset_index()
        else:
            left_data = df[[bot_col, "SkillLeft", metric]].copy()
            left_data = left_data.rename(columns={"SkillLeft": "SkillType"})

            right_data = df[[bot_col.replace("_L", "_R"), "SkillRight", metric.replace("_L", "_R")]].copy()
            right_data = right_data.rename(columns={
                bot_col.replace("_L", "_R"): bot_col,
                "SkillRight": "SkillType",
                metric.replace("_L", "_R"): metric
            })

            combined = pd.concat([left_data, right_data], ignore_index=True)
            grouped = combined.groupby([bot_col, "SkillType"])[metric].agg(['mean', 'std']).reset_index()

        config_values = sorted(grouped["SkillType"].unique())
        config_col_display = "SkillType"
    else:
        # Normal handling for other config columns
        config_values = sorted(df[config_col].unique())

        if needs_per_game:
            # Calculate per-game average first
            df_copy = df.copy()
            df_copy['metric_per_game'] = df_copy[metric] / df_copy['Games']
            grouped = df_copy.groupby([bot_col, config_col])['metric_per_game'].agg(['mean', 'std']).reset_index()
        else:
            grouped = df.groupby([bot_col, config_col])[metric].agg(['mean', 'std']).reset_index()

        config_col_display = config_col

    # Create the grouped bar chart
    fig, ax = plt.subplots(figsize=(width, height))

    # Define colors for each config value
    colors = ['#d62728', '#ff7f0e', '#2ca02c', '#17becf', '#9467bd', '#8c564b']
    config_colors = {val: colors[i % len(colors)] for i, val in enumerate(config_values)}

    # Set up bar positions
    n_bots = len(bots)
    n_configs = len(config_values)
    bar_width = 0.8 / n_configs
    x_positions = np.arange(n_bots)

    # Plot bars for each config value
    for i, config_val in enumerate(config_values):
        if config_col == "Skill":
            config_data = grouped[grouped["SkillType"] == config_val]
        else:
            config_data = grouped[grouped[config_col] == config_val]

        means = []
        # stds = []

        for bot in bots:
            bot_data = config_data[config_data[bot_col] == bot]
            if not bot_data.empty:
                means.append(bot_data['mean'].values[0])
                # std_val = bot_data['std'].values[0]
                # stds.append(std_val if not pd.isna(std_val) else 0)
            else:
                means.append(0)
                # stds.append(0)

        offset = (i - n_configs/2 + 0.5) * bar_width
        bars = ax.bar(x_positions + offset, means, bar_width,
               label=str(config_val), color=config_colors[config_val],)
            #    yerr=stds, capsize=3, error_kw={'linewidth': 1.5, 'elinewidth': 1})

        # Add value labels inside bars
        for bar, mean_val in zip(bars, means):
            height = bar.get_height()
            if height > 0:  # Only add label if bar has height
                ax.text(bar.get_x() + bar.get_width()/2., height/2,
                        f'{mean_val:.1f}',
                        ha='center', va='center', fontsize=7, fontweight='bold', color='white')

        
    # Customize plot
    ax.set_xlabel(get_metric_name('Bot'), fontsize=12, fontweight='bold')
    ax.set_ylabel(ylabel if ylabel else get_metric_name(metric.replace('_L', '')), fontsize=12, fontweight='bold')

    if title:
        plot_title = title
    else:
        display_name = get_metric_name("Skill") if config_col == "Skill" else get_metric_name(config_col)
        plot_title = f'{get_metric_name(metric.replace("_L", ""))} grouped by {display_name}'

    ax.set_title(plot_title, fontsize=14, fontweight='bold', pad=15)
    ax.set_xticks(x_positions)
    # Create bot labels with rank
    if rank_map:
        bot_labels = [f"{bot} (#{int(rank_map[bot])})" for bot in bots]
    else:
        bot_labels = bots
    ax.set_xticklabels(bot_labels, rotation=30, ha='right')

    # Calculate dynamic padding for legend
    legend_padding = calculate_legend_padding(ax, x_labels=bot_labels, rotation=30)
    ax.legend(title=config_col_display, loc='upper center', bbox_to_anchor=(0.5, legend_padding),
              ncol=min(6, n_configs), fontsize=10, framealpha=0.9)
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    fig.tight_layout()
    return fig


def plot_overall_bot_metrics(
    df: pd.DataFrame,
    bot_col: str = "Bot_L",
    metric: str = "Collisions_L",
    width: int = 10,
    height: int = 6,
    title: str = None,
    ylabel: str = None,
):
    """
    Create a simple bar chart showing mean metric values per bot across all configurations.

    Parameters
    ----------
    df : pd.DataFrame
        Summary dataframe (e.g., matchup_summary)
    bot_col : str
        Column name for bots (default: "Bot_L")
    metric : str
        Metric to plot. Options:
        - "Collisions_L" or "Collisions_R": Total collisions
        - "ActionCounts_L" or "ActionCounts_R": Total action counts
        - "Duration_L" or "Duration_R": Action duration
        - "MatchDur": Match duration
        - "Games": Total games
    width : int
        Figure width
    height : int
        Figure height
    title : str
        Plot title (optional)
    ylabel : str
        Y-axis label (optional)

    Returns
    -------
    matplotlib.figure.Figure

    Examples
    --------
    >>> plot_overall_bot_metrics(df, metric="Collisions_L", title="Mean Collisions per Bot")
    >>> plot_overall_bot_metrics(df, metric="ActionCounts_L", title="Mean Actions per Bot")
    >>> plot_overall_bot_metrics(df, metric="MatchDur", title="Mean Match Duration per Bot")
    """

    # Get unique bots and sort by rank
    rank_col = "Rank_L" if bot_col == "Bot_L" else "Rank_R"
    if rank_col in df.columns:
        rank_map = df.groupby(bot_col)[rank_col].first().to_dict()
        bots = sorted(df[bot_col].unique(), key=lambda b: rank_map.get(b, 9999))
    else:
        bots = sorted(df[bot_col].unique())
        rank_map = {}

    # Determine if we need to calculate per-game averages
    per_game_metrics = ["Collisions", "ActionCounts", "Duration", "MatchDur"]
    needs_per_game = any(m in metric for m in per_game_metrics)

    if needs_per_game:
        # Calculate per-game average
        df_copy = df.copy()
        df_copy['metric_per_game'] = df_copy[metric] / df_copy['Games']
        grouped = df_copy.groupby(bot_col)['metric_per_game'].mean().reset_index()
        grouped.columns = [bot_col, 'mean_value']
    else:
        # Direct mean for metrics like winrate
        grouped = df.groupby(bot_col)[metric].mean().reset_index()
        grouped.columns = [bot_col, 'mean_value']

    # Create the bar chart
    fig, ax = plt.subplots(figsize=(width, height))

    # Prepare data for plotting
    means = []
    for bot in bots:
        bot_data = grouped[grouped[bot_col] == bot]
        if not bot_data.empty:
            means.append(bot_data['mean_value'].values[0])
        else:
            means.append(0)

    # Plot bars
    x_positions = np.arange(len(bots))
    bars = ax.bar(x_positions, means, width=0.6, color='#2ca02c', alpha=0.8, edgecolor='black', linewidth=1.2)

    # Add value labels inside bars
    for bar, mean_val in zip(bars, means):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height/2,
                f'{mean_val:.1f}',
                ha='center', va='center', fontsize=9, fontweight='bold', color='white')

    # Customize plot
    ax.set_xlabel(get_metric_name('Bot'), fontsize=12, fontweight='bold')
    ax.set_ylabel(ylabel if ylabel else get_metric_name(metric.replace('_L', '')), fontsize=12, fontweight='bold')

    if title:
        plot_title = title
    else:
        plot_title = f'Mean {get_metric_name(metric.replace("_L", ""))} per {get_metric_name("Bot")} (across all configurations)'

    ax.set_title(plot_title, fontsize=14, fontweight='bold', pad=15)
    ax.set_xticks(x_positions)

    # Create bot labels with rank
    if rank_map:
        bot_labels = [f"{bot} (#{int(rank_map[bot])})" for bot in bots]
    else:
        bot_labels = bots
    ax.set_xticklabels(bot_labels, rotation=30, ha='right')
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    fig.tight_layout()
    return fig


def plot_action_radar(df, bot_col="Bot_L", width=14, height=12, scale=None, radial_limit="auto"):
    """
    Create a radar chart showing mean action counts per bot.

    Parameters:
        scale (str): Scale for radial axis (None = linear, "sqrt", "log")
        radial_limit (str or float):
                    - "auto": Set max based on 95th percentile (recommended for better spacing)
                    - "max": Use the absolute max value
                    - float: Manually set the max radial value
    """
    # Get all action columns
    action_cols = [col for col in df.columns if col.endswith("_Act_L")]

    # Get unique bots and sort by rank
    rank_col = "Rank_L" if bot_col == "Bot_L" else "Rank_R"
    if rank_col in df.columns:
        rank_map = df.groupby(bot_col)[rank_col].first().to_dict()
        bots = sorted(df[bot_col].unique(), key=lambda b: rank_map.get(b, 9999))
    else:
        bots = sorted(df[bot_col].unique())
        rank_map = {}

    # Calculate mean action counts per bot (raw values)
    # Merge SkillBoost and SkillStone into single "Skill"
    bot_data_raw = {}
    action_names = []

    for bot in bots:
        bot_df = df[df[bot_col] == bot]
        means = []

        # Build action list on first iteration
        if not action_names:
            for col in action_cols:
                name = col.replace("_Act_L", "")

                # Skip SkillStone (will be merged with SkillBoost)
                if name == "SkillStone":
                    continue

                # Rename SkillBoost to Skill
                if name == "SkillBoost":
                    action_names.append("Skill")
                else:
                    action_names.append(name)

        # Calculate means with merged skills
        for col in action_cols:
            name = col.replace("_Act_L", "")

            if name == "SkillStone":
                continue  # Skip, already merged

            if name == "SkillBoost":
                # Merge SkillBoost + SkillStone
                skill_boost = bot_df[col].mean()
                skill_stone = bot_df.get("SkillStone_Act_L", bot_df[col] * 0).mean()  # Handle if column doesn't exist
                means.append(skill_boost + skill_stone)
            else:
                means.append(bot_df[col].mean())

        bot_data_raw[bot] = means

    # Transform values based on scale
    bot_data = {}
    for bot, values in bot_data_raw.items():
        if scale == "sqrt":
            bot_data[bot] = [np.sqrt(v) for v in values]
        elif scale == "log":
            bot_data[bot] = [np.log10(v + 1) for v in values]  # +1 to handle zeros
        else:  # linear / None
            bot_data[bot] = values

    # Set up radar chart
    num_vars = len(action_names)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]  # Complete the circle

    fig, ax = plt.subplots(figsize=(width, height), subplot_kw=dict(projection='polar'))

    # Plot each bot
    colors = ['#d62728', '#ff7f0e', '#2ca02c', '#17becf', '#9467bd', '#8c564b']
    for i, (bot, values) in enumerate(bot_data.items()):
        values += values[:1]  # Complete the circle
        # Format label with rank if available
        if rank_map and bot in rank_map:
            label = f"{bot} (#{int(rank_map[bot])})"
        else:
            label = bot
        # Use bot-specific marker
        marker = get_bot_marker(bot)
        ax.plot(angles, values, f'{marker}-', linewidth=2.5, markersize=8,
                label=label, color=colors[i % len(colors)])

    # Set labels with better styling
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(action_names, fontsize=11)

    # Set radial limit for better spacing
    all_values = [v for bot_vals in bot_data.values() for v in bot_vals[:-1]]
    if radial_limit == "auto":
        max_val = np.percentile(all_values, 95)  # Use 95th percentile
        ax.set_ylim(0, max_val * 1.15)  # Add 15% padding
    elif radial_limit == "max":
        max_val = max(all_values)
        ax.set_ylim(0, max_val * 1.1)
    elif isinstance(radial_limit, (int, float)):
        ax.set_ylim(0, radial_limit)

    # Add more radial grid lines for better readability
    ax.yaxis.set_major_locator(plt.MaxNLocator(8))
    ax.tick_params(axis='y', labelsize=9)

    # Set y-axis label based on scale
    if scale == "sqrt":
        ax.set_ylabel('√(Mean Action Count)', labelpad=35, fontsize=11)
    elif scale == "log":
        ax.set_ylabel('log₁₀(Mean Action Count + 1)', labelpad=35, fontsize=11)
    else:
        ax.set_ylabel('Mean Action Count', labelpad=35, fontsize=11)

    ax.set_title('Actions Behaviour', size=16, pad=20, fontweight='bold')
    legend_title = "Bot (Rank)" if rank_map else "Bot"
    ax.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, -0.1), fontsize=10, framealpha=0.9, ncol=3)
    ax.grid(True, linestyle='--', linewidth=0.7, alpha=0.7)

    fig.tight_layout()
    return fig


def plot_collision_radar(df, bot_col="Bot_L", width=14, height=12, scale=None):
    """
    Create a triangular radar chart showing collision outcomes per bot.
    Three axes: hit (wins), tie (draws), being_hit (losses)

    Parameters:
        scale (str): Scale for radial axis. Options:
                    - "linear": Raw values
                    - "sqrt": Square root scale (recommended - shows all values clearly)
                    - "log": Logarithmic scale (more aggressive compression)
    """
    # Get unique bots and sort by rank
    rank_col = "Rank_L" if bot_col == "Bot_L" else "Rank_R"
    if rank_col in df.columns:
        rank_map = df.groupby(bot_col)[rank_col].first().to_dict()
        bots = sorted(df[bot_col].unique(), key=lambda b: rank_map.get(b, 9999))
    else:
        bots = sorted(df[bot_col].unique())
        rank_map = {}

    # Calculate collision statistics per bot (raw values)
    bot_data_raw = {}
    for bot in bots:
        # Get data for this bot on left side
        left_df = df[df[bot_col] == bot]

        # Calculate totals (raw counts)
        hit = left_df["Collisions_L"].sum()
        being_hit = left_df["Collisions_R"].sum()
        ties = left_df["Collisions_Tie"].sum()

        # Store as a list: [hit, tie, being_hit]
        bot_data_raw[bot] = [hit, ties, being_hit]

    # Transform values based on scale
    bot_data = {}
    for bot, values in bot_data_raw.items():
        if scale == "sqrt":
            bot_data[bot] = [np.sqrt(v) for v in values]
        elif scale == "log":
            bot_data[bot] = [np.log10(v + 1) for v in values]  # +1 to handle zeros
        else:  # linear
            bot_data[bot] = values

    # Set up triangular radar chart (3 vertices)
    collision_types = ['hit', 'tie', 'struck']
    num_vars = len(collision_types)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]  # Complete the circle

    fig, ax = plt.subplots(figsize=(width, height), subplot_kw=dict(projection='polar'))

    # Plot each bot
    colors = ['#d62728', '#ff7f0e', '#2ca02c', '#17becf', '#9467bd', '#8c564b']
    for i, (bot, values) in enumerate(bot_data.items()):
        values += values[:1]  # Complete the circle
        # Format label with rank if available
        if rank_map and bot in rank_map:
            label = f"{bot} (#{int(rank_map[bot])})"
        else:
            label = bot
        # Use bot-specific marker
        marker = get_bot_marker(bot)
        ax.plot(angles, values, f'{marker}-', linewidth=2.5, markersize=8,
                label=label, color=colors[i % len(colors)])

    # Set labels with better styling
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(collision_types, fontsize=11)

    # Add more radial grid lines for better readability
    ax.yaxis.set_major_locator(plt.MaxNLocator(8))
    ax.tick_params(axis='y', labelsize=9)

    # Set y-axis label based on scale
    if scale == "sqrt":
        ax.set_ylabel('√(Collision Count)', labelpad=35, fontsize=11)
    elif scale == "log":
        ax.set_ylabel('log₁₀(Collision Count + 1)', labelpad=35, fontsize=11)
    else:
        ax.set_ylabel('Collision Count', labelpad=35, fontsize=11)

    ax.set_title('Collision Behaviour', size=16, pad=20, fontweight='bold')
    legend_title = "Bot (Rank)" if rank_map else "Bot"
    ax.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, -0.1), fontsize=10, framealpha=0.9, ncol=3)
    ax.grid(True, linestyle='--', linewidth=0.7, alpha=0.7)

    fig.tight_layout()
    return fig


def plot_action_distribution_stacked(df, bot_col="Bot_L", width=10, height=6, normalize=False):
    """
    Create a stacked bar chart showing action type distribution per bot.

    Parameters:
        df: Summary dataframe with action counts
        bot_col: Column name for bots (default: "Bot_L")
        width: Figure width
        height: Figure height
        normalize: If True, normalize bars to 100% (show proportions)
                   If False, show absolute counts

    Returns:
        matplotlib.figure.Figure
    """
    # Define action columns (merge SkillBoost and SkillStone into "Skill")
    action_mapping = {
        'Accelerate': 'Accelerate_Act_L',
        'TurnLeft': 'TurnLeft_Act_L',
        'TurnRight': 'TurnRight_Act_L',
        'Dash': 'Dash_Act_L',
        'Skill': ['SkillBoost_Act_L', 'SkillStone_Act_L']
    }

    # Get unique bots and sort by rank
    rank_col = "Rank_L" if bot_col == "Bot_L" else "Rank_R"
    if rank_col in df.columns:
        rank_map = df.groupby(bot_col)[rank_col].first().to_dict()
        bots = sorted(df[bot_col].unique(), key=lambda b: rank_map.get(b, 9999))
    else:
        bots = sorted(df[bot_col].unique())
        rank_map = {}

    # Prepare data for stacking
    action_data = {action: [] for action in action_mapping.keys()}

    for bot in bots:
        bot_df = df[df[bot_col] == bot]

        for action_name, col_names in action_mapping.items():
            if isinstance(col_names, list):
                # Merge multiple columns (for Skill)
                total = sum(bot_df[col].sum() for col in col_names if col in bot_df.columns)
            else:
                # Single column
                total = bot_df[col_names].sum() if col_names in bot_df.columns else 0

            action_data[action_name].append(total)

    # Convert to DataFrame for easier plotting
    data_df = pd.DataFrame(action_data, index=bots)

    # Normalize if requested
    if normalize:
        data_df = data_df.div(data_df.sum(axis=1), axis=0) * 100

    # Create stacked bar chart
    fig, ax = plt.subplots(figsize=(width, height))

    # Define colors for each action type
    colors = {
        'Accelerate': '#d62728',    # Red
        'TurnLeft': '#ff7f0e',      # Orange
        'TurnRight': '#2ca02c',     # Green
        'Dash': '#17becf',          # Cyan
        'Skill': '#1f77b4'          # Blue
    }

    # Create bot labels with rank
    if rank_map:
        bot_labels = [f"{bot} (#{int(rank_map[bot])})" for bot in bots]
    else:
        bot_labels = bots

    # Plot stacked bars
    bottom = np.zeros(len(bots))
    x_pos = np.arange(len(bots))
    for action in action_mapping.keys():
        ax.bar(x_pos, data_df[action], bottom=bottom,
               label=action, color=colors[action], width=0.6)
        bottom += data_df[action]

    # Customize plot
    ax.set_xticks(x_pos)
    ax.set_xticklabels(bot_labels)
    ax.set_xlabel('Bots', fontsize=12, fontweight='bold')

    if normalize:
        ax.set_ylabel('Action Distribution (%)', fontsize=12, fontweight='bold')
        ax.set_title('Action Type Distribution per Bot (Normalized)',
                     fontsize=14, fontweight='bold', pad=15)
        ax.set_ylim(0, 100)
    else:
        ax.set_ylabel('Total Action Count', fontsize=12, fontweight='bold')
        ax.set_title('Action Type Distribution per Bot',
                     fontsize=14, fontweight='bold', pad=15)
        
    legend_padding = calculate_legend_padding(ax, rotation=0)
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, legend_padding), ncol=5,
              fontsize=10, framealpha=0.9)
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    # Rotate x-axis labels if many bots
    if len(bots) > 5:
        plt.setp(ax.get_xticklabels(), rotation=30, ha='right')

    fig.tight_layout()
    return fig


def plot_collision_distribution_stacked(df, bot_col="Bot_L", width=10, height=6, normalize=False):
    """
    Create a stacked bar chart showing collision type distribution per bot.

    Parameters:
        df: Summary dataframe with collision counts
        bot_col: Column name for bots (default: "Bot_L")
        width: Figure width
        height: Figure height
        normalize: If True, normalize bars to 100% (show proportions)
                   If False, show absolute counts

    Returns:
        matplotlib.figure.Figure
    """
    # Get unique bots and sort by rank
    rank_col = "Rank_L" if bot_col == "Bot_L" else "Rank_R"
    if rank_col in df.columns:
        rank_map = df.groupby(bot_col)[rank_col].first().to_dict()
        bots = sorted(df[bot_col].unique(), key=lambda b: rank_map.get(b, 9999))
    else:
        bots = sorted(df[bot_col].unique())
        rank_map = {}

    # Prepare data for stacking
    collision_data = {
        'Hit': [],           # Collisions won (Collisions_L)
        'Struck': [],     # Collisions lost (Collisions_R)
        'Tie': []            # Tie collisions
    }

    for bot in bots:
        bot_df = df[df[bot_col] == bot]

        # Calculate totals
        hit = bot_df['Collisions_L'].sum() if 'Collisions_L' in bot_df.columns else 0
        being_hit = bot_df['Collisions_R'].sum() if 'Collisions_R' in bot_df.columns else 0
        tie = bot_df['Collisions_Tie'].sum() if 'Collisions_Tie' in bot_df.columns else 0

        collision_data['Hit'].append(hit)
        collision_data['Struck'].append(being_hit)
        collision_data['Tie'].append(tie)

    # Convert to DataFrame for easier plotting
    data_df = pd.DataFrame(collision_data, index=bots)

    # Normalize if requested
    if normalize:
        data_df = data_df.div(data_df.sum(axis=1), axis=0) * 100

    # Create stacked bar chart
    fig, ax = plt.subplots(figsize=(width, height))

    # Define colors for each collision type
    colors = {
        'Hit': '#2ca02c',         # Green (wins)
        'Struck': '#d62728',   # Red (losses)
        'Tie': '#ff7f0e'          # Orange (ties)
    }

    # Create bot labels with rank
    if rank_map:
        bot_labels = [f"{bot} (#{int(rank_map[bot])})" for bot in bots]
    else:
        bot_labels = bots

    # Plot stacked bars
    bottom = np.zeros(len(bots))
    x_pos = np.arange(len(bots))
    for collision_type in collision_data.keys():
        ax.bar(x_pos, data_df[collision_type], bottom=bottom,
               label=collision_type, color=colors[collision_type], width=0.6)
        bottom += data_df[collision_type]

    # Customize plot
    ax.set_xticks(x_pos)
    ax.set_xticklabels(bot_labels)
    ax.set_xlabel('Bots', fontsize=12, fontweight='bold')

    if normalize:
        ax.set_ylabel('Collision Distribution (%)', fontsize=12, fontweight='bold')
        ax.set_title('Collision Type Distribution per Bot (Normalized)',
                     fontsize=14, fontweight='bold', pad=15)
        ax.set_ylim(0, 100)
    else:
        ax.set_ylabel('Total Collision Count', fontsize=12, fontweight='bold')
        ax.set_title('Collision Type Distribution per Bot',
                     fontsize=14, fontweight='bold', pad=15)
        
    legend_padding = calculate_legend_padding(ax, rotation=0)
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, legend_padding), ncol=3,
              fontsize=10, framealpha=0.9)
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    # Rotate x-axis labels if many bots
    if len(bots) > 5:
        plt.setp(ax.get_xticklabels(), rotation=30, ha='right')

    fig.tight_layout()
    return fig


def plot_collision_timebins_intensity(
    df,
    group_by="Bot_L",  # "Bot_L" or "Bot_R"
    timer=None,
    act_interval=None,
    round=None,
    mode="total",        # "total" | "per_type" | "select"
    collision_type=None,  # "Actor_L" | "Actor_R" | "Tie" (used when mode == "select")
    width=10,
    height=6,
    summary_df=None,     # Optional: summary dataframe with Rank_L column for bot ranking
):
    """
    Plot collision intensity over time (with optional timer cutoff).

    Modes:
      - "total": sum all collision types -> one line per bot pairing
      - "per_type": show per-collision-type trends; creates one subplot per type (Actor_L, Actor_R, Tie)
      - "select": plot a single collision_type for all bot pairings

    Args:
        df: DataFrame from summary_collision_timebins.csv
        group_by: "Bot_L" or "Bot_R" to group by bot
        timer: Filter by specific timer value
        act_interval: Filter by specific action interval
        round: Filter by specific round
        mode: Visualization mode
        collision_type: Which collision type to show (for mode="select")
        width, height: Figure dimensions
    """

    # --- Filters ---
    if timer is not None:
        df = df[df["Timer"] == timer]
    if act_interval is not None:
        df = df[df["ActInterval"] == act_interval]
    if round is not None:
        df = df[df["Round"] == round]

    if df.empty:
        print("⚠️ No data after filtering.")
        return None

    # --- Preprocess TimeBin ---
    df = df.copy()
    df["TimeBin"] = pd.to_numeric(df["TimeBin"], errors="coerce")
    df = df.dropna(subset=["TimeBin"])
    df = df.sort_values("TimeBin")

    # --- Add rank to bot names if group_by is a bot column ---
    rank_map = None
    if group_by in ["Bot_L", "Bot_R"]:
        # Try to get rank from summary_df first, then from df itself
        if summary_df is not None and "Rank_L" in summary_df.columns:
            rank_map = summary_df.groupby("Bot_L")["Rank_L"].first().to_dict()
        elif "Rank_L" in df.columns:
            rank_map = df.groupby(group_by)["Rank_L"].first().to_dict()

    if rank_map:
        df["BotWithRank"] = df[group_by].map(lambda b: f"{b} (#{int(rank_map.get(b, 999))})")
        group_by_plot = "BotWithRank"
        # Sort bots by rank
        bot_order = sorted(rank_map.keys(), key=lambda b: rank_map.get(b, 999))
        bot_order_with_rank = [f"{b} (#{int(rank_map[b])})" for b in bot_order]
    else:
        group_by_plot = group_by
        bot_order_with_rank = None

    # --- Helper: apply x-axis cutoff ---
    def apply_timer_xlim(ax):
        if timer is not None:
            ax.set_xlim(0, timer)
            ax.set_xticks(range(0, int(timer) + 1, max(1, int(timer // 10) or 1)))

    # --- Plot modes ---
    if mode == "select":
        if not collision_type:
            raise ValueError("collision_type must be provided when mode='select'")
        if collision_type not in ["Actor_L", "Actor_R", "Tie"]:
            raise ValueError("collision_type must be one of ['Actor_L', 'Actor_R', 'Tie']")

        grouped = df.groupby([group_by_plot, "TimeBin"], as_index=False)[collision_type].mean()

        fig, ax = plt.subplots(figsize=(width, height))
        plot_with_bot_markers(ax, data=grouped, x="TimeBin", y=collision_type,
                            hue=group_by_plot, hue_order=bot_order_with_rank)
        ax.set_title(f"Mean {collision_type} collisions over time")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Mean Count")
        legend_title = "Bot (Rank)" if (group_by in ["Bot_L", "Bot_R"] and rank_map is not None) else group_by

        # Calculate dynamic padding for legend
        legend_padding = calculate_legend_padding(ax, rotation=0)
        ax.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, legend_padding), ncol=3,
                 framealpha=0.9, markerscale=1.2)
        ax.grid(True, alpha=0.3)
        apply_timer_xlim(ax)
        fig.tight_layout()
        return fig

    elif mode == "total":
        # Sum all collision types
        df["TotalCollisions"] = df["Actor_L"] + df["Actor_R"] + df["Tie"]
        grouped = df.groupby([group_by_plot, "TimeBin"], as_index=False)["TotalCollisions"].mean()

        fig, ax = plt.subplots(figsize=(width, height))
        plot_with_bot_markers(ax, data=grouped, x="TimeBin", y="TotalCollisions",
                            hue=group_by_plot, hue_order=bot_order_with_rank)
        ax.set_title("Total collision intensity over time")
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Mean Count (summed over collision types)")
        legend_title = "Bot (Rank)" if (group_by in ["Bot_L", "Bot_R"] and rank_map is not None) else group_by

        # Calculate dynamic padding for legend
        legend_padding = calculate_legend_padding(ax, rotation=0)
        ax.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, legend_padding), ncol=3,
                 framealpha=0.9, markerscale=1.2)
        ax.grid(True, alpha=0.3)
        apply_timer_xlim(ax)
        fig.tight_layout()
        return fig

    elif mode == "per_type":
        collision_types = ["Actor_L", "Actor_R", "Tie"]
        n = len(collision_types)
        fig, axes = plt.subplots(
            nrows=1,
            ncols=n,
            figsize=(width, height),
            squeeze=False
        )
        axes = axes.flatten()

        handles, labels = None, None
        for i, ctype in enumerate(collision_types):
            ax = axes[i]
            sub = df.groupby([group_by_plot, "TimeBin"], as_index=False)[ctype].mean()
            if sub.empty:
                ax.set_visible(False)
                continue

            # Plot with bot markers
            plot_with_bot_markers(ax, data=sub, x="TimeBin", y=ctype,
                                hue=group_by_plot, hue_order=bot_order_with_rank)

            # Capture legend handles from first plot, then remove it
            if i == 0:
                handles, labels = ax.get_legend_handles_labels()
                legend = ax.get_legend()
                if legend is not None:
                    legend.remove()

            ax.set_title(ctype)
            ax.set_xlabel("Time (s)")
            ax.set_ylabel("Mean Count")
            ax.grid(True, alpha=0.3)
            apply_timer_xlim(ax)

        # Add global legend below
        if handles and labels:
            legend_title = "Bot (Rank)" if (group_by in ["Bot_L", "Bot_R"] and rank_map is not None) else group_by
            fig.legend(
                handles, labels, title=legend_title,
                loc="upper center", bbox_to_anchor=(0.5, -0.02),
                ncol=min(6, len(labels))
            )
        fig.suptitle("Per-collision-type intensity over timer")
        fig.tight_layout()
        return fig

    else:
        raise ValueError("mode must be one of ['total','per_type','select']")


def prepare_correlation_data(df):
    """
    Prepare data for correlation analysis by combining left and right perspectives.

    Args:
        df: Summary matchup dataframe

    Returns:
        Combined dataframe with all bots' data
    """
    # Bot_L perspective
    df_left = df.copy()
    df_left['Bot'] = df_left['Bot_L']
    df_left['WinRate'] = df_left['WinRate_L']
    df_left['Actions'] = df_left['ActionCounts_L']
    df_left['Collisions'] = df_left['Collisions_L']
    df_left['Duration'] = df_left['Duration_L']
    df_left['Accelerate_Act'] = df_left['Accelerate_Act_L']
    df_left['TurnLeft_Act'] = df_left['TurnLeft_Act_L']
    df_left['TurnRight_Act'] = df_left['TurnRight_Act_L']
    df_left['Dash_Act'] = df_left['Dash_Act_L']
    df_left['SkillBoost_Act'] = df_left['SkillBoost_Act_L']
    df_left['SkillStone_Act'] = df_left['SkillStone_Act_L']
    df_left['Accelerate_Dur'] = df_left['Accelerate_Dur_L']
    df_left['TurnLeft_Dur'] = df_left['TurnLeft_Dur_L']
    df_left['TurnRight_Dur'] = df_left['TurnRight_Dur_L']
    df_left['Dash_Dur'] = df_left['Dash_Dur_L']
    # df_left['SkillBoost_Dur'] = df_left['SkillBoost_Dur_L']
    # df_left['SkillStone_Dur'] = df_left['SkillStone_Dur_L']

    # Bot_R perspective
    df_right = df.copy()
    df_right['Bot'] = df_right['Bot_R']
    df_right['WinRate'] = df_right['WinRate_R']
    df_right['Actions'] = df_right['ActionCounts_R']
    df_right['Collisions'] = df_right['Collisions_R']
    df_right['Duration'] = df_right['Duration_R']
    df_right['Accelerate_Act'] = df_right['Accelerate_Act_R']
    df_right['TurnLeft_Act'] = df_right['TurnLeft_Act_R']
    df_right['TurnRight_Act'] = df_right['TurnRight_Act_R']
    df_right['Dash_Act'] = df_right['Dash_Act_R']
    df_right['SkillBoost_Act'] = df_right['SkillBoost_Act_R']
    df_right['SkillStone_Act'] = df_right['SkillStone_Act_R']
    df_right['Accelerate_Dur'] = df_right['Accelerate_Dur_R']
    df_right['TurnLeft_Dur'] = df_right['TurnLeft_Dur_R']
    df_right['TurnRight_Dur'] = df_right['TurnRight_Dur_R']
    df_right['Dash_Dur'] = df_right['Dash_Dur_R']
    # df_right['SkillBoost_Dur'] = df_right['SkillBoost_Dur_R']
    # df_right['SkillStone_Dur'] = df_right['SkillStone_Dur_R']

    # Combine both perspectives
    df_combined = pd.concat([df_left, df_right], ignore_index=True)

    # Add derived columns
    df_combined['RoundNumeric'] = df_combined['Round'].map({'BestOf1': 1, 'BestOf3': 3})
    df_combined['TotalSkillAct'] = df_combined['SkillBoost_Act'] + df_combined['SkillStone_Act']

    return df_combined


def plot_correlation_scatter(data, x_col, y_col, title, color_by='Bot',
                            alpha=0.95, figsize=(10, 8), add_jitter=False, add_per_bot_regression=False):
    """
    Create scatter plot with regression line and Pearson correlation.

    Args:
        data: DataFrame with data to plot
        x_col: Column name for x-axis
        y_col: Column name for y-axis (should be WinRate)
        title: Plot title
        color_by: Column to color points by (default: 'Bot')
        alpha: Transparency of scatter points
        figsize: Figure size tuple
        add_jitter: If True, add jitter to x-axis for discrete variables
        add_per_bot_regression: If True, add regression line for each bot

    Returns:
        matplotlib figure
    """
    # Remove NaN values
    plot_data = data[[x_col, y_col, color_by]].dropna().copy()

    if len(plot_data) == 0:
        return None

    # Calculate Pearson correlation (on original data)
    pearson_r, pearson_p = stats.pearsonr(plot_data[x_col], plot_data[y_col])

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Add jitter for discrete variables to spread out points
    x_values = plot_data[x_col].values.copy()
    if add_jitter:
        unique_vals = np.unique(x_values)
        if len(unique_vals) <= 5:  # Discrete variable
            x_range = x_values.max() - x_values.min()
            jitter_amount = x_range * 0.02 if x_range > 0 else 0.01
            x_jittered = x_values + np.random.normal(0, jitter_amount, size=len(x_values))
            plot_data['x_jittered'] = x_jittered
        else:
            plot_data['x_jittered'] = x_values
    else:
        plot_data['x_jittered'] = x_values

    # Get bot rankings if available
    rank_map = {}
    if 'Rank_L' in data.columns and color_by == 'Bot':
        rank_map = data.groupby('Bot')['Rank_L'].first().to_dict()

    # Scatter plot with colors by bot
    if color_by in plot_data.columns:
        unique_values = plot_data[color_by].unique()
        # Sort by rank if available, otherwise alphabetically
        if rank_map:
            unique_values = sorted(unique_values, key=lambda v: rank_map.get(v, 999))
        else:
            unique_values = sorted(unique_values)
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_values)))

        for idx, value in enumerate(unique_values):
            mask = plot_data[color_by] == value

            # Get bot marker
            marker = get_bot_marker(value) if color_by == 'Bot' else 'o'

            # Create label with rank if available
            if rank_map and value in rank_map:
                label = f"{value} (#{int(rank_map[value])})"
            else:
                label = value

            ax.scatter(plot_data[mask]['x_jittered'], plot_data[mask][y_col],
                      label=label, alpha=alpha, s=60, color=colors[idx],
                      marker=marker, edgecolors='black', linewidth=0.5)

            # Add per-bot regression line if requested
            if add_per_bot_regression:
                bot_x = plot_data[mask][x_col].values
                bot_y = plot_data[mask][y_col].values
                if len(bot_x) > 1 and bot_x.std() > 0:  # Need at least 2 points and variance for regression
                    bot_slope, bot_intercept = np.polyfit(bot_x, bot_y, 1)
                    bot_x_line = np.linspace(bot_x.min(), bot_x.max(), 100)
                    bot_y_line = bot_slope * bot_x_line + bot_intercept
                    ax.plot(bot_x_line, bot_y_line, '--', linewidth=1.5, color=colors[idx], alpha=0.7)
    else:
        ax.scatter(plot_data['x_jittered'], plot_data[y_col], alpha=alpha, s=60,
                  edgecolors='black', linewidth=0.5)

    # Add overall regression line (using original non-jittered data)
    slope, intercept = np.polyfit(x_values, plot_data[y_col], 1)
    x_line = np.linspace(x_values.min(), x_values.max(), 100)
    y_line = slope * x_line + intercept
    ax.plot(x_line, y_line, 'r-', linewidth=2.5, label=f'Overall Regression')

    # Add correlation info to plot
    corr_text = f'Pearson r = {pearson_r:.3f}\np-value = {pearson_p:.3e}\nn = {len(plot_data)}'
    ax.text(0.05, 0.95, corr_text, transform=ax.transAxes,
            verticalalignment='top', bbox=dict(boxstyle='round',
            facecolor='wheat', alpha=0.8), fontsize=10, family='monospace')

    # Labels and title
    ax.set_xlabel(get_metric_name(x_col), fontsize=12, fontweight='bold')
    ax.set_ylabel(get_metric_name(y_col), fontsize=12, fontweight='bold')
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3, linestyle='--')

    # Legend - position below x-axis
    if color_by in plot_data.columns:
        legend_title = "Bot (Rank)" if (color_by == 'Bot' and rank_map) else color_by
        legend_padding = calculate_legend_padding(ax, rotation=0)
        ax.legend(title=legend_title, loc='upper center', bbox_to_anchor=(0.5, legend_padding),
                 fontsize=8, framealpha=0.9, ncol=3, markerscale=1.2)

    plt.tight_layout()
    return fig


def plot_all_correlations(df, width=10, height=8,alpha=0.2):
    """
    Create all correlation plots for win rate analysis.
    For config variables, creates separate plots for each config value.

    Args:
        df: Summary matchup dataframe
        width: Figure width
        height: Figure height

    Returns:
        Dictionary of figures with nested structure for config-separated plots
    """
    # Prepare data
    data = prepare_correlation_data(df)

    figs = {}

    # a. Winrate vs ActInterval (direct correlation with per-bot regression)
    fig = plot_correlation_scatter(
        data,
        x_col='ActInterval',
        y_col='WinRate',
        title='Win Rate vs Action Interval\n(All Bots Combined)',
        figsize=(width, height),
        add_jitter=False,
        add_per_bot_regression=True,alpha=alpha
    )
    if fig:
        figs['actinterval'] = fig

    # b. Winrate vs Round type (direct correlation with per-bot regression)
    # Build dynamic round type mapping for title
    round_mapping = data[['Round', 'RoundNumeric']].drop_duplicates().dropna()
    round_labels = ', '.join([f"{int(row['RoundNumeric'])}={row['Round']}"
                              for _, row in round_mapping.sort_values('RoundNumeric').iterrows()])
    round_title = f'Win Rate vs Round Type ({round_labels})\n(All Bots Combined)' if round_labels else 'Win Rate vs Round Type\n(All Bots Combined)'

    fig = plot_correlation_scatter(
        data,
        x_col='RoundNumeric',
        y_col='WinRate',
        title=round_title,
        figsize=(width, height),
        add_jitter=False,
        add_per_bot_regression=True,alpha=alpha
    )
    if fig:
        figs['roundtype'] = fig

    # c. Winrate vs Timer (direct correlation with per-bot regression)
    fig = plot_correlation_scatter(
        data,
        x_col='Timer',
        y_col='WinRate',
        title='Win Rate vs Timer Duration\n(All Bots Combined)',
        figsize=(width, height),
        add_jitter=False,
        add_per_bot_regression=True,alpha=alpha
    )
    if fig:
        figs['timer'] = fig

    # d. Winrate vs Skill Type (direct correlation with per-bot regression)
    # Add numeric encoding for skill type
    if 'SkillLeft' in data.columns:
        skill_map = {'Stone': 1, 'Boost': 2}
        data['SkillNumeric'] = data['SkillLeft'].map(skill_map)
        fig = plot_correlation_scatter(
            data,
            x_col='SkillNumeric',
            y_col='WinRate',
            title='Win Rate vs Skill Type (1=Stone, 2=Boost)\n(All Bots Combined)',
            figsize=(width, height),
            add_jitter=False,
            add_per_bot_regression=True,alpha=alpha
        )
        if fig:
            figs['skilltype'] = fig

    # e. Winrate vs Individual Actions
    action_types = ['Accelerate_Act', 'TurnLeft_Act', 'TurnRight_Act',
                   'Dash_Act', 'SkillBoost_Act', 'SkillStone_Act']

    # Get bot rankings if available
    rank_map = {}
    if 'Rank_L' in data.columns:
        rank_map = data.groupby('Bot')['Rank_L'].first().to_dict()

    fig, axes = plt.subplots(2, 3, figsize=(width*1.2, height*1.5))
    axes = axes.flatten()

    for idx, action in enumerate(action_types):
        if action not in data.columns:
            continue

        plot_data = data[[action, 'WinRate', 'Bot']].dropna()

        if len(plot_data) == 0:
            axes[idx].text(0.5, 0.5, f'No data for {action}',
                          ha='center', va='center', transform=axes[idx].transAxes)
            continue

        # Calculate Pearson correlation
        pearson_r, pearson_p = stats.pearsonr(plot_data[action], plot_data['WinRate'])

        # Scatter plot - sort bots by rank
        unique_bots = plot_data['Bot'].unique()
        if rank_map:
            # Sort bots by rank
            unique_bots = sorted(unique_bots, key=lambda b: rank_map.get(b, 999))
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_bots)))

        for bot_idx, bot in enumerate(unique_bots):
            mask = plot_data['Bot'] == bot
            marker = get_bot_marker(bot)

            # Create label with rank if available
            if rank_map and bot in rank_map:
                label = f"{bot} (#{int(rank_map[bot])})"
            else:
                label = bot

            axes[idx].scatter(plot_data[mask][action], plot_data[mask]['WinRate'],
                            label=label, alpha=alpha, s=30, color=colors[bot_idx],
                            marker=marker, edgecolors='black', linewidth=0.5)

            # Per-bot regression line
            bot_x = plot_data[mask][action].values
            bot_y = plot_data[mask]['WinRate'].values
            if len(bot_x) > 1 and bot_x.std() > 0:  # Need at least 2 points and variance for regression
                bot_slope, bot_intercept = np.polyfit(bot_x, bot_y, 1)
                bot_x_line = np.linspace(bot_x.min(), bot_x.max(), 100)
                bot_y_line = bot_slope * bot_x_line + bot_intercept
                axes[idx].plot(bot_x_line, bot_y_line, '--', linewidth=1.5, color=colors[bot_idx], alpha=0.7)

        # Overall regression line
        slope, intercept = np.polyfit(plot_data[action], plot_data['WinRate'], 1)
        x_line = np.linspace(plot_data[action].min(), plot_data[action].max(), 100)
        y_line = slope * x_line + intercept
        axes[idx].plot(x_line, y_line, 'r-', linewidth=2.5)

        # Correlation info
        corr_text = f'r={pearson_r:.3f}\np={pearson_p:.2e}'
        axes[idx].text(0.05, 0.95, corr_text, transform=axes[idx].transAxes,
                      verticalalignment='top', bbox=dict(boxstyle='round',
                      facecolor='wheat', alpha=0.8), fontsize=8, family='monospace')

        axes[idx].set_xlabel(get_metric_name(action), fontsize=10)
        axes[idx].set_ylabel(get_metric_name('WinRate'), fontsize=10)
        axes[idx].set_title(f'{get_metric_name("WinRate")} vs {get_metric_name(action)}', fontsize=11, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, linestyle='--')

    # Add legend below x-axis with rank title if rankings are available
    handles, labels = [], []
    for ax in fig.axes:
        h, l = ax.get_legend_handles_labels()
        handles.extend(h)
        labels.extend(l)

    # Deduplicate while preserving order
    unique = dict(zip(labels, handles))
    legend_title = 'Bot (Rank)' if rank_map else 'Bot'
    fig.legend(unique.values(), unique.keys(),
           title=legend_title, fontsize=6,
           loc='upper center',
           bbox_to_anchor=(0.5, -0.05),
           framealpha=0.7, ncol=3)

    plt.suptitle('Win Rate vs Individual Action Types\n(All Bots Combined)',
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    figs['actions'] = fig

    # e. Winrate vs Individual Actions
    action_types = ['Accelerate_Dur', 'TurnLeft_Dur', 'TurnRight_Dur', 'Dash_Dur']

    # Get bot rankings if available (reuse from above if already set)
    if 'rank_map' not in locals():
        rank_map = {}
        if 'Rank_L' in data.columns:
            rank_map = data.groupby('Bot')['Rank_L'].first().to_dict()

    fig, axes = plt.subplots(2, 2, figsize=(width*1.2, height*1.5))
    axes = axes.flatten()

    for idx, action in enumerate(action_types):
        if action not in data.columns:
            continue

        plot_data = data[[action, 'WinRate', 'Bot']].dropna()

        if len(plot_data) == 0:
            axes[idx].text(0.5, 0.5, f'No data for {action}',
                          ha='center', va='center', transform=axes[idx].transAxes)
            continue

        # Calculate Pearson correlation
        pearson_r, pearson_p = stats.pearsonr(plot_data[action], plot_data['WinRate'])

        # Scatter plot - sort bots by rank
        unique_bots = plot_data['Bot'].unique()
        if rank_map:
            # Sort bots by rank
            unique_bots = sorted(unique_bots, key=lambda b: rank_map.get(b, 999))
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_bots)))

        for bot_idx, bot in enumerate(unique_bots):
            mask = plot_data['Bot'] == bot
            marker = get_bot_marker(bot)

            # Create label with rank if available
            if rank_map and bot in rank_map:
                label = f"{bot} (#{int(rank_map[bot])})"
            else:
                label = bot

            axes[idx].scatter(plot_data[mask][action], plot_data[mask]['WinRate'],
                            label=label, alpha=alpha, s=30, color=colors[bot_idx],
                            marker=marker, edgecolors='black', linewidth=0.5)

            # Per-bot regression line
            bot_x = plot_data[mask][action].values
            bot_y = plot_data[mask]['WinRate'].values
            if len(bot_x) > 1 and bot_x.std() > 0:  # Need at least 2 points and variance for regression
                bot_slope, bot_intercept = np.polyfit(bot_x, bot_y, 1)
                bot_x_line = np.linspace(bot_x.min(), bot_x.max(), 100)
                bot_y_line = bot_slope * bot_x_line + bot_intercept
                axes[idx].plot(bot_x_line, bot_y_line, '--', linewidth=1.5, color=colors[bot_idx], alpha=0.7)

        # Overall regression line
        slope, intercept = np.polyfit(plot_data[action], plot_data['WinRate'], 1)
        x_line = np.linspace(plot_data[action].min(), plot_data[action].max(), 100)
        y_line = slope * x_line + intercept
        axes[idx].plot(x_line, y_line, 'r-', linewidth=2.5)

        # Correlation info
        corr_text = f'r={pearson_r:.3f}\np={pearson_p:.2e}'
        axes[idx].text(0.05, 0.95, corr_text, transform=axes[idx].transAxes,
                      verticalalignment='top', bbox=dict(boxstyle='round',
                      facecolor='wheat', alpha=0.8), fontsize=8, family='monospace')

        axes[idx].set_xlabel(get_metric_name(action), fontsize=10)
        axes[idx].set_ylabel(get_metric_name('WinRate'), fontsize=10)
        axes[idx].set_title(f'{get_metric_name("WinRate")} vs {get_metric_name(action)}', fontsize=11, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, linestyle='--')

    # Add legend below x-axis with rank title if rankings are available
    handles, labels = [], []
    for ax in fig.axes:
        h, l = ax.get_legend_handles_labels()
        handles.extend(h)
        labels.extend(l)

    # Deduplicate while preserving order
    unique = dict(zip(labels, handles))
    legend_title = 'Bot (Rank)' if rank_map else 'Bot'
    fig.legend(unique.values(), unique.keys(),
           title=legend_title, fontsize=6,
           loc='upper center',
           bbox_to_anchor=(0.5, -0.05),
           framealpha=0.7, ncol=3)

    plt.suptitle('Win Rate vs Individual Action Duration\n(All Bots Combined)',
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    figs['actions_dur'] = fig

    # f. Winrate vs Collision Types (Hit, Struck, Tie) - Combined across all configs
    collision_types = ['Collisions_L', 'Collisions_R', 'Collisions_Tie']
    collision_labels = {'Collisions_L': 'Hit', 'Collisions_R': 'Struck', 'Collisions_Tie': 'Tie'}

    # Get bot rankings if available
    rank_map = {}
    if 'Rank_L' in data.columns:
        rank_map = data.groupby('Bot')['Rank_L'].first().to_dict()

    fig, axes = plt.subplots(1, 3, figsize=(width*1.8, height))

    for idx, col_type in enumerate(collision_types):
        if col_type not in data.columns:
            continue

        plot_data = data[[col_type, 'WinRate', 'Bot']].dropna()

        if len(plot_data) < 2:
            axes[idx].text(0.5, 0.5, f'Insufficient data',
                          ha='center', va='center', transform=axes[idx].transAxes)
            continue

        # Calculate Pearson correlation
        pearson_r, pearson_p = stats.pearsonr(plot_data[col_type], plot_data['WinRate'])

        # Get unique bots and assign colors - sort by rank
        unique_bots = plot_data['Bot'].unique()
        if rank_map:
            # Sort bots by rank
            unique_bots = sorted(unique_bots, key=lambda b: rank_map.get(b, 999))
        else:
            unique_bots = sorted(unique_bots)
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_bots)))

        # Scatter plot colored by bot
        for bot_idx, bot in enumerate(unique_bots):
            bot_data = plot_data[plot_data['Bot'] == bot]
            marker = get_bot_marker(bot)

            # Create label with rank if available
            if rank_map and bot in rank_map:
                label = f"{bot} (#{int(rank_map[bot])})"
            else:
                label = bot

            axes[idx].scatter(bot_data[col_type], bot_data['WinRate'],
                            alpha=alpha, s=60, color=colors[bot_idx], marker=marker,
                            label=label, edgecolors='black', linewidth=0.5)

        # Overall regression line
        if len(plot_data) >= 2 and plot_data[col_type].std() > 0:
            slope, intercept = np.polyfit(plot_data[col_type], plot_data['WinRate'], 1)
            x_line = np.linspace(plot_data[col_type].min(), plot_data[col_type].max(), 100)
            y_line = slope * x_line + intercept
            axes[idx].plot(x_line, y_line, 'r-', linewidth=2.5, label='Overall Regression')

        # Correlation info
        corr_text = f'r={pearson_r:.3f}\np={pearson_p:.2e}'
        axes[idx].text(0.05, 0.95, corr_text, transform=axes[idx].transAxes,
                      verticalalignment='top', bbox=dict(boxstyle='round',
                      facecolor='wheat', alpha=0.8), fontsize=10, family='monospace')

        axes[idx].set_xlabel(collision_labels[col_type], fontsize=11, fontweight='bold')
        axes[idx].set_ylabel(get_metric_name('WinRate'), fontsize=11, fontweight='bold')
        axes[idx].set_title(f'{get_metric_name("WinRate")} vs {collision_labels[col_type]}',
                           fontsize=12, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, linestyle='--')

    # Add legend below x-axis with rank title if rankings are available
    handles, labels = [], []
    for ax in fig.axes:
        h, l = ax.get_legend_handles_labels()
        handles.extend(h)
        labels.extend(l)

    # Deduplicate while preserving order
    unique = dict(zip(labels, handles))
    legend_title = 'Bot (Rank)' if rank_map else 'Bot'
    fig.legend(unique.values(), unique.keys(),
           title=legend_title, fontsize=6,
           loc='upper center',
           bbox_to_anchor=(0.5, -0.05),
           framealpha=0.7, ncol=3)


    plt.suptitle(f'Win Rate vs Collision Types\n(All Bots Combined)',
                 fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    figs['collisions'] = fig

    return figs


In [None]:
# Load summary data
df_sum = pd.read_csv("summary_bot.csv").rename(columns={"Duration": "Duration (ms)"})
df = pd.read_csv("summary_matchup.csv")
df_timebins = pd.read_csv("summary_action_timebins.csv")
df_collision_timebins = pd.read_csv("summary_collision_timebins.csv")

# Configuration
cfg = {
    "Timer": sorted(df["Timer"].unique().tolist()),
    "ActInterval": sorted(df["ActInterval"].unique().tolist()),
    "Round": sorted(df["Round"].unique().tolist()),
    "SkillLeft": sorted(df["SkillLeft"].unique().tolist()),
    "SkillRight": sorted(df["SkillRight"].unique().tolist()),
    "Bots": sorted(df["Bot_L"].unique().tolist()),
}
bots = str.join(", ", cfg["Bots"])

# Display settings
width = 10
height = 6

print("Data loaded successfully!")
print(f"\nBots in experiment: {bots}")
print(f"\nConfiguration:")
for key, value in cfg.items():
    print(f"  {key}: {value}")

## Summary Matchup Data

In [None]:
display(df_sum)

## Complete Matchup Data

In [None]:
display(df)

# Overall Analysis

Analyze bot agents facing other agents with similar configurations

## Bot Behaviour Overview

### Actions Behaviour
Mean action counts per bot across all configurations

In [None]:
fig = plot_action_radar(df)
plt.show()

### Collision Behaviour
Hit/Struck/Tie distribution per bot

In [None]:
fig = plot_collision_radar(df)
plt.show()

## Win Rate Matrix

Shows how often each bot wins against others across different matchups.
This is calculated with taking mean of each configuration (10-games iteration matchup) resulting 240 games in total

In [None]:
fig = plot_winrate_matrix(df, width, height)
plt.show()

## Action Taken (All Configurations)

In [None]:
fig = plot_overall_bot_metrics(df, metric="ActionCounts_L", title="Mean Action per Bot")
plt.show()

## Action Duration (All Configurations)

In [None]:
fig = plot_overall_bot_metrics(df, metric="Duration_L", title="Mean Action Duration per Bot")
plt.show()

## Collision (All Configurations)

In [None]:
fig = plot_overall_bot_metrics(df, metric="Collisions_L", title="Mean Collisions per Bot")
plt.show()

## Match Duration (All Configurations)

In [None]:
fig = plot_overall_bot_metrics(df, metric="MatchDur", title="Mean Match Duration per Bot")
plt.show()

## Win Rate Grouped by Timer

In [None]:
fig = plot_grouped_config_winrates(df, config_col="Timer")
plt.show()

## Win Rate Grouped by Action Interval

In [None]:
fig = plot_grouped_config_winrates(df, config_col="ActInterval")
plt.show()

## Win Rate Grouped by Round

In [None]:
fig = plot_grouped_config_winrates(df, config_col="Round")
plt.show()

## Win Rate Grouped by Skill

In [None]:
fig = plot_grouped_config_winrates(df, config_col="Skill")
plt.show()

## Collision Grouped by Timer

In [None]:
fig = plot_grouped_config_winrates(df, metric="Collisions_L", config_col="Timer")
plt.show()

## Collision Grouped by Action Interval

In [None]:
fig = plot_grouped_config_winrates(df, metric="Collisions_L", config_col="ActInterval")
plt.show()

## Collision Grouped by Round

In [None]:
fig = plot_grouped_config_winrates(df, metric="Collisions_L", config_col="Round")
plt.show()

## Collision Grouped by Skill

In [None]:
fig = plot_grouped_config_winrates(df, metric="Collisions_L", config_col="Skill")
plt.show()

## Action Taken Grouped by Timer

In [None]:
fig = plot_grouped_config_winrates(df, metric="ActionCounts_L", config_col="Timer")
plt.show()

## Action Taken Grouped by Action Interval

In [None]:
fig = plot_grouped_config_winrates(df, metric="ActionCounts_L", config_col="ActInterval")
plt.show()

## Action Taken Grouped by Round

In [None]:
fig = plot_grouped_config_winrates(df, metric="ActionCounts_L", config_col="Round")
plt.show()

## Action Taken Grouped by Skill

In [None]:
fig = plot_grouped_config_winrates(df, metric="ActionCounts_L", config_col="Skill")
plt.show()

## Action Duration Grouped by Timer

In [None]:
fig = plot_grouped_config_winrates(df, metric="Duration_L", config_col="Timer")
plt.show()

## Action Duration Grouped by Action Interval

In [None]:
fig = plot_grouped_config_winrates(df, metric="Duration_L", config_col="ActInterval")
plt.show()

## Action Duration Grouped by Round

In [None]:
fig = plot_grouped_config_winrates(df, metric="Duration_L", config_col="Round")
plt.show()

## Action Duration Grouped by Skill

In [None]:
fig = plot_grouped_config_winrates(df, metric="Duration_L", config_col="Skill")
plt.show()

## Match Duration Grouped by Timer

In [None]:
fig = plot_grouped_config_winrates(df, metric="MatchDur", config_col="Timer")
plt.show()

## Match Duration Grouped by Action Interval

In [None]:
fig = plot_grouped_config_winrates(df, metric="MatchDur", config_col="ActInterval")
plt.show()

## Match Duration Grouped by Round

In [None]:
fig = plot_grouped_config_winrates(df, metric="MatchDur", config_col="Round")
plt.show()

## Match Duration Grouped by Skill

In [None]:
fig = plot_grouped_config_winrates(df, metric="MatchDur", config_col="Skill")
plt.show()

## Time-Related Trends

Analyzes Bots aggressiveness over game duration with determining how much action taken duration related to the overall game duration (Time Setting).
Higher timers don't always lead to longer matches. Some matchups finish fights early regardless of time limit.

In [None]:
figs = plot_time_related(df, width, height)
for fig in figs:
    plt.show()

## Action Distribution per Bots

In [None]:
fig = plot_action_distribution_stacked(df, normalize=True)
plt.show()

## Action Intensity Over Time (Per Configuration)

Shows action intensity over time for different timer and action interval configurations

In [None]:
for timI in cfg["Timer"]:
    for actI in cfg["ActInterval"]:
        print(f"\n--- Timer={timI}, ActionInterval={actI} ---")
        
        # Total action intensity
        fig = plot_timebins_intensity(df_timebins, timer=timI, act_interval=actI, mode="total", summary_df=df)
        if fig:
            plt.show()
        
        # Per-action intensity
        fig = plot_timebins_intensity(df_timebins, timer=timI, act_interval=actI, mode="per_action", summary_df=df)
        if fig:
            plt.show()

## Action Intensity Over All Configuration

In [None]:
# Total action intensity
fig = plot_timebins_intensity(df_timebins, mode="total", timer=60, summary_df=df)
if fig:
    plt.show()

In [None]:
# Per-action intensity
fig = plot_timebins_intensity(df_timebins, mode="per_action", timer=60, summary_df=df)
if fig:
    plt.show()

## Collision Intensity Over Time (Per Configuration)

Shows collision intensity over time for different timer and action interval configurations

In [None]:
for timI in cfg["Timer"]:
    for actI in cfg["ActInterval"]:
        print(f"\n--- Timer={timI}, ActionInterval={actI} ---")
        
        # Total collision intensity
        fig = plot_collision_timebins_intensity(df_collision_timebins, timer=timI, act_interval=actI, mode="total", summary_df=df)
        if fig:
            plt.show()
        
        # Per-type collision intensity
        fig = plot_collision_timebins_intensity(df_collision_timebins, timer=timI, act_interval=actI, mode="per_type", summary_df=df)
        if fig:
            plt.show()

## Collision Detail Distribution per Bots

In [None]:
fig = plot_collision_distribution_stacked(df, normalize=True)
plt.show()

## Collision Intensity Over All Configuration

In [None]:
# Total collision intensity
fig = plot_collision_timebins_intensity(df_collision_timebins, mode="total", timer=60, summary_df=df)
if fig:
    plt.show()

In [None]:
# Per-type collision intensity
fig = plot_collision_timebins_intensity(df_collision_timebins, mode="per_type", timer=60, summary_df=df)
if fig:
    plt.show()

## Action Taken vs. Win Relation

Does spending most action (aggressive) lead to a win?
This taking mean of action-taken per games versus win-rate

In [None]:
fig = plot_action_win_related(df, width, height)
plt.show()

## Pearson Correlation Analysis (Overall)

Correlation analysis using Pearson coefficient with scatter plots and regression lines.
All data from all bots combined, separated by configuration

In [None]:
correlation_figs = plot_all_correlations(df, width, height)

# Individual Bot Analysis

Analyze bot agent against its different configurations.
Each of report: Win Rate; Collision; Action-Taken; Duration; is calculated with averaging data from matchup (left and right position)

## Pearson Correlation Analysis (Per Bot)

Detailed plots for individual bots, separated by configuration

In [None]:
# Get unique bots
bots_list = sorted(df['Bot_L'].unique())
print(f"Analyzing {len(bots_list)} bots: {bots_list}")

In [None]:
# Individual bot correlation analysis
for bot in bots_list:
    print(f"\n{'='*60}")
    print(f"Analyzing correlations for {bot}")
    print(f"{'='*60}")
    
    correlation_figs = plot_individual_bot_correlations(df, bot, width, height)
    
    if not correlation_figs:
        print(f"No data available for {bot}")
        continue
    
    # Win Rate vs ActInterval
    if 'actinterval' in correlation_figs:
        print("\n--- Win Rate vs Action Interval Configuration ---")
        plt.show()
    
    # Win Rate vs Round Type
    if 'roundtype' in correlation_figs:
        print("\n--- Win Rate vs Round Type Configuration ---")
        plt.show()
    
    # Win Rate vs Timer
    if 'timer' in correlation_figs:
        print("\n--- Win Rate vs Timer Configuration ---")
        plt.show()
    
    # Win Rate vs Skill Type
    if 'skilltype' in correlation_figs:
        print("\n--- Win Rate vs Skill Type Configuration ---")
        plt.show()
    
    # Win Rate vs Action Types
    if 'actions' in correlation_figs:
        print("\n--- Win Rate vs Individual Action Types ---")
        plt.show()
    
    # Win Rate vs Action Duration
    if 'actions_dur' in correlation_figs:
        print("\n--- Win Rate vs Individual Action Duration ---")
        plt.show()
    
    # Win Rate vs Collisions
    if 'collisions' in correlation_figs:
        print("\n--- Win Rate vs Collision Types (Hit, Struck, Tie) ---")
        plt.show()

# Arena Heatmaps - Bot Movement Analysis

Visualize bot movement patterns across different game phases (Early, Mid, Late)

In [None]:
# Check if arena_heatmap directory exists
heatmap_dir = "arena_heatmaps"

if os.path.exists(heatmap_dir):
    # Get all bot directories
    bot_dirs = [d for d in os.listdir(heatmap_dir)
               if os.path.isdir(os.path.join(heatmap_dir, d))]
    
    # Sort bot directories by rank from df_sum
    if "Rank" in df_sum.columns and "Bot" in df_sum.columns:
        rank_map = df_sum.groupby("Bot")["Rank"].first().to_dict()
        bot_dirs = sorted(bot_dirs, key=lambda b: rank_map.get(b, 9999))
    else:
        bot_dirs = sorted(bot_dirs)
    
    if bot_dirs:
        phase_names = ["window_2.5-15s.png", "window_15-30s.png", "window_30-45s.png", "window_45-60s.png"]
        
        # Display heatmaps for each bot
        for bot_name in bot_dirs:
            print(f"\n{'='*60}")
            print(f"{bot_name} (#{bot_dirs.index(bot_name)+1})")
            print(f"{'='*60}")
            bot_dir = os.path.join(heatmap_dir, bot_name)
            
            # Display phase heatmaps
            fig, axes = plt.subplots(1, len(phase_names), figsize=(20, 5))
            for idx, phase_name in enumerate(phase_names):
                image_path = os.path.join(bot_dir, phase_name)
                if os.path.exists(image_path):
                    image = Image.open(image_path)
                    axes[idx].imshow(image)
                    axes[idx].set_title(phase_name)
                    axes[idx].axis('off')
                else:
                    axes[idx].text(0.5, 0.5, f"Image not found:\n{phase_name}",
                                  ha='center', va='center')
                    axes[idx].axis('off')
            plt.tight_layout()
            plt.show()
            
            # Display position distribution
            dist_path = os.path.join(bot_dir, "position_distribution.png")
            if os.path.exists(dist_path):
                print("\nPosition Distribution (X & Y Overlayed)")
                dist_image = Image.open(dist_path)
                plt.figure(figsize=(10, 6))
                plt.imshow(dist_image)
                plt.axis('off')
                plt.show()
            
            # Display distance distribution
            dist_path = os.path.join(bot_dir, "distance_distribution.png")
            if os.path.exists(dist_path):
                print("\nDistance Distribution")
                dist_image = Image.open(dist_path)
                plt.figure(figsize=(10, 6))
                plt.imshow(dist_image)
                plt.axis('off')
                plt.show()

            print("\nFull Configuration Analysis")
            fig = plot_full_cross_heatmap_half(df, bot_name=bot_name, lower_triangle=True)
            plt.show()
    else:
        print("No bot heatmaps found in directory")
        print("Run: `python detailed_analyzer.py all` to generate heatmaps")
else:
    print(f"Heatmap directory not found: {heatmap_dir}")
    print("Run: `python detailed_analyzer.py all` to generate heatmaps for all bots")