In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from mne.channels import make_standard_montage
from mne import create_info, EvokedArray
from mne.viz import plot_topomap
from matplotlib.gridspec import GridSpec
import os

In [None]:
# Set parameters 
# Define probability of direction threshold
PD_THRESHOLD = 0.995

# Define the brain region to channel mapping
regions_dict = {
    'Frontal': {
        'Left': ['Fp1', 'AF3', 'F7', 'F3'],
        'Right': ['Fp2', 'AF4', 'F8', 'F4'],
        'Midline': ['Fz']
    },
    'Central': {
        'Midline': ['Cz']
    },
    'Parietal': {
        'Left': ['FC1', 'C3', 'CP1', 'P3'],
        'Right': ['FC2', 'C4', 'CP2', 'P4'],
        'Midline': ['Pz']
    },
    'Temporal': {
        'Left': ['FC5', 'T7', 'CP5', 'P7'],
        'Right': ['P8', 'CP6', 'T8', 'FC6']
    },
    'Occipital': {
        'Left': ['PO3', 'O1'],
        'Right': ['PO4', 'O2'],
        'Midline': ['Oz']
    }
}

# Create a list of all channels by iterating through the dictionary
all_channels = [ch for regions in regions_dict.values() for sides in regions.values() for ch in sides]

# Create a Biosemi32 montage
montage = make_standard_montage('biosemi32')
montage.plot();


# Create an info object for the montage
info = create_info(ch_names=all_channels, sfreq=250, ch_types='eeg')
info.set_montage(montage)

# Define functions
def prepare_data(df, pd_threshold=PD_THRESHOLD):
    # Initialize arrays to store data and mask
    data = np.zeros(len(all_channels))
    mask = np.zeros(len(all_channels), dtype=bool)
    
    # Iterate over the dataframe rows
    for idx, row in df.iterrows():
        value = row['median_diff']
        pd_value = row['pd']
        channel = idx  # Channel name
        print(idx)
        
        # Find the index of the channel in the list of all_channels
        ch_idx = all_channels.index(channel)
        
        # Set the data value and mask based on pd threshold
        data[ch_idx] = value
        mask[ch_idx] = pd_value >= pd_threshold
    
    return data, mask

def plot_topo_grid(df, measures, comparisons, info, file=None, title=""):
    n_rows = len(comparisons)
    n_cols = len(measures)
    
    # Adjust figure size (slightly smaller height and width for better spacing)
    fig = plt.figure(figsize=(3 * n_cols + 1, 3 * n_rows + 1))  # Adjusted width and height
    
    # Adjust GridSpec parameters to reduce spacing between subplots
    gs = GridSpec(n_rows, n_cols + 2, figure=fig, width_ratios=[1.0] + [2] * n_cols + [0.2], wspace=0.2, hspace=0.2)
    
    # Set the main title for the figure
    fig.suptitle(title, fontsize=20)  # Slightly smaller font size for the title
    
    # Calculate global min and max for consistent color scaling across subplots
    vmin, vmax = df.query("pd > @PD_THRESHOLD")['median_diff'].min(), df.query("pd > @PD_THRESHOLD")['median_diff'].max()
    abs_max = max(abs(vmin), abs(vmax))
    vmin, vmax = -abs_max, abs_max  # Symmetrical color scale for better comparison
    
    # Loop through each condition and comparison to generate subplots
    for col, measure in enumerate(measures):
        for row, comp in enumerate(comparisons):
            ax = fig.add_subplot(gs[row, col + 1])
            df_comp = df[(df['comparison'] == comp) & (df['measure'] == measure)]
            df_comp = df_comp.set_index('channels').loc[info.ch_names]
            data, mask = prepare_data(df_comp)
            print(mask)
            
            # Apply mask to data to avoid showing masked regions
            masked_data = data.copy()
            masked_data[~mask] = 0
            
            # Plot the topographic map with the masked data
            im, _ = plot_topomap(masked_data, info, axes=ax, show=False, 
                                 mask=mask, 
                                 mask_params=dict(marker=None, markerfacecolor='w', markeredgecolor='k', markersize=6),
                                 cmap='RdBu_r', sensors=True, contours=6, 
                                 outlines='head', 
                                 extrapolate='head',
                                 vlim=(vmin, vmax)
                                 )
            
            # Title for each subplot with adjusted font size
            if row == 0:
                ax.set_title(f'{measure.upper()}', fontsize=16)  # Reduced font size for better readability
            
            # Add comparison names as text boxes with adjusted font size
            if col == 0:
                text_ax = fig.add_subplot(gs[row, 0])
                text_ax.axis('off')
                text_ax.text(0.25, 0.5, comp, fontsize=14, ha='center', va='center', wrap=True)  # Slightly smaller font size

    # Add a single colorbar on the right with adjusted label font size
    cax = fig.add_subplot(gs[:, -1])
    cbar = fig.colorbar(im, cax=cax, orientation='vertical')
    cbar.set_label('Standard Deviation Median Difference', fontsize=12)  # Smaller font size for colorbar label
    
    # Adjust the layout to fit elements without overlap
    plt.subplots_adjust(top=0.92, bottom=0.05, left=0.05, right=0.92, hspace=0.15, wspace=0.2)
    
    # Save the figure if a file path is provided
    if file is not None:
        plt.savefig(file, dpi=300, bbox_inches='tight')

## Group Comparisons 

In [None]:
# Define metrics for combining csv files.
metrics = ['btwn','clcoef', 'indgr', 'outdgr'] # 

