# Exploring Grand Averages

This notebook provides a means to investigate the grand averages for the YAC study.

Reseachers/students are not intended to make changes here.

In [None]:
import glob
import mne
import numpy as np
import re
import pandas as pd
from IPython.display import display
import matplotlib.pyplot as plt

In [None]:
# Dictionary keyed by condition; values are a list of averaged (mne.Evoked) subject waveforms.
conditions = ['VO21', 'VO22', 'VO23', 'VO24', 'VO25']
frequencies = np.arange(7, 30, 1)
averaging_dict = {k:[] for k in conditions}

In [None]:
power_channels = ['E6', 'E75', 'E55']
all_subject_paths = sorted(glob.glob('derivatives/segmented/*.fif'))
for sub in all_subject_paths:
    new_epoch = mne.read_epochs(sub)
    for condition in conditions:
        power, itc = mne.time_frequency.tfr_morlet(new_epoch[condition].pick(power_channels[0]), n_cycles=2, return_itc=True, freqs=frequencies, decim=2)
        trimmed_epoch = new_epoch.crop(tmin=-0.1, tmax=1.0).apply_baseline((-0.1, 0))
        averaging_dict[condition].append((re.findall(r'_(\d+)_', sub)[0], new_epoch[condition].average(), power, itc))

In [None]:
def condition_summary(condition_label):
    print('Working on: ', condition_label)
    grand_average = mne.grand_average([item[1] for item in averaging_dict[condition_label]])
    display(grand_average)
    grand_average.plot()
    times = np.arange(-0.1, 1.0, .1)
    fig = grand_average.plot_topomap(times=times, colorbar=True)
    fig.suptitle(condition_label)

In [None]:
for condition in conditions:
    condition_summary(condition)

In [None]:
mne.viz.plot_compare_evokeds({k:[t[1] for t in v] for k, v in averaging_dict.items()}, picks=['E71'])

In [None]:
mne.viz.plot_compare_evokeds({k:[t[1] for t in v] for k, v in averaging_dict.items()}, picks=['E6'])

## Between Groups

Below groups are taken from emailed spreadsheet.

When making a new comparison, copy the below structure,
change values, and pass to the `evoked_compare_wrapper`
function. The groups should stay the current format of
a list of `(Condition label, list of string subject IDs)`.

In [None]:
# Generates a evoked comparison of variable groups based
# on defined comparison structure above.
# Parameters:
#     comparison_dict: See previous cells for definition.
#     ci: If true, draw confidence intervals. Defaults true.
def evoked_compare_wrapper(comparison_dict, ci=True):
    evokeds = {}
    for label, group_subjects in comparison_dict['groups']:
        evokeds[label] = []
        for cond in comparison_dict['condition']:
            for subject, data, _, _ in averaging_dict[cond]:
                if subject in group_subjects: 
                    evokeds[label].append(data)
    if not ci:
        for k,v in evokeds.items():
            evokeds[k] = mne.grand_average(evokeds[k])
    mne.viz.plot_compare_evokeds(evokeds, picks=comparison_dict['channels'],
                                 title=comparison_dict['figure_title'], combine='mean')

In [None]:
sample_comparison = {
    'groups': [
               ('Top MASC Report', top_masc_report),
               ('Bot MASC Report', bot_masc_report),
              ],
    'condition': ['VO24'],
    'channels': ['E6'],
    'figure_title': 'test',
}
evoked_compare_wrapper(sample_comparison, ci=True)

In [None]:
%matplotlib inline
# This is invoked - time freq of whole conditions ERP


# confidence interval erps - f diff is the two confidence intervals, 
# also plot the difference of erps with ci envelope
# measure of width of confidence interval
# f_obs -> how many hops of the width of the confidence interval to gwet back to zero
# f_obs_plot - > the mask is like the mne confi where the confidence interval of the difference does not include zero
#      threshold 6 is how you control not including zero

cond1 = mne.grand_average([item[3] for item in averaging_dict['VO21']])
cond2 = mne.grand_average([item[3] for item in averaging_dict['VO22']])

epochs_power_1 = np.array([item[3].data for item in averaging_dict['VO21']])[:, 0, :, :]
epochs_power_2 = np.array([item[3].data for item in averaging_dict['VO22']])[:, 0, :, :]

F_obs, clusters, cluster_p_values, H0 = mne.stats.permutation_cluster_test(
    [epochs_power_1, epochs_power_2],
    out_type="mask",
    n_permutations=100,
    threshold=6.0,
    tail=0,
) # returns F difference, sampled, zscore

times = 1e3 * averaging_dict['VO21'][0][2].times  # for changing the unit to ms

evoked_power_contrast = epochs_power_1.mean(axis=0) - epochs_power_2.mean(axis=0)
signs = np.sign(evoked_power_contrast)

F_obs_plot = np.nan * np.ones_like(F_obs)
for c, p_val in zip(clusters, cluster_p_values):
    if p_val <= 0.05:
        F_obs_plot[c] = F_obs[c] * signs[c]
max_F = np.nanmax(abs(F_obs_plot))

fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 4))
ax.imshow(
    F_obs,
    # extent=[times[0], times[-1], frequencies[0], frequencies[-1]],
    aspect="auto",
    origin="lower",
    cmap="gray",
)

ax.imshow(
    F_obs_plot,
    # extent=[times[0], times[-1], frequencies[0], frequencies[-1]],
    aspect="auto",
    origin="lower",
    cmap="RdBu_r",
    vmin=-max_F,
    vmax=max_F,
)
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Frequency (Hz)")
ax.set_title("Induced power")

evoked_contrast = mne.combine_evoked(
    [cond1, cond2], weights=[1, -1]
)
evoked_contrast.plot(axes=ax2)