In [None]:
from matplotlib.colors import LinearSegmentedColormap, TwoSlopeNorm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
import mne
import os

dataset_dir = "/home/aidrive/zhengxj/projects/eeg_enhanced_driving/applied_statistics_proj/data_final" # Your Dataset Path
# driving_conditions
driving_condition_list = ["0_Baseline", "1_EnterAuxiliaryStreet", "2_EnterMainStreet", "3_U-Turn", "4_StraightDrive", "5_RightTurn-P1", "6_RightTurn-P2", "7_LeftTurn-1", "8_LeftTurn-2", "9_RightTurn-1", "10_RightTurn-2", "11_RightTurn-3", "12_RightTurn-4", "13_LeftTurn-3"]
# subject_dict
subject_dict = {
    'ExpertDriver': ["E01", "E02", "E03", "E04", "E05", "E06", "E07", "E08", "E09", "E10"],
    'NoviceDriver': ["N01", "N02", "N03", "N04", "N05", "N06", "N07", "N08", "N09", "N10"]
}

# Then bands below is we used for our visualization, you can adjust them as you want.
bands = {'Delta (1-4 Hz)': (1, 4), 'Theta (4-8 Hz)': (4, 8), 'Alpha (8-12 Hz)': (8, 12), 'Beta (12-30 Hz)': (12, 30), 'Gamma (30-45 Hz)': (30, 45), 'Average (1-45Hz)': (1, 45)}

psd_fig_save_dir = "The Directory You Want to Save the Visualized Figures"
psd_stats_save_dir = "The Directory You Want to Save the Statistic Outputs"

In [None]:
def get_tensor_shape(dataset_dir, driving_condition_list, subject_dict):
    test_data_path = os.path.join(dataset_dir, f"{driving_condition_list[0]}_expert_{subject_dict['expert'][0]}.set")
    test_data = mne.io.read_raw_eeglab(test_data_path, preload=True)

    # Drop problematic channels if they are present
    channels_to_drop = ['HEO', 'VEO']
    test_data.drop_channels(channels_to_drop)

    test_data_freq = test_data.compute_psd(fmin=1, fmax=45).get_data()
    test_data_time = test_data.get_data()

    print(f"shape of `test_data_freq` is {test_data_freq.shape}\n shape of `test_data_time` is {test_data_time.shape}")
    n_eeg_channel, n_freq_bin = test_data_freq.shape
    n_driving_condition = len(driving_condition_list)
    n_driver = 10
    n_driver_type = 2
    
    return n_eeg_channel, n_freq_bin, n_driving_condition, n_driver, n_driver_type

In [None]:
def obtain_eeg_data_tensor(dataset_dir, driving_condition_list, subject_dict):
    n_eeg_channel, n_freq_bin, n_driving_condition, n_driver, n_driver_type = get_tensor_shape(dataset_dir, driving_condition_list, subject_dict)
    # read data here (in time and freq domain)
    eeg_avg_amp = np.zeros((n_driver_type, n_driver, n_driving_condition, n_eeg_channel)) # (n_driver_type, n_driver, n_driving_condition, n_eeg_channel)
    eeg_max_amp = np.zeros((n_driver_type, n_driver, n_driving_condition, n_eeg_channel)) # (n_driver_type, n_driver, n_driving_condition, n_eeg_channel)
    eeg_min_amp = np.zeros((n_driver_type, n_driver, n_driving_condition, n_eeg_channel)) # (n_driver_type, n_driver, n_driving_condition, n_eeg_channel)

    eeg_psd = np.zeros((n_driver_type, n_driver, n_driving_condition, n_eeg_channel, n_freq_bin)) # (n_driver_type, n_driver, n_driving_condition, n_eeg_channel, n_freq_bin)

    for subject_type_id, subject_type in enumerate(subject_dict.keys()):
        for subject_id_id, subject_id in enumerate(subject_dict[subject_type]):
            for driving_condition_id, driving_condition in enumerate(driving_condition_list):
                eeg_path = os.path.join(dataset_dir, f"{driving_condition}_{subject_type}_{subject_id}.set")
                eeg_data = mne.io.read_raw_eeglab(eeg_path, preload=True)
                channels_to_drop = ['HEO', 'VEO']
                eeg_data.drop_channels(channels_to_drop)
                
                eeg_avg_amp[subject_type_id, subject_id_id, driving_condition_id, :] = np.mean(eeg_data.get_data(), axis=1)
                eeg_max_amp[subject_type_id, subject_id_id, driving_condition_id, :] = np.max(eeg_data.get_data(), axis=1)
                eeg_min_amp[subject_type_id, subject_id_id, driving_condition_id, :] = np.min(eeg_data.get_data(), axis=1)
                eeg_psd[subject_type_id, subject_id_id, driving_condition_id, :, :] = eeg_data.compute_psd(fmin=1, fmax=45).get_data()
    return eeg_avg_amp, eeg_max_amp, eeg_min_amp, eeg_psd, eeg_data

