In [None]:
# Set up the environment and define functions

# Import packages
import gc
import numpy as np
import matplotlib.pyplot as plt
import os
import os.path as op
import time
import pandas as pd
import glob
import csv
import mne
from mne.preprocessing.nirs import tddr
from nilearn.glm.first_level import make_first_level_design_matrix  
from mne_nirs.channels import get_long_channels
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests
import seaborn as sns
from scipy import signal
from scipy.stats import ttest_rel, zscore
import mne_nirs

mne.viz.set_browser_backend('matplotlib')

######### Set these variables as appropriate:
raw_path = '../../data'
proc_path = '../../processed'
results_path = '../../results'
subjects_dir = '../../subjects'
subject_group_mapping = pd.read_csv('../../subject_group_mapping.csv')
behavior_results_path = '../../fnirs-behavior-results'
behavior_file = '../../behavior_diff_data.csv'
output_suffix = "final" # used for all file names that are created

# Create the subject to group mapping dictionary
subject_group_mapping = subject_group_mapping.dropna(subset=['Subject'])  
subject_group_mapping['Subject'] = subject_group_mapping['Subject'].astype(int) 
subject_to_group = dict(zip(subject_group_mapping['Subject'], subject_group_mapping['Group']))
subjects = subject_group_mapping['Subject'].astype(str).tolist()

sfreq = 4.807692
conditions = ('A', 'V', 'AV', 'W')
groups = ('trained', 'control')
days = ('1', '3')
runs = (1, 2)
duration = 1.8
design = 'event'
filt_kwargs = dict(l_freq=0.01, h_freq=0.2) 
n_jobs = 4  # for GLM

os.makedirs(proc_path, exist_ok=True)
os.makedirs(results_path, exist_ok=True)
os.makedirs(subjects_dir, exist_ok=True)
# mne.datasets.fetch_fsaverage(subjects_dir=subjects_dir, verbose=True)  # Only need to run once

use = None
all_sci = list()
plt.rcParams['axes.titlesize'] = 8
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8

# Prep making bad channels report
bad_channels_filename = op.join(results_path, f'bad_channels_report_{output_suffix}.csv')

