In [None]:
import mne
import os
import numpy as np
import zipfile
from mne.preprocessing import ICA
from mne.time_frequency import tfr_morlet
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from mne_icalabel import label_components
import Rest_State_Crop as rsc
import Test_Crop as tc
from mne.time_frequency import AverageTFR
from mne.stats import permutation_cluster_1samp_test
from neurora.stuff import (
    clusterbased_permutation_1d_1samp_1sided, 
    permutation_test,
    clusterbased_permutation_2d_1samp_2sided,
    clusterbased_permutation_2d_2sided
)
import sys  # Used to exit the program

matplotlib.use('QtAgg')

stimuli_type = ['Reward_Cases','Punish_Cases','Reward_Avatar','Punish_Avatar']

def channel_cut_todata(power, channels):
    """
    Extract data for specific channels from the given TFR object.
    
    Parameters:
    - power: MNE TFR object containing time-frequency analysis data.
    - channels: List of channel names to extract.
    
    Returns:
    - selected_power_data: Extracted channel data with shape (len(channels), n_freqs, n_times).
    """
    try:
        print(f"Requested channels for extraction: {channels}")
        
        # Get available channel names
        channel_names = power.info['ch_names']
        print(f"Available channels in power object: {channel_names}")
        
        # Check if all requested channels exist
        missing_channels = [ch for ch in channels if ch not in channel_names]
        if missing_channels:
            raise ValueError(f"The following channels are missing: {missing_channels}")

        # Get channel indices
        cn_indices = [channel_names.index(ch) for ch in channels]
        print(f"Indices of requested channels: {cn_indices}")

        # Extract channel data
        selected_power_data = power.data[cn_indices]
        print(f"Extracted data shape: {selected_power_data.shape}")
        
        return selected_power_data
    except ValueError as e:
        print(f"Error: {e}")
        raise
    except AttributeError as e:
        print("Error: The power object does not have the required attributes. Please check the input.")
        raise
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        raise

def link_participants_data(dir: str, type: str):
    tfa_results = {}
    full_dir = os.path.join(dir, type)
    if not os.path.exists(full_dir):
        raise ValueError(f"Directory does not exist: {full_dir}")

    for item in os.listdir(full_dir):
        if item.endswith('.fif'):
            epochs_fname = os.path.join(full_dir, item)
            print(f"Loading file: {epochs_fname}")
            try:
                epochs = mne.read_epochs(epochs_fname, preload=True)

                # Validate epochs integrity
                if epochs.info is None:
                    raise ValueError(f"Epochs info is None for file: {item}")

                freqs = np.logspace(*np.log10([4, 30]), num=10)
                n_cycles = freqs / 2.
                power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True, return_itc=False)

                # Validate TFR data integrity
                if power.data is None:
                    raise ValueError(f"TFR data is None for file: {item}")
                if power.info is None:
                    raise ValueError(f"TFR info is None for file: {item}")
                if power.times is None:
                    raise ValueError(f"TFR times are None for file: {item}")
                if power.freqs is None:
                    raise ValueError(f"TFR freqs are None for file: {item}")

                # Apply baseline correction
                baseline = (-0.3, -0.1)  # Adjust baseline time window as needed
                print(f"Applying baseline correction: {baseline}")
                power.apply_baseline(baseline=baseline, mode='logratio')

                tfa_results[item] = power
            except Exception as e:
                print(f"Error processing file {item}: {e}")
    return tfa_results

def save_progress(figname, data, msg="Saving current progress due to error"):
    """Save current progress to a file for debugging and recovery"""
    try:
        print(msg)
        np.save(figname + "_progress.npy", data)  # Save data progress as a .npy file
        print(f"Progress saved to {figname}_progress.npy")
    except Exception as e:
        print(f"Error while saving progress: {e}")

def plot_tfr_results(figname, tfr, freqs, times, p=0.05, clusterp=0.05, clim=[-4, 4]):
    """
    Parameters:
    tfr : Matrix with shape [n_channels, n_freqs, n_times], representing time-frequency analysis results
    freqs : Array with shape [n_freqs], representing the frequencies of the time-frequency analysis (corresponding to the y-axis frequency range and points)
    times : Array with shape [n_times], representing the time points of the time-frequency analysis (corresponding to the x-axis time range and points)
    p : A floating-point number, default is 0.01, representing the p-value threshold
    clusterp : A floating-point number, default is 0.05, representing the cluster-level p-value threshold
    clim : A list or array, [minimum, maximum], default is [-4, 4], representing the upper and lower bounds of the color bar
    """
    try:
        print("=== Debug: Starting plot_tfr_results ===")
        print(f"Input tfr shape: {tfr.shape}")
        print(f"Freqs: {freqs}")
        print(f"Times: {times}")
        print(f"P-value threshold: {p}")
        print(f"Cluster p-value threshold: {clusterp}")
        print(f"Color limits: {clim}")
        print(f"Saving figure as: {figname}")

        n_channels, n_freqs, n_times = tfr.shape
        print(f"Channels: {n_channels}, Freqs: {n_freqs}, Times: {n_times}")

        # Average over channels
        print("Averaging over channels...")
        tfr_mean = np.mean(tfr, axis=0)  # shape will be (n_freqs, n_times)
        print(f"Averaged TFR shape: {tfr_mean.shape}")

        # Statistical analysis
        print("Performing cluster-based permutation test...")
        T_obs, clusters, cluster_p_values, H0 = permutation_cluster_1samp_test(
            tfr, n_permutations=1000, threshold=p, tail=0, n_jobs=1)
        print("Cluster-based permutation test completed.")
        print(f"Number of clusters found: {len(clusters)}")
        print(f"Cluster p-values: {cluster_p_values}")

        # Create significance matrix
        stats_results = np.zeros((n_freqs, n_times))
        for cl, p_val in zip(clusters, cluster_p_values):
            if p_val < clusterp:
                stats_results[cl] = 1  # Mark significant clusters
        print(f"Significance matrix created. Shape: {stats_results.shape}")

        # Visualize time-frequency analysis results
        print("Creating heatmap and contour plot...")
        fig, ax = plt.subplots(1, 1)
        
        # Outline significant regions
        padsats_results = np.zeros([n_freqs + 2, n_times + 2])
        padsats_results[1:n_freqs + 1, 1:n_times + 1] = stats_results
        x = np.concatenate(([times[0] - 1], times, [times[-1] + 1]))
        y = np.concatenate(([freqs[0] - 1], freqs, [freqs[-1] + 1]))
        X, Y = np.meshgrid(x, y)
        ax.contour(X, Y, padsats_results, [0.5], colors="red", alpha=0.9,
                   linewidths=2, linestyles="dashed")

        # Plot heatmap of time-frequency results
        im = ax.imshow(tfr_mean, cmap='RdBu_r', origin='lower',
                       extent=[times[0], times[-1], freqs[0], freqs[-1]], clim=clim)
        ax.set_aspect('auto')
        cbar = fig.colorbar(im)
        cbar.set_label('dB', fontsize=12)
        ax.set_xlabel('Time (s)', fontsize=16)
        ax.set_ylabel('Frequency (Hz)', fontsize=16)

        # Save the figure
        # Fix save path to avoid redundant 'TFA/' and '.svg'
        # Get original filename and directory
        directory = os.path.dirname(figname)
        base_name = os.path.basename(figname)
        # Add 'clustered_' prefix
        clustered_name = f'clustered_{base_name}'
        # Construct new full path
        clustered_figname = os.path.join(directory, clustered_name)
        print(f"Saving heatmap to: {clustered_figname}")
        plt.savefig(clustered_figname, dpi=600)
        plt.close(fig)  # Close the current figure for subsequent plotting

        print("Plotting completed successfully.")
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, tfr)
        print(f"Error in plot_tfr_results: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during plotting. Process stopped and progress saved.")

def plot_TFA(dict_object: dict, channels: list, title: str, figname:str):
    """Plot TFA images"""
    try:
        print(f"Starting plot_TFA for {figname}")
        tfr_data = [tfr.data for tfr in dict_object.values()]
        mean_tfr_data = np.mean(tfr_data, axis=0)

        Navi_list = list(dict_object.values())
        info = Navi_list[0].info  # Use the info from the first TFR
        mean_tfr = AverageTFR(info, mean_tfr_data, Navi_list[0].times,
                             Navi_list[0].freqs, len(Navi_list))
        fig, ax = plt.subplots(figsize=(10, 6))
        fig = mean_tfr.plot(picks=channels, baseline=(-0.3, -0.1), mode='logratio', title=title, combine='mean', axes=ax)
        print(fig)
        # fig.suptitle(suptitle)
        plt.show()
        fulfigname = os.path.join('TFA', f"{figname}.svg")
        fig[0].savefig(fulfigname)  # Save as SVG format
        plt.close(fig)  # Close the figure
        return 0
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, dict_object)
        print(f"Error in plot_TFA: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during TFA plotting. Process stopped and progress saved.")