In [None]:
def save_eeg_amp_topomap(avg_eeg_amp_stats, eeg_data, v_min, v_max, fig_save_dir, fig_name, stats_type):
    im, cn = mne.viz.plot_topomap(avg_eeg_amp_stats, eeg_data.info, vlim=(v_min, v_max), show=False, size=5)

    # Add the colorbar using the AxesImage (im)
    fig = im.figure  # Get the figure from the AxesImage object
    cbar = fig.colorbar(im, ax=im.axes, orientation='vertical')
    cbar.set_label('Amplitude')  # Optional: Label for the colorbar

    plt.savefig(os.path.join(fig_save_dir, stats_type, fig_name))
    plt.clf()
    plt.close()

In [None]:
def save_sig_diff_electordes_topomap(eeg_data, sig_diff_electrodes_dict, fig_save_dir, driving_condition, stats_type):
    # Plot the sensors with default color scheme, and extract the axes (ax) from the figure
    fig = mne.viz.plot_sensors(eeg_data.info, show_names=False, show=False)
    # Get the current axes from the figure
    ax = fig.gca()
    # Retrieve the channel positions directly from eeg_data.info
    pos = mne.channels.layout._find_topomap_coords(eeg_data.info, picks='all')  # 'pos' now matches plot_sensors scale
    # Plot the red electrodes in the correct coordinates on the same figure
    color_dict = {
        "expert": "#0083FF",
        "novice": "#FF822F",
        "both": "blue"
    }
    for sig_type, sig_diff_electrodes in sig_diff_electrodes_dict.items():
        color = color_dict[sig_type]
        for ch_name in sig_diff_electrodes:
            if ch_name in eeg_data.info['ch_names']:
                idx = eeg_data.info['ch_names'].index(ch_name)  # Find the index of the electrode
                x, y = pos[idx]  # Coordinates in figure space
                ax.plot(x, y, 'o', color=color, markersize=20, label=ch_name if ch_name == sig_diff_electrodes[0] else "")

    # Save the figure with both sensor plot and manually plotted points
    if stats_type == "none":
        plt.savefig(os.path.join(fig_save_dir, f"sig_{driving_condition}.svg"))
    else:
        plt.savefig(os.path.join(fig_save_dir, stats_type, f"sig_{driving_condition}_{stats_type}.svg"))
    plt.clf()
    plt.close()

