### Jexpresso single grid comparison

In [None]:
import pandas as pd
import numpy as np
import pyvista as pv
from scipy.interpolate import interp1d
from pathlib import Path
import matplotlib.pyplot as plt
import warnings
import os
warnings.filterwarnings('ignore')


class EKEAnalysis:
    """Analyze JExpresso outputs using Domain Mean Subtraction method"""
    
    def __init__(self, ref_file, pert_file, max_height=24000):
        self.ref_file = ref_file
        self.pert_file = pert_file
        self.max_height = max_height

    def extract_jexpresso_data(self):
        ref_mesh = pv.read(self.ref_file)
        pert_mesh = pv.read(self.pert_file)
        
        if ref_mesh.n_points != pert_mesh.n_points:
            z_ref = ref_mesh.points[:, 1]
            ref_sort_idx = np.argsort(z_ref)
            z_ref_sorted = z_ref[ref_sort_idx]
            rho_ref = ref_mesh.point_data['ρ'][ref_sort_idx]
            rho_u_ref = ref_mesh.point_data['ρu'][ref_sort_idx]
            rho_v_ref = ref_mesh.point_data['ρv'][ref_sort_idx]
            
            z_pert = pert_mesh.points[:, 1]
            pert_sort_idx = np.argsort(z_pert)
            z_pert_sorted = z_pert[pert_sort_idx]
            rho_pert = pert_mesh.point_data['ρ'][pert_sort_idx]
            rho_u_pert = pert_mesh.point_data['ρu'][pert_sort_idx]
            rho_v_pert = pert_mesh.point_data['ρv'][pert_sort_idx]
            
            if len(z_ref_sorted) > len(z_pert_sorted):
                z_common = z_ref_sorted
                f_rho_pert = interp1d(z_pert_sorted, rho_pert, bounds_error=False, fill_value=0, kind='linear')
                f_rho_u_pert = interp1d(z_pert_sorted, rho_u_pert, bounds_error=False, fill_value=0, kind='linear')
                f_rho_v_pert = interp1d(z_pert_sorted, rho_v_pert, bounds_error=False, fill_value=0, kind='linear')
                
                rho_total = rho_ref + f_rho_pert(z_common)
                rho_u_total = rho_u_ref + f_rho_u_pert(z_common)
                rho_v_total = rho_v_ref + f_rho_v_pert(z_common)
            else:
                z_common = z_pert_sorted
                f_rho_ref = interp1d(z_ref_sorted, rho_ref, bounds_error=False, fill_value=0, kind='linear')
                f_rho_u_ref = interp1d(z_ref_sorted, rho_u_ref, bounds_error=False, fill_value=0, kind='linear')
                f_rho_v_ref = interp1d(z_ref_sorted, rho_v_ref, bounds_error=False, fill_value=0, kind='linear')
                
                rho_total = f_rho_ref(z_common) + rho_pert
                rho_u_total = f_rho_u_ref(z_common) + rho_u_pert
                rho_v_total = f_rho_v_ref(z_common) + rho_v_pert
        else:
            z_common = ref_mesh.points[:, 1]
            rho_total = ref_mesh.point_data['ρ'] + pert_mesh.point_data['ρ']
            rho_u_total = ref_mesh.point_data['ρu'] + pert_mesh.point_data['ρu']
            rho_v_total = ref_mesh.point_data['ρv'] + pert_mesh.point_data['ρv']
        
        valid_mask = np.abs(rho_total) > 1e-10
        u_total = np.zeros_like(rho_u_total)
        v_total = np.zeros_like(rho_v_total)
        u_total[valid_mask] = rho_u_total[valid_mask] / rho_total[valid_mask]
        v_total[valid_mask] = rho_v_total[valid_mask] / rho_total[valid_mask]
        
        height_tolerance = 100.0
        heights_rounded = np.round(z_common / height_tolerance) * height_tolerance
        unique_heights_rounded = np.unique(heights_rounded)
        
        eke_profile = []
        height_profile = []
        
        for h_target in unique_heights_rounded:
            height_mask = np.abs(heights_rounded - h_target) < (height_tolerance / 2)
            if np.sum(height_mask) < 2:
                continue
                
            u_total_level = u_total[height_mask]
            v_total_level = v_total[height_mask]
            rho_level = rho_total[height_mask]
            z_level = z_common[height_mask]
            
            valid_level_mask = (rho_level > 1e-10) & np.isfinite(u_total_level) & np.isfinite(v_total_level)
            if np.sum(valid_level_mask) < 2:
                continue
                
            u_total_level = u_total_level[valid_level_mask]
            v_total_level = v_total_level[valid_level_mask]
            rho_level = rho_level[valid_level_mask]
            z_level = z_level[valid_level_mask]
            
            u_mean_level = np.mean(u_total_level)
            v_mean_level = np.mean(v_total_level)
            u_pert = u_total_level - u_mean_level
            v_pert = v_total_level - v_mean_level
            eke = 0.5 * rho_level * (u_pert**2 + v_pert**2)
            eke_mean = np.mean(eke)
            
            if eke_mean > 1000 or eke_mean < 0 or not np.isfinite(eke_mean):
                continue
                
            eke_profile.append(eke_mean)
            height_profile.append(np.mean(z_level))
            
        if len(eke_profile) == 0:
            self.jex_data = pd.DataFrame({'Height (m)': [], 'EKE (J/m3)': []})
            return
        
        df_raw = pd.DataFrame({
            "Height (m)": height_profile,
            "EKE (J/m3)": eke_profile
        })
        self.jex_data = df_raw[df_raw["Height (m)"] <= self.max_height].drop_duplicates(subset=["Height (m)"]).sort_values('Height (m)').reset_index(drop=True)

    def get_mean_eke(self):
        """Return mean EKE value"""
        if len(self.jex_data) == 0:
            return np.nan
        return np.nanmean(self.jex_data["EKE (J/m3)"].values)


def process_configuration(ref_file, pert_base_path, max_time_minutes=500):
    """Process configuration and stop at max_time_minutes (default 500)"""
    
    if not Path(ref_file).exists():
        print(f"Reference file not found: {ref_file}")
        return None
    
    time_data = {'jex': [], 'times': []}
    
    max_iterations = max_time_minutes // 10
    existing_pert_files = [i for i in range(max_iterations) if Path(pert_base_path.format(i)).exists()]
    
    if not existing_pert_files:
        print(f"No perturbation files found at: {pert_base_path}")
        return None
    
    successful_iterations = 0
    for i in existing_pert_files:
        try:
            analysis = EKEAnalysis(ref_file, pert_base_path.format(i), max_height=24000)
            analysis.extract_jexpresso_data()
            
            jex_total_eke = analysis.get_mean_eke()
            
            if np.isnan(jex_total_eke) or jex_total_eke > 500 or jex_total_eke < 0:
                continue
            
            time_data['jex'].append(jex_total_eke)
            time_data['times'].append(i * 10)
            successful_iterations += 1
        except Exception as e:
            continue
    
    print(f"Successfully processed {successful_iterations} time steps")
    return time_data if len(time_data['jex']) > 0 else None


