In [None]:

# SUPERIOR CODE

import os
import numpy as np
import mne
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import pandas as pd
from scipy.stats import ttest_1samp, pearsonr

# ------------------------ CONFIG ------------------------
preproc_dir = "preprocessed"
fig_dir = "figures"
os.makedirs(fig_dir, exist_ok=True)
n_sub_max = 21
theta_band = (4, 8)  # Theta band in Hz
baseline = (-0.2, 0.0)  # Baseline period in seconds

# A priori time windows (defined by researcher)
time_windows = {
    "Early (100-200 ms)": (0.10, 0.20),
    "Mid (200-300 ms)": (0.20, 0.30),
    "Late (300-400 ms)": (0.30, 0.40),
}

# Posterior channels for occipito-temporal analysis
posterior_channels = [
    'MEG02', 'MEG29', 'MEG11', 'MEG47', 'MEG62', 'MEG15', 'MEG13', 'MEG10', 'MEG14',
    'MEG25', 'MEG48', 'MEG56', 'MEG61', 'MEG64', 'MEG52', 'MEG59', 'MEG12', 'MEG26',
    'MEG49', 'MEG50', 'MEG39', 'MEG54', 'MEG23', 'MEG28'
]

# Set publication style
plt.style.use('default')
mpl.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
    "font.size": 9,
    "axes.labelsize": 10,
    "axes.titlesize": 11,
    "axes.titleweight": "bold",
    "axes.linewidth": 0.8,
    "lines.linewidth": 1.5,
    "lines.markersize": 4,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 8,
    "legend.frameon": True,
    "legend.framealpha": 0.9,
    "figure.dpi": 600,
    "savefig.dpi": 600,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.02,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

# Color palette
COLORS = {
    'emotional': '#E74C3C',  # Red
    'neutral': '#3498DB',    # Blue
    'difference': '#2C3E50', # Dark blue/black
    'early': '#B2182B',  # Dark red – strong significant effect (100–200 ms)
    'mid':   '#EF8A62',  # Muted red / rose – significant but smaller (200–300 ms)
    'late':  '#4A6FE3',  # Steel blue – not significant (300–400 ms)
    'positive': '#E74C3C',   # Red for positive effects
    'negative': '#3498DB',   # Blue for negative effects
    'significance': '#27AE60', # Green for significance
}

# ------------------------ DATA LOADING AND PROCESSING ------------------------

def load_subject_epochs(subject_id):
    """Load emotional and neutral epochs for a subject"""
    emo_file = os.path.join(preproc_dir, f"sub-{subject_id}", "dimensions", "expression",
                            f"sub-{subject_id}_ses-01_run-01_expression_emotional-epo.fif")
    neu_file = os.path.join(preproc_dir, f"sub-{subject_id}", "dimensions", "expression",
                            f"sub-{subject_id}_ses-01_run-01_expression_neutral-epo.fif")
    
    if not (os.path.exists(emo_file) and os.path.exists(neu_file)):
        print(f"  ╰─ Skipping sub-{subject_id} (missing files)")
        return None, None, None
    
    emotional_epochs = mne.read_epochs(emo_file, preload=True, verbose=False)
    neutral_epochs = mne.read_epochs(neu_file, preload=True, verbose=False)
    
    return emotional_epochs, neutral_epochs, emotional_epochs.info

def compute_theta_power(epochs):
    """
    Compute theta band power using Morlet wavelets
    Returns: power_db with shape (n_epochs, n_channels, n_times)
    """
    data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
    sfreq = epochs.info['sfreq']
    times = epochs.times
    
    # Compute time-frequency representation using Morlet wavelets
    from mne.time_frequency import tfr_array_morlet
    
    # Define frequencies for theta band
    freqs = np.arange(theta_band[0], theta_band[1] + 1)
    n_cycles = freqs / 3.0  # Different number of cycles per frequency
    
    # Compute TFR
    power = tfr_array_morlet(data, sfreq=sfreq, freqs=freqs,
                             n_cycles=n_cycles, output='power',
                             zero_mean=True, verbose=False)
    
    # Average across theta frequencies (4-8 Hz)
    power = power.mean(axis=1)  # Shape: (n_epochs, n_channels, n_times)
    
    # Baseline correction (dB conversion)
    baseline_idx = (times >= baseline[0]) & (times <= baseline[1])
    baseline_power = power[:, :, baseline_idx].mean(axis=2, keepdims=True)
    
    # Avoid division by zero
    baseline_power[baseline_power == 0] = 1e-10
    power_db = 10 * np.log10(power / baseline_power)
    
    return power_db, times

def collect_subject_differences():
    """Collect theta power differences across all subjects"""
    # Get list of subjects
    subjects = sorted([d.replace("sub-", "") for d in os.listdir(preproc_dir)
                      if d.startswith("sub-")])[:n_sub_max]
    
    print("\n" + "="*70)
    print("PROCESSING SUBJECTS".center(70))
    print("="*70)
    
    # First pass: find common channels across all subjects
    common_channels = None
    sample_info = None
    
    print("\nFirst pass: Finding common channels...")
    for i, sub in enumerate(subjects, 1):
        emo, neu, info = load_subject_epochs(sub)
        if emo is None:
            continue
        
        # Get available posterior channels for this subject
        available_channels = [ch for ch in emo.ch_names if ch in posterior_channels]
        
        if common_channels is None:
            common_channels = set(available_channels)
            sample_info = info
        else:
            common_channels = common_channels.intersection(set(available_channels))
        
        progress = i / len(subjects) * 100
        print(f"  ├─ Subject {i:2d}/{len(subjects)}: sub-{sub} | Common: {len(common_channels)}")
    
    if not common_channels or len(common_channels) == 0:
        print("❌ No common channels found!")
        return None, None, None, None
    
    common_channels = sorted(list(common_channels))
    print(f"\n✓ Found {len(common_channels)} common channels")
    
    # Second pass: compute power differences
    differences = []
    all_times = None
    subject_ids = []
    
    print("\nSecond pass: Computing theta power differences...")
    for i, sub in enumerate(subjects, 1):
        emo, neu, info = load_subject_epochs(sub)
        if emo is None:
            continue
        
        # Select common channels
        emo.pick(common_channels)
        neu.pick(common_channels)
        
        # Match number of trials
        n_trials = min(len(emo), len(neu))
        emo = emo[:n_trials]
        neu = neu[:n_trials]
        
        # Compute theta power
        emo_power, times = compute_theta_power(emo)
        neu_power, _ = compute_theta_power(neu)
        
        # Calculate difference (emotional - neutral)
        # Average across trials first
        emo_mean = emo_power.mean(axis=0)  # (channels, times)
        neu_mean = neu_power.mean(axis=0)  # (channels, times)
        diff = emo_mean - neu_mean  # (channels, times)
        
        differences.append(diff)
        subject_ids.append(sub)
        if all_times is None:
            all_times = times
        
        progress = i / len(subjects) * 100
        print(f"  ├─ Processed {i:2d}/{len(subjects)}: sub-{sub} | Channels: {diff.shape[0]}, Times: {diff.shape[1]}")
    
    if len(differences) == 0:
        print("❌ No valid data found!")
        return None, None, None, None
    
    # Convert to numpy array
    data = np.array(differences)  # Shape: (n_subjects, n_channels, n_times)
    
    print(f"\n✓ Successfully processed {len(differences)} subjects")
    print(f"✓ Final data shape: {data.shape}")
    print(f"✓ Time range: {all_times[0]*1000:.0f} to {all_times[-1]*1000:.0f} ms")
    
    return data, all_times, common_channels, sample_info, subject_ids

# ------------------------ STATISTICAL ANALYSES ------------------------

def run_planned_comparisons(data, times, channels):
    """Run paired t-tests in a priori time windows"""
    print("\n" + "="*70)
    print("PLANNED COMPARISONS IN A PRIORI TIME WINDOWS".center(70))
    print("="*70)
    
    results = {}
    
    for window_name, (tmin, tmax) in time_windows.items():
        # Find time indices for this window
        time_indices = (times >= tmin) & (times <= tmax)
        
        # Average over time and posterior channels
        window_data = np.mean(data[:, :, time_indices], axis=(1, 2))
        
        # One-tailed paired t-test (emotional > neutral)
        t_stat, p_value = stats.ttest_1samp(window_data, 0, alternative='greater')
        
        # Effect size (Cohen's d)
        cohens_d = np.mean(window_data) / np.std(window_data, ddof=1)
        
        # Store results
        results[window_name] = {
            't_statistic': t_stat,
            'p_value': p_value,
            'cohens_d': cohens_d,
            'mean_difference': np.mean(window_data),
            'std_difference': np.std(window_data, ddof=1),
            'time_window': (tmin, tmax),
            'n_subjects': len(window_data)
        }
        
        # Print results
        print(f"\n{window_name}:")
        print(f"  Mean difference: {np.mean(window_data):.6f} dB")
        print(f"  t({len(window_data)-1}) = {t_stat:.3f}, p = {p_value:.3f}")
        print(f"  Cohen's d = {cohens_d:.3f}")
        print(f"  Time range: {tmin*1000:.0f}-{tmax*1000:.0f} ms")
    
    return results

