# Plotting 

In [None]:

import os
import json
from glob import glob
from collections import defaultdict


import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from pathlib import Path
from statistics import mean
import matplotlib.pyplot as plt

In [None]:
# ------------------------------- setting start ------------------------------ #
# color
color_palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
errorbar_color = "#3A3A3A"

# font
csfont = {'family':'Times New Roman', 'serif': 'Times' , 'size' : 23}
plt.rc('text', usetex=True)
plt.rc('font', **csfont)


# bar plot size
bar_width = 0.4
bar_btw_space = 0.04
bar_space = 0.2

# errorbar plot size
err_lw=1.5
err_capsize=4
err_capthick=1.5

# set fig size
figsize=(6.4, 4.8)
# -------------------------------- setting end ------------------------------- #

## Figure Budget

In [None]:



# [TODO] fix y-axis range to enable comparison across plots
# the exact ranges for each set of figures might differ
def set_metric_ylim(ax, metric_key):
    """
    Set fixed y-axis limits for different metrics to ensure consistency across plots.
    
    Args:
        ax: matplotlib axes object
        metric_key: string, one of 'PSNR', 'SSIM', 'LPIPS'
    """
    if metric_key == 'PSNR':
        ax.set_ylim(24.5, 35.5)
    elif metric_key == 'SSIM':
        ax.set_ylim(0.8, 1.0)
    elif metric_key == 'LPIPS':
        ax.set_ylim(0, 0.2)