def plot_fine_grid_sam(time_data, output_dir):
    """Create publication-quality plot for fine grid SAM only"""
    
    # Set publication-quality matplotlib parameters
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
    plt.rcParams['font.size'] = 12
    plt.rcParams['axes.linewidth'] = 1.5
    plt.rcParams['xtick.major.width'] = 1.5
    plt.rcParams['ytick.major.width'] = 1.5
    
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.set_facecolor('#FAFAFA')
    
    times = time_data['times']
    jex_data = time_data['jex']
    
    # Determine y-axis limits
    positive_values = [v for v in jex_data if v > 0]
    if len(positive_values) > 0:
        y_min = max(1e-7, min(positive_values) * 0.3)
        y_max = max(positive_values) * 3.0
    else:
        y_min = 1e-7
        y_max = 1e2
    
    # Replace zeros and negative values with y_min for log scale
    jex_data_plot = [max(v, y_min) for v in jex_data]
    
    # Plot with distinctive styling
    ax.semilogy(times, jex_data_plot,
                color='#D97330',  # Dark orange
                linestyle='-',     # Solid line
                marker='o',        # Circle markers
                markersize=8,
                linewidth=2.5,
                label='JExpresso EKE',
                alpha=0.90,
                markevery=max(1, len(times)//15),
                markeredgewidth=1.5,
                markeredgecolor='white',
                zorder=4)
    
    # Set axis limits
    ax.set_ylim(y_min, y_max)
    ax.set_xlim(0, 500)
    
    # Labels and title
    ax.set_xlabel('Time (minutes)', fontsize=14, fontweight='600')
    ax.set_ylabel('Mean EKE (J/m³)', fontsize=14, fontweight='600')
    ax.set_title('Fine Grid (200m) - SAM Microphysics - Order 4\nJExpresso EKE Analysis',
                fontsize=16, fontweight='bold', pad=15)
    
    # Enhanced grid
    ax.grid(True, alpha=0.15, color='gray', linewidth=0.5, linestyle=':', which='minor')
    ax.grid(True, alpha=0.35, color='gray', linewidth=0.8, linestyle='-', which='major')
    
    # Enhanced spines
    for spine in ax.spines.values():
        spine.set_color('#333333')
        spine.set_linewidth(1.5)
    
    # Improve tick appearance
    ax.tick_params(axis='both', which='major', labelsize=11, width=1.5, length=7)
    ax.tick_params(axis='both', which='minor', width=1.0, length=4)
    
    # Add legend
    ax.legend(loc='best', fontsize=12, frameon=True, fancybox=True,
             framealpha=0.95, edgecolor='#CCCCCC')
    
    plt.tight_layout()
    
    filename = "Fine_Grid_200m_SAM_Order4.png"
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=400, bbox_inches='tight', facecolor='white', edgecolor='none')
    plt.close()
    print(f"Saved plot: {filename}")