def run_cluster_permutation_test(data, times, n_permutations=1000):
    """
    Run cluster-based permutation test (Maris & Oostenveld, 2007)
    Using paired t-test design (emotional vs neutral)
    """
    print("\n" + "="*70)
    print("CLUSTER PERMUTATION TEST".center(70))
    print("="*70)
    
    # Select analysis time window (100-400 ms)
    analysis_window = (0.10, 0.40)
    time_indices = (times >= analysis_window[0]) & (times <= analysis_window[1])
    
    # Crop data to analysis window
    analysis_data = data[:, :, time_indices]
    analysis_times = times[time_indices]
    
    print(f"Analysis window: {analysis_window[0]*1000:.0f}-{analysis_window[1]*1000:.0f} ms")
    print(f"Time points: {analysis_data.shape[2]}")
    print(f"Channels: {analysis_data.shape[1]}")
    print(f"Subjects: {analysis_data.shape[0]}")
    print(f"Permutations: {n_permutations}")
    print("Test direction: One-tailed (emotional > neutral)")
    
    # Run cluster permutation test
    threshold = 2.0  # Cluster-forming threshold (t-value)
    
    # Note: We're using one-sample t-test against 0 because we already have the difference scores
    t_vals, clusters, cluster_p_vals, H0 = mne.stats.permutation_cluster_1samp_test(
        analysis_data,
        threshold=threshold,
        n_permutations=n_permutations,
        tail=1,  # One-tailed (greater)
        out_type='mask',
        verbose=True
    )
    
    # Process results
    n_clusters = len(clusters)
    significant_clusters = []
    
    print(f"\n✓ Cluster test completed")
    print(f"  Total clusters found: {n_clusters}")
    
    for i, (cluster_mask, p_val) in enumerate(zip(clusters, cluster_p_vals)):
        # Get cluster properties
        sensors_involved = np.any(cluster_mask, axis=1)
        times_involved = np.any(cluster_mask, axis=0)
        
        # Convert time indices to ms
        cluster_times = analysis_times[times_involved]
        time_range = (cluster_times[0]*1000, cluster_times[-1]*1000)
        
        # Calculate cluster mass
        cluster_mass = np.sum(t_vals[cluster_mask])
        
        cluster_info = {
            'cluster_id': i + 1,
            'p_value': p_val,
            'sensors_involved': np.where(sensors_involved)[0],
            'n_sensors': np.sum(sensors_involved),
            'time_range_ms': time_range,
            'duration_ms': time_range[1] - time_range[0],
            'cluster_mass': cluster_mass,
            'significant': p_val < 0.05
        }
        
        if cluster_info['significant']:
            significant_clusters.append(cluster_info)
            print(f"  ⭐ Significant cluster {i+1}: p = {p_val:.4f}")
        else:
            print(f"  Cluster {i+1}: p = {p_val:.4f}")
    
    print(f"\n  Significant clusters (p < 0.05): {len(significant_clusters)}")
    
    return {
        't_values': t_vals,
        'clusters': clusters,
        'cluster_p_values': cluster_p_vals,
        'significant_clusters': significant_clusters,
        'n_clusters': n_clusters,
        'n_significant': len(significant_clusters),
        'threshold': threshold,
        'n_permutations': n_permutations,
        'analysis_window': analysis_window,
        'analysis_times': analysis_times
    }

# ------------------------ VISUALIZATION FUNCTIONS ------------------------