def plot_avg_heatmap(dict_object: dict, channels: list, title: str, figname: str, output_dir: str, sfreq: float):
    """
    Calculate the average based on the existing TFR data dictionary and plot a heatmap.

    Parameters:
        dict_object (dict): Dictionary containing TFR objects.
        channels (list): List of channels to extract data from.
        title (str): Chart title.
        figname (str): Filename for saving the image.
        output_dir (str): Directory for image output.
        sfreq (float): Data sampling frequency.
    """
    try:
        print(f"Starting plot_avg_heatmap for {figname}")

        # Validate input
        if not dict_object:
            raise ValueError("Input dictionary is empty. Please check your input data.")
        print(f"Input TFR dictionary size: {len(dict_object)}")

        # Filter valid TFR objects
        valid_tfrs = {key: tfr for key, tfr in dict_object.items() if tfr and tfr.data is not None}
        print(f"Number of valid TFR objects: {len(valid_tfrs)}")

        if not valid_tfrs:
            raise ValueError("No valid TFR objects found.")

        # Extract all valid TFR data and calculate the average
        tfr_data = [tfr.data for tfr in valid_tfrs.values()]
        mean_tfr_data = np.mean(tfr_data, axis=0)
        print(f"Mean TFR data shape: {mean_tfr_data.shape}")
        print(f"Mean TFR data (sample): {mean_tfr_data.flatten()[:10]}")

        # Extract frequency and time information
        Navi_list = list(valid_tfrs.values())
        first_tfr = Navi_list[0]
        freqs = first_tfr.freqs
        times = first_tfr.times
        print(f"Freqs shape: {freqs.shape}, values: {freqs}")
        print(f"Times shape: {times.shape}, values: {times[:10]}")

        # Validate channel information
        channel_names = first_tfr.info['ch_names']
        print(f"Available channels in TFR: {channel_names}")
        missing_channels = [ch for ch in channels if ch not in channel_names]
        if missing_channels:
            raise ValueError(f"Missing channels from TFR data: {missing_channels}")
        
        # Extract data for specified channels
        try:
            print(f"Extracting data for channels: {channels}")
            cn_indices = [channel_names.index(ch) for ch in channels]
            selected_power_data = mean_tfr_data[cn_indices]
            print(f"Selected power data shape: {selected_power_data.shape}")
        except Exception as e:
            print(f"Error during channel data extraction: {e}")
            raise

        # Plotting
        try:
            print(f"Calling plot_tfr_results for {figname}...")
            # Construct save path
            clustered_figname = os.path.join(output_dir, f"{figname}.svg")
            plot_tfr_results(clustered_figname.replace('.svg', ''), selected_power_data, freqs, times, p=0.01, clusterp=0.05, clim=[-1, 1])
            print(f"Heatmap saved to {os.path.join(output_dir, 'clustered_' + figname + '.svg')}")
        except Exception as e:
            # If an error occurs, save progress and exit the program
            save_progress(figname, selected_power_data)
            print(f"Error while plotting heatmap for {figname}: {e}")
            import traceback
            print(traceback.format_exc())
            sys.exit("Error occurred while plotting heatmap. Process stopped and progress saved.")
    except Exception as e:
        # If an error occurs during data extraction or processing, save progress and exit the program
        save_progress(figname, dict_object)
        print(f"Error in plot_avg_heatmap: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred in plot_avg_heatmap. Process stopped and progress saved.")

def plot_tfr_diff_results(figname, tfr1, tfr2, freqs, times, p=0.05, clusterp=0.05, clim=[-2, 2]):
    """
    Parameters:
    tfr1 : Matrix with shape [n_subs, n_freqs, n_times], representing time-frequency analysis results for condition 1
    tfr2 : Matrix with shape [n_subs, n_freqs, n_times], representing time-frequency analysis results for condition 2
    freqs : Array with shape [n_freqs], representing the frequencies of the time-frequency analysis (corresponding to the y-axis frequency range and points)
    times : Array with shape [n_times], representing the time points of the time-frequency analysis (corresponding to the x-axis time range and points)
    p : A floating-point number, default is 0.01, representing the p-value threshold
    clusterp : A floating-point number, default is 0.05, representing the cluster-level p-value threshold
    clim : A list or array, [minimum, maximum], default is [-2, 2], representing the upper and lower bounds of the color bar
    """
    try:
        print("=== Debug: Starting plot_tfr_diff_results ===")
        print(f"tfr1 shape: {tfr1.shape}")
        print(f"tfr2 shape: {tfr2.shape}")
        print(f"Freqs: {freqs}")
        print(f"Times: {times}")
        print(f"P-value threshold: {p}")
        print(f"Cluster p-value threshold: {clusterp}")
        print(f"Color limits: {clim}")
        print(f"Saving figure as: {figname}")

        n_freqs = len(freqs)
        n_times = len(times)
        print(f"Number of frequencies: {n_freqs}, Number of times: {n_times}")

        # Statistical analysis
        print("Performing cluster-based permutation test for differences...")
        stats_results = clusterbased_permutation_2d_2sided(
            tfr1, tfr2,
            p_threshold=p,
            clusterp_threshold=clusterp,
            iter=1000
        )
        print("Cluster-based permutation test for differences completed.")
        print(f"Stats results shape: {stats_results.shape}")

        # Calculate TFR difference
        tfr_diff = np.mean(tfr1, axis=0) - np.mean(tfr2, axis=0) 
        print(f"TFR difference shape: {tfr_diff.shape}")

        # Visualize time-frequency analysis results
        print("Creating difference heatmap and contour plot...")
        fig, ax = plt.subplots(1, 1)

        # Outline significant regions
        padsats_results = np.zeros([n_freqs + 2, n_times + 2])
        padsats_results[1:n_freqs + 1, 1:n_times + 1] = stats_results
        x = np.concatenate(([times[0]-1], times, [times[-1]+1]))
        y = np.concatenate(([freqs[0]-1], freqs, [freqs[-1]+1]))
        X, Y = np.meshgrid(x, y)
        ax.contour(X, Y, padsats_results, [0.5], colors="red", alpha=0.9,
                   linewidths=2, linestyles="dashed")
        ax.contour(X, Y, padsats_results, [-0.5], colors="blue", alpha=0.9,
                   linewidths=2, linestyles="dashed")

        # Plot heatmap of time-frequency results
        im = ax.imshow(tfr_diff, cmap='RdBu_r', origin='lower',
                      extent=[times[0], times[-1], freqs[0], freqs[-1]], clim=clim)
        ax.set_aspect('auto')
        cbar = fig.colorbar(im)
        cbar.set_label('$\Delta$dB', fontsize=12)
        ax.set_xlabel('Time (ms)', fontsize=16)
        ax.set_ylabel('Frequency (Hz)', fontsize=16)
        # plt.show()

        # Save the figure
        # Fix save path to avoid redundant 'TFA/' and '.svg'
        # Get original filename and directory
        directory = os.path.dirname(figname)
        base_name = os.path.basename(figname)
        # Add '_diff' suffix
        diff_name = f"{os.path.splitext(base_name)[0]}_diff.svg"
        # Construct new full path
        diff_figname = os.path.join(directory, diff_name)
        print(f"Saving difference heatmap to: {diff_figname}")
        plt.savefig(diff_figname, dpi=600)
        plt.close(fig)  # Close the current figure

        print("Difference plotting completed successfully.")
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, {'tfr1': tfr1, 'tfr2': tfr2})
        print(f"Error in plot_tfr_diff_results: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during difference plotting. Process stopped and progress saved.")

def plot_avg_heatmap_diff(dict_object1: dict, dict_object2: dict, title: str, figname: str, output_dir: str, sfreq: float):
    """
    Calculate and plot the TFR difference heatmap between two conditions.

    Parameters:
        dict_object1 (dict): TFR data dictionary for condition 1.
        dict_object2 (dict): TFR data dictionary for condition 2.
        title (str): Chart title.
        figname (str): Filename for saving the image.
        output_dir (str): Directory for image output.
        sfreq (float): Data sampling frequency.
    """
    try:
        print(f"Starting plot_avg_heatmap_diff for {figname}")
        
        if not dict_object1 or not dict_object2:
            raise ValueError("One of the input dictionaries is empty.")
        print("Both input dictionaries are non-empty.")

        # Extract TFR data
        print("Extracting TFR data for both conditions...")
        tfr1_data = [tfr.data for tfr in dict_object1.values()]
        tfr2_data = [tfr.data for tfr in dict_object2.values()]

        if not tfr1_data or not tfr2_data:
            raise ValueError("No valid TFR data available for one or both conditions.")
        print(f"Number of TFR1 data: {len(tfr1_data)}, Number of TFR2 data: {len(tfr2_data)}")

        # Check data consistency
        shapes1 = [tfr.shape for tfr in tfr1_data]
        shapes2 = [tfr.shape for tfr in tfr2_data]
        print(f"TFR1 data shapes: {shapes1}")
        print(f"TFR2 data shapes: {shapes2}")
        if len(set(shapes1)) > 1 or len(set(shapes2)) > 1:
            raise ValueError("Inconsistent TFR data shapes within conditions.")
        print("All TFR data within each condition have consistent shapes.")

        # Difference calculation
        print("Calculating TFR difference...")
        mean_tfr1_data = np.mean(tfr1_data, axis=0)
        mean_tfr2_data = np.mean(tfr2_data, axis=0)
        print(f"Mean TFR1 shape: {mean_tfr1_data.shape}, Mean TFR2 shape: {mean_tfr2_data.shape}")

        # Validate frequency and time
        Navi_list1 = list(dict_object1.values())
        first_tfr1 = Navi_list1[0]
        freqs = first_tfr1.freqs
        times = first_tfr1.times
        print(f"Freqs: {freqs}, Times: {times}")

        try:
            # Plot difference heatmap
            print("Calling plot_tfr_diff_results...")
            # Construct save path
            diff_figname = os.path.join(output_dir, f"{figname}.svg")
            plot_tfr_diff_results(diff_figname.replace('.svg', ''), mean_tfr1_data, mean_tfr2_data, freqs, times, p=0.01, clusterp=0.05, clim=[-2, 2])
            print(f"Difference heatmap saved for {diff_figname}")
        except Exception as e:
            # If an error occurs, save progress and exit the program
            save_progress(figname, {'mean_tfr1': mean_tfr1_data, 'mean_tfr2': mean_tfr2_data})
            print(f"Error while plotting difference heatmap for {figname}: {e}")
            import traceback
            print(traceback.format_exc())
            sys.exit("Error occurred while plotting difference heatmap. Process stopped and progress saved.")
    except Exception as e:
        # If an error occurs during data extraction or processing, save progress and exit the program
        save_progress(figname, {'dict_object1': dict_object1, 'dict_object2': dict_object2})
        print(f"Error in plot_avg_heatmap_diff: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred in plot_avg_heatmap_diff. Process stopped and progress saved.")

