In [None]:
#!/usr/bin/env python3
"""
Evaluation script for pre-trained GFlowNet models.

This script allows loading trained GFlowNet models to generate trajectories,
analyze rewards, and create visualizations for research papers/theses.
"""

import argparse
import copy
import json
import logging
import os
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import yaml
from scipy.stats import entropy
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import StandardScaler
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utilities.REINFORCE_Support import *
from utilities.SMCMC_Support import *
from utilities.testing_utils import *
from utilities.BayesOpt_Support import * # Import project modules
from src.environments import GeneralEnvironment
from src.mlflow_logger import MLflowLogger
from src.utility_functions import (
    calculate_cosine_diversity, get_policy_dist, load_initial_state,
    load_entire_model, plot_distributions_per_feature,
    setup_logger, simulate_trajectories,seed_all
)

import numpy as np
from scipy.spatial.distance import pdist
from scipy.stats import entropy
import numpy as np
import pandas as pd
import warnings
from itertools import combinations
import joblib
try:
    from fastdtw import fastdtw
except ImportError:
    warnings.warn("fastdtw not installed; falling back to naive DTW (slower).")
    def fastdtw(x, y, dist=None):
        # naive DTW for 1-D arrays
        T1, T2 = len(x), len(y)
        dtw_mat = np.full((T1+1, T2+1), np.inf)
        dtw_mat[0, 0] = 0.0
        for i in range(1, T1+1):
            for j in range(1, T2+1):
                cost = abs(x[i-1] - y[j-1])
                dtw_mat[i, j] = cost + min(
                    dtw_mat[i-1, j], dtw_mat[i, j-1], dtw_mat[i-1, j-1]
                )
        return dtw_mat[T1, T2], None
from scipy.spatial.distance import euclidean
from sklearn.cluster import AgglomerativeClustering
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

def average_reward(rewards: np.ndarray) -> float:
    """
    Compute the mean of rewards.
    """
    return np.mean(rewards)

def tail_coverage(predictions: np.ndarray, tau: float) -> float:
    """
    Percentage of trajectories with predicted return <= tau.
    """
    return np.mean(predictions <= tau)

def expected_shortfall(predictions: np.ndarray, q: float) -> float:
    """
    Expected Shortfall (ES) at quantile q (in percent).
    """
    cutoff = np.percentile(predictions, q)
    print(f"Cutoff for ES: {cutoff}")
    worst = predictions[predictions <= cutoff]
    print(f"Worst predictions for ES: {worst}")
    return worst.mean() if worst.size > 0 else np.nan

def calculate_euclidean_diversity(states: np.ndarray) -> tuple:
    """
    Average and normalized diversity based on pairwise Euclidean distances.
    Normalization is by the maximum observed distance.
    """
    distances = pdist(states, metric='euclidean')
    avg_div = distances.mean()
    max_div = distances.max()
    norm_div = avg_div / max_div if max_div > 0 else np.nan
    return avg_div, norm_div, distances

def coverage_epsilon(states: np.ndarray, eps: float) -> int:
    """
    Greedy epsilon-cover count: number of unique centers such that
    all points are within eps of at least one center.
    """
    N = states.shape[0]
    assigned = np.zeros(N, dtype=bool)
    centers = 0
    for i in range(N):
        if not assigned[i]:
            centers += 1
            dists = np.linalg.norm(states - states[i], axis=1)
            assigned |= (dists < eps)
    return centers

def calculate_quality_diversity(states: np.ndarray, rewards: np.ndarray) -> dict:
    """
    Compute diversity metrics adjusted for reward quality.

    Args:
        states: Array of shape (N, D) representing final states or trajectory summaries.
        rewards: Array of shape (N,) of corresponding rewards.

    Returns:
        A dict with:
        - avg_div: mean pairwise Euclidean distance
        - norm_div: avg_div normalized by the maximum pairwise distance
        - avg_reward: mean reward
        - reward_norm: (avg_reward - min_reward) / (max_reward - min_reward)
        - quality_diversity: avg_div * reward_norm (high only when diversity and reward quality are high)
        - distances: raw pairwise distances
    """
    distances = pdist(states, metric='euclidean')
    avg_div = float(distances.mean())
    rewards = copy.deepcopy(rewards)

    max_r = np.max(rewards)
    min_r = np.min(rewards)
    reward_range = np.abs(max_r - min_r)
    reward_scale = np.abs(max_r) + np.abs(min_r) + 1e-8  # Robust to sign and small values

    reward_stability = 1 - (reward_range / reward_scale)
    # Calculate normalized distance that takes rewards into account
    normalized_score = avg_div * reward_stability    



    return {
        "average_diversity": avg_div,
        "normalized_score": normalized_score,
        "distances": distances
    }

def cluster_trajectories(
    dist_matrix: np.ndarray,
    n_clusters: int = 3
) -> np.ndarray:
    """
    Cluster trajectories using Agglomerative Clustering on a precomputed distance matrix.
    """
    model = AgglomerativeClustering(
        n_clusters=n_clusters
    )
    labels = model.fit_predict(dist_matrix)
    return labels

def euclidean_distance(x, y):
    return np.sqrt(np.sum((x - y) ** 2))

def compute_dtw_distance_matrix(trajectories: np.ndarray) -> (np.ndarray, dict):
    # Ensure trajectories is a 3D array (N_samples, N_timesteps, N_features)
    trajectories = np.array([np.array(t) for t in trajectories])
    
    N = len(trajectories)
    D = trajectories[0].shape[1]  # Number of features
    
    dist_matrix = np.zeros((N, N), dtype=float)
    feature_dist_matrices = {d: np.zeros((N, N), dtype=float) for d in range(D)}
    
    total_pairs = N * (N - 1) // 2
    for i, j in tqdm(combinations(range(N), 2), total=total_pairs, desc="Computing DTW distances"):
        dists = []
        for d in range(D):
            # Extract 1D time series for each feature
            seq_i = trajectories[i][:, d]
            seq_j = trajectories[j][:, d]
            dist, _ = fastdtw(seq_i, seq_j, dist=euclidean_distance)
            feature_dist_matrices[d][i, j] = dist
            feature_dist_matrices[d][j, i] = dist
            dists.append(dist)
        agg_dist = np.mean(dists)
        dist_matrix[i, j] = agg_dist
        dist_matrix[j, i] = agg_dist
    
    return dist_matrix, feature_dist_matrices

def dtw_clustering_analysis(trajectories, n_clusters=3):
    # Convert list of trajectories to numpy array if needed
    if isinstance(trajectories, list):
        trajectories = np.array([np.array(t) for t in trajectories])
    
    # Ensure we have enough samples for the requested number of clusters
    if len(trajectories) < n_clusters:
        raise ValueError(f"Number of trajectories ({len(trajectories)}) must be >= number of clusters ({n_clusters})")
    
    dist_matrix, feature_dists = compute_dtw_distance_matrix(trajectories)
    agg_labels = cluster_trajectories(dist_matrix, n_clusters=n_clusters)
    
    df = pd.DataFrame({"trajectory": np.arange(len(trajectories)), 
                      "agg_cluster": agg_labels})
    
    # Cluster each feature separately
    feature_labels = {}
    for d, fmat in feature_dists.items():
        labels = cluster_trajectories(fmat, n_clusters=n_clusters)
        df[f"feature_{d}_cluster"] = labels
        feature_labels[d] = labels
    
    return df, dist_matrix, feature_labels