def create_figure_grand_average_old_style(data, times, channels, cluster_results=None):
    """Create grand average time course figure in OLD STYLE with extended y-axis"""
    # Calculate grand average across subjects and channels
    grand_avg = np.mean(np.mean(data, axis=0), axis=0)  # Average over subjects and channels
    sem = stats.sem(np.mean(data, axis=1), axis=0)  # SEM across subjects
    
    # Convert to milliseconds
    times_ms = times * 1000
    
    # Create figure with space on the right for sidebar
    fig = plt.figure(figsize=(13, 8))
    
    # Create main axes for the plot (taking most of the figure width)
    ax = plt.axes([0.1, 0.15, 0.55, 0.75])  # [left, bottom, width, height]
    
    # Limit to 100-400 ms as in old function
    time_mask = (times_ms >= 100) & (times_ms <= 400)
    times_limited = times_ms[time_mask]
    data_limited = grand_avg[time_mask]
    sem_limited = sem[time_mask]
    
    # Plot time course (BLACK line as in old style)
    line = ax.plot(times_limited, data_limited, 'k-', linewidth=2.5, 
                   label='Grand average (emotional - neutral)')[0]
    
    # Add SEM shading (GRAY as in old style)
    sem_fill = ax.fill_between(times_limited, data_limited - sem_limited, 
                               data_limited + sem_limited,
                               alpha=0.3, color='gray', label='±1 SEM')
    
    # Horizontal zero line
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    
    # Mark a priori time windows with OLD STYLE COLORS
    colors = {
        'Early (100-200 ms)': '#B2182B',  # deep red
        'Mid (200-300 ms)': '#EF8A62',    # red
        'Late (300-400 ms)': '#4A6FE3'    # Blue
    }
    
    for window_name, (tmin, tmax) in time_windows.items():
        tmin_ms, tmax_ms = tmin * 1000, tmax * 1000
        # Shaded background for time window
        ax.axvspan(tmin_ms, tmax_ms, alpha=0.2, color=colors[window_name])
    
    # Calculate means for each time window for labeling
    window_means = {}
    for window_name, (tmin, tmax) in time_windows.items():
        tmin_ms, tmax_ms = tmin * 1000, tmax * 1000
        window_mask = (times_ms >= tmin_ms) & (times_ms <= tmax_ms)
        window_means[window_name] = np.mean(grand_avg[window_mask])
    
    # Add window labels INSIDE plot like old style
    # Adjusted y-position for extended y-axis
    label_y_position = 0.35  # Adjust for new y-axis range
    
    ax.text(150, label_y_position, 'Early\n(100-200 ms)', 
            ha='center', va='top', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='#FF6B6B', alpha=0.3))
    ax.text(250, label_y_position, 'Mid\n(200-300 ms)', 
            ha='center', va='top', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='#4ECDC4', alpha=0.3))
    ax.text(350, label_y_position, 'Late\n(300-400 ms)', 
            ha='center', va='top', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='#45B7D1', alpha=0.3))
    
    # Highlight significant clusters if they exist
    cluster_patches = []
    cluster_labels = []
    if cluster_results and cluster_results['significant_clusters']:
        for cluster in cluster_results['significant_clusters']:
            t_start, t_end = cluster['time_range_ms']
            # Only show if within our plot range
            if t_start >= 100 and t_end <= 400:
                # Draw cluster as a shaded region at the BOTTOM of the plot
                cluster_y_bottom = -0.095  # Just above the bottom of the plot
                cluster_y_top = -0.085     # Small height
                
                # Shaded rectangle for cluster
                cluster_fill = ax.fill_betweenx([cluster_y_bottom, cluster_y_top], 
                                              t_start, t_end,
                                              alpha=0.5, color='#27AE60')
                
                # Add cluster label on top of the green line
                cluster_mid = (t_start + t_end) / 2
                cluster_label = ax.text(cluster_mid, cluster_y_top + 0.005,
                                       f'p={cluster["p_value"]:.3f}',
                                       ha='center', va='bottom', fontsize=9,
                                       fontweight='bold', color='#27AE60')
                
                cluster_patches.append(cluster_fill)
                cluster_labels.append(cluster_label)
    
    # Labels and title - OLD STYLE
    ax.set_xlabel('Time (milliseconds, ms)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Theta Power Difference (decibels, dB)\nEmotional - Neutral', 
                  fontsize=12, fontweight='bold')
    ax.set_title('Grand Average Theta-Band (4-8 Hz) Time Course\n(100-400 ms post-stimulus)', 
                 fontsize=13, fontweight='bold', pad=15)
    
    # Grid and limits - OLD STYLE
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_xlim(100, 400)
    
    # Set y-axis from -0.1 to 0.4 as requested
    ax.set_ylim(-0.1, 0.4)
    
    # Add minor ticks for better readability
    ax.yaxis.set_minor_locator(plt.MultipleLocator(0.05))
    
    # Create custom legend with all elements
    from matplotlib.patches import Patch
    from matplotlib.lines import Line2D
    
    # Create legend elements
    legend_elements = [
        Line2D([0], [0], color='black', linewidth=2.5, label='Grand average\n(emotional - neutral)'),
        Patch(facecolor='gray', alpha=0.3, label='±1 SEM'),
    ]
    
    # Add cluster to legend if it exists
    if cluster_results and cluster_results['significant_clusters']:
        cluster = cluster_results['significant_clusters'][0]
        legend_elements.append(
            Patch(facecolor='#27AE60', alpha=0.5, 
                  label=f"Significant cluster")
        )
    
    # Create a sidebar axes for the legend and parameters
    sidebar_ax = plt.axes([0.68, 0.15, 0.3, 0.75])  # [left, bottom, width, height]
    sidebar_ax.axis('off')
    
    # Add legend at the TOP of the sidebar
    legend = sidebar_ax.legend(handles=legend_elements, loc='upper left', 
                               fontsize=10, framealpha=0.9, borderaxespad=0,
                               title='Legend:', title_fontsize=11)
    
    # Add analysis parameters text box RIGHT UNDER the legend
    params_text = f"""Analysis Parameters:
• Theta band: {theta_band[0]}-{theta_band[1]} Hz
• Baseline: {baseline[0]*1000:.0f}-{baseline[1]*1000:.0f} ms
• Subjects: {data.shape[0]}
• Channels: {len(channels)} posterior sensors
• Time window: 100-400 ms
• Cluster threshold: t > 2.0"""
    
    # Position parameters box directly under the legend
    # We'll estimate the legend height (approx 0.15 of the figure height)
    # and place parameters starting at 0.7 (70% down from top of sidebar)
    sidebar_ax.text(0, 0.65, params_text,
                   fontsize=9,
                   verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8,
                            edgecolor='black', linewidth=0.5))
    
    plt.tight_layout()
    
    # Save both PNG and PDF
    plt.savefig(os.path.join(fig_dir, 'figure_grand_average_old_style.png'), 
                dpi=600, bbox_inches='tight')
    plt.savefig(os.path.join(fig_dir, 'figure_grand_average_old_style.pdf'), 
                bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved old-style figure: {fig_dir}/figure_grand_average_old_style.png")

def create_effect_size_figure_old_style(planned_results):
    """Create effect size figure in OLD STYLE with labels moved to right"""
    if not planned_results:
        return
    
    fig = plt.figure(figsize=(10, 6))
    ax = plt.axes([0.1, 0.15, 0.6, 0.75])  # Make main plot narrower
    
    # Prepare data
    windows = ['Early (100-200 ms)', 'Mid (200-300 ms)', 'Late (300-400 ms)']
    cohens_d = [planned_results[w]['cohens_d'] for w in windows]
    p_values = [planned_results[w]['p_value'] for w in windows]
    
    # OLD STYLE colors
    colors = ['#B2182B', '#EF8A62', '#4A6FE3']
    
    # Create bar plot
    x_pos = np.arange(len(windows))
    bars = ax.bar(x_pos, cohens_d, color=colors, alpha=0.7, 
                  edgecolor='black', linewidth=1)
    
    # Add significance markers (kept on bars)
    for i, (p, d) in enumerate(zip(p_values, cohens_d)):
        if p < 0.001:
            symbol = '***'
            y_offset = 0.03 * (1 if d >= 0 else -1)
        elif p < 0.01:
            symbol = '**'
            y_offset = 0.025 * (1 if d >= 0 else -1)
        elif p < 0.05:
            symbol = '*'
            y_offset = 0.02 * (1 if d >= 0 else -1)
        elif p < 0.1:
            symbol = '†'
            y_offset = 0.015 * (1 if d >= 0 else -1)
        else:
            continue
        
        ax.text(i, d + y_offset, symbol, 
                ha='center', va='bottom' if d >= 0 else 'top', 
                fontsize=12, fontweight='bold')
    
    # Customize plot
    ax.set_xticks(x_pos)
    ax.set_xticklabels(['Early\n(100-200)', 'Mid\n(200-300)', 'Late\n(300-400)'],
                       fontsize=11)
    ax.set_ylabel("Cohen's d (Effect Size)", fontsize=11, fontweight='bold')
    ax.set_title('Effect Sizes by Time Window\n(Paired t-tests)', 
                 fontsize=12, fontweight='bold')
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Create sidebar for annotations
    sidebar_ax = plt.axes([0.75, 0.15, 0.2, 0.75])
    sidebar_ax.axis('off')
    
    # Add effect size interpretation guide to sidebar
    sidebar_ax.text(0, 0.9, 'Effect size interpretation:\n• d = 0.2: Small\n• d = 0.5: Medium\n• d = 0.8: Large', 
                   fontsize=9, verticalalignment='top', fontweight='bold',
                   transform=sidebar_ax.transAxes)
    
    # Add statistics summary to sidebar
    stats_text = f"""Statistical Summary:
n = {planned_results[windows[0]]['n_subjects']} subjects"""
    for w in windows:
        stats_text += f"\n{w.split()[0]}: t={planned_results[w]['t_statistic']:.2f}, p={planned_results[w]['p_value']:.3f}"
    
    sidebar_ax.text(0, 0.5, stats_text,
                   fontsize=9, 
                   verticalalignment='top', fontweight='bold',
                   transform=sidebar_ax.transAxes,
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'figure_effect_sizes_old_style.png'), dpi=600)
    plt.savefig(os.path.join(fig_dir, 'figure_effect_sizes_old_style.pdf'))
    plt.show()
    
    print(f"✓ Saved old-style effect size figure: {fig_dir}/figure_effect_sizes_old_style.png")