def get_avg_power(power, channels, wave, tmin=0, tmax=4):
    hmin = 0
    hmax = 0
    if wave == 'theta':
        hmin = 4
        hmax = 8
    elif wave == 'alpha':
        hmin = 8
        hmax = 13
    elif wave == 'delta':
        pass
    else:
        # beta
        hmin = 13
        hmax = 30

    roi_power = power.copy().pick_channels(channels)
    roi_alpha_band = roi_power.copy().crop(fmin=hmin, fmax=hmax, tmin=tmin, tmax=tmax)
    average_power = roi_alpha_band.data.mean()
    return average_power

def get_group_power(dict_type, channels, wave):
    powerlist = []
    for subject in range(1,60):
        print(f"Processing subject {subject}")
        strsub = 's'+str(subject+300) + '-epo.fif'
        if strsub not in dict_type.keys():
            print(f"{strsub} not found in dictionary. Assigning value 1.")
            powerlist.append(1)
        else:

            power = dict_type.get(strsub)
            
            avg =  get_avg_power(power, channels=channels, wave=wave)
            powerlist.append(avg)
    return powerlist

if __name__ == '__main__':
    try:
        # Define data directory
        ascent_air = 'Test_Epochs_ASRed/'
        output_dir = r

        # Check if output directory exists
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")

        # Load data
        print("Loading data...")
        R_Case_dict = link_participants_data(ascent_air, 'Reward_Cases')
        P_Case_dict = link_participants_data(ascent_air, 'Punish_Cases')
        R_Avatar_dict = link_participants_data(ascent_air, 'Reward_Avatar')
        P_Avatar_dict = link_participants_data(ascent_air, 'Punish_Avatar')

        # Confirm successful data loading
        if not (R_Case_dict and P_Case_dict and R_Avatar_dict and P_Avatar_dict):
            raise ValueError("One or more data dictionaries are empty. Please check your data source.")
        print(f"Data loaded: {len(R_Case_dict)}, {len(P_Case_dict)}, {len(R_Avatar_dict)}, {len(P_Avatar_dict)}")

        # Define channels
        lTPJ = ['CP5', 'P7', 'P3']
        rTPJ = ['CP6', 'P4', 'P8']
        lFC = ['F3', 'F7', 'FC5']
        rFC = ['F4', 'F8', 'FC6']
        sfreq = 500.0
        stimuli_type = ['Reward_Cases', 'Punish_Cases', 'Reward_Avatar', 'Punish_Avatar']
        stimili_dict = [R_Case_dict, P_Case_dict, R_Avatar_dict, P_Avatar_dict]
        channelslist = [lTPJ, rTPJ, lFC, rFC]
        channelsname = ['lTPJ', 'rTPJ', 'lFC', 'rFC']
        wavelist = ['alpha', 'beta', 'theta']

        # Plot heatmaps
        print("Starting to plot heatmaps...")
        plot_avg_heatmap(P_Case_dict, lTPJ + rTPJ, "Punish_Case", "Punish_Case_sig", output_dir, sfreq)
        plot_avg_heatmap(P_Avatar_dict, lTPJ + rTPJ, "Punish_Avatar", "Punish_Ava_sig", output_dir, sfreq)
        plot_avg_heatmap(R_Case_dict, lTPJ + rTPJ, "Reward_Case", "Reward_Case_sig", output_dir, sfreq)
        plot_avg_heatmap(R_Avatar_dict, lTPJ + rTPJ, "Reward_Avatar", "Reward_Ava_sig", output_dir, sfreq)

        # Difference heatmaps
        print("Starting to plot difference heatmaps...")
        plot_avg_heatmap_diff(P_Case_dict, P_Avatar_dict, "Punish_Case - Punish_Avatar", "P_Case_P_Avatar", output_dir, sfreq)
        plot_avg_heatmap_diff(R_Case_dict, R_Avatar_dict, "Reward_Case - Reward_Avatar", "R_Case_R_Avatar", output_dir, sfreq)
        plot_avg_heatmap_diff(P_Case_dict, R_Case_dict, "Punish_Case - Reward_Case", "PRCase", output_dir, sfreq)
        plot_avg_heatmap_diff(P_Avatar_dict, R_Avatar_dict, "Punish_Avatar - Reward_Avatar", "PRAvatar", output_dir, sfreq)

        print("All heatmaps plotted successfully.")

        # Data extraction and computation
        print("Starting to process group power...")
        sublist = [subject + 300 for subject in range(1, 60)]
        df0 = pd.DataFrame({'sub_ID': sublist})

        for stimuli_point, typename in enumerate(stimuli_type):
            typevalue = stimili_dict[stimuli_point]
            for wave in wavelist:
                for channel_point, chname in enumerate(channelsname):
                    chvalue = channelslist[channel_point]
                    column_name = f"{typename}_{chname}_{wave}"
                    print(f"Processing: {column_name}")
                    # Raise an error directly instead of filling with NaN
                    df0[column_name] = get_group_power(typevalue, chvalue, wave)

        # Save processed data
        output_csv = os.path.join(output_dir, "group_power_results.csv")
        df0.to_csv(output_csv, index=False)
        print(f"Group power results saved to: {output_csv}")

    except Exception as e:
        # If an uncaught error occurs in the main program, save progress and exit
        print(f"An error occurred in the main program: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("An error occurred in the main program. Process stopped and progress saved.")

In [None]:
import mne
import os
import numpy as np
import zipfile
from mne.preprocessing import ICA
from mne.time_frequency import tfr_morlet
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from mne_icalabel import label_components
import Rest_State_Crop as rsc
import Test_Crop as tc
from mne.time_frequency import AverageTFR
from mne.stats import permutation_cluster_1samp_test
from neurora.stuff import (
    clusterbased_permutation_1d_1samp_1sided, 
    permutation_test,
    clusterbased_permutation_2d_1samp_2sided,
    clusterbased_permutation_2d_2sided
)
import sys  # Used to exit the program

matplotlib.use('QtAgg')

stimuli_type = ['Reward_Cases','Punish_Cases','Reward_Avatar','Punish_Avatar']

def channel_cut_todata(power, channels):
    """
    Extract data for specific channels from the given TFR object.
    
    Parameters:
    - power: MNE TFR object containing time-frequency analysis data.
    - channels: List of channel names to extract.
    
    Returns:
    - selected_power_data: Extracted channel data with shape (len(channels), n_freqs, n_times).
    """
    try:
        print(f"Requested channels for extraction: {channels}")
        
        # Get available channel names
        channel_names = power.info['ch_names']
        print(f"Available channels in power object: {channel_names}")
        
        # Check if all requested channels exist
        missing_channels = [ch for ch in channels if ch not in channel_names]
        if missing_channels:
            raise ValueError(f"The following channels are missing: {missing_channels}")

        # Get channel indices
        cn_indices = [channel_names.index(ch) for ch in channels]
        print(f"Indices of requested channels: {cn_indices}")

        # Extract channel data
        selected_power_data = power.data[cn_indices]
        print(f"Extracted data shape: {selected_power_data.shape}")
        
        return selected_power_data
    except ValueError as e:
        print(f"Error: {e}")
        raise
    except AttributeError as e:
        print("Error: The power object does not have the required attributes. Please check the input.")
        raise
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        raise

def link_participants_data(dir: str, type: str):
    tfa_results = {}
    full_dir = os.path.join(dir, type)
    if not os.path.exists(full_dir):
        raise ValueError(f"Directory does not exist: {full_dir}")

    for item in os.listdir(full_dir):
        if item.endswith('.fif'):
            epochs_fname = os.path.join(full_dir, item)
            print(f"Loading file: {epochs_fname}")
            try:
                epochs = mne.read_epochs(epochs_fname, preload=True)

                # Validate epochs integrity
                if epochs.info is None:
                    raise ValueError(f"Epochs info is None for file: {item}")

                freqs = np.logspace(*np.log10([4, 30]), num=10)
                n_cycles = freqs / 2.
                power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True, return_itc=False)

                # Validate TFR data integrity
                if power.data is None:
                    raise ValueError(f"TFR data is None for file: {item}")
                if power.info is None:
                    raise ValueError(f"TFR info is None for file: {item}")
                if power.times is None:
                    raise ValueError(f"TFR times are None for file: {item}")
                if power.freqs is None:
                    raise ValueError(f"TFR freqs are None for file: {item}")

                # Add baseline correction (commented out)
                # baseline = (-0.2, 0)  # Adjust baseline time window as needed
                # print(f"Applying baseline correction: {baseline}")
                # power.apply_baseline(baseline=baseline, mode='logratio')

                tfa_results[item] = power
            except Exception as e:
                print(f"Error processing file {item}: {e}")
    return tfa_results