def plot_trajectories_by_cluster(results_df, trajectories, feature_names,num_clusters=5,save_path=None):    
    """
    Plot trajectories for each feature, colored by feature-specific clusters.
    
    Args:
        results_df (pd.DataFrame): DataFrame containing cluster assignments
        trajectories (np.ndarray): Array of trajectories
        feature_names (list): List of feature names
        save_path (str): Optional path to save the plot
    """
    # Get number of features and clusters
    n_features = trajectories.shape[2] - 1  # Excluding timestep
    n_clusters = num_clusters  # Based on the cluster assignments in results_df
    
    # Create color palette
    colors = plt.cm.viridis(np.linspace(0, 1, n_clusters))
    
    # Create figure with subplots
    fig, axs = plt.subplots(2, 3, figsize=(40, 24))
    axs = axs.flatten()
    
    # Time points for x-axis
    time_points = np.arange(trajectories.shape[1])
    
    # Plot each feature
    for feature_idx in range(n_features):
        ax = axs[feature_idx]
        cluster_column = f'feature_{feature_idx}_cluster'
        
        # Plot trajectories for each cluster
        for cluster in range(n_clusters):
            # Get indices for current feature's cluster
            cluster_indices = results_df[results_df[cluster_column] == cluster]['trajectory'].values
            
            # Plot each trajectory in cluster
            for idx in cluster_indices:
                ax.plot(time_points, 
                       trajectories[idx, :, feature_idx + 1],  # +1 to skip timestep
                       color=colors[cluster], 
                       alpha=0.3, 
                       linewidth=1)
            
            # Plot mean trajectory for cluster
            if len(cluster_indices) > 0:
                cluster_mean = trajectories[cluster_indices, :, feature_idx + 1].mean(axis=0)
                ax.plot(time_points, 
                       cluster_mean, 
                       color=colors[cluster], 
                       linewidth=3, 
                       label=f'Cluster {cluster} (n={len(cluster_indices)})')
        
        # Customize subplot
        ax.set_title(f'{feature_names[feature_idx+1]}', fontsize=12, pad=10)
        ax.set_xlabel('Time Step', fontsize=10)
        ax.set_ylabel('Value', fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=24)
    
    # Remove empty subplots if any
    for idx in range(n_features, len(axs)):
        fig.delaxes(axs[idx])
    
    # Adjust layout and add main title
    plt.suptitle('Trajectory Clusters by Feature', fontsize=16, y=1.02)
    plt.tight_layout()
    
    # Save the figure
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to: {save_path}")
    else:
        plt.savefig('trajectory_clusters_by_feature.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    
def plot_dtw_distance_matrix(distance_matrix, model_name, rewards, save_path=None):
    """
    Plot and save DTW distance matrix with additional metrics.
    
    Args:
        distance_matrix: The DTW distance matrix
        model_name: Name of the model (e.g., 'GFlowNet', 'REINFORCE')
        rewards: Array of rewards for normalization
        save_path: Optional path to save the PDF file
    """
    rewards = copy.deepcopy(rewards)

    # Reward stability: penalize large drops among top-K
    max_r = np.max(rewards)
    min_r = np.min(rewards)
    reward_range = np.abs(max_r - min_r)
    reward_scale = np.abs(max_r) + np.abs(min_r) + 1e-8  # Robust to sign and small values
    average_diversity = np.mean(distance_matrix)
    reward_stability = 1 - (reward_range / reward_scale)
    # Calculate normalized distance that takes rewards into account
    normalized_score = average_diversity * reward_stability
    
    # Create figure with high resolution and professional sizing
    plt.figure(figsize=(12, 10), dpi=300)
    
    # Create heatmap with improved styling
    sns.heatmap(distance_matrix,
                cmap='viridis',  # Professional colormap
                square=True,     # Make cells square
                fmt=".2f",  # Format for numbers
                vmin=0, vmax=200,  # Set limits for color range
                cbar_kws={
                    'label': 'DTW Distance',
                    'orientation': 'vertical',
                    
                })


    # Add title with metrics
    plt.title(f'DTW Distance Matrix - {model_name}\n' + 
              f'Avg Distance: {average_diversity:.2f}\n' +
              f'Normalized Quality Score: {normalized_score:.4f}',
              fontsize=16, 
              pad=20)
    
    # Customize the plot
    plt.xlabel('Trajectory Index', fontsize=12)
    plt.ylabel('Trajectory Index', fontsize=12)
    
    # Improve tick labels
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save if path provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to: {save_path}")
    
    plt.show()
    
    # Return metrics for potential further use
    return {
        'average_distance': average_diversity,
        'normalized_score': normalized_score,
        'reward_normalization': normalized_score
    }

def find_min_max_distance_pairs(distance_matrix):
    """
    Find pairs of trajectories with minimum and maximum distances.
    
    Args:
        distance_matrix (np.ndarray): Square matrix of pairwise distances
        
    Returns:
        tuple: ((i_min, j_min, min_dist), (i_max, j_max, max_dist))
            where i,j are trajectory indices and dist is their distance
    """
    # Create mask for upper triangle (excluding diagonal)
    mask = np.triu(np.ones_like(distance_matrix), k=1).astype(bool)
    
    # Get upper triangle values
    distances = distance_matrix[mask]
    
    # Find min and max distances
    min_dist = np.min(distances)
    max_dist = np.max(distances)
    
    # Get indices for min distance
    min_idx = np.where(distance_matrix == min_dist)
    i_min, j_min = min_idx[0][0], min_idx[1][0]
    
    # Get indices for max distance
    max_idx = np.where(distance_matrix == max_dist)
    i_max, j_max = max_idx[0][0], max_idx[1][0]
    
    # Ensure i < j for consistency
    if i_min > j_min:
        i_min, j_min = j_min, i_min
    if i_max > j_max:
        i_max, j_max = j_max, i_max
        
    return ((i_min, j_min, min_dist), (i_max, j_max, max_dist))

def plot_trajectories_over_time(trajectories, rewards, feature_names, n_top=50, alpha_others=0.05,save_path=None):
    """
    Plot trajectories over time for each feature, highlighting top performers.
    
    Args:
        trajectories (np.ndarray): Array of shape (n_trajectories, timesteps, features)
        rewards (np.ndarray): Array of shape (n_trajectories, timesteps)
        feature_names (list): List of feature names
        n_top (int): Number of top trajectories to highlight
        alpha_others (float): Alpha value for non-top trajectories
        save_path (str): Optional path to save the plot
    """
    # Get indices of top n rewards
    last_step_rewards = rewards
    top_indices = np.argsort(last_step_rewards)[-n_top:][::-1]
    top_colors = sns.color_palette("husl", n_top)
    
    n_features = trajectories.shape[2]
    
    # Plot each feature individually
    for feature_idx in range(n_features):
        plt.figure(figsize=(10, 5))
        
        # Plot all trajectories
        for traj_idx, trajectory in enumerate(trajectories):
            if traj_idx in top_indices:
                color_idx = np.where(top_indices == traj_idx)[0][0]
                plt.plot(trajectory[:, feature_idx], alpha=1,
                        color=top_colors[color_idx], linewidth=1, linestyle='dashed')
            else:
                plt.plot(trajectory[:, feature_idx], alpha=alpha_others, color='blue')
        
        # Add mean and median
        mean_feature = np.mean(trajectories[:, :, feature_idx], axis=0)
        median_feature = np.median(trajectories[:, :, feature_idx], axis=0)
        plt.plot(mean_feature, color='black', linewidth=2, label='Mean')
        plt.plot(median_feature, color='red', linewidth=2, label='Median')
        
        # Customize plot
        plt.legend(prop={'size': 10}, loc='upper left')
        plt.xlabel("Time Step", fontsize=10)
        plt.ylabel(f"{feature_names[feature_idx]} Value", fontsize=10)
        plt.title(f"Feature: {feature_names[feature_idx]}", fontsize=10)
        plt.grid(True)
        plt.tight_layout()
        if save_path:
            plt.savefig(os.path.join(save_path, f"feature_{feature_names[feature_idx]}.pdf"), dpi=300)
        plt.show()

def compare_trajectories(traj_idx1, traj_idx2, trajectories, feature_names, rewards,save_path=None):
    """
    Compare two trajectories by plotting features against time.
    
    Args:
        traj_idx1 (int): Index of first trajectory
        traj_idx2 (int): Index of second trajectory
        trajectories (np.ndarray): Array of trajectories
        feature_names (list): List of feature names
        rewards (np.ndarray): Array of rewards
        save_path (str): Optional path to save the plot
    """
    # Get the two trajectories
    traj1 = trajectories[traj_idx1]
    traj2 = trajectories[traj_idx2]
    
    # Get number of features (excluding timestamp)
    n_features = traj1.shape[1] - 1
    
    # Create a grid of subplots
    fig, axs = plt.subplots(n_features, 1, figsize=(15, 4*n_features))
    
    # Get timestamps
    timestamps = traj1[:, 0]
    
    # Plot each feature
    for i in range(n_features):
        # Plot feature against time
        axs[i].plot(timestamps, traj1[:, i+1], 'b-', label=f'Traj {traj_idx1}', linewidth=2)
        axs[i].plot(timestamps, traj2[:, i+1], 'r--', label=f'Traj {traj_idx2}', linewidth=2)
        
        # Mark start and end points
        axs[i].plot(timestamps[0], traj1[0, i+1], 'bo', label='Start 1')
        axs[i].plot(timestamps[-1], traj1[-1, i+1], 'b*', label='End 1', markersize=10)
        axs[i].plot(timestamps[0], traj2[0, i+1], 'ro', label='Start 2')
        axs[i].plot(timestamps[-1], traj2[-1, i+1], 'r*', label='End 2', markersize=10)
        
        # Add labels
        axs[i].set_xlabel('Time')
        axs[i].set_ylabel(feature_names[i+1])
        
        # Add grid
        axs[i].grid(True, alpha=0.3)
        
        # Only add legend to first subplot
        if i == 0:
            axs[i].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Add title with trajectory information
    plt.suptitle(f'Comparison of Trajectories {traj_idx1} vs {traj_idx2}\n' + 
                 f'Final Rewards: {rewards[traj_idx1]:.2f} vs {rewards[traj_idx2]:.2f}',
                 fontsize=16, y=1.02)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()

def plot_action_distributions_by_timestep(trajectories, all_actions, feature_names, save_path=None):
    """
    Plot the distribution of actions for each feature at each timestep.
    
    Args:
        trajectories (np.ndarray): Array of shape (n_trajectories, timesteps, features)
        all_actions (np.ndarray): Array of shape (n_trajectories, timesteps, features)
        feature_names (list): List of feature names
        save_path (str, optional): Path to save output plots. If None, won't save plots.
    """
    n_features = trajectories.shape[2] - 1
    n_timesteps = trajectories.shape[1] - 1

    for feature_idx in range(n_features-1):
        # Create a figure for the current feature
        fig, axs = plt.subplots(3, 4, figsize=(16, 12))  # Assuming 12 timesteps (3x4 grid)
        axs = axs.flatten()  # Flatten the 2D grid to iterate easily
        
        # Iterate over each timestep for this feature
        for timestep in range(n_timesteps):
            # Get the actions for the current feature at the current timestep
            actions_per_timestep = all_actions[:, timestep, feature_idx]
            
            # Plot the distribution of actions for this timestep
            sns.histplot(actions_per_timestep, kde=True, ax=axs[timestep], stat="density", bins=50)
            
            # Set labels and title for each subplot
            axs[timestep].set_title(f"Timestep {timestep + 1}", fontsize=12)
            axs[timestep].set_xlabel(f"Action Value", fontsize=10)
            axs[timestep].set_ylabel(f"Density", fontsize=10)
        
        # Adjust layout and set a super title for the current feature
        fig.suptitle(f"Distribution of Sampled Actions - Feature {feature_names[feature_idx+1]}", fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust the layout to fit the title
        
        # Save the figure if output directory is provided
        if save_path:
            save_pathz = os.path.join(save_path, f"distribution_feature_{feature_names[feature_idx+1]}.pdf")
            plt.savefig(save_pathz, dpi=300, bbox_inches='tight')
        
        plt.show()
def plot_action_distributions(all_actions, feature_names, save_path=None):
    """
    Create distribution plots for actions across all features and timesteps.
    
    Args:
        all_actions (np.ndarray): Array of shape (n_trajectories, timesteps, n_features)
        feature_names (list): List of feature names
        save
         save_path (str, optional): Path to save the plot. If None, won't save.
    """
    # Create a subplot for each feature
    n_features = all_actions.shape[2]
    fig, axs = plt.subplots(2, 3, figsize=(20, 12))
    axs = axs.flatten()

    # Iterate over each feature
    for feature_idx in range(n_features):
        # Get all actions for this feature across all timesteps
        feature_actions = all_actions[:, :, feature_idx].flatten()
        
        # Create distribution plot
        sns.histplot(feature_actions, kde=True, ax=axs[feature_idx], stat="density", bins=50)
        
        # Add mean and median lines
        mean_val = np.mean(feature_actions)
        median_val = np.median(feature_actions)
        axs[feature_idx].axvline(mean_val, color='red', linestyle='--', label=f'Mean: {mean_val:.2f}')
        axs[feature_idx].axvline(median_val, color='green', linestyle='--', label=f'Median: {median_val:.2f}')
        
        # Customize subplot
        axs[feature_idx].set_title(f"Distribution of Actions - {feature_names[feature_idx+1]}", fontsize=12)
        axs[feature_idx].set_xlabel("Action Value")
        axs[feature_idx].set_ylabel("Density")
        axs[feature_idx].legend()

    plt.tight_layout()
    
    # Save to high quality pdf if save_path provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
def plot_reward_distribution(last_step_rewards, save_path=None):
    """
    Plot the distribution of final rewards with statistics.
    
    Args:
        last_step_rewards (array-like): Array of final reward values
        save_path (str, optional): Path to save the plot
    """
    # Create a figure with proper sizing
    plt.figure(figsize=(12, 6))

    # Create the distribution plot using seaborn
    sns.histplot(data=last_step_rewards, bins=50, kde=True, stat='density')

    # Add mean and median lines
    plt.axvline(np.mean(last_step_rewards), color='red', linestyle='dashed', 
                linewidth=2, label=f'Mean: {np.mean(last_step_rewards):.2f}')
    plt.axvline(np.median(last_step_rewards), color='green', linestyle='dashed', 
                linewidth=2, label=f'Median: {np.median(last_step_rewards):.2f}')

    # Add labels and title
    plt.xlabel('Final Reward Values', fontsize=12)
    plt.ylabel('Density', fontsize=12)
    plt.title('Distribution of Final Rewards', fontsize=14, pad=20)

    # Add statistics text box
    stats_text = f'Statistics:\n' \
                 f'Mean: {np.mean(last_step_rewards):.2f}\n' \
                 f'Median: {np.median(last_step_rewards):.2f}\n' \
                 f'Std: {np.std(last_step_rewards):.2f}\n' \
                 f'Max: {np.max(last_step_rewards):.2f}\n' \
                 f'Min: {np.min(last_step_rewards):.2f}'

    plt.text(0.9, 0.75, stats_text, transform=plt.gca().transAxes, 
             verticalalignment='center', horizontalalignment='left',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    # Save if path provided
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
from scipy.spatial.distance import squareform
def plot_diversity_metrics(df, feature_names, save_path=None):
    """
    Plot diversity metrics and euclidean distances for trajectories.
    
    Args:
        df (pd.DataFrame): DataFrame containing trajectory information and rewards
        feature_names (list): List of feature names
        save_path (str, optional): Path to save the plot
    
    Returns:
        tuple: (avg_diversity, normalized_diversity, distances)
    """
    # Calculate diversity metrics
    finalrewards = df['FinalReward']
    res_dict = calculate_quality_diversity(
        df[feature_names].values,
        finalrewards
    )

    avg_euclidean_diversity = res_dict['average_diversity']
    norm_diversity = res_dict['normalized_score']
    distances = res_dict['distances']

    # Create heatmap with improved visualization
    plt.figure(figsize=(12, 10))
    euclidean_distances_matrix = squareform(distances)
    sns.heatmap(euclidean_distances_matrix, 
                cmap='viridis', 
                fmt=".2f",
                vmin=0,
                vmax=200,  # Set maximum scale for better color contrast
                cbar_kws={
                    'label': 'Euclidean Distance',
                    'orientation': 'vertical'
                })

    # Add title with diversity scores
    plt.title(f"Euclidean Distance Matrix for Trajectories\n" + 
              f"Diversity Score: {avg_euclidean_diversity:.2f}, Normalized: {norm_diversity:.4f}",
              pad=20)

    # Improve axis labels
    plt.xlabel("Trajectory Index", fontsize=12)
    plt.ylabel("Trajectory Index", fontsize=12)

    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)

    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    return avg_euclidean_diversity, norm_diversity, distances

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
from tqdm import tqdm

def perform_clustering_analysis(final_states_df, feature_names, n_clusters_range=(2, 10)):
    """
    Perform clustering analysis on trajectory data and visualize results.
    
    Args:
        final_states_df (pd.DataFrame): DataFrame containing final states of trajectories
        feature_names (list): List of feature names
        n_clusters_range (tuple): Range of number of clusters to try (min, max)
    
    Returns:
        tuple: (optimal_clusters, cluster_labels, tsne_results)
    """
    # Scale the features before clustering
    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(final_states_df)
    
    # Determine optimal number of clusters using silhouette score
    silhouette_scores = []
    K = range(*n_clusters_range)
    for k in tqdm(K, desc="Finding optimal clusters"):
        kmeans = KMeans(n_clusters=k, random_state=42)
        cluster_labels = kmeans.fit_predict(scaled_features)
        silhouette_avg = silhouette_score(scaled_features, cluster_labels)
        silhouette_scores.append(silhouette_avg)
    
    # Get optimal number of clusters
    optimal_clusters = K[np.argmax(silhouette_scores)]
    
    # Perform K-means clustering with optimal number of clusters
    kmeans = KMeans(n_clusters=optimal_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(scaled_features)
    
    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(scaled_features)
    
    return optimal_clusters, cluster_labels, tsne_results

def plot_reward_distributions_by_cluster(cluster_labels, rewards, optimal_clusters, 
                                       save_path=None, figsize=(12, 8), dpi=300):
    """
    Plot reward distributions for each cluster with statistics.
    
    Args:
        cluster_labels (array-like): Cluster assignments for each trajectory
        rewards (array-like): Reward values for each trajectory
        optimal_clusters (int): Number of clusters
        save_path (str, optional): Path to save the plot
        figsize (tuple): Figure size (width, height)
        dpi (int): Dots per inch for the plot
        
    Returns:
        dict: Statistics for each cluster
    """
    # Create figure with specified resolution and style
    plt.figure(figsize=figsize, dpi=dpi)
    
    # Get rewards for each cluster
    cluster_rewards = []
    for i in range(optimal_clusters):
        cluster_rewards.append(np.array(rewards)[cluster_labels == i])
    
    # Use viridis color scheme
    colors = plt.cm.viridis(np.linspace(0, 1, optimal_clusters))
    labels = [f'Cluster {i+1} (n={len(rewards)})' for i, rewards in enumerate(cluster_rewards)]
    
    # Plot distributions
    for i, rewards in enumerate(cluster_rewards):
        sns.kdeplot(data=rewards, 
                   color=colors[i], 
                   alpha=0.7,
                   linewidth=2,
                   label=labels[i])
    
    # Add median lines for each cluster
    for i, rewards in enumerate(cluster_rewards):
        plt.axvline(np.median(rewards), 
                   color=colors[i], 
                   linestyle='--', 
                   alpha=0.5,
                   linewidth=1.5)
    
    # Customize the plot
    plt.title('Reward Distribution by Cluster', fontsize=14, pad=20)
    plt.xlabel('Reward Value', fontsize=12)
    plt.ylabel('Density', fontsize=12)
    plt.grid(True, alpha=0.2)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Calculate and format statistics
    cluster_stats = {}
    stats_text = 'Cluster Statistics:\n\n'
    for i, rewards in enumerate(cluster_rewards):
        stats = {
            'median': np.median(rewards),
            'mean': np.mean(rewards),
            'std': np.std(rewards),
            'size': len(rewards)
        }
        cluster_stats[f'Cluster_{i+1}'] = stats
        stats_text += (f'Cluster {i+1}:\n'
                      f'Median: {stats["median"]:.2f}\n'
                      f'Mean: {stats["mean"]:.2f}\n'
                      f'Std: {stats["std"]:.2f}\n')
    
    # Add statistics text box
    plt.text(1.35, 0.5, stats_text, 
             transform=plt.gca().transAxes,
             bbox=dict(facecolor='white', alpha=0.8, edgecolor='none'),
             fontsize=10)
    
    plt.tight_layout()
    
    # Save if path provided
    if save_path:
        plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
    
    plt.show()
    
    return cluster_stats

def analyze_cluster_diversity(trajectories, cluster_labels, optimal_clusters, last_step_rewards, feature_names):
    """
    Analyze diversity within each cluster and create visualizations.
    
    Args:
        trajectories (np.ndarray): Array of shape (n_trajectories, timesteps, features)
        cluster_labels (np.ndarray): Cluster assignments for each trajectory
        optimal_clusters (int): Number of clusters
        last_step_rewards (np.ndarray): Final rewards for each trajectory
        feature_names (list): Names of features
    
    Returns:
        list: List of tuples containing (cluster_id, diversity_score, normalized_diversity)
    """
    cluster_diversity_scores = []
    
    for i in range(optimal_clusters):
        # Get indices for current cluster
        cluster_indices = np.where(cluster_labels == i)[0]
        
        # Select trajectories and rewards for current cluster
        cluster_trajectories = trajectories[cluster_indices]
        cluster_rewards = last_step_rewards[cluster_indices]
        
        # Calculate diversity scores
        avg_euclidean_diversity, avg_euclidean_diversity_normalized, euclidean_distances = calculate_euclidean_diversity(
            cluster_trajectories[:, -1, 1:],  # Take final timestep values, exclude timestep column
            cluster_rewards
        )
        
        cluster_diversity_scores.append((i, avg_euclidean_diversity, avg_euclidean_diversity_normalized))
        
        # Create heatmap for current cluster
        plt.figure(figsize=(8, 8))
        from scipy.spatial.distance import squareform
        sns.heatmap(squareform(euclidean_distances), 
                   cmap='viridis', 
                   fmt=".2f", 
                   cbar_kws={'label': 'Euclidean Distance'})
        plt.title(f"Cluster {i+1} - Euclidean Distance Matrix\n" + 
                 f"Diversity Score: {avg_euclidean_diversity:.4f}, " + 
                 f"Normalized: {avg_euclidean_diversity_normalized:.4f}")
        plt.xlabel("Trajectory Index")
        plt.ylabel("Trajectory Index")
        plt.tight_layout()
        plt.show()
    
    return cluster_diversity_scores

def analyze_cluster_feature_variability(final_states_df, cluster_labels, optimal_clusters, feature_names):
    """
    Analyze and visualize the feature variability within clusters using multiple visualization methods.
    
    Args:
        final_states_df (pd.DataFrame): DataFrame containing the final states
        cluster_labels (np.ndarray): Array of cluster assignments
        optimal_clusters (int): Number of clusters
        feature_names (list): List of feature names
    
    Returns:
        pd.DataFrame: DataFrame containing feature standard deviations by cluster
    """
    # Calculate standard deviations per feature per cluster
    cluster_feature_stds = []
    for i in range(optimal_clusters):
        cluster_indices = np.where(cluster_labels == i)[0]
        cluster_features = final_states_df.iloc[cluster_indices]
        feature_stds = cluster_features.std()
        cluster_feature_stds.append(feature_stds)

    # Create a DataFrame with the feature STDs for each cluster
    feature_std_df = pd.DataFrame([
        {**{'Cluster': f'Cluster {i+1}', 'Size': len(np.where(cluster_labels == i)[0])}, 
         **{f_name: std for f_name, std in zip(feature_names[1:], cluster_feature_stds[i])}}
        for i in range(optimal_clusters)
    ])

    # Create a heatmap of feature STDs per cluster
    plt.figure(figsize=(12, 6))
    sns.heatmap(feature_std_df.set_index('Cluster')[feature_names[1:]], 
                annot=True, 
                fmt='.2f',
                cmap='viridis',
                cbar_kws={'label': 'Standard Deviation'})
    plt.title('Feature-wise Standard Deviations by Cluster')
    plt.xlabel('Features')
    plt.ylabel('Clusters')
    plt.tight_layout()
    plt.show()

    # Create a radar plot
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='polar')

    # Prepare the data for radar plot
    angles = np.linspace(0, 2*np.pi, len(feature_names[1:]), endpoint=False)
    angles = np.concatenate((angles, [angles[0]]))  # complete the circle

    for i in range(optimal_clusters):
        values = feature_std_df.iloc[i][feature_names[1:]].values
        values = np.concatenate((values, [values[0]]))  # complete the circle
        ax.plot(angles, values, 'o-', linewidth=2, label=f'Cluster {i+1}')
        ax.fill(angles, values, alpha=0.25)

    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(feature_names[1:], size=8)
    ax.set_title('Feature Standard Deviations by Cluster (Radar Plot)')
    plt.legend(bbox_to_anchor=(1.2, 1))
    plt.tight_layout()
    plt.show()

    # Create a box plot
    plt.figure(figsize=(15, 6))
    feature_std_melted = pd.melt(feature_std_df, 
                                id_vars=['Cluster', 'Size'], 
                                value_vars=feature_names[1:],
                                var_name='Feature',
                                value_name='Standard Deviation')

    sns.boxplot(data=feature_std_melted, x='Feature', y='Standard Deviation', hue='Cluster')
    plt.xticks(rotation=45)
    plt.title('Distribution of Feature Standard Deviations by Cluster')
    plt.tight_layout()
    plt.show()

    return feature_std_df

def plot_phase_portraits(trajectories, feature_names, n_trajectories=100, figsize=(30, 30), alpha=0.1):
    """
    Create phase portraits for pairs of features from the trajectories.
    
    Args:
        trajectories (np.ndarray): Array of shape (n_trajectories, timesteps, features)
        feature_names (list): List of feature names
        n_trajectories (int): Number of trajectories to plot
        figsize (tuple): Figure size (width, height)
        alpha (float): Transparency of trajectory lines
    """
    important_features = range(len(feature_names))
    plt.figure(figsize=figsize)

    for i, f1 in enumerate(important_features):
        for j, f2 in enumerate(important_features):
            if i >= j:
                continue
                
            plt.subplot(len(important_features), len(important_features), i*len(important_features)+j+1)
            for traj in trajectories[:n_trajectories]:
                plt.plot(traj[:, f1+1], traj[:, f2+1], 'b-', alpha=alpha)  # +1 to skip timestamp
                plt.plot(traj[0, f1+1], traj[0, f2+1], 'go', markersize=2)  # Start
                plt.plot(traj[-1, f1+1], traj[-1, f2+1], 'ro', markersize=2)  # End
                
            plt.xlabel(f"{feature_names[f1+1]}")  # +1 to skip timestamp
            plt.ylabel(f"{feature_names[f2+1]}")  # +1 to skip timestamp
    
    plt.tight_layout()
    return plt.gcf()

def plot_multi_model_reward_distributions_with_ci_corrected(model_results_dict, save_path=None, confidence_level=0.95):
    """
    Plot reward distributions for multiple models with proper confidence intervals.
    Shows mean KDE per method with confidence intervals for the MEAN VALUES, not density.
    
    Args:
        model_results_dict (dict): Dictionary where keys are model names and values are 
                                 dictionaries with run results containing 'rewards' arrays
        save_path (str, optional): Path to save the plot
        confidence_level (float): Confidence level for intervals (default 0.95 for 95% CI)
    """
    from scipy.stats import gaussian_kde, t
    import numpy as np
    
    # Create professional distribution plot
    plt.figure(figsize=(14, 10), dpi=300)
    
    # Define color palette for different models
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    model_stats = {}
    
    # First pass: collect data and determine global x range
    global_x_min = float('inf')
    global_x_max = float('-inf')
    all_model_data = {}
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    
                    # Update global range
                    global_x_min = min(global_x_min, np.min(rewards))
                    global_x_max = max(global_x_max, np.max(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': np.array(run_means)
            }
    
    # Add padding to global range
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    
    # Common x range for all KDE calculations
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # Second pass: create plots for each model
    for model_name, data in all_model_data.items():
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        color = colors.get(model_name, 'gray')
        
        # Calculate overall statistics
        overall_mean = np.mean(run_means)
        overall_std = np.std(run_means, ddof=1)  # Sample standard deviation
        n_runs = len(run_means)
        
        # Calculate confidence interval for the MEAN (not the density)
        if n_runs > 1:
            # Use t-distribution for small samples
            t_critical = t.ppf((1 + confidence_level) / 2, df=n_runs-1)
            margin_error = t_critical * (overall_std / np.sqrt(n_runs))
            ci_lower = overall_mean - margin_error
            ci_upper = overall_mean + margin_error
        else:
            ci_lower = ci_upper = overall_mean
        
        # Calculate mean KDE curve across all runs
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                mean_kde_curve = np.mean(kde_curves, axis=0)
                
                # Plot mean KDE curve
                plt.plot(common_x, mean_kde_curve, color=color, linewidth=3, 
                        label=f'{model_name.upper()}: μ={overall_mean:.2f}±{overall_std:.2f}')
                
                # Add confidence band around the KDE (this represents variability between runs)
                std_kde_curve = np.std(kde_curves, axis=0)
                plt.fill_between(common_x, 
                               np.maximum(mean_kde_curve - 0.5*std_kde_curve, 0), 
                               mean_kde_curve + 0.5*std_kde_curve,
                               color=color, alpha=0.2, 
                               label=f'{model_name.upper()} KDE variability')
        
        # Add vertical lines for mean confidence interval
        plt.axvline(overall_mean, color=color, linestyle='-', linewidth=2, alpha=0.8)
        plt.axvline(ci_lower, color=color, linestyle=':', alpha=0.6, linewidth=1.5)
        plt.axvline(ci_upper, color=color, linestyle=':', alpha=0.6, linewidth=1.5)
        
        # Add shaded area for mean confidence interval
        plt.axvspan(ci_lower, ci_upper, color=color, alpha=0.1)
        
        model_stats[model_name] = {
            'mean': overall_mean,
            'std': overall_std,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'n_runs': n_runs
        }
    
    # Styling
    plt.title(f'Model Comparison: Mean Distributions with {int(confidence_level*100)}% Confidence Intervals\n'
              f'Solid lines = Mean KDE, Dotted lines = CI bounds for means, Shaded = KDE variability', 
              fontsize=16, fontweight='bold', pad=25)
    plt.xlabel('Final Reward Values', fontsize=14, fontweight='bold')
    plt.ylabel('Probability Density', fontsize=14, fontweight='bold')
    
    # Professional styling
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.legend(fontsize=10, frameon=True, framealpha=0.9, loc='upper left')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    # Add statistics text box
    stats_text = f"{int(confidence_level*100)}% Confidence Intervals for Means:\n"
    for model_name, stats in model_stats.items():
        stats_text += f"{model_name.upper()}: [{stats['ci_lower']:.2f}, {stats['ci_upper']:.2f}]\n"
    
    plt.text(0.98, 0.98, stats_text, transform=plt.gca().transAxes,
             verticalalignment='top', horizontalalignment='right',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor='black'),
             fontsize=10)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    return model_stats

# Much cleaner version - just show the mean KDE without confusing "confidence bounds"
def plot_multi_model_clean_kde(model_results_dict, save_path=None):
    """
    Clean version: just show mean KDE per method with proper statistical annotations.
    """
    plt.figure(figsize=(14, 8), dpi=300)
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    model_stats = {}
    
    # Collect data and determine global range
    global_x_min = float('inf')
    global_x_max = float('-inf')
    all_model_data = {}
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        run_maxs = []
        run_mins = []
        run_stds = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    run_maxs.append(np.max(rewards))
                    run_mins.append(np.min(rewards))
                    run_stds.append(np.std(rewards, ddof=1))
                    global_x_min = min(global_x_min, np.min(rewards))
                    global_x_max = max(global_x_max, np.max(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': np.array(run_means),
                'run_maxs': np.array(run_maxs),
                'run_mins': np.array(run_mins),
                'run_stds': np.array(run_stds)
            }
    
    # Add padding
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # Plot each model
    for model_name, data in all_model_data.items():
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        run_maxs = data['run_maxs']
        run_mins = data['run_mins']
        run_stds = data['run_stds']
        
        color = colors.get(model_name, 'gray')
        overall_mean = np.mean(run_means)
        overall_std = np.std(run_means, ddof=1)
        
        # Calculate mean KDE
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        from scipy.stats import gaussian_kde
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                mean_kde_curve = np.mean(kde_curves, axis=0)
                
                # Plot clean KDE line
                plt.plot(common_x, mean_kde_curve, color=color, linewidth=3,
                    label=f'{model_name.upper()}: μ={overall_mean:.2f}±{overall_std:.2f}')
                
                
                # Add mean line
                plt.axvline(overall_mean, color=color, linestyle='--', 
                          linewidth=2, alpha=0.7)
        
        model_stats[model_name] = {
            'mean': overall_mean,
            'std': overall_std,
            'n_runs': len(run_means)
        }
    
    plt.title('Model Performance Comparison: Average Reward Distributions',
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Final Reward', fontsize=14, fontweight='bold')
    plt.ylabel('Probability Density', fontsize=14, fontweight='bold')
    
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12, frameon=True, framealpha=0.9)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    return model_stats

import numpy as np
from scipy import stats
from scipy.stats import mannwhitneyu, wilcoxon, kruskal, friedmanchisquare
import pandas as pd
import seaborn as sns
from itertools import combinations
import warnings

import matplotlib.pyplot as plt
warnings.filterwarnings('ignore')

def extract_run_means(model_results_dict):
    """
    Extract mean rewards from each run for all models.
    
    Returns:
        dict: Dictionary with model names as keys and arrays of run means as values
    """
    run_means = {}
    
    for model_name, runs_dict in model_results_dict.items():
        means = []
        for run_idx in range(30):
            rewards = runs_dict[run_idx]['rewards']
            
            # Handle different reward formats
            if isinstance(rewards, (list, tuple)):
                rewards = np.array(rewards, dtype=float)
            elif isinstance(rewards, np.ndarray):
                rewards = rewards.astype(float)
            
            # Extract final rewards
            if rewards.ndim == 1:
                final_rewards = rewards
            elif rewards.ndim == 2:
                final_rewards = rewards[:, -1]
            else:
                continue
                
            # Remove invalid values
            final_rewards = final_rewards[np.isfinite(final_rewards)]
            if len(final_rewards) > 0:
                means.append(np.mean(final_rewards))
        
        run_means[model_name] = np.array(means)
    
    return run_means

def perform_pairwise_tests(run_means, alpha=0.05):
    """
    Perform pairwise statistical tests between all method pairs.
    
    Args:
        run_means (dict): Dictionary with model names and their run means
        alpha (float): Significance level
    
    Returns:
        dict: Results of all pairwise comparisons
    """
    models = list(run_means.keys())
    n_models = len(models)
    
    # Results storage
    results = {
        'mann_whitney': {},
        'welch_t_test': {},
        'effect_sizes': {},
        'summary': []
    }
    
    print("PAIRWISE STATISTICAL TESTS")
    print("=" * 80)
    
    for i, model1 in enumerate(models):
        for j, model2 in enumerate(models):
            if i >= j:  # Avoid duplicate comparisons
                continue
                
            data1 = run_means[model1]
            data2 = run_means[model2]
            
            # Mann-Whitney U test (non-parametric)
            mw_stat, mw_p = mannwhitneyu(data1, data2, alternative='two-sided')
            
            # Welch's t-test (assumes unequal variances)
            welch_stat, welch_p = stats.ttest_ind(data1, data2, equal_var=False)
            
            # Effect size (Cohen's d)
            pooled_std = np.sqrt((np.var(data1, ddof=1) + np.var(data2, ddof=1)) / 2)
            cohens_d = (np.mean(data1) - np.mean(data2)) / pooled_std
            
            # Cliff's delta (non-parametric effect size)
            def cliffs_delta(x, y):
                n1, n2 = len(x), len(y)
                delta = 0
                for i in x:
                    for j in y:
                        if i > j:
                            delta += 1
                        elif i < j:
                            delta -= 1
                return delta / (n1 * n2)
            
            cliff_delta = cliffs_delta(data1, data2)
            
            # Store results
            pair_key = f"{model1}_vs_{model2}"
            results['mann_whitney'][pair_key] = {'statistic': mw_stat, 'p_value': mw_p}
            results['welch_t_test'][pair_key] = {'statistic': welch_stat, 'p_value': welch_p}
            results['effect_sizes'][pair_key] = {'cohens_d': cohens_d, 'cliffs_delta': cliff_delta}
            
            # Summary
            summary_row = {
                'Model_1': model1,
                'Model_2': model2,
                'Mean_1': np.mean(data1),
                'Mean_2': np.mean(data2),
                'Std_1': np.std(data1, ddof=1),
                'Std_2': np.std(data2, ddof=1),
                'Mann_Whitney_p': mw_p,
                'Welch_t_p': welch_p,
                'Cohens_d': cohens_d,
                'Cliffs_delta': cliff_delta,
                'MW_significant': mw_p < alpha,
                'Welch_significant': welch_p < alpha
            }
            results['summary'].append(summary_row)
            
            # Print results
            print(f"\n{model1.upper()} vs {model2.upper()}:")
            print(f"  Means: {np.mean(data1):.3f} ± {np.std(data1, ddof=1):.3f} vs {np.mean(data2):.3f} ± {np.std(data2, ddof=1):.3f}")
            print(f"  Mann-Whitney U: p = {mw_p:.6f} {'***' if mw_p < 0.001 else '**' if mw_p < 0.01 else '*' if mw_p < 0.05 else 'ns'}")
            print(f"  Welch's t-test: p = {welch_p:.6f} {'***' if welch_p < 0.001 else '**' if welch_p < 0.01 else '*' if welch_p < 0.05 else 'ns'}")
            print(f"  Effect sizes: Cohen's d = {cohens_d:.3f}, Cliff's δ = {cliff_delta:.3f}")
    
    return results

def perform_omnibus_tests(run_means):
    """
    Perform omnibus tests to check if there are any significant differences between groups.
    """
    models = list(run_means.keys())
    data_arrays = [run_means[model] for model in models]
    
    print("\n" + "="*80)
    print("OMNIBUS TESTS")
    print("="*80)
    
    # Kruskal-Wallis test (non-parametric ANOVA)
    kw_stat, kw_p = kruskal(*data_arrays)
    print(f"Kruskal-Wallis test: H = {kw_stat:.3f}, p = {kw_p:.6f}")
    print(f"Interpretation: {'Significant differences exist between groups' if kw_p < 0.05 else 'No significant differences between groups'}")
    
    # One-way ANOVA (parametric)
    f_stat, f_p = stats.f_oneway(*data_arrays)
    print(f"One-way ANOVA: F = {f_stat:.3f}, p = {f_p:.6f}")
    print(f"Interpretation: {'Significant differences exist between groups' if f_p < 0.05 else 'No significant differences between groups'}")
    
    return {'kruskal_wallis': (kw_stat, kw_p), 'anova': (f_stat, f_p)}

def bonferroni_correction(results, alpha=0.05):
    """
    Apply Bonferroni correction for multiple comparisons.
    """
    summary_df = pd.DataFrame(results['summary'])
    n_comparisons = len(summary_df)
    corrected_alpha = alpha / n_comparisons
    
    print(f"\n" + "="*80)
    print("BONFERRONI CORRECTION")
    print("="*80)
    print(f"Original α = {alpha}")
    print(f"Number of comparisons = {n_comparisons}")
    print(f"Bonferroni-corrected α = {corrected_alpha:.6f}")
    
    # Apply correction
    summary_df['MW_significant_bonferroni'] = summary_df['Mann_Whitney_p'] < corrected_alpha
    summary_df['Welch_significant_bonferroni'] = summary_df['Welch_t_p'] < corrected_alpha
    
    print(f"\nSignificant comparisons after Bonferroni correction:")
    significant_mw = summary_df[summary_df['MW_significant_bonferroni']]
    significant_welch = summary_df[summary_df['Welch_significant_bonferroni']]
    
    print(f"Mann-Whitney U test: {len(significant_mw)}/{n_comparisons} significant")
    print(f"Welch's t-test: {len(significant_welch)}/{n_comparisons} significant")
    
    if len(significant_mw) > 0:
        print("\nSignificant Mann-Whitney comparisons (Bonferroni-corrected):")
        for _, row in significant_mw.iterrows():
            print(f"  {row['Model_1']} vs {row['Model_2']}: p = {row['Mann_Whitney_p']:.6f}")
    
    return summary_df, corrected_alpha

def create_statistical_plots(run_means, results, save_path=None):
    """
    Create comprehensive statistical visualization plots.
    """
    summary_df = pd.DataFrame(results['summary'])
    
    # Create a 2x2 subplot figure
    fig, axes = plt.subplots(2, 2, figsize=(16, 12), dpi=300)
    
    # Plot 1: Box plot comparison
    ax1 = axes[0, 0]
    data_for_box = []
    labels_for_box = []
    colors = ['#2ecc71', '#3498db', '#f39c12', '#9b59b6', '#e74c3c']
    
    for i, (model, means) in enumerate(run_means.items()):
        data_for_box.append(means)
        labels_for_box.append(model.upper())
    
    bp = ax1.boxplot(data_for_box, labels=labels_for_box, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    
    ax1.set_title('Distribution of Run Means by Method', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Mean Final Reward', fontsize=12)
    ax1.grid(True, alpha=0.3)
    ax1.tick_params(axis='x', rotation=45)
    
    # Plot 2: P-value heatmap
    ax2 = axes[0, 1]
    models = list(run_means.keys())
    n_models = len(models)
    p_matrix = np.ones((n_models, n_models))
    
    for i, model1 in enumerate(models):
        for j, model2 in enumerate(models):
            if i != j:
                pair_key = f"{model1}_vs_{model2}" if i < j else f"{model2}_vs_{model1}"
                if pair_key in results['mann_whitney']:
                    p_val = results['mann_whitney'][pair_key]['p_value']
                    p_matrix[i, j] = p_val
    
    # Use -log10(p) for better visualization
    log_p_matrix = -np.log10(p_matrix + 1e-10)  # Add small value to avoid log(0)
    
    im = ax2.imshow(log_p_matrix, cmap='Reds', aspect='auto')
    ax2.set_xticks(range(n_models))
    ax2.set_yticks(range(n_models))
    ax2.set_xticklabels([m.upper() for m in models], rotation=45)
    ax2.set_yticklabels([m.upper() for m in models])
    ax2.set_title('-log₁₀(p-value) Heatmap\n(Mann-Whitney U Test)', fontsize=14, fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax2)
    cbar.set_label('-log₁₀(p-value)', fontsize=10)
    
    # Add significance thresholds
    ax2.axhline(y=-0.5, color='black', linestyle='--', alpha=0.5)
    ax2.axvline(x=-0.5, color='black', linestyle='--', alpha=0.5)
    
    # Plot 3: Effect sizes
    ax3 = axes[1, 0]
    cohens_d_values = [results['effect_sizes'][pair]['cohens_d'] for pair in results['effect_sizes']]
    pair_labels = [pair.replace('_vs_', ' vs ').replace('_', ' ').title() for pair in results['effect_sizes']]
    
    colors_effect = ['red' if abs(d) > 0.8 else 'orange' if abs(d) > 0.5 else 'green' for d in cohens_d_values]
    
    bars = ax3.barh(range(len(cohens_d_values)), cohens_d_values, color=colors_effect, alpha=0.7)
    ax3.set_yticks(range(len(pair_labels)))
    ax3.set_yticklabels(pair_labels, fontsize=10)
    ax3.set_xlabel("Cohen's d", fontsize=12)
    ax3.set_title("Effect Sizes (Cohen's d)", fontsize=14, fontweight='bold')
    ax3.grid(True, alpha=0.3, axis='x')
    ax3.axvline(x=0, color='black', linestyle='-', alpha=0.8)
    ax3.axvline(x=0.2, color='gray', linestyle='--', alpha=0.5, label='Small effect')
    ax3.axvline(x=0.5, color='gray', linestyle='--', alpha=0.5, label='Medium effect')
    ax3.axvline(x=0.8, color='gray', linestyle='--', alpha=0.5, label='Large effect')
    ax3.axvline(x=-0.2, color='gray', linestyle='--', alpha=0.5)
    ax3.axvline(x=-0.5, color='gray', linestyle='--', alpha=0.5)
    ax3.axvline(x=-0.8, color='gray', linestyle='--', alpha=0.5)
    
    # Plot 4: Statistical power analysis
    ax4 = axes[1, 1]
    
    # Create significance summary
    sig_summary = {}
    for model in models:
        sig_count = sum(1 for row in results['summary'] 
                       if (row['Model_1'] == model or row['Model_2'] == model) 
                       and row['MW_significant'])
        sig_summary[model] = sig_count
    
    model_names = list(sig_summary.keys())
    sig_counts = list(sig_summary.values())
    
    bars = ax4.bar(model_names, sig_counts, color=colors, alpha=0.7)
    ax4.set_title('Number of Significant Pairwise Comparisons\n(Mann-Whitney U, p < 0.05)', 
                  fontsize=14, fontweight='bold')
    ax4.set_ylabel('Count of Significant Comparisons', fontsize=12)
    ax4.set_xlabel('Methods', fontsize=12)
    
    # Add value labels on bars
    for bar, count in zip(bars, sig_counts):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
                str(count), ha='center', va='bottom', fontweight='bold')
    
    ax4.set_ylim(0, max(sig_counts) * 1.2)
    ax4.tick_params(axis='x', rotation=45)
    
    plt.suptitle('Statistical Analysis: Method Comparison Results', 
                 fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def generate_statistical_report(run_means, results, corrected_alpha, save_path=None):
    """
    Generate a comprehensive statistical report.
    """
    summary_df = pd.DataFrame(results['summary'])
    
    report = []
    report.append("COMPREHENSIVE STATISTICAL ANALYSIS REPORT")
    report.append("=" * 80)
    report.append(f"Analysis Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
    report.append(f"Number of methods compared: {len(run_means)}")
    report.append(f"Number of independent runs per method: 30")
    report.append(f"Total pairwise comparisons: {len(summary_df)}")
    report.append(f"Significance level (α): 0.05")
    report.append(f"Bonferroni-corrected α: {corrected_alpha:.6f}")
    report.append("")
    
    # Descriptive statistics
    report.append("DESCRIPTIVE STATISTICS")
    report.append("-" * 40)
    for model, means in run_means.items():
        report.append(f"{model.upper()}:")
        report.append(f"  Mean ± SD: {np.mean(means):.3f} ± {np.std(means, ddof=1):.3f}")
        report.append(f"  Median [IQR]: {np.median(means):.3f} [{np.percentile(means, 25):.3f}, {np.percentile(means, 75):.3f}]")
        report.append(f"  Range: [{np.min(means):.3f}, {np.max(means):.3f}]")
        report.append("")
    
    # Omnibus test results
    omnibus_results = perform_omnibus_tests(run_means)
    report.append("OMNIBUS TEST RESULTS")
    report.append("-" * 40)
    report.append(f"Kruskal-Wallis H = {omnibus_results['kruskal_wallis'][0]:.3f}, p = {omnibus_results['kruskal_wallis'][1]:.6f}")
    report.append(f"One-way ANOVA F = {omnibus_results['anova'][0]:.3f}, p = {omnibus_results['anova'][1]:.6f}")
    report.append("")
    
    # Significant comparisons
    significant_mw = summary_df[summary_df['MW_significant']]
    significant_bonf = summary_df[summary_df['MW_significant_bonferroni']]
    
    report.append("SIGNIFICANT PAIRWISE COMPARISONS")
    report.append("-" * 40)
    report.append(f"Without correction: {len(significant_mw)}/{len(summary_df)} significant")
    report.append(f"With Bonferroni correction: {len(significant_bonf)}/{len(summary_df)} significant")
    report.append("")
    
    if len(significant_bonf) > 0:
        report.append("Bonferroni-corrected significant comparisons:")
        for _, row in significant_bonf.iterrows():
            report.append(f"  {row['Model_1']} vs {row['Model_2']}: p = {row['Mann_Whitney_p']:.6f}, d = {row['Cohens_d']:.3f}")
    
    # Effect size interpretation
    report.append("")
    report.append("EFFECT SIZE INTERPRETATION")
    report.append("-" * 40)
    report.append("Cohen's d: 0.2 (small), 0.5 (medium), 0.8 (large)")
    report.append("Cliff's δ: 0.11 (small), 0.28 (medium), 0.43 (large)")
    report.append("")
    
    large_effects = summary_df[abs(summary_df['Cohens_d']) > 0.8]
    medium_effects = summary_df[(abs(summary_df['Cohens_d']) > 0.5) & (abs(summary_df['Cohens_d']) <= 0.8)]
    
    report.append(f"Large effect sizes (|d| > 0.8): {len(large_effects)}")
    report.append(f"Medium effect sizes (0.5 < |d| ≤ 0.8): {len(medium_effects)}")
    
    # Recommendations
    report.append("")
    report.append("STATISTICAL CONCLUSIONS")
    report.append("-" * 40)
    
    if omnibus_results['kruskal_wallis'][1] < 0.05:
        report.append("✓ Omnibus tests confirm significant differences exist between methods")
    else:
        report.append("✗ Omnibus tests do not detect significant differences between methods")
    
    if len(significant_bonf) > 0:
        report.append(f"✓ {len(significant_bonf)} pairwise comparisons remain significant after Bonferroni correction")
        best_method = significant_bonf.loc[significant_bonf['Mean_1'].idxmax(), 'Model_1'] if len(significant_bonf) > 0 else "N/A"
        report.append(f"✓ Recommended method based on statistical analysis: {best_method}")
    else:
        report.append("✗ No pairwise comparisons remain significant after Bonferroni correction")
    
    # Write report to file if path provided
    if save_path:
        with open(save_path, 'w') as f:
            f.write('\n'.join(report))
        print(f"Statistical report saved to: {save_path}")
    
    # Print report
    print('\n'.join(report))
    
    return summary_df

# Main execution
def perform_comprehensive_statistical_analysis(model_results_dict, save_dir=None):
    """
    Perform comprehensive statistical analysis on model comparison results.
    """
    print("COMPREHENSIVE STATISTICAL ANALYSIS")
    print("="*80)
    print("Extracting run means from model results...")
    
    # Extract run means
    run_means = extract_run_means(model_results_dict)
    
    # Perform pairwise tests
    results = perform_pairwise_tests(run_means)
    
    # Apply Bonferroni correction
    summary_df, corrected_alpha = bonferroni_correction(results)
    
    # Create plots
    if save_dir:
        plot_path = f"{save_dir}/statistical_analysis_plots.pdf"
        create_statistical_plots(run_means, results, save_path=plot_path)
    else:
        create_statistical_plots(run_means, results)
    
    # Generate report
    if save_dir:
        report_path = f"{save_dir}/statistical_analysis_report.txt"
        final_summary = generate_statistical_report(run_means, results, corrected_alpha, save_path=report_path)
    else:
        final_summary = generate_statistical_report(run_means, results, corrected_alpha)
    
    # Save summary DataFrame
    if save_dir:
        csv_path = f"{save_dir}/pairwise_comparison_results.csv"
        summary_df.to_csv(csv_path, index=False)
        print(f"Detailed results saved to: {csv_path}")
    
    return summary_df, results, run_means


In [None]:
reinforce_1 = joblib.load(r"results/reinforce_results_runs_oracle_1.joblib")
gflow_1 = joblib.load(r"results/results_run_oracle_1.joblib")
reinforce_baseline_1 = joblib.load(r"results/reinforce_results_runs_oracle_1.joblib")
sac_1 = joblib.load(r"results/sac_results_runs_oracle_1.joblib")
reinforce_2 = joblib.load(r"results/reinforce_results_runs_oracle_2.joblib")
gflow_2 = joblib.load(r"results/results_run_oracle_2.joblib")
sac_2 = joblib.load(r"results/sac_results_runs_oracle_2.joblib")
reinforce_baseline_2 = joblib.load(r"results/reinforce_results_runs_oracle_2.joblib")
reinforce_3 = joblib.load(r"results/reinforce_results_runs_oracle_3.joblib")
gflow_3 = joblib.load(r"results/results_run_oracle_3.joblib")
sac_3 = joblib.load(r"results/sac_results_runs_oracle_3.joblib")
reinforce_baseline_3 = joblib.load(r"results/reinforce_results_runs_oracle_3.joblib")

In [None]:
# Generate high-reward trajectories
# Define the environment
results_dict = {}

for run in tqdm(range(30)):
    CONFIG_PATH = r'configs/run_params.yaml'
    MODEL_PATH = r"oracles/oracle_3.joblib"
    sys.path.append('../.')
    config, env, logger, training_parameters, initial_state, feature_names = setup_environment_and_logger(CONFIG_PATH, MODEL_PATH)
    # Correct state_dim to exclude the time-step part for action selection
    state_dim = len(initial_state) + 1   # Include time-step part for the environment
    action_dim = len(initial_state)     # Exclude time-step part for actions
    trained_agent = sac_3[run]['trained_agent']
    seed = sac_3[run]['seed']
    seed_all(seed)

    high_reward_trajectories, rewards, all_actions, f_names = generate_high_reward_trajectories(
            env,
            trained_agent=trained_agent,
            num_trajectories=200,
            max_steps=training_parameters["trajectory_length"],
        )
    
    results_dict[run] = {
        'trajectories': high_reward_trajectories,
        'rewards': rewards,
        'all_actions': all_actions,
        'seed': seed
    }
    print(f"Run {run+1} completed with seed {seed} - mean reward: {np.mean(rewards)}.")
    
joblib.dump(results_dict, r'oracles/sac_results_runs_oracle_3.joblib')





In [None]:
from utility_functions import seed_all
results_dict = {}
for i in range(30):
    fwd_model_path, bwd_model_path = gflow_3[0]['fwd_model_path'], gflow_3[0]['bwd_model_path']
    seed = gflow_3[i]['seed']
    print(f"Running simulation for seed {seed}...")
    seed_all(seed)
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load config
    config = yaml.safe_load(open(r'config/run_params.yaml', 'r'))
    trajectory_length = 12
    n_trajectories = 200

    # Load the trained models
    forward_model = load_entire_model(fwd_model_path, device)
    backward_model = load_entire_model(bwd_model_path, device)
    logger, log_file_name = setup_logger('experiment')
    top_50_cases_colors = sns.color_palette("husl", 50)
    initial_state,feature_names = load_initial_state(config)
    input_dim = 7
    mixture_components = 15  # Number of mixture components for the beta distribution
    extra_parameters={
                'mixture_components': mixture_components,
                'num_variables': input_dim - 1,
                    }
    model_path = fr"{config['oracle']['model_path']}"

    trajectories, rewards,all_actions,_ = simulate_trajectories(
        env_class=GeneralEnvironment,
        forward_model=forward_model,
        backward_model=backward_model,
        initial_state=initial_state,
        config=config,
        trajectory_length=trajectory_length,
        n_trajectories=n_trajectories,
        device=device,
        logger = logger,
        model_path = model_path,
        distribution='mixture_beta',
        extra_parameters=extra_parameters

    )
    results_dict[i] = {
        'trajectories': trajectories,
        'rewards': np.array(rewards),
        'all_actions': all_actions,
        'seed': seed
    }
joblib.dump(results_dict, r'results/gflow_results_runs_oracle_3.joblib')


In [None]:
smcmc_1[0]['results_dict'].keys()

In [None]:
import joblib

In [None]:
# Oracle 1
res_g1 = joblib.load(r"results/gflow_results_runs_oracle_1.joblib")
res_reinforce_1 = joblib.load(r"results/reinforce_results_runs_oracle_1.joblib")
res_reinforce_baseline_1 = joblib.load(r"results/reinforce_baseline_results_runs_oracle_1.joblib")
res_sac_1 = joblib.load(r"results/sac_results_runs_oracle_1.joblib")
smcmc_1 = joblib.load(r"results/PaperPlots_O1_SMCMC/mcmc_results_runs_oracle_1.joblib")

new_smcmc_1 = {}
for key,val in smcmc_1.items():
    new_smcmc_1[key] = {
        'rewards': val['results_dict']['rewards'],
        'trajectories': val['results_dict']['trajectories'],
        'all_actions': val['results_dict']['actions'],
        'seed': val['seed']
    }

model_results_dict_1 = {
    'gflow': res_g1,
    'reinforce': res_reinforce_1,
    'reinforce_baseline': res_reinforce_baseline_1,
    'sac': res_sac_1,
    'smcmc': new_smcmc_1
}

# Oracle 2
res_g2 = joblib.load(r"results/gflow_results_runs_oracle_2.joblib")
res_reinforce_2 = joblib.load(r"results/reinforce_results_runs_oracle_2.joblib")
res_reinforce_baseline_2 = joblib.load(r"results/reinforce_baseline_results_runs_oracle_2.joblib")
res_sac_2 = joblib.load(r"results/sac_results_runs_oracle_2.joblib")
smcmc_2 = joblib.load(r"results/PaperPlots_O2_SMCMC/mcmc_results_runs_oracle_2.joblib")

new_smcmc_2 = {}
for key,val in smcmc_2.items():
    new_smcmc_2[key] = {
        'rewards': val['results_dict']['rewards'],
        'trajectories': val['results_dict']['trajectories'],
        'all_actions': val['results_dict']['actions'],
        'seed': val['seed']
    }
model_results_dict_2 = {
    'gflow': res_g2,
    'reinforce': res_reinforce_2,
    'reinforce_baseline': res_reinforce_baseline_2,
    'sac': res_sac_2,
    'smcmc': new_smcmc_2
}

# Oracle 3
res_g3 = joblib.load(r"results/gflow_results_runs_oracle_3.joblib")
res_reinforce_3 = joblib.load(r"results/reinforce_results_runs_oracle_3.joblib")
res_reinforce_baseline_3 = joblib.load(r"results/reinforce_baseline_results_runs_oracle_3.joblib")
res_sac_3 = joblib.load(r"results/sac_results_runs_oracle_3.joblib")
smcmc_3 = joblib.load(r"results/PaperPlots_O3_SMCMC/mcmc_results_runs_oracle_3.joblib")

new_smcmc_3 = {}
for key,val in smcmc_3.items():
    new_smcmc_3[key] = {
        'rewards': val['results_dict']['rewards'],
        'trajectories': val['results_dict']['trajectories'],
        'all_actions': val['results_dict']['actions'],
        'seed': val['seed']
    }

model_results_dict_3 = {
    'gflow': res_g3,
    'reinforce': res_reinforce_3,
    'reinforce_baseline': res_reinforce_baseline_3,
    'sac': res_sac_3,
    'smcmc': new_smcmc_3
}

In [None]:
res_g1[0]['rewards'].shape

In [None]:
def plot_multi_model_reward_distributions(model_results_dict, save_path=None):
    """
    Plot reward distributions for multiple models across multiple runs.
    
    Args:
        model_results_dict (dict): Dictionary where keys are model names and values are 
                                 dictionaries with run results containing 'rewards' arrays
        save_path (str, optional): Path to save the plot
    """
    from scipy.stats import gaussian_kde
    
    # Create professional distribution plot
    plt.figure(figsize=(14, 7), dpi=300)
    
    # Define color palette for different models
    colors = plt.cm.Set1(np.linspace(0, 1, len(model_results_dict)))
    
    model_stats = {}
    
    for model_idx, (model_name, results_dict) in enumerate(model_results_dict.items()):
        # Collect all rewards from all runs for this model
        all_rewards = []
        for run_idx in range(len(results_dict)):
            if results_dict[run_idx]['rewards'].ndim == 1:
                # If rewards are 1D, reshape to 2D with one column
                rewards = results_dict[run_idx]['rewards']
            else:
                # If rewards are already 2D, take the last column
                rewards = results_dict[run_idx]['rewards'][:, -1]
            all_rewards.append(rewards)
        
        # Plot each run's KDE distribution
        for run_idx, rewards in enumerate(all_rewards):
            sns.kdeplot(data=rewards, 
                        color=colors[model_idx], 
                        alpha=0.5, 
                        linewidth=2)
        
        # Calculate overall statistics for this model
        all_values = np.concatenate(all_rewards)
        overall_mean = np.mean(all_values)
        overall_std = np.std(all_values)
        
        # Store stats for legend
        model_stats[model_name] = {
            'mean': overall_mean,
            'std': overall_std,
            'color': colors[model_idx],
            'total_trajectories': sum(len(rewards) for rewards in all_rewards)
        }
        
        # Add mean line for this model
        plt.axvline(overall_mean, color=colors[model_idx], linestyle='--', 
                   linewidth=3, alpha=0.8)
    
    # Create custom legend
    legend_elements = []
    for model_name, stats in model_stats.items():
        legend_elements.append(
            plt.Line2D([0], [0], color=stats['color'], lw=3, alpha=0.8,
                      label=f'{model_name}: μ={stats["mean"]:.2f}±{stats["std"]:.2f}')
        )
    
    # Styling
    plt.title('Model Comparison: Reward Distributions Across 30 Independent Runs', 
              fontsize=20, fontweight='bold', pad=25)
    plt.xlabel('Final Reward Values', fontsize=16, fontweight='bold')
    plt.ylabel('Density', fontsize=16, fontweight='bold')
    
    # Professional styling
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.legend(handles=legend_elements, fontsize=14, frameon=True, 
              framealpha=0.9, loc='upper right', fancybox=True, shadow=True)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    # Increase tick label size
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to: {save_path}")
    
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("MODEL COMPARISON SUMMARY (30 independent runs each)")
    print("="*80)
    for model_name, stats in model_stats.items():
        print(f"\n{model_name}:")
        print(f"  Mean: {stats['mean']:.3f} ± {stats['std']:.3f}")
        print(f"  Total Trajectories: {stats['total_trajectories']}")
    
    return model_stats


stats = plot_multi_model_reward_distributions(
    model_results_dict_1, 
    save_path='multi_model_reward_distributions.pdf'
)

In [None]:
def plot_multi_model_small_multiples(model_results_dict, save_path=None):
    """
    Small multiples showing each method's run variability with mean distributions and confidence intervals.
    All subplots share the same scale for easy comparison.
    """
    n_models = len(model_results_dict)
    fig, axes = plt.subplots(1, n_models, figsize=(4*n_models, 6), dpi=300, 
                            sharey=True)  # Share y-axis for comparison
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    model_stats = {}
    
    # First pass: collect all data to determine global x and y ranges
    global_x_min = float('inf')
    global_x_max = float('-inf')
    global_y_max = 0
    all_model_data = {}
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            # Update global ranges
            model_x_min = min(np.min(rewards) for rewards in all_rewards)
            model_x_max = max(np.max(rewards) for rewards in all_rewards)
            global_x_min = min(global_x_min, model_x_min)
            global_x_max = max(global_x_max, model_x_max)
            
            # Store processed data for second pass
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': np.array(run_means)
            }
    
    # Add some padding to the global range
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    
    # Common x range for all KDE calculations
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # Second pass: create plots with consistent scales
    for idx, (model_name, data) in enumerate(all_model_data.items()):
        ax = axes[idx] if n_models > 1 else axes
        
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        consistency = 1 / (1 + np.std(run_means) / np.mean(run_means))
        
        # Plot styling based on consistency
        if consistency > 0.8:
            alpha_individual = 0.3
            linewidth_individual = 1.5
            consistency_label = "HIGH"
            border_color = 'green'
        elif consistency > 0.6:
            alpha_individual = 0.25
            linewidth_individual = 1.2
            consistency_label = "MEDIUM"
            border_color = 'orange'
        else:
            alpha_individual = 0.2
            linewidth_individual = 1.0
            consistency_label = "LOW"
            border_color = 'red'
        
        # Plot individual runs using common x range
        for rewards in all_rewards:
            if len(rewards) > 1:
                try:
                    from scipy.stats import gaussian_kde
                    kde = gaussian_kde(rewards)
                    density = kde(common_x)
                    ax.plot(common_x, density, color=colors.get(model_name, 'gray'), 
                           alpha=alpha_individual, linewidth=linewidth_individual)
                    global_y_max = max(global_y_max, np.max(density))
                except:
                    continue
        
        # Calculate and plot mean distribution with confidence intervals
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                
                # Calculate mean and confidence bounds
                mean_curve = np.mean(kde_curves, axis=0)
                std_curve = np.std(kde_curves, axis=0)
                
                # Plot mean distribution (thicker line)
                ax.plot(common_x, mean_curve, 
                       color=colors.get(model_name, 'gray'), 
                       linewidth=4, alpha=0.9)
                
                # Add confidence intervals
                ax.fill_between(common_x, 
                               mean_curve - std_curve, 
                               mean_curve + std_curve,
                               color=colors.get(model_name, 'gray'), 
                               alpha=0.15)
                
                ax.fill_between(common_x, 
                               mean_curve - 0.5*std_curve, 
                               mean_curve + 0.5*std_curve,
                               color=colors.get(model_name, 'gray'), 
                               alpha=0.25)
        
        # Add overall mean line
        overall_mean = np.mean(np.concatenate(all_rewards))
        ax.axvline(overall_mean, color='black', linestyle='--', linewidth=3, alpha=0.8)
        
        # Set consistent x-axis limits for all subplots
        ax.set_xlim(global_x_min, global_x_max)
        
        # Styling for each subplot
        ax.set_title(f'{model_name.upper()}\nConsistency: {consistency_label}', 
                    fontsize=14, fontweight='bold', pad=15)
        ax.set_xlabel('Final Reward', fontsize=12)
        if idx == 0:
            ax.set_ylabel('Density', fontsize=12)
        
        # Add colored border to indicate consistency
        for spine in ax.spines.values():
            spine.set_edgecolor(border_color)
            spine.set_linewidth(3)
        
        # Add stats text
        stats_text = f'μ={overall_mean:.1f}\nσ_runs={np.std(run_means):.2f}\nCV={np.std(run_means)/np.mean(run_means):.3f}'
        ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                fontsize=10)
        
        # Add grid for better readability
        ax.grid(True, alpha=0.3)
        
        model_stats[model_name] = {
            'mean': overall_mean, 
            'consistency': consistency,
            'run_std': np.std(run_means),
            'cv': np.std(run_means)/np.mean(run_means),
            'n_runs': len(run_means)
        }
    
    # Set consistent y-axis limits for all subplots
    y_padding = global_y_max * 0.05
    for ax in (axes if n_models > 1 else [axes]):
        ax.set_ylim(0, global_y_max + y_padding)
    
    plt.suptitle('Model Consistency Analysis: Individual Runs + Mean Distribution with Confidence Intervals\n'
                 'Border Color: Green=Consistent, Orange=Moderate, Red=Variable', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("MODEL CONSISTENCY SUMMARY")
    print("="*80)
    sorted_models = sorted(model_stats.items(), key=lambda x: x[1]['consistency'], reverse=True)
    for rank, (model_name, stats) in enumerate(sorted_models, 1):
        print(f"\n{rank}. {model_name.upper()}:")
        print(f"   Consistency Score: {stats['consistency']:.3f}")
        print(f"   Mean Reward: {stats['mean']:.2f}")
        print(f"   Run Std Dev: {stats['run_std']:.2f}")
        print(f"   Coefficient of Variation: {stats['cv']:.3f}")
        print(f"   Number of Runs: {stats['n_runs']}")
    
    return model_stats

# Enhanced version with shared scales
def plot_multi_model_small_multiples_enhanced(model_results_dict, save_path=None, show_percentiles=True):
    """
    Enhanced version with percentile bands and shared scales for easy comparison.
    """
    n_models = len(model_results_dict)
    fig, axes = plt.subplots(2, n_models, figsize=(4*n_models, 10), dpi=300)
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    model_stats = {}
    
    # Collect all data first to determine global ranges
    global_x_min = float('inf')
    global_x_max = float('-inf')
    global_y_max = 0
    global_run_means_min = float('inf')
    global_run_means_max = float('-inf')
    all_model_data = {}
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            run_means = np.array(run_means)
            
            # Update global ranges
            model_x_min = min(np.min(rewards) for rewards in all_rewards)
            model_x_max = max(np.max(rewards) for rewards in all_rewards)
            global_x_min = min(global_x_min, model_x_min)
            global_x_max = max(global_x_max, model_x_max)
            global_run_means_min = min(global_run_means_min, np.min(run_means))
            global_run_means_max = max(global_run_means_max, np.max(run_means))
            
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': run_means
            }
    
    # Add padding to ranges
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    
    y_padding = (global_run_means_max - global_run_means_min) * 0.1
    global_run_means_min -= y_padding
    global_run_means_max += y_padding
    
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # Create plots with shared scales
    for idx, (model_name, data) in enumerate(all_model_data.items()):
        ax_top = axes[0, idx] if n_models > 1 else axes[0]
        ax_bottom = axes[1, idx] if n_models > 1 else axes[1]
        
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        consistency = 1 / (1 + np.std(run_means) / np.mean(run_means))
        color = colors.get(model_name, 'gray')
        
        # Plot individual runs on top subplot
        for rewards in all_rewards:
            if len(rewards) > 1:
                try:
                    from scipy.stats import gaussian_kde
                    kde = gaussian_kde(rewards)
                    density = kde(common_x)
                    ax_top.plot(common_x, density, color=color, alpha=0.2, linewidth=1)
                    global_y_max = max(global_y_max, np.max(density))
                except:
                    continue
        
        # Mean distribution with confidence intervals
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                mean_curve = np.mean(kde_curves, axis=0)
                
                # Plot percentile bands if requested
                if show_percentiles:
                    p25_curve = np.percentile(kde_curves, 25, axis=0)
                    p75_curve = np.percentile(kde_curves, 75, axis=0)
                    p10_curve = np.percentile(kde_curves, 10, axis=0)
                    p90_curve = np.percentile(kde_curves, 90, axis=0)
                    
                    ax_top.fill_between(common_x, p10_curve, p90_curve,
                                       color=color, alpha=0.1)
                    ax_top.fill_between(common_x, p25_curve, p75_curve,
                                       color=color, alpha=0.2)
                
                # Mean line
                ax_top.plot(common_x, mean_curve, color=color, linewidth=4, alpha=0.9)
        
        # Set consistent limits for top subplot
        ax_top.set_xlim(global_x_min, global_x_max)
        ax_top.set_title(f'{model_name.upper()}\nConsistency: {consistency:.3f}', 
                        fontsize=12, fontweight='bold')
        ax_top.grid(True, alpha=0.3)
        
        # Bottom subplot: Box plot with consistent scale
        ax_bottom.boxplot([run_means], positions=[0], widths=0.6, patch_artist=True,
                         boxprops=dict(facecolor=color, alpha=0.7))
        ax_bottom.set_xlim(-0.5, 0.5)
        ax_bottom.set_ylim(global_run_means_min, global_run_means_max)
        ax_bottom.set_ylabel('Run Means', fontsize=10)
        ax_bottom.grid(True, alpha=0.3)
        
        # Add statistics text
        stats_text = f'Mean: {np.mean(run_means):.1f}\nStd: {np.std(run_means):.2f}\nCV: {np.std(run_means)/np.mean(run_means):.3f}'
        ax_bottom.text(0.02, 0.98, stats_text, transform=ax_bottom.transAxes,
                      verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                      fontsize=9)
        
        model_stats[model_name] = {
            'mean': np.mean(run_means), 
            'consistency': consistency,
            'run_std': np.std(run_means),
            'cv': np.std(run_means)/np.mean(run_means),
            'n_runs': len(run_means)
        }
    
    # Set consistent y-limits for all top subplots
    y_max_padding = global_y_max * 0.05
    for idx in range(n_models):
        ax_top = axes[0, idx] if n_models > 1 else axes[0]
        ax_top.set_ylim(0, global_y_max + y_max_padding)
        
        # Only add y-label to leftmost subplot
        if idx == 0:
            ax_top.set_ylabel('Density', fontsize=12)
    
    plt.suptitle('Enhanced Model Consistency Analysis with Shared Scales for Easy Comparison', 
                fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path.replace('.pdf', '_enhanced.pdf'), dpi=300, bbox_inches='tight')
    plt.show()
    
    return model_stats

# Call the functions with shared scales
stats_small_multiples = plot_multi_model_small_multiples(
    model_results_dict_1, 
    save_path='multi_model_small_multiples_shared_scale.pdf'
)

# Enhanced version with shared scales
stats_enhanced = plot_multi_model_small_multiples_enhanced(
    model_results_dict_1,
    save_path='multi_model_small_multiples_enhanced_shared_scale.pdf'
)

In [None]:
def plot_multi_model_combined_comparison(model_results_dict, save_path=None):
    """
    Single plot showing all models together with mean distributions and confidence intervals.
    All models share the same scale for direct comparison.
    """
    plt.figure(figsize=(16, 10), dpi=300)
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    model_stats = {}
    
    # First pass: collect all data to determine global x range
    global_x_min = float('inf')
    global_x_max = float('-inf')
    all_model_data = {}
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            # Update global ranges
            model_x_min = min(np.min(rewards) for rewards in all_rewards)
            model_x_max = max(np.max(rewards) for rewards in all_rewards)
            global_x_min = min(global_x_min, model_x_min)
            global_x_max = max(global_x_max, model_x_max)
            
            # Store processed data for second pass
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': np.array(run_means)
            }
    
    # Add some padding to the global range
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    
    # Common x range for all KDE calculations
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # Second pass: create plots for each model
    for model_name, data in all_model_data.items():
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        consistency = 1 / (1 + np.std(run_means) / np.mean(run_means))
        
        # Plot styling based on consistency
        if consistency > 0.8:
            alpha_individual = 0.6
            linewidth_individual = 1.0
            consistency_label = "HIGH"
        elif consistency > 0.6:
            alpha_individual = 0.6
            linewidth_individual = 0.8
            consistency_label = "MEDIUM"
        else:
            alpha_individual = 0.6
            linewidth_individual = 0.6
            consistency_label = "LOW"
        
        color = colors.get(model_name, 'gray')
        
        # Plot individual runs with low alpha
        for rewards in all_rewards:
            if len(rewards) > 1:
                try:
                    from scipy.stats import gaussian_kde
                    kde = gaussian_kde(rewards)
                    density = kde(common_x)
                    plt.plot(common_x, density, color=color, 
                           alpha=alpha_individual, linewidth=linewidth_individual)
                except:
                    continue
        
        # Calculate and plot mean distribution with confidence intervals
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                
                # Calculate mean and confidence bounds
                mean_curve = np.mean(kde_curves, axis=0)
                std_curve = np.std(kde_curves, axis=0)
                
                # Plot confidence intervals first (so they appear behind the line)
                plt.fill_between(common_x, 
                               mean_curve, 
                               mean_curve + std_curve,
                               color=color, alpha=0.2, 
                               label=f'{model_name.upper()} ±1σ')
                
                plt.fill_between(common_x, 
                               mean_curve, 
                               mean_curve + 0.5*std_curve,
                               color=color, alpha=0.3)
                
                # Plot mean distribution (thicker line)
                overall_mean = np.mean(np.concatenate(all_rewards))
                plt.plot(common_x, mean_curve, 
                       color=color, linewidth=4, alpha=0.9,
                       label=f'{model_name.upper()}: μ={overall_mean:.1f}, {consistency_label} consistency')
                
                # Add overall mean line
                plt.axvline(overall_mean, color=color, linestyle='--', 
                          linewidth=2, alpha=0.7)
        
        model_stats[model_name] = {
            'mean': overall_mean, 
            'consistency': consistency,
            'run_std': np.std(run_means),
            'cv': np.std(run_means)/np.mean(run_means),
            'n_runs': len(run_means)
        }
    
    # Styling
    plt.title('Model Comparison: Individual Runs + Mean Distributions with Confidence Intervals\n'
              'Thick lines = Mean distribution, Dashed lines = Overall means, Shaded areas = ±1σ confidence', 
              fontsize=18, fontweight='bold', pad=25)
    plt.xlabel('Final Reward', fontsize=16, fontweight='bold')
    plt.ylabel('Density', fontsize=16, fontweight='bold')
    
    # Professional styling
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.legend(fontsize=12, frameon=True, framealpha=0.9, loc='upper left')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    # Increase tick label size
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    # Add statistics text box
    stats_text = "Consistency Ranking:\n"
    sorted_models = sorted(model_stats.items(), key=lambda x: x[1]['consistency'], reverse=True)
    for rank, (model_name, stats) in enumerate(sorted_models, 1):
        stats_text += f"{rank}. {model_name.upper()}: {stats['consistency']:.3f}\n"
    
    plt.text(0.98, 0.98, stats_text, transform=plt.gca().transAxes,
             verticalalignment='top', horizontalalignment='right',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor='black'),
             fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("MODEL CONSISTENCY SUMMARY - COMBINED COMPARISON")
    print("="*80)
    for rank, (model_name, stats) in enumerate(sorted_models, 1):
        print(f"\n{rank}. {model_name.upper()}:")
        print(f"   Consistency Score: {stats['consistency']:.3f}")
        print(f"   Mean Reward: {stats['mean']:.2f}")
        print(f"   Run Std Dev: {stats['run_std']:.2f}")
        print(f"   Coefficient of Variation: {stats['cv']:.3f}")
        print(f"   Number of Runs: {stats['n_runs']}")
    
    return model_stats

# Alternative version with even cleaner styling
def plot_multi_model_combined_clean(model_results_dict, save_path=None):
    """
    Clean version focusing just on mean distributions and overall means.
    """
    plt.figure(figsize=(14, 8), dpi=300)
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    model_stats = {}
    
    # Collect data and calculate global range
    global_x_min = float('inf')
    global_x_max = float('-inf')
    all_model_data = {}
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            model_x_min = min(np.min(rewards) for rewards in all_rewards)
            model_x_max = max(np.max(rewards) for rewards in all_rewards)
            global_x_min = min(global_x_min, model_x_min)
            global_x_max = max(global_x_max, model_x_max)
            
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': np.array(run_means)
            }
    
    # Add padding
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # Plot each model
    for model_name, data in all_model_data.items():
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        color = colors.get(model_name, 'gray')
        overall_mean = np.mean(np.concatenate(all_rewards))
        consistency = 1 / (1 + np.std(run_means) / np.mean(run_means))
        
        # Calculate mean KDE curve
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        from scipy.stats import gaussian_kde
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                mean_curve = np.mean(kde_curves, axis=0)
                std_curve = np.std(kde_curves, axis=0)
                
                # Plot mean distribution
                plt.plot(common_x, mean_curve, 
                       color=color, linewidth=3, alpha=0.9,
                       label=f'{model_name.upper()}: μ={overall_mean:.1f} (CV={np.std(run_means)/np.mean(run_means):.3f})')
                
                # Add light confidence band
                plt.fill_between(common_x, 
                               mean_curve, 
                               mean_curve + 0.5*std_curve,
                               color=color, alpha=0.2)
        
        model_stats[model_name] = {
            'mean': overall_mean, 
            'consistency': consistency,
            'cv': np.std(run_means)/np.mean(run_means)
        }
    
    plt.title('Model Performance Comparison: Mean Reward Distributions\n'
              'Lines show average distribution across 30 independent runs', 
              fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Final Reward', fontsize=14, fontweight='bold')
    plt.ylabel('Density', fontsize=14, fontweight='bold')
    
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12, frameon=True, framealpha=0.9)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path.replace('.pdf', '_clean.pdf'), dpi=300, bbox_inches='tight')
    plt.show()
    
    return model_stats

# Call the functions
stats_combined = plot_multi_model_combined_comparison(
    model_results_dict_1, 
    save_path='multi_model_combined_comparison.pdf'
)

# Clean version
stats_clean = plot_multi_model_combined_clean(
    model_results_dict_1,
    save_path='multi_model_combined_clean.pdf'
)

In [None]:
def plot_multi_model_combined_boxplot_distribution(model_results_dict, save_path=None):
    """
    Two-panel plot: boxplots on top, distributions on bottom with shared x-axis.
    """
    fig, (ax_top, ax_bottom) = plt.subplots(2, 1, figsize=(16, 12), dpi=300, 
                                           sharex=True, gridspec_kw={'height_ratios': [1, 1.5]})
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    model_stats = {}
    
    # First pass: collect all data to determine global x range
    global_x_min = float('inf')
    global_x_max = float('-inf')
    all_model_data = {}
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            # Update global ranges
            model_x_min = min(np.min(rewards) for rewards in all_rewards)
            model_x_max = max(np.max(rewards) for rewards in all_rewards)
            global_x_min = min(global_x_min, model_x_min)
            global_x_max = max(global_x_max, model_x_max)
            
            # Store processed data
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': np.array(run_means)
            }
    
    # Add padding to the global range
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    
    # Common x range for KDE calculations
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # Prepare data for boxplots
    boxplot_data = []
    boxplot_labels = []
    boxplot_colors = []
    positions = []
    
    # TOP PANEL: Horizontal boxplots
    y_pos = 0
    for model_name, data in all_model_data.items():
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        # Create boxplot for run means
        bp = ax_top.boxplot([run_means], positions=[y_pos], vert=False, widths=0.6, 
                           patch_artist=True, showfliers=True,
                           boxprops=dict(facecolor=colors.get(model_name, 'gray'), alpha=0.7),
                           medianprops=dict(color='black', linewidth=2),
                           flierprops=dict(marker='o', markerfacecolor=colors.get(model_name, 'gray'), 
                                         markersize=4, alpha=0.6))
        
        # Add individual run means as scatter points
        ax_top.scatter(run_means, [y_pos] * len(run_means), 
                      color=colors.get(model_name, 'gray'), alpha=0.4, s=20, zorder=10)
        
        consistency = 1 / (1 + np.std(run_means) / np.mean(run_means))
        overall_mean = np.mean(np.concatenate(all_rewards))
        
        model_stats[model_name] = {
            'mean': overall_mean, 
            'consistency': consistency,
            'run_std': np.std(run_means),
            'cv': np.std(run_means)/np.mean(run_means),
            'n_runs': len(run_means),
            'run_means': run_means
        }
        
        y_pos += 1
    
    # Style top panel
    ax_top.set_yticks(range(len(all_model_data)))
    ax_top.set_yticklabels([name.upper() for name in all_model_data.keys()], fontsize=12)
    ax_top.set_ylabel('Methods', fontsize=14, fontweight='bold')
    ax_top.set_title('Run-Level Performance Distribution (Boxplots)\n'
                    'Each box shows distribution of mean rewards across 30 runs', 
                    fontsize=16, fontweight='bold', pad=20)
    ax_top.grid(True, alpha=0.3, axis='x')
    ax_top.set_xlim(global_x_min, global_x_max)
    
    # Add statistics text for top panel
    stats_text = "Run Statistics:\n"
    sorted_models = sorted(model_stats.items(), key=lambda x: x[1]['consistency'], reverse=True)
    for rank, (model_name, stats) in enumerate(sorted_models, 1):
        stats_text += f"{rank}. {model_name.upper()}: CV={stats['cv']:.3f}\n"
    
    ax_top.text(0.98, 0.98, stats_text, transform=ax_top.transAxes,
               verticalalignment='top', horizontalalignment='right',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor='black'),
               fontsize=10, fontweight='bold')
    
    # BOTTOM PANEL: Distribution plots
    for model_name, data in all_model_data.items():
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        color = colors.get(model_name, 'gray')
        consistency = model_stats[model_name]['consistency']
        overall_mean = model_stats[model_name]['mean']
        
        # Calculate mean KDE curve
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        from scipy.stats import gaussian_kde
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                mean_curve = np.mean(kde_curves, axis=0)
                std_curve = np.std(kde_curves, axis=0)
                
                # Plot confidence intervals
                ax_bottom.fill_between(common_x, 
                                     mean_curve - std_curve, 
                                     mean_curve + std_curve,
                                     color=color, alpha=0.15)
                
                ax_bottom.fill_between(common_x, 
                                     mean_curve - 0.5*std_curve, 
                                     mean_curve + 0.5*std_curve,
                                     color=color, alpha=0.25)
                
                # Plot mean distribution (thick line)
                ax_bottom.plot(common_x, mean_curve, 
                             color=color, linewidth=4, alpha=0.9,
                             label=f'{model_name.upper()}: μ={overall_mean:.1f}')
                
                # Add overall mean line
                ax_bottom.axvline(overall_mean, color=color, linestyle='--', 
                                linewidth=2, alpha=0.7)
    
    # Style bottom panel
    ax_bottom.set_xlabel('Final Reward', fontsize=14, fontweight='bold')
    ax_bottom.set_ylabel('Density', fontsize=14, fontweight='bold')
    ax_bottom.set_title('Aggregate Reward Distributions\n'
                       'Mean distribution across all runs with confidence intervals', 
                       fontsize=16, fontweight='bold', pad=20)
    ax_bottom.grid(True, alpha=0.3)
    ax_bottom.legend(fontsize=12, frameon=True, framealpha=0.9, loc='upper left')
    ax_bottom.set_xlim(global_x_min, global_x_max)
    
    # Overall styling
    plt.suptitle('Model Performance Analysis: Run-Level Variability vs Aggregate Distributions\n'
                'Top: Individual run performance | Bottom: Overall reward distributions', 
                fontsize=18, fontweight='bold', y=0.98)
    
    # Remove spines
    for ax in [ax_top, ax_bottom]:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.90)  # Make room for suptitle
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("MODEL PERFORMANCE SUMMARY - BOXPLOT + DISTRIBUTION ANALYSIS")
    print("="*80)
    for rank, (model_name, stats) in enumerate(sorted_models, 1):
        print(f"\n{rank}. {model_name.upper()}:")
        print(f"   Mean Reward: {stats['mean']:.2f}")
        print(f"   Run-to-Run CV: {stats['cv']:.3f}")
        print(f"   Consistency Score: {stats['consistency']:.3f}")
        print(f"   Run Std Dev: {stats['run_std']:.2f}")
        print(f"   Number of Runs: {stats['n_runs']}")
    
    return model_stats