def main():
    
    # Create output directory
    output_base = "/Users/olayemiadeyemi/Documents/sponge_on"
    os.makedirs(output_base, exist_ok=True)
    
    print("=" * 80)
    print("Starting Fine Grid (200m) SAM Analysis - Order 4")
    print("=" * 80)
    
    # Define file paths
    ref_file = "/Users/olayemiadeyemi/Documents/New/reference_state/fine_4/output/iter_0.pvtu"
    pert_file = "/Users/olayemiadeyemi/Documents/sponge_on/f_150_sam_4/output/iter_{}.pvtu"
    
    print(f"\nReference file: {ref_file}")
    print(f"Perturbation files: {pert_file}")
    print("-" * 80)
    
    # Process the configuration
    print("\nProcessing fine grid SAM configuration...")
    try:
        time_data = process_configuration(ref_file, pert_file, max_time_minutes=500)
        
        if time_data and len(time_data['jex']) > 0:
            print(f"✓ Successfully processed {len(time_data['jex'])} time steps")
            
            # Generate plot
            print("\nGenerating publication-quality plot...")
            print("-" * 80)
            plot_fine_grid_sam(time_data, output_base)
            
            print("\n" + "=" * 80)
            print("Analysis complete!")
            print(f"Output directory: {output_base}")
            print("=" * 80)
            print(f"  - Plot saved: {output_base}/Fine_Grid_200m_SAM_Order4.png")
            print("=" * 80)
        else:
            print("✗ No valid data found")
    except Exception as e:
        print(f"✗ Failed to process: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

## Jexpresso multi-grid comparison

In [None]:
import pandas as pd
import numpy as np
import pyvista as pv
from scipy.interpolate import interp1d
from pathlib import Path
import matplotlib.pyplot as plt
import warnings
import os
warnings.filterwarnings('ignore')


class EKEAnalysis:
    """Analyze JExpresso outputs using Domain Mean Subtraction method"""
    
    def __init__(self, ref_file, pert_file, max_height=24000):
        self.ref_file = ref_file
        self.pert_file = pert_file
        self.max_height = max_height

    def extract_jexpresso_data(self):
        ref_mesh = pv.read(self.ref_file)
        pert_mesh = pv.read(self.pert_file)
        
        if ref_mesh.n_points != pert_mesh.n_points:
            z_ref = ref_mesh.points[:, 1]
            ref_sort_idx = np.argsort(z_ref)
            z_ref_sorted = z_ref[ref_sort_idx]
            rho_ref = ref_mesh.point_data['ρ'][ref_sort_idx]
            rho_u_ref = ref_mesh.point_data['ρu'][ref_sort_idx]
            rho_v_ref = ref_mesh.point_data['ρv'][ref_sort_idx]
            
            z_pert = pert_mesh.points[:, 1]
            pert_sort_idx = np.argsort(z_pert)
            z_pert_sorted = z_pert[pert_sort_idx]
            rho_pert = pert_mesh.point_data['ρ'][pert_sort_idx]
            rho_u_pert = pert_mesh.point_data['ρu'][pert_sort_idx]
            rho_v_pert = pert_mesh.point_data['ρv'][pert_sort_idx]
            
            if len(z_ref_sorted) > len(z_pert_sorted):
                z_common = z_ref_sorted
                f_rho_pert = interp1d(z_pert_sorted, rho_pert, bounds_error=False, fill_value=0, kind='linear')
                f_rho_u_pert = interp1d(z_pert_sorted, rho_u_pert, bounds_error=False, fill_value=0, kind='linear')
                f_rho_v_pert = interp1d(z_pert_sorted, rho_v_pert, bounds_error=False, fill_value=0, kind='linear')
                
                rho_total = rho_ref + f_rho_pert(z_common)
                rho_u_total = rho_u_ref + f_rho_u_pert(z_common)
                rho_v_total = rho_v_ref + f_rho_v_pert(z_common)
            else:
                z_common = z_pert_sorted
                f_rho_ref = interp1d(z_ref_sorted, rho_ref, bounds_error=False, fill_value=0, kind='linear')
                f_rho_u_ref = interp1d(z_ref_sorted, rho_u_ref, bounds_error=False, fill_value=0, kind='linear')
                f_rho_v_ref = interp1d(z_ref_sorted, rho_v_ref, bounds_error=False, fill_value=0, kind='linear')
                
                rho_total = f_rho_ref(z_common) + rho_pert
                rho_u_total = f_rho_u_ref(z_common) + rho_u_pert
                rho_v_total = f_rho_v_ref(z_common) + rho_v_pert
        else:
            z_common = ref_mesh.points[:, 1]
            rho_total = ref_mesh.point_data['ρ'] + pert_mesh.point_data['ρ']
            rho_u_total = ref_mesh.point_data['ρu'] + pert_mesh.point_data['ρu']
            rho_v_total = ref_mesh.point_data['ρv'] + pert_mesh.point_data['ρv']
        
        valid_mask = np.abs(rho_total) > 1e-10
        u_total = np.zeros_like(rho_u_total)
        v_total = np.zeros_like(rho_v_total)
        u_total[valid_mask] = rho_u_total[valid_mask] / rho_total[valid_mask]
        v_total[valid_mask] = rho_v_total[valid_mask] / rho_total[valid_mask]
        
        height_tolerance = 100.0
        heights_rounded = np.round(z_common / height_tolerance) * height_tolerance
        unique_heights_rounded = np.unique(heights_rounded)
        
        eke_profile = []
        height_profile = []
        
        for h_target in unique_heights_rounded:
            height_mask = np.abs(heights_rounded - h_target) < (height_tolerance / 2)
            if np.sum(height_mask) < 2:
                continue
                
            u_total_level = u_total[height_mask]
            v_total_level = v_total[height_mask]
            rho_level = rho_total[height_mask]
            z_level = z_common[height_mask]
            
            valid_level_mask = (rho_level > 1e-10) & np.isfinite(u_total_level) & np.isfinite(v_total_level)
            if np.sum(valid_level_mask) < 2:
                continue
                
            u_total_level = u_total_level[valid_level_mask]
            v_total_level = v_total_level[valid_level_mask]
            rho_level = rho_level[valid_level_mask]
            z_level = z_level[valid_level_mask]
            
            u_mean_level = np.mean(u_total_level)
            v_mean_level = np.mean(v_total_level)
            u_pert = u_total_level - u_mean_level
            v_pert = v_total_level - v_mean_level
            eke = 0.5 * rho_level * (u_pert**2 + v_pert**2)
            eke_mean = np.mean(eke)
            
            if eke_mean > 1000 or eke_mean < 0 or not np.isfinite(eke_mean):
                continue
                
            eke_profile.append(eke_mean)
            height_profile.append(np.mean(z_level))
            
        if len(eke_profile) == 0:
            self.jex_data = pd.DataFrame({'Height (m)': [], 'EKE (J/m3)': []})
            return
        
        df_raw = pd.DataFrame({
            "Height (m)": height_profile,
            "EKE (J/m3)": eke_profile
        })
        self.jex_data = df_raw[df_raw["Height (m)"] <= self.max_height].drop_duplicates(subset=["Height (m)"]).sort_values('Height (m)').reset_index(drop=True)

    def get_mean_eke(self):
        """Return mean EKE value"""
        if len(self.jex_data) == 0:
            return np.nan
        return np.nanmean(self.jex_data["EKE (J/m3)"].values)


def process_configuration(ref_file, pert_base_path, max_time_minutes=500):
    """Process configuration and stop at max_time_minutes (default 500)"""
    
    if not Path(ref_file).exists():
        print(f"Reference file not found: {ref_file}")
        return None
    
    time_data = {'jex': [], 'times': []}
    
    max_iterations = max_time_minutes // 10
    existing_pert_files = [i for i in range(max_iterations) if Path(pert_base_path.format(i)).exists()]
    
    if not existing_pert_files:
        print(f"No perturbation files found at: {pert_base_path}")
        return None
    
    successful_iterations = 0
    for i in existing_pert_files:
        try:
            analysis = EKEAnalysis(ref_file, pert_base_path.format(i), max_height=24000)
            analysis.extract_jexpresso_data()
            
            jex_total_eke = analysis.get_mean_eke()
            
            if np.isnan(jex_total_eke) or jex_total_eke > 500 or jex_total_eke < 0:
                continue
            
            time_data['jex'].append(jex_total_eke)
            time_data['times'].append(i * 10)
            successful_iterations += 1
        except Exception as e:
            continue
    
    print(f"Successfully processed {successful_iterations} time steps")
    return time_data if len(time_data['jex']) > 0 else None


def plot_comprehensive_grid(data_dict, output_dir):
    """Create publication-quality comprehensive 3x2 grid with color gradients"""
    
    # Set publication-quality matplotlib parameters
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
    plt.rcParams['font.size'] = 10
    plt.rcParams['axes.linewidth'] = 1.2
    plt.rcParams['xtick.major.width'] = 1.2
    plt.rcParams['ytick.major.width'] = 1.2
    
    fig, axes = plt.subplots(3, 2, figsize=(20, 17))
    
    # Color gradient: lighter (coarse) to darker (fine) for Jexpresso
    jex_colors = {
        '4200m': '#FFD9B8',  # Lightest orange
        '1200m': '#FFBC85',  # Light orange
        '800m': '#FFA55C',   # Medium-light orange
        '600m': '#FF8C42',   # Medium orange
        '400m': '#D97330',   # Dark orange
        '200m': '#A65620'    # Darkest orange
    }
    
    resolutions = ['4200m', '1200m', '800m', '600m', '400m', '200m']
    
    # Distinct line styles progressing from coarse to fine
    linestyles = {
        '4200m': (0, (1, 1)),           # Dotted
        '1200m': (0, (5, 5)),           # Long dashed
        '800m': (0, (5, 2)),            # Dash
        '600m': (0, (3, 1, 1, 1)),      # Dash-dot-dot
        '400m': (0, (5, 1, 1, 1, 1, 1)),  # Dash-dot-dot with spacing
        '200m': '-'                      # Solid (finest resolution)
    }
    
    # Graduated line widths (thicker for finer resolutions)
    linewidths = {
        '4200m': 1.8,
        '1200m': 2.0,
        '800m': 2.3,
        '600m': 2.6,
        '400m': 2.9,
        '200m': 3.3
    }
    
    # Distinct markers for each resolution
    markers = {
        '4200m': 'x',      # Cross
        '1200m': 's',      # Square
        '800m': 'D',       # Diamond
        '600m': 'o',       # Circle
        '400m': '^',       # Triangle up
        '200m': 'v'        # Triangle down
    }
    
    # Graduated marker sizes (larger for finer resolutions)
    markersizes = {
        '4200m': 6,
        '1200m': 7,
        '800m': 8,
        '600m': 9,
        '400m': 10,
        '200m': 11
    }
    
    orders = ['4', '5', '6']
    microphysics_types = ['Kessler', 'Cold']
    
    # First pass: collect all data to determine global y-axis limits
    all_values = []
    
    for order in orders:
        for micro in microphysics_types:
            for resolution in resolutions:
                key = f"Order_{order}_{resolution}_{micro}"
                if key in data_dict and data_dict[key] is not None:
                    data = data_dict[key]
                    all_values.extend(data['jex'])
    
    # Determine global y-axis limits
    if len(all_values) > 0:
        positive_values = [v for v in all_values if v > 0]
        if len(positive_values) > 0:
            y_min = max(1e-7, min(positive_values) * 0.3)
            y_max = max(positive_values) * 3.0
        else:
            y_min = 1e-7
            y_max = 1e2
    else:
        y_min = 1e-7
        y_max = 1e2
    
    # Collect legend handles and labels (only once)
    legend_handles = []
    legend_labels = []
    legend_collected = False
    
    # Second pass: create plots with uniform axes
    for row_idx, order in enumerate(orders):
        for col_idx, micro in enumerate(microphysics_types):
            ax = axes[row_idx, col_idx]
            ax.set_facecolor('#FAFAFA')
            
            for resolution in resolutions:
                key = f"Order_{order}_{resolution}_{micro}"
                
                if key not in data_dict or data_dict[key] is None:
                    continue
                
                data = data_dict[key]
                times = data['times']
                jex_data = data['jex']
                
                # Replace zeros and negative values with y_min for log scale
                jex_data_plot = [max(v, y_min) for v in jex_data]
                
                # Plot Jexpresso with color gradient and markers
                jex_line = ax.semilogy(times, jex_data_plot, 
                        color=jex_colors[resolution],
                        linestyle=linestyles[resolution],
                        marker=markers[resolution],
                        markersize=markersizes[resolution],
                        linewidth=linewidths[resolution],
                        label=f'JEx {resolution}',
                        alpha=0.90,
                        markevery=max(1, len(times)//10),
                        markeredgewidth=1.2,
                        markeredgecolor='white',
                        zorder=4)
                
                # Collect legend items from first subplot only
                if not legend_collected:
                    legend_handles.append(jex_line[0])
                    legend_labels.append(f'JEx {resolution}')
            
            # Mark legend as collected after first subplot
            if row_idx == 0 and col_idx == 0:
                legend_collected = True
            
            # Set uniform y-axis limits for all subplots
            ax.set_ylim(y_min, y_max)
            ax.set_xlim(0, 500)
            
            if row_idx == 2:
                ax.set_xlabel('Time (minutes)', fontsize=13, fontweight='600')
            if col_idx == 0:
                ax.set_ylabel('Mean EKE (J/m³)', fontsize=13, fontweight='600')
            
            title = f'Order {order} - {micro}'
            ax.set_title(title, fontsize=14, fontweight='700', pad=12)
            
            # Enhanced grid
            ax.grid(True, alpha=0.15, color='gray', linewidth=0.5, linestyle=':', which='minor')
            ax.grid(True, alpha=0.35, color='gray', linewidth=0.8, linestyle='-', which='major')
            
            # Enhanced spines
            for spine in ax.spines.values():
                spine.set_color('#333333')
                spine.set_linewidth(1.3)
            
            # Improve tick appearance
            ax.tick_params(axis='both', which='major', labelsize=10, width=1.2, length=6)
            ax.tick_params(axis='both', which='minor', width=0.8, length=3)
    
    # Create main title
    plt.suptitle('JExpresso EKE Analysis: Orders 4, 5, 6\nAll Resolutions and Microphysics Schemes', 
                fontsize=18, fontweight='bold', y=0.995)
    
    # Add legend below the plots
    fig.legend(legend_handles, legend_labels, 
              loc='lower center', 
              bbox_to_anchor=(0.5, -0.02),
              ncol=6,
              fontsize=10,
              frameon=True,
              fancybox=True,
              framealpha=0.97,
              edgecolor='#CCCCCC',
              borderpad=0.8)
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.99])
    
    filename = "JExpresso_Grid_Publication_Quality.png"
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=400, bbox_inches='tight', facecolor='white', edgecolor='none')
    plt.close()
    print(f"Saved publication-quality grid: {filename}")