def save_progress(figname, data, msg="Saving current progress due to error"):
    """Save current progress to a file for debugging and recovery"""
    try:
        print(msg)
        np.save(figname + "_progress.npy", data)  # Save data progress as a .npy file
        print(f"Progress saved to {figname}_progress.npy")
    except Exception as e:
        print(f"Error while saving progress: {e}")

def plot_tfr_results(figname, tfr, freqs, times, p=0.05, clusterp=0.05, clim=[-4, 4]):
    """
    Parameters:
    tfr : Matrix with shape [n_channels, n_freqs, n_times], representing time-frequency analysis results
    freqs : Array with shape [n_freqs], representing the frequencies of the time-frequency analysis (corresponding to the y-axis frequency range and points)
    times : Array with shape [n_times], representing the time points of the time-frequency analysis (corresponding to the x-axis time range and points)
    p : A floating-point number, default is 0.05, representing the p-value threshold
    clusterp : A floating-point number, default is 0.05, representing the cluster-level p-value threshold
    clim : A list or array, [minimum, maximum], default is [-4, 4], representing the upper and lower bounds of the color bar
    """
    try:
        print("=== Debug: Starting plot_tfr_results ===")
        print(f"Input tfr shape: {tfr.shape}")
        print(f"Freqs: {freqs}")
        print(f"Times: {times}")
        print(f"P-value threshold: {p}")
        print(f"Cluster p-value threshold: {clusterp}")
        print(f"Color limits: {clim}")
        print(f"Saving figure as: {figname}")

        n_channels, n_freqs, n_times = tfr.shape
        print(f"Channels: {n_channels}, Freqs: {n_freqs}, Times: {n_times}")

        # Average over channels
        print("Averaging over channels...")
        tfr_mean = np.mean(tfr, axis=0)  # shape will be (n_freqs, n_times)
        print(f"Averaged TFR shape: {tfr_mean.shape}")

        # Statistical analysis
        print("Performing cluster-based permutation test...")
        T_obs, clusters, cluster_p_values, H0 = permutation_cluster_1samp_test(
            tfr, n_permutations=1000, threshold=p, tail=0, n_jobs=1)
        print("Cluster-based permutation test completed.")
        print(f"Number of clusters found: {len(clusters)}")
        print(f"Cluster p-values: {cluster_p_values}")

        # Create significance matrix
        stats_results = np.zeros((n_freqs, n_times))
        for cl, p_val in zip(clusters, cluster_p_values):
            if p_val < clusterp:
                stats_results[cl] = 1  # Mark significant clusters
        print(f"Significance matrix created. Shape: {stats_results.shape}")

        # Visualize time-frequency analysis results
        print("Creating heatmap and contour plot...")
        fig, ax = plt.subplots(1, 1)
        
        # Outline significant regions
        padsats_results = np.zeros([n_freqs + 2, n_times + 2])
        padsats_results[1:n_freqs + 1, 1:n_times + 1] = stats_results
        x = np.concatenate(([times[0] - 1], times, [times[-1] + 1]))
        y = np.concatenate(([freqs[0] - 1], freqs, [freqs[-1] + 1]))
        X, Y = np.meshgrid(x, y)
        ax.contour(X, Y, padsats_results, [0.5], colors="red", alpha=0.9,
                   linewidths=2, linestyles="dashed")

        # Plot heatmap of time-frequency results
        im = ax.imshow(tfr_mean, cmap='RdBu_r', origin='lower',
                       extent=[times[0], times[-1], freqs[0], freqs[-1]], clim=clim)
        ax.set_aspect('auto')
        cbar = fig.colorbar(im)
        cbar.set_label('dB', fontsize=12)
        ax.set_xlabel('Time (s)', fontsize=16)
        ax.set_ylabel('Frequency (Hz)', fontsize=16)

        # Save the figure
        # Fix save path to avoid redundant 'TFA/' and '.svg'
        # Get original filename and directory
        directory = os.path.dirname(figname)
        base_name = os.path.basename(figname)
        # Add 'clustered_' prefix
        clustered_name = f'clustered_{base_name}'
        # Construct new full path
        clustered_figname = os.path.join(directory, clustered_name)
        print(f"Saving heatmap to: {clustered_figname}")
        plt.savefig(clustered_figname, dpi=600)
        plt.close(fig)  # Close the current figure for subsequent plotting

        print("Plotting completed successfully.")
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, tfr)
        print(f"Error in plot_tfr_results: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during plotting. Process stopped and progress saved.")

def plot_TFA(dict_object: dict, channels: list, title: str, figname:str):
    """Plot TFA images"""
    try:
        print(f"Starting plot_TFA for {figname}")
        tfr_data = [tfr.data for tfr in dict_object.values()]
        mean_tfr_data = np.mean(tfr_data, axis=0)

        Navi_list = list(dict_object.values())
        info = Navi_list[0].info  # Use the info from the first TFR
        mean_tfr = AverageTFR(info, mean_tfr_data, Navi_list[0].times,
                             Navi_list[0].freqs, len(Navi_list))
        fig, ax = plt.subplots(figsize=(10, 6))
        fig = mean_tfr.plot(picks=channels, baseline=(-0.3, -0.1), mode='logratio', title=title, combine='mean', axes=ax)
        print(fig)
        # fig.suptitle(suptitle)
        plt.show()
        fulfigname = os.path.join('TFA', f"{figname}.svg")
        fig[0].savefig(fulfigname)  # Save as SVG format
        plt.close(fig)  # Close the figure
        return 0
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, dict_object)
        print(f"Error in plot_TFA: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during TFA plotting. Process stopped and progress saved.")

def plot_avg_heatmap(dict_object: dict, channels: list, title: str, figname: str, output_dir: str, sfreq: float):
    """
    Calculate the average based on the existing TFR data dictionary and plot a heatmap.

    Parameters:
        dict_object (dict): Dictionary containing TFR objects.
        channels (list): List of channels to extract data from.
        title (str): Chart title.
        figname (str): Filename for saving the image.
        output_dir (str): Directory for image output.
        sfreq (float): Data sampling frequency.
    """
    try:
        print(f"Starting plot_avg_heatmap for {figname}")

        # Validate input
        if not dict_object:
            raise ValueError("Input dictionary is empty. Please check your input data.")
        print(f"Input TFR dictionary size: {len(dict_object)}")

        # Filter valid TFR objects
        valid_tfrs = {key: tfr for key, tfr in dict_object.items() if tfr and tfr.data is not None}
        print(f"Number of valid TFR objects: {len(valid_tfrs)}")

        if not valid_tfrs:
            raise ValueError("No valid TFR objects found.")

        # Extract all valid TFR data and calculate the average
        tfr_data = [tfr.data for tfr in valid_tfrs.values()]
        mean_tfr_data = np.mean(tfr_data, axis=0)
        print(f"Mean TFR data shape: {mean_tfr_data.shape}")
        print(f"Mean TFR data (sample): {mean_tfr_data.flatten()[:10]}")

        # Extract frequency and time information
        Navi_list = list(valid_tfrs.values())
        first_tfr = Navi_list[0]
        freqs = first_tfr.freqs
        times = first_tfr.times
        print(f"Freqs shape: {freqs.shape}, values: {freqs}")
        print(f"Times shape: {times.shape}, values: {times[:10]}")

        # Validate channel information
        channel_names = first_tfr.info['ch_names']
        print(f"Available channels in TFR: {channel_names}")
        missing_channels = [ch for ch in channels if ch not in channel_names]
        if missing_channels:
            raise ValueError(f"Missing channels from TFR data: {missing_channels}")
        
        # Extract data for specified channels
        try:
            print(f"Extracting data for channels: {channels}")
            cn_indices = [channel_names.index(ch) for ch in channels]
            selected_power_data = mean_tfr_data[cn_indices]
            print(f"Selected power data shape: {selected_power_data.shape}")
        except Exception as e:
            print(f"Error during channel data extraction: {e}")
            raise

        # Plotting
        try:
            print(f"Calling plot_tfr_results for {figname}...")
            # Construct save path
            clustered_figname = os.path.join(output_dir, f"{figname}.svg")
            plot_tfr_results(clustered_figname.replace('.svg', ''), selected_power_data, freqs, times, p=0.01, clusterp=0.05, clim=[-1, 1])
            print(f"Heatmap saved to {os.path.join(output_dir, 'clustered_' + figname + '.svg')}")
        except Exception as e:
            # If an error occurs, save progress and exit the program
            save_progress(figname, selected_power_data)
            print(f"Error while plotting heatmap for {figname}: {e}")
            import traceback
            print(traceback.format_exc())
            sys.exit("Error occurred while plotting heatmap. Process stopped and progress saved.")
    except Exception as e:
        # If an error occurs during data extraction or processing, save progress and exit the program
        save_progress(figname, dict_object)
        print(f"Error in plot_avg_heatmap: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred in plot_avg_heatmap. Process stopped and progress saved.")