In [None]:
def calculate_p_value_and_plot_topomap_for_all_driving_condition(expert_eeg_amp_stats, novice_eeg_amp_stats, stats_type, eeg_data, p_value_dict, fig_save_dir):
    # first, calculate p_value
    # test for amp_stats along all driving condition
    # (n_driver, n_driving_condition, n_eeg_channel)
    n_eeg_channel = expert_eeg_amp_stats.shape[-1]
    eeg_stats_ttest_p = np.zeros(n_eeg_channel)
    # shape of eeg_avg_amp is (n_driver_type, n_driver, n_driving_condition, n_eeg_channel)
    for channel_id in range(n_eeg_channel):
        _, p_value = stats.ttest_ind(expert_eeg_amp_stats[:,:,channel_id].ravel(), novice_eeg_amp_stats[:,:,channel_id].ravel(), alternative="two-sided")
        eeg_stats_ttest_p[channel_id] = p_value
    
    p_value_dict["all"] = eeg_stats_ttest_p
    # then, plot topomap
    # first plot the electrodes topomap that have significant difference
    sig_diff_electrodes = {
        "expert": [],
        "novice": [],
        "both": []
    }
    for channel_id, eeg_channel in enumerate(eeg_data.info["ch_names"]):
        if eeg_stats_ttest_p[channel_id] < 0.05:
            expert_avg_raw_channel = np.mean(expert_eeg_amp_stats[:,:,channel_id])
            novice_avg_raw_channel = np.mean(novice_eeg_amp_stats[:,:,channel_id])
            if expert_avg_raw_channel > novice_avg_raw_channel:
                sig_diff_electrodes["expert"].append(eeg_channel)
            elif expert_avg_raw_channel < novice_avg_raw_channel:
                sig_diff_electrodes["novice"].append(eeg_channel)
            else:
                sig_diff_electrodes["both"].append(eeg_channel)

    save_sig_diff_electordes_topomap(eeg_data, sig_diff_electrodes, fig_save_dir, driving_condition="all", stats_type=stats_type)
    
    # then plot the data topomap
    avg_expert_eeg_amp_stats = np.mean(expert_eeg_amp_stats, axis=(0,1))
    v_min_expert, v_max_expert = np.min(avg_expert_eeg_amp_stats), np.max(avg_expert_eeg_amp_stats)

    avg_novice_eeg_amp_stats = np.mean(novice_eeg_amp_stats, axis=(0,1))
    v_min_novice, v_max_novice = np.min(avg_novice_eeg_amp_stats), np.max(avg_novice_eeg_amp_stats)

    v_min = min(v_min_expert, v_min_novice)
    v_max = max(v_max_expert, v_max_novice)
    
    save_eeg_amp_topomap(avg_expert_eeg_amp_stats, eeg_data, v_min, v_max, fig_save_dir, f"expert_all_{stats_type}.svg", stats_type)
    save_eeg_amp_topomap(avg_novice_eeg_amp_stats, eeg_data, v_min, v_max, fig_save_dir, f"novice_all_{stats_type}.svg", stats_type)

In [None]:
def calculate_p_value_and_plot_topomap_for_each_driving_condition(expert_eeg_amp_stat_driving_condition, novice_eeg_amp_stat_driving_condition, stats_type, eeg_data, p_value_dict, driving_condition, fig_save_dir):
    # first, calculate p_value
    # test for avg_amp along all driving condition
    # (n_driver, n_eeg_channel)
    n_eeg_channel = expert_eeg_amp_stat_driving_condition.shape[-1]
    eeg_stats_ttest_p = np.zeros(n_eeg_channel)
    # shape of eeg_avg_amp is (n_driver_type, n_driver, n_driving_condition, n_eeg_channel)
    for channel_id in range(n_eeg_channel):
        _, p_value = stats.ttest_ind(expert_eeg_amp_stat_driving_condition[:,channel_id].ravel(), novice_eeg_amp_stat_driving_condition[:,channel_id].ravel(), alternative="two-sided")
        eeg_stats_ttest_p[channel_id] = p_value
        
    p_value_dict[driving_condition] = eeg_stats_ttest_p
    # then, plot topomap
    # first plot the electrodes topomap that have significant difference
    sig_diff_electrodes = {
        "expert": [],
        "novice": [],
        "both": []
    }
    for channel_id, eeg_channel in enumerate(eeg_data.info["ch_names"]):
        if eeg_stats_ttest_p[channel_id] < 0.05:
            expert_avg_raw_dc_channel = np.mean(expert_eeg_amp_stat_driving_condition[:,channel_id])
            novice_avg_raw_dc_channel = np.mean(novice_eeg_amp_stat_driving_condition[:,channel_id])
            if expert_avg_raw_dc_channel > novice_avg_raw_dc_channel:
                sig_diff_electrodes["expert"].append(eeg_channel)
            elif expert_avg_raw_dc_channel < novice_avg_raw_dc_channel:
                sig_diff_electrodes["novice"].append(eeg_channel)
            else:
                sig_diff_electrodes["both"].append(eeg_channel)

    save_sig_diff_electordes_topomap(eeg_data, sig_diff_electrodes, fig_save_dir, driving_condition, stats_type)
    
    # then plot the data topomap
    avg_expert_eeg_amp_stats = np.mean(expert_eeg_amp_stat_driving_condition, axis=0)
    v_min_expert, v_max_expert = np.min(avg_expert_eeg_amp_stats), np.max(avg_expert_eeg_amp_stats)

    avg_novice_eeg_amp_stats = np.mean(novice_eeg_amp_stat_driving_condition, axis=0)
    v_min_novice, v_max_novice = np.min(avg_novice_eeg_amp_stats), np.max(avg_novice_eeg_amp_stats)
    
    v_min = min(v_min_expert, v_min_novice)
    v_max = max(v_max_expert, v_max_novice)
    
    save_eeg_amp_topomap(avg_expert_eeg_amp_stats, eeg_data, v_min, v_max, fig_save_dir, f"expert_{driving_condition}_{stats_type}.svg", stats_type)
    save_eeg_amp_topomap(avg_novice_eeg_amp_stats, eeg_data, v_min, v_max, fig_save_dir, f"novice_{driving_condition}_{stats_type}.svg", stats_type)