def budget_policy_curves():
    """
    Plot metrics vs budget for different budgeting policies.
    
    - Area-based budgeting - area_*_occlusion (0, 40k, 80k, 160k, 320k, 640k)
    - Distortion-based budgeting - distortion_*_occlusion (0, 40k, 80k, 160k, 320k, 640k)
    - Uniform budgeting - uniform_*_occlusion (0, 40k, 80k, 160k, 320k, 640k)
    """
    ITERATION = 'ours_15000'  # Last iteration for GS+Mesh
    MESH_ITERATION = 'ours_1'  # Pure mesh uses different key
    
    input_dir = Path('./data') / SCENE_NAME
    output_dir = Path('./plots') / 'budget_policy_curves' / SCENE_NAME
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Define policies and their budgets
    budgets = [40000, 80000, 160000, 320000, 640000]
    
    # Quality Metrics to plot
    metrics = {
        'PSNR': {'ylabel': 'PSNR (dB)', 'title': 'PSNR'},
        'SSIM': {'ylabel': 'SSIM', 'title': 'SSIM'},
        # 'LPIPS': {'ylabel': 'LPIPS', 'title': 'LPIPS'}
        # [TODO] draw LPIPS soon
    }
    
    # Plot each metric
    for metric_key, metric_info in metrics.items():
        fig, ax = plt.subplots(figsize=figsize)
        
        print(f"\n{'='*60}")
        print(f"Plotting {metric_key}")
        print(f"{'='*60}")
        
        # 1. Get Pure mesh baseline value - area_1_occlusion
        mesh_mean = None
        mesh_stderr = None
        
        mesh_file = input_dir / 'area_1_occlusion' / 'per_view_gs_mesh.json'
        print(f"\nChecking mesh file: {mesh_file}")
        print(f"File exists: {mesh_file.exists()}")
        
        if mesh_file.exists():
            with open(mesh_file, 'r') as f:
                data = json.load(f)
            
            print(f"Available keys: {list(data.keys())}")
            
            # Try mesh iteration key first, then regular iteration key
            iter_key = MESH_ITERATION if MESH_ITERATION in data else ITERATION
            
            if iter_key in data:
                metric_data = data[iter_key][metric_key]
                if isinstance(metric_data, dict):
                    values = [v for v in metric_data.values() if v != -1.0]
                else:
                    values = [metric_data]
                
                mesh_mean = np.mean(values)
                std_val = np.std(values)
                mesh_stderr = std_val / np.sqrt(len(values))
                
                print(f"Pure Mesh - {metric_key}: mean={mesh_mean:.4f}, std={std_val:.4f}, n={len(values)}")
            else:
                print(f"Neither {MESH_ITERATION} nor {ITERATION} found in mesh file")
        
        # 2. Area-based budgeting
        area_xs = []
        area_ys = []
        area_errs = []
        
        # Add pure mesh point at budget=0
        if mesh_mean is not None:
            area_xs.append(0)
            area_ys.append(mesh_mean)
            area_errs.append(mesh_stderr)
        
        print(f"\nProcessing Area-based budgeting:")
        for budget in budgets:
            area_file = input_dir / f'area_{budget}_occlusion' / 'per_view_gs_mesh.json'
            print(f"  Checking: {area_file.name}, exists: {area_file.exists()}")
            
            if area_file.exists():
                with open(area_file, 'r') as f:
                    data = json.load(f)
                
                if ITERATION in data:
                    metric_data = data[ITERATION][metric_key]
                    if isinstance(metric_data, dict):
                        values = [v for v in metric_data.values() if v != -1.0]
                    else:
                        values = [metric_data]
                    
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    stderr = std_val / np.sqrt(len(values))
                    num_splats = data[ITERATION].get('num_splats', budget)
                    
                    area_xs.append(num_splats)
                    area_ys.append(mean_val)
                    area_errs.append(stderr)
                    
                    print(f"    Budget {budget}: mean={mean_val:.4f}, std={std_val:.4f}, splats={num_splats}")
                else:
                    print(f"    {ITERATION} not found in {area_file.name}")
        
        if area_xs:
            print(f"  Plotting {len(area_xs)} area points")
            ax.errorbar(area_xs, area_ys, yerr=area_errs,
                       marker='o', markersize=8, linewidth=2.5,
                       capsize=err_capsize, capthick=err_capthick,
                       color=color_palette[1], label='Area-based', zorder=2)
        else:
            print("  No area data to plot!")
        
        # 3. Distortion-based budgeting
        dist_xs = []
        dist_ys = []
        dist_errs = []
        
        # Add pure mesh point at budget=0
        if mesh_mean is not None:
            dist_xs.append(0)
            dist_ys.append(mesh_mean)
            dist_errs.append(mesh_stderr)
        
        print(f"\nProcessing Distortion-based budgeting:")
        for budget in budgets:
            dist_file = input_dir / f'distortion_{budget}_occlusion' / 'per_view_gs_mesh.json'
            print(f"  Checking: {dist_file.name}, exists: {dist_file.exists()}")
            
            if dist_file.exists():
                with open(dist_file, 'r') as f:
                    data = json.load(f)
                
                if ITERATION in data:
                    metric_data = data[ITERATION][metric_key]
                    if isinstance(metric_data, dict):
                        values = [v for v in metric_data.values() if v != -1.0]
                    else:
                        values = [metric_data]
                    
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    stderr = std_val / np.sqrt(len(values))
                    num_splats = data[ITERATION].get('num_splats', budget)
                    
                    dist_xs.append(num_splats)
                    dist_ys.append(mean_val)
                    dist_errs.append(stderr)
                    
                    print(f"    Budget {budget}: mean={mean_val:.4f}, std={std_val:.4f}, splats={num_splats}")
                else:
                    print(f"    {ITERATION} not found in {dist_file.name}")
        
        if dist_xs:
            print(f"  Plotting {len(dist_xs)} distortion points")
            ax.errorbar(dist_xs, dist_ys, yerr=dist_errs,
                       marker='s', markersize=8, linewidth=2.5,
                       capsize=err_capsize, capthick=err_capthick,
                       color=color_palette[2], label='Distortion-based', zorder=2)
        else:
            print("  No distortion data to plot!")
        
        # 4. Uniform budgeting
        uniform_xs = []
        uniform_ys = []
        uniform_errs = []
        
        # Add pure mesh point at budget=0
        if mesh_mean is not None:
            uniform_xs.append(0)
            uniform_ys.append(mesh_mean)
            uniform_errs.append(mesh_stderr)
        
        print(f"\nProcessing Uniform budgeting:")
        for budget in budgets:
            uniform_file = input_dir / f'uniform_{budget}_occlusion' / 'per_view_gs_mesh.json'
            print(f"  Checking: {uniform_file.name}, exists: {uniform_file.exists()}")
            
            if uniform_file.exists():
                with open(uniform_file, 'r') as f:
                    data = json.load(f)
                
                if ITERATION in data:
                    metric_data = data[ITERATION][metric_key]
                    if isinstance(metric_data, dict):
                        values = [v for v in metric_data.values() if v != -1.0]
                    else:
                        values = [metric_data]
                    
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    stderr = std_val / np.sqrt(len(values))
                    num_splats = data[ITERATION].get('num_splats', budget)
                    
                    uniform_xs.append(num_splats)
                    uniform_ys.append(mean_val)
                    uniform_errs.append(stderr)
                    
                    print(f"    Budget {budget}: mean={mean_val:.4f}, std={std_val:.4f}, splats={num_splats}")
                else:
                    print(f"    {ITERATION} not found in {uniform_file.name}")
        
        if uniform_xs:
            print(f"  Plotting {len(uniform_xs)} uniform points")
            ax.errorbar(uniform_xs, uniform_ys, yerr=uniform_errs,
                       marker='.', markersize=8, linewidth=2.5,
                       capsize=err_capsize, capthick=err_capthick,
                       color=color_palette[3], label='Uniform', zorder=2)
        else:
            print("  No uniform data to plot!")
        
        # Formatting
        ax.set_xlabel('Bit Budget (\#Gaussians K)', fontsize=20)
        ax.set_ylabel(f"Quality in {metric_info['ylabel']}", fontsize=20)
        # ax.set_title(f"{metric_info['title']} vs. Budget ({SCENE_NAME})", fontsize=22)
        
        # Set fixed y-axis range
        set_metric_ylim(ax, metric_key)
        
        # ax.grid(True, alpha=0.3, linestyle='--')
        ax.legend(loc='best', framealpha=0.9, fontsize=18)
        ax.tick_params(labelsize=18)
        
        fig.set_constrained_layout(True)
        
        # Save both formats
        base_name = f'{metric_key}_vs_budget_{SCENE_NAME}'
        plt.savefig(output_dir / f'{base_name}.png', dpi=300, bbox_inches='tight')
        plt.savefig(output_dir / f'{base_name}.eps', format='eps', bbox_inches='tight')
        print(f"\nSaved: {base_name}.png and .eps\n")
        
        plt.show()
        plt.close()

