In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple
from dataclasses import dataclass
import matplotlib as mpl

@dataclass
class DirectoryInfo:
    path: str
    label: str

@dataclass
class DirectoryStats:
    avg_speak: float
    avg_speak_failed: float
    avg_act: float
    avg_act_failed: float
    success_rate: float
    total_files: int
    valid_files: int
    failed_files: int

def detect_format(content: str) -> str:
    """Detect if the file is in old or new format"""
    lines = content.split('\n')
    for line in lines:
        if '> speak:' in line or '> act:' in line or '> think:' in line:
            return 'new'
        if 'Act' in line and 'speak:' in line:
            return 'old'
    return 'unknown'

def count_actions_old_format(content: str) -> Tuple[int, int]:
    """Count speak and act in old format"""
    speak_count = sum(1 for line in content.split('\n') 
                     if 'Act' in line and 'speak:' in line)
    act_count = sum(1 for line in content.split('\n')
                   if 'Act' in line and 'speak:' not in line and 'think:' not in line)
    return speak_count, act_count

def count_actions_new_format(content: str) -> Tuple[int, int]:
    """Count speak and act in new format"""
    speak_count = sum(1 for line in content.split('\n') 
                     if '> speak:' in line)
    act_count = sum(1 for line in content.split('\n')
                   if '> act:' in line and '> think:' not in line)
    return speak_count, act_count

def analyze_directory(directory_path: str) -> DirectoryStats:
    """Analyze a single directory to compute statistics"""
    speak_counts = []
    speak_counts_failed = []
    act_counts = []
    act_counts_failed = []
    total_files = 0
    valid_files = 0
    failed_files = 0
    format_counts = {'old': 0, 'new': 0, 'unknown': 0}
    
    try:
        files = [f for f in os.listdir(directory_path) if f.endswith('.txt')]
        total_files = len(files)
        
        for filename in files:
            file_path = os.path.join(directory_path, filename)
            is_failed = 'failed' in filename
            
            try:
                with open(file_path, 'r', encoding='utf-8') as file:
                    content = file.read()
                    
                    # Detect format and use appropriate counting function
                    file_format = detect_format(content)
                    format_counts[file_format] += 1
                    
                    if file_format == 'new':
                        speak_count, act_count = count_actions_new_format(content)
                    else:  # default to old format if unknown
                        speak_count, act_count = count_actions_old_format(content)
                    
                    if is_failed:
                        speak_counts_failed.append(speak_count)
                        act_counts_failed.append(act_count)
                        failed_files += 1
                    else:
                        speak_counts.append(speak_count)
                        act_counts.append(act_count)
                        valid_files += 1
                    
            except Exception as e:
                print(f"Error processing file {filename}: {str(e)}")
        
        # Print format statistics
        print(f"\nDirectory: {directory_path}")
        print(f"Format counts: {format_counts}")
        
        # Calculate averages
        avg_speak = sum(speak_counts) / len(speak_counts) if speak_counts else 0
        avg_act = sum(act_counts) / len(act_counts) if act_counts else 0
        avg_speak_failed = sum(speak_counts_failed) / len(speak_counts_failed) if speak_counts_failed else 0
        avg_act_failed = sum(act_counts_failed) / len(act_counts_failed) if act_counts_failed else 0
        success_rate = valid_files / total_files if total_files > 0 else 0
        
        return DirectoryStats(
            avg_speak=avg_speak,
            avg_speak_failed=avg_speak_failed,
            avg_act=avg_act,
            avg_act_failed=avg_act_failed,
            success_rate=success_rate,
            total_files=total_files,
            valid_files=valid_files,
            failed_files=failed_files
        )
        
    except Exception as e:
        print(f"Error accessing directory {directory_path}: {str(e)}")
        return DirectoryStats(0, 0, 0, 0, 0, 0, 0, 0)

def set_paper_style():
    """Set publication-quality style parameters with larger fonts"""
    plt.style.use('default')
    mpl.rcParams.update({
        'font.size': 12,  
        'axes.titlesize': 18,
        'axes.labelsize': 16,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 15,
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.1
    })