def main():
    
    # Create output directory
    output_base = "/Users/olayemiadeyemi/Documents/sponge_on"
    os.makedirs(output_base, exist_ok=True)
    
    print("=" * 80)
    print("Starting JExpresso EKE Analysis - Publication Quality")
    print("=" * 80)
    
    # Define all file paths
    # JExpresso Reference files for each order
    jexpresso_reference = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/New/reference_state/coarse_4",
            '1200m': "/Users/olayemiadeyemi/Documents/New/reference_state/medium_4",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_800_4/output/iter_0.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_600_4/output/iter_0.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_400_4/output/iter_0.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/New/reference_state/fine_4"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/New/reference_state/coarse_5",
            '1200m': "/Users/olayemiadeyemi/Documents/New/reference_state/medium_5",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_800_5/output/iter_0.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_600_5/output/iter_0.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_400_5/output/iter_0.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/New/reference_state/fine_5"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/New/reference_state/coarse_6",
            '1200m': "/Users/olayemiadeyemi/Documents/New/reference_state/medium_6",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_800_6/output/iter_0.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_600_6/output/iter_0.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_400_6/output/iter_0.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/New/reference_state/fine_6"
        }
    }
    
    # JExpresso Kessler paths
    jexpresso_kessler = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_k_4/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_k_4/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_800_4_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_600_4_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_400_4_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/f_150_k_4/output/iter_{}.pvtu"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_k_5/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_k_5/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_800_5_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_600_5_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_400_5_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_200_5_c5/output/iter_{}.pvtu"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_k_6/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_k_6/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_800_6_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_600_6_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_400_6_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_200_6_c5/output/iter_{}.pvtu"
        }
    }
    
    # JExpresso Cold paths
    jexpresso_cold = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_sam_4/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_sam_4/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_800_4_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_600_4_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_400_4_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/f_150_sam_4/output/iter_{}.pvtu"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_sam_5/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_sam_5/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_800_5_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_600_5_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_400_5_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_200_5_c5/output/iter_{}.pvtu"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_sam_6/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_sam_6/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_800_6_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_600_6_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_400_6_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_200_6_c5/output/iter_{}.pvtu"
        }
    }
    
    # Process all configurations
    all_data = {}
    
    orders = ['4', '5', '6']
    resolutions = ['4200m', '1200m', '800m', '600m', '400m', '200m']
    microphysics_types = ['Kessler', 'Cold']
    
    print("\nProcessing configurations...")
    print("-" * 80)
    
    for order in orders:
        for resolution in resolutions:
            for micro in microphysics_types:
                print(f"\nProcessing: Order {order}, Resolution {resolution}, {micro}")
                
                if micro == 'Kessler':
                    jex_pert = jexpresso_kessler[order][resolution]
                else:
                    jex_pert = jexpresso_cold[order][resolution]
                
                ref_file = jexpresso_reference[order][resolution]
                
                # Handle directory-based reference files
                if not ref_file.endswith('.pvtu'):
                    ref_file = os.path.join(ref_file, 'output', 'iter_0.pvtu')
                
                # Process Jexpresso
                try:
                    time_data = process_configuration(ref_file, jex_pert, max_time_minutes=500)
                    
                    if time_data and len(time_data['jex']) > 0:
                        key = f"Order_{order}_{resolution}_{micro}"
                        all_data[key] = time_data
                        print(f"✓ Successfully processed {key} ({len(time_data['jex'])} time steps)")
                    else:
                        print(f"✗ No valid data for Order {order}, {resolution}, {micro}")
                except Exception as e:
                    print(f"✗ Failed to process Order {order}, {resolution}, {micro}: {e}")
    
    print("\n" + "=" * 80)
    print(f"Data processing complete. Successfully processed {len(all_data)} configurations.")
    print("=" * 80)
    
    # Generate publication-quality comprehensive grid plot
    print("\nGenerating publication-quality comprehensive grid plot...")
    print("-" * 80)
    plot_comprehensive_grid(all_data, output_base)
    
    print("\n" + "=" * 80)
    print("Publication-quality plot generated successfully!")
    print(f"Output directory: {output_base}")
    print("=" * 80)
    print(f"  - Publication-quality grid: {output_base}/JExpresso_Grid_Publication_Quality.png")
    print("=" * 80)