def plot_tfr_diff_results(figname, tfr1, tfr2, freqs, times, p=0.05, clusterp=0.05, clim=[-2, 2]):
    """
    Parameters:
    tfr1 : Matrix with shape [n_subs, n_freqs, n_times], representing time-frequency analysis results for condition 1
    tfr2 : Matrix with shape [n_subs, n_freqs, n_times], representing time-frequency analysis results for condition 2
    freqs : Array with shape [n_freqs], representing the frequencies of the time-frequency analysis (corresponding to the y-axis frequency range and points)
    times : Array with shape [n_times], representing the time points of the time-frequency analysis (corresponding to the x-axis time range and points)
    p : A floating-point number, default is 0.05, representing the p-value threshold
    clusterp : A floating-point number, default is 0.05, representing the cluster-level p-value threshold
    clim : A list or array, [minimum, maximum], default is [-2, 2], representing the upper and lower bounds of the color bar
    """
    try:
        print("=== Debug: Starting plot_tfr_diff_results ===")
        print(f"tfr1 shape: {tfr1.shape}")
        print(f"tfr2 shape: {tfr2.shape}")
        print(f"Freqs: {freqs}")
        print(f"Times: {times}")
        print(f"P-value threshold: {p}")
        print(f"Cluster p-value threshold: {clusterp}")
        print(f"Color limits: {clim}")
        print(f"Saving figure as: {figname}")

        n_freqs = len(freqs)
        n_times = len(times)
        print(f"Number of frequencies: {n_freqs}, Number of times: {n_times}")

        # Statistical analysis
        print("Performing cluster-based permutation test for differences...")
        stats_results = clusterbased_permutation_2d_2sided(
            tfr1, tfr2,
            p_threshold=p,
            clusterp_threshold=clusterp,
            iter=1000
        )
        print("Cluster-based permutation test for differences completed.")
        print(f"Stats results shape: {stats_results.shape}")

        # Calculate TFR difference
        tfr_diff = np.mean(tfr1, axis=0) - np.mean(tfr2, axis=0) 
        print(f"TFR difference shape: {tfr_diff.shape}")

        # Visualize time-frequency analysis results
        print("Creating difference heatmap and contour plot...")
        fig, ax = plt.subplots(1, 1)

        # Outline significant regions
        padsats_results = np.zeros([n_freqs + 2, n_times + 2])
        padsats_results[1:n_freqs + 1, 1:n_times + 1] = stats_results
        x = np.concatenate(([times[0]-1], times, [times[-1]+1]))
        y = np.concatenate(([freqs[0]-1], freqs, [freqs[-1]+1]))
        X, Y = np.meshgrid(x, y)
        ax.contour(X, Y, padsats_results, [0.5], colors="red", alpha=0.9,
                   linewidths=2, linestyles="dashed")
        ax.contour(X, Y, padsats_results, [-0.5], colors="blue", alpha=0.9,
                   linewidths=2, linestyles="dashed")

        # Plot heatmap of time-frequency results
        im = ax.imshow(tfr_diff, cmap='RdBu_r', origin='lower',
                      extent=[times[0], times[-1], freqs[0], freqs[-1]], clim=clim)
        ax.set_aspect('auto')
        cbar = fig.colorbar(im)
        cbar.set_label('$\Delta$dB', fontsize=12)
        ax.set_xlabel('Time (ms)', fontsize=16)
        ax.set_ylabel('Frequency (Hz)', fontsize=16)
        # plt.show()

        # Save the figure
        # Fix save path to avoid redundant 'TFA/' and '.svg'
        # Get original filename and directory
        directory = os.path.dirname(figname)
        base_name = os.path.basename(figname)
        # Add '_diff' suffix
        diff_name = f"{os.path.splitext(base_name)[0]}_diff.svg"
        # Construct new full path
        diff_figname = os.path.join(directory, diff_name)
        print(f"Saving difference heatmap to: {diff_figname}")
        plt.savefig(diff_figname, dpi=600)
        plt.close(fig)  # Close the current figure

        print("Difference plotting completed successfully.")
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, {'tfr1': tfr1, 'tfr2': tfr2})
        print(f"Error in plot_tfr_diff_results: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during difference plotting. Process stopped and progress saved.")

def plot_avg_heatmap_diff(dict_object1: dict, dict_object2: dict, title: str, figname: str, output_dir: str, sfreq: float):
    """
    Calculate and plot the TFR difference heatmap between two conditions.

    Parameters:
        dict_object1 (dict): TFR data dictionary for condition 1.
        dict_object2 (dict): TFR data dictionary for condition 2.
        title (str): Chart title.
        figname (str): Filename for saving the image.
        output_dir (str): Directory for image output.
        sfreq (float): Data sampling frequency.
    """
    try:
        print(f"Starting plot_avg_heatmap_diff for {figname}")
        
        if not dict_object1 or not dict_object2:
            raise ValueError("One of the input dictionaries is empty.")
        print("Both input dictionaries are non-empty.")

        # Extract TFR data
        print("Extracting TFR data for both conditions...")
        tfr1_data = [tfr.data for tfr in dict_object1.values()]
        tfr2_data = [tfr.data for tfr in dict_object2.values()]

        if not tfr1_data or not tfr2_data:
            raise ValueError("No valid TFR data available for one or both conditions.")
        print(f"Number of TFR1 data: {len(tfr1_data)}, Number of TFR2 data: {len(tfr2_data)}")

        # Check data consistency
        shapes1 = [tfr.shape for tfr in tfr1_data]
        shapes2 = [tfr.shape for tfr in tfr2_data]
        print(f"TFR1 data shapes: {shapes1}")
        print(f"TFR2 data shapes: {shapes2}")
        if len(set(shapes1)) > 1 or len(set(shapes2)) > 1:
            raise ValueError("Inconsistent TFR data shapes within conditions.")
        print("All TFR data within each condition have consistent shapes.")

        # Difference calculation
        print("Calculating TFR difference...")
        mean_tfr1_data = np.mean(tfr1_data, axis=0)
        mean_tfr2_data = np.mean(tfr2_data, axis=0)
        print(f"Mean TFR1 shape: {mean_tfr1_data.shape}, Mean TFR2 shape: {mean_tfr2_data.shape}")

        # Validate frequency and time
        Navi_list1 = list(dict_object1.values())
        first_tfr1 = Navi_list1[0]
        freqs = first_tfr1.freqs
        times = first_tfr1.times
        print(f"Freqs: {freqs}, Times: {times}")

        try:
            # Plot difference heatmap
            print("Calling plot_tfr_diff_results...")
            # Construct save path
            diff_figname = os.path.join(output_dir, f"{figname}.svg")
            plot_tfr_diff_results(diff_figname.replace('.svg', ''), mean_tfr1_data, mean_tfr2_data, freqs, times, p=0.01, clusterp=0.05, clim=[-2, 2])
            print(f"Difference heatmap saved for {diff_figname}")
        except Exception as e:
            # If an error occurs, save progress and exit the program
            save_progress(figname, {'mean_tfr1': mean_tfr1_data, 'mean_tfr2': mean_tfr2_data})
            print(f"Error while plotting difference heatmap for {figname}: {e}")
            import traceback
            print(traceback.format_exc())
            sys.exit("Error occurred while plotting difference heatmap. Process stopped and progress saved.")
    except Exception as e:
        # If an error occurs during data extraction or processing, save progress and exit the program
        save_progress(figname, {'dict_object1': dict_object1, 'dict_object2': dict_object2})
        print(f"Error in plot_avg_heatmap_diff: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred in plot_avg_heatmap_diff. Process stopped and progress saved.")

def get_avg_power(power, channels, wave, tmin=-2.1, tmax=-0.2):
    hmin = 0
    hmax = 0
    if wave == 'theta':
        hmin = 4
        hmax = 8
    elif wave == 'alpha':
        hmin = 8
        hmax = 13
    elif wave == 'delta':
        pass
    else:
        # beta
        hmin = 13
        hmax = 30

    roi_power = power.copy().pick_channels(channels)
    roi_alpha_band = roi_power.copy().crop(fmin=hmin, fmax=hmax, tmin=tmin, tmax=tmax)
    average_power = roi_alpha_band.data.mean()
    return average_power

def get_group_power(dict_type, channels, wave):
    powerlist = []
    for subject in range(1,60):
        print(f"Processing subject {subject}")
        strsub = 's'+str(subject+300) + '-epo.fif'
        if strsub not in dict_type.keys():
            print(f"{strsub} not found in dictionary. Assigning value 1.")
            powerlist.append(1)
        else:
            power = dict_type.get(strsub)
            avg = get_avg_power(power, channels=channels, wave=wave)
            powerlist.append(avg)
    return powerlist