In [None]:
def calculate_p_value_and_plot_topomap(eeg_amp_stats, stats_type, eeg_data, driving_condition_list, fig_save_dir, stats_save_dir):
    # shape of eeg_avg_amp is (n_driver_type, n_driver, n_driving_condition, n_eeg_channel)
    # first, create dir for figs and stats data save
    os.makedirs(os.path.join(fig_save_dir, stats_type), exist_ok=True)
    os.makedirs(os.path.join(stats_save_dir, stats_type), exist_ok=True)
    p_value_dict = {
        "ch_names": eeg_data.info["ch_names"]
    }
    # then, we first calculate p value and plot topomap for the stats data along all driving conditions
    calculate_p_value_and_plot_topomap_for_all_driving_condition(eeg_amp_stats[0], eeg_amp_stats[1], stats_type, eeg_data, p_value_dict, fig_save_dir)
    # then do the same thing for all driving conditions
    for driving_condition_id, driving_condition in enumerate(driving_condition_list):
        expert_eeg_amp_stats_driving_condition = eeg_amp_stats[0,:,driving_condition_id,:]
        novice_eeg_amp_stats_driving_condition = eeg_amp_stats[1,:,driving_condition_id,:]
        calculate_p_value_and_plot_topomap_for_each_driving_condition(expert_eeg_amp_stats_driving_condition, novice_eeg_amp_stats_driving_condition, stats_type, eeg_data, p_value_dict, driving_condition, fig_save_dir)
    df = pd.DataFrame(p_value_dict)
    df.to_csv(os.path.join(stats_save_dir, stats_type, f"eeg_{stats_type}_amp_driving_condition.csv"), index_label="index")