if __name__ == "__main__":
    main()

### WRF vs Jexpresso

In [None]:
import pandas as pd
import numpy as np
from netCDF4 import Dataset
from wrf import getvar, destagger
import pyvista as pv
from scipy.interpolate import interp1d
from pathlib import Path
import matplotlib.pyplot as plt
import warnings
import traceback
import os
warnings.filterwarnings('ignore', category=UserWarning, module='wrf')


class EKEComparison:
    """Compare WRF and JExpresso outputs using Domain Mean Subtraction method"""
    
    def __init__(self, wrf_file, ref_file, pert_file, max_height=24000):
        self.wrf_file = wrf_file
        self.ref_file = ref_file
        self.pert_file = pert_file
        self.max_height = max_height
        self.R_d = 287.05

    def extract_wrf_data(self, timeidx=0):
        with Dataset(self.wrf_file) as ncfile:
            z = getvar(ncfile, "z", timeidx=timeidx, meta=False)
            u_full = destagger(getvar(ncfile, "ua", timeidx=timeidx, meta=False), -1)
            v_full = destagger(getvar(ncfile, "va", timeidx=timeidx, meta=False), -1)
            z = z[:, :, :-1]
            t = getvar(ncfile, "tk", timeidx=timeidx, meta=False)[:, :, :-1]
            qv = getvar(ncfile, "QVAPOR", timeidx=timeidx, meta=False)[:, :, :-1]
            p = getvar(ncfile, "P", timeidx=timeidx, meta=False)[:, :, :-1]
            pb = getvar(ncfile, "PB", timeidx=timeidx, meta=False)[:, :, :-1]
            
            Tv_full = t * (1 + 0.61 * qv)
            pressure_full = p + pb
            rho_full = pressure_full / (self.R_d * Tv_full)
            
            nz, ny, nx = u_full.shape
            eke_profile = []
            height_profile = []
            
            for k in range(nz):
                u_level = u_full[k, :, :]
                v_level = v_full[k, :, :]
                rho_level = rho_full[k, :, :]
                z_level = z[k, :, :]
                
                mean_height = np.mean(z_level)
                height_profile.append(mean_height)
                
                u_horizontal_mean = np.mean(u_level)
                v_horizontal_mean = np.mean(v_level)
                u_pert = u_level - u_horizontal_mean
                v_pert = v_level - v_horizontal_mean
                eke = 0.5 * rho_level * (u_pert**2 + v_pert**2)
                eke_profile.append(np.mean(eke))

        df = pd.DataFrame({
            "Height (m)": height_profile,
            "EKE (J/m3)": eke_profile
        })
        self.wrf_data = df[df["Height (m)"] <= self.max_height].reset_index(drop=True)

    def extract_jexpresso_data(self):
        ref_mesh = pv.read(self.ref_file)
        pert_mesh = pv.read(self.pert_file)
        
        if ref_mesh.n_points != pert_mesh.n_points:
            z_ref = ref_mesh.points[:, 1]
            ref_sort_idx = np.argsort(z_ref)
            z_ref_sorted = z_ref[ref_sort_idx]
            rho_ref = ref_mesh.point_data['ρ'][ref_sort_idx]
            rho_u_ref = ref_mesh.point_data['ρu'][ref_sort_idx]
            rho_v_ref = ref_mesh.point_data['ρv'][ref_sort_idx]
            
            z_pert = pert_mesh.points[:, 1]
            pert_sort_idx = np.argsort(z_pert)
            z_pert_sorted = z_pert[pert_sort_idx]
            rho_pert = pert_mesh.point_data['ρ'][pert_sort_idx]
            rho_u_pert = pert_mesh.point_data['ρu'][pert_sort_idx]
            rho_v_pert = pert_mesh.point_data['ρv'][pert_sort_idx]
            
            if len(z_ref_sorted) > len(z_pert_sorted):
                z_common = z_ref_sorted
                f_rho_pert = interp1d(z_pert_sorted, rho_pert, bounds_error=False, fill_value=0, kind='linear')
                f_rho_u_pert = interp1d(z_pert_sorted, rho_u_pert, bounds_error=False, fill_value=0, kind='linear')
                f_rho_v_pert = interp1d(z_pert_sorted, rho_v_pert, bounds_error=False, fill_value=0, kind='linear')
                
                rho_total = rho_ref + f_rho_pert(z_common)
                rho_u_total = rho_u_ref + f_rho_u_pert(z_common)
                rho_v_total = rho_v_ref + f_rho_v_pert(z_common)
            else:
                z_common = z_pert_sorted
                f_rho_ref = interp1d(z_ref_sorted, rho_ref, bounds_error=False, fill_value=0, kind='linear')
                f_rho_u_ref = interp1d(z_ref_sorted, rho_u_ref, bounds_error=False, fill_value=0, kind='linear')
                f_rho_v_ref = interp1d(z_ref_sorted, rho_v_ref, bounds_error=False, fill_value=0, kind='linear')
                
                rho_total = f_rho_ref(z_common) + rho_pert
                rho_u_total = f_rho_u_ref(z_common) + rho_u_pert
                rho_v_total = f_rho_v_ref(z_common) + rho_v_pert
        else:
            z_common = ref_mesh.points[:, 1]
            rho_total = ref_mesh.point_data['ρ'] + pert_mesh.point_data['ρ']
            rho_u_total = ref_mesh.point_data['ρu'] + pert_mesh.point_data['ρu']
            rho_v_total = ref_mesh.point_data['ρv'] + pert_mesh.point_data['ρv']
        
        valid_mask = np.abs(rho_total) > 1e-10
        u_total = np.zeros_like(rho_u_total)
        v_total = np.zeros_like(rho_v_total)
        u_total[valid_mask] = rho_u_total[valid_mask] / rho_total[valid_mask]
        v_total[valid_mask] = rho_v_total[valid_mask] / rho_total[valid_mask]
        
        height_tolerance = 100.0
        heights_rounded = np.round(z_common / height_tolerance) * height_tolerance
        unique_heights_rounded = np.unique(heights_rounded)
        
        eke_profile = []
        height_profile = []
        
        for h_target in unique_heights_rounded:
            height_mask = np.abs(heights_rounded - h_target) < (height_tolerance / 2)
            if np.sum(height_mask) < 2:
                continue
                
            u_total_level = u_total[height_mask]
            v_total_level = v_total[height_mask]
            rho_level = rho_total[height_mask]
            z_level = z_common[height_mask]
            
            valid_level_mask = (rho_level > 1e-10) & np.isfinite(u_total_level) & np.isfinite(v_total_level)
            if np.sum(valid_level_mask) < 2:
                continue
                
            u_total_level = u_total_level[valid_level_mask]
            v_total_level = v_total_level[valid_level_mask]
            rho_level = rho_level[valid_level_mask]
            z_level = z_level[valid_level_mask]
            
            u_mean_level = np.mean(u_total_level)
            v_mean_level = np.mean(v_total_level)
            u_pert = u_total_level - u_mean_level
            v_pert = v_total_level - v_mean_level
            eke = 0.5 * rho_level * (u_pert**2 + v_pert**2)
            eke_mean = np.mean(eke)
            
            if eke_mean > 1000 or eke_mean < 0 or not np.isfinite(eke_mean):
                continue
                
            eke_profile.append(eke_mean)
            height_profile.append(np.mean(z_level))
            
        if len(eke_profile) == 0:
            self.jex_data = pd.DataFrame({'Height (m)': [], 'EKE (J/m3)': []})
            return
        
        df_raw = pd.DataFrame({
            "Height (m)": height_profile,
            "EKE (J/m3)": eke_profile
        })
        self.jex_data = df_raw[df_raw["Height (m)"] <= self.max_height].drop_duplicates(subset=["Height (m)"]).sort_values('Height (m)').reset_index(drop=True)

    def interpolate_to_wrf_heights(self):
        df_jex = self.jex_data
        if len(df_jex) == 0:
            heights_wrf = self.wrf_data["Height (m)"].values
            return np.full_like(heights_wrf, np.nan)
            
        heights_wrf = self.wrf_data["Height (m)"].values
        valid_mask = ~np.isnan(df_jex["EKE (J/m3)"])
        
        if valid_mask.sum() < 2:
            return np.full_like(heights_wrf, np.nan)
        
        try:
            f = interp1d(
                df_jex["Height (m)"][valid_mask], 
                df_jex["EKE (J/m3)"][valid_mask], 
                bounds_error=False, 
                fill_value="extrapolate",
                kind='linear'
            )
            return f(heights_wrf)
        except Exception as e:
            return np.full_like(heights_wrf, np.nan)

    def get_comparison_data(self):
        jex_eke_interp = self.interpolate_to_wrf_heights()
        return {
            "heights": self.wrf_data["Height (m)"].values,
            "wrf_eke": self.wrf_data["EKE (J/m3)"].values,
            "jex_eke": jex_eke_interp
        }