SCENE_NAME = 'ship'

budget_policy_curves()

## Figure Delta (DTGS - Pure Mesh)

In [None]:
def set_metric_ylim_delta(ax, metric_key):
    """
    Set fixed y-axis limits for delta plots.
    
    Args:
        ax: matplotlib axes object
        metric_key: string, one of 'PSNR', 'SSIM', 'LPIPS'
    """
    if metric_key == 'PSNR':
        ax.set_ylim(-0.5, 8.0)  # Delta range for PSNR improvement
    elif metric_key == 'SSIM':
        ax.set_ylim(-0.01, 0.15)  # Delta range for SSIM improvement
    elif metric_key == 'LPIPS':
        ax.set_ylim(-0.15, 0.01)  # Delta range for LPIPS improvement (lower is better)


def policy_budget_delta_curves():
    """
    Plot average delta (improvement over pure mesh) across all scenes.
    
    Average across all nerf-synthetic scenes 
    (the 5 scenes we used are ficus, hotdog, lego, mic, ship)
    x-axis: budget number (0, 40k, 80k, 160k, 320k, 640k)
    y-axis: delta of quality in PSNR/SSIM/LPIPS (compared to pure mesh baseline)
    hue: different budgeting policies
    """
    ITERATION = 'ours_15000'
    MESH_ITERATION = 'ours_1'
    
    SCENES = ['ficus', 'hotdog', 'lego', 'mic', 'ship']
    budgets = [40000, 80000, 160000, 320000, 640000]
    policies = ['area', 'distortion', 'uniform']
    
    input_base = Path('./data')
    output_dir = Path('./plots') / 'budget_policy_delta_curves'
    output_dir.mkdir(parents=True, exist_ok=True)
    
    metrics = {
        'PSNR': {'ylabel': r'$\Delta$ in PSNR (dB)', 'title': 'PSNR Improvement'},
        'SSIM': {'ylabel': r'$\Delta$ in SSIM', 'title': 'SSIM Improvement'},
        # 'LPIPS': {'ylabel': r'$\Delta$ LPIPS', 'title': 'LPIPS Improvement'}
    }
    
    # For each metric, create one plot
    for metric_key, metric_info in metrics.items():
        fig, ax = plt.subplots(figsize=figsize)
        
        print(f"\n{'='*60}")
        print(f"Computing Average Delta for {metric_key}")
        print(f"{'='*60}")
        
        # Store data for each policy
        policy_data = {
            'area': {'xs': [], 'ys': [], 'errs': []},
            'distortion': {'xs': [], 'ys': [], 'errs': []},
            'uniform': {'xs': [], 'ys': [], 'errs': []}
        }
        
        # For each budget level
        for budget in budgets:
            print(f"\nProcessing budget: {budget}")
            
            # Collect deltas for each policy across all scenes
            policy_deltas = {p: [] for p in policies}
            
            for scene in SCENES:
                scene_dir = input_base / scene
                
                # Get pure mesh baseline for this scene
                mesh_file = scene_dir / 'area_1_occlusion' / 'per_view_gs_mesh.json'
                
                if not mesh_file.exists():
                    print(f"  [WARN] Missing mesh file for {scene}")
                    continue
                
                with open(mesh_file, 'r') as f:
                    mesh_data = json.load(f)
                
                iter_key = MESH_ITERATION if MESH_ITERATION in mesh_data else ITERATION
                
                if iter_key not in mesh_data:
                    print(f"  [WARN] No iteration key found in mesh file for {scene}")
                    continue
                
                # Get baseline metric value
                mesh_metric = mesh_data[iter_key][metric_key]
                if isinstance(mesh_metric, dict):
                    mesh_values = [v for v in mesh_metric.values() if v != -1.0]
                    baseline = np.mean(mesh_values)
                else:
                    baseline = mesh_metric
                
                # For each policy, compute delta
                for policy in policies:
                    policy_file = scene_dir / f'{policy}_{budget}_occlusion' / 'per_view_gs_mesh.json'
                    
                    if not policy_file.exists():
                        print(f"  [WARN] Missing {policy} file for {scene} at budget {budget}")
                        continue
                    
                    with open(policy_file, 'r') as f:
                        policy_file_data = json.load(f)
                    
                    if ITERATION not in policy_file_data:
                        print(f"  [WARN] No {ITERATION} in {policy} file for {scene}")
                        continue
                    
                    # Get policy metric value
                    policy_metric = policy_file_data[ITERATION][metric_key]
                    if isinstance(policy_metric, dict):
                        policy_values = [v for v in policy_metric.values() if v != -1.0]
                        policy_val = np.mean(policy_values)
                    else:
                        policy_val = policy_metric
                    
                    # Compute delta (improvement over baseline)
                    # For LPIPS, lower is better, so we flip the sign
                    if metric_key == 'LPIPS':
                        delta = baseline - policy_val  # Positive = improvement
                    else:
                        delta = policy_val - baseline  # Positive = improvement
                    
                    policy_deltas[policy].append(delta)
                    print(f"  {scene} {policy}: baseline={baseline:.4f}, policy={policy_val:.4f}, delta={delta:.4f}")
            
            # Compute average delta and standard error across scenes
            for policy in policies:
                if len(policy_deltas[policy]) > 0:
                    deltas = np.array(policy_deltas[policy])
                    mean_delta = np.mean(deltas)
                    std_delta = np.std(deltas)
                    stderr_delta = std_delta / np.sqrt(len(deltas))
                    
                    policy_data[policy]['xs'].append(budget)
                    policy_data[policy]['ys'].append(mean_delta)
                    policy_data[policy]['errs'].append(stderr_delta)
                    
                    print(f"  {policy} average: mean={mean_delta:.4f}, std={std_delta:.4f}, n={len(deltas)}")
        
        # Add budget=0 point (delta=0 by definition)
        for policy in policies:
            if len(policy_data[policy]['xs']) > 0:
                policy_data[policy]['xs'].insert(0, 0)
                policy_data[policy]['ys'].insert(0, 0.0)
                policy_data[policy]['errs'].insert(0, 0.0)
        
        # Plot each policy
        policy_styles = {
            'area': {'marker': 'o', 'color': color_palette[1], 'label': 'Area-based'},
            'distortion': {'marker': 's', 'color': color_palette[2], 'label': 'Distortion-based'},
            'uniform': {'marker': '^', 'color': color_palette[3], 'label': 'Uniform'}
        }
        
        for policy in policies:
            if len(policy_data[policy]['xs']) > 0:
                style = policy_styles[policy]
                ax.errorbar(
                    policy_data[policy]['xs'], 
                    policy_data[policy]['ys'], 
                    yerr=policy_data[policy]['errs'],
                    marker=style['marker'], 
                    markersize=8, 
                    linewidth=2.5,
                    capsize=err_capsize, 
                    capthick=err_capthick,
                    color=style['color'], 
                    label=style['label'], 
                    zorder=2
                )
                print(f"\n[INFO] Plotted {len(policy_data[policy]['xs'])} points for {policy}")
        
        # Add horizontal line at y=0 (no improvement)
        ax.axhline(y=0, color='gray', linestyle='--', linewidth=1, alpha=0.5, zorder=1)
        
        # Formatting
        ax.set_xlabel('Bit Budget (\#Gaussians K)', fontsize=20)
        ax.set_ylabel(metric_info['ylabel'], fontsize=20)
        
        # Set fixed y-axis range
        set_metric_ylim_delta(ax, metric_key)
        
        ax.legend(loc='best', framealpha=0.9, fontsize=18)
        ax.tick_params(labelsize=18)
        
        fig.set_constrained_layout(True)
        
        # Save both formats
        base_name = f'{metric_key}_delta_vs_budget_average'
        plt.savefig(output_dir / f'{base_name}.png', dpi=300, bbox_inches='tight')
        plt.savefig(output_dir / f'{base_name}.eps', format='eps', bbox_inches='tight')
        print(f"\n[INFO] Saved: {base_name}.png and .eps\n")
        
        plt.show()
        plt.close()