In [None]:
def plot_eeg_psd_topomap(expert_psd_data, novice_psd_data, eeg_data, bands, fig_save_dir, driving_condition):
    '''This function only take responsibility of plotting topomap
    expert_psd_data's shape (n_eeg_channel, n_freq_bin)
    novice_psd_data's shape is the same
    eeg_data is the mne's raw class object
    '''
    eeg_spectrum = eeg_data.compute_psd(fmin=1,fmax=45)
    freqs = eeg_spectrum.freqs
    
    # first find the vmin and vmax of the total graph
    n_band = len(bands)
    # vmin_array = np.ones(2 * (n_band - 1))
    # vmax_array = np.ones(2 * (n_band - 1))
    
    vmin_array = np.ones(2 * n_band)
    vmax_array = np.ones(2 * n_band)
    
    
    for i, (band_name, (fmin, fmax)) in enumerate(bands.items()):
        band_idx = (freqs >= fmin) & (freqs < fmax)
        expert_psd_band_data = expert_psd_data[:,band_idx].mean(axis=1)
        novice_psd_band_data = novice_psd_data[:,band_idx].mean(axis=1)
        # unit of psd is (uv)^2/Hz
        expert_psd_band_data = expert_psd_band_data * 1e12
        novice_psd_band_data = novice_psd_band_data * 1e12
        
        # if band_name == "Delta (1-4 Hz)":
        #     delta_vmin = min(np.min(expert_psd_band_data), np.min(novice_psd_band_data))
        #     delta_vmax = max(np.max(expert_psd_band_data), np.max(novice_psd_band_data))
        # else:            
        #     vmin_array[i - 1] = np.min(expert_psd_band_data)
        #     vmin_array[i - 1 + (n_band - 1)] = np.min(novice_psd_band_data)
        #     vmax_array[i - 1] = np.max(expert_psd_band_data)
        #     vmax_array[i - 1 + (n_band - 1)] = np.max(novice_psd_band_data)
        vmin_array[i] = np.min(expert_psd_band_data)
        vmin_array[i + n_band] = np.min(novice_psd_band_data)
        vmax_array[i] = np.max(expert_psd_band_data)
        vmax_array[i + n_band] = np.max(novice_psd_band_data)

    # vmin = np.min(vmin_array)
    # vmax = np.max(vmax_array)
    
    vmin = 10 * np.log10(np.min(vmin_array))
    vmax = 10 * np.log10(np.max(vmax_array))

    # eeg_spectrum._data = expert_psd_data
    # fig = eeg_spectrum.plot_topomap(bands=bands, vlim=(vmin, vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"expert_{driving_condition}.svg"))
    # fig = eeg_spectrum.plot_topomap(bands={"Delta (1-4 Hz)": (1, 4)}, vlim=(delta_vmin, delta_vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"expert_{driving_condition}_delta.svg"))

    # eeg_spectrum._data = novice_psd_data
    # fig = eeg_spectrum.plot_topomap(bands=bands, vlim=(vmin, vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"novice_{driving_condition}.svg"))
    # fig = eeg_spectrum.plot_topomap(bands={"Delta (1-4 Hz)": (1, 4)}, vlim=(delta_vmin, delta_vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"novice_{driving_condition}_delta.svg"))
    
    eeg_spectrum._data = expert_psd_data
    fig = eeg_spectrum.plot_topomap(bands=bands, dB=True, vlim=(vmin, vmax))
    fig.savefig(os.path.join(fig_save_dir, f"expert_{driving_condition}.svg"))

    eeg_spectrum._data = novice_psd_data
    fig = eeg_spectrum.plot_topomap(bands=bands, dB=True, vlim=(vmin, vmax))
    fig.savefig(os.path.join(fig_save_dir, f"novice_{driving_condition}.svg"))


