In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import json
from scipy.interpolate import interp1d
import seaborn as sns

def read_monitor_files(log_dir):
    """
    Read all monitor CSV files from the log directory.
    Monitor files are saved as 0.monitor.csv, 1.monitor.csv, etc.
    """
    monitor_files = glob(os.path.join(log_dir, "*.monitor.csv"))
    if not monitor_files:
        # Try subdirectories
        monitor_files = glob(os.path.join(log_dir, "*", "*.monitor.csv"))

    if not monitor_files:
        raise ValueError(f"No monitor files found in {log_dir}")

    all_data = []

    for file_path in monitor_files:
        # Read the CSV file
        try:
            # Monitor files have a header with metadata
            with open(file_path, 'r') as f:
                # Skip the first line which contains metadata
                f.readline()
                df = pd.read_csv(f)

            # Add cumulative timesteps
            df['cumulative_timesteps'] = df['l'].cumsum()
            all_data.append(df)
        except Exception as e:
            print(f"Error reading {file_path}: {e}")
            continue

    return all_data

def aggregate_data(all_data, window_size=100, num_points=200):
    """
    Aggregate data from multiple processes and compute statistics.
    """
    # Combine all timesteps and rewards
    all_timesteps = []
    all_rewards = []

    for df in all_data:
        all_timesteps.extend(df['cumulative_timesteps'].values)
        all_rewards.extend(df['r'].values)

    # Sort by timesteps
    sorted_indices = np.argsort(all_timesteps)
    all_timesteps = np.array(all_timesteps)[sorted_indices]
    all_rewards = np.array(all_rewards)[sorted_indices]

    # Create bins for aggregation
    max_timesteps = all_timesteps[-1]
    timestep_bins = np.linspace(0, max_timesteps, num_points)

    mean_rewards = []
    std_rewards = []
    timesteps_plot = []

    for i in range(len(timestep_bins) - 1):
        bin_start = timestep_bins[i]
        bin_end = timestep_bins[i + 1]

        # Get rewards in this bin
        mask = (all_timesteps >= bin_start) & (all_timesteps < bin_end)
        bin_rewards = all_rewards[mask]

        if len(bin_rewards) > 0:
            # Use a rolling window for smoother curves
            if i > 0:
                # Look back to get more samples for statistics
                lookback_start = max(0, i - window_size // 2)
                lookback_mask = (all_timesteps >= timestep_bins[lookback_start]) & (all_timesteps < bin_end)
                window_rewards = all_rewards[lookback_mask]

                if len(window_rewards) > 0:
                    mean_rewards.append(np.mean(window_rewards))
                    std_rewards.append(np.std(window_rewards) / np.sqrt(len(window_rewards)))  # Standard error
                    timesteps_plot.append((bin_start + bin_end) / 2)

    return np.array(timesteps_plot), np.array(mean_rewards), np.array(std_rewards)

def get_method_colors(all_method_names):
    """
    Create a consistent color mapping for all methods across plots.
    Following the format from the first script.
    """
    # Define specific colors for certain methods to ensure consistency
    specific_colors = {
        'LP': '#d62728',    # Red color for LP
        'AMA': '#ff7f0e',   # Orange color for AMA
        'Ensemble': '#1f77b4',  # Blue color for Ensemble
        'MSE': '#2ca02c',   # Green color for MSE
        'IDF': '#9467bd',   # Purple color for IDF
        'RND': '#8c564b'    # Brown color for RND
    }

    # Available colors for automatic assignment (excluding already used specific colors)
    available_colors = ['#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

    method_colors = {}
    auto_color_idx = 0

    for method_name in sorted(all_method_names):
        if method_name in specific_colors:
            method_colors[method_name] = specific_colors[method_name]
        else:
            method_colors[method_name] = available_colors[auto_color_idx % len(available_colors)]
            auto_color_idx += 1

    return method_colors

def get_display_name(method_name):
    """
    Convert method names to display names for the legend.

    Args:
        method_name: Original method name

    Returns:
        str: Display name for the legend
    """
    if method_name == 'LP':
        return 'LPM(ours)'
    return method_name

def plot_four_environment_comparison(experiment_dicts, env_titles, save_dir="comparison_plots"):
    """
    Plot comparison of exploration methods across four different environments side by side.
    Following the format from the first script but adapted for exploration curves.

    Args:
        experiment_dicts: List of 4 dictionaries, each mapping method names to log directories
        env_titles: List of 4 environment titles
        save_dir: Directory to save the plots
    """

    plt.style.use('default')
    plt.rcParams.update({
        'font.size': 10,
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'Times', 'serif'],
        'mathtext.fontset': 'stix',
        'axes.linewidth': 0.8,
        'grid.linewidth': 0.5,
        'lines.linewidth': 1.5,
    })

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Get all unique method names for consistent coloring
    all_methods = set()
    for exp_dict in experiment_dicts:
        all_methods.update(exp_dict.keys())

    # Remove 'random' from methods if it exists (we'll handle it separately)
    if 'random' in all_methods:
        all_methods.remove('random')

    method_colors = get_method_colors(all_methods)

    # Create figure with 4 subplots side by side
    fig, axes = plt.subplots(1, 4, figsize=(32, 7))

    # Store all plot handles for legend - use dictionaries for organized collection
    method_handles_dict = {}
    random_handle = None

    # Plot each environment
    for idx, (experiment_dict, env_title) in enumerate(zip(experiment_dicts, env_titles)):
        ax = axes[idx]

        if not experiment_dict:
            ax.text(0.5, 0.5, f'No data available\nfor {env_title}',
                   ha='center', va='center', transform=ax.transAxes,
                   fontsize=15, style='italic')
            ax.set_title(env_title, fontsize=35, fontweight='medium')
            continue

        print(f"\nProcessing {env_title}...")

        # Track max timesteps for random baseline
        max_timesteps = 0

        # Plot each method in this environment
        for method_name, log_dir in experiment_dict.items():
            # Skip random for now, we'll add it as baseline
            if method_name.lower() == 'random':
                continue

            print(f"  Processing {method_name} from {log_dir}...")

            try:
                # Read monitor files
                all_data = read_monitor_files(log_dir)

                # Aggregate data
                timesteps, mean_rewards, std_rewards = aggregate_data(all_data)

                # Convert timesteps to steps (keep as is or divide by 1000 for thousands)
                timesteps_steps = timesteps * 1000000 / 16000

                # Update max timesteps for baseline line
                if len(timesteps_steps) > 0:
                    max_timesteps = max(max_timesteps, timesteps_steps[-1])

                color = method_colors[method_name]

                # Plot mean line
                line = ax.plot(timesteps_steps, mean_rewards,
                              color=color,
                              linestyle='-',
                              linewidth=3,
                              label=method_name,
                              alpha=0.9)[0]

                # Plot confidence interval (reduced std for tighter bands)
                ax.fill_between(timesteps_steps,
                               mean_rewards - std_rewards,
                               mean_rewards + std_rewards,
                               color=color,
                               alpha=0.1)

                # Collect handles for legend (only from first subplot to avoid duplicates)
                if idx == 0 and method_name not in method_handles_dict:
                    method_handles_dict[method_name] = line

            except Exception as e:
                print(f"    Error processing {method_name}: {e}")
                continue

        # Add random baseline if available
        if 'random' in experiment_dict:
            random_log_dir = experiment_dict['random']
            print(f"  Adding random baseline from {random_log_dir}...")

            try:
                # Compute overall average reward for random baseline
                all_data = read_monitor_files(random_log_dir)
                all_rewards = []
                for df in all_data:
                    all_rewards.extend(df['r'].values)

                if all_rewards:
                    random_avg_reward = np.mean(all_rewards)
                    baseline_line = ax.axhline(y=random_avg_reward, color='gray',
                                             linestyle='--', linewidth=2,
                                             alpha=0.8)

                    # Add to legend only from first subplot
                    if idx == 0 and random_handle is None:
                        random_handle = baseline_line

            except Exception as e:
                print(f"    Error processing random baseline: {e}")

        # Customize each subplot
        if idx == 0:  # Only leftmost plot gets y-label
            ax.set_ylabel('Extrinsic Reward', fontsize=35, fontweight='medium')
        ax.set_title(env_title, fontsize=35, fontweight='medium')
        ax.grid(True, alpha=0.3)
        ax.set_xlim(left=0)
        ax.tick_params(axis='both', which='major', labelsize=25)
        ax.xaxis.get_offset_text().set_fontsize(30)
        ax.yaxis.get_offset_text().set_fontsize(30)

    # Organize legend with specific ordering: No Int Rew first, LPM(ours) last
    legend_handles = []
    legend_labels = []

    # First add "No Int Rew" if it exists
    if random_handle is not None:
        legend_handles.append(random_handle)
        legend_labels.append('No Int Rew')

    # Then add all methods except LP (sorted)
    for method_name in sorted(method_handles_dict.keys()):
        if method_name != 'LP':
            legend_handles.append(method_handles_dict[method_name])
            legend_labels.append(get_display_name(method_name))

    # Finally add LP (as LPM(ours)) at the end if it exists
    if 'LP' in method_handles_dict:
        legend_handles.append(method_handles_dict['LP'])
        legend_labels.append(get_display_name('LP'))

    # Add single legend at the top of the figure
    if legend_handles:
        fig.legend(legend_handles, legend_labels,
                  loc='upper center', bbox_to_anchor=(0.5, 1.08),
                  ncol=len(legend_labels), fontsize=30)

    # Adjust layout to make room for legend
    fig.text(0.5, -0.04, 'Frames (millions)', ha='center', va='bottom',
             fontsize=35, fontweight='medium')

    plt.tight_layout()
    plt.subplots_adjust(top=0.85, wspace=0.3)

    # Save the plot
    plt.savefig(os.path.join(save_dir, 'four_environment_exploration_comparison.pdf'),
               dpi=1000, bbox_inches='tight')
    plt.show()
    plt.close()

def print_exploration_summary_statistics(experiment_dicts, env_titles):
    """
    Print summary statistics for all methods across all environments.
    """
    print("\n" + "="*100)
    print("SUMMARY STATISTICS - FINAL EXPLORATION PERFORMANCE ACROSS ENVIRONMENTS")
    print("="*100)

    # Get all unique methods
    all_methods = set()
    for exp_dict in experiment_dicts:
        all_methods.update(exp_dict.keys())

    # Remove random from main methods
    if 'random' in all_methods:
        all_methods.remove('random')

    # Create header
    header = f"{'Method':<15}"
    for title in env_titles:
        header += f" {title:<20}"
    print(header)
    print("-" * len(header))

    for method_name in sorted(all_methods):
        row = f"{method_name:<15}"

        for exp_dict in experiment_dicts:
            if method_name in exp_dict:
                try:
                    log_dir = exp_dict[method_name]
                    all_data = read_monitor_files(log_dir)
                    timesteps, mean_rewards, std_rewards = aggregate_data(all_data)

                    if len(mean_rewards) > 0:
                        final_reward = mean_rewards[-1]
                        row += f" {final_reward:.1f} ± {std_rewards[-1]:.1f}".ljust(20)
                    else:
                        row += " N/A".ljust(20)
                except:
                    row += " Error".ljust(20)
            else:
                row += " N/A".ljust(20)

        print(row)

    print("="*100)


def main():
    """
    Main function to create four-environment comparison plots.
    Example usage with placeholder experiment dictionaries.
    """
    # Example experiment dictionaries for four environments
    # Replace these with your actual log directories

    experiment_dict_env1 = {
        'LP': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/New_spaceinvader/improve_20250605_210532',
        'random': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/New_spaceinvader/random_20250605_210532',
        'AMA': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/New_spaceinvader/ama_20250605_210532',
        'MSE': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/New_spaceinvader/mse_20250605_234534',
        'RND': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/rnd_w_no_noise',
        'IDF': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/icm_w_no_noise_w_clip_10eta',
        "Ensemble": '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ensemble_w_no_noise',
        'EDT': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/New_spaceinvader/atari_spaceinvader_ttd_wo_noise',
        'EME': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/space_invader_eme_wo_noise'
    }

    experiment_dict_env2 = {
        # 'pure': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ppo_pure_05-29',
        'LP': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ppo_improve_w_seed_06_03_noisy',
        'random': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/random_noisy_06-03',
        # 'ama': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ama_05-29',
        # 'ama-with-clip': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ama_with_clip_06_03_w_seed',
        'AMA': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ama_with_clip_06_03_w_seed_noisy',
        'MSE': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/New_spaceinvader/mse_noisy_20250606_014834',
        'RND': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/rnd_w_noise_w_clip',
        'IDF': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/idf_w_no_noise_w_clip_10eta',
        "Ensemble": '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ensemble_w_noise',
        "EDT": '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/New_spaceinvader/atari_spaceinvader_ttd_w_noise',
        'EME': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/space_invader_eme_w_noise'
    }

    experiment_dict_env3 = {
        'LP': '/content/drive/MyDrive/MercedLab/Optimization_explore/log_v2/deterministic/MsPacman/ppo-improvement_20250630_010602',
        'random': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v2/random_20250610_223302',
        # 'ama': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ama_05-29',
        # 'ama-with-clip': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/SpaceInvaders/ama_with_clip_06_03_w_seed',
        'AMA': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v2/ama_v2_20250610_223302',
        'MSE': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v2/mse_20250610_223302',
        # 'unet_larger_clip': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v3/unet_20250611_190433',
        'IDF': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v3/IDF_w_no_noise',
        'Ensemble': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v3/Ensemble_w_no_noise',
        'RND': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v3/RND_w_no_noise',
        'EDT': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/MsPacman/pacman_tdd_wo_noise',
        'EME': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/MsPacman/pacman_eme_wo_noise'
    }

    experiment_dict_env4 = {
        'random': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v2_noisy/random_20250611_021817',
        'AMA': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v2_noisy/ama_v2_20250611_021817',
        'MSE': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v2_noisy/mse_20250611_021817',
        'LP': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v4_noise/improve_20250701_194016',
        # 'ppo': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v3_noisy/ppo_20250613_190232',
        'RND': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v3/RND_w_noise',
        'Ensemble': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v3/ensemble_w_noise',
        'IDF': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/pacman_v4/idf_w_noise',
        'EDT': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/MsPacman/pacman_tdd_w_noise',
        'EME': '/content/drive/MyDrive/MercedLab/Optimization_explore/logs/MsPacman/pacman_eme_wo_noise'
    }

    # Environment titles
    env_titles = ["Space Invader", "Space Invader w Noise", "Ms PacMan", "Ms PacMan w Noise"]

    # List of experiment dictionaries
    experiment_dicts = [experiment_dict_env1, experiment_dict_env2, experiment_dict_env3, experiment_dict_env4]

    print("Creating four-environment exploration comparison plots...")
    print("="*80)

    # Create comparison plots
    plot_four_environment_comparison(experiment_dicts, env_titles)

    # Print summary statistics
    print_exploration_summary_statistics(experiment_dicts, env_titles)

    print(f"\nComparison plots saved in 'comparison_plots' directory")

if __name__ == "__main__":
    main()