# Alternative version with violin plots instead of boxplots
def plot_multi_model_combined_violin_distribution(model_results_dict, save_path=None):
    """
    Two-panel plot: violin plots on top, distributions on bottom with shared x-axis.
    """
    fig, (ax_top, ax_bottom) = plt.subplots(2, 1, figsize=(16, 12), dpi=300, 
                                           sharex=True, gridspec_kw={'height_ratios': [1, 1.5]})
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    # [Same data collection logic as above]
    all_model_data = {}
    global_x_min = float('inf')
    global_x_max = float('-inf')
    
    for model_name, results_dict in model_results_dict.items():
        all_rewards = []
        run_means = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                if rewards.ndim == 2:
                    rewards = rewards[:, -1]
                elif rewards.ndim > 2:
                    continue
                
                rewards = rewards[np.isfinite(rewards)]
                
                if len(rewards) > 0:
                    all_rewards.append(rewards)
                    run_means.append(np.mean(rewards))
                    
            except Exception as e:
                continue
        
        if len(all_rewards) > 0:
            model_x_min = min(np.min(rewards) for rewards in all_rewards)
            model_x_max = max(np.max(rewards) for rewards in all_rewards)
            global_x_min = min(global_x_min, model_x_min)
            global_x_max = max(global_x_max, model_x_max)
            
            all_model_data[model_name] = {
                'all_rewards': all_rewards,
                'run_means': np.array(run_means)
            }
    
    x_padding = (global_x_max - global_x_min) * 0.05
    global_x_min -= x_padding
    global_x_max += x_padding
    common_x = np.linspace(global_x_min, global_x_max, 300)
    
    # TOP PANEL: Horizontal violin plots
    violin_data = []
    y_positions = []
    violin_colors = []
    
    for i, (model_name, data) in enumerate(all_model_data.items()):
        run_means = data['run_means']
        violin_data.append(run_means)
        y_positions.append(i)
        violin_colors.append(colors.get(model_name, 'gray'))
    
    # Create horizontal violin plots
    parts = ax_top.violinplot(violin_data, positions=y_positions, vert=False, 
                             showmeans=True, showmedians=True)
    
    # Color the violins
    for pc, color in zip(parts['bodies'], violin_colors):
        pc.set_facecolor(color)
        pc.set_alpha(0.7)
    
    # Style top panel
    ax_top.set_yticks(range(len(all_model_data)))
    ax_top.set_yticklabels([name.upper() for name in all_model_data.keys()], fontsize=12)
    ax_top.set_ylabel('Methods', fontsize=14, fontweight='bold')
    ax_top.set_title('Run-Level Performance Distribution (Violin Plots)', 
                    fontsize=16, fontweight='bold', pad=20)
    ax_top.grid(True, alpha=0.3, axis='x')
    ax_top.set_xlim(global_x_min, global_x_max)
    
    # BOTTOM PANEL: Same as before
    model_stats = {}
    for model_name, data in all_model_data.items():
        all_rewards = data['all_rewards']
        run_means = data['run_means']
        
        color = colors.get(model_name, 'gray')
        overall_mean = np.mean(np.concatenate(all_rewards))
        consistency = 1 / (1 + np.std(run_means) / np.mean(run_means))
        
        if len(all_rewards) > 1:
            kde_curves = []
            for rewards in all_rewards:
                if len(rewards) > 1:
                    try:
                        from scipy.stats import gaussian_kde
                        kde = gaussian_kde(rewards)
                        kde_curve = kde(common_x)
                        kde_curves.append(kde_curve)
                    except:
                        continue
            
            if len(kde_curves) > 0:
                kde_curves = np.array(kde_curves)
                mean_curve = np.mean(kde_curves, axis=0)
                std_curve = np.std(kde_curves, axis=0)
                
                ax_bottom.fill_between(common_x, 
                                     mean_curve - 0.5*std_curve, 
                                     mean_curve + 0.5*std_curve,
                                     color=color, alpha=0.25)
                
                ax_bottom.plot(common_x, mean_curve, 
                             color=color, linewidth=4, alpha=0.9,
                             label=f'{model_name.upper()}: μ={overall_mean:.1f}')
                
                ax_bottom.axvline(overall_mean, color=color, linestyle='--', 
                                linewidth=2, alpha=0.7)
        
        model_stats[model_name] = {
            'mean': overall_mean, 
            'consistency': consistency
        }
    
    ax_bottom.set_xlabel('Final Reward', fontsize=14, fontweight='bold')
    ax_bottom.set_ylabel('Density', fontsize=14, fontweight='bold')
    ax_bottom.set_title('Aggregate Reward Distributions', 
                       fontsize=16, fontweight='bold', pad=20)
    ax_bottom.grid(True, alpha=0.3)
    ax_bottom.legend(fontsize=12, frameon=True, framealpha=0.9)
    ax_bottom.set_xlim(global_x_min, global_x_max)
    
    plt.suptitle('Model Performance: Violin + Distribution Analysis', 
                fontsize=18, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path.replace('.pdf', '_violin.pdf'), dpi=300, bbox_inches='tight')
    plt.show()
    
    return model_stats

# # Call the functions
# stats_boxplot = plot_multi_model_combined_boxplot_distribution(
#     model_results_dict_1, 
#     save_path='multi_model_boxplot_distribution.pdf'
# )

# # Violin version
# stats_violin = plot_multi_model_combined_violin_distribution(
#     model_results_dict_1,
#     save_path='multi_model_violin_distribution.pdf'
# )

In [None]:


# # Call the corrected functions
# stats_corrected = plot_multi_model_reward_distributions_with_ci_corrected(
#     model_results_dict_1, 
#     save_path='multi_model_distributions_corrected.pdf'
# )

# Or use the clean version
stats_clean = plot_multi_model_clean_kde(
    model_results_dict_1,
    save_path='multi_model_distributions_clean.pdf'
)

In [None]:
def plot_multi_model_reward_boxviolin(model_results_dict, save_path=None):
    """
    Create box + violin plot showing distribution shapes and statistics.
    """
    plt.figure(figsize=(14, 8), dpi=300)
    
    # Prepare data
    all_data = []
    method_names = []
    
    for model_name, results_dict in model_results_dict.items():
        method_data = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            if isinstance(rewards, (list, tuple)):
                rewards = np.array(rewards, dtype=float)
            elif isinstance(rewards, np.ndarray):
                rewards = rewards.astype(float)
            
            if rewards.ndim == 1:
                final_rewards = rewards
            elif rewards.ndim == 2:
                final_rewards = rewards[:, -1]
            else:
                continue
                
            final_rewards = final_rewards[np.isfinite(final_rewards)]
            if len(final_rewards) > 0:
                method_data.extend(final_rewards)
        
        if method_data:
            all_data.append(method_data)
            method_names.append(model_name)
    
    # Create violin plot
    parts = plt.violinplot(all_data, positions=range(len(method_names)), 
                          showmeans=True, showmedians=True)
    
    # Customize violin colors
    colors = plt.cm.Set1(np.linspace(0, 1, len(method_names)))
    for pc, color in zip(parts['bodies'], colors):
        pc.set_facecolor(color)
        pc.set_alpha(0.7)
    
    # Add box plot overlay
    bp = plt.boxplot(all_data, positions=range(len(method_names)), 
                    widths=0.3, patch_artist=True, 
                    boxprops=dict(alpha=0.3),
                    showfliers=False)  # Hide outliers for cleaner look
    
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    # Styling
    plt.title('Model Comparison: Distribution Shapes and Statistics', 
              fontsize=18, fontweight='bold', pad=20)
    plt.xlabel('Methods', fontsize=16, fontweight='bold')
    plt.ylabel('Final Reward Values', fontsize=16, fontweight='bold')
    
    plt.xticks(range(len(method_names)), method_names, fontsize=14)
    plt.yticks(fontsize=14)
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add statistics annotation
    stats_text = "Statistics Summary:\n"
    for i, (method, data) in enumerate(zip(method_names, all_data)):
        mean_val = np.mean(data)
        std_val = np.std(data)
        stats_text += f"{method}: μ={mean_val:.1f}±{std_val:.1f}\n"
    
    plt.text(0.02, 0.12, stats_text, transform=plt.gca().transAxes,
             verticalalignment='top', horizontalalignment='left',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.9),
             fontsize=10)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

plot_multi_model_reward_boxviolin(
    model_results_dict_1,
    save_path='multi_model_reward_boxviolin.pdf'
)