if __name__ == '__main__':
    try:
        # Define data directory
        ascent_air = 'Test_Epochs_Approach/'
        output_dir = r

        # Check if output directory exists
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")

        # Load data
        print("Loading data...")
        R_Case_dict = link_participants_data(ascent_air, 'Reward_Cases')
        P_Case_dict = link_participants_data(ascent_air, 'Punish_Cases')
        R_Avatar_dict = link_participants_data(ascent_air, 'Reward_Avatar')
        P_Avatar_dict = link_participants_data(ascent_air, 'Punish_Avatar')

        # Confirm successful data loading
        if not (R_Case_dict and P_Case_dict and R_Avatar_dict and P_Avatar_dict):
            raise ValueError("One or more data dictionaries are empty. Please check your data source.")
        print(f"Data loaded: {len(R_Case_dict)}, {len(P_Case_dict)}, {len(R_Avatar_dict)}, {len(P_Avatar_dict)}")

        # Define channels
        lTPJ = ['CP5', 'P7', 'P3']
        rTPJ = ['CP6', 'P4', 'P8']
        lFC = ['F3', 'F7', 'FC5']
        rFC = ['F4', 'F8', 'FC6']
        sfreq = 500.0
        stimuli_type = ['Reward_Cases', 'Punish_Cases', 'Reward_Avatar', 'Punish_Avatar']
        stimili_dict = [R_Case_dict, P_Case_dict, R_Avatar_dict, P_Avatar_dict]
        channelslist = [lTPJ, rTPJ, lFC, rFC]
        channelsname = ['lTPJ', 'rTPJ', 'lFC', 'rFC']
        wavelist = ['alpha', 'beta', 'theta']

        # Plot heatmaps
        print("Starting to plot heatmaps...")
        plot_avg_heatmap(P_Case_dict, lTPJ + rTPJ, "Punish_Case", "Punish_Case_sig", output_dir, sfreq)
        plot_avg_heatmap(P_Avatar_dict, lTPJ + rTPJ, "Punish_Avatar", "Punish_Ava_sig", output_dir, sfreq)
        plot_avg_heatmap(R_Case_dict, lTPJ + rTPJ, "Reward_Case", "Reward_Case_sig", output_dir, sfreq)
        plot_avg_heatmap(R_Avatar_dict, lTPJ + rTPJ, "Reward_Avatar", "Reward_Ava_sig", output_dir, sfreq)

        # Difference heatmaps
        print("Starting to plot difference heatmaps...")
        plot_avg_heatmap_diff(P_Case_dict, P_Avatar_dict, "Punish_Case - Punish_Avatar", "P_Case_P_Avatar", output_dir, sfreq)
        plot_avg_heatmap_diff(R_Case_dict, R_Avatar_dict, "Reward_Case - Reward_Avatar", "R_Case_R_Avatar", output_dir, sfreq)
        plot_avg_heatmap_diff(P_Case_dict, R_Case_dict, "Punish_Case - Reward_Case", "PRCase", output_dir, sfreq)
        plot_avg_heatmap_diff(P_Avatar_dict, R_Avatar_dict, "Punish_Avatar - Reward_Avatar", "PRAvatar", output_dir, sfreq)

        print("All heatmaps plotted successfully.")

        # Data extraction and computation
        print("Starting to process group power...")
        sublist = [subject + 300 for subject in range(1, 60)]
        df0 = pd.DataFrame({'sub_ID': sublist})

        for stimuli_point, typename in enumerate(stimuli_type):
            typevalue = stimili_dict[stimuli_point]
            for wave in wavelist:
                for channel_point, chname in enumerate(channelsname):
                    chvalue = channelslist[channel_point]
                    column_name = f"{typename}_{chname}_{wave}"
                    print(f"Processing: {column_name}")
                    # Raise an error directly instead of filling with NaN
                    df0[column_name] = get_group_power(typevalue, chvalue, wave)

        # Save processed data
        output_csv = os.path.join(output_dir, "group_power_results.csv")
        df0.to_csv(output_csv, index=False)
        print(f"Group power results saved to: {output_csv}")

    except Exception as e:
        # If an uncaught error occurs in the main program, save progress and exit
        print(f"An error occurred in the main program: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("An error occurred in the main program. Process stopped and progress saved.")

In [None]:
import mne
import os
import numpy as np
import zipfile
from mne.preprocessing import ICA
from mne.time_frequency import tfr_morlet
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
from mne_icalabel import label_components
import Rest_State_Crop as rsc
import Test_Crop as tc
from mne.time_frequency import AverageTFR
from mne.stats import permutation_cluster_1samp_test
from neurora.stuff import (
    clusterbased_permutation_1d_1samp_1sided, 
    permutation_test,
    clusterbased_permutation_2d_1samp_2sided,
    clusterbased_permutation_2d_2sided
)
import sys  # Used to exit the program

matplotlib.use('QtAgg')

stimuli_type = ['Reward_Cases','Punish_Cases','Reward_Avatar','Punish_Avatar']

def channel_cut_todata(power, channels):
    """
    Extract data for specific channels from the given TFR object.
    
    Parameters:
    - power: MNE TFR object containing time-frequency analysis data.
    - channels: List of channel names to extract.
    
    Returns:
    - selected_power_data: Extracted channel data with shape (len(channels), n_freqs, n_times).
    """
    try:
        print(f"Requested channels for extraction: {channels}")
        
        # Get available channel names
        channel_names = power.info['ch_names']
        print(f"Available channels in power object: {channel_names}")
        
        # Check if all requested channels exist
        missing_channels = [ch for ch in channels if ch not in channel_names]
        if missing_channels:
            raise ValueError(f"The following channels are missing: {missing_channels}")

        # Get channel indices
        cn_indices = [channel_names.index(ch) for ch in channels]
        print(f"Indices of requested channels: {cn_indices}")

        # Extract channel data
        selected_power_data = power.data[cn_indices]
        print(f"Extracted data shape: {selected_power_data.shape}")
        
        return selected_power_data
    except ValueError as e:
        print(f"Error: {e}")
        raise
    except AttributeError as e:
        print("Error: The power object does not have the required attributes. Please check the input.")
        raise
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        raise

def link_participants_data(dir: str, type: str):
    tfa_results = {}
    full_dir = os.path.join(dir, type)
    if not os.path.exists(full_dir):
        raise ValueError(f"Directory does not exist: {full_dir}")

    for item in os.listdir(full_dir):
        if item.endswith('.fif'):
            epochs_fname = os.path.join(full_dir, item)
            print(f"Loading file: {epochs_fname}")
            try:
                epochs = mne.read_epochs(epochs_fname, preload=True)

                # Validate epochs integrity
                if epochs.info is None:
                    raise ValueError(f"Epochs info is None for file: {item}")

                freqs = np.logspace(*np.log10([4, 30]), num=10)
                n_cycles = freqs / 2.
                power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True, return_itc=False)

                # Validate TFR data integrity
                if power.data is None:
                    raise ValueError(f"TFR data is None for file: {item}")
                if power.info is None:
                    raise ValueError(f"TFR info is None for file: {item}")
                if power.times is None:
                    raise ValueError(f"TFR times are None for file: {item}")
                if power.freqs is None:
                    raise ValueError(f"TFR freqs are None for file: {item}")

                # Add baseline correction (commented out)
                # baseline = (4.0, 4.2)  # Adjust baseline time window as needed
                # print(f"Applying baseline correction: {baseline}")
                # power.apply_baseline(baseline=baseline, mode='logratio')

                tfa_results[item] = power
            except Exception as e:
                print(f"Error processing file {item}: {e}")
    return tfa_results

def save_progress(figname, data, msg="Saving current progress due to error"):
    """Save current progress to a file for debugging and recovery"""
    try:
        print(msg)
        np.save(figname + "_progress.npy", data)  # Save data progress as a .npy file
        print(f"Progress saved to {figname}_progress.npy")
    except Exception as e:
        print(f"Error while saving progress: {e}")

def plot_tfr_results(figname, tfr, freqs, times, p=0.05, clusterp=0.05, clim=[-4, 4]):
    """
    Parameters:
    tfr : Matrix with shape [n_channels, n_freqs, n_times], representing time-frequency analysis results
    freqs : Array with shape [n_freqs], representing the frequencies of the time-frequency analysis (corresponding to the y-axis frequency range and points)
    times : Array with shape [n_times], representing the time points of the time-frequency analysis (corresponding to the x-axis time range and points)
    p : A floating-point number, default is 0.05, representing the p-value threshold
    clusterp : A floating-point number, default is 0.05, representing the cluster-level p-value threshold
    clim : A list or array, [minimum, maximum], default is [-4, 4], representing the upper and lower bounds of the color bar
    """
    try:
        print("=== Debug: Starting plot_tfr_results ===")
        print(f"Input tfr shape: {tfr.shape}")
        print(f"Freqs: {freqs}")
        print(f"Times: {times}")
        print(f"P-value threshold: {p}")
        print(f"Cluster p-value threshold: {clusterp}")
        print(f"Color limits: {clim}")
        print(f"Saving figure as: {figname}")

        n_channels, n_freqs, n_times = tfr.shape
        print(f"Channels: {n_channels}, Freqs: {n_freqs}, Times: {n_times}")

        # Average over channels
        print("Averaging over channels...")
        tfr_mean = np.mean(tfr, axis=0)  # shape will be (n_freqs, n_times)
        print(f"Averaged TFR shape: {tfr_mean.shape}")

        # Statistical analysis
        print("Performing cluster-based permutation test...")
        T_obs, clusters, cluster_p_values, H0 = permutation_cluster_1samp_test(
            tfr, n_permutations=1000, threshold=p, tail=0, n_jobs=1)
        print("Cluster-based permutation test completed.")
        print(f"Number of clusters found: {len(clusters)}")
        print(f"Cluster p-values: {cluster_p_values}")

        # Create significance matrix
        stats_results = np.zeros((n_freqs, n_times))
        for cl, p_val in zip(clusters, cluster_p_values):
            if p_val < clusterp:
                stats_results[cl] = 1  # Mark significant clusters
        print(f"Significance matrix created. Shape: {stats_results.shape}")

        # Visualize time-frequency analysis results
        print("Creating heatmap and contour plot...")
        fig, ax = plt.subplots(1, 1)
        
        # Outline significant regions
        padsats_results = np.zeros([n_freqs + 2, n_times + 2])
        padsats_results[1:n_freqs + 1, 1:n_times + 1] = stats_results
        x = np.concatenate(([times[0] - 1], times, [times[-1] + 1]))
        y = np.concatenate(([freqs[0] - 1], freqs, [freqs[-1] + 1]))
        X, Y = np.meshgrid(x, y)
        ax.contour(X, Y, padsats_results, [0.5], colors="red", alpha=0.9,
                   linewidths=2, linestyles="dashed")

        # Plot heatmap of time-frequency results
        im = ax.imshow(tfr_mean, cmap='RdBu_r', origin='lower',
                       extent=[times[0], times[-1], freqs[0], freqs[-1]], clim=clim)
        ax.set_aspect('auto')
        cbar = fig.colorbar(im)
        cbar.set_label('dB', fontsize=12)
        ax.set_xlabel('Time (s)', fontsize=16)
        ax.set_ylabel('Frequency (Hz)', fontsize=16)

        # Save the figure
        # Fix save path to avoid redundant 'TFA/' and '.svg'
        # Get original filename and directory
        directory = os.path.dirname(figname)
        base_name = os.path.basename(figname)
        # Add 'clustered_' prefix
        clustered_name = f'clustered_{base_name}'
        # Construct new full path
        clustered_figname = os.path.join(directory, clustered_name)
        print(f"Saving heatmap to: {clustered_figname}")
        plt.savefig(clustered_figname, dpi=600)
        plt.close(fig)  # Close the current figure for subsequent plotting

        print("Plotting completed successfully.")
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, tfr)
        print(f"Error in plot_tfr_results: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during plotting. Process stopped and progress saved.")