# read the data
df_group = pd.DataFrame()

for metric in metrics:
	# Read the data from the CSV file
    temp_df = pd.read_csv(f'c:/Users/j_m289/Pictures/phd/3. Data Analysis/studies/OKTOS/resting_te/analysis/brms/output/{metric}_resp_marginaldiff.csv')
    
    # Add the 'measure' column to indicate the metric
    temp_df['measure'] = metric
    
    # Concatenate the current temp_df to the main dataframe df
    df_group = pd.concat([df_group, temp_df], ignore_index=True)

# Split dataframe by condition
condition_category = pd.CategoricalDtype(categories=["EC", "EO"], ordered=True)
df_group['condition'] = df_group['condition'].astype(condition_category)

# Get unique comparisons and measures
comparisons = df_group['comparison'].unique()
measures = df_group['measure'].unique()

# Filter the data for 'EO' and 'EC' conditions
eyes_open_group = df_group[df_group['condition'] == 'EO']
eyes_closed_group = df_group[df_group['condition'] == 'EC']

In [None]:
# Usage
EC_group_comp = plot_topo_grid(eyes_closed_group, measures, comparisons, info, 
                     file = 'c:/Users/j_m289/Pictures/phd/3. Data Analysis/studies/OKTOS/resting_te/analysis/brms/figures/EC_group_comp', title="Resting Eyes-Closed Group Differences per Timepoint")

In [None]:
EO_group_comp = plot_topo_grid(eyes_open_group, measures, comparisons, info, 
                     file = 'c:/Users/j_m289/Pictures/phd/3. Data Analysis/studies/OKTOS/resting_te/analysis/brms/figures/EO_group_comp', title="Resting Eyes-Open Group Differences per Timepoint")

## Timepoint Comparisons

In [None]:
# Define metrics for combining csv files.
metrics = ['btwn','clcoef', 'indgr', 'outdgr'] #, 'clcoef', 'indgr', 'outdgr'

# read the data
df_timepoint = pd.DataFrame()

for metric in metrics:
	# Read the data from the CSV file
    temp_df = pd.read_csv(f'c:/Users/j_m289/Pictures/phd/3. Data Analysis/studies/OKTOS/resting_te/analysis/brms/output/{metric}_timepoint_marginaldiff.csv')
    
    # Add the 'measure' column to indicate the metric
    temp_df['measure'] = metric
    
    # Concatenate the current temp_df to the main dataframe df
    df_timepoint = pd.concat([df_timepoint, temp_df], ignore_index=True)

# Split dataframe by condition
condition_category = pd.CategoricalDtype(categories=["EC", "EO"], ordered=True)
df_timepoint['condition'] = df_timepoint['condition'].astype(condition_category)

# Get unique comparisons and measures
comparisons = df_timepoint['comparison'].unique()
measures = df_timepoint['measure'].unique()

# Filter the data for 'RESP' and 'NR' groups
responder_group = df_timepoint[df_timepoint['group'] == 'RESP']
nonresponder_group = df_timepoint[df_timepoint['group'] == 'NR']

In [None]:
# Usage
resp_timepoint_comp = plot_topo_grid(responder_group, measures, comparisons, info, 
                     file = 'c:/Users/j_m289/Pictures/phd/3. Data Analysis/studies/OKTOS/resting_te/analysis/brms/figures/resp_timepoint_comp', title="Responder Timepoint Differences per Condition")

In [None]:
# Usage
nonresponder_timepoint_comp = plot_topo_grid(nonresponder_group, measures, comparisons, info, 
                     file = 'c:/Users/j_m289/Pictures/phd/3. Data Analysis/studies/OKTOS/resting_te/analysis/brms/figures/nonresponder_timepoint_comp', title="Non-Responder Timepoint Differences per Condition")

## Exporting Channel Median Differences

In [None]:
def process_data(df, comparisons):
    # Initialize empty lists to store results
    comparison_list = []
    group_list = []
    condition_list = []
    output_list = []

    # Loop over each metric
    for metric in ['btwn', 'clcoef', 'indgr', 'outdgr']:

        # Loop over each timepoint
        for comparison in comparisons:

            # Filter the dataframe based on the current metric and timepoint
            temp_df = df[(df['measure'] == metric) & (df['comparison'] == comparison)]

            # Iterate over each row of the filtered DataFrame
            for index, row in temp_df.iterrows():
                # Check if p-value is less than or equal to 0.01
                if row['pvalue'] <= 0.01:
                    # Concatenate the channel with the value in the median_difference column
                    result = f"{row['channels']} ({round(row['median_diff'],3)})"

                    # Add the timepoint and metric to the comparison list
                    comparison_list.append(f"{metric} - {comparison}")
                    
					#Add the group to the group list
                    group_list.append(f"{row['group']}")
                    
					#Add the group to the group list
                    condition_list.append(f"{row['condition']}")
                    
                    # Add the concatenated result to the output list
                    output_list.append(result)
    
    # Create a DataFrame with the results for all timepoints and metrics
    result_df = pd.DataFrame({
        'group': group_list,
        'condition': condition_list,
        'comparison': comparison_list,
        'output': output_list
    })

    return result_df

In [None]:
# change to dataframe of your choosing and provide comparisons
comparisons = list(df_timepoint['comparison'].unique())
channel_data = process_data(df_timepoint, comparisons)

channel_data.to_csv('c:/Users/j_m289/Pictures/phd/3. Data Analysis/studies/OKTOS/resting_te/analysis/brms/output/channel_median_differences_timepoint.csv')