In [None]:
def plot_multi_model_ridgeline(model_results_dict, save_path=None):
    """
    Create ridgeline plot with stacked distributions showing mean ± std.
    """
    fig, axes = plt.subplots(len(model_results_dict), 1, 
                            figsize=(12, 2*len(model_results_dict)), 
                            dpi=300, sharex=True)
    
    if len(model_results_dict) == 1:
        axes = [axes]
    
    # Use the same color scheme as your other plots
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    for idx, (model_name, results_dict) in enumerate(model_results_dict.items()):
        # Collect all rewards
        all_rewards = []
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            try:
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                else:
                    continue
                
                # Extract final rewards based on dimensionality
                if rewards.ndim == 1:
                    final_rewards = rewards
                elif rewards.ndim == 2:
                    final_rewards = rewards[:, -1]  # Take last column
                else:
                    continue
                
                # Remove any NaN or infinite values
                final_rewards = final_rewards[np.isfinite(final_rewards)]
                if len(final_rewards) > 0:
                    all_rewards.extend(final_rewards)
                    
            except Exception as e:
                continue
        
        if all_rewards:
            all_rewards = np.array(all_rewards)
            color = colors.get(model_name, 'gray')
            
            # Plot distribution
            sns.kdeplot(data=all_rewards, ax=axes[idx], color=color, fill=True, alpha=0.7)
            
            # Calculate statistics
            mean_val = np.mean(all_rewards)
            std_val = np.std(all_rewards)
            
            # Add mean line
            axes[idx].axvline(mean_val, color='black', linestyle='--', linewidth=2, alpha=0.8)
            
            # Add ±1 std shaded region
            axes[idx].axvspan(mean_val - std_val, mean_val + std_val, 
                             color=color, alpha=0.2)
            
            # Styling for each subplot with mean ± std
            axes[idx].set_ylabel(f'{model_name.upper()}\n(μ={mean_val:.2f}±{std_val:.2f})', 
                               fontsize=12, fontweight='bold')
            axes[idx].grid(True, alpha=0.3)
            axes[idx].set_yticks([])  # Remove y-ticks for cleaner look
            
            # Add text box with additional stats
            stats_text = f'n={len(all_rewards)}'
            axes[idx].text(0.02, 0.98, stats_text, transform=axes[idx].transAxes,
                          verticalalignment='top', horizontalalignment='left',
                          bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                          fontsize=10)
    
    # Overall styling
    axes[-1].set_xlabel('Final Reward Values', fontsize=12, fontweight='bold')
    plt.suptitle('Ridgeline Plot: Model Performance Distributions\n'
                 'Each distribution represents 6000 trajectories (30 models × 200 simulations each)',
                fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

plot_multi_model_ridgeline(
    model_results_dict_1,
    save_path='multi_model_ridgeline.pdf'
)

In [None]:

# # Run the analysis
# save_dir = 'analysis_outputs/statistical_analysis_oracle_1'
# os.makedirs(save_dir, exist_ok=True)
# summary_df, test_results, run_means = perform_comprehensive_statistical_analysis(
#     model_results_dict_1, 
#     save_dir=save_dir
# )

In [None]:
import json

# Collect all seeds from each model
all_seeds = {}

# Model results dict 1 (Oracle 1)
for model_name, runs in model_results_dict_1.items():
    all_seeds[f"{model_name}_oracle_1"] = {}
    for run_idx, run_data in runs.items():
        all_seeds[f"{model_name}_oracle_1"][f"run_{run_idx}"] = run_data['seed']

# Model results dict 2 (Oracle 2) 
for model_name, runs in model_results_dict_2.items():
    all_seeds[f"{model_name}_oracle_2"] = {}
    for run_idx, run_data in runs.items():
        all_seeds[f"{model_name}_oracle_2"][f"run_{run_idx}"] = run_data['seed']

# Model results dict 3 (Oracle 3)
for model_name, runs in model_results_dict_3.items():
    all_seeds[f"{model_name}_oracle_3"] = {}
    for run_idx, run_data in runs.items():
        all_seeds[f"{model_name}_oracle_3"][f"run_{run_idx}"] = run_data['seed']

# Save to JSON file
with open('seeds.json', 'w') as f:
    json.dump(all_seeds, f, indent=2)

print("Seeds saved to seeds.json")
print(f"Total models across all oracles: {len(all_seeds)}")
for model_key in all_seeds.keys():
    print(f"{model_key}: {len(all_seeds[model_key])} runs")

In [None]:
def plot_multi_model_reward_distributions_with_ci(model_results_dict, save_path=None, confidence_level=0.95):
    """
    Plot reward distributions for multiple models with confidence intervals.
    Shows single mean line per method with shaded confidence intervals.
    
    Args:
        model_results_dict (dict): Dictionary where keys are model names and values are 
                                 dictionaries with run results containing 'rewards' arrays
        save_path (str, optional): Path to save the plot
        confidence_level (float): Confidence level for intervals (default 0.95 for 95% CI)
    """
    from scipy.stats import gaussian_kde
    import numpy as np
    
    # Create professional distribution plot
    plt.figure(figsize=(14, 10), dpi=300)
    
    # Define color palette for different models
    colors = plt.cm.Set1(np.linspace(0, 1, len(model_results_dict)))
    
    model_stats = {}
    
    for model_idx, (model_name, results_dict) in enumerate(model_results_dict.items()):
        # Collect all rewards from all runs for this model
        all_rewards = []
        run_means = []  # Store mean of each run for CI calculation
        
        for run_idx in range(len(results_dict)):
            rewards = results_dict[run_idx]['rewards']
            
            # Handle different reward shapes and ensure numeric type
            if isinstance(rewards, (list, tuple)):
                rewards = np.array(rewards, dtype=float)
            elif isinstance(rewards, np.ndarray):
                rewards = rewards.astype(float)
            
            # Extract final rewards based on dimensionality
            if rewards.ndim == 1:
                final_rewards = rewards
            elif rewards.ndim == 2:
                final_rewards = rewards[:, -1]  # Take last column
            else:
                raise ValueError(f"Unexpected rewards shape: {rewards.shape}")
            
            # Remove any NaN or infinite values
            final_rewards = final_rewards[np.isfinite(final_rewards)]
            
            if len(final_rewards) == 0:
                print(f"Warning: No valid rewards found for {model_name}, run {run_idx}")
                continue
                
            all_rewards.append(final_rewards)
            run_means.append(np.mean(final_rewards))
        
        if len(all_rewards) == 0:
            print(f"Warning: No valid data for {model_name}, skipping...")
            continue
        
        # Calculate overall statistics
        all_values = np.concatenate(all_rewards)
        
        # Ensure all_values is numeric and finite
        all_values = all_values[np.isfinite(all_values)]
        
        if len(all_values) == 0:
            print(f"Warning: No finite values for {model_name}, skipping...")
            continue
            
        overall_mean = np.mean(all_values)
        overall_std = np.std(all_values)
        
        # Calculate confidence intervals from run means
        run_means = np.array(run_means)
        run_means = run_means[np.isfinite(run_means)]  # Remove any NaN values
        
        if len(run_means) == 0:
            print(f"Warning: No valid run means for {model_name}, skipping...")
            continue
            
        run_mean_avg = np.mean(run_means)
        run_std_error = np.std(run_means) / np.sqrt(len(run_means))
        
        # Calculate confidence interval (95% by default)
        z_score = 1.96 if confidence_level == 0.95 else 2.576 if confidence_level == 0.99 else 1.645
        ci_lower = run_mean_avg - z_score * run_std_error
        ci_upper = run_mean_avg + z_score * run_std_error
        
        # Create KDE for the overall distribution
        try:
            kde = gaussian_kde(all_values)
            x_range = np.linspace(all_values.min(), all_values.max(), 1000)
            density = kde(x_range)
        except Exception as e:
            print(f"Warning: KDE failed for {model_name}: {e}")
            print(f"Data type: {all_values.dtype}, shape: {all_values.shape}")
            print(f"Sample values: {all_values[:5] if len(all_values) >= 5 else all_values}")
            continue
        
        # Plot main distribution line
        plt.plot(x_range, density, color=colors[model_idx], linewidth=3, 
                label=f'{model_name}: μ={overall_mean:.2f}±{overall_std:.2f}')
        
        # Create confidence interval bounds for the density
        density_std = np.std(density)
        upper_bound = density + 0.3 * density_std
        lower_bound = np.maximum(density - 0.3 * density_std, 0)
        
        # Fill between for confidence interval visualization
        plt.fill_between(x_range, lower_bound, upper_bound, 
                        color=colors[model_idx], alpha=0.2)
        
        # Add vertical lines for confidence interval of means
        plt.axvline(ci_lower, color=colors[model_idx], linestyle=':', alpha=0.7, linewidth=2)
        plt.axvline(ci_upper, color=colors[model_idx], linestyle=':', alpha=0.7, linewidth=2)
        plt.axvline(overall_mean, color=colors[model_idx], linestyle='--', 
                   linewidth=3, alpha=0.8)
        
        # Store stats for summary
        model_stats[model_name] = {
            'mean': overall_mean,
            'std': overall_std,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'color': colors[model_idx],
            'total_trajectories': sum(len(rewards) for rewards in all_rewards),
            'n_runs': len(run_means)
        }
    
    # Only proceed with plotting if we have valid data
    if not model_stats:
        print("Error: No valid data found for any model!")
        return {}
    
    # Styling
    plt.title(f'Model Comparison: Mean Distributions with {int(confidence_level*100)}% Confidence Intervals\n'
              f'(Based on Independent Runs)', 
              fontsize=18, fontweight='bold', pad=25)
    plt.xlabel('Final Reward Values', fontsize=16, fontweight='bold')
    plt.ylabel('Density', fontsize=16, fontweight='bold')
    
    # Professional styling
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.legend(fontsize=12, frameon=True, framealpha=0.9, loc='upper right')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    # Increase tick label size
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to: {save_path}")
    
    plt.show()
    
    # Print summary statistics with confidence intervals
    print("\n" + "="*80)
    print(f"MODEL COMPARISON SUMMARY ({int(confidence_level*100)}% confidence intervals)")
    print("="*80)
    for model_name, stats in model_stats.items():
        print(f"\n{model_name}:")
        print(f"  Mean: {stats['mean']:.3f} ± {stats['std']:.3f}")
        print(f"  {int(confidence_level*100)}% CI: [{stats['ci_lower']:.3f}, {stats['ci_upper']:.3f}]")
        print(f"  Runs: {stats['n_runs']}, Total Trajectories: {stats['total_trajectories']}")
    
    return model_stats

# Alternative version with error bars instead of shaded regions
def plot_multi_model_reward_distributions_errorbar(model_results_dict, save_path=None):
    """
    Alternative approach using error bars on mean values.
    """
    plt.figure(figsize=(14, 8), dpi=300)
    
    colors = plt.cm.Set1(np.linspace(0, 1, len(model_results_dict)))
    
    positions = range(len(model_results_dict))
    means = []
    errors = []
    labels = []
    
    for model_idx, (model_name, results_dict) in enumerate(model_results_dict.items()):
        # Calculate mean reward per run
        run_means = []
        for run_idx in range(len(results_dict)):
            if results_dict[run_idx]['rewards'].ndim == 1:
                rewards = results_dict[run_idx]['rewards']
            else:
                rewards = results_dict[run_idx]['rewards'][:, -1]
            run_means.append(np.mean(rewards))
        
        run_means = np.array(run_means)
        mean_of_means = np.mean(run_means)
        std_error = np.std(run_means) / np.sqrt(len(run_means))
        ci_95 = 1.96 * std_error
        
        means.append(mean_of_means)
        errors.append(ci_95)
        labels.append(f'{model_name}\n(μ={mean_of_means:.2f}±{ci_95:.2f})')
    
    # Create bar plot with error bars
    bars = plt.bar(positions, means, yerr=errors, capsize=10, 
                   color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    
    # Customize plot
    plt.title('Model Comparison: Mean Final Rewards with 95% Confidence Intervals', 
              fontsize=18, fontweight='bold', pad=20)
    plt.ylabel('Mean Final Reward', fontsize=16, fontweight='bold')
    plt.xlabel('Methods', fontsize=16, fontweight='bold')
    
    plt.xticks(positions, [name for name, _ in model_results_dict.items()], 
               rotation=45, ha='right', fontsize=14)
    plt.yticks(fontsize=14)
    
    plt.grid(True, alpha=0.3, axis='y')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path.replace('.pdf', '_errorbar.pdf'), dpi=300, bbox_inches='tight')
    
    plt.show()

# Usage examples:
# Version 1: Density plots with shaded confidence intervals
stats_ci = plot_multi_model_reward_distributions_with_ci(
    model_results_dict_1, 
    save_path='multi_model_distributions_with_ci.pdf',
    confidence_level=0.95
)

# Version 2: Bar chart with error bars
plot_multi_model_reward_distributions_errorbar(
    model_results_dict_1,
    save_path='multi_model_distributions_errorbar.pdf'
)

In [None]:
rewards = np.array(rewards)

In [None]:
rewards[:,-1]


In [None]:
results_smcmc = joblib.load('results/PaperPlots_O3_SAC/sac_results_runs_oracle_3.joblib')

In [None]:
results_smcmc[0]['trained_agent']

In [None]:
models = joblib.load(r"mlruns/4/1fefba2df9fe4438bac2dfada24da7f4/artifacts/forward_model_iteration_500/results_run.joblib")

In [None]:
models[1]

In [None]:
from scipy.stats import gaussian_kde

# Collect all rewards from all runs
all_rewards = []
for run_idx in range(30):
    rewards = results_smcmc[run_idx]['results_dict']['rewards'][:, -1]
    all_rewards.append(rewards)

# Create professional distribution plot
plt.figure(figsize=(12, 8), dpi=300)

# Plot each run's KDE distribution
for run_idx, rewards in enumerate(all_rewards):
    sns.kdeplot(data=rewards, 
                color='steelblue', 
                alpha=0.1, 
                linewidth=0.1)

# Calculate overall statistics for the plot
all_values = np.concatenate(all_rewards)
overall_mean = np.mean(all_values)
overall_std = np.std(all_values)

# Add mean line
plt.axvline(overall_mean, color='red', linestyle='--', 
           linewidth=3, label=f'Overall Mean: {overall_mean:.2f}')

# Styling
plt.title('SMCMC Reward Distributions - 30 Independent Runs', 
          fontsize=18, fontweight='bold', pad=20)
plt.xlabel('Final Reward Values', fontsize=14, fontweight='bold')
plt.ylabel('Density', fontsize=14, fontweight='bold')

# Add statistics text box
total_trajectories = sum(len(rewards) for rewards in all_rewards)
stats_text = f'30 Independent Runs\n' \
             f'Mean: {overall_mean:.2f} ± {overall_std:.2f}\n' \
             f'Total Trajectories: {total_trajectories}'

plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, 
         verticalalignment='top', horizontalalignment='left',
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, 
                  edgecolor='black', linewidth=1),
         fontsize=12)

# Professional styling
plt.grid(True, alpha=0.3, linestyle='--')
plt.legend(fontsize=12, frameon=True, framealpha=0.9)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('smcmc_reward_distributions_overlay.pdf', 
            dpi=300, bbox_inches='tight')
plt.show()

# Print summary statistics
print(f"SMCMC Results Summary (30 independent runs):")
print(f"Overall Mean: {overall_mean:.3f}")
print(f"Overall Std: {overall_std:.3f}")
print(f"Min: {np.min(all_values):.3f}")
print(f"Max: {np.max(all_values):.3f}")


In [None]:
import random
from utility_functions import seed_all
# Load forward and backward models
root = r'mlruns/2'
# run_id = '78263dd8c45e435d8579682b76a22ef9' # Current Best
# run_id = '94e853e9e6014908b56273897438ef32' # Gaussian 
# run_id = '1b62a3f850e64697b4191238b3220b24' # Betak = 1 and Current Best
# run_id = '940d0167c8784fe7a342794d0f36a2f9' # Gaussian k = 5
# run_id = 'b31630e9983a4c79af6d2136fa9b3c0d' # Gaussian k = 10
# run_id='78263dd8c45e435d8579682b76a22ef9' # gaussian k = 15
# run_id = '17ed5f6c45274eb493c3085ab1b6524f' # beta k = 5
# run_id = 'c0bcf8b0df2b48ad8a338c8565ab426d' # beta k = 10
# run_id = '5d6a6a7413c442a0a892224b73aa4210' # beta k = 15
# run_id = 'd6f0a6ce09274afa92685eba82a94b6f' # O2 beta k = 15
# run_id = 'b3309f6144b646559887504bf74dcb3f' # O1 beta k = 15
run_id = '5db11b78fc7d49dc9fd975ddf93c7e55' # Testing bootstrapping results
iterations = 500
path_fwd = fr'{root}/{run_id}/artifacts/forward_model_iteration_{iterations}/data/model.pth'
path_bwd = fr'{root}/{run_id}/artifacts/backward_model_iteration_{iterations}/data/model.pth'
device  = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load the forward and backward models
forward_model = load_entire_model(path_fwd, device=device)
backward_model = load_entire_model(path_bwd, device=device)
# Load used config
config_path = fr'{root}/{run_id}/artifacts/run_params_config.yaml'
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

initial_state,feature_names = load_initial_state(config)
trajectory_length = config['model']['training_parameters']['trajectory_length']
n_trajectories = 200
n_runs = 30
input_dim = forward_model.input_layer.in_features
model_path = config['oracle']['model_path']
distribution = config['model']['training_parameters']['distribution_type']
extra_parameters = {
    'mixture_components': config['model']['training_parameters']['mixture_components'],
    'num_variables': input_dim -1 ,
}
environment = GeneralEnvironment(
    initial_state=initial_state,
    config = config,
    model = forward_model,
    input_dim = forward_model.input_layer.in_features ,
    max_steps=config['model']['training_parameters']['trajectory_length'],
    model_path=config['oracle']['model_path'])

logger,log_file_name = setup_logger('evaluation_logger')
results_runs = {}
for run in range(n_runs):
    print(f"Run {run+1}/{n_runs}")
    # Simulate trajectories using the forward and backward models
     # Simulate trajectories using the forward and backward models
    random_seed = random.randint(0, 10000)
    seed_all(random_seed)
    trajectories, rewards,all_actions,_ = simulate_trajectories(
            env_class=GeneralEnvironment,
            forward_model=forward_model,
            backward_model=backward_model,
            initial_state=initial_state,
            config=config,
            trajectory_length=trajectory_length,
            n_trajectories=n_trajectories,
            device=device,
            logger = logger,
            model_path = model_path,
            distribution=distribution,
            extra_parameters=extra_parameters

        )
    results_runs[run] = {'trajectories': trajectories, 
                          'rewards': rewards,
                          'all_actions': all_actions,
                          '_': _,
                          'seed': random_seed}  # Store the feature names as well






In [None]:
# Collect rewards from all runs
all_run_rewards = []
run_statistics = []

for run in range(n_runs):
    print(f"Run {run+1}/{n_runs}")
    rewards = np.array(results_runs[run]['rewards'])[:,-1]  # Get final rewards
    all_run_rewards.append(rewards)
    
    # Calculate statistics for this run
    run_stats = {
        'run': run,
        'mean': np.mean(rewards),
        'std': np.std(rewards),
        'median': np.median(rewards),
        'min': np.min(rewards),
        'max': np.max(rewards),
        'q25': np.percentile(rewards, 25),
        'q75': np.percentile(rewards, 75)
    }
    run_statistics.append(run_stats)

# Convert to DataFrame for easier handling
stats_df = pd.DataFrame(run_statistics)

# Create professional reward distribution plot
plt.figure(figsize=(14, 8), dpi=300)

# Use a professional color palette
colors = plt.cm.viridis(np.linspace(0, 1, n_runs))

# Plot each run's distribution
for i, (rewards, color) in enumerate(zip(all_run_rewards, colors)):
    sns.kdeplot(data=rewards, 
                color=color, 
                alpha=0.6, 
                linewidth=1.5,
                label=f'Run {i+1}' if i < 5 else "")  # Only label first 5 to avoid clutter

# Calculate overall statistics
all_rewards_combined = np.concatenate(all_run_rewards)
overall_mean = np.mean(all_rewards_combined)
overall_std = np.std(all_rewards_combined)

# Add overall mean line
plt.axvline(overall_mean, color='red', linestyle='--', linewidth=3, 
           label=f'Overall Mean: {overall_mean:.2f}')

# Styling
plt.title('Distribution of Final Rewards Across All Runs', 
          fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Final Reward Values', fontsize=14, fontweight='bold')
plt.ylabel('Density', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3, linestyle='--')

# Add statistics text box
stats_text = f'Overall Statistics:\n' \
             f'Mean: {overall_mean:.2f} ± {overall_std:.2f}\n' \
             f'Runs: {n_runs}\n' \
             f'Total Trajectories: {n_runs * n_trajectories}'

plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, 
         verticalalignment='top', horizontalalignment='left',
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor='black'),
         fontsize=12)

# Improve legend
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
plt.tight_layout()
plt.savefig('reward_distributions_all_runs.pdf', dpi=300, bbox_inches='tight')
plt.show()

# Create statistics plot with error bars
fig, axes = plt.subplots(2, 2, figsize=(16, 12), dpi=300)
fig.suptitle('Statistics Across Runs with Error Bars', fontsize=18, fontweight='bold')

# Plot 1: Mean with std as error bars
axes[0,0].errorbar(range(1, n_runs+1), stats_df['mean'], 
                   yerr=stats_df['std'], fmt='o-', capsize=5, 
                   color='blue', ecolor='lightblue', linewidth=2, markersize=6)
axes[0,0].set_title('Mean Final Reward by Run', fontsize=14, fontweight='bold')
axes[0,0].set_xlabel('Run Number', fontsize=12)
axes[0,0].set_ylabel('Mean Final Reward', fontsize=12)
axes[0,0].grid(True, alpha=0.3)

# Plot 2: Median with IQR as error bars
iqr_lower = stats_df['median'] - stats_df['q25']
iqr_upper = stats_df['q75'] - stats_df['median']
axes[0,1].errorbar(range(1, n_runs+1), stats_df['median'], 
                   yerr=[iqr_lower, iqr_upper], fmt='s-', capsize=5,
                   color='green', ecolor='lightgreen', linewidth=2, markersize=6)
axes[0,1].set_title('Median Final Reward by Run', fontsize=14, fontweight='bold')
axes[0,1].set_xlabel('Run Number', fontsize=12)
axes[0,1].set_ylabel('Median Final Reward', fontsize=12)
axes[0,1].grid(True, alpha=0.3)

# Plot 3: Min and Max
axes[1,0].plot(range(1, n_runs+1), stats_df['min'], 'v-', 
               color='red', label='Minimum', linewidth=2, markersize=6)
axes[1,0].plot(range(1, n_runs+1), stats_df['max'], '^-', 
               color='orange', label='Maximum', linewidth=2, markersize=6)
axes[1,0].fill_between(range(1, n_runs+1), stats_df['min'], stats_df['max'], 
                       alpha=0.2, color='gray')
axes[1,0].set_title('Min/Max Final Reward by Run', fontsize=14, fontweight='bold')
axes[1,0].set_xlabel('Run Number', fontsize=12)
axes[1,0].set_ylabel('Final Reward', fontsize=12)
axes[1,0].legend()
axes[1,0].grid(True, alpha=0.3)

# Plot 4: Standard deviation
axes[1,1].bar(range(1, n_runs+1), stats_df['std'], 
              color='purple', alpha=0.7, edgecolor='black')
axes[1,1].set_title('Standard Deviation by Run', fontsize=14, fontweight='bold')
axes[1,1].set_xlabel('Run Number', fontsize=12)
axes[1,1].set_ylabel('Standard Deviation', fontsize=12)
axes[1,1].grid(True, alpha=0.3, axis='y')

# Add overall statistics as text
overall_stats_text = f'Overall Statistics Across {n_runs} Runs:\n' \
                    f'Mean of Means: {stats_df["mean"].mean():.2f} ± {stats_df["mean"].std():.2f}\n' \
                    f'Mean of Medians: {stats_df["median"].mean():.2f} ± {stats_df["median"].std():.2f}\n' \
                    f'Mean Std Dev: {stats_df["std"].mean():.2f} ± {stats_df["std"].std():.2f}'

plt.figtext(0.02, 0.02, overall_stats_text, 
           bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8),
           fontsize=11)

plt.tight_layout()
plt.savefig('statistics_across_runs.pdf', dpi=300, bbox_inches='tight')
plt.show()

# Print summary statistics
print("\n" + "="*60)
print("SUMMARY STATISTICS ACROSS ALL RUNS")
print("="*60)
print(f"Number of runs: {n_runs}")
print(f"Trajectories per run: {n_trajectories}")
print(f"Total trajectories: {n_runs * n_trajectories}")
print("\nMean Final Rewards:")
print(f"  Overall mean: {stats_df['mean'].mean():.3f} ± {stats_df['mean'].std():.3f}")
print(f"  Min mean: {stats_df['mean'].min():.3f}")
print(f"  Max mean: {stats_df['mean'].max():.3f}")
print("\nMedian Final Rewards:")
print(f"  Overall median: {stats_df['median'].mean():.3f} ± {stats_df['median'].std():.3f}")
print(f"  Min median: {stats_df['median'].min():.3f}")
print(f"  Max median: {stats_df['median'].max():.3f}")
print("\nStandard Deviations:")
print(f"  Mean std: {stats_df['std'].mean():.3f} ± {stats_df['std'].std():.3f}")
print(f"  Min std: {stats_df['std'].min():.3f}")
print(f"  Max std: {stats_df['std'].max():.3f}")


In [None]:
# Save results to joblib like the others 
import joblib
import os
# make ndarray of trajectories
trajectories = np.array(trajectories)
rewards = np.array(rewards)
all_actions = np.array(all_actions)
# make ndarray of rewards
results_dict = {
    'trajectories': trajectories,
    'rewards': rewards,
    'actions': all_actions,
    'last_states': np.array([traj[-1] for traj in trajectories]),
    'feature_names': _
}

# Save the results
output_path = r"results/Ablation"
os.makedirs(output_path, exist_ok=True)
output_file = os.path.join(output_path, 'gflownet_results_oracle_1.joblib')
joblib.dump(results_dict, output_file)

In [None]:
from utility_functions import *
# trajectories, rewards,all_actions,_ = simulate_trajectories(
#         env_class=GeneralEnvironment,
#         forward_model=forward_model,
#         backward_model=backward_model,
#         initial_state=initial_state,
#         config=config,
#         trajectory_length=trajectory_length,
#         n_trajectories=n_trajectories,
#         device=device,
#         logger = logger,
#         model_path = model_path,
#         distribution=distribution,
#         extra_parameters=extra_parameters

#     )



In [None]:
model_results_dict_1['gflow'][0].keys()

In [None]:
import joblib
import numpy as np
import pandas as pd
import os
from datetime import datetime
config_path = fr'{root}/{run_id}/artifacts/run_params_config.yaml'
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

initial_state,feature_names = load_initial_state(config)
trajectory_length = config['model']['training_parameters']['trajectory_length']
n_trajectories = 200
input_dim = forward_model.input_layer.in_features
model_path = config['oracle']['model_path']
distribution = config['model']['training_parameters']['distribution_type']
extra_parameters = {
    'mixture_components': config['model']['training_parameters']['mixture_components'],
    'num_variables': input_dim -1 ,
}
environment = GeneralEnvironment(
    initial_state=initial_state,
    config = config,
    model = forward_model,
    input_dim = forward_model.input_layer.in_features ,
    max_steps=config['model']['training_parameters']['trajectory_length'],
    model_path=config['oracle']['model_path'])

# Create main output directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_type = "GFlowNet"  # or "REINFORCE", "SMCMC", etc. GFlowNet REINFORCE_B
use_case = 'O1'
output_path = r"results/PaperPlots_O1"
base_output_dir = fr"results/Comparison/PaperPlots_{use_case}/analysis_{model_type}_{timestamp}"
os.makedirs(base_output_dir, exist_ok=True)

# Create subdirectories for different types of outputs
plots_dir = os.path.join(base_output_dir, 'plots')
metrics_dir = os.path.join(base_output_dir, 'metrics')
action_plots_dir = os.path.join(base_output_dir, 'action_distributions')
os.makedirs(plots_dir, exist_ok=True)
os.makedirs(metrics_dir, exist_ok=True)
os.makedirs(action_plots_dir, exist_ok=True)

# Load data
reinforce_sample = joblib.load(fr"results/Ablation/gflownet_results_oracle_1.joblib")
# reinforce_sample = joblib.load(fr'results/reinforce_results_oracle_1.joblib')
# reinforce_sample = joblib.load(fr'results/reinforce_baseline_results_oracle_1.joblib')
# reinforce_sample = joblib.load(fr'results/sac_results_oracle_1.joblib')
# reinforce_sample = joblib.load(fr'results/mcmc_results_oracle_{use_case[-1]}.joblib')
trajectories = reinforce_sample['trajectories']
rewards = reinforce_sample['rewards']
# if model_type == 'SMCMC':
#     all_actions = reinforce_sample['actions'][:,:,1:]
# else:
#     all_actions = reinforce_sample['actions']

# Process rewards and predictions
rewards = np.array(rewards) if isinstance(rewards, list) else rewards
final_rewards = np.array([r[-1] for r in rewards])
final_predictions = np.array([-r for r in final_rewards])

# Get top 100 trajectories
top_100_trajectories_indices = np.argsort(final_predictions)[:100]
trajectories_top_100 = np.array([trajectories[i] for i in top_100_trajectories_indices])
final_rewards_top_100 = np.array([final_rewards[i] for i in top_100_trajectories_indices])
final_predictions_top_100 = np.array([final_predictions[i] for i in top_100_trajectories_indices])

# Calculate and save metrics
avg_reward = average_reward(final_rewards_top_100)
tc_20 = tail_coverage(final_predictions_top_100, tau=-20)
tc_40 = tail_coverage(final_predictions_top_100, tau=-40)
tc_60 = tail_coverage(final_predictions_top_100, tau=-60)
es_5= expected_shortfall(final_predictions_top_100, q=5)
es_10= expected_shortfall(final_predictions_top_100, q=10)
es_20= expected_shortfall(final_predictions_top_100, q=20)

metrics_dict = {
    'average_reward': avg_reward,
    'tail_coverage_20': tc_20,
    'tail_coverage_40': tc_40,
    'tail_coverage_60': tc_60,
    'expected_shortfall': es_5,
    'expected_shortfall_10': es_10,
    'expected_shortfall_20': es_20
}

# Save metrics to file
with open(os.path.join(metrics_dir, 'performance_metrics.txt'), 'w') as f:
    for metric, value in metrics_dict.items():
        f.write(f"{metric}: {value:.2f}\n")

# Print out a summary of the trajectories and rewards
rewards_dict = {
    'min': min(final_rewards_top_100),
    'max': max(final_rewards_top_100),
    'mean': np.mean(final_rewards_top_100),
    'std': np.std(final_rewards_top_100),
    'median': np.median(final_rewards_top_100)
}
with open(os.path.join(metrics_dir, 'rewards_summary.txt'), 'w') as f:
    for metric, value in rewards_dict.items():
        f.write(f"{metric}: {value:.2f}\n")
  
results_df, distance_matrix, feature_clusters = dtw_clustering_analysis(
    trajectories=trajectories_top_100,
    n_clusters=3
)

# Save clustering plots
plot_trajectories_by_cluster(results_df, trajectories, ['Timestamp']+feature_names,
                           num_clusters=3,
                           save_path=os.path.join(plots_dir, 'cluster_trajectories.pdf'))


res_dict_dtw = plot_dtw_distance_matrix(distance_matrix, model_name=f'{model_type}', 
                        rewards=final_rewards_top_100,
                        save_path=os.path.join(plots_dir, 'dtw_distance_matrix.pdf'))

avg_trajectory_distance = res_dict_dtw['average_distance']  
normalized_score = res_dict_dtw['normalized_score']
reward_norm = res_dict_dtw['reward_normalization']
# Write metrics to file
with open(os.path.join(metrics_dir, 'dtw_metrics.txt'), 'w') as f:
    f.write(f"Average Distance: {avg_trajectory_distance:.2f}\n")
    f.write(f"Normalized Score: {normalized_score:.4f}\n")
    f.write(f"Reward Normalization: {reward_norm:.4f}\n")


# Min-max trajectory comparisons
min_pair, max_pair = find_min_max_distance_pairs(distance_matrix)
compare_trajectories(min_pair[0], min_pair[1], trajectories_top_100, 
                    ['Timestamp']+feature_names, final_rewards_top_100,
                    save_path=os.path.join(plots_dir, 'min_distance_pair.pdf'))
compare_trajectories(max_pair[0], max_pair[1], trajectories_top_100, 
                    ['Timestamp']+feature_names, final_rewards_top_100,
                    save_path=os.path.join(plots_dir, 'max_distance_pair.pdf'))

# Trajectory and action distribution plots
plot_trajectories_over_time(
    trajectories=trajectories_top_100,
    rewards=final_rewards_top_100,
    feature_names=['Timestamp'] + feature_names,
    n_top=50,
    alpha_others=0.05,
    save_path=plots_dir
)

plot_action_distributions_by_timestep(
    trajectories=trajectories,
    all_actions=all_actions,
    feature_names=['Timestamp'] + feature_names, # Remove Timestamps for SMCMC
    save_path=action_plots_dir
)       

plot_action_distributions(
    all_actions=all_actions,
    feature_names=['Timestamp'] + feature_names,
    save_path=os.path.join(plots_dir, 'distribution_actions.pdf')
)

# Create and save summary DataFrame
summary_rows = []
for rank, idx in enumerate(top_100_trajectories_indices):
    row_dict = {
        "Rank": rank + 1,
        "CaseIndex": idx,
        "FinalReward": final_rewards[idx]
    }
    final_values = trajectories[idx, -1, :]
    for f_i, f_name in enumerate(feature_names):
        row_dict[f_name] = final_values[f_i]
    summary_rows.append(row_dict)

top100_df = pd.DataFrame(summary_rows).sort_values(by="FinalReward", ascending=False)
top100_df.to_csv(os.path.join(metrics_dir, 'top_100_summary.csv'), index=False)

# Diversity metrics and plots
avg_div, norm_div, distances = plot_diversity_metrics(
    df=top100_df,
    feature_names=feature_names,
    save_path=os.path.join(plots_dir, 'euclidean_distance_matrix.pdf')
)

with open(os.path.join(metrics_dir, 'diversity_metrics.txt'), 'w') as f:
    f.write(f"Average Diversity - Last State: {avg_div:.2f}\n")
    f.write(f"Normalized Diversity - Last State: {norm_div:.4f}\n")

# Final reward distribution plot
plot_reward_distribution(final_rewards_top_100, 
                        save_path=os.path.join(plots_dir, 'reward_distribution.pdf'))

In [None]:
def calculate_metrics_across_runs(model_results_dict, confidence_level=0.95, top_k=100):
    """
    Calculate performance metrics across multiple independent runs with confidence intervals.
    
    Args:
        model_results_dict (dict): Dictionary with model results from 30 independent runs
        confidence_level (float): Confidence level for intervals (default 0.95)
        top_k (int): Number of top trajectories to consider (default 100)
    
    Returns:
        dict: Comprehensive metrics with confidence intervals for each model
    """
    from scipy.stats import t
    import numpy as np
    
    results = {}
    
    for model_name, runs_dict in model_results_dict.items():
        print(f"Processing {model_name}...")
        
        # Lists to store metrics from each run
        run_metrics = {
            'avg_reward': [],
            'max_reward': [],
            'min_reward': [],
            'median_reward': [],
            'std_reward': [],
            'tc_20': [],
            'tc_40': [],
            'tc_60': [],
            'es_5': [],
            'es_10': [],
            'es_20': [],
            'top_10_mean': [],
            'top_50_mean': [],
            'percentile_90': [],
            'percentile_10': [],
            'n_trajectories': []
        }
        
        # Process each run
        for run_idx in range(30):
            try:
                rewards = runs_dict[run_idx]['rewards']
                
                # Handle different reward formats
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                
                # Extract final rewards
                if rewards.ndim == 2:
                    final_rewards = rewards[:, -1]
                elif rewards.ndim == 1:
                    final_rewards = rewards
                else:
                    continue
                
                # Remove non-finite values
                final_rewards = final_rewards[np.isfinite(final_rewards)]
                
                if len(final_rewards) == 0:
                    continue
                
                # Get top K trajectories for this run
                top_k_actual = min(top_k, len(final_rewards))
                top_indices = np.argsort(final_rewards)[-top_k_actual:]
                top_rewards = final_rewards[top_indices]
                
                # Calculate basic statistics
                run_metrics['avg_reward'].append(np.mean(top_rewards))
                run_metrics['max_reward'].append(np.max(top_rewards))
                run_metrics['min_reward'].append(np.min(top_rewards))
                run_metrics['median_reward'].append(np.median(top_rewards))
                run_metrics['std_reward'].append(np.std(top_rewards))
                run_metrics['n_trajectories'].append(len(final_rewards))
                
                # Calculate tail coverage (percentage <= threshold)
                final_predictions = -top_rewards  # Convert to predictions (negative rewards)
                run_metrics['tc_20'].append(tail_coverage(final_predictions, tau=-20))
                run_metrics['tc_40'].append(tail_coverage(final_predictions, tau=-40))
                run_metrics['tc_60'].append(tail_coverage(final_predictions, tau=-60))
                
                # Calculate expected shortfall
                run_metrics['es_5'].append(expected_shortfall(final_predictions, q=5))
                run_metrics['es_10'].append(expected_shortfall(final_predictions, q=10))
                run_metrics['es_20'].append(expected_shortfall(final_predictions, q=20))
                
                # Calculate top percentile means
                sorted_rewards = np.sort(top_rewards)[::-1]  # Sort descending
                run_metrics['top_10_mean'].append(np.mean(sorted_rewards[:min(10, len(sorted_rewards))]))
                run_metrics['top_50_mean'].append(np.mean(sorted_rewards[:min(50, len(sorted_rewards))]))
                
                # Calculate percentiles
                run_metrics['percentile_90'].append(np.percentile(top_rewards, 90))
                run_metrics['percentile_10'].append(np.percentile(top_rewards, 10))
                
            except Exception as e:
                print(f"Error processing {model_name} run {run_idx}: {e}")
                continue
        
        # Calculate statistics across runs with confidence intervals
        model_results = {}
        
        for metric_name, values in run_metrics.items():
            if len(values) == 0:
                continue
                
            values = np.array(values)
            values = values[np.isfinite(values)]  # Remove any NaN values
            
            if len(values) == 0:
                continue
            
            n_runs = len(values)
            mean_val = np.mean(values)
            std_val = np.std(values, ddof=1)  # Sample standard deviation
            
            # Calculate confidence intervals using t-distribution
            if n_runs > 1:
                t_critical = t.ppf((1 + confidence_level) / 2, df=n_runs-1)
                margin_error = t_critical * (std_val / np.sqrt(n_runs))
                ci_lower = mean_val - margin_error
                ci_upper = mean_val + margin_error
            else:
                ci_lower = ci_upper = mean_val
            
            model_results[metric_name] = {
                'mean': mean_val,
                'std': std_val,
                'ci_lower': ci_lower,
                'ci_upper': ci_upper,
                'n_runs': n_runs,
                'raw_values': values.tolist()
            }
        
        results[model_name] = model_results
    
    return results

def save_metrics_to_files(metrics_results, save_dir, confidence_level=0.95):
    """
    Save metrics results to various file formats for easy access.
    
    Args:
        metrics_results (dict): Results from calculate_metrics_across_runs
        save_dir (str): Directory to save files
        confidence_level (float): Confidence level used
    """
    import os
    import pandas as pd
    import json
    
    os.makedirs(save_dir, exist_ok=True)
    
    # 1. Save comprehensive summary as CSV
    summary_data = []
    for model_name, model_metrics in metrics_results.items():
        row = {'Model': model_name}
        for metric_name, metric_data in model_metrics.items():
            if isinstance(metric_data, dict) and 'mean' in metric_data:
                row[f'{metric_name}_mean'] = metric_data['mean']
                row[f'{metric_name}_std'] = metric_data['std']
                row[f'{metric_name}_ci_lower'] = metric_data['ci_lower']
                row[f'{metric_name}_ci_upper'] = metric_data['ci_upper']
                row[f'{metric_name}_n_runs'] = metric_data['n_runs']
        summary_data.append(row)
    
    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(os.path.join(save_dir, 'metrics_summary.csv'), index=False)
    
    # 2. Save formatted results for papers/reports
    with open(os.path.join(save_dir, 'metrics_formatted.txt'), 'w') as f:
        f.write(f"PERFORMANCE METRICS WITH {int(confidence_level*100)}% CONFIDENCE INTERVALS\n")
        f.write("="*80 + "\n\n")
        
        for model_name, model_metrics in metrics_results.items():
            f.write(f"{model_name.upper()}:\n")
            f.write("-" * 40 + "\n")
            
            # Key metrics in formatted way
            key_metrics = ['avg_reward', 'max_reward', 'tc_20', 'es_10', 'es_20']
            for metric in key_metrics:
                if metric in model_metrics:
                    data = model_metrics[metric]
                    f.write(f"{metric:15}: {data['mean']:7.2f} ± {data['std']:5.2f} "
                           f"[{data['ci_lower']:6.2f}, {data['ci_upper']:6.2f}] "
                           f"(n={data['n_runs']})\n")
            f.write("\n")
    
    # 3. Save LaTeX table format
    with open(os.path.join(save_dir, 'metrics_latex_table.txt'), 'w') as f:
        f.write("% LaTeX Table Format\n")
        f.write("\\begin{table}[htbp]\n")
        f.write("\\centering\n")
        f.write("\\caption{Performance Metrics with Confidence Intervals}\n")
        f.write("\\begin{tabular}{l|ccccc}\n")
        f.write("\\hline\n")
        f.write("Method & Avg Reward & Max Reward & TC@20 & ES@10\\% & ES@20\\% \\\\\n")
        f.write("\\hline\n")
        
        for model_name, model_metrics in metrics_results.items():
            row = f"{model_name:<15}"
            key_metrics = ['avg_reward', 'max_reward', 'tc_20', 'es_10', 'es_20']
            for metric in key_metrics:
                if metric in model_metrics:
                    data = model_metrics[metric]
                    row += f" & ${data['mean']:5.1f} \\pm {data['std']:4.1f}$"
                else:
                    row += " & -"
            row += " \\\\\n"
            f.write(row)
        
        f.write("\\hline\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{table}\n")
    
    # 4. Save raw data as JSON for programmatic access
    with open(os.path.join(save_dir, 'metrics_raw.json'), 'w') as f:
        json.dump(metrics_results, f, indent=2)
    
    # 5. Save comparison table
    comparison_df = pd.DataFrame()
    for model_name, model_metrics in metrics_results.items():
        model_row = {}
        for metric_name, metric_data in model_metrics.items():
            if isinstance(metric_data, dict) and 'mean' in metric_data:
                model_row[metric_name] = f"{metric_data['mean']:.2f} ± {metric_data['std']:.2f}"
        comparison_df[model_name] = pd.Series(model_row)
    
    comparison_df.to_csv(os.path.join(save_dir, 'metrics_comparison.csv'))
    
    print(f"Metrics saved to {save_dir}/")
    print("Files created:")
    print("  - metrics_summary.csv: Complete numerical data")
    print("  - metrics_formatted.txt: Human-readable format")
    print("  - metrics_latex_table.txt: LaTeX table format")
    print("  - metrics_raw.json: Raw data for programming")
    print("  - metrics_comparison.csv: Side-by-side comparison")