def plot_eeg_psd_topomap_white_zero(expert_psd_data, novice_psd_data, eeg_data, bands, fig_save_dir, driving_condition):
    '''This function only take responsibility of plotting topomap
    expert_psd_data's shape (n_eeg_channel, n_freq_bin)
    novice_psd_data's shape is the same
    eeg_data is the mne's raw class object
    '''
    eeg_spectrum = eeg_data.compute_psd(fmin=1,fmax=45)
    freqs = eeg_spectrum.freqs
    
    # first find the vmin and vmax of the total graph
    n_band = len(bands)
    # vmin_array = np.ones(2 * (n_band - 1))
    # vmax_array = np.ones(2 * (n_band - 1))
    
    vmin_array = np.ones(2 * n_band)
    vmax_array = np.ones(2 * n_band)
    
    
    for i, (band_name, (fmin, fmax)) in enumerate(bands.items()):
        band_idx = (freqs >= fmin) & (freqs < fmax)
        expert_psd_band_data = expert_psd_data[:,band_idx].mean(axis=1)
        novice_psd_band_data = novice_psd_data[:,band_idx].mean(axis=1)
        # unit of psd is (uv)^2/Hz
        expert_psd_band_data = expert_psd_band_data * 1e12
        novice_psd_band_data = novice_psd_band_data * 1e12
        
        # if band_name == "Delta (1-4 Hz)":
        #     delta_vmin = min(np.min(expert_psd_band_data), np.min(novice_psd_band_data))
        #     delta_vmax = max(np.max(expert_psd_band_data), np.max(novice_psd_band_data))
        # else:            
        #     vmin_array[i - 1] = np.min(expert_psd_band_data)
        #     vmin_array[i - 1 + (n_band - 1)] = np.min(novice_psd_band_data)
        #     vmax_array[i - 1] = np.max(expert_psd_band_data)
        #     vmax_array[i - 1 + (n_band - 1)] = np.max(novice_psd_band_data)
        vmin_array[i] = np.min(expert_psd_band_data)
        vmin_array[i + n_band] = np.min(novice_psd_band_data)
        vmax_array[i] = np.max(expert_psd_band_data)
        vmax_array[i + n_band] = np.max(novice_psd_band_data)

    # vmin = np.min(vmin_array)
    # vmax = np.max(vmax_array)
    
    vmin = 10 * np.log10(np.min(vmin_array))
    vmax = 10 * np.log10(np.max(vmax_array))

    # eeg_spectrum._data = expert_psd_data
    # fig = eeg_spectrum.plot_topomap(bands=bands, vlim=(vmin, vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"expert_{driving_condition}.svg"))
    # fig = eeg_spectrum.plot_topomap(bands={"Delta (1-4 Hz)": (1, 4)}, vlim=(delta_vmin, delta_vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"expert_{driving_condition}_delta.svg"))

    # eeg_spectrum._data = novice_psd_data
    # fig = eeg_spectrum.plot_topomap(bands=bands, vlim=(vmin, vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"novice_{driving_condition}.svg"))
    # fig = eeg_spectrum.plot_topomap(bands={"Delta (1-4 Hz)": (1, 4)}, vlim=(delta_vmin, delta_vmax))
    # fig.savefig(os.path.join(fig_save_dir, f"novice_{driving_condition}_delta.svg"))
    
    # Create a custom colormap
    colors = [(0, 'blue'), (0.5, 'white'), (1, 'red')]  # Blue to white to red
    cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)

    # Define the normalization
    cnorm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
    print(f"vmin={vmin}, vmax={vmax}")
    
    eeg_spectrum._data = expert_psd_data
    fig = eeg_spectrum.plot_topomap(bands=bands, dB=True, vlim=(vmin, vmax), cmap=(cmap, False), cnorm=cnorm, show=False)
    fig.savefig(os.path.join(fig_save_dir, f"expert_{driving_condition}.svg"))

    eeg_spectrum._data = novice_psd_data
    fig = eeg_spectrum.plot_topomap(bands=bands, dB=True, vlim=(vmin, vmax), cmap=(cmap, False), cnorm=cnorm, show=False)
    fig.savefig(os.path.join(fig_save_dir, f"novice_{driving_condition}.svg"))


def plot_eeg_psd_topomap_single_band(expert_psd_data, novice_psd_data, eeg_data, bands, fig_save_dir, driving_condition):
    '''This function only take responsibility of plotting topomap
    expert_psd_data's shape (n_eeg_channel, n_freq_bin)
    novice_psd_data's shape is the same
    eeg_data is the mne's raw class object
    '''

    single_band_save_dir = os.path.join(fig_save_dir, "single_band")
    os.makedirs(single_band_save_dir, exist_ok=True)
    
    eeg_spectrum = eeg_data.compute_psd(fmin=1,fmax=45)
    freqs = eeg_spectrum.freqs
    
    # first find the vmin and vmax of the total graph
    n_band = len(bands)
    # vmin_array = np.ones(2 * (n_band - 1))
    # vmax_array = np.ones(2 * (n_band - 1))
    
    for i, (band_name, (fmin, fmax)) in enumerate(bands.items()):
        band_idx = (freqs >= fmin) & (freqs < fmax)
        expert_psd_band_data = expert_psd_data[:,band_idx].mean(axis=1)
        novice_psd_band_data = novice_psd_data[:,band_idx].mean(axis=1)
        # unit of psd is (uv)^2/Hz
        expert_psd_band_data = expert_psd_band_data * 1e12
        novice_psd_band_data = novice_psd_band_data * 1e12
        
        vmin = min(expert_psd_band_data.min(), novice_psd_band_data.min())
        vmax = max(expert_psd_band_data.max(), novice_psd_band_data.max())
    
        vmin = 10 * np.log10(np.min(vmin))
        vmax = 10 * np.log10(np.max(vmax))
        
        eeg_spectrum._data = expert_psd_data
        fig = eeg_spectrum.plot_topomap(bands={band_name: (fmin, fmax)}, dB=True, vlim=(vmin, vmax))
        fig.savefig(os.path.join(single_band_save_dir, f"topomap_expert_{driving_condition}_{band_name}.svg"))
    
        eeg_spectrum._data = novice_psd_data
        fig = eeg_spectrum.plot_topomap(bands={band_name: (fmin, fmax)}, dB=True, vlim=(vmin, vmax))
        fig.savefig(os.path.join(single_band_save_dir, f"topomap_novice_{driving_condition}_{band_name}.svg"))