with open(bad_channels_filename, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Subject', 'Day', 'Run', 'Percent Bad'])

def add_bad_channel_entry(subject, day, run, percentage_bad):
    with open(bad_channels_filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([subject, day, run, f'{percentage_bad:.2f}%'])

def normalize_channel_names(channels_set):
    return {name.split()[0] for name in channels_set}

# Sanity check for subjects
subjects_check = {int(subject) for subject in subjects}
subject_to_group_check = set(subject_to_group.keys())
if subjects_check == subject_to_group_check:
    print("N=" + str(len(subjects)))
    del subjects_check, subject_to_group_check
else:
    print("Error loading subject info") 

In [None]:
# Set parameters for preprocessing

def preprocess_fnirs_data(raw_intensity, proc_path, base):
    # 1. Convert to optical density:
    print(f'    Analyzing {base}')
    raw_od = mne.preprocessing.nirs.optical_density(raw_intensity, verbose='error')

    # 2. Identify bad channels based on flat signal and scalp coupling index:
    peaks = np.ptp(raw_od.get_data('fnirs'), axis=-1)
    flat_names = [raw_od.ch_names[f].split(' ')[0] for f in np.where(peaks < 0.001)[0]]
    sci = mne.preprocessing.nirs.scalp_coupling_index(raw_od)
    sci_mask = (sci < 0.25)
    got = np.where(sci_mask)[0]
    percentage_bad = (len(got) / len(raw_od.ch_names)) * 100
    assert raw_od.info['bads'] == []
    bads = set(raw_od.ch_names[pick] for pick in got)
    bads = bads | set(ch_name for ch_name in raw_od.ch_names if ch_name.split(' ')[0] in flat_names)
    bads = sorted(bads)

    # 3. Apply temporal derivative distribution repair (TDDR), bandpass filter, apply bad channels:
    raw_tddr = tddr(raw_od)
    raw_tddr_bp = raw_tddr.copy().filter(**filt_kwargs)
    raw_tddr_bp.info['bads'] = bads

    # 5. Short channel regression (if present): 
    try:
        raw_tddr_bp = mne_nirs.signal_enhancement.short_channel_regression(raw_tddr_bp)
    except:
        print(f"No short channels found for {base}.")

    # 6. Convert to hemoglobin concentration:
    raw_h = mne.preprocessing.nirs.beer_lambert_law(raw_tddr_bp, 6.)

    # 7. Normalize channel names and verify bad channels
    h_bads = [ch_name for ch_name in raw_h.ch_names if ch_name.split(' ')[0] in set(bad.split(' ')[0] for bad in bads)]
    raw_h.info['bads'] = h_bads
    raw_h.info._check_consistency()

    # 8. Select long channels and verify that the signal is not flat:
    raw_h = get_long_channels(raw_h)
    picks = mne.pick_types(raw_h.info, fnirs=True)
    peaks = np.ptp(raw_h.get_data(picks), axis=-1)
    assert (peaks > 1e-9).all()

    # 9. Interpolate bad channels
    raw_h_interp = raw_h.copy().interpolate_bads(reset_bads=True, method=dict(fnirs='nearest'))
    raw_h_interp.save(op.join(proc_path, f'{subject}_{day}_{run:03d}_long_hbo_{output_suffix}_raw.fif'), overwrite=True)
    assert len(raw_h.ch_names) == len(raw_h_interp.ch_names)

    return raw_h_interp, percentage_bad, bads

In [None]:
# Load participant data

#subjects = ['223'] #testing

for subject in subjects:
    for day in days:
        for run in runs:
            group = subject_to_group.get(int(subject), "unknown")
            root1 = f'Day{day}'
            root2 = f'{subject}_{day}'
            root3 = f'*-*-*_{run:03d}'
            fname_base = op.join(raw_path, root1, root2, root3)
            fname = glob.glob(fname_base)
            base = f'{subject}_{day}_{run:03d}'
            base_pr = base.ljust(20)
            raw_intensity = mne.io.read_raw_nirx(fname[0])
            raw_intensity, percentage_bad_long, bads_long = preprocess_fnirs_data(raw_intensity, proc_path, base + '_long')
            add_bad_channel_entry(subject, day, run, percentage_bad_long)
            del raw_intensity, percentage_bad_long, bads_long
            gc.collect()  


In [None]:
# Remove subjects with >30% bad channels 

bad_channels_df = pd.read_csv(bad_channels_filename)
bad_channels_df['Percent Bad'] = bad_channels_df['Percent Bad'].str.rstrip('%').astype(float)
average_bad_channels = bad_channels_df.groupby('Subject')['Percent Bad'].mean()

# Find subjects with more than 30% bad channels
bad_subjects = average_bad_channels[average_bad_channels > 30].index.tolist()
print("Subjects with more than 30% bad channels:", bad_subjects)

# Initialize counters for each group
removed_trained = 0
removed_control = 0
remaining_trained = 0
remaining_control = 0

# Count and remove the subjects
for subject in bad_subjects:
    subject_int = int(subject) 
    if subject_int in subject_to_group:
        if subject_to_group[subject_int] == "trained":
            removed_trained += 1
        elif subject_to_group[subject_int] == "control":
            removed_control += 1
        subject_to_group.pop(subject_int, None)

# Update the subjects list after counting the removed subjects
subjects = [subject for subject in subjects if subject not in bad_subjects]
for group in subject_to_group.values():
    if group == "trained":
        remaining_trained += 1
    elif group == "control":
        remaining_control += 1

# Output the results
print(" ")
print(f'Removed {removed_trained} trained subjects.')
print(f'Removed {removed_control} control subjects.')
print(" ")
print(f'Remaining trained subjects: {remaining_trained}')
print(f'Remaining control subjects: {remaining_control}')

In [None]:
# Clean events and make design matrix

def make_design(raw_h_long, design, subject=None, run=None, day=None, group=None):
    annotations_to_remove = raw_h_long.annotations.description == '255.0'
    raw_h_long.annotations.delete(annotations_to_remove)
    events, _ = mne.events_from_annotations(raw_h_long)
    
    # Fix mis-codings
    rows_to_remove = events[:, -1] == 1
    events = events[~rows_to_remove]
    if len(events) == 101:
        events = events[1:]

    n_times = len(raw_h_long.times)
    stim = np.zeros((n_times, 4))
    events[:, 2] -= 1
    assert len(events) == 100, len(events)
    want = [0] + [25] * 4
    count = np.bincount(events[:, 2])
    assert np.array_equal(count, want), count
    assert events.shape == (100, 3), events.shape

    if design == 'block':
        events = events[0::5]
        duration = 20.
        assert np.array_equal(np.bincount(events[:, 2]), [0] + [5] * 4)
    else:
        assert design == 'event'
        assert len(events) == 100
        duration = 1.8
        assert events.shape == (100, 3)
        events_r = events[:, 2].reshape(20, 5)
        assert (events_r == events_r[:, :1]).all()
        del events_r
        
    idx = (events[:, [0, 2]] - [0, 1]).T
    assert np.in1d(idx[1], np.arange(len(conditions))).all()
    stim[tuple(idx)] = 1
    
    n_block = int(np.ceil(duration * sfreq))
    stim = signal.fftconvolve(stim, np.ones((n_block, 1)), axes=0)[:n_times]
    dm_events = pd.DataFrame({
        'trial_type': [conditions[ii] for ii in idx[1]],
        'onset': idx[0] / raw_h_long.info['sfreq'],
        'duration': n_block / raw_h_long.info['sfreq']})
    dm = make_first_level_design_matrix(
        raw_h_long.times, dm_events, hrf_model='glover',
        drift_model='polynomial', drift_order=0)
        
    return stim, dm, events


In [None]:
# Change the subject, day, and run to plot different waveforms

plot_subject = '223'
plot_day = 1
plot_run = 1

fname2 = op.join(proc_path, f'{plot_subject}_{plot_day}_{plot_run:03d}_long_hbo_{output_suffix}_raw.fif')
use = mne.io.read_raw_fif(fname2, preload=True)
events, _ = mne.events_from_annotations(use)
ch_names = [ch_name.rstrip(' hbo') for ch_name in use.ch_names]
info = use.info

fig, axes = plt.subplots(2, 1, figsize=(6., 3), constrained_layout=True)
ax = axes[0]
raw_h = use
stim, dm, _ = make_design(raw_h, design)

colors = dict(
    A='#4477AA',  # blue
    AV='#CCBB44',  # yellow
    V='#EE7733',  # orange
    W='#AA3377',  # purple
)

for ci, condition in enumerate(conditions):
    color = colors[condition]
    ax.fill_between(
        raw_h.times, stim[:, ci], 0, edgecolor='none', facecolor='k',
        alpha=0.5)
    model = dm[conditions[ci]].to_numpy()
    ax.plot(raw_h.times, model, ls='-', lw=1, color=color)
    x = raw_h.times[np.where(model > 0)[0][0]]
    ax.text(
        x + 10, 1.1, condition, color=color, fontweight='bold', ha='center')
ax.set(ylabel='Modeled\noxyHb', xlabel='', xlim=raw_h.times[[0, -1]])

# HbO/HbR
ax = axes[1]
picks = [pi for pi, ch_name in enumerate(raw_h.ch_names)
         if 'S7_D19' in ch_name]
colors = dict(hbo='r', hbr='b')
ylim = np.array([-1, 1])
for pi, pick in enumerate(picks):
    color = colors[raw_h.ch_names[pick][-3:]]
    data = raw_h.get_data(pick)[0] * 1e6
    val = np.ptp(data)
    assert val > 0.01
    ax.plot(raw_h.times, data, color=color, lw=1.)
ax.set(ylim=ylim, xlabel='Time (s)', ylabel='μM',
       xlim=raw_h.times[[0, -1]])
for ax in axes:
    for key in ('top', 'right'):
        ax.spines[key].set_visible(False)
plt.savefig(op.join(results_path, f'figure_1_{output_suffix}.png'))


In [None]:
# Run GLM analysis and epoching

subj_cha_list = []
for subject in subjects:
    group = subject_to_group.get(int(subject), "unknown")
    for day in days:
        for run in runs:
            fname_long = op.join(proc_path, f'{subject}_{day}_{run:03d}_long_hbo_{output_suffix}_raw.fif')
            raw_h_long = mne.io.read_raw_fif(fname_long)
            _, dm, _ = make_design(raw_h_long, design, subject, run, day, group)
            glm_est = mne_nirs.statistics.run_glm(
                raw_h_long, dm, noise_model='ols', n_jobs=n_jobs)
            cha = glm_est.to_dataframe()
            cha['subject'] = subject
            cha['run'] = run
            cha['day'] = day
            cha['group'] = group
            subj_cha_list.append(cha)
            del raw_h_long, dm, glm_est, cha
            gc.collect()  #
        print(f'***Finished processing subject {subject} day {day}.')

df_cha = pd.concat(subj_cha_list, ignore_index=True)
df_cha.reset_index(drop=True, inplace=True)


In [None]:
# Block averages

event_id = {condition: ci for ci, condition in enumerate(conditions, 1)}
evokeds = {condition: dict() for condition in conditions}
for day in days:
    for subject in subjects:
        fname = op.join(proc_path, f'{subject}_{day}_{output_suffix}-ave.fif')
        tmin, tmax = -2, 38
        baseline = (None, 0)
        t0 = time.time()
        print(f'Creating block average for {subject} day {day}... ', end='')
        raws = list()
        events = list()
        for run in runs:
            fname2 = op.join(proc_path, f'{subject}_{day}_{run:03d}_long_hbo_{output_suffix}_raw.fif')
            raw_h = mne.io.read_raw_fif(fname2)
            events.append(make_design(raw_h, None, 'block', subject, run)[2])
            raws.append(raw_h)
        bads = sorted(set(sum((r.info['bads'] for r in raws), [])))
        for r in raws:
            r.info['bads'] = bads
        raw_h, events = mne.concatenate_raws(raws, events_list=events)
        epochs = mne.Epochs(raw_h, events, event_id, tmin=tmin, tmax=tmax,
                            baseline=baseline)
        this_ev = [epochs[condition].average() for condition in conditions]
        assert all(ev.nave > 0 for ev in this_ev)
        mne.write_evokeds(fname, this_ev, overwrite=True)
        print(f'{time.time() - t0:0.1f} sec')
        for condition in conditions:
            evokeds[condition][subject] = mne.read_evokeds(fname, condition)
        print(f'Done for {group} {subject} day {day} run {run:03d}... ', end='')
        del raws, events, raw_h, epochs, this_ev
        gc.collect()  #

# Mark bad channels
bad = dict()
bb = dict()

for day in days:
    for subject in subjects:
        for run in runs:
            fname2 = op.join(proc_path, f'{subject}_{day}_{run:03d}_long_hbo_{output_suffix}_raw.fif')
            this_info = mne.io.read_info(fname2)
            bad_channels = [idx - 1 for idx in sorted(
                this_info['ch_names'].index(bad) + 1 for bad in this_info['bads'])]
            valid_indices = np.arange(len(use.ch_names))
            bb = [b for b in bad_channels if b in valid_indices]
            bad[(subject, run, day)] = bb
        assert np.in1d(bad[(subject, run, day)], np.arange(len(use.ch_names))).all()

bad_combo = dict()
for day in days:
    for (subject, run, day), bb in bad.items():
        bad_combo[subject] = sorted(set(bad_combo.get(subject, [])) | set(bb))
bad = bad_combo

start = len(df_cha)
n_drop = 0
for day in days:
    for (subject, run, day), bb in bad.items():
        if not len(bb):
            continue
        drop_names = [use.ch_names[b] for b in bb]
        is_subject = (df_cha['subject'] == subject)
        is_day = (df_cha['day'] == day)
        drop = df_cha.index[
            is_subject &
            is_day &
            np.in1d(df_cha['ch_name'], drop_names)]
        n_drop += len(drop)
        if len(drop):
            print(f'Dropping {len(drop)} for {subject} day {day}')
            df_cha.drop(drop, inplace=True)
end = len(df_cha)
assert n_drop == start - end, (n_drop, start - end)

# Combine runs by averaging
sorts = ['subject', 'ch_name', 'Chroma', 'Condition', 'group', 'day', 'run']
df_cha.sort_values(sorts, inplace=True)
theta = np.array(df_cha['theta']).reshape(-1, len(runs)).mean(-1)
df_cha.drop(
    [col for col in df_cha.columns if col not in sorts[:-1]], axis='columns',
    inplace=True)
df_cha.reset_index(drop=True, inplace=True)
df_cha = df_cha[::len(runs)]
df_cha.reset_index(drop=True, inplace=True)
df_cha['theta'] = theta

In [None]:
# Calculate HbDiff

# Load the data
df_cha_nolabels = df_cha.copy()
df_cha_nolabels['ch_name'] = df_cha_nolabels['ch_name'].str[:-4]

# Separate HbO and HbR
df_hbo = df_cha_nolabels[df_cha_nolabels['Chroma'].str.endswith('hbo')].set_index(['subject', 'Condition', 'group', 'day', 'ch_name']).sort_index()
df_hbr = df_cha_nolabels[df_cha_nolabels['Chroma'].str.endswith('hbr')].set_index(['subject', 'Condition', 'group', 'day', 'ch_name']).sort_index()

# Compute the difference
df_cha_diff_list = []
for ch_name in df_hbo.index.get_level_values('ch_name').unique():
    # Get aligned indices
    df_hbo_ch = df_hbo.loc[(slice(None), slice(None), slice(None), slice(None), ch_name), :].sort_index()
    df_hbr_ch = df_hbr.loc[(slice(None), slice(None), slice(None), slice(None), ch_name), :].sort_index()
    
    # Ensure df_hbo_ch and df_hbr_ch have the same length
    common_index = df_hbo_ch.index.intersection(df_hbr_ch.index)
    df_hbo_ch = df_hbo_ch.loc[common_index]
    df_hbr_ch = df_hbr_ch.loc[common_index]
    
    # Calculate the difference
    df_diff = df_hbo_ch[['theta']].sub(df_hbr_ch[['theta']])
    
    # Align df_cha_ch with df_diff
    df_cha_ch = df_hbo_ch.reset_index()
    df_cha_ch['theta'] = df_diff.values
    df_cha_ch['Chroma'] = 'hbdiff'
    df_cha_ch['ch_name'] = df_cha_ch['ch_name'] + ' hbdiff'
    
    if not df_cha_ch.empty:
        df_cha_diff_list.append(df_cha_ch)

df_cha_diff_concat = pd.concat(df_cha_diff_list, ignore_index=True)

# Concatenate original df_cha with df_cha_diff_concat
df_final = pd.concat([df_cha, df_cha_diff_concat], ignore_index=True)
df_final.to_csv(op.join(results_path, f'df_combined_final_cha_{output_suffix}.csv'), index=False)


Run correlational analyses below.

In [None]:
import os.path as op
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests
from statsmodels.tools.sm_exceptions import ValueWarning
import warnings
warnings.filterwarnings("ignore", category=ValueWarning) # type: ignore
warnings.filterwarnings("ignore", category=UserWarning) # type: ignore

# Load the datasets
behavior_file = '../../behavior_diff_data.csv'
behavior_df = pd.read_csv(behavior_file)
behavior_df['subject'] = behavior_df['subject'].astype(str)
df_final = pd.read_csv(op.join('../../results/df_combined_final_cha_final.csv'))
theta_df = df_final.copy()
theta_df['subject'] = theta_df['subject'].astype(str)
results_path = '../../results/correlations'

# Get the unique conditions
conditions = ['A', 'AV', 'V']
response_vars = ['AO_WR', 'AV_WR', 'TBW']
chromas = ['hbo', 'hbr', 'hbdiff']
days = [1, 3]

In [None]:
# Baseline theta vs. WR changes

def perform_analysis(group, output_suffix):
    # Initialize a list to store significant models
    significant_models = []
    
    # Initialize a dictionary to store p-values and model data by condition, response variable, chroma, and day
    all_p_values = {condition: {response_var: {chroma: {day: [] for day in days} for chroma in chromas} for response_var in response_vars} for condition in conditions}
    all_model_data = {condition: {response_var: {chroma: {day: [] for day in days} for chroma in chromas} for response_var in response_vars} for condition in conditions}

    # Track maximum R-squared value
    max_r_squared = 0
    max_r_squared_model = None

    for day in days:
        theta_df_filtered = theta_df[theta_df['day'] == day].copy()
        theta_df_filtered['ch_name'] = theta_df_filtered['ch_name'].str.split(' ').str[0]

        for chroma in chromas:
            # Filter theta_df for the specific chroma
            theta_dataset = theta_df_filtered[theta_df_filtered['Chroma'] == chroma].copy()
            theta_dataset['ch_name'] = theta_dataset['ch_name'].str.split(' ').str[0]

            # Collect all p-values and model data for each response variable and condition
            for condition in conditions:
                for response_var in response_vars:
                    # Filter the dataset for the current condition
                    theta_df_condition = theta_dataset[theta_dataset['Condition'] == condition]
                    
                    # Pivot the theta_df to have channel names as columns
                    theta_pivot = theta_df_condition.pivot_table(index=['subject', 'group', 'Condition'], columns='ch_name', values='theta').reset_index()
                    theta_pivot['subject'] = theta_pivot['subject'].astype(str)  # Ensure subject is string

                    # Merge the datasets based on 'subject' and 'group'
                    merged_df = pd.merge(theta_pivot, behavior_df[['subject', 'group', 'TBW', 'AO_WR', 'AV_WR', 'VO_WR', 'age', 'AO_WR_1', 'AV_WR_1', 'VO_WR_1', 'TBW_1']], on=['subject', 'group'])
                    channels = theta_df_condition['ch_name'].unique()  # list of all channel names

                    for channel in channels:
                        df = merged_df[[channel, response_var, 'group']].dropna()  # drop rows with missing values
                        if df.empty:
                            continue  # Skip this channel if there is no data
                        model = smf.ols(f"{response_var} ~ {channel}", df[df['group'] == group]).fit()
                        r_sq = model.rsquared
                        p_value_channel = model.pvalues[channel]  # p-value for the channel
                        all_p_values[condition][response_var][chroma][day].append(p_value_channel)
                        all_model_data[condition][response_var][chroma][day].append((condition, channel, response_var, model, r_sq, p_value_channel, df, chroma, day))

                    p_values = all_p_values[condition][response_var][chroma][day]
                    model_data = all_model_data[condition][response_var][chroma][day]
                    
                    if p_values:
                        # Apply FDR correction
                        rejected, p_values_corrected, _, _ = multipletests(p_values, alpha=0.05, method='fdr_bh')
                        
                        # Filter results based on FDR corrected p-values
                        for (condition, channel, response_var, model, r_sq, p_value, df, chroma, day), p_val_corr, reject in zip(model_data, p_values_corrected, rejected):
                            if p_val_corr < 0.05:
                                print(f"Group: {group}, Day: {day}, Chroma: {chroma}, Condition: {condition}, Channel: {channel}, Outcome: {response_var}\n   R-squared: {r_sq}, corrected p-value: {p_val_corr}\n")
                                significant_models.append({
                                    'Condition': condition,
                                    'Channel': channel,
                                    'Response Variable': response_var,
                                    'R-squared': r_sq,
                                    'P-value': p_value,
                                    'P-value Corrected': p_val_corr,
                                    'Model Summary': model.summary().as_text(),
                                    'Chroma': chroma,
                                    'Day': day
                                })

                                # Plot the significant results
                                plt.figure(figsize=(8, 6))

                                # Plot trained data
                                trained_df = df[df['group'] == 'trained']
                                if not trained_df.empty:
                                    trained_model = smf.ols(f"{response_var} ~ {channel}", trained_df).fit()
                                    sns.scatterplot(x=trained_df[channel], y=trained_df[response_var], label='Trained', color='#92b6f0', s=100)
                                    sns.lineplot(x=trained_df[channel], y=trained_model.predict(trained_df), color='#92b6f0', linewidth=2)

                                # Plot control data
                                control_df = df[df['group'] == 'control']
                                if not control_df.empty:
                                    control_model = smf.ols(f"{response_var} ~ {channel}", control_df).fit()
                                    sns.scatterplot(x=control_df[channel], y=control_df[response_var], label='Control', color='gray', s=100)
                                    sns.lineplot(x=control_df[channel], y=control_model.predict(control_df), color='gray', linewidth=2)
                                
                                xlabel = ('[HbO] on Day ' + str(day) if chroma == 'hbo' else
                                        '[HbR] on Day ' + str(day) if chroma == 'hbr' else
                                        '[HbDiff] on Day ' + str(day) if chroma == 'hbdiff' else
                                        f'{chroma.upper()} on Day ' + str(day))
                                ylabel = ('Change in Auditory Word Recognition' if response_var == 'AO_WR' else
                                        'Change in Audiovisual Word Recognition' if response_var == 'AV_WR' else
                                        'Change in Visual Word Recognition' if response_var == 'VO_WR' else
                                        f'Change in {response_var}')
                                plt.legend(loc='upper right') 
                                plt.xlabel(xlabel, fontsize=16)
                                plt.ylabel(ylabel, fontsize=16)
                                plt.title(f'{ylabel} vs.\nCortical Response to {condition} Speech ({channel})', fontsize=16)
                                plt.savefig(op.join(results_path, f'{output_suffix}__{group}_day{day}_{chroma}_{condition}_{response_var}_{channel}_plot.png'))
                                plt.close()
                                
    if significant_models:
        significant_models_df = pd.DataFrame(significant_models).sort_values(by='R-squared', ascending=False)
        significant_models_df.to_csv(op.join(results_path, f'{output_suffix}_{group}_models.csv'), index=False)
    else:
        print(f'No significant models found for {group} group')

# Perform analysis for the trained group
perform_analysis('trained', 'theta_vs_behavior')

# Perform analysis for the control group
perform_analysis('control', 'theta_vs_behavior')


In [None]:
# Changes in theta vs. WR changes

# Load the datasets
theta_df_filtered = theta_df.pivot_table(index=['subject', 'group', 'Condition', 'Chroma', 'ch_name'], columns='day', values='theta').reset_index()
theta_df_filtered['theta_diff'] = theta_df_filtered[3] - theta_df_filtered[1]
output_suffix = "theta_diff_final"

def perform_analysis(group, output_suffix):
    # Initialize a list to store significant models
    theta_df_filtered['ch_name'] = theta_df_filtered['ch_name'].str.split(' ').str[0]
    significant_models = []
    all_p_values = {condition: {response_var: {chroma: [] for chroma in chromas} for response_var in response_vars} for condition in conditions}
    all_model_data = {condition: {response_var: {chroma: [] for chroma in chromas} for response_var in response_vars} for condition in conditions}

    for chroma in chromas:
        # Filter theta_df for the specific chroma
        theta_dataset = theta_df_filtered[theta_df_filtered['Chroma'] == chroma].copy()
        theta_dataset['ch_name'] = theta_dataset['ch_name'].str.split(' ').str[0]

        # Collect all p-values and model data for each response variable and condition
        for condition in conditions:
            for response_var in response_vars:
                # Filter the dataset for the current condition
                theta_df_condition = theta_dataset[theta_dataset['Condition'] == condition]
                
                # Pivot the theta_df to have channel names as columns
                theta_pivot = theta_df_condition.pivot_table(index=['subject', 'group', 'Condition'], columns='ch_name', values='theta_diff').reset_index()
                theta_pivot['subject'] = theta_pivot['subject'].astype(str)  # Ensure subject is string

                # Merge the datasets based on 'subject' and 'group'
                merged_df = pd.merge(theta_pivot, behavior_df[['subject', 'group', 'TBW', 'AO_WR', 'AV_WR', 'VO_WR', 'age', 'AO_WR_1', 'AV_WR_1', 'VO_WR_1', 'TBW_1']], on=['subject', 'group'])
                channels = theta_df_condition['ch_name'].unique()  # list of all channel names

                for channel in channels:
                    df = merged_df[[channel, response_var, 'group']].dropna()  # drop rows with missing values
                    if df.empty:
                        continue  # Skip this channel if there is no data
                    model = smf.ols(f"{response_var} ~ {channel}", df[df['group'] == group]).fit()
                    r_sq = model.rsquared
                    p_value_channel = model.pvalues[channel]  # p-value for the channel
                    all_p_values[condition][response_var][chroma].append(p_value_channel)
                    all_model_data[condition][response_var][chroma].append((condition, channel, response_var, model, r_sq, p_value_channel, df, chroma))

                p_values = all_p_values[condition][response_var][chroma]
                model_data = all_model_data[condition][response_var][chroma]
                
                if p_values:
                    # Apply FDR correction
                    rejected, p_values_corrected, _, _ = multipletests(p_values, alpha=0.05, method='fdr_bh')
                    
                    # Filter results based on FDR corrected p-values
                    for (condition, channel, response_var, model, r_sq, p_value, df, chroma), p_val_corr, reject in zip(model_data, p_values_corrected, rejected):
                        if p_val_corr < 0.05:
                            print(f"Group: {group}, Chroma: {chroma}, Condition: {condition}, Channel: {channel}, Outcome: {response_var}\n   R-squared: {r_sq}, p-value: {p_value}, corrected p-value: {p_val_corr}\n")
                            significant_models.append({
                                'Condition': condition,
                                'Channel': channel,
                                'Response Variable': response_var,
                                'R-squared': r_sq,
                                'P-value': p_value,
                                'P-value Corrected': p_val_corr,
                                'Model Summary': model.summary().as_text(),
                                'Chroma': chroma,
                            })
                            # Plot the significant results
                            plt.figure(figsize=(8, 6))

                            # Plot trained data
                            trained_df = df[df['group'] == 'trained']
                            if not trained_df.empty:
                                trained_model = smf.ols(f"{response_var} ~ {channel}", trained_df).fit()
                                sns.scatterplot(x=trained_df[channel], y=trained_df[response_var], label='Trained', color='#92b6f0', s=100)
                                sns.lineplot(x=trained_df[channel], y=trained_model.predict(trained_df), color='#92b6f0', linewidth=2)

                            # Plot control data
                            control_df = df[df['group'] == 'control']
                            if not control_df.empty:
                                control_model = smf.ols(f"{response_var} ~ {channel}", control_df).fit()
                                sns.scatterplot(x=control_df[channel], y=control_df[response_var], label='Control', color='gray', s=100)
                                sns.lineplot(x=control_df[channel], y=control_model.predict(control_df), color='gray', linewidth=2)
                            
                            xlabel = f'{chroma.upper()} Change (Day 3 - Day 1) on {channel}'
                            ylabel = ('Change in Auditory Word Recognition' if response_var == 'AO_WR' else
                                    'Change in Audiovisual Word Recognition' if response_var == 'AV_WR' else
                                    'Change in Visual Word Recognition' if response_var == 'VO_WR' else
                                    f'Change in {response_var}')
                            plt.legend(loc='upper right') 
                            plt.xlabel(xlabel, fontsize=16)
                            plt.ylabel(ylabel, fontsize=16)
                            plt.title(f'{ylabel} vs.\nCortical Response to {condition} Speech ({channel})', fontsize=16)
                            plt.savefig(op.join(results_path, f'{output_suffix}__{group}_{chroma}_{condition}_{response_var}_{channel}_plot.png'))
                            plt.close()
                                
    if significant_models:
        significant_models_df = pd.DataFrame(significant_models).sort_values(by='R-squared', ascending=False)
        csv_file = op.join(results_path, f'{output_suffix}_{group}_models.csv')
        if not op.isfile(csv_file):
            significant_models_df.to_csv(csv_file, index=False, mode='w')
        else:
            significant_models_df.to_csv(csv_file, index=False, mode='a', header=False)
    else:
        print(f'No significant models found for {group} group')


# Perform analysis for the trained group
perform_analysis('trained', 'theta_diff_vs_behavior')

# Perform analysis for the control group
perform_analysis('control', 'theta_diff_vs_behavior')


In [None]:
# Run same analysis for baseline theta vs. change in theta values

# Pivot table to get both Day 1 theta and theta_diff (Day 3 - Day 1)
theta_df_filtered = theta_df.pivot_table(index=['subject', 'group', 'Condition', 'Chroma', 'ch_name'], columns='day', values='theta').reset_index()
theta_df_filtered['theta_diff'] = theta_df_filtered[3] - theta_df_filtered[1]  # Calculate theta_diff
theta_df_filtered['theta_baseline'] = theta_df_filtered[1]  # Baseline theta (Day 1)

# Specify the channels you want to analyze
channels_to_analyze = ['S19_D4', 'S19_D6', 'S25_D14', 'S15_D14', 'S21_D10', 'S3_D4', 'S4_D3']  # Example channels

# Get the unique conditions
conditions = ['A', 'AV', 'V']
response_vars = ['theta_diff']
chromas = ['hbo', 'hbr', 'hbdiff']

def perform_analysis(group, output_suffix, channels):
    # Initialize a list to store significant models
    theta_df_filtered['ch_name'] = theta_df_filtered['ch_name'].str.split(' ').str[0]
    significant_models = []
    
    # Reset counters here
    total_channels_analyzed = 0
    significant_channels_count = 0
    
    for chroma in chromas:
        all_p_values = []
        model_data = []

        for condition in conditions:
            print(f"\nStarting analysis for condition: {condition}, chroma: {chroma}, group: {group}\n")
            # Filter the dataset for the current condition and chroma
            theta_df_condition = theta_df_filtered[(theta_df_filtered['Chroma'] == chroma) & (theta_df_filtered['Condition'] == condition)]
            
            # Filter the dataset to include only the specified channels
            theta_df_condition = theta_df_condition[theta_df_condition['ch_name'].isin(channels)]

            # Pivot the theta_df to have channel names as columns, keeping 'group' in the index
            theta_pivot = theta_df_condition.pivot_table(index=['subject', 'group'], columns='ch_name', values=['theta_diff', 'theta_baseline'])
            theta_pivot.columns = ['_'.join(col).strip() if isinstance(col, tuple) else col for col in theta_pivot.columns.values]  # Flatten the MultiIndex
            
            # Reset the index to make 'group' a column again
            theta_pivot.reset_index(inplace=True)
            
            # Iterate over each channel
            for channel in channels:                
                # Ensure both baseline and diff values are available
                baseline_column = f'theta_baseline_{channel}'
                diff_column = f'theta_diff_{channel}'
                
                if baseline_column in theta_pivot.columns and diff_column in theta_pivot.columns:
                    df = theta_pivot[[diff_column, baseline_column, 'group']].dropna()  # drop rows with missing values
                    if df.empty:
                        print(f"  No data available for channel: {channel}, skipping.")
                        continue  # Skip this channel if there is no data

                    total_channels_analyzed += 1  # Increment the total channels analyzed counter

                    # Regression model: theta_diff ~ theta_baseline
                    model = smf.ols(f"{diff_column} ~ {baseline_column}", df[df['group'] == group]).fit()
                    r_sq = model.rsquared
                    p_value_channel = model.pvalues[baseline_column]  # p-value for the channel

                    # Collect p-values and models for FDR correction
                    all_p_values.append(p_value_channel)
                    model_data.append((condition, channel, r_sq, p_value_channel, model, chroma, baseline_column, diff_column, df))
        
        # Apply FDR correction
        rejected, p_values_corrected, _, _ = multipletests(all_p_values, alpha=0.05, method='fdr_bh')

        # Filter results based on FDR corrected p-values and plot only if significant
        for i, (condition, channel, r_sq, p_value, model, chroma, baseline_column, diff_column, df) in enumerate(model_data):
            if p_values_corrected[i] < 0.05:
            #if p_values_corrected[i] < 0.05:
                print(f"Channel: {channel}, {condition}, chroma: {chroma}, group: {group}")
                print(f"   p-value: {p_value}, corrected p-value: {p_values_corrected[i]}, R-squared: {r_sq}\n")
                significant_channels_count += 1  # Increment the significant channels counter
                significant_models.append({
                    'Condition': condition,
                    'Channel': channel,
                    'R-squared': r_sq,
                    'P-value': p_value,
                    'P-value Corrected': p_values_corrected[i],
                    'Model Summary': model.summary().as_text(),
                    'Chroma': chroma,
                    'Group': group
                })

                # Plot the significant results
                plt.figure(figsize=(8, 6))

                # Plot trained data
                trained_df = df[df['group'] == 'trained']
                if not trained_df.empty:
                    trained_model = smf.ols(f"{diff_column} ~ {baseline_column}", trained_df).fit()
                    sns.scatterplot(x=trained_df[baseline_column], y=trained_df[diff_column], label='Trained', color='#92b6f0', s=100)
                    sns.lineplot(x=trained_df[baseline_column], y=trained_model.predict(trained_df), color='#92b6f0', linewidth=2)

                # Plot control data
                control_df = df[df['group'] == 'control']
                if not control_df.empty:
                    control_model = smf.ols(f"{diff_column} ~ {baseline_column}", control_df).fit()
                    sns.scatterplot(x=control_df[baseline_column], y=control_df[diff_column], label='Control', color='gray', s=100)
                    sns.lineplot(x=control_df[baseline_column], y=control_model.predict(control_df), color='gray', linewidth=2)
                
                xlabel = f'Baseline {chroma.upper()} (Day 1) on {channel}'
                ylabel = f'{chroma.upper()} Change (Day 3 - Day 1)'
                plt.legend(loc='upper right') 
                plt.xlabel(xlabel, fontsize=16)
                plt.ylabel(ylabel, fontsize=16)
                plt.title(f'{ylabel} vs.\nBaseline Cortical Response to {condition} Speech ({channel})', fontsize=16)
                plt.savefig(op.join(results_path, f'{output_suffix}__{group}_{chroma}_{condition}_{channel}_plot.png'))
                plt.close()

    # Save significant models to CSV
    if significant_models:
        significant_models_df = pd.DataFrame(significant_models).sort_values(by='R-squared', ascending=False)
        csv_file = op.join(results_path, f'{output_suffix}_{group}_models.csv')
        if not op.isfile(csv_file):
            significant_models_df.to_csv(csv_file, index=False, mode='w')
        else:
            significant_models_df.to_csv(csv_file, index=False, mode='a', header=False)
    else:
        print(f'No significant models found for group {group}')
    
    # Print the counter
    print(f"\n{significant_channels_count} out of {total_channels_analyzed} channels were significantly correlated for group {group}.")

# Perform analysis for the trained group
perform_analysis('trained', 'theta_diff_vs_baseline', channels_to_analyze)


In [None]:
# Run same analysis for day 1 vs day 3 theta values

# Load the datasets
behavior_df = pd.read_csv(behavior_file)
behavior_df['subject'] = behavior_df['subject'].astype(str)

theta_df_filtered = df_final.copy()
theta_df_filtered['subject'] = theta_df_filtered['subject'].astype(str)

# Pivot table to get both Day 1 theta and theta_diff (Day 3 - Day 1)
theta_df_filtered = theta_df_filtered.pivot_table(index=['subject', 'group', 'Condition', 'Chroma', 'ch_name'], columns='day', values='theta').reset_index()
theta_df_filtered['theta_diff'] = theta_df_filtered['3'] - theta_df_filtered['1']  # Calculate theta_diff
theta_df_filtered['theta_baseline'] = theta_df_filtered['1']  # Baseline theta (Day 1)

# Specify the channels you want to analyze
channels_to_analyze = ['S19_D4', 'S19_D6', 'S25_D14', 'S15_D14', 'S21_D10', 'S3_D4', 'S4_D3', 'S8_D18']  # Example channels

# Get the unique conditions
conditions = ['A', 'AV', 'V']
chromas = ['hbo', 'hbr', 'hbdiff']

def perform_analysis(group, output_suffix, channels):
    # Initialize a list to store significant models
    theta_df_filtered['ch_name'] = theta_df_filtered['ch_name'].str.split(' ').str[0]
    significant_models = []
    total_channels_analyzed = 0
    significant_channels_count = 0

    for chroma in chromas:
        all_p_values = []
        model_data = []

        # Filter theta_df for the specific chroma
        theta_dataset = theta_df_filtered[theta_df_filtered['Chroma'] == chroma].copy()
        theta_dataset['ch_name'] = theta_dataset['ch_name'].str.split(' ').str[0]

        # Collect all p-values and model data for each condition
        for condition in conditions:
            print(f"\nStarting analysis for condition: {condition}, chroma: {chroma}, group: {group}\n")
            # Filter the dataset for the current condition
            theta_df_condition = theta_dataset[theta_dataset['Condition'] == condition].copy()
            
            # Filter the dataset to include only the specified channels
            theta_df_condition = theta_df_condition[theta_df_condition['ch_name'].isin(channels)]

            # Pivot the theta_df to have channel names as columns, keeping 'group' in the index
            theta_pivot = theta_df_condition.pivot_table(index=['subject', 'group'], columns='ch_name', values=['1', '3'])  # 1 for Day 1, 3 for Day 3
            theta_pivot.columns = [f'{day}_{ch_name}' for day, ch_name in theta_pivot.columns]  # Simplify the column names
            
            # Reset the index to make 'group' a column again
            theta_pivot.reset_index(inplace=True)
            
            # Iterate over each channel
            for channel in channels:                
                # Ensure both Day 1 and Day 3 values are available
                day1_column = f'1_{channel}'
                day3_column = f'3_{channel}'
                
                if day1_column in theta_pivot.columns and day3_column in theta_pivot.columns:
                    df = theta_pivot[[day1_column, day3_column, 'group']].dropna()  # drop rows with missing values
                    if df.empty:
                        print(f"  No data available for channel: {channel}, skipping.")
                        continue  # Skip this channel if there is no data

                    total_channels_analyzed += 1  # Increment the total channels analyzed counter

                    # Regression model: Day 3 ~ Day 1
                    formula = f'Q("{day3_column}") ~ Q("{day1_column}")'
                    model = smf.ols(formula, df[df['group'] == group]).fit()
                    r_sq = model.rsquared
                    p_value_channel = model.pvalues[f'Q("{day1_column}")']  # p-value for the channel

                    # Collecting all p-values and models
                    all_p_values.append(p_value_channel)
                    model_data.append((condition, channel, r_sq, p_value_channel, model, chroma))

        if all_p_values:
            # Apply FDR correction
            rejected, p_values_corrected, _, _ = multipletests(all_p_values, alpha=0.05, method='fdr_bh')

            # Filter results based on FDR corrected p-values and plot only if significant
            for i, (condition, channel, r_sq, p_value, model, chroma) in enumerate(model_data):
                if p_values_corrected[i] < 0.05:
                    print(f"Channel: {channel}, {condition}, chroma: {chroma}, group: {group}\n   R-squared: {r_sq}, corrected p-value: {p_values_corrected[i]}\n")
                    significant_channels_count += 1  # Increment the significant channels counter
                    significant_models.append({
                        'Condition': condition,
                        'Channel': channel,
                        'R-squared': r_sq,
                        'P-value': p_value,
                        'P-value Corrected': p_values_corrected[i],
                        'Model Summary': model.summary().as_text(),
                        'Chroma': chroma,
                        'Group': group
                    })

                    # Plot the significant results
                    plt.figure(figsize=(8, 6))

                    # Plot trained data
                    trained_df = df[df['group'] == 'trained']
                    if not trained_df.empty:
                        trained_model = smf.ols(formula, trained_df).fit()
                        sns.scatterplot(x=trained_df[day1_column], y=trained_df[day3_column], label='Trained', color='#92b6f0', s=100)
                        sns.lineplot(x=trained_df[day1_column], y=trained_model.predict(trained_df), color='#92b6f0', linewidth=2)

                    # Plot control data
                    control_df = df[df['group'] == 'control']
                    if not control_df.empty:
                        control_model = smf.ols(formula, control_df).fit()
                        sns.scatterplot(x=control_df[day1_column], y=control_df[day3_column], label='Control', color='gray', s=100)
                        sns.lineplot(x=control_df[day1_column], y=control_model.predict(control_df), color='gray', linewidth=2)

                    xlabel = f'{chroma.upper()} (Day 1) on {channel}'
                    ylabel = f'{chroma.upper()} (Day 3)'
                    plt.legend(loc='upper right') 
                    plt.xlabel(xlabel, fontsize=16)
                    plt.ylabel(ylabel, fontsize=16)
                    plt.title(f'{ylabel} vs.\nCortical Response to {condition} Speech ({channel})', fontsize=16)
                    plt.savefig(op.join(results_path, f'{output_suffix}__{group}_{chroma}_{condition}_{channel}_plot.png'))
                    plt.close()

    # Save significant models to CSV
    if significant_models:
        significant_models_df = pd.DataFrame(significant_models).sort_values(by='R-squared', ascending=False)
        csv_file = op.join(results_path, f'{output_suffix}_{group}_models.csv')
        if not op.isfile(csv_file):
            significant_models_df.to_csv(csv_file, index=False, mode='w')
        else:
            significant_models_df.to_csv(csv_file, index=False, mode='a', header=False)
    else:
        print(f'No significant models found for group {group}')
    
    # Print the counter
    print(f"\n{significant_channels_count} out of {total_channels_analyzed} channels were significantly correlated for group {group}.")

# Perform analysis for the trained group
perform_analysis('trained', 'theta_day1_vs_day3', channels_to_analyze)



Code below is still in progress...

In [None]:
# Plot sensors of interest

import mne
import matplotlib.pyplot as plt
import os.path as op

# Use the info from your existing data
proc_path = '../../processed'
fname = op.join(proc_path, f'205_1_001_long_hbo_final_raw.fif')
use = mne.io.read_raw_fif(fname, preload=True)
use.load_data()
info = use.copy().info

# Specify the names of the sensors you want to plot
selected_sensors = ['S25_D14 hbo']
valid_sensors = [sensor for sensor in selected_sensors if sensor in info['ch_names']]

if not valid_sensors:
    print("No valid sensors selected for plotting.")
else:
    # Increase the font size for sensor names
    plt.rcParams.update({'font.size': 20})  # Set the desired font size

    # Pick only the selected channels
    info_picked = mne.pick_info(info, mne.pick_channels(info['ch_names'], valid_sensors))
    info_picked.rename_channels({ch: '  ' + ch for ch in info_picked['ch_names']})

    # Plot the sensors manually with blue circles
    fig = mne.viz.plot_sensors(info_picked, kind='topomap', show_names=False, pointsize=800, linewidth=0)

    plt.show()



In [None]:
# Plot some waveforms

import mne
import os.path as op
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Load subject group mapping
subject_group_mapping = pd.read_csv('../../subject_group_mapping.csv')
subject_group_mapping['Subject'] = subject_group_mapping['Subject'].astype(str).str.strip().str.split('.').str[0]
subject_to_group = dict(zip(subject_group_mapping['Subject'], subject_group_mapping['Group']))

# Ensure all subjects are in string format without decimals
subjects = subject_group_mapping['Subject'].tolist()
subjects = [subject for subject in subjects if subject not in ['202', '203', '204', '206', '214', '221', '223', '226', '233']]

# Define the subjects in the 'trained' group
trained_subjects = [subject for subject in subjects if subject_to_group[subject] == 'control']

# Specify the channel of interest
channel_of_interest = 'S25_D14'
condition = 'AV'

# Initialize the dictionaries to store evoked data for averaging and plotting
evoked_dict_day1 = {'HbO': [], 'HbR': []}
evoked_dict_day3 = {'HbO': [], 'HbR': []}

# Loop over subjects and load their evoked data
for subject in trained_subjects:
    for day in ['1', '3']:
        # Define the filename
        fname = op.join(proc_path, f'{subject}_{day}_final-ave.fif')
        
        # Load the evoked data
        evokeds = mne.read_evokeds(fname, condition=condition, baseline=(None, 0), verbose=False)
                
        # Pick the HbO and HbR data for the channel of interest
        evoked_hbo = evokeds.copy().pick([f'{channel_of_interest} hbo'])
        evoked_hbr = evokeds.copy().pick([f'{channel_of_interest} hbr'])
        
        evoked_hbo.rename_channels(lambda x: x[:-4])
        evoked_hbr.rename_channels(lambda x: x[:-4])
        
        # Check if both HbO and HbR data exist for this subject and day
        if len(evoked_hbo.ch_names) == 1 and len(evoked_hbr.ch_names) == 1:
            # Store in the appropriate evoked dictionary
            if day == '1':
                evoked_dict_day1['HbO'].append(evoked_hbo)
                evoked_dict_day1['HbR'].append(evoked_hbr)
            elif day == '3':
                evoked_dict_day3['HbO'].append(evoked_hbo)
                evoked_dict_day3['HbR'].append(evoked_hbr)

# Compute grand average for Day 1 and Day 3
grand_avg_day1_hbo = mne.grand_average(evoked_dict_day1['HbO'])
grand_avg_day1_hbr = mne.grand_average(evoked_dict_day1['HbR'])
grand_avg_day3_hbo = mne.grand_average(evoked_dict_day3['HbO'])
grand_avg_day3_hbr = mne.grand_average(evoked_dict_day3['HbR'])

# Prepare the grand average dictionary for plotting
grand_avg_dict_day1 = {'HbO': grand_avg_day1_hbo, 'HbR': grand_avg_day1_hbr}
grand_avg_dict_day3 = {'HbO': grand_avg_day3_hbo, 'HbR': grand_avg_day3_hbr}

# Define the colors and styles
color_dict_day1 = {'HbO': 'red', 'HbR': 'blue'}
color_dict_day3 = {'HbO': 'red', 'HbR': 'blue'}
styles_dict_day1 = {'HbO': {'linestyle': '--'}, 'HbR': {'linestyle': '--'}}
styles_dict_day3 = {'HbO': {'linestyle': '-'}, 'HbR': {'linestyle': '-'}}

# Plot for Day 1 using the grand average
mne.viz.plot_compare_evokeds(
    grand_avg_dict_day1,
    ci=0.95, colors=color_dict_day1, styles=styles_dict_day1,
    title=f'{channel_of_interest} ({condition} Condition) - Day 1', show=True,
    truncate_xaxis=(-2, 15),  # This line sets the x-axis limits
    ylim=dict(hbo=(-0.10, 0.10))  # This line sets the y-axis limits
)

# Plot for Day 3 using the grand average
mne.viz.plot_compare_evokeds(
    grand_avg_dict_day3,
    ci=0.95, colors=color_dict_day3, styles=styles_dict_day3,
    title=f'{channel_of_interest} ({condition} Condition) - Day 3', show=True,
    truncate_xaxis=(-2, 15),  # This line sets the x-axis limits
    ylim=dict(hbo=(-0.10, 0.10))  # This line sets the y-axis limits
)



# Initialize the dictionaries to store evoked data for plotting
evoked_dict_day1 = {'HbO': [], 'HbR': []}
evoked_dict_day3 = {'HbO': [], 'HbR': []}

# Loop over subjects and load their evoked data
for subject in trained_subjects:
    for day in ['1', '3']:
        # Define the filename
        fname = op.join(proc_path, f'{subject}_{day}_final-ave.fif')
        
        # Load the evoked data
        evokeds = mne.read_evokeds(fname, condition=condition, baseline=(None, 0), verbose=False)
                
        # Pick the HbO and HbR data for the channel of interest
        evoked_hbo = evokeds.copy().pick([f'{channel_of_interest} hbo'])
        evoked_hbr = evokeds.copy().pick([f'{channel_of_interest} hbr'])
        
        evoked_hbo.rename_channels(lambda x: x[:-4])
        evoked_hbr.rename_channels(lambda x: x[:-4])
        
        # Check if both HbO and HbR data exist for this subject and day
        if len(evoked_hbo.ch_names) == 1 and len(evoked_hbr.ch_names) == 1:
            # Store in the appropriate evoked dictionary
            if day == '1':
                evoked_dict_day1['HbO'].append(evoked_hbo)
                evoked_dict_day1['HbR'].append(evoked_hbr)
            elif day == '3':
                evoked_dict_day3['HbO'].append(evoked_hbo)
                evoked_dict_day3['HbR'].append(evoked_hbr)

# Check if all participants were included
expected_subjects = len(trained_subjects)
actual_subjects_day1 = len(evoked_dict_day1['HbO'])  # Same number for HbO and HbR
actual_subjects_day3 = len(evoked_dict_day3['HbO'])  # Same number for HbO and HbR

if actual_subjects_day1 != expected_subjects:
    print(f"Warning: Only {actual_subjects_day1} out of {expected_subjects} subjects were included in the Day 1 analysis.")
else:
    print("All participants were included in the Day 1 analysis.")

if actual_subjects_day3 != expected_subjects:
    print(f"Warning: Only {actual_subjects_day3} out of {expected_subjects} subjects were included in the Day 3 analysis.")
else:
    print("All participants were included in the Day 3 analysis.")

# Define the colors and styles
color_dict_day1 = {'HbO': 'red', 'HbR': 'blue'}
color_dict_day3 = {'HbO': 'red', 'HbR': 'blue'}
styles_dict_day1 = {'HbO': {'linestyle': '--'}, 'HbR': {'linestyle': '--'}}
styles_dict_day3 = {'HbO': {'linestyle': '-'}, 'HbR': {'linestyle': '-'}}

# Plot for Day 1 using the individual evoked responses
mne.viz.plot_compare_evokeds(
    {'HbO': evoked_dict_day1['HbO'], 'HbR': evoked_dict_day1['HbR']},
    ci=0.95, colors=color_dict_day1, styles=styles_dict_day1,
    title=f'{channel_of_interest} ({condition} Condition) - Day 1', show=True,
    truncate_xaxis=(-2, 15),  # This line sets the x-axis limits
    ylim=dict(hbo=(-0.10, 0.10))  # This line sets the y-axis limits
)

# Plot for Day 3 using the individual evoked responses
mne.viz.plot_compare_evokeds(
    {'HbO': evoked_dict_day3['HbO'], 'HbR': evoked_dict_day3['HbR']},
    ci=0.95, colors=color_dict_day3, styles=styles_dict_day3,
    title=f'{channel_of_interest} ({condition} Condition) - Day 3', show=True,
    truncate_xaxis=(-2, 15),  # This line sets the x-axis limits
    ylim=dict(hbo=(-0.10, 0.10))  # This line sets the y-axis limits
)


In [None]:
# Compare thetas in 2 channels

# Filter data for the specific channels you want to compare
channels_to_analyze = ['S19_D6', 'S15_D14']

from statsmodels.tools.sm_exceptions import ValueWarning
warnings.filterwarnings("ignore", category=ValueWarning) # type: ignore
warnings.filterwarnings("ignore", category=UserWarning) # type: ignore

# Load the datasets
behavior_df = pd.read_csv(behavior_file)
behavior_df['subject'] = behavior_df['subject'].astype(str)

theta_df_filtered = df_final.copy()
theta_df_filtered['subject'] = theta_df_filtered['subject'].astype(str)

# Pivot table to get both Day 1 theta and theta_diff (Day 3 - Day 1)
theta_df_filtered = theta_df_filtered.pivot_table(index=['subject', 'group', 'Condition', 'Chroma', 'ch_name'], columns='day', values='theta').reset_index()
theta_df_filtered['theta_diff'] = theta_df_filtered[3] - theta_df_filtered[1]  # Calculate theta_diff
theta_df_filtered['theta_baseline'] = theta_df_filtered[1]  # Baseline theta (Day 1)
theta_df_filtered = theta_df_filtered[theta_df_filtered['group'] != 'unknown']


def perform_channel_comparison(group, output_suffix):
    # Initialize a list to store significant models
    theta_df_filtered['ch_name'] = theta_df_filtered['ch_name'].str.split(' ').str[0]
    significant_models = []
    
    # Iterate over each condition and chroma
    for chroma in chromas:
        all_p_values = []
        model_data = []

        for condition in conditions:
            print(f"\nStarting analysis for condition: {condition}, chroma: {chroma}, group: {group}\n")
            # Filter the dataset for the current condition and chroma
            theta_df_condition = theta_df_filtered[(theta_df_filtered['Chroma'] == chroma) & (theta_df_filtered['Condition'] == condition)]
            
            # Pivot the theta_df to have channel names as columns
            theta_pivot = theta_df_condition.pivot_table(index=['subject', 'group'], columns='ch_name', values=['theta_diff', 'theta_baseline'])
            theta_pivot.columns = ['_'.join(col).strip() if isinstance(col, tuple) else col for col in theta_pivot.columns.values]  # Flatten the MultiIndex
            
            # Reset the index to make 'group' a column again
            theta_pivot.reset_index(inplace=True)

            # Ensure both channels are in the pivoted dataframe
            if f'theta_diff_S19_D6' in theta_pivot.columns and f'theta_diff_S15_D14' in theta_pivot.columns:
                df = theta_pivot[[f'theta_diff_S19_D6', f'theta_diff_S15_D14', f'theta_baseline_S19_D6', f'theta_baseline_S15_D14', 'group']].dropna()
                
                if df.empty:
                    print("  No data available for the selected channels, skipping.")
                    continue  # Skip this condition if there is no data

                # Perform regression analysis for S19_D6 vs. S20_D10
                model = smf.ols(f"theta_diff_S15_D14 ~ theta_diff_S19_D6", df[df['group'] == group]).fit()
                r_sq = model.rsquared
                p_value_channel = model.pvalues[f'theta_diff_S19_D6']

                # Collect p-values and models for FDR correction
                all_p_values.append(p_value_channel)
                model_data.append((condition, 'S19_D6_vs_S15_D14', r_sq, p_value_channel, model, chroma, df))
            else:
                print(f"  One or both channels S19_D6 and S15_D14 are missing for condition {condition}, skipping.")
                continue  # Skip this condition if there is no data

        # Apply FDR correction
        if len(all_p_values) > 0:
            rejected, p_values_corrected, _, _ = multipletests(all_p_values, alpha=0.05, method='fdr_bh')

        # Filter results based on FDR corrected p-values and plot only if significant
        for i, (condition, comparison, r_sq, p_value, model, chroma, df) in enumerate(model_data):
            if p_values_corrected[i] < 0.05:
                print(f"Comparison: {comparison}, {condition}, chroma: {chroma}, group: {group}")
                print(f"   p-value: {p_value}, corrected p-value: {p_values_corrected[i]}, R-squared: {r_sq}\n")
                
                significant_models.append({
                    'Condition': condition,
                    'Comparison': comparison,
                    'R-squared': r_sq,
                    'P-value': p_value,
                    'P-value Corrected': p_values_corrected[i],
                    'Model Summary': model.summary().as_text(),
                    'Chroma': chroma,
                    'Group': group
                })

                # Plot the significant results
                plt.figure(figsize=(8, 6))
                sns.scatterplot(x=df[f'theta_diff_S19_D6'], y=df[f'theta_diff_S15_D14'], hue=df['group'], palette=['#92b6f0', 'gray'], s=100)
                sns.lineplot(x=df[f'theta_diff_S19_D6'], y=model.predict(df), color='black', linewidth=2)

                xlabel = f'{chroma.upper()} Change in S19_D6 (Day 3 - Day 1)'
                ylabel = f'{chroma.upper()} Change in S15_D14 (Day 3 - Day 1)'
                plt.xlabel(xlabel, fontsize=16)
                plt.ylabel(ylabel, fontsize=16)
                plt.title(f'Comparison of {xlabel} and {ylabel} in {condition} condition', fontsize=16)
                plt.legend(title='Group', loc='upper right')
                plt.savefig(op.join(results_path, f'{output_suffix}__{group}_{chroma}_{condition}_S19_D6_vs_S15_D14_plot.png'))
                plt.close()

    # Save significant models to CSV
    if significant_models:
        significant_models_df = pd.DataFrame(significant_models).sort_values(by='R-squared', ascending=False)
        csv_file = op.join(results_path, f'{output_suffix}_{group}_models.csv')
        if not op.isfile(csv_file):
            significant_models_df.to_csv(csv_file, index=False, mode='w')
        else:
            significant_models_df.to_csv(csv_file, index=False, mode='a', header=False)
    else:
        print(f'No significant models found for group {group}')

# Perform the comparison analysis for the trained group
perform_channel_comparison('trained', 'S19_D6_vs_S15_D14_comparison')



In [None]:
""" # Run paired t-test over days and plot changes over time

# Load the final combined dataframe
df_final = pd.read_csv(op.join(results_path, f'df_combined_final_cha_{output_suffix}.csv'))
conditions = ['A', 'AV', 'V']
groups = ['trained', 'control']
chromas = ['hbo', 'hbr', 'hbdiff']

df_final['ch_name'] = df_final['ch_name'].str.split(' ').str[0]
fname = op.join(proc_path, f'205_1_001_long_hbo_final_raw.fif')
use = mne.io.read_raw_fif(fname, preload=True)
use.load_data()
new_ch_names = {}
seen_names = set()
for ch_name in use.info['ch_names']:
    new_name = ch_name.split(' ')[0]
    if new_name not in seen_names:
        new_ch_names[ch_name] = new_name
        seen_names.add(new_name)

use.rename_channels(new_ch_names)
use = use.pick_channels(list(new_ch_names.values()))

# Perform analysis for each group and Chroma
for group in groups:
    for chroma in chromas:
        # Prepare figure for composite plots
        fig, axes = plt.subplots(1, len(conditions), figsize=(15, 5))
        for idx, condition in enumerate(conditions):
            # Filter data for day 1 and day 3 for the specific group and Chroma
            df_day1 = df_final.query(f"group == '{group}' and Chroma == '{chroma}' and day == 1").copy()
            df_day3 = df_final.query(f"group == '{group}' and Chroma == '{chroma}' and day == 3").copy()

            # Ensure ch_name and Condition columns are of the same data type
            df_day1['ch_name'] = df_day1['ch_name'].astype(str)
            df_day1['Condition'] = df_day1['Condition'].astype(str)
            df_day3['ch_name'] = df_day3['ch_name'].astype(str)
            df_day3['Condition'] = df_day3['Condition'].astype(str)

            # Set index and sort
            df_day1 = df_day1.set_index(['subject', 'group', 'ch_name', 'Condition', 'Chroma']).sort_index()
            df_day3 = df_day3.set_index(['subject', 'group', 'ch_name', 'Condition', 'Chroma']).sort_index()

            # Merge dataframes to align day 1 and day 3 data
            df_merged = df_day1[['theta']].rename(columns={'theta': 'theta_day1'}).merge(
                df_day3[['theta']].rename(columns={'theta': 'theta_day3'}),
                left_index=True, right_index=True)

            # Calculate the difference and z-score
            df_merged['theta_diff'] = df_merged['theta_day3'] - df_merged['theta_day1']
            df_merged['z'] = zscore(df_merged['theta_diff'])

            # Perform paired t-test for each channel and condition across subjects
            t_stats = []
            p_values = []
            ch_names = []
            condition_list = []

            for (ch_name, cond), group_df in df_merged.groupby(['ch_name', 'Condition']):
                t_stat, p_value = ttest_rel(group_df['theta_day1'], group_df['theta_day3'])
                t_stats.append(t_stat)
                p_values.append(p_value)
                ch_names.append(ch_name)
                condition_list.append(cond)

            # Create a results DataFrame
            results_df = pd.DataFrame({
                'ch_name': ch_names,
                'Condition': condition_list,
                't_stat': t_stats,
                'p_value': p_values
            })

            # Combine with z-score data
            z_scores = df_merged.groupby(['ch_name', 'Condition'])['z'].mean().reset_index()
            
            # Ensure consistent data types before merging
            z_scores['ch_name'] = z_scores['ch_name'].astype(str)
            z_scores['Condition'] = z_scores['Condition'].astype(str)
            results_df['ch_name'] = results_df['ch_name'].astype(str)
            results_df['Condition'] = results_df['Condition'].astype(str)

            results_df = results_df.merge(z_scores, on=['ch_name', 'Condition'])

            # Correct for multiple comparisons
            print(f'Correcting for {len(results_df["p_value"])} comparisons using FDR')
            _, results_df['P_fdr'] = mne.stats.fdr_correction(results_df['p_value'], method='indep')
            results_df['SIG'] = results_df['P_fdr'] < 0.05
            
            # Print significant results
            significant_results = results_df.loc[results_df.SIG == True]
            print(significant_results)

            # Prepare data for brain plots
            ch_of_interest = use.pick_channels([ch_name for ch_name in use.info['ch_names']])
            info_of_interest = ch_of_interest.info

            zs = {}
            condition_data = results_df[(results_df['Condition'] == condition)]
                        
            zs[condition] = np.array([
                condition_data.loc[(condition_data['ch_name'] == ch_name), 'z'].values[0]
                if not condition_data.loc[(condition_data['ch_name'] == ch_name), 'z'].empty and condition_data.loc[(condition_data['ch_name'] == ch_name), 'p_value'].values[0] < 0.05
                else 0
                for ch_name in condition_data['ch_name']
            ])
            
            # Create an EvokedArray for each condition
            evoked = mne.EvokedArray(zs[condition][:, np.newaxis], info_of_interest)
            picks = np.arange(len(info_of_interest['ch_names']))

            stc = mne.stc_near_sensors(
                evoked, trans='fsaverage', subject='fsaverage', mode='weighted',
                distance=0.02, project=True, picks=picks, subjects_dir=subjects_dir)

            # Plot the brain and capture the image in-memory
            brain = stc.plot(hemi='both', views=['lat', 'frontal', 'lat'],
                             cortex='low_contrast', time_viewer=False, show_traces=False,
                             surface='pial', smoothing_steps=0, size=(1200, 400),
                             clim=dict(kind='value', pos_lims=[0, 0.75, 1.5]),
                             colormap='RdBu_r', view_layout='horizontal',
                             colorbar=(0, 1), time_label='', background='w',
                             brain_kwargs=dict(units='m'),
                             add_data_kwargs=dict(colorbar_kwargs=dict(
                                 title_font_size=16, label_font_size=12, n_labels=5,
                                 title='z score')), subjects_dir=subjects_dir)
            brain.show_view('lat', hemi='lh', row=0, col=0)
            brain.show_view(azimuth=270, elevation=90, row=0, col=1)
            brain.show_view('lat', hemi='rh', row=0, col=2)

            # Capture the plot as an image in memory
            screenshot = brain.screenshot(time_viewer=False)
            brain.close()

            # Display the image in the composite figure
            ax = axes[idx]
            ax.imshow(screenshot)
            ax.axis('off')
            ax.set_title(f'{group.capitalize()} - Condition {condition} ({chroma})', fontsize=18)

            del df_day1, df_day3, df_merged, t_stats, p_values, ch_names, condition_list, results_df, z_scores
            gc.collect()  #

        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.savefig(op.join(results_path, f'{group}_{chroma}_composite_brain_plots_insig.png'))
        plt.show()
        plt.close(fig)

        del fig, axes, ch_of_interest, info_of_interest, evoked, stc, brain, screenshot
        gc.collect()  #
 """

In [None]:
""" # Plot topographic maps of significant models

for group in groups:
    df_r2 = pd.read_csv(op.join(results_path, f'{group}_fnirs-behavior-models_{group}.csv'))
    df_filtered = df_r2[(df_r2['P-value Corrected'] < 0.05)]
    ch_names = df_filtered['Channel'].values 
    info = use.copy().pick_types(fnirs='hbo', exclude=())
    info_picked = info.pick_channels(ch_names)
    fig = mne.viz.plot_sensors(info_picked.info, kind='topomap', show_names=True, pointsize=100, linewidth=0]
    plt.savefig(op.join(results_path, f'{group}_sig-p-corr_topomap.png'))
    plt.show()
    plt.close(fig)
 """