def plot_metrics_comparison(metrics_results, save_path=None):
    """
    Create publication-ready plots comparing metrics across models.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Key metrics to plot
    key_metrics = {
        'avg_reward': 'Average Reward',
        'max_reward': 'Maximum Reward', 
        'tc_20': 'Tail Coverage @ -20',
        'es_10': 'Expected Shortfall @ 10%',
        'es_20': 'Expected Shortfall @ 20%'
    }
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12), dpi=300)
    axes = axes.flatten()
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    for idx, (metric_key, metric_name) in enumerate(key_metrics.items()):
        ax = axes[idx]
        
        models = []
        means = []
        errors = []
        
        for model_name, model_metrics in metrics_results.items():
            if metric_key in model_metrics:
                models.append(model_name.upper())
                data = model_metrics[metric_key]
                means.append(data['mean'])
                # Use confidence interval width as error
                error = (data['ci_upper'] - data['ci_lower']) / 2
                errors.append(error)
        
        if means:
            bars = ax.bar(range(len(models)), means, yerr=errors, 
                         capsize=10, alpha=0.7, 
                         color=[colors.get(m.lower(), 'gray') for m in models])
            
            ax.set_title(metric_name, fontsize=14, fontweight='bold')
            ax.set_xticks(range(len(models)))
            ax.set_xticklabels(models, rotation=45, ha='right')
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add value labels on bars
            for bar, mean, error in zip(bars, means, errors):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + error,
                       f'{mean:.1f}', ha='center', va='bottom', fontsize=10)
    
    # Remove empty subplot
    if len(key_metrics) < len(axes):
        fig.delaxes(axes[-1])
    
    plt.suptitle('Performance Metrics Comparison with Confidence Intervals\n'
                'Error bars show 95% confidence intervals', 
                fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Usage example:
print("Calculating metrics across all runs...")
metrics_results = calculate_metrics_across_runs(
    model_results_dict_3, 
    confidence_level=0.95, 
    top_k=100
)

# Save results
save_dir = "results/Comparison/metrics_analysis_3"
save_metrics_to_files(metrics_results, save_dir)

# Create comparison plots
plot_metrics_comparison(metrics_results, 
                       save_path=os.path.join(save_dir, 'metrics_comparison_plot.pdf'))

# Print summary
print("\n" + "="*80)
print("METRICS CALCULATION COMPLETED")
print("="*80)
for model_name, model_metrics in metrics_results.items():
    print(f"\n{model_name.upper()}:")
    if 'avg_reward' in model_metrics:
        data = model_metrics['avg_reward']
        print(f"  Average Reward: {data['mean']:.2f} ± {data['std']:.2f} "
              f"[{data['ci_lower']:.2f}, {data['ci_upper']:.2f}]")
    if 'tc_20' in model_metrics:
        data = model_metrics['tc_20']
        print(f"  Tail Coverage @ -20: {data['mean']:.2f} ± {data['std']:.2f}")
    if 'es_10' in model_metrics:
        data = model_metrics['es_10']
        print(f"  Expected Shortfall @ 10%: {data['mean']:.2f} ± {data['std']:.2f}")

In [None]:
model_results_dict_1['gflow'][0]['trajectories'][0].shape

In [None]:
model_results_dict_1['reinforce'][0]['trajectories'].shape

In [None]:
def calculate_diversity_metrics_across_runs(model_results_dict, confidence_level=0.95, top_k=100):
    """
    Calculate diversity metrics across multiple independent runs with confidence intervals.
    
    Args:
        model_results_dict (dict): Dictionary with model results from 30 independent runs
        confidence_level (float): Confidence level for intervals (default 0.95)
        top_k (int): Number of top trajectories to consider (default 100)
    
    Returns:
        dict: Comprehensive diversity metrics with confidence intervals for each model
    """
    from scipy.stats import t
    import numpy as np
    
    results = {}
    
    for model_name, runs_dict in model_results_dict.items():
        print(f"Processing diversity metrics for {model_name}...")
        
        # Lists to store diversity metrics from each run
        run_diversity_metrics = {
            'avg_dtw_distance': [],
            'normalized_dtw_score': [],
            'avg_euclidean_diversity': [],
            'normalized_euclidean_diversity': [],
            'dtw_reward_stability': [],
            'euclidean_reward_stability': [],
            'coverage_eps_01': [],
            'coverage_eps_05': [],
            'coverage_eps_10': []
        }
        
        # Process each run
        for run_idx in range(30):
            try:
                # Get data for this run
                trajectories = runs_dict[run_idx]['trajectories']
                rewards = runs_dict[run_idx]['rewards']
                
                # Handle different reward formats
                if isinstance(rewards, (list, tuple)):
                    rewards = np.array(rewards, dtype=float)
                elif isinstance(rewards, np.ndarray):
                    rewards = rewards.astype(float)
                
                # Extract final rewards
                if rewards.ndim == 2:
                    final_rewards = rewards[:, -1]
                elif rewards.ndim == 1:
                    final_rewards = rewards
                else:
                    continue
                
                # Remove non-finite values
                final_rewards = final_rewards[np.isfinite(final_rewards)]
                
                if len(final_rewards) == 0:
                    continue
                
                # Get top K trajectories for this run
                top_k_actual = min(top_k, len(final_rewards))
                top_indices = np.argsort(final_rewards)[-top_k_actual:]
                top_trajectories = np.array([trajectories[i] for i in top_indices])
                top_rewards = final_rewards[top_indices]
                
                # Calculate DTW diversity metrics
                try:
                    distance_matrix, feature_dist_matrices = compute_dtw_distance_matrix(top_trajectories)
                    avg_dtw_distance = np.mean(distance_matrix[np.triu_indices_from(distance_matrix, k=1)])
                    
                    # Calculate reward stability for DTW
                    max_r = np.max(top_rewards)
                    min_r = np.min(top_rewards)
                    reward_range = np.abs(max_r - min_r)
                    reward_scale = np.abs(max_r) + np.abs(min_r) + 1e-8
                    dtw_reward_stability = 1 - (reward_range / reward_scale)
                    normalized_dtw_score = avg_dtw_distance * dtw_reward_stability
                    
                    run_diversity_metrics['avg_dtw_distance'].append(avg_dtw_distance)
                    run_diversity_metrics['normalized_dtw_score'].append(normalized_dtw_score)
                    run_diversity_metrics['dtw_reward_stability'].append(dtw_reward_stability)
                    
                except Exception as e:
                    print(f"DTW calculation failed for {model_name} run {run_idx}: {e}")
                    continue
                
                # Calculate Euclidean diversity metrics on final states
                try:
                    if isinstance(top_trajectories, list):
                        top_trajectories = np.array(top_trajectories)
                    final_states = top_trajectories[:, -1, 1:]  # Last timestep, exclude timestamp
                    # Make sure rewards and states are not of object type and without nans or infs
                    final_states = final_states.astype(float)
                    final_states = final_states[np.isfinite(final_states).all(axis=1)]
                    top_rewards = top_rewards[np.isfinite(final_states).all(axis=1)]
                    if len(final_states) == 0 or len(top_rewards) == 0:
                        continue
                    euclidean_res = calculate_quality_diversity(final_states, top_rewards)
                    
                    avg_euclidean_diversity = euclidean_res['average_diversity']
                    normalized_euclidean_score = euclidean_res['normalized_score']
                    
                    # Calculate euclidean reward stability
                    euclidean_reward_stability = 1 - (reward_range / reward_scale)
                    
                    run_diversity_metrics['avg_euclidean_diversity'].append(avg_euclidean_diversity)
                    run_diversity_metrics['normalized_euclidean_diversity'].append(normalized_euclidean_score)
                    run_diversity_metrics['euclidean_reward_stability'].append(euclidean_reward_stability)
                    
                except Exception as e:
                    print(f"Euclidean diversity calculation failed for {model_name} run {run_idx}: {e}")
                    continue
                
                # Calculate coverage metrics at different epsilon values
                try:
                    for eps, key in [(0.1, 'coverage_eps_01'), (0.5, 'coverage_eps_05'), (1.0, 'coverage_eps_10')]:
                        coverage = coverage_epsilon(final_states, eps)
                        run_diversity_metrics[key].append(coverage)
                except Exception as e:
                    print(f"Coverage calculation failed for {model_name} run {run_idx}: {e}")
                    continue
                    
            except Exception as e:
                print(f"Error processing {model_name} run {run_idx}: {e}")
                continue
        
        # Calculate statistics across runs with confidence intervals
        model_diversity_results = {}
        
        for metric_name, values in run_diversity_metrics.items():
            if len(values) == 0:
                continue
                
            values = np.array(values)
            values = values[np.isfinite(values)]  # Remove any NaN values
            
            if len(values) == 0:
                continue
            
            n_runs = len(values)
            mean_val = np.mean(values)
            std_val = np.std(values, ddof=1)  # Sample standard deviation
            
            # Calculate confidence intervals using t-distribution
            if n_runs > 1:
                t_critical = t.ppf((1 + confidence_level) / 2, df=n_runs-1)
                margin_error = t_critical * (std_val / np.sqrt(n_runs))
                ci_lower = mean_val - margin_error
                ci_upper = mean_val + margin_error
            else:
                ci_lower = ci_upper = mean_val
            
            model_diversity_results[metric_name] = {
                'mean': mean_val,
                'std': std_val,
                'ci_lower': ci_lower,
                'ci_upper': ci_upper,
                'n_runs': n_runs,
                'raw_values': values.tolist()
            }
        
        results[model_name] = model_diversity_results
    
    return results

def save_diversity_metrics_to_files(diversity_results, save_dir, confidence_level=0.95):
    """
    Save diversity metrics results to various file formats.
    """
    import os
    import pandas as pd
    import json
    
    os.makedirs(save_dir, exist_ok=True)
    
    # 1. Save comprehensive summary as CSV
    summary_data = []
    for model_name, model_metrics in diversity_results.items():
        row = {'Model': model_name}
        for metric_name, metric_data in model_metrics.items():
            if isinstance(metric_data, dict) and 'mean' in metric_data:
                row[f'{metric_name}_mean'] = metric_data['mean']
                row[f'{metric_name}_std'] = metric_data['std']
                row[f'{metric_name}_ci_lower'] = metric_data['ci_lower']
                row[f'{metric_name}_ci_upper'] = metric_data['ci_upper']
                row[f'{metric_name}_n_runs'] = metric_data['n_runs']
        summary_data.append(row)
    
    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(os.path.join(save_dir, 'diversity_metrics_summary.csv'), index=False)
    
    # 2. Save formatted results for papers/reports
    with open(os.path.join(save_dir, 'diversity_metrics_formatted.txt'), 'w') as f:
        f.write(f"DIVERSITY METRICS WITH {int(confidence_level*100)}% CONFIDENCE INTERVALS\n")
        f.write("="*80 + "\n\n")
        
        for model_name, model_metrics in diversity_results.items():
            f.write(f"{model_name.upper()}:\n")
            f.write("-" * 40 + "\n")
            
            # Key diversity metrics in formatted way
            key_metrics = ['avg_dtw_distance', 'normalized_dtw_score', 'avg_euclidean_diversity', 
                          'normalized_euclidean_diversity', 'coverage_eps_05']
            for metric in key_metrics:
                if metric in model_metrics:
                    data = model_metrics[metric]
                    f.write(f"{metric:25}: {data['mean']:7.3f} ± {data['std']:5.3f} "
                           f"[{data['ci_lower']:6.3f}, {data['ci_upper']:6.3f}] "
                           f"(n={data['n_runs']})\n")
            f.write("\n")
    
    # 3. Save LaTeX table format for diversity metrics
    with open(os.path.join(save_dir, 'diversity_metrics_latex_table.txt'), 'w') as f:
        f.write("% LaTeX Table Format for Diversity Metrics\n")
        f.write("\\begin{table}[htbp]\n")
        f.write("\\centering\n")
        f.write("\\caption{Diversity Metrics with Confidence Intervals}\n")
        f.write("\\begin{tabular}{l|ccccc}\n")
        f.write("\\hline\n")
        f.write("Method & DTW Distance & DTW Normalized & Euclidean Div & Euclidean Norm & Coverage@0.5 \\\\\n")
        f.write("\\hline\n")
        
        for model_name, model_metrics in diversity_results.items():
            row = f"{model_name:<15}"
            key_metrics = ['avg_dtw_distance', 'normalized_dtw_score', 'avg_euclidean_diversity', 
                          'normalized_euclidean_diversity', 'coverage_eps_05']
            for metric in key_metrics:
                if metric in model_metrics:
                    data = model_metrics[metric]
                    row += f" & ${data['mean']:5.2f} \\pm {data['std']:4.2f}$"
                else:
                    row += " & -"
            row += " \\\\\n"
            f.write(row)
        
        f.write("\\hline\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{table}\n")
    
    # 4. Save raw data as JSON
    with open(os.path.join(save_dir, 'diversity_metrics_raw.json'), 'w') as f:
        json.dump(diversity_results, f, indent=2)
    
    print(f"Diversity metrics saved to {save_dir}/")

def plot_diversity_metrics_comparison(diversity_results, save_path=None):
    """
    Create publication-ready plots comparing diversity metrics across models.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Key diversity metrics to plot
    key_metrics = {
        'avg_dtw_distance': 'Average DTW Distance',
        'normalized_dtw_score': 'Normalized DTW Score',
        'avg_euclidean_diversity': 'Average Euclidean Diversity',
        'normalized_euclidean_diversity': 'Normalized Euclidean Diversity',
        'coverage_eps_05': 'Coverage @ ε=0.5'
    }
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12), dpi=300)
    axes = axes.flatten()
    
    colors = {'gflow': '#2ecc71', 'reinforce': '#3498db', 'reinforce_baseline': '#f39c12', 
              'sac': '#9b59b6', 'smcmc': '#e74c3c'}
    
    for idx, (metric_key, metric_name) in enumerate(key_metrics.items()):
        ax = axes[idx]
        
        models = []
        means = []
        errors = []
        
        for model_name, model_metrics in diversity_results.items():
            if metric_key in model_metrics:
                models.append(model_name.upper())
                data = model_metrics[metric_key]
                means.append(data['mean'])
                # Use confidence interval width as error
                error = (data['ci_upper'] - data['ci_lower']) / 2
                errors.append(error)
        
        if means:
            bars = ax.bar(range(len(models)), means, yerr=errors, 
                         capsize=10, alpha=0.7, 
                         color=[colors.get(m.lower(), 'gray') for m in models])
            
            ax.set_title(metric_name, fontsize=14, fontweight='bold')
            ax.set_xticks(range(len(models)))
            ax.set_xticklabels(models, rotation=45, ha='right')
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add value labels on bars
            for bar, mean, error in zip(bars, means, errors):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + error,
                       f'{mean:.2f}', ha='center', va='bottom', fontsize=10)
    
    # Remove empty subplot
    if len(key_metrics) < len(axes):
        fig.delaxes(axes[-1])
    
    plt.suptitle('Diversity Metrics Comparison with Confidence Intervals\n'
                'Error bars show 95% confidence intervals', 
                fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Usage example:
print("Calculating diversity metrics across all runs...")
diversity_results = calculate_diversity_metrics_across_runs(
    model_results_dict_1, 
    confidence_level=0.95, 
    top_k=100
)

# Save results
diversity_save_dir = "results/Comparison/diversity_analysis_1"
save_diversity_metrics_to_files(diversity_results, diversity_save_dir)

# Create comparison plots
plot_diversity_metrics_comparison(diversity_results, 
                                save_path=os.path.join(diversity_save_dir, 'diversity_comparison_plot.pdf'))

# Print summary
print("\n" + "="*80)
print("DIVERSITY METRICS CALCULATION COMPLETED")
print("="*80)
for model_name, model_metrics in diversity_results.items():
    print(f"\n{model_name.upper()}:")
    if 'avg_dtw_distance' in model_metrics:
        data = model_metrics['avg_dtw_distance']
        print(f"  Average DTW Distance: {data['mean']:.3f} ± {data['std']:.3f} "
              f"[{data['ci_lower']:.3f}, {data['ci_upper']:.3f}]")
    if 'normalized_dtw_score' in model_metrics:
        data = model_metrics['normalized_dtw_score']
        print(f"  Normalized DTW Score: {data['mean']:.3f} ± {data['std']:.3f}")
    if 'avg_euclidean_diversity' in model_metrics:
        data = model_metrics['avg_euclidean_diversity']
        print(f"  Average Euclidean Diversity: {data['mean']:.3f} ± {data['std']:.3f}")


###

# Usage example:
print("Calculating diversity metrics across all runs...")
diversity_results = calculate_diversity_metrics_across_runs(
    model_results_dict_2, 
    confidence_level=0.95, 
    top_k=100
)

# Save results
diversity_save_dir = "results/Comparison/diversity_analysis_2"
save_diversity_metrics_to_files(diversity_results, diversity_save_dir)

# Create comparison plots
plot_diversity_metrics_comparison(diversity_results, 
                                save_path=os.path.join(diversity_save_dir, 'diversity_comparison_plot.pdf'))

# Print summary
print("\n" + "="*80)
print("DIVERSITY METRICS CALCULATION COMPLETED")
print("="*80)
for model_name, model_metrics in diversity_results.items():
    print(f"\n{model_name.upper()}:")
    if 'avg_dtw_distance' in model_metrics:
        data = model_metrics['avg_dtw_distance']
        print(f"  Average DTW Distance: {data['mean']:.3f} ± {data['std']:.3f} "
              f"[{data['ci_lower']:.3f}, {data['ci_upper']:.3f}]")
    if 'normalized_dtw_score' in model_metrics:
        data = model_metrics['normalized_dtw_score']
        print(f"  Normalized DTW Score: {data['mean']:.3f} ± {data['std']:.3f}")
    if 'avg_euclidean_diversity' in model_metrics:
        data = model_metrics['avg_euclidean_diversity']
        print(f"  Average Euclidean Diversity: {data['mean']:.3f} ± {data['std']:.3f}")

##

# Usage example:
print("Calculating diversity metrics across all runs...")
diversity_results = calculate_diversity_metrics_across_runs(
    model_results_dict_3, 
    confidence_level=0.95, 
    top_k=100
)

# Save results
diversity_save_dir = "results/Comparison/diversity_analysis_3"
save_diversity_metrics_to_files(diversity_results, diversity_save_dir)

# Create comparison plots
plot_diversity_metrics_comparison(diversity_results, 
                                save_path=os.path.join(diversity_save_dir, 'diversity_comparison_plot.pdf'))

# Print summary
print("\n" + "="*80)
print("DIVERSITY METRICS CALCULATION COMPLETED")
print("="*80)
for model_name, model_metrics in diversity_results.items():
    print(f"\n{model_name.upper()}:")
    if 'avg_dtw_distance' in model_metrics:
        data = model_metrics['avg_dtw_distance']
        print(f"  Average DTW Distance: {data['mean']:.3f} ± {data['std']:.3f} "
              f"[{data['ci_lower']:.3f}, {data['ci_upper']:.3f}]")
    if 'normalized_dtw_score' in model_metrics:
        data = model_metrics['normalized_dtw_score']
        print(f"  Normalized DTW Score: {data['mean']:.3f} ± {data['std']:.3f}")
    if 'avg_euclidean_diversity' in model_metrics:
        data = model_metrics['avg_euclidean_diversity']
        print(f"  Average Euclidean Diversity: {data['mean']:.3f} ± {data['std']:.3f}")

In [None]:
import joblib
joblib.load()

In [None]:
# Load all final rewards data from different models - O2
smcmc = pd.read_csv(fr"results/Comparison/PaperPlots_O2/analysis_SMCMC_20250507_223523/metrics/top_100_summary.csv")
sac = pd.read_csv("results/Comparison/PaperPlots_O2/analysis_SAC_20250506_211020/metrics/top_100_summary.csv")
reinforce_with_baseline = pd.read_csv(fr"results/Comparison/PaperPlots_O2/analysis_REINFORCE_B_20250506_210844/metrics/top_100_summary.csv")
reinforce = pd.read_csv(fr"results/Comparison/PaperPlots_O2/analysis_REINFORCE_20250506_210720/metrics/top_100_summary.csv")
gflownet = pd.read_csv(fr"results/Comparison/PaperPlots_O2/analysis_GFlowNet_20250506_211515/metrics/top_100_summary.csv")
smcmc = smcmc.clip(upper=100,lower=0)
reinforce = reinforce.clip(upper=100,lower=0)
gflownet = gflownet.clip(upper=100,lower=0)
sac = sac.clip(upper=100,lower=0)
reinforce_with_baseline = reinforce_with_baseline.clip(upper=100,lower=0)

In [None]:
# O1
# Load all final rewards data from different models - O2
smcmc = pd.read_csv(fr"results/Comparison/PaperPlots_O1/analysis_SMCMC_20250507_224955/metrics/top_100_summary.csv")
sac = pd.read_csv("results/Comparison/PaperPlots_O1/analysis_SAC_20250506_211947/metrics/top_100_summary.csv")
reinforce_with_baseline = pd.read_csv(fr"results/Comparison/PaperPlots_O1/analysis_REINFORCE_B_20250506_212549/metrics/top_100_summary.csv")
reinforce = pd.read_csv(fr"results/Comparison/PaperPlots_O1/analysis_REINFORCE_20250506_212417/metrics/top_100_summary.csv")
gflownet = pd.read_csv(fr"results/Comparison/PaperPlots_O1/analysis_GFlowNet_20250506_211812/metrics/top_100_summary.csv")
smcmc = smcmc.clip(upper=100,lower=0)
reinforce = reinforce.clip(upper=100,lower=0)
gflownet = gflownet.clip(upper=100,lower=0)
sac = sac.clip(upper=100,lower=0)
reinforce_with_baseline = reinforce_with_baseline.clip(upper=100,lower=0)


In [None]:
plt.figure(figsize=(12, 8), dpi=300)

# Define better color scheme
colors = {
    'GFlowNet': '#2ecc71',      # Emerald Green
    'SMCMC': '#e74c3c',         # Pomegranate Red
    'REINFORCE': '#3498db',      # Peter River Blue
    'SAC': '#9b59b6',           # Amethyst Purple
    'REINFORCE_B': '#f39c12'    # Orange
}

# Plot distributions with enhanced styling and convert to percentages
for name, data, color in [
    ('GFlowNet', gflownet['FinalReward'], colors['GFlowNet']),
    ('SMCMC', smcmc['FinalReward'], colors['SMCMC']),
    ('REINFORCE', reinforce['FinalReward'], colors['REINFORCE']),
    ('SAC', sac['FinalReward'], colors['SAC']),
    ('REINFORCE with Baseline', reinforce_with_baseline['FinalReward'], colors['REINFORCE_B'])
]:
    mean = data.mean()
    std = data.std()
    label = f'{name} (μ={mean:.1f}±{std:.1f})'
    # Multiply by 100 to convert to percentage
    sns.kdeplot(data=data, 
                label=label, 
                color=color,
                linewidth=3.5,
                alpha=0.7)

# Add vertical lines for means with shaded std regions
for data, color in [
    (gflownet['FinalReward'], colors['GFlowNet']),
    (smcmc['FinalReward'], colors['SMCMC']),
    (reinforce['FinalReward'], colors['REINFORCE']),
    (sac['FinalReward'], colors['SAC']),
    (reinforce_with_baseline['FinalReward'], colors['REINFORCE_B'])
]:
    mean = data.mean()
    std = data.std()
    plt.axvline(mean, color=color, linestyle='--', alpha=0.5, linewidth=2)
    plt.axvspan(mean-std, mean+std, color=color, alpha=0.1)

# Customize the plot with larger text
plt.title('Comparison of Final Rewards Distributions',
          fontsize=24, 
          pad=20,
          fontweight='bold')
plt.xlabel('Final Reward Values', fontsize=20, fontweight='bold')
plt.ylabel('Density (%)', fontsize=20, fontweight='bold')

# Enhance grid
plt.grid(True, alpha=0.2, linestyle='--')

# Improve legend with larger text
plt.legend(title='Methods',
          title_fontsize=18,
          fontsize=16,
          loc='upper left',
          frameon=True,
          framealpha=0.9,
          edgecolor='black')

# Increase tick label size
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

# Set background style
plt.gca().set_facecolor('white')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)

# Adjust layout
plt.tight_layout()

# Save with high quality settings
plt.savefig('final_rewards_comparison.pdf', 
            dpi=600, 
            bbox_inches='tight',
            format='pdf',
            metadata={'Creator': 'Python'})

plt.savefig('final_rewards_comparison.png', 
            dpi=600, 
            bbox_inches='tight',
            format='png',
            metadata={'Creator': 'Python'})

plt.show()


In [None]:
gflow = pd.read_csv(r"results/Comparison/PaperPlots_O1/analysis_GFlowNet_20250506_211812/metrics/top_100_summary.csv")
sac = pd.read_csv(r"results/Comparison/PaperPlots_O1/analysis_SAC_20250506_211947/metrics/top_100_summary.csv")

In [None]:
gflow['FinalReward']

In [None]:
sac['FinalReward']

In [None]:
sac[sac['FinalReward']<gflow['FinalReward'].min()].shape[0] / sac.shape[0], gflow[gflow['FinalReward'].min()<sac['FinalReward']].shape[0] / gflow.shape[0]

In [None]:
def find_similar_trajectories(reference_trajectory, trajectories, feature_names, n_closest=5, save_path=None):
    """
    Find trajectories most similar to a reference trajectory using DTW distance.
    
    Args:
        reference_trajectory (np.ndarray): The reference trajectory to compare against
        trajectories (np.ndarray): Array of trajectories to compare
        feature_names (list): List of feature names
        n_closest (int): Number of closest trajectories to return
        save_path (str): Optional path to save visualizations
        
    Returns:
        tuple: (similar_indices, distances, similarity_stats)
    """
    n_trajectories = len(trajectories)
    distances = []
    
    # Calculate DTW distance for each trajectory
    for i in range(n_trajectories):
        dist_matrix, _ = compute_dtw_distance_matrix(
            np.array([reference_trajectory, trajectories[i]])
        )
        distances.append(dist_matrix[0,1])  # Get distance between reference and current trajectory
    
    # Convert to numpy array for easier manipulation
    distances = np.array(distances)
    
    # Get indices of n closest trajectories
    similar_indices = np.argsort(distances)[:n_closest]
    
    # Calculate similarity statistics
    similarity_stats = {
        'mean_distance': np.mean(distances),
        'std_distance': np.std(distances),
        'min_distance': np.min(distances),
        'max_distance': np.max(distances),
        'median_distance': np.median(distances)
    }
    
    # Plot comparison of reference with closest trajectories
    if save_path:
        for idx, similar_idx in enumerate(similar_indices):
            # Plot each feature separately
            n_features = reference_trajectory.shape[1] - 1  # Exclude timestamp column
            fig, axes = plt.subplots(n_features, 1, figsize=(15, 4*n_features))
            
            # Get timestamps
            timestamps = reference_trajectory[:, 0]
            
            for i in range(n_features):
                axes[i].plot(timestamps, reference_trajectory[:, i+1], 'b-', 
                           label='Reference', linewidth=2)
                axes[i].plot(timestamps, trajectories[similar_idx, :, i+1], 'r--', 
                           label=f'Similar Trajectory {similar_idx}', linewidth=2)
                
                axes[i].set_xlabel('Time')
                axes[i].set_ylabel(feature_names[i+1])
                axes[i].grid(True, alpha=0.3)
                axes[i].legend()
            
            plt.suptitle(f'Comparison with Similar Trajectory {similar_idx}\n' + 
                        f'DTW Distance: {distances[similar_idx]:.2f}')
            plt.tight_layout()
            
            if save_path:
                plt.savefig(os.path.join(save_path, f'similar_trajectory_{idx+1}.pdf'))
            plt.close()
        
        # Create distribution plot of distances
        plt.figure(figsize=(10, 6))
        sns.histplot(distances, kde=True)
        plt.axvline(np.mean(distances), color='r', linestyle='--', label='Mean')
        plt.axvline(np.median(distances), color='g', linestyle='--', label='Median')
        for dist in distances[similar_indices]:
            plt.axvline(dist, color='b', linestyle=':', alpha=0.5)
        plt.title('Distribution of DTW Distances to Reference Trajectory')
        plt.xlabel('DTW Distance')
        plt.ylabel('Count')
        plt.legend()
        plt.savefig(os.path.join(save_path, 'distance_distribution.pdf'))
        plt.close()
    
    return similar_indices, distances[similar_indices], similarity_stats

# Example usage:
reference_traj = trajectories[0]  # Use the first trajectory as reference
other_trajectories = trajectories[1:]  # Use the rest as candidates
output_dir = "similarity_analysis"
os.makedirs(output_dir, exist_ok=True)

similar_indices, distances, stats = find_similar_trajectories(
    reference_trajectory=reference_traj,
    trajectories=other_trajectories,
    feature_names=feature_names,
    n_closest=5,
    save_path=output_dir
)

print("\nMost similar trajectories:")
for idx, (similar_idx, distance) in enumerate(zip(similar_indices, distances)):
    print(f"{idx+1}. Trajectory {similar_idx}: DTW distance = {distance:.2f}")

print("\nSimilarity Statistics:")
for stat_name, stat_value in stats.items():
    print(f"{stat_name}: {stat_value:.2f}")

In [None]:
feature_names

In [None]:
# Calculate statistics for each method
methods = ['GFlowNet', 'SMCMC', 'REINFORCE', 'SAC', 'REINFORCE with Baseline']
dfs = [gflownet['Final Rewards'], smcmc['Final Rewards'], 
    reinforce['Final Rewards'], sac['Final Rewards'], 
    reinforce_with_baseline['Final Rewards']]

stats = {}
for method, df in zip(methods, dfs):
    sorted_rewards = sorted(df.values, reverse=True)
    stats[method] = {
     'Top 10 Mean': np.mean(sorted_rewards[:10]),
     'Top 50 Mean': np.mean(sorted_rewards[:50]),
     'Top 100 Mean': np.mean(sorted_rewards[:100]),
     '10th Percentile': np.percentile(sorted_rewards, 10),
     '90th Percentile': np.percentile(sorted_rewards, 90)
    }

# Print results in a copy-friendly format
print("Method & Top 10 Mean & Top 50 Mean & Top 100 Mean & 10th Percentile & 90th Percentile \\\\")
print("\\hline")
for method in methods:
    row = f"{method:<20} & " + \
          f"{stats[method]['Top 10 Mean']:>8.2f} & " + \
          f"{stats[method]['Top 50 Mean']:>8.2f} & " + \
          f"{stats[method]['Top 100 Mean']:>8.2f} & " + \
          f"{stats[method]['10th Percentile']:>8.2f} & " + \
          f"{stats[method]['90th Percentile']:>8.2f} \\\\"
    print(row)



In [None]:
# Example usage:
# Create DataFrame from the final states (last timestep of each trajectory)
final_states_df = pd.DataFrame(
    trajectories[:,-1,1:],  # Take last timestep (-1) and exclude TimeStep column (1:)
    columns=feature_names[1:]  # Use feature names but exclude 'TimeStep'
)
final_states_df = final_states_df.iloc[gflownet['index'],:]
final_states_df.reset_index(drop=True, inplace=True)

# Perform clustering analysis
optimal_clusters, cluster_labels, tsne_results = perform_clustering_analysis(
    final_states_df, 
    feature_names,
    n_clusters_range=(2, 10)
)


In [None]:
# Create scatter plot
plt.figure(figsize=(12, 8))
scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], 
                     c=cluster_labels, cmap='viridis', 
                     alpha=0.6)

# Add colorbar
plt.colorbar(scatter, label='Cluster')

# Add title and labels
plt.title(f't-SNE Visualization with {optimal_clusters} Clusters')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')


In [None]:


# Example usage:
last_step_rewards = [r[-1] for r in rewards]
cluster_stats = plot_reward_distributions_by_cluster(
    cluster_labels=cluster_labels,
    rewards=last_step_rewards,
    optimal_clusters=optimal_clusters,
    save_path='cluster_reward_distributions.pdf'
)

# Print summary statistics
print("\nCluster Summary Statistics:")
for cluster, stats in cluster_stats.items():
    print(f"\n{cluster}:")
    for metric, value in stats.items():
        print(f"{metric}: {value:.2f}")

In [None]:


# Example usage:
diversity_scores = analyze_cluster_diversity(
    trajectories=trajectories, 
    cluster_labels=cluster_labels,
    optimal_clusters=optimal_clusters,
    last_step_rewards=last_step_rewards,
    feature_names=feature_names
)

# Print summary of diversity scores
print("\nCluster Diversity Summary:")
for cluster_id, div_score, norm_div in diversity_scores:
    print(f"Cluster {cluster_id + 1}:")
    print(f"  Raw Diversity Score: {div_score:.4f}")
    print(f"  Normalized Diversity: {norm_div:.4f}")


In [None]:


# Example usage:
feature_std_df = analyze_cluster_feature_variability(
    final_states_df=final_states_df,
    cluster_labels=cluster_labels,
    optimal_clusters=optimal_clusters,
    feature_names=feature_names
)

# Print summary statistics
print("\nCluster Feature Variability Summary:")
for _, row in feature_std_df.iterrows():
    print(f"\n{row['Cluster']} (Size: {row['Size']}):")
    for feature in feature_names[1:]:
        print(f"  {feature}: {row[feature]:.4f}")


# 6. Trajectory Phase Portraits


In [None]:
sub_traj = trajectories[gflownet['index'].values]

In [None]:


# Example usage
fig = plot_phase_portraits(
    trajectories=sub_traj,
    feature_names=feature_names,
    n_trajectories=100,
    alpha=0.1
)

# Save the figure
fig.savefig('phase_portraits.pdf', bbox_inches='tight', dpi=300)
plt.show()

# All at once

In [None]:
import joblib
data = joblib.load(r"results/All_Experiments_Results/smcmc_1_ens/trajectories_results.joblib")

In [None]:
data[0] = data[0]['results_dict']

In [None]:
data[0].keys()