def plot_TFA(dict_object: dict, channels: list, title: str, figname:str):
    """Plot TFA images"""
    try:
        print(f"Starting plot_TFA for {figname}")
        tfr_data = [tfr.data for tfr in dict_object.values()]
        mean_tfr_data = np.mean(tfr_data, axis=0)

        Navi_list = list(dict_object.values())
        info = Navi_list[0].info  # Use the info from the first TFR
        mean_tfr = AverageTFR(info, mean_tfr_data, Navi_list[0].times,
                             Navi_list[0].freqs, len(Navi_list))
        fig, ax = plt.subplots(figsize=(10, 6))
        fig = mean_tfr.plot(picks=channels, baseline=(-0.3, -0.1), mode='logratio', title=title, combine='mean', axes=ax)
        print(fig)
        # fig.suptitle(suptitle)
        plt.show()
        fulfigname = os.path.join('TFA', f"{figname}.svg")
        fig[0].savefig(fulfigname)  # Save as SVG format
        plt.close(fig)  # Close the figure
        return 0
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, dict_object)
        print(f"Error in plot_TFA: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during TFA plotting. Process stopped and progress saved.")

def plot_avg_heatmap(dict_object: dict, channels: list, title: str, figname: str, output_dir: str, sfreq: float):
    """
    Calculate the average based on the existing TFR data dictionary and plot a heatmap.

    Parameters:
        dict_object (dict): Dictionary containing TFR objects.
        channels (list): List of channels to extract data from.
        title (str): Chart title.
        figname (str): Filename for saving the image.
        output_dir (str): Directory for image output.
        sfreq (float): Data sampling frequency.
    """
    try:
        print(f"Starting plot_avg_heatmap for {figname}")

        # Validate input
        if not dict_object:
            raise ValueError("Input dictionary is empty. Please check your input data.")
        print(f"Input TFR dictionary size: {len(dict_object)}")

        # Filter valid TFR objects
        valid_tfrs = {key: tfr for key, tfr in dict_object.items() if tfr and tfr.data is not None}
        print(f"Number of valid TFR objects: {len(valid_tfrs)}")

        if not valid_tfrs:
            raise ValueError("No valid TFR objects found.")

        # Extract all valid TFR data and calculate the average
        tfr_data = [tfr.data for tfr in valid_tfrs.values()]
        mean_tfr_data = np.mean(tfr_data, axis=0)
        print(f"Mean TFR data shape: {mean_tfr_data.shape}")
        print(f"Mean TFR data (sample): {mean_tfr_data.flatten()[:10]}")

        # Extract frequency and time information
        Navi_list = list(valid_tfrs.values())
        first_tfr = Navi_list[0]
        freqs = first_tfr.freqs
        times = first_tfr.times
        print(f"Freqs shape: {freqs.shape}, values: {freqs}")
        print(f"Times shape: {times.shape}, values: {times[:10]}")

        # Validate channel information
        channel_names = first_tfr.info['ch_names']
        print(f"Available channels in TFR: {channel_names}")
        missing_channels = [ch for ch in channels if ch not in channel_names]
        if missing_channels:
            raise ValueError(f"Missing channels from TFR data: {missing_channels}")
        
        # Extract data for specified channels
        try:
            print(f"Extracting data for channels: {channels}")
            cn_indices = [channel_names.index(ch) for ch in channels]
            selected_power_data = mean_tfr_data[cn_indices]
            print(f"Selected power data shape: {selected_power_data.shape}")
        except Exception as e:
            print(f"Error during channel data extraction: {e}")
            raise

        # Plotting
        try:
            print(f"Calling plot_tfr_results for {figname}...")
            # Construct save path
            clustered_figname = os.path.join(output_dir, f"{figname}.svg")
            plot_tfr_results(clustered_figname.replace('.svg', ''), selected_power_data, freqs, times, p=0.01, clusterp=0.05, clim=[-1, 1])
            print(f"Heatmap saved to {os.path.join(output_dir, 'clustered_' + figname + '.svg')}")
        except Exception as e:
            # If an error occurs, save progress and exit the program
            save_progress(figname, selected_power_data)
            print(f"Error while plotting heatmap for {figname}: {e}")
            import traceback
            print(traceback.format_exc())
            sys.exit("Error occurred while plotting heatmap. Process stopped and progress saved.")
    except Exception as e:
        # If an error occurs during data extraction or processing, save progress and exit the program
        save_progress(figname, dict_object)
        print(f"Error in plot_avg_heatmap: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred in plot_avg_heatmap. Process stopped and progress saved.")

def plot_tfr_diff_results(figname, tfr1, tfr2, freqs, times, p=0.05, clusterp=0.05, clim=[-2, 2]):
    """
    Parameters:
    tfr1 : Matrix with shape [n_subs, n_freqs, n_times], representing time-frequency analysis results for condition 1
    tfr2 : Matrix with shape [n_subs, n_freqs, n_times], representing time-frequency analysis results for condition 2
    freqs : Array with shape [n_freqs], representing the frequencies of the time-frequency analysis (corresponding to the y-axis frequency range and points)
    times : Array with shape [n_times], representing the time points of the time-frequency analysis (corresponding to the x-axis time range and points)
    p : A floating-point number, default is 0.05, representing the p-value threshold
    clusterp : A floating-point number, default is 0.05, representing the cluster-level p-value threshold
    clim : A list or array, [minimum, maximum], default is [-2, 2], representing the upper and lower bounds of the color bar
    """
    try:
        print("=== Debug: Starting plot_tfr_diff_results ===")
        print(f"tfr1 shape: {tfr1.shape}")
        print(f"tfr2 shape: {tfr2.shape}")
        print(f"Freqs: {freqs}")
        print(f"Times: {times}")
        print(f"P-value threshold: {p}")
        print(f"Cluster p-value threshold: {clusterp}")
        print(f"Color limits: {clim}")
        print(f"Saving figure as: {figname}")

        n_freqs = len(freqs)
        n_times = len(times)
        print(f"Number of frequencies: {n_freqs}, Number of times: {n_times}")

        # Statistical analysis
        print("Performing cluster-based permutation test for differences...")
        stats_results = clusterbased_permutation_2d_2sided(
            tfr1, tfr2,
            p_threshold=p,
            clusterp_threshold=clusterp,
            iter=1000
        )
        print("Cluster-based permutation test for differences completed.")
        print(f"Stats results shape: {stats_results.shape}")

        # Calculate TFR difference
        tfr_diff = np.mean(tfr1, axis=0) - np.mean(tfr2, axis=0) 
        print(f"TFR difference shape: {tfr_diff.shape}")

        # Visualize time-frequency analysis results
        print("Creating difference heatmap and contour plot...")
        fig, ax = plt.subplots(1, 1)

        # Outline significant regions
        padsats_results = np.zeros([n_freqs + 2, n_times + 2])
        padsats_results[1:n_freqs + 1, 1:n_times + 1] = stats_results
        x = np.concatenate(([times[0]-1], times, [times[-1]+1]))
        y = np.concatenate(([freqs[0]-1], freqs, [freqs[-1]+1]))
        X, Y = np.meshgrid(x, y)
        ax.contour(X, Y, padsats_results, [0.5], colors="red", alpha=0.9,
                   linewidths=2, linestyles="dashed")
        ax.contour(X, Y, padsats_results, [-0.5], colors="blue", alpha=0.9,
                   linewidths=2, linestyles="dashed")

        # Plot heatmap of time-frequency results
        im = ax.imshow(tfr_diff, cmap='RdBu_r', origin='lower',
                      extent=[times[0], times[-1], freqs[0], freqs[-1]], clim=clim)
        ax.set_aspect('auto')
        cbar = fig.colorbar(im)
        cbar.set_label('$\Delta$dB', fontsize=12)
        ax.set_xlabel('Time (ms)', fontsize=16)
        ax.set_ylabel('Frequency (Hz)', fontsize=16)
        # plt.show()

        # Save the figure
        # Fix save path to avoid redundant 'TFA/' and '.svg'
        # Get original filename and directory
        directory = os.path.dirname(figname)
        base_name = os.path.basename(figname)
        # Add '_diff' suffix
        diff_name = f"{os.path.splitext(base_name)[0]}_diff.svg"
        # Construct new full path
        diff_figname = os.path.join(directory, diff_name)
        print(f"Saving difference heatmap to: {diff_figname}")
        plt.savefig(diff_figname, dpi=600)
        plt.close(fig)  # Close the current figure

        print("Difference plotting completed successfully.")
    except Exception as e:
        # If an error occurs, save progress and exit the program
        save_progress(figname, {'tfr1': tfr1, 'tfr2': tfr2})
        print(f"Error in plot_tfr_diff_results: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred during difference plotting. Process stopped and progress saved.")