def create_composite_results_figure_old_style(data, times, channels, planned_results, 
                                            cluster_results=None):
    """
    Create comprehensive composite figure for main results (OLD STYLE)
    Figure 5 for thesis: Combines time course, effect sizes, and topography
    """
    print("\nCreating composite results figure (old style)...")
    
    # Calculate grand average
    grand_avg = np.mean(np.mean(data, axis=0), axis=0)
    times_ms = times * 1000
    sem = stats.sem(np.mean(data, axis=1), axis=0)
    
    # Create figure with subplots
    fig = plt.figure(figsize=(16, 12))
    
    # Create gridspec with custom layout
    gs = fig.add_gridspec(3, 4, height_ratios=[2, 1, 1], hspace=0.25, wspace=0.3)
    
    # Panel A: Time course (spans first row)
    ax1 = fig.add_subplot(gs[0, :3])
    
    # Limit to 100-400 ms
    time_mask = (times_ms >= 100) & (times_ms <= 400)
    times_limited = times_ms[time_mask]
    data_limited = grand_avg[time_mask]
    sem_limited = sem[time_mask]
    
    # Plot time course (BLACK line)
    ax1.plot(times_limited, data_limited, 'k-', linewidth=2.5)
    ax1.fill_between(times_limited, data_limited - sem_limited, data_limited + sem_limited,
                    alpha=0.3, color='gray', label='±1 SEM')
    
    # Mark a priori time windows with OLD colors
    colors = {'Early (100-200 ms)': '#B2182B', 
              'Mid (200-300 ms)': '#EF8A62', 
              'Late (300-400 ms)': '#4A6FE3'}
    
    for window_name, color in colors.items():
        if window_name in time_windows:
            tmin, tmax = time_windows[window_name]
            ax1.axvspan(tmin*1000, tmax*1000, alpha=0.2, color=color)
    
    ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax1.set_xlabel('Time (milliseconds, ms)', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Theta Power Difference (decibels, dB)\nEmotional - Neutral', 
                  fontsize=12, fontweight='bold')
    ax1.set_title('A) Grand Average Theta-Band (4-8 Hz) Time Course\n(100-400 ms post-stimulus)', 
                 fontsize=13, fontweight='bold')
    ax1.legend(loc='upper right')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(100, 400)
    
    # Add window labels
    ylim = ax1.get_ylim()
    ax1.text(150, ylim[1]*0.9, 'Early\n(100-200 ms)', 
            ha='center', va='top', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='#B2182B', alpha=0.3))
    ax1.text(250, ylim[1]*0.9, 'Mid\n(200-300 ms)', 
            ha='center', va='top', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='#EF8A62', alpha=0.3))
    ax1.text(350, ylim[1]*0.9, 'Late\n(300-400 ms)', 
            ha='center', va='top', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='#4A6FE3', alpha=0.3))
    
    # Panel B: Effect sizes
    ax2 = fig.add_subplot(gs[0, 3])
    
    if planned_results:
        windows = ['Early (100-200 ms)', 'Mid (200-300 ms)', 'Late (300-400 ms)']
        cohens_d = [planned_results[w]['cohens_d'] for w in windows]
        p_values = [planned_results[w]['p_value'] for w in windows]
        
        # Create bar plot
        x_pos = np.arange(len(windows))
        bars = ax2.bar(x_pos, cohens_d, color=[colors[w] for w in windows], alpha=0.7,
                      edgecolor='black', linewidth=1)
        
        # Add significance markers
        for i, p in enumerate(p_values):
            if p < 0.001:
                symbol = '***'
                y_offset = 0.03
            elif p < 0.01:
                symbol = '**'
                y_offset = 0.025
            elif p < 0.05:
                symbol = '*'
                y_offset = 0.02
            elif p < 0.1:
                symbol = '†'
                y_offset = 0.015
            else:
                continue
            
            ax2.text(i, cohens_d[i] + y_offset, symbol, 
                    ha='center', va='bottom', fontsize=12, fontweight='bold')
        
        ax2.set_xticks(x_pos)
        ax2.set_xticklabels(['Early\n(100-200)', 'Mid\n(200-300)', 'Late\n(300-400)'],
                           fontsize=11)
        ax2.set_ylabel("Cohen's d (Effect Size)", fontsize=11, fontweight='bold')
        ax2.set_title('B) Effect Sizes by Time Window\n(Paired t-tests)', 
                     fontsize=12, fontweight='bold')
        ax2.axhline(y=0, color='black', linestyle='-', alpha=0.3)
        ax2.grid(True, alpha=0.3, axis='y')
        
        # Add effect size interpretation
        ax2.text(0.02, 0.98, 'Effect size interpretation:\n• d = 0.2: Small\n• d = 0.5: Medium\n• d = 0.8: Large', 
                transform=ax2.transAxes, fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    # Panels C, D, E: Topographic maps (simplified)
    # Note: You would need to compute topography data for each window
    # For now, we'll create placeholder axes
    
    for i, window_name in enumerate(['Early (100-200 ms)', 'Mid (200-300 ms)', 'Late (300-400 ms)']):
        ax = fig.add_subplot(gs[1, i])
        
        # Placeholder text - you would add actual topomap plotting here
        ax.text(0.5, 0.5, f'Topography: {window_name}\n(To be implemented)', 
                ha='center', va='center', transform=ax.transAxes, fontsize=10)
        ax.set_title(f'{chr(67+i)}) {window_name}', fontsize=11, fontweight='bold')
        ax.axis('off')
    
    # Panel F: Colorbar/legend space
    ax6 = fig.add_subplot(gs[1, 3])
    ax6.axis('off')
    
    # Add analysis parameters in bottom
    ax7 = fig.add_subplot(gs[2, :])
    ax7.axis('off')
    
    params_text = f"""Analysis Parameters:
• Theta band: {theta_band[0]}-{theta_band[1]} Hz | Baseline: {baseline[0]*1000:.0f}-{baseline[1]*1000:.0f} ms
• Subjects: {data.shape[0]} | Channels: {len(channels)} posterior sensors
• Statistical tests: One-tailed paired t-tests (emotional > neutral)"""
    
    ax7.text(0.5, 0.5, params_text, ha='center', va='center', 
            transform=ax7.transAxes, fontsize=10,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'figure_composite_old_style.png'), dpi=600)
    plt.savefig(os.path.join(fig_dir, 'figure_composite_old_style.pdf'))
    plt.show()
    
    print(f"✓ Saved composite figure (old style): {fig_dir}/figure_composite_old_style.png")

def create_figure_cluster_topography(data, times, channels, info, cluster_results):
    """Create topographic map of cluster activity - FIXED with labels moved to right"""
    if not cluster_results or not cluster_results['significant_clusters']:
        print("No significant clusters to plot")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Adjust subplot positions
    plt.subplots_adjust(left=0.1, right=0.9, wspace=0.3)
    
    # Panel A: Bar plot of t-values for each channel at peak time
    ax1 = axes[0]
    
    # Find peak time in the cluster
    cluster = cluster_results['significant_clusters'][0]
    cluster_mask = cluster_results['clusters'][cluster['cluster_id'] - 1]
    
    # Find time point with maximum average t-value in cluster
    cluster_t_values = cluster_results['t_values'][cluster_mask]
    time_indices = np.where(np.any(cluster_mask, axis=0))[0]
    
    if len(time_indices) == 0:
        print("Warning: No time indices in cluster mask")
        return
    
    max_time_idx = time_indices[np.argmax(np.mean(cluster_t_values, axis=0))]
    t_values_at_peak = cluster_results['t_values'][:, max_time_idx]
    
    # Get the actual number of channels from the t-values (should be 5)
    n_channels = t_values_at_peak.shape[0]
    
    # Create bar plot for each channel - using the actual number of channels
    channel_indices = np.arange(n_channels)
    colors = ['#E74C3C' if t > 0 else '#3498DB' for t in t_values_at_peak]
    
    bars = ax1.bar(channel_indices, t_values_at_peak, 
                   color=colors, alpha=0.8,
                   edgecolor='black', linewidth=0.5)
    
    # Highlight cluster sensors
    cluster_sensors = cluster['sensors_involved']
    for sensor in cluster_sensors:
        if sensor < len(bars):
            bars[sensor].set_edgecolor('yellow')
            bars[sensor].set_linewidth(2)
    
    # Add threshold line
    ax1.axhline(cluster_results['threshold'], color='black', linestyle='--',
               linewidth=1, alpha=0.7, label=f'Threshold (t={cluster_results["threshold"]})')
    
    ax1.set_xlabel('Channel Index', fontweight='bold')
    ax1.set_ylabel('t-value (emotional > neutral)', fontweight='bold')
    ax1.set_title(f'A. Channel-wise t-values at Peak Activation\n({cluster_results["analysis_times"][max_time_idx]*1000:.0f} ms)',
                 fontsize=11, fontweight='bold')
    
    # Set x-ticks - use actual channel indices
    ax1.set_xticks(channel_indices)
    
    # If we have channel names, use them; otherwise use indices
    if len(channels) >= n_channels:
        # Use the first n_channels channel names
        ax1.set_xticklabels([channels[i] for i in range(n_channels)], rotation=45, ha='right')
    else:
        # Use channel indices
        ax1.set_xticklabels([f'Channel {i}' for i in channel_indices], rotation=45, ha='right')
    
    ax1.grid(True, alpha=0.2, linestyle='--', axis='y')
    
    # Panel B: Cluster time course for involved sensors
    ax2 = axes[1]
    
    # Average t-values across cluster sensors over time
    if len(cluster_sensors) > 0:
        cluster_sensor_timeseries = cluster_results['t_values'][cluster_sensors, :]
        avg_cluster_t = np.mean(cluster_sensor_timeseries, axis=0)
        sem_cluster_t = stats.sem(cluster_sensor_timeseries, axis=0)
    else:
        avg_cluster_t = np.mean(cluster_results['t_values'], axis=0)
        sem_cluster_t = stats.sem(cluster_results['t_values'], axis=0)
    
    times_ms = cluster_results['analysis_times'] * 1000
    
    ax2.plot(times_ms, avg_cluster_t, 
             color='#27AE60', linewidth=2)
    ax2.fill_between(times_ms, avg_cluster_t - sem_cluster_t, avg_cluster_t + sem_cluster_t,
                     color='#27AE60', alpha=0.2)
    
    # Add cluster-forming threshold
    ax2.axhline(cluster_results['threshold'], color='black', linestyle='--',
               linewidth=1, alpha=0.7, label=f'Threshold (t={cluster_results["threshold"]})')
    
    # Highlight significant cluster time window
    t_start, t_end = cluster['time_range_ms']
    ax2.axvspan(t_start, t_end, alpha=0.2, color='#27AE60')
    
    # Zero line
    ax2.axhline(0, color='black', linewidth=0.8, alpha=0.5, zorder=0)
    
    ax2.set_xlabel('Time (ms)', fontweight='bold')
    ax2.set_ylabel('t-value (emotional > neutral)', fontweight='bold')
    ax2.set_title('B. Temporal Profile of Significant Cluster',
                 fontsize=11, fontweight='bold')
    
    ax2.grid(True, alpha=0.2, linestyle='--')
    ax2.set_xlim(90, 410)
    
    # Create sidebar for annotations (to the right of both panels)
    sidebar_ax = plt.axes([0.92, 0.15, 0.06, 0.75])
    sidebar_ax.axis('off')
    
    # Add cluster statistics to sidebar
    stats_text = f"""Cluster Statistics:
p-value: {cluster['p_value']:.4f}
Sensors: {cluster['n_sensors']}
Time: {t_start:.0f}-{t_end:.0f} ms
Duration: {cluster['duration_ms']:.0f} ms
Cluster mass: {cluster['cluster_mass']:.1f}
Permutations: {cluster_results['n_permutations']}"""
    
    sidebar_ax.text(0, 0.9, stats_text,
                   fontsize=8, verticalalignment='top', fontweight='bold',
                   transform=sidebar_ax.transAxes,
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
    
    # Add legend to sidebar
    from matplotlib.patches import Patch
    from matplotlib.lines import Line2D
    
    legend_elements = [
        Patch(facecolor='#27AE60', alpha=0.2, label='Cluster time window'),
        Line2D([0], [0], color='#27AE60', linewidth=2, label='Avg cluster t-values'),
        Line2D([0], [0], color='black', linestyle='--', linewidth=1, 
               alpha=0.7, label=f'Threshold (t={cluster_results["threshold"]})'),
        Patch(facecolor='yellow', edgecolor='black', linewidth=2, 
              label='Cluster sensors', alpha=0.8),
    ]
    
    sidebar_ax.legend(handles=legend_elements, loc='lower left', 
                     fontsize=8, framealpha=0.9,
                     title='Legend:', title_fontsize=9)
    
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'figure_cluster_topography.png'), dpi=600)
    plt.savefig(os.path.join(fig_dir, 'figure_cluster_topography.pdf'))
    plt.show()
    print(f"✓ Saved figure: {fig_dir}/figure_cluster_topography.png")
    