def process_configuration(wrf_file, ref_file, pert_base_path, max_time_minutes=500):
    """Process configuration and stop at max_time_minutes (default 500)"""
    
    if not Path(wrf_file).exists():
        print(f"WRF file not found: {wrf_file}")
        return None
    
    if not Path(ref_file).exists():
        print(f"Reference file not found: {ref_file}")
        return None
    
    time_data = {'wrf': [], 'jex': [], 'times': []}
    
    max_iterations = max_time_minutes // 10
    existing_pert_files = [i for i in range(max_iterations) if Path(pert_base_path.format(i)).exists()]
    
    if not existing_pert_files:
        print(f"No perturbation files found at: {pert_base_path}")
        return None
    
    successful_iterations = 0
    for i in existing_pert_files:
        try:
            comparison = EKEComparison(wrf_file, ref_file, pert_base_path.format(i), max_height=24000)
            comparison.extract_wrf_data(timeidx=i)
            comparison.extract_jexpresso_data()
            comp_data = comparison.get_comparison_data()
            
            if len(comp_data['wrf_eke']) == 0 or np.all(np.isnan(comp_data['wrf_eke'])) or np.all(np.isnan(comp_data['jex_eke'])):
                continue
            
            wrf_total_eke = np.nanmean(comp_data['wrf_eke'])
            jex_total_eke = np.nanmean(comp_data['jex_eke'])
            
            if wrf_total_eke > 500 or jex_total_eke > 500:
                continue
            
            time_data['wrf'].append(wrf_total_eke)
            time_data['jex'].append(jex_total_eke)
            time_data['times'].append(i * 10)
            successful_iterations += 1
        except Exception as e:
            continue
    
    print(f"Successfully processed {successful_iterations} time steps")
    return time_data if len(time_data['wrf']) > 0 else None