def plot_avg_heatmap_diff(dict_object1: dict, dict_object2: dict, title: str, figname: str, output_dir: str, sfreq: float):
    """
    Calculate and plot the TFR difference heatmap between two conditions.

    Parameters:
        dict_object1 (dict): TFR data dictionary for condition 1.
        dict_object2 (dict): TFR data dictionary for condition 2.
        title (str): Chart title.
        figname (str): Filename for saving the image.
        output_dir (str): Directory for image output.
        sfreq (float): Data sampling frequency.
    """
    try:
        print(f"Starting plot_avg_heatmap_diff for {figname}")
        
        if not dict_object1 or not dict_object2:
            raise ValueError("One of the input dictionaries is empty.")
        print("Both input dictionaries are non-empty.")

        # Extract TFR data
        print("Extracting TFR data for both conditions...")
        tfr1_data = [tfr.data for tfr in dict_object1.values()]
        tfr2_data = [tfr.data for tfr in dict_object2.values()]

        if not tfr1_data or not tfr2_data:
            raise ValueError("No valid TFR data available for one or both conditions.")
        print(f"Number of TFR1 data: {len(tfr1_data)}, Number of TFR2 data: {len(tfr2_data)}")

        # Check data consistency
        shapes1 = [tfr.shape for tfr in tfr1_data]
        shapes2 = [tfr.shape for tfr in tfr2_data]
        print(f"TFR1 data shapes: {shapes1}")
        print(f"TFR2 data shapes: {shapes2}")
        if len(set(shapes1)) > 1 or len(set(shapes2)) > 1:
            raise ValueError("Inconsistent TFR data shapes within conditions.")
        print("All TFR data within each condition have consistent shapes.")

        # Difference calculation
        print("Calculating TFR difference...")
        mean_tfr1_data = np.mean(tfr1_data, axis=0)
        mean_tfr2_data = np.mean(tfr2_data, axis=0)
        print(f"Mean TFR1 shape: {mean_tfr1_data.shape}, Mean TFR2 shape: {mean_tfr2_data.shape}")

        # Validate frequency and time
        Navi_list1 = list(dict_object1.values())
        first_tfr1 = Navi_list1[0]
        freqs = first_tfr1.freqs
        times = first_tfr1.times
        print(f"Freqs: {freqs}, Times: {times}")

        try:
            # Plot difference heatmap
            print("Calling plot_tfr_diff_results...")
            # Construct save path
            diff_figname = os.path.join(output_dir, f"{figname}.svg")
            plot_tfr_diff_results(diff_figname.replace('.svg', ''), mean_tfr1_data, mean_tfr2_data, freqs, times, p=0.01, clusterp=0.05, clim=[-2, 2])
            print(f"Difference heatmap saved for {diff_figname}")
        except Exception as e:
            # If an error occurs, save progress and exit the program
            save_progress(figname, {'mean_tfr1': mean_tfr1_data, 'mean_tfr2': mean_tfr2_data})
            print(f"Error while plotting difference heatmap for {figname}: {e}")
            import traceback
            print(traceback.format_exc())
            sys.exit("Error occurred while plotting difference heatmap. Process stopped and progress saved.")
    except Exception as e:
        # If an error occurs during data extraction or processing, save progress and exit the program
        save_progress(figname, {'dict_object1': dict_object1, 'dict_object2': dict_object2})
        print(f"Error in plot_avg_heatmap_diff: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("Error occurred in plot_avg_heatmap_diff. Process stopped and progress saved.")

def get_avg_power(power, channels, wave, tmin=4.2, tmax=6):
    hmin = 0
    hmax = 0
    if wave == 'theta':
        hmin = 4
        hmax = 8
    elif wave == 'alpha':
        hmin = 8
        hmax = 13
    elif wave == 'delta':
        pass
    else:
        # beta
        hmin = 13
        hmax = 30

    roi_power = power.copy().pick_channels(channels)
    roi_alpha_band = roi_power.copy().crop(fmin=hmin, fmax=hmax, tmin=tmin, tmax=tmax)
    average_power = roi_alpha_band.data.mean()
    return average_power

def get_group_power(dict_type, channels, wave):
    powerlist = []
    for subject in range(1,60):
        print(f"Processing subject {subject}")
        strsub = 's'+str(subject+300) + '-epo.fif'
        if strsub not in dict_type.keys():
            print(f"{strsub} not found in dictionary. Assigning value 1.")
            powerlist.append(1)
        else:
            power = dict_type.get(strsub)
            avg = get_avg_power(power, channels=channels, wave=wave)
            powerlist.append(avg)
    return powerlist

if __name__ == '__main__':
    try:
        # Define data directory
        ascent_air = 'Test_Epochs_Leave/'
        output_dir = r""  # Note: This appears incomplete in the original code
        # Check if output directory exists
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(f"Created output directory: {output_dir}")

        # Load data
        print("Loading data...")
        R_Case_dict = link_participants_data(ascent_air, 'Reward_Cases')
        P_Case_dict = link_participants_data(ascent_air, 'Punish_Cases')
        R_Avatar_dict = link_participants_data(ascent_air, 'Reward_Avatar')
        P_Avatar_dict = link_participants_data(ascent_air, 'Punish_Avatar')

        # Confirm successful data loading
        if not (R_Case_dict and P_Case_dict and R_Avatar_dict and P_Avatar_dict):
            raise ValueError("One or more data dictionaries are empty. Please check your data source.")
        print(f"Data loaded: {len(R_Case_dict)}, {len(P_Case_dict)}, {len(R_Avatar_dict)}, {len(P_Avatar_dict)}")

        # Define channels
        lTPJ = ['CP5', 'P7', 'P3']
        rTPJ = ['CP6', 'P4', 'P8']
        lFC = ['F3', 'F7', 'FC5']
        rFC = ['F4', 'F8', 'FC6']
        sfreq = 500.0
        stimuli_type = ['Reward_Cases', 'Punish_Cases', 'Reward_Avatar', 'Punish_Avatar']
        stimili_dict = [R_Case_dict, P_Case_dict, R_Avatar_dict, P_Avatar_dict]
        channelslist = [lTPJ, rTPJ, lFC, rFC]
        channelsname = ['lTPJ', 'rTPJ', 'lFC', 'rFC']
        wavelist = ['alpha', 'beta', 'theta']

        # Plot heatmaps
        print("Starting to plot heatmaps...")
        plot_avg_heatmap(P_Case_dict, lTPJ + rTPJ, "Punish_Case", "Punish_Case_sig", output_dir, sfreq)
        plot_avg_heatmap(P_Avatar_dict, lTPJ + rTPJ, "Punish_Avatar", "Punish_Ava_sig", output_dir, sfreq)
        plot_avg_heatmap(R_Case_dict, lTPJ + rTPJ, "Reward_Case", "Reward_Case_sig", output_dir, sfreq)
        plot_avg_heatmap(R_Avatar_dict, lTPJ + rTPJ, "Reward_Avatar", "Reward_Ava_sig", output_dir, sfreq)

        # Difference heatmaps
        print("Starting to plot difference heatmaps...")
        plot_avg_heatmap_diff(P_Case_dict, P_Avatar_dict, "Punish_Case - Punish_Avatar", "P_Case_P_Avatar", output_dir, sfreq)
        plot_avg_heatmap_diff(R_Case_dict, R_Avatar_dict, "Reward_Case - Reward_Avatar", "R_Case_R_Avatar", output_dir, sfreq)
        plot_avg_heatmap_diff(P_Case_dict, R_Case_dict, "Punish_Case - Reward_Case", "PRCase", output_dir, sfreq)
        plot_avg_heatmap_diff(P_Avatar_dict, R_Avatar_dict, "Punish_Avatar - Reward_Avatar", "PRAvatar", output_dir, sfreq)

        print("All heatmaps plotted successfully.")

        # Data extraction and computation
        print("Starting to process group power...")
        sublist = [subject + 300 for subject in range(1, 60)]
        df0 = pd.DataFrame({'sub_ID': sublist})

        for stimuli_point, typename in enumerate(stimuli_type):
            typevalue = stimili_dict[stimuli_point]
            for wave in wavelist:
                for channel_point, chname in enumerate(channelsname):
                    chvalue = channelslist[channel_point]
                    column_name = f"{typename}_{chname}_{wave}"
                    print(f"Processing: {column_name}")
                    # Raise an error directly instead of filling with NaN
                    df0[column_name] = get_group_power(typevalue, chvalue, wave)

        # Save processed data
        output_csv = os.path.join(output_dir, "group_power_results.csv")
        df0.to_csv(output_csv, index=False)
        print(f"Group power results saved to: {output_csv}")

    except Exception as e:
        # If an uncaught error occurs in the main program, save progress and exit
        print(f"An error occurred in the main program: {e}")
        import traceback
        print(traceback.format_exc())
        sys.exit("An error occurred in the main program. Process stopped and progress saved.")