def create_figure_individual_differences(data, times, subject_ids):
    """Create individual participant differences figure with labels moved to right"""
    # Calculate mean difference for each subject
    subject_means = np.mean(np.mean(data, axis=1), axis=1)  # Average over channels and time
    
    # Create figure
    fig = plt.figure(figsize=(12, 6))
    ax = plt.axes([0.1, 0.15, 0.6, 0.75])  # Make main plot narrower
    
    # Sort subjects by mean difference
    sort_idx = np.argsort(subject_means)
    sorted_means = subject_means[sort_idx]
    sorted_ids = [subject_ids[i] for i in sort_idx]
    
    # Color bars by direction
    bar_colors = [COLORS['positive'] if val > 0 else COLORS['negative'] 
                  for val in sorted_means]
    
    # Create bars
    bars = ax.bar(range(len(sorted_means)), sorted_means,
                 color=bar_colors, alpha=0.8,
                 edgecolor='black', linewidth=0.5)
    
    # Add individual points
    ax.scatter(range(len(sorted_means)), sorted_means,
               color='black', s=20, zorder=3)
    
    # Calculate statistics
    mean_val = np.mean(subject_means)
    sem_val = stats.sem(subject_means)
    ci_95 = 1.96 * sem_val
    
    # Add mean and CI
    ax.axhline(mean_val, color='black', linestyle='--',
               linewidth=1.5)
    ax.axhline(mean_val + ci_95, color='gray', linestyle=':', linewidth=1)
    ax.axhline(mean_val - ci_95, color='gray', linestyle=':', linewidth=1)
    
    # Zero line
    ax.axhline(0, color='black', linewidth=0.8, alpha=0.5, zorder=0)
    
    # Labels and title
    ax.set_xlabel('Participant (sorted by difference)', fontweight='bold')
    ax.set_ylabel('Theta Power Difference (dB)\nEmotional - Neutral', fontweight='bold')
    ax.set_title('Individual Participant Differences in Theta Power',
                 fontsize=12, fontweight='bold', pad=15)
    
    # X-ticks
    ax.set_xticks(range(0, len(sorted_ids), 2))
    ax.set_xticklabels([sorted_ids[i] for i in range(0, len(sorted_ids), 2)],
                       rotation=45, ha='right')
    
    # Grid and limits
    ax.grid(True, alpha=0.2, linestyle='--', axis='y')
    ax.set_xlim(-0.5, len(sorted_means) - 0.5)
    
    # Create sidebar for annotations
    sidebar_ax = plt.axes([0.75, 0.15, 0.2, 0.75])
    sidebar_ax.axis('off')
    
    # Add statistics to sidebar
    n_positive = np.sum(subject_means > 0)
    n_negative = np.sum(subject_means < 0)
    shapiro_stat, shapiro_p = stats.shapiro(subject_means)
    
    stats_text = f"""Individual Statistics:
Positive (Emotional > Neutral): {n_positive}
Negative (Neutral > Emotional): {n_negative}
Mean ± SEM: {mean_val:.3f} ± {sem_val:.3f} dB
95% CI: [{mean_val-ci_95:.3f}, {mean_val+ci_95:.3f}] dB
Shapiro-Wilk: p = {shapiro_p:.3f}"""
    
    sidebar_ax.text(0, 0.8, stats_text,
                   fontsize=9, verticalalignment='top', fontweight='bold',
                   transform=sidebar_ax.transAxes,
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
    
    # Add legend to sidebar
    from matplotlib.patches import Patch
    from matplotlib.lines import Line2D
    
    legend_elements = [
        Patch(facecolor=COLORS['positive'], alpha=0.8, label='Emotional > Neutral'),
        Patch(facecolor=COLORS['negative'], alpha=0.8, label='Neutral > Emotional'),
        Line2D([0], [0], color='black', linestyle='--', linewidth=1.5, 
               label=f'Group mean = {mean_val:.3f} dB'),
        Line2D([0], [0], color='gray', linestyle=':', linewidth=1, label='95% CI'),
    ]
    
    sidebar_ax.legend(handles=legend_elements, loc='lower left', 
                     fontsize=8, framealpha=0.9,
                     title='Legend:', title_fontsize=9)
    
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'figure_individual_differences.png'), dpi=600)
    plt.savefig(os.path.join(fig_dir, 'figure_individual_differences.pdf'))
    plt.show()
    print(f"✓ Saved figure: {fig_dir}/figure_individual_differences.png")