In [None]:
def calculate_psd_p_value_and_plot_topomap_for_all_driving_condition(expert_eeg_psd, novice_eeg_psd, bands, p_value_dict, eeg_data, fig_save_dir):
    # (n_driver, n_driving_condition, n_eeg_channel, n_freq_bin)
    eeg_spectrum = eeg_data.compute_psd(fmin=1,fmax=45)
    freqs = eeg_spectrum.freqs
    
    n_eeg_channel = expert_eeg_psd.shape[-2]
    for band_name, (fmin, fmax) in bands.items():
        eeg_stats_ttest_p = np.zeros(n_eeg_channel)
        band_idx = (freqs >= fmin) & (freqs < fmax)
        for channel_id in range(n_eeg_channel):
            _, p_value = stats.ttest_ind(expert_eeg_psd[:,:,channel_id, band_idx].ravel(), novice_eeg_psd[:,:,channel_id, band_idx].ravel(), alternative="two-sided")
            eeg_stats_ttest_p[channel_id] = p_value

        p_value_dict[f"all_{band_name}"] = eeg_stats_ttest_p
        # then, plot topomap
        # first plot the electrodes topomap that have significant difference
        sig_diff_electrodes = {
            "expert": [], # red
            "novice": [], # green
            "both": [] # blue
        }
        for channel_id, eeg_channel in enumerate(eeg_data.info["ch_names"]):
            if eeg_stats_ttest_p[channel_id] < 0.05:
                expert_avg_psd_band_channel = np.mean(expert_eeg_psd[:,:, channel_id, band_idx])
                novice_avg_psd_band_channel = np.mean(novice_eeg_psd[:,:, channel_id, band_idx])
                if expert_avg_psd_band_channel > novice_avg_psd_band_channel:
                    sig_diff_electrodes["expert"].append(eeg_channel)
                elif expert_avg_psd_band_channel < novice_avg_psd_band_channel:
                    sig_diff_electrodes["novice"].append(eeg_channel)
                else:
                    sig_diff_electrodes["both"].append(eeg_channel)
        save_sig_diff_electordes_topomap(eeg_data, sig_diff_electrodes, fig_save_dir, driving_condition=f"all_{band_name}", stats_type="none")
    
    # then plot the data topomap
    avg_expert_eeg_psd = np.mean(expert_eeg_psd, axis=(0,1)) + 1e-12
    avg_novice_eeg_psd = np.mean(novice_eeg_psd, axis=(0,1)) + 1e-12

    plot_eeg_psd_topomap(avg_expert_eeg_psd, avg_novice_eeg_psd, eeg_data, bands, fig_save_dir, "all")
    # plot_eeg_psd_topomap_white_zero(avg_expert_eeg_psd, avg_novice_eeg_psd, eeg_data, bands, fig_save_dir, "all")
    # plot_eeg_psd_topomap_single_band(avg_expert_eeg_psd, avg_novice_eeg_psd, eeg_data, bands, fig_save_dir, "all")