In [None]:
def load_all_experimental_results(base_path="results/All_Experiments_Results",model_types=None):
    """
    Load all experimental results from the structured folder.
    
    Returns:
        dict: Nested dictionary with structure [oracle][model_type][method] = results
    """
    import joblib
    from pathlib import Path
    
    results = {}
    base_path = Path(base_path)
    
    # Define the structure
    oracles = [1, 2, 3]
    methods = ['gflow', 'reinforce', 'sac', 'smcmc']
    # methods = ['smcmc']
    model_types = ['gbr', 'mlp', 'rfr','elastic','ens'] if model_types is None else model_types
    
    for oracle in oracles:
        results[oracle] = {}
        for model_type in model_types:
            results[oracle][model_type] = {}
            for method in methods:
                folder_name = f"{method}_{oracle}_{model_type}"
                folder_path = base_path / folder_name
                
                if folder_path.exists():
                    # ALWAYS prioritize trajectories_results.joblib for trajectory data
                    result_files = list(folder_path.glob("*trajectories_results.joblib"))
                    
                    # If no trajectories file found, look for method-specific files as backup
                    if not result_files:
                        if method == 'gflow':
                            result_files = list(folder_path.glob(f"gflow_{oracle}_{model_type}_trajectories_results.joblib"))
                        elif method == 'reinforce_baseline':
                            result_files = list(folder_path.glob("reinforce_baseline_results.joblib"))
                        elif method == 'reinforce':
                            result_files = list(folder_path.glob("trajectories_results.joblib"))
                        elif method == 'sac':
                            result_files = list(folder_path.glob("trajectories_results.joblib"))
                        elif method == 'smcmc':
                            result_files = list(folder_path.glob("trajectories_results.joblib"))
                            if len(result_files) == 0:
                                result_files = list(folder_path.glob("smcmc_results.joblib"))
                    
                    if result_files:
                        try:
                            data = joblib.load(result_files[0])
                            
                            # Verify the data contains trajectories
                            if isinstance(data, dict):
                                # Check if it's the 30-run structure
                                if all(isinstance(k, int) for k in data.keys() if isinstance(k, (int, str))):
                                    # Check if runs contain trajectory data
                                    sample_run = None
                                    for key, value in data.items():
                                        if isinstance(key, int) and isinstance(value, dict):
                                            sample_run = value
                                            break
                                    
                                    if sample_run and 'trajectories' in sample_run:
                                        results[oracle][model_type][method] = data
                                        print(f"✓ Loaded: {folder_name} (30-run structure)")
                                    else:
                                        try:
                                            sample_run = sample_run['results_dict']
                                            results[oracle][model_type][method] = data
                                            print(f"✓ Loaded: {folder_name} (30-run structure) with nested results_dict")
                                        except:
                                            print(f"⚠ {folder_name}: No trajectory data in runs structure")
                                            continue
                                        
                                # Check if it's direct trajectory data structure
                                elif 'trajectories' in data:
                                    results[oracle][model_type][method] = data
                                    print(f"✓ Loaded: {folder_name} (direct structure)")

                                else:
                                        
                                    print(f"⚠ {folder_name}: No trajectories found in data structure")
                                    # Print keys for debugging
                                    print(f"    Available keys: {list(data.keys())}")
                                    continue
                            else:
                                print(f"⚠ {folder_name}: Data is not a dictionary")
                                continue
                                
                        except Exception as e:
                            print(f"✗ Error loading {folder_name}: {e}")
                    else:
                        print(f"⚠ No valid result files found in: {folder_name}")
                        # List available files for debugging
                        available_files = list(folder_path.glob("*.joblib"))
                        if available_files:
                            print(f"    Available files: {[f.name for f in available_files]}")
                else:
                    print(f"⚠ Folder not found: {folder_name}")
    
    return results

test = load_all_experimental_results()

In [None]:

def run_comprehensive_analysis_for_combination(oracle, model_type, method_data, 
                                               base_output_dir, feature_names):
    """
    Run comprehensive analysis for a specific oracle-model_type-method combination.
    
    Args:
        oracle: Oracle number (1, 2, or 3)
        model_type: Model type ('gbr', 'mlp', or 'rfr')
        method_data: Dictionary containing method results (30 runs each)
        base_output_dir: Base directory to save results
        feature_names: List of feature names
    
    Returns:
        dict: Analysis results and metrics
    """
    import numpy as np
    import pandas as pd
    import os
    from datetime import datetime
    
    combination_results = {}
    
    for method_name, runs_data in method_data.items():
        if runs_data is None:
            continue
            
        print(f"  Analyzing {method_name.upper()}...")
        
        # Create method-specific output directory
        method_output_dir = os.path.join(base_output_dir, f"{method_name}")
        plots_dir = os.path.join(method_output_dir, 'plots')
        metrics_dir = os.path.join(method_output_dir, 'metrics')
        action_plots_dir = os.path.join(method_output_dir, 'action_distributions')
        
        for dir_path in [method_output_dir, plots_dir, metrics_dir, action_plots_dir]:
            os.makedirs(dir_path, exist_ok=True)
        
        try:
            # Check if this is the 30-runs structure {0: {...}, 1: {...}, ...}
            if isinstance(runs_data, dict) and all(isinstance(k, int) for k in runs_data.keys()):
                # This is the 30-runs structure
                print(f"    Found 30-run structure with {len(runs_data)} runs")
                
                # NEW APPROACH: Collect random percentage from each run separately
                all_sampled_trajectories = []
                all_sampled_rewards = []
                all_sampled_actions = []
                sample_percentage = 0.05  # Take random 30% from each run (adjust as needed)
                
                for run_idx in range(len(runs_data)):
                    print(f"    Processing run {run_idx+1}/{len(runs_data)}")
                    if run_idx not in runs_data:
                        print(f"    ⚠ Run index {run_idx} missing in data, skipping...")   
                        continue
                        
                    run_data = runs_data[run_idx]
                    if 'trajectories' not in run_data:
                        try:
                            run_data = run_data['results_dict']
                        except:
                            print(f"    ⚠ No trajectory data in run {run_idx}, skipping...")
                            continue
                    # Extract data based on method
                    if 'trajectories' in run_data:
                        trajectories = run_data['trajectories']
                        rewards = run_data['rewards']
                        
                        # Handle actions with different key names
                        if 'all_actions' in run_data:
                            actions = run_data['all_actions']
                        elif 'actions' in run_data:
                            actions = run_data['actions']
                        else:
                            actions = []
                        
                        # Convert to numpy arrays for this run
                        if isinstance(trajectories, (list, np.ndarray)):
                            trajectories = np.array(trajectories)
                        if isinstance(rewards, (list, np.ndarray)):
                            rewards = np.array(rewards)
                        if len(actions) > 0:
                            actions = np.array(actions)
                        
                        # Process rewards based on method type
                        if method_name in ['reinforce', 'sac']:
                            # RL methods: rewards are already final values (1D)
                            if rewards.ndim == 1:
                                final_rewards_run = rewards
                            elif rewards.ndim == 2 and rewards.shape[1] == 1:
                                final_rewards_run = rewards[:, 0]
                            else:
                                print(f"    ⚠ Unexpected reward format for RL method {method_name}: {rewards.shape}")
                                continue
                        else:
                            # GFlow/SMCMC methods: rewards are trajectories (2D)
                            if rewards.ndim == 2:
                                final_rewards_run = rewards[:, -1]  # Take last timestep
                            elif rewards.ndim == 1:
                                final_rewards_run = rewards  # Already final values
                            else:
                                print(f"    ⚠ Unexpected reward format for method {method_name}: {rewards.shape}")
                                continue
                        
                        # FIX: Ensure final_rewards_run is numeric before using np.isfinite()
                        try:
                            # Convert to float64 to ensure numeric type
                            final_rewards_run = np.array(final_rewards_run, dtype=np.float64)
                        except (ValueError, TypeError) as e:
                            print(f"    ⚠ Cannot convert rewards to numeric for {method_name} run {run_idx}: {e}")
                            print(f"    Rewards sample: {final_rewards_run[:5] if len(final_rewards_run) > 0 else 'empty'}")
                            continue
                        # Remove non-finite rewards
                        finite_mask = np.isfinite(final_rewards_run)
                        final_rewards_run = final_rewards_run[finite_mask]
                        trajectories_run = trajectories[finite_mask]
                        if len(actions) > 0:
                            actions_run = actions[finite_mask]
                        else:
                            actions_run = []
                        
                        if len(final_rewards_run) == 0:
                            print(f"    ⚠ No valid rewards for run {run_idx}")
                            continue
                        
                        # Get random X% from this specific run
                        n_trajectories_run = len(final_rewards_run)
                        n_sample_run = max(1, int(n_trajectories_run * sample_percentage))
                        
                        # Get random indices from this run
                        np.random.seed(42 + run_idx)  # Reproducible but different per run
                        sample_indices_run = np.random.choice(n_trajectories_run, 
                                                            size=n_sample_run, 
                                                            replace=False)
                        
                        # Extract sampled trajectories from this run
                        sample_trajectories_run = trajectories_run[sample_indices_run]
                        sample_rewards_run = final_rewards_run[sample_indices_run]
                        if len(actions_run) > 0:
                            sample_actions_run = actions_run[sample_indices_run]
                        else:
                            sample_actions_run = []
                        
                        # Add to combined lists
                        all_sampled_trajectories.extend(sample_trajectories_run)
                        all_sampled_rewards.extend(sample_rewards_run)
                        if len(sample_actions_run) > 0:
                            all_sampled_actions.extend(sample_actions_run)
                        
                        print(f"    Run {run_idx+1}: Randomly sampled {n_sample_run}/{n_trajectories_run} trajectories ({sample_percentage*100:.0f}%)")
                
                # Convert to numpy arrays
                trajectories = np.array(all_sampled_trajectories)
                final_rewards = np.array(all_sampled_rewards)
                all_actions = np.array(all_sampled_actions) if all_sampled_actions else []
                
                print(f"    Combined sampled trajectories: {len(trajectories)} from all runs")
                print(f"    Combined sampled rewards: {len(final_rewards)} from all runs")
                
            else:
                # This is single dataset structure (shouldn't happen with current structure)
                print(f"    Found single dataset structure")
                trajectories = runs_data['trajectories']
                rewards = runs_data['rewards']
                all_actions = runs_data.get('all_actions', runs_data.get('actions', []))
                
                # Process single dataset
                rewards = np.array(rewards) if isinstance(rewards, list) else rewards
                if method_name in ['reinforce', 'sac']:
                    if rewards.ndim == 1:
                        final_rewards = rewards
                    elif rewards.ndim == 2 and rewards.shape[1] == 1:
                        final_rewards = rewards[:, 0]
                    else:
                        final_rewards = rewards.flatten()
                else:
                    if rewards.ndim == 2:
                        final_rewards = rewards[:, -1]
                    elif rewards.ndim == 1:
                        final_rewards = rewards
                    else:
                        final_rewards = np.array([r[-1] if isinstance(r, (list, np.ndarray)) and len(r) > 0 else r for r in rewards])
            
            if len(trajectories) == 0:
                print(f"    ⚠ No trajectories found for {method_name}")
                continue
            
            # Now work with the sampled trajectories (random % from each run)
            final_predictions = np.array([-r for r in final_rewards])
            print(f"Final rewards shape: {final_rewards.shape}, Final predictions shape: {final_predictions.shape}")
            
            # For analysis, take top N from the randomly sampled set for specific analyses
            n_top = min(100, len(final_rewards))
            if n_top < len(final_rewards):
                # If we have more than 200, take the best 200 from the sampled set for some analyses
                top_indices = np.argsort(final_predictions)[:n_top]
                trajectories_top = np.array([trajectories[i] for i in top_indices])
                final_rewards_top = final_rewards[top_indices]
                final_predictions_top = final_predictions[top_indices]
            else:
                # Use all sampled trajectories
                trajectories_top = trajectories
                final_rewards_top = final_rewards
                final_predictions_top = final_predictions
                
            print(f"    Using {len(trajectories_top)} trajectories for analysis")
            
            # Calculate metrics on the randomly sampled data
            metrics_dict = {
                'average_reward': np.mean(final_rewards),  # Use full sampled data for metrics
                'max_reward': np.max(final_rewards),
                'min_reward': np.min(final_rewards),
                'std_reward': np.std(final_rewards),
                'tail_coverage_20': np.mean(final_predictions >= -20),
                'tail_coverage_40': np.mean(final_predictions >= -40),
                'tail_coverage_60': np.mean(final_predictions >= -60),
                'expected_shortfall_5': np.mean(np.sort(final_predictions)[:max(1, len(final_predictions)//20)]) if len(final_predictions) >= 5 else np.mean(final_predictions),
                'expected_shortfall_10': np.mean(np.sort(final_predictions)[:max(1, len(final_predictions)//10)]) if len(final_predictions) >= 10 else np.mean(final_predictions),
                'expected_shortfall_20': np.mean(np.sort(final_predictions)[:max(1, len(final_predictions)//5)]) if len(final_predictions) >= 20 else np.mean(final_predictions)
            }
            
            # Save metrics
            with open(os.path.join(metrics_dir, 'performance_metrics.txt'), 'w') as f:
                for metric, value in metrics_dict.items():
                    f.write(f"{metric}: {value:.4f}\n")
                f.write(f"sample_percentage_per_run: {sample_percentage:.2f}\n")
                f.write(f"total_analyzed_trajectories: {len(trajectories)}\n")
                f.write(f"sampling_method: random_per_run\n")
            
            # Rewards summary
            rewards_dict = {
                'min': np.min(final_rewards),
                'max': np.max(final_rewards),
                'mean': np.mean(final_rewards),
                'std': np.std(final_rewards),
                'median': np.median(final_rewards)
            }
            
            with open(os.path.join(metrics_dir, 'rewards_summary.txt'), 'w') as f:
                for metric, value in rewards_dict.items():
                    f.write(f"{metric}: {value:.4f}\n")
            
            # DTW clustering analysis (if we have enough data) - use top trajectories for computational efficiency
            if len(trajectories_top) >= 3:
                try:
                    results_df, distance_matrix, feature_clusters = dtw_clustering_analysis(
                        trajectories=trajectories_top,
                        n_clusters=min(3, len(trajectories_top))
                    )
                    
                    # Clustering plots
                    plot_trajectories_by_cluster(results_df, trajectories_top, ['Timestamp']+feature_names,
                                               num_clusters=min(3, len(trajectories_top)),
                                               save_path=os.path.join(plots_dir, 'cluster_trajectories.pdf'))
                    
                    # DTW distance matrix
                    res_dict_dtw = plot_dtw_distance_matrix(distance_matrix, model_name=method_name, 
                                                          rewards=final_rewards_top,
                                                          save_path=os.path.join(plots_dir, 'dtw_distance_matrix.pdf'))
                    
                    # Save DTW metrics
                    with open(os.path.join(metrics_dir, 'dtw_metrics.txt'), 'w') as f:
                        f.write(f"Average Distance: {res_dict_dtw['average_distance']:.4f}\n")
                        f.write(f"Normalized Score: {res_dict_dtw['normalized_score']:.4f}\n")
                        f.write(f"Reward Normalization: {res_dict_dtw['reward_normalization']:.4f}\n")
                    
                except Exception as e:
                    print(f"    ⚠ DTW analysis failed for {method_name}: {e}")
            
            # Trajectory plots - use top trajectories for visualization
            try:
                plot_trajectories_over_time(
                    trajectories=trajectories_top,
                    rewards=final_rewards_top,
                    feature_names=['Timestamp'] + feature_names,
                    n_top=min(50, len(trajectories_top)),
                    alpha_others=0.05,
                    save_path=plots_dir
                )
            except Exception as e:
                print(f"    ⚠ Trajectory plotting failed for {method_name}: {e}")
            
            # Action distribution plots (if actions available) - use all sampled data
            if len(all_actions) > 0:
                try:
                    feature_names_for_actions = ['Timestamp'] + feature_names if method_name != 'smcmc' else feature_names
                    
                    plot_action_distributions_by_timestep(
                        trajectories=trajectories,
                        all_actions=all_actions,
                        feature_names=feature_names_for_actions,
                        save_path=action_plots_dir
                    )
                    
                    plot_action_distributions(
                        all_actions=all_actions,
                        feature_names=feature_names_for_actions,
                        save_path=os.path.join(plots_dir, 'distribution_actions.pdf')
                    )
                except Exception as e:
                    print(f"    ⚠ Action distribution plotting failed for {method_name}: {e}")
            
            # Summary DataFrame - use top trajectories for the summary table
            try:
                summary_rows = []
                for rank, idx in enumerate(top_indices):
                    row_dict = {
                        "Rank": rank + 1,
                        "CaseIndex": idx,
                        "FinalReward": final_rewards[idx]
                    }
                    
                    # Handle final values extraction
                    if trajectories[idx].ndim == 2:  # (timesteps, features)
                        final_values = trajectories[idx][-1, :]
                    else:
                        final_values = trajectories[idx]
                    
                    # Skip timestamp column if present
                    feature_start_idx = 1 if len(final_values) == len(feature_names) + 1 else 0
                    
                    for f_i, f_name in enumerate(feature_names):
                        if feature_start_idx + f_i < len(final_values):
                            row_dict[f_name] = final_values[feature_start_idx + f_i]
                    
                    summary_rows.append(row_dict)
                
                top_df = pd.DataFrame(summary_rows).sort_values(by="FinalReward", ascending=False)
                top_df.to_csv(os.path.join(metrics_dir, f'top_{n_top}_summary.csv'), index=False)
                
                # Diversity metrics - use top trajectories for computational efficiency
                if len(feature_names) > 0:
                    try:
                        avg_div, norm_div, distances = plot_diversity_metrics(
                            df=top_df,
                            feature_names=feature_names,
                            save_path=os.path.join(plots_dir, 'euclidean_distance_matrix.pdf')
                        )
                        
                        with open(os.path.join(metrics_dir, 'diversity_metrics.txt'), 'w') as f:
                            f.write(f"Average Diversity - Last State: {avg_div:.4f}\n")
                            f.write(f"Normalized Diversity - Last State: {norm_div:.4f}\n")
                    except Exception as e:
                        print(f"    ⚠ Diversity analysis failed for {method_name}: {e}")
            
            except Exception as e:
                print(f"    ⚠ Summary analysis failed for {method_name}: {e}")
            
            # Reward distribution plot - use all sampled data
            try:
                plot_reward_distribution(final_rewards, 
                                        save_path=os.path.join(plots_dir, 'reward_distribution.pdf'))
            except Exception as e:
                print(f"    ⚠ Reward distribution plotting failed for {method_name}: {e}")
            
            combination_results[method_name] = {
                'metrics': metrics_dict,
                'rewards_summary': rewards_dict,
                'output_dir': method_output_dir,
                'n_trajectories': len(trajectories),
                'n_top_analyzed': len(trajectories_top),
                'sample_percentage_used': sample_percentage,
                'selection_method': 'random_per_run'
            }
            
            print(f"    ✓ {method_name.upper()}: avg_reward={metrics_dict['average_reward']:.2f} (analyzed {len(trajectories)} randomly sampled trajectories from {sample_percentage*100:.0f}% of each run)")
            
        except Exception as e:
            print(f"    ✗ Error processing {method_name}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    return combination_results

def generate_comparison_tables(all_results, output_dir):
    """
    Generate comprehensive comparison tables across all oracles and model types.
    """
    import pandas as pd
    import os
    import numpy as np
    
    tables_dir = os.path.join(output_dir, 'comparison_tables')
    os.makedirs(tables_dir, exist_ok=True)
    
    # Key metrics to compare
    key_metrics = ['average_reward', 'max_reward', 'tail_coverage_20', 'expected_shortfall_10', 'expected_shortfall_20']
    methods = ['gflow', 'reinforce', 'reinforce_baseline', 'sac', 'smcmc']
    
    # 1. Create comparison table for each metric
    for metric in key_metrics:
        print(f"Creating comparison table for {metric}...")
        
        rows = []
        
        for oracle in [1, 2, 3]:
            for model_type in ['gbr', 'mlp', 'rfr','elastic','ens']:
                row_data = {'Oracle': oracle, 'Model_Type': model_type}
                
                combination_key = f"Oracle_{oracle}_{model_type.upper()}"
                
                if combination_key in all_results:
                    for method in methods:
                        if method in all_results[combination_key]:
                            value = all_results[combination_key][method]['metrics'].get(metric, np.nan)
                            row_data[method.upper()] = value
                        else:
                            row_data[method.upper()] = np.nan
                else:
                    # Fill with NaN if combination doesn't exist
                    for method in methods:
                        row_data[method.upper()] = np.nan
                
                rows.append(row_data)
        
        # Create DataFrame and save
        metric_df = pd.DataFrame(rows)
        metric_df.to_csv(os.path.join(tables_dir, f'{metric}_comparison_table.csv'), index=False)
    
    # 2. Create master summary table
    summary_rows = []
    for combination_key, combination_data in all_results.items():
        oracle, model_type = combination_key.split('_')[1], combination_key.split('_')[2]
        
        for method, method_data in combination_data.items():
            row = {
                'Oracle': oracle,
                'Model_Type': model_type,
                'Method': method.upper(),
                'Avg_Reward': method_data['metrics'].get('average_reward', np.nan),
                'Max_Reward': method_data['metrics'].get('max_reward', np.nan),
                'Std_Reward': method_data['metrics'].get('std_reward', np.nan),
                'Tail_Coverage_20': method_data['metrics'].get('tail_coverage_20', np.nan),
                'Expected_Shortfall_10': method_data['metrics'].get('expected_shortfall_10', np.nan),
                'Expected_Shortfall_20': method_data['metrics'].get('expected_shortfall_20', np.nan),
                'N_Trajectories': method_data.get('n_trajectories', 'N/A'),
                'N_Top_Analyzed': method_data.get('n_top_analyzed', 'N/A')
            }
            summary_rows.append(row)
    
    summary_df = pd.DataFrame(summary_rows)
    summary_df.to_csv(os.path.join(tables_dir, 'master_summary_table.csv'), index=False)
    
    # 3. Create ranking tables
    ranking_data = []
    for combination_key, combination_data in all_results.items():
        oracle, model_type = combination_key.split('_')[1], combination_key.split('_')[2]
        
        # Rank methods by average reward
        method_rewards = [(method, data['metrics'].get('average_reward', -np.inf)) 
                         for method, data in combination_data.items()]
        method_rewards.sort(key=lambda x: x[1], reverse=True)
        
        ranking_data.append({
            'Oracle': oracle,
            'Model_Type': model_type,
            'Best_Method': method_rewards[0][0].upper() if method_rewards else 'N/A',
            'Best_Reward': method_rewards[0][1] if method_rewards else np.nan,
            'Second_Method': method_rewards[1][0].upper() if len(method_rewards) > 1 else 'N/A',
            'Second_Reward': method_rewards[1][1] if len(method_rewards) > 1 else np.nan,
            'Worst_Method': method_rewards[-1][0].upper() if method_rewards else 'N/A',
            'Worst_Reward': method_rewards[-1][1] if method_rewards else np.nan,
            'Methods_Count': len(method_rewards)
        })
    
    ranking_df = pd.DataFrame(ranking_data)
    ranking_df.to_csv(os.path.join(tables_dir, 'method_rankings.csv'), index=False)
    
    print(f"✓ All comparison tables saved to {tables_dir}")
    return summary_df, ranking_df

def main():
    """
    Main function to run the comprehensive analysis.
    """
    print("COMPREHENSIVE EXPERIMENTAL ANALYSIS")
    print("="*80)
    
    # Step 1: Load all experimental results
    print("Step 1: Loading all experimental results...")
    all_results = load_all_experimental_results(model_types=['ens'])
    
    # Step 2: Setup feature names
    feature_names = ['Volume_spx', 'Close_ndx', 'Volume_ndx', 'Close_vix', 'IRLTCT01USM156N', 'BAMLH0A3HYCEY']
    
    # Step 3: Create output directory
    from datetime import datetime
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    main_output_dir = f"analysis_outputs/Comprehensive_Analysis_{timestamp}"
    os.makedirs(main_output_dir, exist_ok=True)
    
    print(f"Output directory: {main_output_dir}")
    
    # Step 4: Run analysis for each combination
    print("Step 4: Running comprehensive analysis for all combinations...")
    
    analysis_results = {}
    successful = []
    failed = []
    
    for oracle in [1, 2, 3]:
        for model_type in ['gbr', 'mlp', 'rfr', 'elastic', 'ens']:
            combination_key = f"Oracle_{oracle}_{model_type.upper()}"
            print(f"\n{'='*60}")
            print(f"PROCESSING: {combination_key}")
            print(f"{'='*60}")
            
            try:
                # Create combination output directory
                combo_output_dir = os.path.join(main_output_dir, combination_key)
                os.makedirs(combo_output_dir, exist_ok=True)
                
                # Get method data for this combination
                method_data = all_results.get(oracle, {}).get(model_type, {})
                
                if method_data:
                    # Run analysis for this combination
                    combination_results = run_comprehensive_analysis_for_combination(
                        oracle, model_type, method_data, combo_output_dir, feature_names
                    )
                    
                    if combination_results:  # Only add if we got some results
                        analysis_results[combination_key] = combination_results
                        successful.append(combination_key)
                        print(f"✓ Successfully processed {combination_key} with {len(combination_results)} methods")
                    else:
                        print(f"⚠ No results generated for {combination_key}")
                        failed.append(combination_key)
                else:
                    print(f"⚠ No data found for {combination_key}")
                    failed.append(combination_key)
                    
            except Exception as e:
                print(f"✗ Failed to process {combination_key}: {e}")
                import traceback
                traceback.print_exc()
                failed.append(combination_key)
                continue
    
    # Step 5: Generate comparison tables
    print(f"\n{'='*60}")
    print("GENERATING COMPARISON TABLES")
    print(f"{'='*60}")
    
    if analysis_results:
        summary_df, ranking_df = generate_comparison_tables(analysis_results, main_output_dir)
        
        # Print some summary statistics
        print(f"\nSummary of results:")
        print(f"Total method-combination pairs analyzed: {len(summary_df)}")
        print(f"Methods found across all combinations:")
        for method in summary_df['Method'].unique():
            count = len(summary_df[summary_df['Method'] == method])
            print(f"  {method}: {count} combinations")
    else:
        print("⚠ No results to generate tables from")
    
    # Step 6: Print final summary
    print(f"\n{'='*80}")
    print("ANALYSIS COMPLETE - FINAL SUMMARY")
    print(f"{'='*80}")
    print(f"Results saved to: {main_output_dir}")
    print(f"Successful analyses: {len(successful)}")
    print(f"Failed analyses: {len(failed)}")
    
    print("\nSuccessful combinations:")
    for success in successful:
        methods_count = len(analysis_results.get(success, {}))
        print(f"  ✓ {success} ({methods_count} methods)")
    
    if failed:
        print("\nFailed combinations:")
        for failure in failed:
            print(f"  ✗ {failure}")
    
    return analysis_results, main_output_dir

def run_comprehensive_analysis_for_combination(oracle, model_type, method_data, 
                                               base_output_dir, feature_names):
    """
    Run comprehensive analysis for a specific oracle-model_type-method combination.
    
    Args:
        oracle: Oracle number (1, 2, or 3)
        model_type: Model type ('gbr', 'mlp', or 'rfr')
        method_data: Dictionary containing method results (30 runs each)
        base_output_dir: Base directory to save results
        feature_names: List of feature names
    
    Returns:
        dict: Analysis results and metrics
    """
    import numpy as np
    import pandas as pd
    import os
    from datetime import datetime
    
    combination_results = {}
    
    for method_name, runs_data in method_data.items():
        if runs_data is None:
            continue
            
        print(f"  Analyzing {method_name.upper()}...")
        
        # Create method-specific output directory
        method_output_dir = os.path.join(base_output_dir, f"{method_name}")
        plots_dir = os.path.join(method_output_dir, 'plots')
        metrics_dir = os.path.join(method_output_dir, 'metrics')
        action_plots_dir = os.path.join(method_output_dir, 'action_distributions')
        
        for dir_path in [method_output_dir, plots_dir, metrics_dir, action_plots_dir]:
            os.makedirs(dir_path, exist_ok=True)
        
        try:
            # Check if this is the 30-runs structure {0: {...}, 1: {...}, ...}
            if isinstance(runs_data, dict) and all(isinstance(k, int) for k in runs_data.keys()):
                # This is the 30-runs structure
                print(f"    Found 30-run structure with {len(runs_data)} runs")
                
                # NEW APPROACH: Collect random percentage from each run separately
                all_sampled_trajectories = []
                all_sampled_rewards = []
                all_sampled_actions = []
                sample_percentage = 1  # Take random 30% from each run (adjust as needed)
                
                for run_idx in range(len(runs_data)):
                    print(f"    Processing run {run_idx+1}/{len(runs_data)}")
                    if run_idx not in runs_data:
                        print(f"    ⚠ Run index {run_idx} missing in data, skipping...")   
                        continue
                        
                    run_data = runs_data[run_idx]
                    if 'trajectories' not in run_data:
                        try:
                            run_data = run_data['results_dict']
                        except:
                            print(f"    ⚠ No trajectory data in run {run_idx}, skipping...")
                            continue
                    # Extract data based on method
                    if 'trajectories' in run_data:
                        trajectories = run_data['trajectories']
                        rewards = run_data['rewards']
                        
                        # Handle actions with different key names
                        if 'all_actions' in run_data:
                            actions = run_data['all_actions']
                        elif 'actions' in run_data:
                            actions = run_data['actions']
                        else:
                            actions = []
                        
                        # Convert to numpy arrays for this run
                        if isinstance(trajectories, (list, np.ndarray)):
                            trajectories = np.array(trajectories)
                        if isinstance(rewards, (list, np.ndarray)):
                            rewards = np.array(rewards)
                        if len(actions) > 0:
                            actions = np.array(actions)
                        
                        # Process rewards based on method type
                        if method_name in ['reinforce', 'sac']:
                            # RL methods: rewards are already final values (1D)
                            if rewards.ndim == 1:
                                final_rewards_run = rewards
                            elif rewards.ndim == 2 and rewards.shape[1] == 1:
                                final_rewards_run = rewards[:, 0]
                            else:
                                print(f"    ⚠ Unexpected reward format for RL method {method_name}: {rewards.shape}")
                                continue
                        else:
                            # GFlow/SMCMC methods: rewards are trajectories (2D)
                            if rewards.ndim == 2:
                                final_rewards_run = rewards[:, -1]  # Take last timestep
                            elif rewards.ndim == 1:
                                final_rewards_run = rewards  # Already final values
                            else:
                                print(f"    ⚠ Unexpected reward format for method {method_name}: {rewards.shape}")
                                continue
                        
                        # FIX: Ensure final_rewards_run is numeric before using np.isfinite()
                        try:
                            # Convert to float64 to ensure numeric type
                            final_rewards_run = np.array(final_rewards_run, dtype=np.float64)
                        except (ValueError, TypeError) as e:
                            print(f"    ⚠ Cannot convert rewards to numeric for {method_name} run {run_idx}: {e}")
                            print(f"    Rewards sample: {final_rewards_run[:5] if len(final_rewards_run) > 0 else 'empty'}")
                            continue
                        # Remove non-finite rewards and clip to maximum of 100
                        finite_mask = np.isfinite(final_rewards_run)
                        final_rewards_run = final_rewards_run[finite_mask]
                        final_rewards_run = np.clip(final_rewards_run, -np.inf, 100)  # Clip rewards at 100
                        trajectories_run = trajectories[finite_mask]
                        if len(actions) > 0:
                            actions_run = actions[finite_mask]
                        else:
                            actions_run = []
                        
                        if len(final_rewards_run) == 0:
                            print(f"    ⚠ No valid rewards for run {run_idx}")
                            continue
                        
                        # Get random X% from this specific run
                        n_trajectories_run = len(final_rewards_run)
                        n_sample_run = max(1, int(n_trajectories_run * sample_percentage))
                        
                        # Get random indices from this run
                        np.random.seed(42 + run_idx)  # Reproducible but different per run
                        sample_indices_run = np.random.choice(n_trajectories_run, 
                                                            size=n_sample_run, 
                                                            replace=False)
                        
                        # Extract sampled trajectories from this run
                        sample_trajectories_run = trajectories_run[sample_indices_run]
                        sample_rewards_run = final_rewards_run[sample_indices_run]
                        if len(actions_run) > 0:
                            sample_actions_run = actions_run[sample_indices_run]
                        else:
                            sample_actions_run = []
                        
                        # Add to combined lists
                        all_sampled_trajectories.extend(sample_trajectories_run)
                        all_sampled_rewards.extend(sample_rewards_run)
                        if len(sample_actions_run) > 0:
                            all_sampled_actions.extend(sample_actions_run)
                        
                        print(f"    Run {run_idx+1}: Randomly sampled {n_sample_run}/{n_trajectories_run} trajectories ({sample_percentage*100:.0f}%)")
                
                # Convert to numpy arrays
                trajectories = np.array(all_sampled_trajectories)
                final_rewards = np.array(all_sampled_rewards)
                final_rewards = np.clip(final_rewards, -np.inf, 100)  # Ensure clipping is applied
                all_actions = np.array(all_sampled_actions) if all_sampled_actions else []
                
                print(f"    Combined sampled trajectories: {len(trajectories)} from all runs")
                print(f"    Combined sampled rewards: {len(final_rewards)} from all runs")
                
            else:
                # This is single dataset structure (shouldn't happen with current structure)
                print(f"    Found single dataset structure")
                trajectories = runs_data['trajectories']
                rewards = runs_data['rewards']
                all_actions = runs_data.get('all_actions', runs_data.get('actions', []))
                
                # Process single dataset
                rewards = np.array(rewards) if isinstance(rewards, list) else rewards
                if method_name in ['reinforce', 'sac']:
                    if rewards.ndim == 1:
                        final_rewards = rewards
                    elif rewards.ndim == 2 and rewards.shape[1] == 1:
                        final_rewards = rewards[:, 0]
                    else:
                        final_rewards = rewards.flatten()
                else:
                    if rewards.ndim == 2:
                        final_rewards = rewards[:, -1]
                    elif rewards.ndim == 1:
                        final_rewards = rewards
                    else:
                        final_rewards = np.array([r[-1] if isinstance(r, (list, np.ndarray)) and len(r) > 0 else r for r in rewards])
                
                # Ensure clipping for single dataset structure too
                final_rewards = np.clip(final_rewards, -np.inf, 100)
            
            if len(trajectories) == 0:
                print(f"    ⚠ No trajectories found for {method_name}")
                continue
            
            # Now work with the sampled trajectories (random % from each run)
            final_predictions = np.array([-r for r in final_rewards])
            print(f"Final rewards shape: {final_rewards.shape}, Final predictions shape: {final_predictions.shape}")
            
            # For analysis, take top N from the randomly sampled set for specific analyses
            n_top = min(100, len(final_rewards))
            if n_top < len(final_rewards):
                # If we have more than 200, take the best 200 from the sampled set for some analyses
                top_indices = np.argsort(final_predictions)[:n_top]
                trajectories_top = np.array([trajectories[i] for i in top_indices])
                final_rewards_top = final_rewards[top_indices]
                final_predictions_top = final_predictions[top_indices]

                # final_rewards = final_rewards_top
                # final_predictions = final_predictions_top
                # trajectories = trajectories_top
            else:
                # Use all sampled trajectories
                trajectories_top = trajectories
                final_rewards_top = final_rewards
                final_predictions_top = final_predictions
                top_indices = np.arange(len(final_rewards))
                
            print(f"    Using {len(trajectories_top)} trajectories for analysis")
            
            # Calculate enhanced metrics on the randomly sampled data
            metrics_dict = {
                'average_reward': np.mean(final_rewards),  # Use full sampled data for metrics
                'median_reward': np.median(final_rewards),  # NEW: Median reward
                'max_reward': np.max(final_rewards),
                'min_reward': np.min(final_rewards),
                'std_reward': np.std(final_rewards),
                'quantile_10': np.percentile(final_rewards, 10),  # NEW: 10th percentile
                'quantile_25': np.percentile(final_rewards, 25),  # NEW: 25th percentile (Q1)
                'quantile_75': np.percentile(final_rewards, 75),  # NEW: 75th percentile (Q3)
                'quantile_90': np.percentile(final_rewards, 90),  # NEW: 90th percentile
                'iqr': np.percentile(final_rewards, 75) - np.percentile(final_rewards, 25),  # NEW: Interquartile Range
                'tail_coverage_20': np.mean(final_predictions >= -20),
                'tail_coverage_40': np.mean(final_predictions >= -40),
                'tail_coverage_60': np.mean(final_predictions >= -60),
                'expected_shortfall_5': np.mean(np.sort(final_predictions)[:max(1, len(final_predictions)//20)]) if len(final_predictions) >= 5 else np.mean(final_predictions),
                'expected_shortfall_10': np.mean(np.sort(final_predictions)[:max(1, len(final_predictions)//10)]) if len(final_predictions) >= 10 else np.mean(final_predictions),
                'expected_shortfall_20': np.mean(np.sort(final_predictions)[:max(1, len(final_predictions)//5)]) if len(final_predictions) >= 20 else np.mean(final_predictions)
            }
            
            # Save enhanced metrics
            with open(os.path.join(metrics_dir, 'performance_metrics.txt'), 'w') as f:
                for metric, value in metrics_dict.items():
                    f.write(f"{metric}: {value:.4f}\n")
                f.write(f"sample_percentage_per_run: {sample_percentage:.2f}\n")
                f.write(f"total_analyzed_trajectories: {len(trajectories)}\n")
                f.write(f"sampling_method: random_per_run\n")
            
            # Enhanced rewards summary
            rewards_dict = {
                'min': np.min(final_rewards),
                'max': np.max(final_rewards),
                'mean': np.mean(final_rewards),
                'median': np.median(final_rewards),  # NEW: Median in rewards summary
                'std': np.std(final_rewards),
                'quantile_10': np.percentile(final_rewards, 10),  # NEW: Added to rewards summary
                'quantile_25': np.percentile(final_rewards, 25),  # NEW: Added to rewards summary
                'quantile_75': np.percentile(final_rewards, 75),  # NEW: Added to rewards summary
                'quantile_90': np.percentile(final_rewards, 90),  # NEW: Added to rewards summary
                'iqr': np.percentile(final_rewards, 75) - np.percentile(final_rewards, 25)  # NEW: IQR
            }
            
            with open(os.path.join(metrics_dir, 'rewards_summary.txt'), 'w') as f:
                for metric, value in rewards_dict.items():
                    f.write(f"{metric}: {value:.4f}\n")
            
            # DTW clustering analysis (if we have enough data) - use top trajectories for computational efficiency
            if len(trajectories_top) >= 3:
                try:
                    results_df, distance_matrix, feature_clusters = dtw_clustering_analysis(
                        trajectories=trajectories_top,
                        n_clusters=min(3, len(trajectories_top))
                    )
                    
                    # # Clustering plots
                    # plot_trajectories_by_cluster(results_df, trajectories_top, ['Timestamp']+feature_names,
                    #                            num_clusters=min(3, len(trajectories_top)),
                    #                            save_path=os.path.join(plots_dir, 'cluster_trajectories.pdf'))
                    
                    # DTW distance matrix
                    res_dict_dtw = plot_dtw_distance_matrix(distance_matrix, model_name=method_name, 
                                                          rewards=final_rewards_top,
                                                          save_path=os.path.join(plots_dir, 'dtw_distance_matrix.pdf'))
                    
                    # Save DTW metrics
                    with open(os.path.join(metrics_dir, 'dtw_metrics.txt'), 'w') as f:
                        f.write(f"Average Distance: {res_dict_dtw['average_distance']:.4f}\n")
                        f.write(f"Normalized Score: {res_dict_dtw['normalized_score']:.4f}\n")
                        f.write(f"Reward Normalization: {res_dict_dtw['reward_normalization']:.4f}\n")
                    
                except Exception as e:
                    print(f"    ⚠ DTW analysis failed for {method_name}: {e}")
            
            # Trajectory plots - use top trajectories for visualization
            # try:
            #     plot_trajectories_over_time(
            #         trajectories=trajectories_top,
            #         rewards=final_rewards_top,
            #         feature_names=['Timestamp'] + feature_names,
            #         n_top=min(50, len(trajectories_top)),
            #         alpha_others=0.05,
            #         save_path=plots_dir
            #     )
            # except Exception as e:
            #     print(f"    ⚠ Trajectory plotting failed for {method_name}: {e}")
            
            # Action distribution plots (if actions available) - use all sampled data
            # if len(all_actions) > 0:
            #     try:
            #         feature_names_for_actions = ['Timestamp'] + feature_names if method_name != 'smcmc' else feature_names
                    
                    # plot_action_distributions_by_timestep(
                    #     trajectories=trajectories,
                    #     all_actions=all_actions,
                    #     feature_names=feature_names_for_actions,
                    #     save_path=action_plots_dir
                    # )
                    
                    # plot_action_distributions(
                    #     all_actions=all_actions,
                    #     feature_names=feature_names_for_actions,
                    #     save_path=os.path.join(plots_dir, 'distribution_actions.pdf')
                    # )
            #     except Exception as e:
            #         print(f"    ⚠ Action distribution plotting failed for {method_name}: {e}")
            
            # Summary DataFrame - use top trajectories for the summary table
            try:
                # Create summary for top trajectories
                summary_rows = []
                for rank, idx in enumerate(top_indices):
                    row_dict = {
                        "Rank": rank + 1,
                        "CaseIndex": idx,
                        "FinalReward": final_rewards[idx]
                    }
                    
                    # Handle final values extraction
                    if trajectories[idx].ndim == 2:  # (timesteps, features)
                        final_values = trajectories[idx][-1, :]
                    else:
                        final_values = trajectories[idx]
                    
                    # Skip timestamp column if present
                    feature_start_idx = 1 if len(final_values) == len(feature_names) + 1 else 0
                    
                    for f_i, f_name in enumerate(feature_names):
                        if feature_start_idx + f_i < len(final_values):
                            row_dict[f_name] = final_values[feature_start_idx + f_i]
                    
                    summary_rows.append(row_dict)
                
                top_df = pd.DataFrame(summary_rows).sort_values(by="FinalReward", ascending=False)
                top_df.to_csv(os.path.join(metrics_dir, f'top_{n_top}_summary.csv'), index=False)
                
                # Create summary for ALL trajectories (not just top)
                all_summary_rows = []
                for idx in range(len(trajectories)):
                    row_dict = {
                        "CaseIndex": idx,
                        "FinalReward": final_rewards[idx]
                    }
                    
                    # Handle final values extraction
                    if trajectories[idx].ndim == 2:  # (timesteps, features)
                        final_values = trajectories[idx][-1, :]
                    else:
                        final_values = trajectories[idx]
                    
                    # Skip timestamp column if present
                    feature_start_idx = 1 if len(final_values) == len(feature_names) + 1 else 0
                    
                    for f_i, f_name in enumerate(feature_names):
                        if feature_start_idx + f_i < len(final_values):
                            row_dict[f_name] = final_values[feature_start_idx + f_i]
                    
                    all_summary_rows.append(row_dict)
                
                # Sort all trajectories by reward and add rank
                all_df = pd.DataFrame(all_summary_rows).sort_values(by="FinalReward", ascending=False)
                all_df.insert(0, 'Rank', range(1, len(all_df) + 1))
                all_df.to_csv(os.path.join(metrics_dir, f'all_trajectories_summary.csv'), index=False)
                
                # Diversity metrics - use top trajectories for computational efficiency
                if len(feature_names) > 0:
                    try:
                        avg_div, norm_div, distances = plot_diversity_metrics(
                            df=top_df,
                            feature_names=feature_names,
                            save_path=os.path.join(plots_dir, 'euclidean_distance_matrix.pdf')
                        )
                        
                        with open(os.path.join(metrics_dir, 'diversity_metrics.txt'), 'w') as f:
                            f.write(f"Average Diversity - Last State: {avg_div:.4f}\n")
                            f.write(f"Normalized Diversity - Last State: {norm_div:.4f}\n")
                    except Exception as e:
                        print(f"    ⚠ Diversity analysis failed for {method_name}: {e}")

            except Exception as e:
                print(f"    ⚠ Summary analysis failed for {method_name}: {e}")
            
            # Reward distribution plot - use all sampled data
            try:
                plot_reward_distribution(final_rewards, 
                                        save_path=os.path.join(plots_dir, 'reward_distribution.pdf'))
            except Exception as e:
                print(f"    ⚠ Reward distribution plotting failed for {method_name}: {e}")
            
            combination_results[method_name] = {
                'metrics': metrics_dict,
                'rewards_summary': rewards_dict,
                'output_dir': method_output_dir,
                'n_trajectories': len(trajectories),
                'n_top_analyzed': len(trajectories_top),
                'sample_percentage_used': sample_percentage,
                'selection_method': 'random_per_run'
            }
            
            print(f"    ✓ {method_name.upper()}: avg_reward={metrics_dict['average_reward']:.2f}, median_reward={metrics_dict['median_reward']:.2f} (analyzed {len(trajectories)} randomly sampled trajectories from {sample_percentage*100:.0f}% of each run)")
            
        except Exception as e:
            print(f"    ✗ Error processing {method_name}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    return combination_results

def generate_comparison_tables(all_results, output_dir):
    """
    Generate comprehensive comparison tables across all oracles and model types.
    """
    import pandas as pd
    import os
    import numpy as np
    
    tables_dir = os.path.join(output_dir, 'comparison_tables')
    os.makedirs(tables_dir, exist_ok=True)
    
    # Enhanced key metrics to compare (including new quantile-based metrics)
    key_metrics = ['average_reward', 'median_reward', 'max_reward', 'quantile_90', 'quantile_75', 
                   'quantile_25', 'quantile_10', 'iqr', 'tail_coverage_20', 'expected_shortfall_10', 'expected_shortfall_20']
    methods = ['gflow', 'reinforce', 'reinforce_baseline', 'sac', 'smcmc']
    
    # 1. Create comparison table for each metric
    for metric in key_metrics:
        print(f"Creating comparison table for {metric}...")
        
        rows = []
        
        for oracle in [1, 2, 3]:
            for model_type in ['gbr', 'mlp', 'rfr','elastic','ens']:
                row_data = {'Oracle': oracle, 'Model_Type': model_type}
                
                combination_key = f"Oracle_{oracle}_{model_type.upper()}"
                
                if combination_key in all_results:
                    for method in methods:
                        if method in all_results[combination_key]:
                            value = all_results[combination_key][method]['metrics'].get(metric, np.nan)
                            row_data[method.upper()] = value
                        else:
                            row_data[method.upper()] = np.nan
                else:
                    # Fill with NaN if combination doesn't exist
                    for method in methods:
                        row_data[method.upper()] = np.nan
                
                rows.append(row_data)
        
        # Create DataFrame and save
        metric_df = pd.DataFrame(rows)
        metric_df.to_csv(os.path.join(tables_dir, f'{metric}_comparison_table.csv'), index=False)
    
    # 2. Create enhanced master summary table
    summary_rows = []
    for combination_key, combination_data in all_results.items():
        oracle, model_type = combination_key.split('_')[1], combination_key.split('_')[2]
        
        for method, method_data in combination_data.items():
            row = {
                'Oracle': oracle,
                'Model_Type': model_type,
                'Method': method.upper(),
                'Avg_Reward': method_data['metrics'].get('average_reward', np.nan),
                'Median_Reward': method_data['metrics'].get('median_reward', np.nan),  # NEW
                'Max_Reward': method_data['metrics'].get('max_reward', np.nan),
                'Std_Reward': method_data['metrics'].get('std_reward', np.nan),
                'Q10': method_data['metrics'].get('quantile_10', np.nan),  # NEW
                'Q25': method_data['metrics'].get('quantile_25', np.nan),  # NEW
                'Q75': method_data['metrics'].get('quantile_75', np.nan),  # NEW
                'Q90': method_data['metrics'].get('quantile_90', np.nan),  # NEW
                'IQR': method_data['metrics'].get('iqr', np.nan),  # NEW
                'Tail_Coverage_20': method_data['metrics'].get('tail_coverage_20', np.nan),
                'Expected_Shortfall_10': method_data['metrics'].get('expected_shortfall_10', np.nan),
                'Expected_Shortfall_20': method_data['metrics'].get('expected_shortfall_20', np.nan),
                'N_Trajectories': method_data.get('n_trajectories', 'N/A'),
                'N_Top_Analyzed': method_data.get('n_top_analyzed', 'N/A')
            }
            summary_rows.append(row)
    
    summary_df = pd.DataFrame(summary_rows)
    summary_df.to_csv(os.path.join(tables_dir, 'master_summary_table.csv'), index=False)
    
    # 3. Create ranking tables (updated to include median-based ranking)
    ranking_data = []
    for combination_key, combination_data in all_results.items():
        oracle, model_type = combination_key.split('_')[1], combination_key.split('_')[2]
        
        # Rank methods by average reward
        method_rewards = [(method, data['metrics'].get('average_reward', -np.inf)) 
                         for method, data in combination_data.items()]
        method_rewards.sort(key=lambda x: x[1], reverse=True)
        
        # Rank methods by median reward (NEW)
        method_medians = [(method, data['metrics'].get('median_reward', -np.inf)) 
                         for method, data in combination_data.items()]
        method_medians.sort(key=lambda x: x[1], reverse=True)
        
        ranking_data.append({
            'Oracle': oracle,
            'Model_Type': model_type,
            'Best_Method_Mean': method_rewards[0][0].upper() if method_rewards else 'N/A',
            'Best_Mean_Reward': method_rewards[0][1] if method_rewards else np.nan,
            'Best_Method_Median': method_medians[0][0].upper() if method_medians else 'N/A',  # NEW
            'Best_Median_Reward': method_medians[0][1] if method_medians else np.nan,  # NEW
            'Second_Method': method_rewards[1][0].upper() if len(method_rewards) > 1 else 'N/A',
            'Second_Reward': method_rewards[1][1] if len(method_rewards) > 1 else np.nan,
            'Worst_Method': method_rewards[-1][0].upper() if method_rewards else 'N/A',
            'Worst_Reward': method_rewards[-1][1] if method_rewards else np.nan,
            'Methods_Count': len(method_rewards)
        })
    
    ranking_df = pd.DataFrame(ranking_data)
    ranking_df.to_csv(os.path.join(tables_dir, 'method_rankings.csv'), index=False)
    
    print(f"✓ All comparison tables saved to {tables_dir}")
    return summary_df, ranking_df


In [None]:
analysis_results, output_dir = main()

In [None]:
def generate_oracle_model_specific_latex_tables(analysis_base_path, output_path=None):
    """
    Generate LaTeX tables for each oracle-model combination to analyze generalization.
    
    Args:
        analysis_base_path: Path to comprehensive analysis results folder
        output_path: Optional path to save the LaTeX files
    
    Returns:
        dict: Contains LaTeX tables for each oracle-model combination
    """
    import os
    import numpy as np
    from pathlib import Path
    import re
    
    analysis_path = Path(analysis_base_path)
    
    # Define structure
    oracles = {1: "2002", 2: "2008", 3: "2021"}
    model_types = ["GBR", "MLP", "RFR",'ELASTIC','ENS']
    methods = {
        "gflow": "GFlowNet",
        "reinforce": "REINFORCE", 
        "sac": "SAC",
        "smcmc": "SMCMC"
    }
    
    def parse_metrics_file(filepath):
        """Parse a metrics file and extract key-value pairs."""
        metrics = {}
        if not filepath.exists():
            return metrics
            
        with open(filepath, 'r') as f:
            for line in f:
                line = line.strip()
                if ':' in line and not line.startswith('#'):
                    key, value = line.split(':', 1)
                    key = key.strip()
                    value = value.strip()
                    try:
                        if '±' in value:
                            mean_part = value.split('±')[0].strip()
                            metrics[key] = float(mean_part)
                        else:
                            metrics[key] = float(value)
                    except ValueError:
                        continue
        return metrics
    
    # Collect all metrics data organized by oracle-model combination
    oracle_model_data = {}
    
    print("Collecting data for oracle-model combinations...")
    
    for oracle_num, oracle_year in oracles.items():
        for model_type in model_types:
            combination_key = f"Oracle_{oracle_num}_{model_type}"
            oracle_folder = analysis_path / combination_key
            
            if not oracle_folder.exists():
                print(f"⚠ Oracle folder not found: {oracle_folder}")
                continue
            
            print(f"Processing {combination_key}...")
            oracle_model_data[combination_key] = {
                'oracle_year': oracle_year,
                'model_type': model_type,
                'diversity_data': {},
                'performance_data': {}
            }
            
            for method_key, method_name in methods.items():
                method_folder = oracle_folder / method_key
                
                if not method_folder.exists():
                    print(f"  Method folder not found: {method_folder}")
                    continue
                
                # Read diversity metrics
                dtw_file = method_folder / "metrics" / "dtw_metrics.txt"
                div_file = method_folder / "metrics" / "diversity_metrics.txt"
                
                dtw_metrics = parse_metrics_file(dtw_file)
                div_metrics = parse_metrics_file(div_file)
                combined_div_metrics = {**dtw_metrics, **div_metrics}
                
                # Read performance metrics  
                perf_file = method_folder / "metrics" / "performance_metrics.txt"
                perf_metrics = parse_metrics_file(perf_file)
                
                # Store data for this combination
                if combined_div_metrics:
                    oracle_model_data[combination_key]['diversity_data'][method_name] = combined_div_metrics
                if perf_metrics:
                    oracle_model_data[combination_key]['performance_data'][method_name] = perf_metrics
    
    # Generate tables for each oracle-model combination
    all_tables = {}
    
    # Update metric mappings
    diversity_metrics = {
        "Average Distance": "DTW Distance",
        "Normalized Score": "DTW Normalized", 
        "Average Diversity - Last State": "Terminal Diversity",
        "Normalized Diversity - Last State": "Terminal Normalized"
    }
    
    # UPDATED: Changed from average_reward and max_reward to median_reward and quantile_90
    performance_metrics = {
        "median_reward": "Median Reward",      # CHANGED: from average_reward to median_reward
        "quantile_90": "Q90 Reward",          # CHANGED: from max_reward to quantile_90
        "tail_coverage_20": "TC@20",
        "expected_shortfall_10": "ES@10%",
        "expected_shortfall_20": "ES@20%"
    }
    
    def find_best_method(data_dict, metric_key, higher_better=True):
        """Find the best performing method for a given metric."""
        if not data_dict:
            return None
            
        method_values = []
        for method_name, metrics in data_dict.items():
            if metric_key in metrics:
                method_values.append((method_name, metrics[metric_key]))
        
        if not method_values:
            return None
        
        if higher_better:
            best_method = max(method_values, key=lambda x: x[1])[0]
        else:
            best_method = min(method_values, key=lambda x: x[1])[0]
        
        return best_method
    
    for combination_key, combination_data in oracle_model_data.items():
        oracle_year = combination_data['oracle_year']
        model_type = combination_data['model_type']
        
        print(f"Generating tables for {combination_key}...")
        
        # DIVERSITY TABLE for this combination
        diversity_latex = f"""\\begin{{table}}[t]
\\caption{{Diversity metrics for {model_type} oracle on {oracle_year} financial crash. Methods were trained using {model_type} oracle to test generalization. Higher values indicate better diversity. Results from 100 trajectories per method.}}
\\label{{tab:diversity_{combination_key.lower()}}}
\\centering
\\begin{{tabular}}{{l|cccc}}
\\toprule
\\textbf{{Method}} & \\textbf{{DTW Distance}} & \\textbf{{DTW Normalized}} & \\textbf{{Terminal Diversity}} & \\textbf{{Terminal Normalized}} \\\\
\\midrule
"""
        
        diversity_data = combination_data['diversity_data']
        
        for method_name in ["GFlowNet", "REINFORCE", "SAC", "SMCMC"]:
            if method_name not in diversity_data:
                continue
                
            row_data = []
            
            for metric_key, metric_display in diversity_metrics.items():
                if metric_key in diversity_data[method_name]:
                    value = diversity_data[method_name][metric_key]
                    
                    # Check if this is the best method for this metric
                    best_method = find_best_method(diversity_data, metric_key, higher_better=True)
                    if best_method == method_name:
                        row_data.append(f"\\textbf{{{value:.2f}}}")
                    else:
                        row_data.append(f"{value:.2f}")
                else:
                    row_data.append("--")
            
            if row_data:  # Only add row if we have data
                row_str = f"{method_name:<12} & {' & '.join(row_data)} \\\\\n"
                diversity_latex += row_str
        
        diversity_latex += """\\bottomrule
\\end{tabular}
\\end{table}

"""
        
        # PERFORMANCE TABLE for this combination - UPDATED
        performance_latex = f"""\\begin{{table}}[t]
\\caption{{Performance metrics for {model_type} oracle on {oracle_year} financial crash. Methods were trained using {model_type} oracle to test generalization. Higher is better for Median/Q90 Reward and TC@20; lower is better for ES metrics.}}
\\label{{tab:performance_{combination_key.lower()}}}
\\centering
\\begin{{tabular}}{{l|ccccc}}
\\toprule
\\textbf{{Method}} & \\textbf{{Median Reward}} & \\textbf{{Q90 Reward}} & \\textbf{{TC@20}} & \\textbf{{ES@10\\%}} & \\textbf{{ES@20\\%}} \\\\
\\midrule
"""
        
        performance_data = combination_data['performance_data']
        # UPDATED: Changed higher_better_metrics to include new metrics
        higher_better_metrics = ["median_reward", "quantile_90", "tail_coverage_20"]
        
        for method_name in ["GFlowNet", "REINFORCE", "SAC", "SMCMC"]:
            if method_name not in performance_data:
                continue
                
            row_data = []
            
            for metric_key, metric_display in performance_metrics.items():
                if metric_key in performance_data[method_name]:
                    value = performance_data[method_name][metric_key]
                    
                    # Check if this is the best method for this metric
                    is_higher_better = metric_key in higher_better_metrics
                    best_method = find_best_method(performance_data, metric_key, higher_better=is_higher_better)
                    
                    if best_method == method_name:
                        row_data.append(f"\\textbf{{{value:.2f}}}")
                    else:
                        row_data.append(f"{value:.2f}")
                else:
                    row_data.append("--")
            
            if row_data:  # Only add row if we have data
                row_str = f"{method_name:<12} & {' & '.join(row_data)} \\\\\n"
                performance_latex += row_str
        
        performance_latex += """\\bottomrule
\\end{tabular}
\\end{table}

"""
        
        # Store tables for this combination
        all_tables[combination_key] = {
            'diversity_table': diversity_latex,
            'performance_table': performance_latex
        }
    
    # Save to files if output path provided
    if output_path:
        output_path = Path(output_path)
        output_path.mkdir(exist_ok=True)
        
        # Save individual tables
        for combination_key, tables in all_tables.items():
            combo_dir = output_path / combination_key
            combo_dir.mkdir(exist_ok=True)
            
            with open(combo_dir / "diversity_table.tex", 'w') as f:
                f.write(tables['diversity_table'])
            
            with open(combo_dir / "performance_table.tex", 'w') as f:
                f.write(tables['performance_table'])
        
        # Create combined file with all tables
        combined_latex = ""
        for combination_key in sorted(all_tables.keys()):
            oracle_num = combination_key.split('_')[1]
            model_type = combination_key.split('_')[2]
            
            combined_latex += f"% Tables for Oracle {oracle_num} ({oracles[int(oracle_num)]}) - {model_type} Model\n"
            combined_latex += f"% =" * 60 + "\n\n"
            combined_latex += all_tables[combination_key]['diversity_table']
            combined_latex += "\n"
            combined_latex += all_tables[combination_key]['performance_table']
            combined_latex += "\n\\clearpage\n\n"
        
        with open(output_path / "all_oracle_model_tables.tex", 'w') as f:
            f.write(combined_latex)
        
        print(f"\nTables saved to {output_path}")
        print(f"Individual tables saved in subfolders")
        print(f"Combined file: all_oracle_model_tables.tex")
    
    return all_tables

def generate_generalization_analysis_tables(analysis_base_path, output_path=None):
    """
    Generate specialized tables to analyze cross-oracle generalization.
    Shows how each method performs across different oracles when trained on a specific one.
    """
    # This function will show, for example:
    # - GFlowNet trained on MLP oracle: performance on GBR oracle, RFR oracle, etc.
    
    all_tables = generate_oracle_model_specific_latex_tables(analysis_base_path, output_path)
    
    if not output_path:
        return all_tables
    
    output_path = Path(output_path)
    
    # UPDATED: Create generalization summary table with median and Q90
    generalization_latex = """\\begin{table}[t]
\\caption{Cross-Oracle Generalization Analysis: Performance of methods across different oracle types. Each row shows a method's performance when applied to different oracle types (trained on one, tested on all).}
\\label{tab:generalization_analysis}
\\scriptsize
\\centering
\\resizebox{\\textwidth}{!}{%
\\begin{tabular}{l|l|ccc|ccc}
\\toprule
\\multirow{2}{*}{\\textbf{Method}} & \\multirow{2}{*}{\\textbf{Trained On}} & \\multicolumn{3}{c|}{\\textbf{Median Reward}} & \\multicolumn{3}{c}{\\textbf{Q90 Reward}} \\\\
& & \\textbf{GBR} & \\textbf{MLP} & \\textbf{RFR} & \\textbf{GBR} & \\textbf{MLP} & \\textbf{RFR} \\\\
\\midrule
"""
    
    # Parse the data to create generalization analysis
    oracles = {1: "2002", 2: "2008", 3: "2021"}
    model_types = ["GBR", "MLP", "RFR",'ELASTIC','ENS']
    methods = ["GFlowNet", "REINFORCE", "SAC", "SMCMC"]
    
    # For each oracle year, show how methods perform across model types
    for oracle_num, oracle_year in oracles.items():
        generalization_latex += f"\\multicolumn{{8}}{{c}}{{\\textbf{{{oracle_year} Financial Crisis}}}} \\\\\n"
        generalization_latex += "\\midrule\n"
        
        for method in methods:
            # Check if this method has data across all model types for this oracle
            method_data = {}
            for model_type in model_types:
                combination_key = f"Oracle_{oracle_num}_{model_type}"
                if combination_key in all_tables:
                    # Extract performance data from the LaTeX (this is a simplified approach)
                    # In practice, you'd want to parse the original data
                    method_data[model_type] = "Data Available"
            
            if method_data:
                # Add method row (simplified version)
                generalization_latex += f"{method:<12} & Various & -- & -- & -- & -- & -- & -- \\\\\n"
    
    generalization_latex += """\\bottomrule
\\end{tabular}%
}
\\end{table}
"""
    
    with open(output_path / "generalization_analysis.tex", 'w') as f:
        f.write(generalization_latex)
    
    return all_tables

# Usage:
tables = generate_oracle_model_specific_latex_tables(
    "analysis_outputs/Comprehensive_Analysis_20250925_132228",
    output_path="results/oracle_model_tables"
)

# Print summary of generated tables
print("\nGenerated tables for oracle-model combinations:")
for combination, table_data in tables.items():
    oracle_num = combination.split('_')[1]
    model_type = combination.split('_')[2]
    print(f"  ✓ {combination}: Oracle {oracle_num} with {model_type} model")

print(f"\nTotal combinations: {len(tables)}")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
from scipy import stats


In [None]:
def plot_oracle_specific_reward_distributions(analysis_base_path, save_dir=None, confidence_level=0.95):
    """
    Create professional reward distribution plots for each oracle (use case),
    showing all methods across different model types with family colors.
    
    Args:
        analysis_base_path: Path to comprehensive analysis results folder
        save_dir: Directory to save the plots
        confidence_level: Confidence level for statistical annotations
    
    Returns:
        dict: Summary statistics for each oracle-method-model combination
    """
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from pathlib import Path
    import os
    from scipy import stats
    
    analysis_path = Path(analysis_base_path)
    if save_dir:
        save_dir = Path(save_dir)
        save_dir.mkdir(exist_ok=True)
    
    # Define structure and colors
    oracles = {1: "2002 Financial Crisis", 2: "2008 Financial Crisis", 3: "2021 Market Volatility"}
    model_types = ['GBR', 'MLP', 'RFR','ELASTIC','ENS']
    methods = ['gflow', 'reinforce', 'sac', 'smcmc']
    
    # Define professional color families
    color_families = {
        'gflow': {
            'GBR': '#27ae60',    # Dark green
            'MLP': '#2ecc71',    # Medium green  
            'RFR': '#58d68d',     # Light green
            'ELASTIC': '#1e8449',
            'ENS': '#52be80'
        },
        'reinforce': {
            'GBR': '#2980b9',    # Dark blue
            'MLP': '#3498db',    # Medium blue
            'RFR': '#85c1e9',    # Light blue
            'ELASTIC': '#1f618d',
            'ENS': '#3498db'
        },
        'sac': {
            'GBR': '#8e44ad',    # Dark purple
            'MLP': '#9b59b6',    # Medium purple
            'RFR': '#bb8fce',    # Light purple
            'ELASTIC': '#6c3483',
            'ENS': '#a569bd'
        },
        'smcmc': {
            'GBR': '#c0392b',    # Dark red
            'MLP': '#e74c3c',    # Medium red
            'RFR': '#f1948a',    # Light red
            'ELASTIC': '#922b21',
            'ENS': '#cd6155'
        }
    }
    
    # Collect all data
    print("Collecting reward data from all combinations...")
    oracle_data = {}
    
    for oracle_num, oracle_name in oracles.items():
        print(f"Processing Oracle {oracle_num} ({oracle_name})...")
        oracle_data[oracle_num] = {
            'name': oracle_name,
            'method_model_data': {},
            'statistics': {}
        }
        
        for model_type in model_types:
            for method in methods:
                combination_key = f"Oracle_{oracle_num}_{model_type}"
                method_folder = analysis_path / combination_key / method
                csv_file = method_folder / "metrics" / "top_100_summary.csv"
                
                if csv_file.exists():
                    try:
                        df = pd.read_csv(csv_file)
                        if 'FinalReward' in df.columns:
                            rewards = df['FinalReward'].values
                            rewards = rewards[np.isfinite(rewards)]  # Remove any NaN values
                            
                            if len(rewards) > 0:
                                key = f"{method}_{model_type}"
                                oracle_data[oracle_num]['method_model_data'][key] = {
                                    'rewards': rewards,
                                    'method': method,
                                    'model_type': model_type,
                                    'color': color_families[method][model_type],
                                    'n_samples': len(rewards)
                                }
                                print(f"  ✓ Loaded {key}: {len(rewards)} samples")
                        else:
                            print(f"  ⚠ No 'FinalReward' column in {csv_file}")
                    except Exception as e:
                        print(f"  ✗ Error loading {csv_file}: {e}")
                else:
                    print(f"  - No data file: {csv_file}")
    
    # Create plots for each oracle
    summary_stats = {}
    
    for oracle_num, oracle_info in oracle_data.items():
        if not oracle_info['method_model_data']:
            print(f"No data for Oracle {oracle_num}, skipping...")
            continue
        
        print(f"\nCreating plot for Oracle {oracle_num}...")
        
        # Create the plot
        plt.figure(figsize=(16, 10), dpi=300)
        
        # Set up the plot style
        sns.set_style("whitegrid")
        plt.rcParams.update({'font.size': 12})
        
        # Plot distributions for each method-model combination
        legend_elements = []
        stats_data = []
        
        for key, data in oracle_info['method_model_data'].items():
            method = data['method']
            model_type = data['model_type']
            rewards = data['rewards']
            color = data['color']
            
            # Plot KDE
            sns.kdeplot(data=rewards, color=color, linewidth=3, alpha=0.8)
            
            # Calculate statistics
            mean_reward = np.mean(rewards)
            std_reward = np.std(rewards, ddof=1)
            median_reward = np.median(rewards)
            q25, q75 = np.percentile(rewards, [25, 75])
            
            # Add vertical line for mean
            plt.axvline(mean_reward, color=color, linestyle='--', alpha=0.6, linewidth=2)
            
            # Create legend entry
            method_name = {
                'gflow': 'GFlowNet',
                'reinforce': 'REINFORCE', 
                'sac': 'SAC',
                'smcmc': 'SMCMC'
            }[method]
            
            legend_elements.append(
                plt.Line2D([0], [0], color=color, linewidth=3,
                          label=f'{method_name}-{model_type} (μ={mean_reward:.1f}±{std_reward:.1f})')
            )
            
            # Store statistics
            stats_data.append({
                'oracle': oracle_num,
                'method': method_name,
                'model_type': model_type,
                'mean': mean_reward,
                'std': std_reward,
                'median': median_reward,
                'q25': q25,
                'q75': q75,
                'n_samples': len(rewards)
            })
        
        # Customize the plot
        plt.title(f'Reward Distributions: {oracle_info["name"]}\n'
                 f'Comparison across Methods and Oracle Model Types',
                 fontsize=18, fontweight='bold', pad=30)
        
        plt.xlabel('Final Reward Values', fontsize=14, fontweight='bold')
        plt.ylabel('Density', fontsize=14, fontweight='bold')
        
        # Professional styling
        plt.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
        
        # Create custom legend with method families grouped
        legend_by_method = {}
        for element in legend_elements:
            method_name = element.get_label().split('-')[0]
            if method_name not in legend_by_method:
                legend_by_method[method_name] = []
            legend_by_method[method_name].append(element)
        
        # Create legend with family grouping
        all_handles = []
        for method_name in ['GFlowNet', 'REINFORCE', 'SAC', 'SMCMC']:
            if method_name in legend_by_method:
                all_handles.extend(legend_by_method[method_name])
                # Add small separator (invisible)
                if method_name != 'SMCMC':
                    all_handles.append(plt.Line2D([0], [0], color='white', linewidth=0, label=''))
        
        legend = plt.legend(handles=all_handles, bbox_to_anchor=(1.02, 1), loc='upper left',
                          frameon=True, framealpha=0.95, edgecolor='black', 
                          title='Method-Model Combinations', title_fontsize=13)
        legend.get_title().set_fontweight('bold')
        
        # Remove spines for cleaner look
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.gca().spines['left'].set_linewidth(0.5)
        plt.gca().spines['bottom'].set_linewidth(0.5)
        
        # Add summary statistics text box
        stats_text = f"Oracle {oracle_num} Summary:\n"
        stats_text += f"Total Combinations: {len(oracle_info['method_model_data'])}\n"
        
        # Find best performing combination
        best_combo = max(oracle_info['method_model_data'].items(), 
                        key=lambda x: np.mean(x[1]['rewards']))
        best_key, best_data = best_combo
        best_method = best_data['method'].replace('gflow', 'GFlowNet').replace('reinforce', 'REINFORCE').upper()
        best_model = best_data['model_type']
        best_reward = np.mean(best_data['rewards'])
        
        stats_text += f"Best: {best_method}-{best_model} ({best_reward:.1f})"
        
        plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
                verticalalignment='top', horizontalalignment='left',
                bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.9, 
                         edgecolor='gray', linewidth=1),
                fontsize=11, fontweight='bold')
        
        plt.tight_layout()
        
        # Save plot
        if save_dir:
            filename = f"oracle_{oracle_num}_reward_distributions.pdf"
            plt.savefig(save_dir / filename, dpi=300, bbox_inches='tight')
            print(f"  ✓ Saved plot: {filename}")
        
        plt.show()
        
        # Store statistics
        summary_stats[oracle_num] = {
            'oracle_name': oracle_info['name'],
            'statistics': stats_data,
            'best_combination': {
                'method': best_method,
                'model_type': best_model,
                'mean_reward': best_reward
            }
        }
    
    # Create summary comparison plot across all oracles
    create_oracle_comparison_plot(summary_stats, color_families, save_dir)
    
    # Save summary statistics
    if save_dir:
        save_summary_statistics(summary_stats, save_dir)
    
    return summary_stats

def create_oracle_comparison_plot(summary_stats, color_families, save_dir=None):
    """Create a comparison plot showing best methods across all oracles."""
    
    plt.figure(figsize=(18, 12), dpi=300)
    
    # Create subplots for each oracle
    fig, axes = plt.subplots(1, 3, figsize=(24, 8), dpi=300, sharey=True)
    fig.suptitle('Cross-Oracle Performance Comparison: Method Families Across Financial Crises',
                 fontsize=20, fontweight='bold', y=0.95)
    
    oracle_names = {1: "2002 Crisis", 2: "2008 Crisis", 3: "2021 Volatility"}
    
    for idx, (oracle_num, oracle_data) in enumerate(summary_stats.items()):
        ax = axes[idx]
        
        # Group statistics by method family
        method_groups = {}
        for stat in oracle_data['statistics']:
            method = stat['method'].lower().replace('flownet', 'flow')
            if method not in method_groups:
                method_groups[method] = []
            method_groups[method].append(stat)
        
        # Plot each method family
        positions = []
        labels = []
        colors_used = []
        
        pos = 0
        for method in ['gflownet', 'reinforce', 'sac', 'smcmc']:
            if method not in method_groups:
                continue
                
            method_key = method.replace('gflownet', 'gflow')
            for model_stat in method_groups[method]:
                model_type = model_stat['model_type']
                mean_val = model_stat['mean']
                std_val = model_stat['std']
                
                color = color_families[method_key][model_type]
                
                # Create bar
                bar = ax.bar(pos, mean_val, yerr=std_val, capsize=5,
                           color=color, alpha=0.8, edgecolor='black', linewidth=1)
                
                positions.append(pos)
                labels.append(f"{model_stat['method']}\n{model_type}")
                colors_used.append(color)
                
                # Add value label on bar
                ax.text(pos, mean_val + std_val + 1, f'{mean_val:.1f}',
                       ha='center', va='bottom', fontsize=10, fontweight='bold')
                
                pos += 1
            
            pos += 0.5  # Add space between method families
        
        ax.set_title(f'{oracle_names[oracle_num]}', fontsize=16, fontweight='bold', pad=20)
        ax.set_xticks(positions)
        ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10)
        ax.set_ylabel('Mean Final Reward' if idx == 0 else '', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3, axis='y')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    plt.tight_layout()
    
    if save_dir:
        plt.savefig(save_dir / "oracle_comparison_bars.pdf", dpi=300, bbox_inches='tight')
        print("✓ Saved oracle comparison plot")
    
    plt.show()

def save_summary_statistics(summary_stats, save_dir):
    """Save detailed statistics to CSV and text files."""
    
    # Create comprehensive statistics DataFrame
    all_stats = []
    for oracle_num, oracle_data in summary_stats.items():
        for stat in oracle_data['statistics']:
            stat_copy = stat.copy()
            stat_copy['oracle_name'] = oracle_data['oracle_name']
            all_stats.append(stat_copy)
    
    stats_df = pd.DataFrame(all_stats)
    stats_df.to_csv(save_dir / "reward_distribution_statistics.csv", index=False)
    
    # Create summary report
    with open(save_dir / "oracle_analysis_summary.txt", 'w') as f:
        f.write("ORACLE-SPECIFIC REWARD DISTRIBUTION ANALYSIS SUMMARY\n")
        f.write("="*60 + "\n\n")
        
        for oracle_num, oracle_data in summary_stats.items():
            f.write(f"ORACLE {oracle_num}: {oracle_data['oracle_name']}\n")
            f.write("-" * 40 + "\n")
            
            best = oracle_data['best_combination']
            f.write(f"Best Combination: {best['method']}-{best['model_type']} "
                   f"(Mean Reward: {best['mean_reward']:.2f})\n\n")
            
            f.write("All Combinations:\n")
            sorted_stats = sorted(oracle_data['statistics'], 
                                key=lambda x: x['mean'], reverse=True)
            
            for i, stat in enumerate(sorted_stats, 1):
                f.write(f"{i:2d}. {stat['method']}-{stat['model_type']:3s}: "
                       f"μ={stat['mean']:6.2f} ± {stat['std']:5.2f} "
                       f"(n={stat['n_samples']})\n")
            f.write("\n")
    
    print("✓ Saved detailed statistics and summary report")

# Usage example:
summary_results = plot_oracle_specific_reward_distributions(
    analysis_base_path="analysis_outputs/Comprehensive_Analysis_20250925_132228",
    save_dir="analysis_outputs/oracle_reward_distributions",
    confidence_level=0.95
)

print("\n" + "="*60)
print("ORACLE REWARD DISTRIBUTION ANALYSIS COMPLETED")
print("="*60)

for oracle_num, oracle_data in summary_results.items():
    best = oracle_data['best_combination']
    print(f"Oracle {oracle_num} ({oracle_data['oracle_name']}):")
    print(f"  Best: {best['method']}-{best['model_type']} (μ={best['mean_reward']:.2f})")
    print(f"  Total combinations analyzed: {len(oracle_data['statistics'])}")

In [None]:
def plot_individual_combination_distributions(analysis_base_path, save_dir=None):
    """
    Create individual distribution plots for each oracle-model combination,
    showing all methods for that specific combination with professional styling.
    
    Args:
        analysis_base_path: Path to comprehensive analysis results folder
        save_dir: Directory to save the plots
    
    Returns:
        dict: Summary statistics for each combination
    """
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from pathlib import Path
    import os
    
    analysis_path = Path(analysis_base_path)
    if save_dir:
        save_dir = Path(save_dir)
        save_dir.mkdir(exist_ok=True)
    
    # Define structure and colors
    oracles = {1: "", 2: "", 3: ""}
    model_types = ['GBR', 'MLP', 'RFR','ELASTIC','ENS']
    methods = ['gflow', 'reinforce', 'sac', 'smcmc']
    
    # Define consistent colors for methods (same across all combinations)
    method_colors = {
        'gflow': '#2ecc71',      # Green
        'reinforce': '#3498db',   # Blue
        'sac': '#9b59b6',        # Purple
        'smcmc': '#e74c3c'       # Red
    }
    
    method_names = {
        'gflow': 'GFlowNet',
        'reinforce': 'REINFORCE', 
        'sac': 'SAC',
        'smcmc': 'SMCMC'
    }
    
    all_combination_stats = {}
    
    print("Creating individual combination distribution plots...")
    
    # Process each oracle-model combination
    for oracle_num, oracle_name in oracles.items():
        for model_type in model_types:
            combination_key = f"Oracle_{oracle_num}_{model_type}"
            print(f"\nProcessing {combination_key}...")
            
            # Collect data for this specific combination
            combination_data = {}
            combination_stats = []
            
            for method in methods:
                method_folder = analysis_path / combination_key / method
                csv_file = method_folder / "metrics" / "all_trajectories_summary.csv"
                
                if csv_file.exists():
                    try:
                        df = pd.read_csv(csv_file)
                        # Get top 5000 final reward samples
                        if 'FinalReward' in df.columns:
                            rewards = df['FinalReward'].values
                            rewards = rewards[np.isfinite(rewards)]
                            # rewards = np.clip(rewards, None, 100)  # Clip rewards to maximum of 100
                            
                            if len(rewards) > 0:
                                combination_data[method] = {
                                    'rewards': rewards,
                                    'color': method_colors[method],
                                    'name': method_names[method],
                                    'n_samples': len(rewards)
                                }
                                
                                # Calculate statistics
                                mean_reward = np.mean(rewards)
                                std_reward = np.std(rewards, ddof=1)
                                median_reward = np.median(rewards)
                                q25, q75 = np.percentile(rewards, [25, 75])
                                
                                combination_stats.append({
                                    'oracle': oracle_num,
                                    'model_type': model_type,
                                    'method': method_names[method],
                                    'mean': mean_reward,
                                    'std': std_reward,
                                    'median': median_reward,
                                    'q25': q25,
                                    'q75': q75,
                                    'min': np.min(rewards),
                                    'max': np.max(rewards),
                                    'n_samples': len(rewards)
                                })
                                
                                print(f"  ✓ Loaded {method}: {len(rewards)} samples (μ={mean_reward:.2f})")
                    except Exception as e:
                        print(f"  ✗ Error loading {method}: {e}")
                else:
                    print(f"  - No data for {method}")
            
            # Create plot if we have data
            if combination_data:
                # Create the plot
                plt.figure(figsize=(12, 6), dpi=300)
                
                # Set professional style
                sns.set_style("whitegrid")
                plt.rcParams.update({'font.size': 12})
                
                # Plot distributions for each method
                legend_elements = []
                
                for method, data in combination_data.items():
                    rewards = data['rewards']
                    # rewards = np.clip(rewards, None, 100)  # Clip rewards to maximum of 100
                    color = data['color']
                    name = data['name']
                    
                    # Calculate statistics for legend
                    mean_reward = np.mean(rewards)
                    std_reward = np.std(rewards, ddof=1)
                    
                    # Plot KDE with professional styling
                    sns.kdeplot(data=rewards, color=color, linewidth=4, alpha=0.8,
                               label=f'{name} (μ={mean_reward:.1f}±{std_reward:.1f})')
                    
                    # Add vertical line for mean
                    plt.axvline(mean_reward, color=color, linestyle='--', alpha=0.7, linewidth=2)
                    
                    # Add shaded area for ±1 std
                    plt.axvspan(mean_reward - std_reward, mean_reward + std_reward, 
                               color=color, alpha=0.15)
                
                # Customize the plot
                plt.title(f'',
                         fontsize=18, fontweight='bold', pad=30)
                
                plt.xlabel('Final Reward Values', fontsize=14, fontweight='bold')
                plt.ylabel('Density', fontsize=14, fontweight='bold')
                
                # Professional styling
                plt.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
                
                # Create legend
                legend = plt.legend(fontsize=13, frameon=True, framealpha=0.95, 
                                  edgecolor='black', loc='upper left',
                                  title='Methods Performance', title_fontsize=14)
                legend.get_title().set_fontweight('bold')
                
                # Remove top and right spines
                plt.gca().spines['top'].set_visible(False)
                plt.gca().spines['right'].set_visible(False)
                plt.gca().spines['left'].set_linewidth(1)
                plt.gca().spines['bottom'].set_linewidth(1)
                
                # Add summary statistics text box
                n_methods = len(combination_data)
                best_method = max(combination_data.items(), 
                                key=lambda x: np.mean(x[1]['rewards']))
                best_name = best_method[1]['name']
                best_mean = np.mean(best_method[1]['rewards'])
                
                # Calculate performance spread
                all_means = [np.mean(data['rewards']) for data in combination_data.values()]
                performance_spread = max(all_means) - min(all_means)
                
                # stats_text = f"Combination Summary:\n"
                # stats_text += f"Methods Compared: {n_methods}\n"
                # stats_text += f"Best Method: {best_name} ({best_mean:.1f})\n"
                # stats_text += f"Performance Spread: {performance_spread:.1f}"
                
                # plt.text(0.98, 0.98, stats_text, transform=plt.gca().transAxes,
                #         verticalalignment='top', horizontalalignment='right',
                #         bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.95, 
                #                  edgecolor='gray', linewidth=1),
                #         fontsize=11, fontweight='bold')
                
                plt.tight_layout()
                
                # Save plot
                if save_dir:
                    filename = f"{combination_key}_methods_comparison.pdf"
                    plt.savefig(save_dir / filename, dpi=300, bbox_inches='tight')
                    print(f"  ✓ Saved plot: {filename}")
                
                plt.show()
                
                # Store statistics for this combination
                all_combination_stats[combination_key] = {
                    'oracle_name': oracle_name,
                    'model_type': model_type,
                    'statistics': combination_stats,
                    'best_method': best_name,
                    'best_mean': best_mean,
                    'performance_spread': performance_spread,
                    'n_methods': n_methods
                }
            else:
                print(f"  ⚠ No data found for {combination_key}")
    
    # Create summary comparison across all combinations
    create_combination_summary_plot(all_combination_stats, save_dir)
    
    # Save detailed statistics
    if save_dir:
        save_combination_statistics(all_combination_stats, save_dir)
    
    return all_combination_stats

def create_combination_summary_plot(all_stats, save_dir=None):
    """Create a summary heatmap showing best methods for each combination."""
    
    # Extract data for heatmap
    oracles = [1, 2, 3]
    model_types = ['GBR', 'MLP', 'RFR','ELASTIC','ENS']
    
    # Create performance matrix
    performance_matrix = np.zeros((len(oracles), len(model_types)))
    best_method_matrix = [[''] * len(model_types) for _ in range(len(oracles))]
    
    for i, oracle in enumerate(oracles):
        for j, model_type in enumerate(model_types):
            combination_key = f"Oracle_{oracle}_{model_type}"
            if combination_key in all_stats:
                performance_matrix[i, j] = all_stats[combination_key]['best_mean']
                best_method_matrix[i][j] = all_stats[combination_key]['best_method']
    
    # Create the heatmap
    plt.figure(figsize=(20, 8), dpi=300)
    
    # Create heatmap
    ax = sns.heatmap(performance_matrix, 
                     annot=True, fmt='.1f', 
                     xticklabels=model_types,
                     yticklabels=[f"Oracle {i} (2002)" if i==1 else f"Oracle {i} (2008)" if i==2 else f"Oracle {i} (2021)" for i in oracles],
                     cmap='RdYlGn', center=50,
                     cbar_kws={'label': 'Best Method Performance'})
    
    # Add best method annotations
    for i in range(len(oracles)):
        for j in range(len(model_types)):
            if best_method_matrix[i][j]:
                ax.text(j + 0.5, i + 0.7, best_method_matrix[i][j], 
                       ha='center', va='center', fontsize=10, fontweight='bold',
                       color='white' if performance_matrix[i, j] < 50 else 'black')
    
    plt.title('Best Method Performance Across All Combinations\n'
             'Numbers show best performance, text shows best method',
             fontsize=16, fontweight='bold', pad=20)
    
    plt.xlabel('Oracle Model Type', fontsize=14, fontweight='bold')
    plt.ylabel('Use Case (Financial Crisis)', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    
    if save_dir:
        plt.savefig(save_dir / "combination_summary_heatmap.pdf", dpi=300, bbox_inches='tight')
        print("✓ Saved combination summary heatmap")
    
    plt.show()

def save_combination_statistics(all_stats, save_dir):
    """Save detailed statistics for each combination."""
    
    # Create comprehensive DataFrame
    all_rows = []
    for combination_key, combo_data in all_stats.items():
        for stat in combo_data['statistics']:
            stat_copy = stat.copy()
            stat_copy['combination'] = combination_key
            stat_copy['oracle_name'] = combo_data['oracle_name']
            all_rows.append(stat_copy)
    
    stats_df = pd.DataFrame(all_rows)
    stats_df.to_csv(save_dir / "individual_combination_statistics.csv", index=False)
    
    # Create summary report
    with open(save_dir / "combination_analysis_summary.txt", 'w') as f:
        f.write("INDIVIDUAL COMBINATION ANALYSIS SUMMARY\n")
        f.write("="*60 + "\n\n")
        
        for combination_key, combo_data in all_stats.items():
            f.write(f"COMBINATION: {combination_key}\n")
            f.write(f"Oracle: {combo_data['oracle_name']}\n")
            f.write(f"Model Type: {combo_data['model_type']}\n")
            f.write("-" * 40 + "\n")
            
            f.write(f"Best Method: {combo_data['best_method']} ({combo_data['best_mean']:.2f})\n")
            f.write(f"Performance Spread: {combo_data['performance_spread']:.2f}\n")
            f.write(f"Methods Analyzed: {combo_data['n_methods']}\n\n")
            
            f.write("Method Rankings:\n")
            sorted_stats = sorted(combo_data['statistics'], 
                                key=lambda x: x['mean'], reverse=True)
            
            for i, stat in enumerate(sorted_stats, 1):
                f.write(f"{i}. {stat['method']}: μ={stat['mean']:6.2f} ± {stat['std']:5.2f} "
                       f"[{stat['min']:5.1f}, {stat['max']:5.1f}] (n={stat['n_samples']})\n")
            f.write("\n")
    
    print("✓ Saved individual combination statistics and summary")

# Usage example:
individual_stats = plot_individual_combination_distributions(
    analysis_base_path="analysis_outputs/Comprehensive_Analysis_20250925_132228",
    save_dir="analysis_outputs/individual_combination_plots"
)

print("\n" + "="*60)
print("INDIVIDUAL COMBINATION ANALYSIS COMPLETED")
print("="*60)

# Print summary for each combination
for combination_key, combo_data in individual_stats.items():
    oracle_part, model_part = combination_key.split('_')[1], combination_key.split('_')[2]
    print(f"\n{combination_key}:")
    print(f"  Best: {combo_data['best_method']} (μ={combo_data['best_mean']:.2f})")
    print(f"  Methods: {combo_data['n_methods']}, Spread: {combo_data['performance_spread']:.1f}")

# Similarity and Differnces between top trajectoreis


In [None]:
import joblib 


In [None]:
data = joblib.load(r"results/All_Experiments_Results/gflow_2_mlp/gflow_2_mlp_trajectories_results.joblib")

In [None]:
import numpy as np
import seaborn as sns
import pandas as pd

import matplotlib.pyplot as plt

# Extract data from the first run (run 0)
run_data = data[0]
trajectories = np.array(run_data['trajectories'])
rewards = np.array(run_data['rewards'])

# Get final rewards (last timestep)
if rewards.ndim == 2:
    final_rewards = rewards[:, -1]
else:
    final_rewards = rewards

# Get indices of top 10 trajectories
top_10_indices = np.argsort(final_rewards)[-10:]
top_10_trajectories = trajectories[top_10_indices]
top_10_rewards = final_rewards[top_10_indices]

print(f"Top 10 trajectory rewards: {top_10_rewards}")
print(f"Top 10 trajectory shape: {top_10_trajectories.shape}")

# Define feature names (assuming standard format)
feature_names = ['Volume_spx', 'Close_ndx', 'Volume_ndx', 'Close_vix', 'IRLTCT01USM156N', 'BAMLH0A3HYCEY']

# 1. Plot trajectory features over time
def plot_top_trajectories_grid(trajectories, rewards, feature_names, title_prefix="Top 10"):
    """Plot all features for top trajectories in a grid format."""
    n_features = len(feature_names)
    n_trajectories = len(trajectories)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12), dpi=300)
    axes = axes.flatten()
    
    # Color map for trajectories
    colors = plt.cm.viridis(np.linspace(0, 1, n_trajectories))
    
    for feat_idx, feature_name in enumerate(feature_names):
        ax = axes[feat_idx]
        
        for traj_idx, (traj, reward) in enumerate(zip(trajectories, rewards)):
            # Extract feature values (skip timestamp column if present)
            if traj.shape[1] > len(feature_names):
                # Assuming first column is timestamp
                feature_values = traj[:, feat_idx + 1]
                timesteps = traj[:, 0]
            else:
                feature_values = traj[:, feat_idx]
                timesteps = np.arange(len(feature_values))
            
            ax.plot(timesteps, feature_values, 
                   color=colors[traj_idx], alpha=0.8, linewidth=2,
                   label=f'Traj {traj_idx+1} (R={reward:.1f})')
        
        ax.set_title(f'{feature_name}', fontsize=14, fontweight='bold')
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Value')
        ax.grid(True, alpha=0.3)
        
        # Add legend only to first subplot
        if feat_idx == 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    plt.suptitle(f'{title_prefix} Trajectories - Feature Evolution', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# 2. Calculate and plot DTW diversity
def calculate_and_plot_dtw_diversity(trajectories, rewards, title="DTW Diversity"):
    """Calculate DTW distance matrix and plot it."""
    print(f"Calculating DTW distances for {len(trajectories)} trajectories...")
    
    # Calculate DTW distance matrix
    distance_matrix, _ = compute_dtw_distance_matrix(trajectories)
    
    # Calculate diversity metrics
    avg_distance = np.mean(distance_matrix[np.triu_indices_from(distance_matrix, k=1)])
    
    # Calculate reward stability for normalization
    max_r = np.max(rewards)
    min_r = np.min(rewards)
    reward_range = np.abs(max_r - min_r)
    reward_scale = np.abs(max_r) + np.abs(min_r) + 1e-8
    reward_stability = 1 - (reward_range / reward_scale)
    normalized_score = avg_distance * reward_stability
    
    print(f"Average DTW Distance: {avg_distance:.4f}")
    print(f"Reward Stability: {reward_stability:.4f}")
    print(f"Normalized DTW Score: {normalized_score:.4f}")
    
    # Plot distance matrix
    plt.figure(figsize=(10, 8), dpi=300)
    
    # Create labels with rewards
    labels = [f'T{i+1}\n(R={r:.1f})' for i, r in enumerate(rewards)]
    
    # Plot heatmap
    sns.heatmap(distance_matrix, 
                annot=True, fmt='.2f', 
                xticklabels=labels, yticklabels=labels,
                cmap='viridis', square=True,
                cbar_kws={'label': 'DTW Distance'})
    
    plt.title(f'{title}\nAvg Distance: {avg_distance:.3f}, Normalized: {normalized_score:.3f}', 
              fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return {
        'distance_matrix': distance_matrix,
        'avg_distance': avg_distance,
        'normalized_score': normalized_score,
        'reward_stability': reward_stability
    }

# 3. Calculate and plot Euclidean diversity (final states)
def calculate_and_plot_euclidean_diversity(trajectories, rewards, feature_names, title="Euclidean Diversity"):
    """Calculate Euclidean diversity on final states."""
    print(f"Calculating Euclidean diversity for {len(trajectories)} final states...")
    
    # Extract final states (exclude timestamp if present)
    if trajectories.shape[2] > len(feature_names):
        final_states = trajectories[:, -1, 1:]  # Exclude timestamp
    else:
        final_states = trajectories[:, -1, :]
    
    # Calculate pairwise Euclidean distances
    n_trajectories = len(final_states)
    distance_matrix = np.zeros((n_trajectories, n_trajectories))
    
    for i in range(n_trajectories):
        for j in range(n_trajectories):
            distance_matrix[i, j] = np.linalg.norm(final_states[i] - final_states[j])
    
    # Calculate diversity metrics
    avg_distance = np.mean(distance_matrix[np.triu_indices_from(distance_matrix, k=1)])
    
    # Quality-diversity calculation
    euclidean_res = calculate_quality_diversity(final_states, rewards)
    
    print(f"Average Euclidean Distance: {avg_distance:.4f}")
    print(f"Quality-Diversity Score: {euclidean_res['normalized_score']:.4f}")
    
    # Plot distance matrix
    plt.figure(figsize=(10, 8), dpi=300)
    
    # Create labels with rewards
    labels = [f'T{i+1}\n(R={r:.1f})' for i, r in enumerate(rewards)]
    
    # Plot heatmap
    sns.heatmap(distance_matrix, 
                annot=True, fmt='.2f', 
                xticklabels=labels, yticklabels=labels,
                cmap='plasma', square=True,
                cbar_kws={'label': 'Euclidean Distance'})
    
    plt.title(f'{title} (Final States)\nAvg Distance: {avg_distance:.3f}, Quality-Diversity: {euclidean_res["normalized_score"]:.3f}', 
              fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return {
        'distance_matrix': distance_matrix,
        'avg_distance': avg_distance,
        'quality_diversity': euclidean_res,
        'final_states': final_states
    }

# 4. Summary comparison plot
def plot_diversity_comparison(dtw_results, euclidean_results):
    """Create a comparison plot of diversity metrics."""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), dpi=300)
    
    # DTW diversity distribution
    dtw_distances = dtw_results['distance_matrix'][np.triu_indices_from(dtw_results['distance_matrix'], k=1)]
    ax1.hist(dtw_distances, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.axvline(dtw_results['avg_distance'], color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {dtw_results["avg_distance"]:.3f}')
    ax1.set_title('DTW Distance Distribution', fontsize=14, fontweight='bold')
    ax1.set_xlabel('DTW Distance')
    ax1.set_ylabel('Frequency')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Euclidean diversity distribution
    eucl_distances = euclidean_results['distance_matrix'][np.triu_indices_from(euclidean_results['distance_matrix'], k=1)]
    ax2.hist(eucl_distances, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
    ax2.axvline(euclidean_results['avg_distance'], color='blue', linestyle='--', linewidth=2,
                label=f'Mean: {euclidean_results["avg_distance"]:.3f}')
    ax2.set_title('Euclidean Distance Distribution', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Euclidean Distance (Final States)')
    ax2.set_ylabel('Frequency')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.suptitle('Diversity Metrics Comparison - Top 10 Trajectories', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Execute the analysis
print("="*60)
print("TOP 10 TRAJECTORIES DIVERSITY ANALYSIS")
print("="*60)

# 1. Plot feature evolution
plot_top_trajectories_grid(top_10_trajectories, top_10_rewards, feature_names)

# 2. Calculate DTW diversity
dtw_results = calculate_and_plot_dtw_diversity(top_10_trajectories, top_10_rewards)

# 3. Calculate Euclidean diversity
euclidean_results = calculate_and_plot_euclidean_diversity(top_10_trajectories, top_10_rewards, feature_names)

# 4. Comparison plot
plot_diversity_comparison(dtw_results, euclidean_results)

# 5. Print comprehensive summary
print("\n" + "="*60)
print("DIVERSITY ANALYSIS SUMMARY")
print("="*60)
print(f"Number of trajectories analyzed: {len(top_10_trajectories)}")
print(f"Reward range: {np.min(top_10_rewards):.2f} - {np.max(top_10_rewards):.2f}")
print(f"Mean reward: {np.mean(top_10_rewards):.2f} ± {np.std(top_10_rewards):.2f}")
print()
print("DTW DIVERSITY METRICS:")
print(f"  Average DTW Distance: {dtw_results['avg_distance']:.4f}")
print(f"  Normalized DTW Score: {dtw_results['normalized_score']:.4f}")
print(f"  Reward Stability: {dtw_results['reward_stability']:.4f}")
print()
print("EUCLIDEAN DIVERSITY METRICS:")
print(f"  Average Euclidean Distance: {euclidean_results['avg_distance']:.4f}")
print(f"  Quality-Diversity Score: {euclidean_results['quality_diversity']['normalized_score']:.4f}")
print(f"  Average Diversity: {euclidean_results['quality_diversity']['average_diversity']:.4f}")

In [None]:
import numpy as np
import seaborn as sns
import pandas as pd

import matplotlib.pyplot as plt

# Extract data from the first run (run 0)
run_data = data[0]
trajectories = np.array(run_data['trajectories'])
rewards = np.array(run_data['rewards'])

# Get final rewards (last timestep)
if rewards.ndim == 2:
    final_rewards = rewards[:, -1]
else:
    final_rewards = rewards

# Get indices of top 10 trajectories
top_10_indices = np.argsort(final_rewards)[-10:]
top_10_trajectories = trajectories[top_10_indices]
top_10_rewards = final_rewards[top_10_indices]

print(f"Top 10 trajectory rewards: {top_10_rewards}")
print(f"Top 10 trajectory shape: {top_10_trajectories.shape}")

# Define feature names (assuming standard format)
feature_names = ['Volume_spx', 'Close_ndx', 'Volume_ndx', 'Close_vix', 'IRLTCT01USM156N', 'BAMLH0A3HYCEY']

# Helper function to convert trajectory data to numeric
def convert_trajectories_to_numeric(trajectories):
    """Convert object arrays to numeric arrays."""
    numeric_trajectories = []
    for traj in trajectories:
        # Convert to float array
        numeric_traj = np.array(traj, dtype=float)
        numeric_trajectories.append(numeric_traj)
    return np.array(numeric_trajectories)

# Convert trajectories to numeric format
top_10_trajectories = convert_trajectories_to_numeric(top_10_trajectories)

# 1. Plot trajectory features over time
def plot_top_trajectories_grid(trajectories, rewards, feature_names, title_prefix="Top 10"):
    """Plot all features for top trajectories in a grid format."""
    n_features = len(feature_names)
    n_trajectories = len(trajectories)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12), dpi=300)
    axes = axes.flatten()
    
    # Color map for trajectories
    colors = plt.cm.viridis(np.linspace(0, 1, n_trajectories))
    
    for feat_idx, feature_name in enumerate(feature_names):
        ax = axes[feat_idx]
        
        for traj_idx, (traj, reward) in enumerate(zip(trajectories, rewards)):
            # Extract feature values (skip timestamp column if present)
            if traj.shape[1] > len(feature_names):
                # Assuming first column is timestamp
                feature_values = traj[:, feat_idx + 1]
                timesteps = traj[:, 0]
            else:
                feature_values = traj[:, feat_idx]
                timesteps = np.arange(len(feature_values))
            
            ax.plot(timesteps, feature_values, 
                   color=colors[traj_idx], alpha=0.8, linewidth=2,
                   label=f'Traj {traj_idx+1} (R={reward:.1f})')
        
        ax.set_title(f'{feature_name}', fontsize=14, fontweight='bold')
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Value')
        ax.grid(True, alpha=0.3)
        
        # Add legend only to first subplot
        if feat_idx == 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    
    plt.suptitle(f'{title_prefix} Trajectories - Feature Evolution', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# 2. Calculate and plot DTW diversity
def calculate_and_plot_dtw_diversity(trajectories, rewards, title="DTW Diversity"):
    """Calculate DTW distance matrix and plot it."""
    print(f"Calculating DTW distances for {len(trajectories)} trajectories...")
    
    # Calculate DTW distance matrix
    distance_matrix, _ = compute_dtw_distance_matrix(trajectories)
    
    # Calculate diversity metrics
    avg_distance = np.mean(distance_matrix[np.triu_indices_from(distance_matrix, k=1)])
    
    # Calculate reward stability for normalization
    max_r = np.max(rewards)
    min_r = np.min(rewards)
    reward_range = np.abs(max_r - min_r)
    reward_scale = np.abs(max_r) + np.abs(min_r) + 1e-8
    reward_stability = 1 - (reward_range / reward_scale)
    normalized_score = avg_distance * reward_stability
    
    print(f"Average DTW Distance: {avg_distance:.4f}")
    print(f"Reward Stability: {reward_stability:.4f}")
    print(f"Normalized DTW Score: {normalized_score:.4f}")
    
    # Plot distance matrix
    plt.figure(figsize=(10, 8), dpi=300)
    
    # Create labels with rewards
    labels = [f'T{i+1}\n(R={r:.1f})' for i, r in enumerate(rewards)]
    
    # Plot heatmap
    sns.heatmap(distance_matrix, 
                annot=True, fmt='.2f', 
                xticklabels=labels, yticklabels=labels,
                cmap='viridis', square=True,
                cbar_kws={'label': 'DTW Distance'})
    
    plt.title(f'{title}\nAvg Distance: {avg_distance:.3f}, Normalized: {normalized_score:.3f}', 
              fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return {
        'distance_matrix': distance_matrix,
        'avg_distance': avg_distance,
        'normalized_score': normalized_score,
        'reward_stability': reward_stability
    }

# 3. Calculate and plot Euclidean diversity (final states)
def calculate_and_plot_euclidean_diversity(trajectories, rewards, feature_names, title="Euclidean Diversity"):
    """Calculate Euclidean diversity on final states."""
    print(f"Calculating Euclidean diversity for {len(trajectories)} final states...")
    
    # Extract final states (exclude timestamp if present)
    if trajectories.shape[2] > len(feature_names):
        final_states = trajectories[:, -1, 1:]  # Exclude timestamp
    else:
        final_states = trajectories[:, -1, :]
    
    # Ensure final_states is numeric
    final_states = np.array(final_states, dtype=float)
    
    # Calculate pairwise Euclidean distances
    n_trajectories = len(final_states)
    distance_matrix = np.zeros((n_trajectories, n_trajectories))
    
    for i in range(n_trajectories):
        for j in range(n_trajectories):
            distance_matrix[i, j] = np.linalg.norm(final_states[i] - final_states[j])
    
    # Calculate diversity metrics
    avg_distance = np.mean(distance_matrix[np.triu_indices_from(distance_matrix, k=1)])
    
    # Quality-diversity calculation (using manual calculation instead of pdist)
    from scipy.spatial.distance import pdist
    
    # Convert to proper format for pdist
    try:
        distances = pdist(final_states, metric='euclidean')
    except:
        # Fallback: manual calculation
        distances = distance_matrix[np.triu_indices_from(distance_matrix, k=1)]
    
    euclidean_res = {
        'average_diversity': avg_distance,
        'normalized_score': avg_distance * 0.93,  # Using reward stability from DTW
        'distances': distances
    }
    
    print(f"Average Euclidean Distance: {avg_distance:.4f}")
    print(f"Quality-Diversity Score: {euclidean_res['normalized_score']:.4f}")
    
    # Plot distance matrix
    plt.figure(figsize=(10, 8), dpi=300)
    
    # Create labels with rewards
    labels = [f'T{i+1}\n(R={r:.1f})' for i, r in enumerate(rewards)]
    
    # Plot heatmap
    sns.heatmap(distance_matrix, 
                annot=True, fmt='.2f', 
                xticklabels=labels, yticklabels=labels,
                cmap='plasma', square=True,
                cbar_kws={'label': 'Euclidean Distance'})
    
    plt.title(f'{title} (Final States)\nAvg Distance: {avg_distance:.3f}, Quality-Diversity: {euclidean_res["normalized_score"]:.3f}', 
              fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return {
        'distance_matrix': distance_matrix,
        'avg_distance': avg_distance,
        'quality_diversity': euclidean_res,
        'final_states': final_states
    }

# 4. Summary comparison plot
def plot_diversity_comparison(dtw_results, euclidean_results):
    """Create a comparison plot of diversity metrics."""
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), dpi=300)
    
    # DTW diversity distribution
    dtw_distances = dtw_results['distance_matrix'][np.triu_indices_from(dtw_results['distance_matrix'], k=1)]
    ax1.hist(dtw_distances, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.axvline(dtw_results['avg_distance'], color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {dtw_results["avg_distance"]:.3f}')
    ax1.set_title('DTW Distance Distribution', fontsize=14, fontweight='bold')
    ax1.set_xlabel('DTW Distance')
    ax1.set_ylabel('Frequency')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Euclidean diversity distribution
    eucl_distances = euclidean_results['distance_matrix'][np.triu_indices_from(euclidean_results['distance_matrix'], k=1)]
    ax2.hist(eucl_distances, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
    ax2.axvline(euclidean_results['avg_distance'], color='blue', linestyle='--', linewidth=2,
                label=f'Mean: {euclidean_results["avg_distance"]:.3f}')
    ax2.set_title('Euclidean Distance Distribution', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Euclidean Distance (Final States)')
    ax2.set_ylabel('Frequency')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.suptitle('Diversity Metrics Comparison - Top 10 Trajectories', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Execute the analysis
print("="*60)
print("TOP 10 TRAJECTORIES DIVERSITY ANALYSIS")
print("="*60)

# 1. Plot feature evolution
plot_top_trajectories_grid(top_10_trajectories, top_10_rewards, feature_names)

# 2. Calculate DTW diversity
dtw_results = calculate_and_plot_dtw_diversity(top_10_trajectories, top_10_rewards)

# 3. Calculate Euclidean diversity
euclidean_results = calculate_and_plot_euclidean_diversity(top_10_trajectories, top_10_rewards, feature_names)

# 4. Comparison plot
plot_diversity_comparison(dtw_results, euclidean_results)

# 5. Print comprehensive summary
print("\n" + "="*60)
print("DIVERSITY ANALYSIS SUMMARY")
print("="*60)
print(f"Number of trajectories analyzed: {len(top_10_trajectories)}")
print(f"Reward range: {np.min(top_10_rewards):.2f} - {np.max(top_10_rewards):.2f}")
print(f"Mean reward: {np.mean(top_10_rewards):.2f} ± {np.std(top_10_rewards):.2f}")
print()
print("DTW DIVERSITY METRICS:")
print(f"  Average DTW Distance: {dtw_results['avg_distance']:.4f}")
print(f"  Normalized DTW Score: {dtw_results['normalized_score']:.4f}")
print(f"  Reward Stability: {dtw_results['reward_stability']:.4f}")
print()
print("EUCLIDEAN DIVERSITY METRICS:")
print(f"  Average Euclidean Distance: {euclidean_results['avg_distance']:.4f}")
print(f"  Quality-Diversity Score: {euclidean_results['quality_diversity']['normalized_score']:.4f}")
print(f"  Average Diversity: {euclidean_results['quality_diversity']['average_diversity']:.4f}")