def create_visualizations(stats: Dict[str, DirectoryStats]):
    """Create publication-ready visualizations"""
    set_paper_style()
    
    #settings
    labels = list(stats.keys())
    bar_width = 0.25
    colors = ['#cfebbf', '#f08080', '#88c1f2']
    
    # Create figure with row-wise subplots
    fig = plt.figure(figsize=(12, 10))  # Slightly larger figure
    gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.3)
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1])
    
    # 1. Success Rate Plot
    success_rates = [s.success_rate * 100 for s in stats.values()]
    bars1 = ax1.bar(labels, success_rates, color='#e9c8fd', width=0.6)
    
    # Set y-axis limits to zoom in on the relevant range
    min_rate = min(success_rates) - 0.5  # Add small padding below
    max_rate = max(success_rates) + 0.5  # Add small padding above
    ax1.set_ylim(50, 60)  # Fixed range from 50 to 60 
    
    ax1.set_xticklabels([])
    ax1.tick_params(axis='x', length=0)
    ax1.set_ylabel('Success Rate (%)')
    ax1.grid(False)
    
    # Add value labels
    for bar in bars1:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=14)
    
    # 2. Combined Count Comparison
    x = np.arange(len(labels))
    speaks_success = [s.avg_speak for s in stats.values()]
    speaks_failed = [s.avg_speak_failed for s in stats.values()]
    acts_success = [s.avg_act for s in stats.values()]
    
    # Create grouped bars
    bars2_1 = ax2.bar(x - bar_width, speaks_success, bar_width, 
                      label='Avg. Dialog turns (Successful trajectories)', 
                      color=colors[0])
    bars2_2 = ax2.bar(x, speaks_failed, bar_width,
                      label='Avg. Dialog turns (Unsuccessful trajectories)', 
                      color=colors[1])
    bars2_3 = ax2.bar(x + bar_width, acts_success, bar_width,
                      label='Avg. Agent Actions (Successful trajectories)', 
                      color=colors[2])
    
    ax2.set_ylabel('Average Count')
    ax2.set_xticks(x)
    ax2.set_xticklabels(labels, rotation=30, ha='right')
    ax2.legend(frameon=True, fancybox=True, framealpha=0.9, 
              loc='upper right', bbox_to_anchor=(1, 1.02))
    ax2.grid(False)
    
    # Add value labels
    for bars in [bars2_1, bars2_2, bars2_3]:
        for bar in bars:
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.1f}', ha='center', va='bottom', fontsize=14)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save with high quality
    plt.savefig('trial_analysis.pdf', format='pdf', bbox_inches='tight')
    plt.savefig('trial_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

    
    
def create_visualizations_from_dict(stats_dict):
    """Create visualizations directly from statistics dictionary"""
    set_paper_style()
    
    # Common settings
    labels = list(stats_dict.keys())
    bar_width = 0.25
    colors = ['#cfebbf', '#f08080', '#88c1f2']
    
    # Create figure with row-wise subplots
    fig = plt.figure(figsize=(12, 10))
    gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.3)
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1])
    
    # 1. Success Rate Plot
    success_rates = [stats_dict[label]["success_rate"] for label in labels]
    bars1 = ax1.bar(labels, success_rates, color='#e9c8fd', width=0.6)
    
    # Set y-axis limits to zoom in on the relevant range
    ax1.set_ylim(50, 60)
    ax1.set_xticklabels([])
    ax1.tick_params(axis='x', length=0)
    ax1.set_ylabel('Success Rate (%)')
    ax1.grid(False)
    
    # Add value labels
    for bar in bars1:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=14)
    
    # 2. Combined Count Comparison
    x = np.arange(len(labels))
    speaks_success = [stats_dict[label]["avg_speak_success"] for label in labels]
    speaks_failed = [stats_dict[label]["avg_speak_failed"] for label in labels]
    acts_success = [stats_dict[label]["avg_act_success"] for label in labels]
    
    # Create grouped bars
    bars2_1 = ax2.bar(x - bar_width, speaks_success, bar_width, 
                      label='Avg. Dialog turns (Successful trajectories)', 
                      color=colors[0])
    bars2_2 = ax2.bar(x, speaks_failed, bar_width,
                      label='Avg. Dialog turns (Unsuccessful trajectories)', 
                      color=colors[1])
    bars2_3 = ax2.bar(x + bar_width, acts_success, bar_width,
                      label='Avg. Agent Actions (Successful trajectories)', 
                      color=colors[2])
    
    ax2.set_ylabel('Average Count')
    ax2.set_xticks(x)
    ax2.set_xticklabels(labels, rotation=30, ha='right')
    ax2.legend(frameon=True, fancybox=True, framealpha=0.9, 
              loc='upper right', bbox_to_anchor=(1, 1.02))
    ax2.grid(False)
    
    # Add value labels
    for bars in [bars2_1, bars2_2, bars2_3]:
        for bar in bars:
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.1f}', ha='center', va='bottom', fontsize=14)
    
    plt.tight_layout()
    plt.savefig('trial_analysis.pdf', format='pdf', bbox_inches='tight')
    plt.savefig('trial_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
def main():
    directories = [
        DirectoryInfo(
            path="./results/React/React-Opt",
            label="No\nDialog"
        ),
        DirectoryInfo(
            path="./results/Respact/Respact-Opt-Friction-v0",
            label="Probing"
        ),
        DirectoryInfo(
            path="./results/Respact/Respact-Opt-Friction-v1",
            label="Assumption Reveal"
        ),
        DirectoryInfo(
            path="./results/Respact/Respact-Opt-Friction-v2",
            label="Overspecification"
        )
    ]
    
    # Collect statistics
    stats = {dir_info.label: analyze_directory(dir_info.path) 
            for dir_info in directories}
    
    # Create visualizations
    create_visualizations(stats)

if __name__ == "__main__":
    main()