policy_budget_delta_curves()

## Figure Iter

In [None]:

# [TODO] use different names for this functionality for differt plotters
def set_metric_ylim(ax, metric_key):
    """
    Set fixed y-axis limits for different metrics to ensure consistency across plots.
    
    Args:
        ax: matplotlib axes object
        metric_key: string, one of 'PSNR', 'SSIM', 'LPIPS'
    """
    if metric_key == 'PSNR':
        ax.set_ylim(24.5, 35.5)
    elif metric_key == 'SSIM':
        ax.set_ylim(0.8, 1.0)
    elif metric_key == 'LPIPS':
        ax.set_ylim(0, 0.2)

def policy_iter_curves():
    """
    Plot metrics vs iteration for different budgeting policies at fixed budget (160k).
    
    x-axis: iteration number (1000, 2000, 3000, 4000, 5000, 6000, 7000, 10000, 12000, 15000)
    y-axis: quality in PSNR/SSIM/LPIPS
    hue: different budgeting policies
    
    - Area-based budgeting - area_160000_occlusion
    - Distortion-based budgeting - distortion_160000_occlusion
    - Uniform budgeting - uniform_160000_occlusion
    """
    BUDGET = 160000  # Fixed budget
    
    input_dir = Path('./data') / SCENE_NAME
    output_dir = Path('./plots') / 'policy_iter_curves' / SCENE_NAME
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Define iterations to plot
    iterations = ['ours_1000', 'ours_2000', 'ours_3000', 'ours_4000', 'ours_5000', 
                  'ours_6000', 'ours_7000', 'ours_10000', 'ours_12000', 'ours_15000']
    iter_nums = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 10000, 12000, 15000]
    
    # Quality Metrics to plot
    metrics = {
        'PSNR': {'ylabel': 'PSNR (dB)', 'title': 'PSNR'},
        'SSIM': {'ylabel': 'SSIM', 'title': 'SSIM'},
        # 'LPIPS': {'ylabel': 'LPIPS', 'title': 'LPIPS'}
    }
    
    # Plot each metric
    for metric_key, metric_info in metrics.items():
        fig, ax = plt.subplots(figsize=figsize)
        
        print(f"\n{'='*60}")
        print(f"Plotting {metric_key} vs Iteration")
        print(f"{'='*60}")
        
        # 1. Area-based budgeting
        area_xs = []
        area_ys = []
        area_errs = []
        
        area_file = input_dir / f'area_{BUDGET}_occlusion' / 'per_view_gs_mesh.json'
        print(f"\nProcessing Area-based budgeting:")
        print(f"  File: {area_file.name}, exists: {area_file.exists()}")
        
        if area_file.exists():
            with open(area_file, 'r') as f:
                data = json.load(f)
            
            for iter_key, iter_num in zip(iterations, iter_nums):
                if iter_key in data:
                    metric_data = data[iter_key][metric_key]
                    if isinstance(metric_data, dict):
                        values = [v for v in metric_data.values() if v != -1.0]
                    else:
                        values = [metric_data]
                    
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    stderr = std_val / np.sqrt(len(values))
                    
                    area_xs.append(iter_num)
                    area_ys.append(mean_val)
                    area_errs.append(stderr)
                    
                    print(f"    Iter {iter_num}: mean={mean_val:.4f}, std={std_val:.4f}, n={len(values)}")
                else:
                    print(f"    Iter {iter_num}: {iter_key} not found")
        
        if area_xs:
            print(f"  Plotting {len(area_xs)} area points")
            ax.errorbar(area_xs, area_ys, yerr=area_errs,
                       marker='o', markersize=8, linewidth=2.5,
                       capsize=err_capsize, capthick=err_capthick,
                       color=color_palette[1], label='Area-based', zorder=2)
        
        # 2. Distortion-based budgeting
        dist_xs = []
        dist_ys = []
        dist_errs = []
        
        dist_file = input_dir / f'distortion_{BUDGET}_occlusion' / 'per_view_gs_mesh.json'
        print(f"\nProcessing Distortion-based budgeting:")
        print(f"  File: {dist_file.name}, exists: {dist_file.exists()}")
        
        if dist_file.exists():
            with open(dist_file, 'r') as f:
                data = json.load(f)
            
            for iter_key, iter_num in zip(iterations, iter_nums):
                if iter_key in data:
                    metric_data = data[iter_key][metric_key]
                    if isinstance(metric_data, dict):
                        values = [v for v in metric_data.values() if v != -1.0]
                    else:
                        values = [metric_data]
                    
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    stderr = std_val / np.sqrt(len(values))
                    
                    dist_xs.append(iter_num)
                    dist_ys.append(mean_val)
                    dist_errs.append(stderr)
                    
                    print(f"    Iter {iter_num}: mean={mean_val:.4f}, std={std_val:.4f}, n={len(values)}")
                else:
                    print(f"    Iter {iter_num}: {iter_key} not found")
        
        if dist_xs:
            print(f"  Plotting {len(dist_xs)} distortion points")
            ax.errorbar(dist_xs, dist_ys, yerr=dist_errs,
                       marker='s', markersize=8, linewidth=2.5,
                       capsize=err_capsize, capthick=err_capthick,
                       color=color_palette[2], label='Distortion-based', zorder=2)
        
        # 3. Uniform budgeting
        uniform_xs = []
        uniform_ys = []
        uniform_errs = []
        
        uniform_file = input_dir / f'uniform_{BUDGET}_occlusion' / 'per_view_gs_mesh.json'
        print(f"\nProcessing Uniform budgeting:")
        print(f"  File: {uniform_file.name}, exists: {uniform_file.exists()}")
        
        if uniform_file.exists():
            with open(uniform_file, 'r') as f:
                data = json.load(f)
            
            for iter_key, iter_num in zip(iterations, iter_nums):
                if iter_key in data:
                    metric_data = data[iter_key][metric_key]
                    if isinstance(metric_data, dict):
                        values = [v for v in metric_data.values() if v != -1.0]
                    else:
                        values = [metric_data]
                    
                    mean_val = np.mean(values)
                    std_val = np.std(values)
                    stderr = std_val / np.sqrt(len(values))
                    
                    uniform_xs.append(iter_num)
                    uniform_ys.append(mean_val)
                    uniform_errs.append(stderr)
                    
                    print(f"    Iter {iter_num}: mean={mean_val:.4f}, std={std_val:.4f}, n={len(values)}")
                else:
                    print(f"    Iter {iter_num}: {iter_key} not found")
        
        if uniform_xs:
            print(f"  Plotting {len(uniform_xs)} uniform points")
            ax.errorbar(uniform_xs, uniform_ys, yerr=uniform_errs,
                       marker='^', markersize=8, linewidth=2.5,
                       capsize=err_capsize, capthick=err_capthick,
                       color=color_palette[3], label='Uniform', zorder=2)
        
        # Formatting
        ax.set_xlabel('Iteration', fontsize=20)
        ax.set_ylabel(f"Quality in {metric_info['ylabel']}", fontsize=20)
        
        # Set fixed y-axis range
        set_metric_ylim(ax, metric_key)
        
        ax.legend(loc='best', framealpha=0.9, fontsize=18)
        ax.tick_params(labelsize=18)
        
        fig.set_constrained_layout(True)
        
        # Save both formats
        base_name = f'{metric_key}_vs_iteration_{SCENE_NAME}_budget{BUDGET//1000}k'
        plt.savefig(output_dir / f'{base_name}.png', dpi=300, bbox_inches='tight')
        plt.savefig(output_dir / f'{base_name}.eps', format='eps', bbox_inches='tight')
        print(f"\nSaved: {base_name}.png and .eps\n")
        
        plt.show()
        plt.close()


policy_iter_curves()


# SCENE_NAME_LIST = ['ficus', 'hotdog', 'lego', 'mic', 'ship']

# for name in SCENE_NAME_LIST:
#     SCENE_NAME = name
#     policy_iter_curves()