def plot_comprehensive_grid(data_dict, output_dir):
    """Create publication-quality comprehensive 3x2 grid with color gradients and distinct styling"""
    
    # Set publication-quality matplotlib parameters
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
    plt.rcParams['font.size'] = 10
    plt.rcParams['axes.linewidth'] = 1.2
    plt.rcParams['xtick.major.width'] = 1.2
    plt.rcParams['ytick.major.width'] = 1.2
    
    # Adjusted figure size to accommodate legend below
    fig, axes = plt.subplots(3, 2, figsize=(20, 17))
    
    # Color gradients: lighter (coarse) to darker (fine) for both WRF and Jexpresso
    wrf_colors = {
        '4200m': '#A8C5E8',  # Lightest blue
        '1200m': '#7AAAD9',  # Light blue
        '800m': '#5290CA',   # Medium-light blue
        '600m': '#2E5CB8',   # Medium blue (original)
        '400m': '#1E4A9A',   # Dark blue
        '200m': '#0F2F5C'    # Darkest blue
    }
    
    jex_colors = {
        '4200m': '#FFD9B8',  # Lightest orange
        '1200m': '#FFBC85',  # Light orange
        '800m': '#FFA55C',   # Medium-light orange
        '600m': '#FF8C42',   # Medium orange (original)
        '400m': '#D97330',   # Dark orange
        '200m': '#A65620'    # Darkest orange
    }
    
    resolutions = ['4200m', '1200m', '800m', '600m', '400m', '200m']
    
    # Distinct line styles progressing from coarse to fine
    linestyles = {
        '4200m': (0, (1, 1)),           # Dotted
        '1200m': (0, (5, 5)),           # Long dashed
        '800m': (0, (5, 2)),            # Dash
        '600m': (0, (3, 1, 1, 1)),      # Dash-dot-dot
        '400m': (0, (5, 1, 1, 1, 1, 1)),  # Dash-dot-dot with spacing
        '200m': '-'                      # Solid (finest resolution)
    }
    
    # Graduated line widths (thicker for finer resolutions)
    linewidths = {
        '4200m': 1.8,
        '1200m': 2.0,
        '800m': 2.3,
        '600m': 2.6,
        '400m': 2.9,
        '200m': 3.3
    }
    
    # Distinct markers for each resolution
    markers = {
        '4200m': 'x',      # Cross
        '1200m': 's',      # Square
        '800m': 'D',       # Diamond
        '600m': 'o',       # Circle
        '400m': '^',       # Triangle up
        '200m': 'v'        # Triangle down
    }
    
    # Graduated marker sizes (larger for finer resolutions)
    markersizes = {
        '4200m': 6,
        '1200m': 7,
        '800m': 8,
        '600m': 9,
        '400m': 10,
        '200m': 11
    }
    
    orders = ['4', '5', '6']
    microphysics_types = ['Kessler', 'Cold']
    
    # First pass: collect all data to determine global y-axis limits
    all_wrf_values = []
    all_jex_values = []
    
    for order in orders:
        for micro in microphysics_types:
            for resolution in resolutions:
                key = f"Order_{order}_{resolution}_{micro}"
                if key in data_dict and data_dict[key] is not None:
                    data = data_dict[key]
                    all_wrf_values.extend(data['wrf'])
                    all_jex_values.extend(data['jex'])
    
    # Determine global y-axis limits
    all_values = all_wrf_values + all_jex_values
    if len(all_values) > 0:
        positive_values = [v for v in all_values if v > 0]
        if len(positive_values) > 0:
            y_min = max(1e-7, min(positive_values) * 0.3)
            y_max = max(positive_values) * 3.0
        else:
            y_min = 1e-7
            y_max = 1e2
    else:
        y_min = 1e-7
        y_max = 1e2
    
    # Collect legend handles and labels (only once)
    legend_handles = []
    legend_labels = []
    legend_collected = False
    
    # Second pass: create plots with uniform axes
    for row_idx, order in enumerate(orders):
        for col_idx, micro in enumerate(microphysics_types):
            ax = axes[row_idx, col_idx]
            ax.set_facecolor('#FAFAFA')
            
            # Track which resolutions have been plotted for this subplot
            wrf_plotted = set()
            
            for resolution in resolutions:
                key = f"Order_{order}_{resolution}_{micro}"
                
                if key not in data_dict or data_dict[key] is None:
                    continue
                
                data = data_dict[key]
                times = data['times']
                wrf_data = data['wrf']
                jex_data = data['jex']
                
                # Replace zeros and negative values with y_min for log scale
                wrf_data_plot = [max(v, y_min) for v in wrf_data]
                jex_data_plot = [max(v, y_min) for v in jex_data]
                
                # Plot WRF for all subplots (WRF is the same baseline for all orders)
                if resolution not in wrf_plotted:
                    wrf_line = ax.semilogy(times, wrf_data_plot, 
                            color=wrf_colors[resolution],
                            linestyle=linestyles[resolution],
                            linewidth=linewidths[resolution],
                            label=f'WRF {resolution}',
                            alpha=0.95,
                            zorder=5)
                    wrf_plotted.add(resolution)
                    
                    # Collect legend items from first subplot only
                    if not legend_collected:
                        legend_handles.append(wrf_line[0])
                        legend_labels.append(f'WRF {resolution}')
                
                # Plot Jexpresso with color gradient and markers
                jex_line = ax.semilogy(times, jex_data_plot, 
                        color=jex_colors[resolution],
                        linestyle=linestyles[resolution],
                        marker=markers[resolution],
                        markersize=markersizes[resolution],
                        linewidth=linewidths[resolution],
                        label=f'JEx {resolution}',
                        alpha=0.90,
                        markevery=max(1, len(times)//10),
                        markeredgewidth=1.2,
                        markeredgecolor='white',
                        zorder=4)
                
                # Collect legend items from first subplot only
                if not legend_collected:
                    legend_handles.append(jex_line[0])
                    legend_labels.append(f'JEx {resolution}')
            
            # Mark legend as collected after first subplot
            if row_idx == 0 and col_idx == 0:
                legend_collected = True
            
            # Set uniform y-axis limits for all subplots
            ax.set_ylim(y_min, y_max)
            ax.set_xlim(0, 500)
            
            if row_idx == 2:
                ax.set_xlabel('Time (minutes)', fontsize=13, fontweight='600')
            if col_idx == 0:
                ax.set_ylabel('Mean EKE (J/m³)', fontsize=13, fontweight='600')
            
            title = f'Order {order} - {micro}'
            ax.set_title(title, fontsize=14, fontweight='700', pad=12)
            
            # Enhanced grid
            ax.grid(True, alpha=0.15, color='gray', linewidth=0.5, linestyle=':', which='minor')
            ax.grid(True, alpha=0.35, color='gray', linewidth=0.8, linestyle='-', which='major')
            
            # Enhanced spines
            for spine in ax.spines.values():
                spine.set_color('#333333')
                spine.set_linewidth(1.3)
            
            # Improve tick appearance
            ax.tick_params(axis='both', which='major', labelsize=10, width=1.2, length=6)
            ax.tick_params(axis='both', which='minor', width=0.8, length=3)
    
    # Create main title
    plt.suptitle('Comprehensive EKE Comparison: WRF vs Jexpresso (Orders 4, 5, 6)\nAll Resolutions and Microphysics Schemes', 
                fontsize=18, fontweight='bold', y=0.995)
    
    # Add legend below the plots
    fig.legend(legend_handles, legend_labels, 
              loc='lower center', 
              bbox_to_anchor=(0.5, -0.02),
              ncol=6,
              fontsize=10,
              frameon=True,
              fancybox=True,
              framealpha=0.97,
              edgecolor='#CCCCCC',
              borderpad=0.8)
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.99])
    
    filename = "Comprehensive_Grid_Publication_Quality.png"
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=400, bbox_inches='tight', facecolor='white', edgecolor='none')
    plt.close()
    print(f"Saved publication-quality grid: {filename}")