In [None]:
def calculate_psd_p_value_and_plot_topomap_for_each_driving_condition(expert_eeg_psd, novice_eeg_psd, bands, p_value_dict, eeg_data, driving_condition, fig_save_dir):
    # (n_driver, n_eeg_channel, n_freq_bin)
    eeg_spectrum = eeg_data.compute_psd(fmin=1,fmax=45)
    freqs = eeg_spectrum.freqs
    
    n_eeg_channel = expert_eeg_psd.shape[-2]
    for band_name, (fmin, fmax) in bands.items():
        eeg_stats_ttest_p = np.zeros(n_eeg_channel)
        band_idx = (freqs >= fmin) & (freqs < fmax)
        for channel_id in range(n_eeg_channel):
            _, p_value = stats.ttest_ind(expert_eeg_psd[:,channel_id, band_idx].ravel(), novice_eeg_psd[:,channel_id, band_idx].ravel(), alternative="two-sided")
            eeg_stats_ttest_p[channel_id] = p_value

        p_value_dict[f"{driving_condition}_{band_name}"] = eeg_stats_ttest_p
        # then, plot topomap
        # first plot the electrodes topomap that have significant difference
        sig_diff_electrodes = {
            "expert": [], # red
            "novice": [], # green
            "both": [] # blue
        }
        for channel_id, eeg_channel in enumerate(eeg_data.info["ch_names"]):
            if eeg_stats_ttest_p[channel_id] < 0.05:
                expert_avg_psd_band_channel = np.mean(expert_eeg_psd[:, channel_id, band_idx])
                novice_avg_psd_band_channel = np.mean(novice_eeg_psd[:, channel_id, band_idx])
                if expert_avg_psd_band_channel > novice_avg_psd_band_channel:
                    sig_diff_electrodes["expert"].append(eeg_channel)
                elif expert_avg_psd_band_channel < novice_avg_psd_band_channel:
                    sig_diff_electrodes["novice"].append(eeg_channel)
                else:
                    sig_diff_electrodes["both"].append(eeg_channel)

        save_sig_diff_electordes_topomap(eeg_data, sig_diff_electrodes, fig_save_dir, driving_condition=f"{driving_condition}_{band_name}", stats_type="none")
        
    # then plot the data topomap
    avg_expert_eeg_psd = np.mean(expert_eeg_psd, axis=0) + 1e-12
    avg_novice_eeg_psd = np.mean(novice_eeg_psd, axis=0) + 1e-12

    plot_eeg_psd_topomap(avg_expert_eeg_psd, avg_novice_eeg_psd, eeg_data, bands, fig_save_dir, driving_condition)
    # plot_eeg_psd_topomap_white_zero(avg_expert_eeg_psd, avg_novice_eeg_psd, eeg_data, bands, fig_save_dir, driving_condition)
    plot_eeg_psd_topomap_single_band(avg_expert_eeg_psd, avg_novice_eeg_psd, eeg_data, bands, fig_save_dir, driving_condition)

In [None]:
def calculate_psd_p_value_and_plot_topomap(eeg_psd, eeg_data, driving_condition_list, bands, fig_save_dir, stats_save_dir):
    # eeg_psd's shape (n_driver_type, n_driver, n_driving_condition, n_eeg_channel, n_freq_bin)
    p_value_dict = {
    "ch_names": eeg_data.info["ch_names"]
    }
    calculate_psd_p_value_and_plot_topomap_for_all_driving_condition(eeg_psd[0], eeg_psd[1], bands, p_value_dict, eeg_data, fig_save_dir)
    
    for driving_condition_id, driving_condition in enumerate(driving_condition_list):
        expert_eeg_psd_driving_condition = eeg_psd[0,:,driving_condition_id,:,:]
        novice_eeg_psd_driving_condition = eeg_psd[1,:,driving_condition_id,:,:]
        calculate_psd_p_value_and_plot_topomap_for_each_driving_condition(expert_eeg_psd_driving_condition, novice_eeg_psd_driving_condition, bands, p_value_dict, eeg_data, driving_condition, fig_save_dir)
    
    df = pd.DataFrame(p_value_dict)
    df.to_csv(os.path.join(stats_save_dir, "psd_stats.csv"))

In [None]:
def eeg_freq_stats_test(dataset_dir, driving_condition_list, subject_dict, bands, psd_fig_save_dir, psd_stats_save_dir):
    _, _, _, eeg_psd, eeg_data = obtain_eeg_data_tensor(dataset_dir, driving_condition_list, subject_dict)
    calculate_psd_p_value_and_plot_topomap(eeg_psd, eeg_data, driving_condition_list, bands, psd_fig_save_dir, psd_stats_save_dir)

In [None]:
eeg_freq_stats_test(dataset_dir, driving_condition_list, subject_dict, bands, psd_fig_save_dir, psd_stats_save_dir)