def create_figure_cluster_topography(data, times, channels, info, cluster_results):
    """Create topographic map of cluster activity - FIXED with correct channel count"""
    if not cluster_results or not cluster_results['significant_clusters']:
        print("No significant clusters to plot")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Panel A: Bar plot of t-values for each channel at peak time
    ax1 = axes[0]
    
    # Find peak time in the cluster
    cluster = cluster_results['significant_clusters'][0]
    cluster_mask = cluster_results['clusters'][cluster['cluster_id'] - 1]
    
    # Find time point with maximum average t-value in cluster
    cluster_t_values = cluster_results['t_values'][cluster_mask]
    time_indices = np.where(np.any(cluster_mask, axis=0))[0]
    
    if len(time_indices) == 0:
        print("Warning: No time indices in cluster mask")
        return
    
    max_time_idx = time_indices[np.argmax(np.mean(cluster_t_values, axis=0))]
    t_values_at_peak = cluster_results['t_values'][:, max_time_idx]
    
    # Get the actual number of channels from the t-values (should be 5)
    n_channels = t_values_at_peak.shape[0]
    
    # Create bar plot for each channel - using the actual number of channels
    channel_indices = np.arange(n_channels)
    colors = ['#E74C3C' if t > 0 else '#3498DB' for t in t_values_at_peak]
    
    bars = ax1.bar(channel_indices, t_values_at_peak, 
                   color=colors, alpha=0.8,
                   edgecolor='black', linewidth=0.5)
    
    # Highlight cluster sensors
    cluster_sensors = cluster['sensors_involved']
    for sensor in cluster_sensors:
        if sensor < len(bars):
            bars[sensor].set_edgecolor('yellow')
            bars[sensor].set_linewidth(2)
    
    # Add threshold line
    ax1.axhline(cluster_results['threshold'], color='black', linestyle='--',
               linewidth=1, alpha=0.7, label=f'Threshold (t={cluster_results["threshold"]})')
    
    ax1.set_xlabel('Channel Index', fontweight='bold')
    ax1.set_ylabel('t-value (emotional > neutral)', fontweight='bold')
    ax1.set_title(f'A. Channel-wise t-values at Peak Activation\n({cluster_results["analysis_times"][max_time_idx]*1000:.0f} ms)',
                 fontsize=11, fontweight='bold')
    
    # Set x-ticks - use actual channel indices
    ax1.set_xticks(channel_indices)
    
    # If we have channel names, use them; otherwise use indices
    if len(channels) >= n_channels:
        # Use the first n_channels channel names
        ax1.set_xticklabels([channels[i] for i in range(n_channels)], rotation=45, ha='right')
    else:
        # Use channel indices
        ax1.set_xticklabels([f'Channel {i}' for i in channel_indices], rotation=45, ha='right')
    
    ax1.grid(True, alpha=0.2, linestyle='--', axis='y')
    ax1.legend(loc='upper right')
    
    # Panel B: Cluster time course for involved sensors
    ax2 = axes[1]
    
    # Average t-values across cluster sensors over time
    if len(cluster_sensors) > 0:
        cluster_sensor_timeseries = cluster_results['t_values'][cluster_sensors, :]
        avg_cluster_t = np.mean(cluster_sensor_timeseries, axis=0)
        sem_cluster_t = stats.sem(cluster_sensor_timeseries, axis=0)
    else:
        avg_cluster_t = np.mean(cluster_results['t_values'], axis=0)
        sem_cluster_t = stats.sem(cluster_results['t_values'], axis=0)
    
    times_ms = cluster_results['analysis_times'] * 1000
    
    ax2.plot(times_ms, avg_cluster_t, 
             color='#27AE60', linewidth=2,
             label='Average cluster sensors')
    ax2.fill_between(times_ms, avg_cluster_t - sem_cluster_t, avg_cluster_t + sem_cluster_t,
                     color='#27AE60', alpha=0.2)
    
    # Add cluster-forming threshold
    ax2.axhline(cluster_results['threshold'], color='black', linestyle='--',
               linewidth=1, alpha=0.7, label=f'Threshold (t={cluster_results["threshold"]})')
    
    # Highlight significant cluster time window
    t_start, t_end = cluster['time_range_ms']
    ax2.axvspan(t_start, t_end, alpha=0.2, color='#27AE60',
               label=f'Cluster: {t_start:.0f}-{t_end:.0f} ms')
    
    # Zero line
    ax2.axhline(0, color='black', linewidth=0.8, alpha=0.5, zorder=0)
    
    ax2.set_xlabel('Time (ms)', fontweight='bold')
    ax2.set_ylabel('t-value (emotional > neutral)', fontweight='bold')
    ax2.set_title('B. Temporal Profile of Significant Cluster',
                 fontsize=11, fontweight='bold')
    
    # Add cluster statistics
    stats_text = f"""Cluster Statistics:
p-value: {cluster['p_value']:.4f}
Sensors: {cluster['n_sensors']}
Time: {t_start:.0f}-{t_end:.0f} ms
Duration: {cluster['duration_ms']:.0f} ms
Cluster mass: {cluster['cluster_mass']:.1f}
Permutations: {cluster_results['n_permutations']}"""
    
    ax2.text(0.98, 0.98, stats_text,
            transform=ax2.transAxes, fontsize=8,
            horizontalalignment='right', verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
    
    ax2.grid(True, alpha=0.2, linestyle='--')
    ax2.set_xlim(90, 410)
    ax2.legend(loc='upper right')
    
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'figure_cluster_topography.png'), dpi=600)
    plt.savefig(os.path.join(fig_dir, 'figure_cluster_topography.pdf'))
    plt.show()
    print(f"✓ Saved figure: {fig_dir}/figure_cluster_topography.png")

def create_figure_grand_average(data, times, channels, cluster_results=None):
    """Create grand average time course figure (ORIGINAL STYLE)"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Calculate grand average across subjects and channels
    grand_avg = np.mean(np.mean(data, axis=0), axis=0)  # Average over subjects and channels
    sem = stats.sem(np.mean(data, axis=1), axis=0)  # SEM across subjects
    
    # Convert to milliseconds
    times_ms = times * 1000
    
    # Plot grand average
    ax.plot(times_ms, grand_avg, 
            color=COLORS['difference'], linewidth=2,
            label='Grand average (emotional - neutral)')
    
    # Add SEM shading
    ax.fill_between(times_ms, grand_avg - sem, grand_avg + sem,
                    color=COLORS['difference'], alpha=0.2,
                    label='± SEM')
    
    # Mark stimulus onset
    ax.axvline(0, color='black', linestyle='--', linewidth=1, alpha=0.5)
    ax.text(5, ax.get_ylim()[1] * 0.95, 'Stimulus onset', 
            fontsize=9, va='top', style='italic')
    
    # Highlight a priori time windows
    for (window_name, (tmin, tmax)), color in zip(time_windows.items(), 
                                                  [COLORS['early'], COLORS['mid'], COLORS['late']]):
        tmin_ms, tmax_ms = tmin * 1000, tmax * 1000
        ax.axvspan(tmin_ms, tmax_ms, alpha=0.15, color=color)
        
        # Calculate mean in window
        window_mask = (times >= tmin) & (times <= tmax)
        window_mean = np.mean(grand_avg[window_mask])
        
        # Add window label
        ax.text((tmin_ms + tmax_ms) / 2, ax.get_ylim()[0] + 0.1, 
                f'{window_name.split()[0]}\n{window_mean:.4f} dB',
                ha='center', va='bottom', fontsize=8,
                bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8))
    
    # Highlight significant clusters if they exist
    if cluster_results and cluster_results['significant_clusters']:
        for cluster in cluster_results['significant_clusters']:
            t_start, t_end = cluster['time_range_ms']
            ax.axvspan(t_start, t_end, alpha=0.3, color=COLORS['significance'],
                      label=f"Cluster p={cluster['p_value']:.3f}")
    
    # Zero line
    ax.axhline(0, color='black', linewidth=0.8, alpha=0.5, zorder=0)
    
    # Add analysis parameters
    params_text = f"""Analysis Parameters:
• Theta band: {theta_band[0]}-{theta_band[1]} Hz
• Baseline: {baseline[0]*1000:.0f}-{baseline[1]*1000:.0f} ms
• Subjects: {data.shape[0]}
• Channels: {data.shape[1]} posterior sensors
• Time points: {data.shape[2]}"""
    
    ax.text(0.02, 0.98, params_text,
            transform=ax.transAxes, fontsize=8,
            verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Labels and title
    ax.set_xlabel('Time (ms)', fontweight='bold')
    ax.set_ylabel('Theta Power Difference (dB)\nEmotional - Neutral', fontweight='bold')
    ax.set_title('Grand Average Theta Power Difference Across Posterior Sensors',
                 fontsize=12, fontweight='bold', pad=15)
    
    # Grid and limits
    ax.grid(True, alpha=0.2, linestyle='--')
    ax.set_xlim(-50, 450)
    
    # Legend
    ax.legend(loc='upper right')
    
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'figure_grand_average.png'), dpi=600)
    plt.savefig(os.path.join(fig_dir, 'figure_grand_average.pdf'))
    plt.show()
    print(f"✓ Saved figure: {fig_dir}/figure_grand_average.png")

def create_figure_effect_size_comparison(planned_results, cluster_results):
    """Create figure comparing effect sizes from different analyses with labels moved to right"""
    if not planned_results:
        return
    
    fig = plt.figure(figsize=(10, 6))
    ax = plt.axes([0.1, 0.15, 0.6, 0.75])  # Make main plot narrower
    
    # Extract data
    windows = list(planned_results.keys())
    cohens_d = [planned_results[w]['cohens_d'] for w in windows]
    p_values = [planned_results[w]['p_value'] for w in windows]
    
    # Create bar plot
    x_pos = np.arange(len(windows))
    colors = [COLORS['early'], COLORS['mid'], COLORS['late']]
    bars = ax.bar(x_pos, cohens_d, color=colors, alpha=0.8,
                  edgecolor='black', linewidth=1)
    
    # Add significance markers (kept on bars)
    for i, (p, d) in enumerate(zip(p_values, cohens_d)):
        if p < 0.001:
            symbol = '***'
            y_offset = 0.05 * (1 if d >= 0 else -1)
        elif p < 0.01:
            symbol = '**'
            y_offset = 0.04 * (1 if d >= 0 else -1)
        elif p < 0.05:
            symbol = '*'
            y_offset = 0.03 * (1 if d >= 0 else -1)
        elif p < 0.1:
            symbol = '†'
            y_offset = 0.02 * (1 if d >= 0 else -1)
        else:
            continue
        
        ax.text(i, d + y_offset, symbol, 
                ha='center', va='bottom' if d >= 0 else 'top', 
                fontsize=12, fontweight='bold')
    
    # Customize plot
    ax.set_xticks(x_pos)
    ax.set_xticklabels([w.split('(')[0].strip() for w in windows], fontsize=10)
    ax.set_ylabel("Cohen's d (Effect Size)", fontweight='bold')
    ax.set_title('Effect Size Comparison by Time Window\n(One-tailed t-tests)', 
                 fontsize=11, fontweight='bold')
    ax.axhline(y=0, color='black', linewidth=0.8, alpha=0.5)
    
    # Add grid
    ax.grid(True, alpha=0.2, linestyle='--', axis='y')
    
    # Create sidebar for annotations
    sidebar_ax = plt.axes([0.75, 0.15, 0.2, 0.75])
    sidebar_ax.axis('off')
    
    # Add effect size interpretation to sidebar
    es_text = 'Effect size interpretation:\n• 0.2 = Small\n• 0.5 = Medium\n• 0.8 = Large'
    sidebar_ax.text(0, 0.85, es_text, 
                   fontsize=9, verticalalignment='top', fontweight='bold',
                   transform=sidebar_ax.transAxes,
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Add cluster test results if available
    if cluster_results and cluster_results['significant_clusters']:
        cluster_text = "Cluster Permutation Test:"
        for cluster in cluster_results['significant_clusters']:
            cluster_text += f"\n• p = {cluster['p_value']:.3f}"
            cluster_text += f"\n  {cluster['time_range_ms'][0]:.0f}-{cluster['time_range_ms'][1]:.0f} ms"
        
        sidebar_ax.text(0, 0.5, cluster_text,
                       fontsize=9, verticalalignment='top', fontweight='bold',
                       transform=sidebar_ax.transAxes,
                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Add statistical summary
    stats_text = f"Statistical Summary:\nn = {planned_results[windows[0]]['n_subjects']} subjects"
    for w in windows:
        stats_text += f"\n{w.split('(')[0].strip()}: t={planned_results[w]['t_statistic']:.2f}"
    
    sidebar_ax.text(0, 0.2, stats_text,
                   fontsize=9, verticalalignment='top', fontweight='bold',
                   transform=sidebar_ax.transAxes,
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, 'figure_effect_size_comparison.png'), dpi=600)
    plt.savefig(os.path.join(fig_dir, 'figure_effect_size_comparison.pdf'))
    plt.show()
    print(f"✓ Saved figure: {fig_dir}/figure_effect_size_comparison.png")

def plot_theta_topography_standalone(data, times, channels, info):
    """
    Create standalone topography figure only (no time course)
    Based on your provided function, adapted for your data structure
    """
    print("\nCreating standalone topography figure...")
    
    # Create time windows dictionary
    time_windows_topo = {
        'early': (0.10, 0.20),
        'mid': (0.20, 0.30),
        'late': (0.30, 0.40)
    }
    
    # Calculate channel-wise averages for each time window
    # data shape: (n_subjects, n_channels, n_times)
    topo_data = {}
    for wname, (tmin, tmax) in time_windows_topo.items():
        time_indices = (times >= tmin) & (times <= tmax)
        # Average across time and subjects: (n_channels,)
        # Note: data is already emotional - neutral differences
        topo_data[wname] = np.mean(data[:, :, time_indices], axis=(0, 2))
    
    print(f"Topography data computed for {len(topo_data)} windows")
    print(f"Data shape per window: {topo_data['early'].shape if 'early' in topo_data else 'N/A'}")
    
    # Get only the channels we actually have data for
    # Your data has 5 channels, but the info might have all 24
    # We need to extract positions only for the actual channels in our data
    actual_channels = channels  # These are the 5 common channels
    
    # Create a simple info structure for our actual channels
    from mne import create_info
    from mne.channels import make_dig_montage
    
    try:
        # Try to extract positions for actual channels only
        ch_positions = []
        ch_names_in_info = info['ch_names']
        
        for ch in actual_channels:
            if ch in ch_names_in_info:
                idx = ch_names_in_info.index(ch)
                loc = info['chs'][idx]['loc'][:3]
                ch_positions.append(loc)
            else:
                # If channel not found, use dummy position
                ch_positions.append([0, 0, 0])
                print(f"⚠️ Channel {ch} not found in info")
        
        ch_positions = np.array(ch_positions)
        print(f"✓ Extracted positions for {len(ch_positions)} actual channels")
        
    except Exception as e:
        print(f"⚠️ Could not extract channel positions: {e}")
        print("Using dummy positions in a circle")
        # Create dummy positions in a circle for the actual number of channels
        n_channels = len(actual_channels)
        angles = np.linspace(0, 2*np.pi, n_channels, endpoint=False)
        ch_positions = np.column_stack([np.cos(angles), np.sin(angles), np.zeros(n_channels)])
    
    # Set up publication style
    mpl.rcParams.update({
        "font.size": 10,
        "font.family": "sans-serif",
        "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
        "axes.linewidth": 0.8,
        "pdf.fonttype": 42,
        "ps.fonttype": 42
    })
    
    # ----------------- CREATE FIGURE -----------------
    # Increased figure height from 6 to 8 inches
    fig = plt.figure(figsize=(12, 8))  # Increased height to make plot taller
    
    # ----------------- TOPOGRAPHIES -----------------
    # Prepare topography data
    all_topo_vals = np.concatenate([topo_data[w][:, np.newaxis] for w in topo_data], axis=1)
    vmax = np.max(np.abs(all_topo_vals)) * 1.05
    vmin = -vmax
    
    # Colors for time windows
    color_map = {
        "early": "#B2182B",  # blue-ish
        "mid":   "#EF8A62",  # orange-ish
        "late":  "#4A6FE3",  # red-ish
    }
    
    # Topography positions - adjusted for taller figure
    # Increased y-positions and heights to use more vertical space
    topo_positions = [
        [0.05, 0.50, 0.25, 0.45],  # early: increased height from 0.40 to 0.45, moved up from 0.45 to 0.50
        [0.37, 0.50, 0.25, 0.45],  # mid: increased height from 0.40 to 0.45, moved up from 0.45 to 0.50
        [0.69, 0.50, 0.25, 0.45],  # late: increased height from 0.40 to 0.45, moved up from 0.45 to 0.50
    ]
    
    # Plot topographies
    for wname, pos in zip(time_windows_topo.keys(), topo_positions):
        ax = fig.add_axes(pos)
        avg = topo_data[wname]
        
        # Make sure we have the right number of values
        if len(avg) != ch_positions.shape[0]:
            print(f"⚠️ Mismatch: {len(avg)} values but {ch_positions.shape[0]} positions for {wname}")
            # Truncate or pad to match
            if len(avg) < ch_positions.shape[0]:
                # Pad with zeros
                avg = np.pad(avg, (0, ch_positions.shape[0] - len(avg)), 'constant')
            else:
                # Truncate
                avg = avg[:ch_positions.shape[0]]
        
        try:
            # Try to use MNE's plot_topomap with vlim parameter
            from mne.viz import plot_topomap
            try:
                # Try with vlim parameter (newer MNE versions)
                im, _ = plot_topomap(avg, ch_positions[:, :2], axes=ax, show=False,
                                     vlim=(vmin, vmax), cmap="RdBu_r", 
                                     contours=0, sensors=True, sphere=None)
            except TypeError:
                # Try with vmin/vmax parameters (older MNE versions)
                im, _ = plot_topomap(avg, ch_positions[:, :2], axes=ax, show=False,
                                     vmin=vmin, vmax=vmax, cmap="RdBu_r", 
                                     contours=0, sensors=True, sphere=None)
            
            # Add window label with time range
            time_range = f"{time_windows_topo[wname][0]*1000:.0f}-{time_windows_topo[wname][1]*1000:.0f} ms"
            ax.set_title(f"{wname.capitalize()} Window\n({time_range})", 
                        fontsize=11, fontweight='bold', color=color_map[wname], pad=15)
            
        except Exception as e:
            print(f"⚠️ Could not plot topography for {wname}: {e}")
            # Fallback: create a simple scatter plot
            ax.scatter(ch_positions[:, 0], ch_positions[:, 1], c=avg, 
                      cmap="RdBu_r", vmin=vmin, vmax=vmax, s=100)
            ax.set_aspect('equal')
            ax.axis('off')
            ax.set_title(f"{wname.capitalize()} Window\n({time_range})", 
                        fontsize=11, fontweight='bold', color=color_map[wname], pad=15)
        
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_edgecolor('#CCCCCC')
            spine.set_linewidth(0.5)
    
    # ----------------- COLORBAR -----------------
    # Move colorbar up to be closer to topographies
    cax = fig.add_axes([0.25, 0.35, 0.5, 0.04])  # Moved up from 0.25 to 0.35
    sm = mpl.cm.ScalarMappable(cmap="RdBu_r", norm=mpl.colors.Normalize(vmin=vmin, vmax=vmax))
    cbar = fig.colorbar(sm, cax=cax, orientation="horizontal")
    cbar.set_label("Theta Power Difference (dB)\nEmotional - Neutral", 
                   fontsize=11, fontweight='bold')
    cbar.ax.tick_params(labelsize=10, length=5, width=1)
    
    # ----------------- TITLE -----------------
    # Move title up higher (from y=0.98 to y=1.02)
    fig.suptitle('Theta-Band (4-8 Hz) Topographic Distribution\nacross A Priori Time Windows', 
                 fontsize=13, fontweight='bold', y=1.02)  # Increased from 0.98 to 1.02
    
    # ----------------- ANALYSIS PARAMETERS -----------------
    # Adjust parameters box position for taller figure
    params_text = f"""Analysis Parameters:
• Theta band: {theta_band[0]}-{theta_band[1]} Hz
• Baseline: {baseline[0]*1000:.0f}-{baseline[1]*1000:.0f} ms
• Subjects: {data.shape[0]}
• Channels: {len(actual_channels)} posterior sensors
• Time windows: 100-200, 200-300, 300-400 ms"""
    
    ax_params = fig.add_axes([0.05, 0.08, 0.90, 0.12])  # Moved up from 0.05 to 0.08
    ax_params.axis('off')
    ax_params.text(0, 1, params_text, fontsize=9,
                   verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8,
                            edgecolor='black', linewidth=0.5))
    
    # Adjust layout to ensure everything fits
    plt.subplots_adjust(top=0.92)  # Adjust top margin for title
    
    # ----------------- SAVE FIGURE -----------------
    plt.tight_layout()
    fig.savefig(os.path.join(fig_dir, "figure_theta_topography_standalone.png"), 
                dpi=600, bbox_inches='tight')
    fig.savefig(os.path.join(fig_dir, "figure_theta_topography_standalone.pdf"), 
                bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved standalone topography figure: {fig_dir}/figure_theta_topography_standalone.png")

# ------------------------ MAIN ANALYSIS ------------------------

def main():
    """Main analysis pipeline"""
    print("\n" + "="*80)
    print("THETA BAND OSCILLATORY POWER ANALYSIS - EMOTIONAL FACE PROCESSING")
    print("="*80)
    
    # 1. Load and process data
    print("\n1. LOADING AND PROCESSING DATA")
    data, times, channels, info, subject_ids = collect_subject_differences()
    
    if data is None:
        print("❌ Failed to load data. Exiting.")
        return
    
    # 2. Planned comparisons in a priori time windows
    print("\n2. PLANNED COMPARISONS (A PRIORI TIME WINDOWS)")
    planned_results = run_planned_comparisons(data, times, channels)
    
    # 3. Cluster permutation test
    print("\n3. CLUSTER-BASED PERMUTATION TEST")
    print("   (Maris & Oostenveld, 2007)")
    cluster_results = run_cluster_permutation_test(data, times, n_permutations=1000)
    
    # 4. Create figures
    print("\n4. CREATING PUBLICATION-QUALITY FIGURES")
    
    # Figure 1: Original style grand average time course
    create_figure_grand_average(data, times, channels, cluster_results)
    
    # Figure 2: Old-style grand average time course
    create_figure_grand_average_old_style(data, times, channels, cluster_results)
    
    # Figure 3: Old-style effect sizes
    create_effect_size_figure_old_style(planned_results)
    
    # Figure 4: Individual differences
    create_figure_individual_differences(data, times, subject_ids)
    
    # Figure 5: Effect size comparison
    create_figure_effect_size_comparison(planned_results, cluster_results)
    
    # Figure 6: Topography figure (integrated from your provided function)
    plot_theta_topography_standalone(data, times, channels, info)
    
    # Figure 7: Cluster topography (only if significant clusters found)
    if cluster_results['significant_clusters']:
        create_figure_cluster_topography(data, times, channels, info, cluster_results)
    else:
        print("⚠️  No significant clusters found - skipping topography figure")
    
    # Figure 8: Composite figure (old style)
    create_composite_results_figure_old_style(data, times, channels, planned_results, cluster_results)
    
    # 5. Generate summary report
    print("\n5. GENERATING STATISTICAL SUMMARY")
    print("\n" + "="*80)
    print("SUMMARY OF KEY FINDINGS")
    print("="*80)
    
    print("\nA. PLANNED COMPARISONS IN A PRIORI TIME WINDOWS:")
    for window_name, results in planned_results.items():
        print(f"\n  {window_name}:")
        print(f"    Mean difference: {results['mean_difference']:.4f} dB")
        print(f"    t({results['n_subjects']-1}) = {results['t_statistic']:.3f}, p = {results['p_value']:.3f}")
        print(f"    Cohen's d = {results['cohens_d']:.3f}")
        print(f"    Time range: {results['time_window'][0]*1000:.0f}-{results['time_window'][1]*1000:.0f} ms")
    
    print(f"\nB. CLUSTER PERMUTATION TEST:")
    print(f"    Analysis window: {cluster_results['analysis_window'][0]*1000:.0f}-{cluster_results['analysis_window'][1]*1000:.0f} ms")
    print(f"    Cluster-forming threshold: t > {cluster_results['threshold']}")
    print(f"    Permutations: {cluster_results['n_permutations']}")
    print(f"    Total clusters examined: {cluster_results['n_clusters']}")
    print(f"    Significant clusters (p < 0.05): {cluster_results['n_significant']}")
    
    if cluster_results['significant_clusters']:
        for cluster in cluster_results['significant_clusters']:
            print(f"\n    Significant cluster {cluster['cluster_id']}:")
            print(f"      p-value: {cluster['p_value']:.4f}")
            print(f"      Time window: {cluster['time_range_ms'][0]:.0f}-{cluster['time_range_ms'][1]:.0f} ms")
            print(f"      Duration: {cluster['duration_ms']:.0f} ms")
            print(f"      Sensors involved: {cluster['n_sensors']}")
            print(f"      Cluster mass: {cluster['cluster_mass']:.1f}")
    
    print(f"\nC. INDIVIDUAL VARIABILITY:")
    subject_means = np.mean(np.mean(data, axis=1), axis=1)
    print(f"    Participants with Emotional > Neutral: {np.sum(subject_means > 0)}")
    print(f"    Participants with Neutral > Emotional: {np.sum(subject_means < 0)}")
    print(f"    Mean ± SEM: {np.mean(subject_means):.4f} ± {stats.sem(subject_means):.4f} dB")
    
    # Shapiro-Wilk test for normality
    shapiro_stat, shapiro_p = stats.shapiro(subject_means)
    print(f"    Shapiro-Wilk normality test: W = {shapiro_stat:.3f}, p = {shapiro_p:.3f}")
    
    print("\n" + "="*80)
    print("INTERPRETATION AND FUTURE DIRECTIONS")
    print("="*80)
    
    print("""
INTERPRETATION GUIDANCE:
• The cluster permutation test evaluates whether the probability distributions
  for emotional and neutral conditions differ significantly across space and time.
• A significant cluster indicates that the null hypothesis (no difference between
  conditions) can be rejected, but does NOT indicate effects at specific time points.
• Effect sizes should be considered alongside p-values for practical significance.

KEY FINDINGS:
1. Minimal theta power differences between emotional and neutral faces
   across a priori time windows.
2. Significant cluster detected spanning the analysis window, but with
   negligible effect size (Cohen's d = -0.070).
3. High inter-subject variability in response patterns.
4. Non-normal distribution of individual differences supports use of
   non-parametric statistics.

FUTURE DIRECTIONS:
1. Larger sample sizes needed to reliably detect small effects
2. Explore individual differences more systematically
3. Consider alternative frequency bands or analysis approaches
4. Replicate findings with different emotional stimuli paradigms
""")
    
    print(f"\n✓ Analysis complete! Figures saved to: {fig_dir}/")
    print("✓ Check the directory for publication-ready figures in PNG and PDF format.")

if __name__ == "__main__":
    main()