def main():
    
    # Create output directory structure - only main directory needed
    output_base = "/Users/olayemiadeyemi/Documents/sponge_on"
    os.makedirs(output_base, exist_ok=True)
    
    print("=" * 80)
    print("Starting EKE Comprehensive Analysis - Publication Quality")
    print("=" * 80)
    
    # Define all file paths
    # JExpresso Reference files for each order
    jexpresso_reference = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/New/reference_state/coarse_4",
            '1200m': "/Users/olayemiadeyemi/Documents/New/reference_state/medium_4",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_800_4/output/iter_0.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_600_4/output/iter_0.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_400_4/output/iter_0.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/New/reference_state/fine_4"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/New/reference_state/coarse_5",
            '1200m': "/Users/olayemiadeyemi/Documents/New/reference_state/medium_5",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_800_5/output/iter_0.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_600_5/output/iter_0.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_400_5/output/iter_0.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/New/reference_state/fine_5"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/New/reference_state/coarse_6",
            '1200m': "/Users/olayemiadeyemi/Documents/New/reference_state/medium_6",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_800_6/output/iter_0.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_600_6/output/iter_0.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/ref_400_6/output/iter_0.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/New/reference_state/fine_6"
        }
    }
    
    # JExpresso Kessler paths
    jexpresso_kessler = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_k_4/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_k_4/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_800_4_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_600_4_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_400_4_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/f_150_k_4/output/iter_{}.pvtu"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_k_5/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_k_5/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_800_5_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_600_5_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_400_5_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_200_5_c5/output/iter_{}.pvtu"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_k_6/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_k_6/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_800_6_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_600_6_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_400_6_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_K_200_6_c5/output/iter_{}.pvtu"
        }
    }
    
    # JExpresso Cold paths
    jexpresso_cold = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_sam_4/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_sam_4/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_800_4_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_600_4_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_400_4_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/f_150_sam_4/output/iter_{}.pvtu"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_sam_5/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_sam_5/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_800_5_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_600_5_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_400_5_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_200_5_c5/output/iter_{}.pvtu"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/c_150_sam_6/output/iter_{}.pvtu",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/m_150_sam_6/output/iter_{}.pvtu",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_800_6_c5/output/iter_{}.pvtu",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_600_6_c5/output/iter_{}.pvtu",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_400_6_c5/output/iter_{}.pvtu",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/fine_S_200_6_c5/output/iter_{}.pvtu"
        }
    }
    
    # WRF Kessler paths
    wrf_kessler = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6104-01-01_00_00_00",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6105-01-01_00_00_00",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6020-01-01_00_00_00",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6022-01-01_00_00_00",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6024-01-01_00_00_00",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6106-01-01_00_00_00"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6104-01-01_00_00_00",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6105-01-01_00_00_00",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6020-01-01_00_00_00",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6022-01-01_00_00_00",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6024-01-01_00_00_00",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6106-01-01_00_00_00"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6104-01-01_00_00_00",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6105-01-01_00_00_00",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6020-01-01_00_00_00",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6022-01-01_00_00_00",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6024-01-01_00_00_00",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6106-01-01_00_00_00"
        }
    }
    
    # WRF Cold paths
    wrf_cold = {
        '4': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6107-01-01_00_00_00",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6108-01-01_00_00_00",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6021-01-01_00_00_00",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6023-01-01_00_00_00",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6025-01-01_00_00_00",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6109-01-01_00_00_00"
        },
        '5': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6107-01-01_00_00_00",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6108-01-01_00_00_00",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6021-01-01_00_00_00",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6023-01-01_00_00_00",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6025-01-01_00_00_00",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6109-01-01_00_00_00"
        },
        '6': {
            '4200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6107-01-01_00_00_00",
            '1200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6108-01-01_00_00_00",
            '800m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6021-01-01_00_00_00",
            '600m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6023-01-01_00_00_00",
            '400m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6025-01-01_00_00_00",
            '200m': "/Users/olayemiadeyemi/Documents/sponge_on/wrfout_d01_6109-01-01_00_00_00"
        }
    }
    
    # Process all configurations
    all_data = {}
    wrf_cache = {}  # Cache WRF data to avoid redundant processing
    
    orders = ['4', '5', '6']
    resolutions = ['4200m', '1200m', '800m', '600m', '400m', '200m']
    microphysics_types = ['Kessler', 'Cold']
    
    print("\nProcessing configurations...")
    print("-" * 80)
    
    for resolution in resolutions:
        for micro in microphysics_types:
            # Process WRF once per resolution/microphysics combination
            wrf_cache_key = f"{resolution}_{micro}"
            
            if micro == 'Kessler':
                wrf_file = wrf_kessler['4'][resolution]  # Use order 4 as reference
            else:
                wrf_file = wrf_cold['4'][resolution]  # Use order 4 as reference
            
            # Process WRF data once and cache it
            if Path(wrf_file).exists():
                print(f"\nProcessing WRF data for {resolution} {micro}...")
                try:
                    # Create a dummy comparison object to extract WRF data
                    ref_file_dummy = jexpresso_reference['4'][resolution]
                    if not ref_file_dummy.endswith('.pvtu'):
                        ref_file_dummy = os.path.join(ref_file_dummy, 'output', 'iter_0.pvtu')
                    
                    comparison_wrf = EKEComparison(wrf_file, ref_file_dummy, "", max_height=24000)
                    
                    # Extract WRF time series
                    wrf_time_data = []
                    wrf_times = []
                    max_iterations = 50
                    
                    for i in range(max_iterations):
                        try:
                            comparison_wrf.extract_wrf_data(timeidx=i)
                            wrf_eke = np.nanmean(comparison_wrf.wrf_data["EKE (J/m3)"].values)
                            if wrf_eke <= 500 and not np.isnan(wrf_eke):
                                wrf_time_data.append(wrf_eke)
                                wrf_times.append(i * 10)
                        except:
                            break
                    
                    if len(wrf_time_data) > 0:
                        wrf_cache[wrf_cache_key] = {
                            'wrf_dataframe': comparison_wrf.wrf_data,
                            'wrf_time_series': wrf_time_data,
                            'times': wrf_times
                        }
                        print(f"✓ Cached WRF data for {resolution} {micro} ({len(wrf_time_data)} time steps)")
                    else:
                        print(f"✗ No valid WRF data for {resolution} {micro}")
                except Exception as e:
                    print(f"✗ Failed to process WRF for {resolution} {micro}: {e}")
            
            # Now process Jexpresso for each order
            for order in orders:
                print(f"\nProcessing Jexpresso: Order {order}, Resolution {resolution}, {micro}")
                
                if micro == 'Kessler':
                    jex_pert = jexpresso_kessler[order][resolution]
                else:
                    jex_pert = jexpresso_cold[order][resolution]
                
                ref_file = jexpresso_reference[order][resolution]
                
                # Handle directory-based reference files
                if not ref_file.endswith('.pvtu'):
                    ref_file = os.path.join(ref_file, 'output', 'iter_0.pvtu')
                
                # Check if we have cached WRF data
                if wrf_cache_key not in wrf_cache:
                    print(f"✗ No WRF cache available for {resolution} {micro}")
                    continue
                
                # Process Jexpresso
                try:
                    time_data = {'wrf': wrf_cache[wrf_cache_key]['wrf_time_series'].copy(),
                               'jex': [],
                               'times': wrf_cache[wrf_cache_key]['times'].copy()}
                    
                    max_iterations = 50
                    existing_pert_files = [i for i in range(max_iterations) if Path(jex_pert.format(i)).exists()]
                    
                    if not existing_pert_files:
                        print(f"✗ No perturbation files found")
                        continue
                    
                    for i in existing_pert_files:
                        if i * 10 not in time_data['times']:
                            continue
                        try:
                            comparison_jex = EKEComparison(wrf_file, ref_file, jex_pert.format(i), max_height=24000)
                            comparison_jex.wrf_data = wrf_cache[wrf_cache_key]['wrf_dataframe']
                            comparison_jex.extract_jexpresso_data()
                            comp_data = comparison_jex.get_comparison_data()
                            
                            if len(comp_data['jex_eke']) == 0 or np.all(np.isnan(comp_data['jex_eke'])):
                                continue
                            
                            jex_total_eke = np.nanmean(comp_data['jex_eke'])
                            
                            if jex_total_eke > 500 or jex_total_eke < 0 or not np.isfinite(jex_total_eke):
                                continue
                            
                            time_data['jex'].append(jex_total_eke)
                        except:
                            continue
                    
                    # Trim to match lengths
                    min_len = min(len(time_data['wrf']), len(time_data['jex']))
                    time_data['wrf'] = time_data['wrf'][:min_len]
                    time_data['jex'] = time_data['jex'][:min_len]
                    time_data['times'] = time_data['times'][:min_len]
                    
                    if len(time_data['jex']) > 0:
                        key = f"Order_{order}_{resolution}_{micro}"
                        all_data[key] = time_data
                        print(f"✓ Successfully processed {key} ({len(time_data['jex'])} time steps)")
                    else:
                        print(f"✗ No valid Jexpresso data for Order {order}, {resolution}, {micro}")
                except Exception as e:
                    print(f"✗ Failed to process Order {order}, {resolution}, {micro}: {e}")
    
    print("\n" + "=" * 80)
    print(f"Data processing complete. Successfully processed {len(all_data)} configurations.")
    print("=" * 80)
    
    # Generate publication-quality comprehensive grid plot
    print("\nGenerating publication-quality comprehensive grid plot...")
    print("-" * 80)
    plot_comprehensive_grid(all_data, output_base)
    
    print("\n" + "=" * 80)
    print("Publication-quality plot generated successfully!")
    print(f"Output directory: {output_base}")
    print("=" * 80)
    print(f"  - Publication-quality grid: {output_base}/Comprehensive_Grid_Publication_Quality.png")
    print("=" * 80)


if __name__ == "__main__":
    main()