# AMME Day2 LFP Analyses

<p>
This notebook contains code for loading preprocessed AMME data and performing LFP analyses on the first 500ms of lfp data after the onset of an image (same as PNAS 2018 analyses)
</p>

---
> Author:    Martina Hollearn    
> Contact:   martina.hollearn@psych.utah.edu   
> Version:   05/16/2024

## 1. Import Libraries

In [1]:
import os
import mne
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.io import loadmat
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.stats import ttest_rel, ttest_ind
from IPython.display import clear_output
import statsmodels.api as sm
from statsmodels.stats.multitest import multipletests
from matplotlib.colors import TwoSlopeNorm
from mne.time_frequency import tfr_multitaper, tfr_array_multitaper, csd_multitaper
from tensorpac import Pac, EventRelatedPac
from tensorpac.utils import ITC
import matplotlib.patches as mpatches
from scipy.interpolate import interp1d
from mne.time_frequency import EpochsSpectrum

%matplotlib inline
%config InlineBackend.figure_format='retina'

## 2. Load Preprocessed Data
Pulls information from:
Pull information from AMME datafiles per subject that are preprocessed:
- PreprocessedData/subject_respcondition.npy: numpy array of response condition for that subject (1-yes, 0-no responses extracted from log files)
- PreprocessedData/subject_stimcondition.npy: numpy array of stimulation condition (PNAS: 1-stim 0-nostim, 9-new; Duration: 1-onesec stim, 3-threesec stim, 0-nostim, 9-new; Timing: 1-before stim, 2-during stim, 3-afterstim, 0-nostim, 9-new)
- PreprocessedData/ChanLabels.csv: channel labels for usable channels (DroppedChans.csv shows the dropped channel labels after visual inspection)
- PreprocessedData/PreprocessedData.npy: the epoched, filtered, and downsampled (fs=500) lfp data (DroppedEpochs.csv shows the dropped epochs after visual inspection)
- PreprocessedData/Events.npy: numpy array of events IDs for recreating epochs

Check to make sure that the file names for each files called in match. Filenames are inconsistent throughtout the AMME dataset.

In [2]:
# Params
verbosity = 0  # control output verbosity

# Define path for locating & saving data
projDir = '/Users/martinahollearn/Library/CloudStorage/Box-Box/InmanLab/AMME_Data_Emory/AMME_Data'
subject = 'amyg072'
datapath = os.path.join(projDir, subject)
preproDataPath = os.path.join(datapath, 'PreprocessedData')
my_preprocessing_path = os.path.join(preproDataPath, 'Martinas_preprocessing')
savepath = os.path.join(my_preprocessing_path, 'LFP_analysis_results')
logfile_path = os.path.join(datapath, f'{subject}_Lamyg_day2.log') #make sure log file name matches
event_filename = f'{subject}_Lamyg_LFP_day2_trialtimes.mat'
events_path = os.path.join(datapath, subject, event_filename)

# Create results folder if it doesn't exist
if not os.path.exists(savepath):
    os.mkdir(savepath)

try:
    # Load preprocessed data
    data = np.load(os.path.join(my_preprocessing_path, 'PreprocessedData.npy'))
    events = np.load(os.path.join(my_preprocessing_path, 'Events.npy'))
    chans = pd.read_csv(os.path.join(my_preprocessing_path, 'ChanLabels.csv'), index_col=0)['Chan'].to_list()
    bad_chans = pd.read_csv(os.path.join(my_preprocessing_path, 'DroppedChans.csv'), index_col=0)['Dropped Chans'].to_list()
    bad_epochs= pd.read_csv(os.path.join(my_preprocessing_path, 'DroppedEpochs.csv'), index_col=0)['Dropped Epochs'].to_list()
    print(bad_epochs)
except:
    print('Error loading data')
    raise


[]


### 2.1 Create `MNE` Objects

In [3]:
# Create MNE object for all channels
info = mne.create_info(ch_names = chans, ch_types = 'seeg', sfreq=500, verbose = verbosity)
n_chans = len(info.ch_names)

ROIs = {
    'Hipp': ['3Ld4'],
    'BLA': ['1Ld4'],
    'EC': ['1Ld2'],
    'PRC': [],
    'PHG': [],
    'MTL': [], # Leave this empty. Will be auto computed in a few lines below
}

# Add all channels to the 'all'
for roi_name in ROIs.keys():
    if roi_name != 'MTL':
        ROIs['MTL'] += ROIs[roi_name]


for roi_name, roi_channels in ROIs.items():
    # Skip processing if we don't have any channels.
    if roi_channels == []:
        continue

    print(f"==== Preprocessing ROIs: {roi_name}")

    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_channels.npy'), np.array(roi_channels)) #save the ROI channels
    print(f"Manually specified ROI {roi_name} channels for:", roi_channels)

    # Find the missing channels in the channel_names list
    missing_ROI_channels = [ch for ch in roi_channels if ch not in info.ch_names]
    print("We are missing the following manually specified ROI channels:", missing_ROI_channels) #probably filtered out
    print("number of ROI channels", len(roi_channels))

print("========= Setting Epoch Array object")
epochs = mne.EpochsArray(data, info, events, tmin = -5) #tmin = -5 to start the epoch 5 seconds before the event, so the 0 will be the event onset not the start of the epoch


==== Preprocessing ROIs: Hipp
Manually specified ROI Hipp channels for: ['3Ld4']
We are missing the following manually specified ROI channels: []
number of ROI channels 1
==== Preprocessing ROIs: BLA
Manually specified ROI BLA channels for: ['1Ld4']
We are missing the following manually specified ROI channels: []
number of ROI channels 1
==== Preprocessing ROIs: EC
Manually specified ROI EC channels for: ['1Ld2']
We are missing the following manually specified ROI channels: []
number of ROI channels 1
==== Preprocessing ROIs: MTL
Manually specified ROI MTL channels for: ['3Ld4', '1Ld4', '1Ld2']
We are missing the following manually specified ROI channels: []
number of ROI channels 3
Not setting metadata
200 matching events found
No baseline correction applied
0 projection items activated


## 2.2 Re-Epoch data from remembered vs forgotten to stim and no stim
1. Save remembered and forgotten epochs separately raw, wihtout the stim info
2. Filter the original log file and the Signal Detection Response Data file for hits and misses so we can get the stim condition at each hit and miss
3. Double check our numbers in each stimulation and response condition to make sure we didn't make a mistake
4. Construct new event type based on stimulation within remembered and forgotten LFP files
5. Re-epoch the data to have a remembered LFP file with stimulation event types for epoch event markers, and same for forgotten LFP file.
6. Plot the new LFP data to ensure the re-epoching worked

In [4]:
remembered = epochs['1']
forgotten = epochs['0']

print("Remembered epochs shape:", remembered.get_data().shape)
print("Forgotten epochs shape:", forgotten.get_data().shape)

np.save(os.path.join(my_preprocessing_path, 'remembered_rawdata.npy'), remembered.get_data().copy())
np.save(os.path.join(my_preprocessing_path, 'forgotten_rawdata.npy'), forgotten.get_data().copy())

#Load original logfile
logfile = pd.read_csv(logfile_path, delimiter='\t', skiprows=2, skipfooter=1) #tab delimited csv file is our log file format, reject first 2 rows and the last row
#Load Signal Detection file
trialconditions = pd.read_csv(os.path.join(my_preprocessing_path, f'{subject}_SignalDetection_ResponseData.csv'), index_col=0)

#Merge the two files to compare the trial and stim conditions based on the original indices
logfile['TRIAL_CONDITION'] = trialconditions['Response']
logfile = logfile[['TRIAL', 'CONDITION', 'TRIAL_CONDITION']]
print('LOGFILE',logfile, "Old logfile shape",logfile.shape)

logfile = logfile[logfile['CONDITION']!= 'new'].reset_index()
print('LOGFILE no new',logfile, "New logfile shape",logfile.shape)

# Now drop trials based on the new index rather than 'TRIAL' column
logfile = logfile.drop(bad_epochs).reset_index(drop=True)
print('LOGFILE after dropping bad epochs', logfile.shape)

# Filter for only miss and hit
logfile = logfile[(logfile['TRIAL_CONDITION'] == 'hit') | (logfile['TRIAL_CONDITION'] == 'miss')]
print("===================== TOTALS =====================")
count_df = logfile.groupby(['TRIAL_CONDITION', 'CONDITION']).size().reset_index(name='Count')
print(count_df)
count_df = logfile.groupby(['CONDITION']).size().reset_index(name='Count')
print(count_df)
count_df = logfile.groupby(['TRIAL_CONDITION']).size().reset_index(name='Count')
print(count_df)

logfile.loc[logfile['CONDITION'] == 'nostim', 'EVENT_CONDITION'] = 0
logfile.loc[logfile['CONDITION'] == 'Before stim', 'EVENT_CONDITION'] = 1
logfile.loc[logfile['CONDITION'] == 'During stim', 'EVENT_CONDITION'] = 2
logfile.loc[logfile['CONDITION'] == 'After stim', 'EVENT_CONDITION'] = 3
print(logfile)

Remembered epochs shape: (60, 233, 5001)
Forgotten epochs shape: (140, 233, 5001)


  logfile = pd.read_csv(logfile_path, delimiter='\t', skiprows=2, skipfooter=1) #tab delimited csv file is our log file format, reject first 2 rows and the last row


LOGFILE      TRIAL    CONDITION TRIAL_CONDITION
0        1          new              cr
1        2       nostim            miss
2        3       nostim             hit
3        4          new              cr
4        5  During stim            miss
..     ...          ...             ...
295    296          new              fa
296    297  During stim            miss
297    298  During stim            miss
298    299       nostim             hit
299    300          new              cr

[300 rows x 3 columns] Old logfile shape (300, 3)
LOGFILE no new      index  TRIAL    CONDITION TRIAL_CONDITION
0        1      2       nostim            miss
1        2      3       nostim             hit
2        4      5  During stim            miss
3        5      6       nostim            miss
4        6      7   After stim             hit
..     ...    ...          ...             ...
195    293    294       nostim             hit
196    294    295       nostim            miss
197    296    297  Duri

In [7]:
# Filter for only miss and hit
remembered_logfile = logfile[logfile['TRIAL_CONDITION'] == 'hit']
forgotten_logfile = logfile[logfile['TRIAL_CONDITION'] == 'miss']

# Double check the counts of each trial type
print("===================== TOTALS REMEMBERED =====================")
count_df = remembered_logfile.groupby(['TRIAL_CONDITION', 'CONDITION']).size().reset_index(name='Count')
print(count_df)
count_df = remembered_logfile.groupby(['CONDITION']).size().reset_index(name='Count')
print(count_df)
count_df = remembered_logfile.groupby(['TRIAL_CONDITION']).size().reset_index(name='Count')
print(count_df)

print("===================== TOTALS FORGOTTEN =====================")
count_df = forgotten_logfile.groupby(['TRIAL_CONDITION', 'CONDITION']).size().reset_index(name='Count')
print(count_df)
count_df = forgotten_logfile.groupby(['CONDITION']).size().reset_index(name='Count')
print(count_df)
count_df = forgotten_logfile.groupby(['TRIAL_CONDITION']).size().reset_index(name='Count')
print(count_df)


  TRIAL_CONDITION    CONDITION  Count
0             hit   After stim     10
1             hit  Before stim     19
2             hit  During stim     14
3             hit       nostim     17
     CONDITION  Count
0   After stim     10
1  Before stim     19
2  During stim     14
3       nostim     17
  TRIAL_CONDITION  Count
0             hit     60
  TRIAL_CONDITION    CONDITION  Count
0            miss   After stim     40
1            miss  Before stim     31
2            miss  During stim     36
3            miss       nostim     33
     CONDITION  Count
0   After stim     40
1  Before stim     31
2  During stim     36
3       nostim     33
  TRIAL_CONDITION  Count
0            miss    140


In [8]:
print("Events shape:", events.shape)

#Create the new, stimulation-based event arrays for re-epoching
remembered_event_array = remembered_logfile['EVENT_CONDITION'].values.astype(int)
forgotten_event_array = forgotten_logfile['EVENT_CONDITION'].values.astype(int)
print('remembered:',len(remembered_event_array), 'forgotten:', len(forgotten_event_array))
print('remembered events',remembered_event_array)
print('forgotten events',forgotten_event_array)

# Filter the DataFrame based on the third column (Label)
forgotten_events = events[events[:,2] == 0]
remembered_events = events[events[:,2] == 1]

# Replace the event labels for forgotten events
forgotten_events[:,2] = forgotten_event_array
remembered_events[:,2] = remembered_event_array

# Separate before, during, and after stim events
before_sec_remembered = remembered_events[remembered_events[:,2] == 1]
during_sec_remembered = remembered_events[remembered_events[:,2] == 2]
after_sec_remembered = remembered_events[remembered_events[:,2] == 3]
print('before remembered:',len(before_sec_remembered), 'during remembered:', len(during_sec_remembered), 'after remembered:', len(after_sec_remembered))

#Separate the 1s and 3s events from forgotten_event_array
before_sec_forgotten = forgotten_events[forgotten_events[:,2] == 1]
during_sec_forgotten = forgotten_events[forgotten_events[:,2] == 2]
after_sec_forgotten = forgotten_events[forgotten_events[:,2] == 3]

Events shape: (200, 3)
remembered: 60 forgotten: 140
remembered events [0 3 3 3 0 3 0 0 1 1 3 0 1 2 0 0 2 1 3 3 1 0 2 3 1 2 1 1 1 1 2 0 1 3 1 0 1
 0 0 2 2 2 1 2 0 3 2 2 0 1 2 2 1 0 2 1 1 1 0 0]
forgotten events [0 2 0 0 1 3 1 3 2 3 3 0 1 2 0 3 3 0 1 2 1 0 1 1 0 0 2 3 3 1 3 0 2 1 2 3 0
 3 2 2 3 1 0 1 1 3 0 1 3 2 2 3 2 0 3 1 3 3 3 2 2 2 1 0 1 2 2 3 3 3 0 3 2 1
 3 1 3 3 0 1 2 2 0 3 2 3 2 0 0 1 1 0 0 0 2 0 2 2 3 0 2 3 1 2 2 0 1 0 1 0 1
 3 3 0 3 2 2 3 3 3 3 0 1 1 2 1 3 2 3 2 0 1 3 2 1 0 1 0 2 2]
before remembered: 19 during remembered: 14 after remembered: 10


In [9]:
#Re-epoch the data based on the new event arrays (stimulation-based)
remembered_epoched_data = mne.EpochsArray(remembered.get_data().copy(), info, remembered_events, tmin = -5) #tmin - 1 means the epoch starts 1s before the event onset, so the event onset is at 0s. This way i can use the first 1s as baseline (t-1)
forgotten_epoched_data = mne.EpochsArray(forgotten.get_data().copy(), info, forgotten_events, tmin = -5)
print('remembered epochs shape:', remembered_epoched_data.get_data().shape)
print('forgotten epochs shape:', forgotten_epoched_data.get_data().shape)

#Plot the new epoched data to doulbe check epoching is correct (compare to log file)
#remembered_epoched_data.plot(events=remembered_events, title='Remembered Epochs')
#forgotten_epoched_data.plot(events=forgotten_events, title='Forgotten Epochs')

#Save the new epoched data
remembered_epoched_data.save(os.path.join(preproDataPath, 'remembered_epoched_data.npy'), overwrite = True)
forgotten_epoched_data.save(os.path.join(preproDataPath, 'forgotten_epoched_data.npy'), overwrite = True)

  remembered_epoched_data = mne.EpochsArray(remembered.get_data().copy(), info, remembered_events, tmin = -5) #tmin - 1 means the epoch starts 1s before the event onset, so the event onset is at 0s. This way i can use the first 1s as baseline (t-1)


Not setting metadata
60 matching events found
No baseline correction applied
0 projection items activated


  forgotten_epoched_data = mne.EpochsArray(forgotten.get_data().copy(), info, forgotten_events, tmin = -5)


Not setting metadata
140 matching events found
No baseline correction applied
0 projection items activated
remembered epochs shape: (60, 233, 5001)
forgotten epochs shape: (140, 233, 5001)


  print('remembered epochs shape:', remembered_epoched_data.get_data().shape)
  print('forgotten epochs shape:', forgotten_epoched_data.get_data().shape)
  remembered_epoched_data.save(os.path.join(preproDataPath, 'remembered_epoched_data.npy'), overwrite = True)
  forgotten_epoched_data.save(os.path.join(preproDataPath, 'forgotten_epoched_data.npy'), overwrite = True)


## 3. Stim-Evoked Power Modulation
The identification of *responsive* (i.e., modulated) electrodes is modeled after methods reported in [*Solomon et al. 2021*](https://www.brainstimjrnl.com/article/S1935-861X(21)00216-3/fulltext).

### 3.1 Compute Power Spectra
1. Use multitaper method to compute power for both stim and no-stim conditions for the first 500 ms after an image onset across all channels.
2. Compute power for both stim and no-stim conditions for the first 500 ms after an image onset in ROI channels only.
3. Compute power for remembered vs forgotten stim conditions for the first 500 ms after an image onset.

In [10]:
def process_epoch_freqs_separately(epoch_data, tmin=0, tmax=0.5):
    fmin_low, fmax_low = 1, 15
    bandwidth_low = 12  # Corresponds to ±6 Hz bandwidth
    fmin_high, fmax_high = 15, 120
    bandwidth_high = 20  # Corresponds to ±10 Hz bandwidth


    # Compute PSD using multitaper method
    power_low_freq_allchans = epoch_data.compute_psd(
        method='multitaper', fmin=fmin_low, fmax=fmax_low, tmin=tmin, tmax=tmax, bandwidth=bandwidth_low, adaptive=True)

    # Compute PSD using multitaper method
    power_high_freq_allchans = epoch_data.compute_psd(
        method='multitaper', fmin=fmin_high, fmax=fmax_high, tmin=tmin, tmax=tmax, bandwidth=bandwidth_high, adaptive=True)

    #assert power_low_freq_allchans.info == power_high_freq_allchans.info, "The info objects of the two EpochsSpectrum objects must be the same."


    # Step 2: Concatenate the data along the frequency axis (axis=-1)
    combined_data = np.concatenate([power_low_freq_allchans.get_data(), power_high_freq_allchans.get_data()], axis=-1)

    # Step 3: Concatenate the frequency arrays
    combined_freqs = np.concatenate([power_low_freq_allchans.freqs, power_high_freq_allchans.freqs])
    # print(f"process_epoch_freqs_separately returning shape: {combined_data.shape}, {combined_freqs.shape}")
    # print(f"combined data: {combined_data}")
    return combined_data, combined_freqs

In [None]:
#### Sanity check code
BeforeEpoch = remembered_epoched_data['1']  # Before stimulation events
DuringEpoch = remembered_epoched_data['2']  # During stimulation events
AfterEpoch = remembered_epoched_data['3']  # After stimulation events
NoStimEpoch = remembered_epoched_data['0']  # NoStim events

###### NEW METHOD:

# Compute PSD using multitaper method for low frequencies (1-15 Hz)
ns____power_low = NoStimEpoch.compute_psd(
    method='multitaper', fmin=1, fmax=15, tmin=0, tmax=0.5, bandwidth=12, adaptive=True)
before_power_low = BeforeEpoch.compute_psd(
    method='multitaper', fmin=1, fmax=15, tmin=0, tmax=0.5, bandwidth=12, adaptive=True)
during_power_low = DuringEpoch.compute_psd(
    method='multitaper', fmin=1, fmax=15, tmin=0, tmax=0.5, bandwidth=12, adaptive=True)
after_power_low = AfterEpoch.compute_psd(
    method='multitaper', fmin=1, fmax=15, tmin=0, tmax=0.5, bandwidth=12, adaptive=True)

# Compute PSD using multitaper method for high frequencies (15-120 Hz)
ns____power_high = NoStimEpoch.compute_psd(
    method='multitaper', fmin=15, fmax=120, tmin=0, tmax=0.5, bandwidth=20, adaptive=True)
before_power_high = BeforeEpoch.compute_psd(
    method='multitaper', fmin=15, fmax=120, tmin=0, tmax=0.5, bandwidth=20, adaptive=True)
during_power_high = DuringEpoch.compute_psd(
    method='multitaper', fmin=15, fmax=120, tmin=0, tmax=0.5, bandwidth=20, adaptive=True)
after_power_high = AfterEpoch.compute_psd(
    method='multitaper', fmin=15, fmax=120, tmin=0, tmax=0.5, bandwidth=20, adaptive=True)

# Plotting
plt.close()
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)  # Create subfigures

# Plot PSDs for low frequency (1-15 Hz)
ns____power_low.plot(axes=axes[0], color='blue', average=True)
before_power_low.plot(axes=axes[0], color='pink', average=True)
during_power_low.plot(axes=axes[0], color='purple', average=True)
after_power_low.plot(axes=axes[0], color= 'orange', average=True)

# Plot PSDs for high frequency (15-120 Hz)
ns____power_high.plot(axes=axes[1], color='blue', average=True)
before_power_high.plot(axes=axes[1], color='pink', average=True)
during_power_high.plot(axes=axes[1], color='purple', average=True)
after_power_high.plot(axes=axes[1], color='orange', average=True)

# Clean-up Figures
axes[0].set_title('All channel Low Freq Power')
axes[1].set_title('All channel High Freq Power')
axes[0].set_ylabel('dB')
axes[1].set_ylabel('dB')
axes[0].set_xlabel('Frequency (Hz)')
axes[1].set_xlabel('Frequency (Hz)')
sns.despine(top=True, right=True)

# Manually create legend with custom colored patches
legend_patches = [
    mpatches.Patch(color='blue', label='No Stim', alpha=0.5),
    mpatches.Patch(color='pink', label='Before', alpha=0.5),
    mpatches.Patch(color='purple', label='During', alpha=0.5),
    mpatches.Patch(color='orange', label='After', alpha=0.5)
]
axes[1].legend(handles=legend_patches, loc='upper right')
axes[0].legend().remove()  # Remove legend from the first subplot

# Show the plot
plt.show()


    Using multitaper spectrum estimation with 5 DPSS windows
    Using multitaper spectrum estimation with 5 DPSS windows
    Using multitaper spectrum estimation with 5 DPSS windows
    Using multitaper spectrum estimation with 5 DPSS windows
    Using multitaper spectrum estimation with 9 DPSS windows
    Using multitaper spectrum estimation with 9 DPSS windows
    Using multitaper spectrum estimation with 9 DPSS windows
    Using multitaper spectrum estimation with 9 DPSS windows


In [None]:
# Create stim/no-stim epochs
NoStimEpoch = remembered_epoched_data['0']
BeforeStimEpoch = remembered_epoched_data['1']
DuringStimEpoch = remembered_epoched_data['2']
AfterStimEpoch = remembered_epoched_data['3']

# Save epochs as numpy arrays
np.save(os.path.join(my_preprocessing_path, 'Remembered_NoStimEpoch.npy'), NoStimEpoch.get_data())
np.save(os.path.join(my_preprocessing_path, 'Remembered_BeforeStimEpoch.npy'), BeforeStimEpoch.get_data())
np.save(os.path.join(my_preprocessing_path, 'Remembered_DuringStimEpoch.npy'), DuringStimEpoch.get_data())
np.save(os.path.join(my_preprocessing_path, 'Remembered_AfterStimEpoch.npy'), AfterStimEpoch.get_data())

# Compute Stim and No Stim PSDs for all channels
NS_power_np_arr, NS_freqs = process_epoch_freqs_separately(NoStimEpoch, tmin=0, tmax=0.5)
BeforeStim_power_np_arr, BeforeStim_freqs = process_epoch_freqs_separately(BeforeStimEpoch, tmin=0, tmax=0.5)
DuringStim_power_np_arr, DuringStim_freqs = process_epoch_freqs_separately(DuringStimEpoch, tmin=0, tmax=0.5)
AfterStim_power_np_arr, AfterStim_freqs = process_epoch_freqs_separately(AfterStimEpoch, tmin=0, tmax=0.5)

# For each ROI set, run the plots!
for roi_name, roi_channels in ROIs.items():
    # Skip if we dont use that ROI stuff
    if roi_channels == []:
        continue

    print(f"Running ROI {roi_name} ...")

    print(f"Creating AvgStim epochs for {roi_name} ...")
    # get epoch data in numpy format
    before_stim_data = BeforeStimEpoch.get_data()
    during_stim_data = DuringStimEpoch.get_data()
    after_stim_data = AfterStimEpoch.get_data()

    # Compute PSDs for ROI channels
    NS_power_np_arr_roi, NS_freqs_roi = process_epoch_freqs_separately(NoStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    BeforeStim_power_np_arr_roi, BeforeStim_freqs_roi = process_epoch_freqs_separately(BeforeStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    DuringStim_power_np_arr_roi, DuringStim_freqs_roi = process_epoch_freqs_separately(DuringStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    AfterStim_power_np_arr_roi, AfterStim_freqs_roi = process_epoch_freqs_separately(AfterStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)

        ###### To get the averaged epoch data, we need to average across the epochs first, then average across the channels ##########
    # Average across epochs first (do not log-transform here)
    NS_power_numpy_arr_roi = np.mean(NS_power_np_arr_roi, axis=0)
    BeforeStim_power_numpy_arr_roi = np.mean(BeforeStim_power_np_arr_roi, axis=0)
    DuringStim_power_numpy_arr_roi = np.mean(DuringStim_power_np_arr_roi, axis=0)
    AfterStim_power_numpy_arr_roi = np.mean(AfterStim_power_np_arr_roi, axis=0)

    # Apply log10 after averaging across epochs
    NS_power_numpy_arr_roi = 10 * np.log10(NS_power_numpy_arr_roi)
    BeforeStim_power_numpy_arr_roi = 10 * np.log10(BeforeStim_power_numpy_arr_roi)
    DuringStim_power_numpy_arr_roi = 10 * np.log10(DuringStim_power_numpy_arr_roi)
    AfterStim_power_numpy_arr_roi = 10 * np.log10(AfterStim_power_numpy_arr_roi)

    # Average across the channels
    NS_power_numpy_arr_roi = NS_power_numpy_arr_roi.mean(axis=0)
    BeforeStim_power_numpy_arr_roi = BeforeStim_power_numpy_arr_roi.mean(axis=0)
    DuringStim_power_numpy_arr_roi = DuringStim_power_numpy_arr_roi.mean(axis=0)
    AfterStim_power_numpy_arr_roi = AfterStim_power_numpy_arr_roi.mean(axis=0)
    AvgStim_power_numpy_arr_roi = (BeforeStim_power_numpy_arr_roi + DuringStim_power_numpy_arr_roi + AfterStim_power_numpy_arr_roi) / 3

    # Save the averaged power for this ROI so that we can use it in the group script
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_NS_power_remembered.npy'), NS_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_AvgStim_power_remembered.npy'), AvgStim_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_BeforeStim_power_remembered.npy'), BeforeStim_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_DuringStim_power_remembered.npy'), DuringStim_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_AfterStim_power_remembered.npy'), AfterStim_power_numpy_arr_roi)

    print("Calculating Theta and Gamma differences ...")
    # Get freqs
    PSD_freqs = NS_freqs_roi

    # Save the avg power for this ROI into a numpy array
    thetaAvgStim = np.array(AvgStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaNoStim = np.array(NS_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaBeforeStim = np.array(BeforeStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaDuringStim = np.array(DuringStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaAfterStim = np.array(AfterStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaAvgDiff = (thetaAvgStim - thetaNoStim).mean()
    thetaBefDiff = (thetaBeforeStim - thetaNoStim).mean()
    thetaDurDiff = (thetaDuringStim - thetaNoStim).mean()
    thetaAftDiff = (thetaAfterStim - thetaNoStim).mean()

    gammaAvgStim = np.array(AvgStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaNoStim = np.array(NS_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaBeforeStim = np.array(BeforeStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaDuringStim = np.array(DuringStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaAfterStim = np.array(AfterStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaAvgDiff = (gammaAvgStim - gammaNoStim).mean()
    gammaBefDiff = (gammaBeforeStim - gammaNoStim).mean()
    gammaDurDiff = (gammaDuringStim - gammaNoStim).mean()
    gammaAftDiff = (gammaAfterStim - gammaNoStim).mean()

    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_remembered.npy'), thetaAvgDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_before_remembered.npy'), thetaBefDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_during_remembered.npy'), thetaDurDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_after_remembered.npy'), thetaAftDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_remembered.npy'), gammaAvgDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_before_remembered.npy'), gammaBefDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_during_remembered.npy'), gammaDurDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_after_remembered.npy'), gammaAftDiff)

    print("Generating 0-0.5 sec power minus -0.5-0 sec power ...")
    # Generate power for -0.5 sec to 0 sec (baseline) and for 0 sec to 0.5 sec (analysis window)
    # Then subtract the two to get the difference.
    # BASELINE
    NS_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(NoStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    NS_power_roi_baseline = 10 * np.log10(np.mean(NS_power_np_arr_roi_baseline, axis=0))
    BeforeStim_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(BeforeStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    BeforeStim_power_roi_baseline = 10 * np.log10(np.mean(BeforeStim_power_np_arr_roi_baseline, axis=0))
    DuringStim_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(DuringStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    DuringStim_power_roi_baseline = 10 * np.log10(np.mean(DuringStim_power_np_arr_roi_baseline, axis=0))
    AfterStim_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(AfterStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    AfterStim_power_roi_baseline = 10 * np.log10(np.mean(AfterStim_power_np_arr_roi_baseline, axis=0))

    # ANALYSIS WINDOW
    NS_power_np_arr_roi_post, _ = process_epoch_freqs_separately(NoStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    NS_power_roi_post = 10 * np.log10(np.mean(NS_power_np_arr_roi_post, axis=0))
    BeforeStim_power_np_arr_roi_post, _ = process_epoch_freqs_separately(BeforeStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    BeforeStim_power_roi_post = 10 * np.log10(np.mean(BeforeStim_power_np_arr_roi_post, axis=0))
    DuringStim_power_np_arr_roi_post, _ = process_epoch_freqs_separately(DuringStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    DuringStim_power_roi_post = 10 * np.log10(np.mean(DuringStim_power_np_arr_roi_post, axis=0))
    AfterStim_power_np_arr_roi_post, _ = process_epoch_freqs_separately(AfterStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    AfterStim_power_roi_post = 10 * np.log10(np.mean(AfterStim_power_np_arr_roi_post, axis=0))

    # Average across channels
    NS_power_roi_baseline = NS_power_roi_baseline.mean(axis=0)
    BeforeStim_power_roi_baseline = BeforeStim_power_roi_baseline.mean(axis=0)
    DuringStim_power_roi_baseline = DuringStim_power_roi_baseline.mean(axis=0)
    AfterStim_power_roi_baseline = AfterStim_power_roi_baseline.mean(axis=0)

    NS_power_roi_post = NS_power_roi_post.mean(axis=0)
    BeforeStim_power_roi_post = BeforeStim_power_roi_post.mean(axis=0)
    DuringStim_power_roi_post = DuringStim_power_roi_post.mean(axis=0)
    AfterStim_power_roi_post = AfterStim_power_roi_post.mean(axis=0)

    # Average the roi post (0 -> 0.5 sec)
    AvgStim_power_roi_baseline = (BeforeStim_power_roi_baseline + DuringStim_power_roi_baseline + AfterStim_power_roi_baseline) / 3
    AvgStim_power_roi_post = (BeforeStim_power_roi_post + DuringStim_power_roi_post + AfterStim_power_roi_post) / 3

    # Subtract the baseline from the post
    NS_power_post_minus_baseline = NS_power_roi_post - NS_power_roi_baseline
    AvgStim_power_post_minus_baseline = AvgStim_power_roi_post - AvgStim_power_roi_baseline
    BeforeStim_power_post_minus_baseline = BeforeStim_power_roi_post - BeforeStim_power_roi_baseline
    DuringStim_power_post_minus_baseline = DuringStim_power_roi_post - DuringStim_power_roi_baseline
    AfterStim_power_post_minus_baseline = AfterStim_power_roi_post - AfterStim_power_roi_baseline

    # Save the numpy arrays
    np.save(os.path.join(savepath, f'ROI_{roi_name}_NS_power_post_minus_baseline_remembered.npy'), NS_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_AvgStim_power_post_minus_baseline_remembered.npy'), AvgStim_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_BeforeStim_power_post_minus_baseline_remembered.npy'), BeforeStim_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_DuringStim_power_post_minus_baseline_remembered.npy'), DuringStim_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_AfterStim_power_post_minus_baseline_remembered.npy'), AfterStim_power_post_minus_baseline)

    # Create a range for the x-axis
    x = PSD_freqs

    # Plot the lines
    plt.close()
    plt.figure(figsize=(10, 6))
    plt.plot(x, NS_power_post_minus_baseline.flatten(), label='NS', color='blue', linestyle='-')
    plt.plot(x, AvgStim_power_post_minus_baseline.flatten(), label='AvgStim', color='red', linestyle='--')
    plt.plot(x, BeforeStim_power_post_minus_baseline.flatten(), label='OneStim', color='pink', linestyle='-')
    plt.plot(x, DuringStim_power_post_minus_baseline.flatten(), label='ThreeStim', color='purple', linestyle='-')
    plt.plot(x, AfterStim_power_post_minus_baseline.flatten(), label='AfterStim', color='orange', linestyle='-')

    # Add titles and labels
    plt.title(f'Remembered {roi_name} NoStim & Stim Power ')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power (dB)')
    plt.legend()

    # Show the plot
    plt.show()

    print("Plotting PSDs ...")

    # Power Spectral Density Params (PSD)
    pad = 0.1

    # Figure params
    plt.close()
    fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)

    # Plot PSDs for all channels
    avg_ns_power_np_arr = 10 * np.log10(np.mean(NS_power_np_arr, axis=0))
    avg_beforestim_power_np_arr = 10 * np.log10(np.mean(BeforeStim_power_np_arr, axis=0))
    avg_duringstim_power_np_arr = 10 * np.log10(np.mean(DuringStim_power_np_arr, axis=0))
    avg_afterstim_power_np_arr = 10 * np.log10(np.mean(AfterStim_power_np_arr, axis=0))
    avg_avgstim_power_np_arr = (avg_beforestim_power_np_arr + avg_duringstim_power_np_arr + avg_afterstim_power_np_arr) / 3

    # Plotting mean and std
    ns_power_mean = np.mean(avg_ns_power_np_arr, axis=0)
    ns_power_std = np.std(avg_ns_power_np_arr, axis=0)
    beforestim_power_mean = np.mean(avg_beforestim_power_np_arr, axis=0)
    beforestim_power_std = np.std(avg_beforestim_power_np_arr, axis=0)
    duringstim_power_mean = np.mean(avg_duringstim_power_np_arr, axis=0)
    duringstim_power_std = np.std(avg_duringstim_power_np_arr, axis=0)
    afterstim_power_mean = np.mean(avg_afterstim_power_np_arr, axis=0)
    afterstim_power_std = np.std(avg_afterstim_power_np_arr, axis=0)
    avgstim_power_mean = np.mean(avg_avgstim_power_np_arr, axis=0)
    avgstim_power_std = np.std(avg_avgstim_power_np_arr, axis=0)

    axes[0].plot(NS_freqs, ns_power_mean, label='NS Mean', color='blue')
    axes[0].fill_between(NS_freqs, ns_power_mean - ns_power_std, ns_power_mean + ns_power_std, color='b', alpha=0.2, label='NS SD')
    axes[0].plot(NS_freqs, avgstim_power_mean, label='AvgStim Mean', color='red')
    axes[0].fill_between(NS_freqs, avgstim_power_mean - avgstim_power_std, avgstim_power_mean + avgstim_power_std, color='r', alpha=0.2, label='AvgStim SD')
    axes[0].plot(BeforeStim_freqs, beforestim_power_mean, label='BeforeStim Mean', color='pink')
    axes[0].fill_between(BeforeStim_freqs, beforestim_power_mean - beforestim_power_std, beforestim_power_mean + beforestim_power_std, color='pink', alpha=0.2, label='BeforeStim SD')
    axes[0].plot(DuringStim_freqs, duringstim_power_mean, label='DuringStim Mean', color='purple') 
    axes[0].fill_between(DuringStim_freqs, duringstim_power_mean - duringstim_power_std, duringstim_power_mean + duringstim_power_std, color='purple', alpha=0.2, label='DuringStim SD')
    axes[0].plot(AfterStim_freqs, afterstim_power_mean, label='AfterStim Mean', color='orange')
    axes[0].fill_between(AfterStim_freqs, afterstim_power_mean - afterstim_power_std, afterstim_power_mean + afterstim_power_std, color='orange', alpha=0.2, label='AfterStim SD')

    # Plot PSDs for ROI channels
    avg_ns_power_np_arr_roi = 10 * np.log10(np.mean(NS_power_np_arr_roi, axis=0))
    avg_beforestim_power_np_arr_roi = 10 * np.log10(np.mean(BeforeStim_power_np_arr_roi, axis=0))
    avg_duringstim_power_np_arr_roi = 10 * np.log10(np.mean(DuringStim_power_np_arr_roi, axis=0))
    avg_afterstim_power_np_arr_roi = 10 * np.log10(np.mean(AfterStim_power_np_arr_roi, axis=0))
    avg_avg_power_np_arr_roi = (avg_beforestim_power_np_arr_roi + avg_duringstim_power_np_arr_roi + avg_afterstim_power_np_arr_roi) / 3

    # Plotting mean and std for ROIs
    ns_power_mean_roi = np.mean(avg_ns_power_np_arr_roi, axis=0)
    ns_power_std_roi = np.std(avg_ns_power_np_arr_roi, axis=0)
    avgstim_power_mean_roi = np.mean(avg_avg_power_np_arr_roi, axis=0)
    avgstim_power_std_roi = np.std(avg_avg_power_np_arr_roi, axis=0)
    beforestim_power_mean_roi = np.mean(avg_beforestim_power_np_arr_roi, axis=0)
    beforestim_power_std_roi = np.std(avg_beforestim_power_np_arr_roi, axis=0)
    duringstim_power_mean_roi = np.mean(avg_duringstim_power_np_arr_roi, axis=0)
    duringstim_power_std_roi = np.std(avg_duringstim_power_np_arr_roi, axis=0)
    afterstim_power_mean_roi = np.mean(avg_afterstim_power_np_arr_roi, axis=0)
    afterstim_power_std_roi = np.std(avg_afterstim_power_np_arr_roi, axis=0)

    axes[1].plot(NS_freqs_roi, ns_power_mean_roi, label='NS Mean', color='blue')
    axes[1].fill_between(NS_freqs_roi, ns_power_mean_roi - ns_power_std_roi, ns_power_mean_roi + ns_power_std_roi, color='b', alpha=0.2, label='NS SD')
    axes[1].plot(BeforeStim_freqs_roi, beforestim_power_mean_roi, label='BeforeStim Mean', color='pink')
    axes[1].fill_between(BeforeStim_freqs_roi, beforestim_power_mean_roi - beforestim_power_std_roi, beforestim_power_mean_roi + beforestim_power_std_roi, color='pink', alpha=0.2, label='BeforeStim SD')
    axes[1].plot(DuringStim_freqs_roi, duringstim_power_mean_roi, label='DuringStim Mean', color='purple')
    axes[1].fill_between(DuringStim_freqs_roi, duringstim_power_mean_roi - duringstim_power_std_roi, duringstim_power_mean_roi + duringstim_power_std_roi, color='purple', alpha=0.2, label='DuringStim SD')
    axes[1].plot(AfterStim_freqs_roi, afterstim_power_mean_roi, label='AfterStim Mean', color='orange')
    axes[1].fill_between(AfterStim_freqs_roi, afterstim_power_mean_roi - afterstim_power_std_roi, afterstim_power_mean_roi + afterstim_power_std_roi, color='orange', alpha=0.2, label='AfterStim SD')
    axes[1].plot(NS_freqs_roi, avgstim_power_mean_roi, label='AvgStim Mean', color='red')
    axes[1].fill_between(NS_freqs_roi, avgstim_power_mean_roi - avgstim_power_std_roi, avgstim_power_mean_roi + avgstim_power_std_roi, color='r', alpha=0.2, label='AvgStim SD')

    # Clean-up Figures
    axes[0].set_title('Remembered Power by stim cond for all channels')
    axes[1].set_title(f'Remembered Power by stim cond for {roi_name} ROI channels')
    axes[0].set_ylabel('dB')
    axes[1].set_ylabel('dB')
    axes[0].set_xlabel('Frequency (Hz)')
    axes[1].set_xlabel('Frequency (Hz)')
    sns.despine(top=True, right=True)

    # Manually create legend with custom colored patches
    legend_patches = [
        mpatches.Patch(color='blue', label='No Stim',),
        mpatches.Patch(color='pink', label='One Stim'),
        mpatches.Patch(color='purple', label='Three Stim'),
        mpatches.Patch(color='red', label='Avg Stim'),
    ]
    axes[1].legend(handles=legend_patches, loc='upper right')
    axes[0].legend().remove() # Remove legend from the first subplot

    # Save
    plt.savefig(os.path.join(savepath, f'Remembered_PowerSpectra-{roi_name}.png'), dpi=1200, bbox_inches='tight')
    plt.show()



In [2]:
# Create stim/no-stim epochs
NoStimEpoch = remembered_epoched_data['0']
BeforeStimEpoch = remembered_epoched_data['1']
DuringStimEpoch = remembered_epoched_data['2']
AfterStimEpoch = remembered_epoched_data['3']

# Save epochs as numpy arrays
np.save(os.path.join(my_preprocessing_path, 'Remembered_NoStimEpoch.npy'), NoStimEpoch.get_data())
np.save(os.path.join(my_preprocessing_path, 'Remembered_BeforeStimEpoch.npy'), BeforeStimEpoch.get_data())
np.save(os.path.join(my_preprocessing_path, 'Remembered_DuringStimEpoch.npy'), DuringStimEpoch.get_data())
np.save(os.path.join(my_preprocessing_path, 'Remembered_AfterStimEpoch.npy'), AfterStimEpoch.get_data())

# Compute Stim and No Stim PSDs for all channels
NS_power_np_arr, NS_freqs = process_epoch_freqs_separately(NoStimEpoch, tmin=0, tmax=0.5)
BeforeStim_power_np_arr, BeforeStim_freqs = process_epoch_freqs_separately(BeforeStimEpoch, tmin=0, tmax=0.5)
DuringStim_power_np_arr, DuringStim_freqs = process_epoch_freqs_separately(DuringStimEpoch, tmin=0, tmax=0.5)
AfterStim_power_np_arr, AfterStim_freqs = process_epoch_freqs_separately(AfterStimEpoch, tmin=0, tmax=0.5)

# For each ROI set, run the plots!
for roi_name, roi_channels in ROIs.items():
    # Skip if we dont use that ROI stuff
    if roi_channels == []:
        continue

    print(f"Running ROI {roi_name} ...")

    print(f"Creating AvgStim epochs for {roi_name} ...")
    # get epoch data in numpy format
    before_stim_data = BeforeStimEpoch.get_data()
    during_stim_data = DuringStimEpoch.get_data()
    after_stim_data = AfterStimEpoch.get_data()

    # Compute PSDs for ROI channels
    NS_power_np_arr_roi, NS_freqs_roi = process_epoch_freqs_separately(NoStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    BeforeStim_power_np_arr_roi, BeforeStim_freqs_roi = process_epoch_freqs_separately(BeforeStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    DuringStim_power_np_arr_roi, DuringStim_freqs_roi = process_epoch_freqs_separately(DuringStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    AfterStim_power_np_arr_roi, AfterStim_freqs_roi = process_epoch_freqs_separately(AfterStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)

    ###### To get the averaged epoch data, we need to average across the epochs first, then average across the channels ##########
    # Average across epochs, then get the raw numpy array so we can save that and read into group analysis
    NS_power_numpy_arr_roi = 10* np.log(np.mean(NS_power_np_arr_roi, axis=0))
    BeforeStim_power_numpy_arr_roi = 10* np.log(np.mean(BeforeStim_power_np_arr_roi, axis=0))
    DuringStim_power_numpy_arr_roi = 10* np.log(np.mean(DuringStim_power_np_arr_roi, axis=0))
    AfterStim_power_numpy_arr_roi = 10* np.log(np.mean(AfterStim_power_np_arr_roi, axis=0))

    # Average across the channels
    NS_power_numpy_arr_roi = NS_power_numpy_arr_roi.mean(axis=0)
    BeforeStim_power_numpy_arr_roi = BeforeStim_power_numpy_arr_roi.mean(axis=0)
    DuringStim_power_numpy_arr_roi = DuringStim_power_numpy_arr_roi.mean(axis=0)
    AfterStim_power_numpy_arr_roi = AfterStim_power_numpy_arr_roi.mean(axis=0)
    AvgStim_power_numpy_arr_roi = (BeforeStim_power_numpy_arr_roi + DuringStim_power_numpy_arr_roi + AfterStim_power_numpy_arr_roi) / 3
  

    # Save the averaged power for this ROI so that we can use it in the group script
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_NS_power_remembered.npy'), NS_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_AvgStim_power_remembered.npy'), AvgStim_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_BeforeStim_power_remembered.npy'), BeforeStim_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_DuringStim_power_remembered.npy'), DuringStim_power_numpy_arr_roi)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_AfterStim_power_remembered.npy'), AfterStim_power_numpy_arr_roi)

    print("Calculating Theta and Gamma differences ...")
    # Get freqs
    PSD_freqs = NS_freqs_roi

    # Save the avg power for this ROI into a numpy array
    thetaAvgStim = np.array(AvgStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaNoStim = np.array(NS_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaBeforeStim = np.array(BeforeStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaDuringStim = np.array(DuringStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaAfterStim = np.array(AfterStim_power_numpy_arr_roi[(PSD_freqs >= 4) & (PSD_freqs <= 8)])
    thetaAvgDiff = (thetaAvgStim - thetaNoStim).mean()
    thetaBefDiff = (thetaBeforeStim - thetaNoStim).mean()
    thetaDurDiff = (thetaDuringStim - thetaNoStim).mean()
    thetaAftDiff = (thetaAfterStim - thetaNoStim).mean()

    gammaAvgStim = np.array(AvgStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaNoStim = np.array(NS_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaBeforeStim = np.array(BeforeStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaDuringStim = np.array(DuringStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaAfterStim = np.array(AfterStim_power_numpy_arr_roi[(PSD_freqs >= 30) & (PSD_freqs <= 55)])
    gammaAvgDiff = (gammaAvgStim - gammaNoStim).mean()
    gammaBefDiff = (gammaBeforeStim - gammaNoStim).mean()
    gammaDurDiff = (gammaDuringStim - gammaNoStim).mean()
    gammaAftDiff = (gammaAfterStim - gammaNoStim).mean()


    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_remembered.npy'), thetaAvgDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_before_remembered.npy'), thetaBefDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_during_remembered.npy'), thetaDurDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_theta_after_remembered.npy'), thetaAftDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_remembered.npy'), gammaAvgDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_before_remembered.npy'), gammaBefDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_during_remembered.npy'), gammaDurDiff)
    np.save(os.path.join(my_preprocessing_path, f'ROI_{roi_name}_avg_vs_ns_power_gamma_after_remembered.npy'), gammaAftDiff)

    #print(thetaDiff)
    #print(gammaDiff)

    print("Generating 0-0.5 sec power minus -0.5-0 sec power ...")
    # Generate power for -0.5 sec to 0 sec (baseline) and for 0 sec to 0.5 sec (analysis window)
    # Then subtract the two to get the difference. Average ThreeSec and OneSec
    # BASELINE
    NS_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(NoStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    NS_power_roi_baseline = np.log10(np.mean(NS_power_np_arr_roi_baseline, axis=0))*10
    BeforeStim_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(BeforeStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    BeforeStim_power_roi_baseline = np.log10(np.mean(BeforeStim_power_np_arr_roi_baseline, axis=0))*10
    DuringStim_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(DuringStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    DuringStim_power_roi_baseline = np.log10(np.mean(DuringStim_power_np_arr_roi_baseline, axis=0))*10
    AfterStim_power_np_arr_roi_baseline, _ = process_epoch_freqs_separately(AfterStimEpoch.copy().pick(picks=roi_channels), tmin=-0.5, tmax=0)
    AfterStim_power_roi_baseline = np.log10(np.mean(AfterStim_power_np_arr_roi_baseline, axis=0))*10

    # ANALYSIS WINDOW
    NS_power_np_arr_roi_post, _ = process_epoch_freqs_separately(NoStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    NS_power_roi_post = np.log10(np.mean(NS_power_np_arr_roi_post, axis=0))*10
    BeforeStim_power_np_arr_roi_post, _ = process_epoch_freqs_separately(BeforeStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    BeforeStim_power_roi_post = np.log10(np.mean(BeforeStim_power_np_arr_roi_post, axis=0))*10
    DuringStim_power_np_arr_roi_post, _ = process_epoch_freqs_separately(DuringStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    DuringStim_power_roi_post = np.log10(np.mean(DuringStim_power_np_arr_roi_post, axis=0))*10
    AfterStim_power_np_arr_roi_post, _ = process_epoch_freqs_separately(AfterStimEpoch.copy().pick(picks=roi_channels), tmin=0, tmax=0.5)
    AfterStim_power_roi_post = np.log10(np.mean(AfterStim_power_np_arr_roi_post, axis=0))*10
    
    # Average across channels
    NS_power_roi_baseline = NS_power_roi_baseline.mean(axis=0)
    BeforeStim_power_roi_baseline = BeforeStim_power_roi_baseline.mean(axis=0)
    DuringStim_power_roi_baseline = DuringStim_power_roi_baseline.mean(axis=0)
    AfterStim_power_roi_baseline = AfterStim_power_roi_baseline.mean(axis=0)

    NS_power_roi_post = NS_power_roi_post.mean(axis=0)
    BeforeStim_power_roi_post = BeforeStim_power_roi_post.mean(axis=0)
    DuringStim_power_roi_post = DuringStim_power_roi_post.mean(axis=0)
    AfterStim_power_roi_post = AfterStim_power_roi_post.mean(axis=0)

    # Average the roi post (0 -> 0.5 sec)
    AvgStim_power_roi_baseline = (BeforeStim_power_roi_baseline + DuringStim_power_roi_baseline + AfterStim_power_roi_baseline) / 3
    AvgStim_power_roi_post = (BeforeStim_power_roi_post + DuringStim_power_roi_post + AfterStim_power_roi_post) / 3

    # Subtract the baseline from the post
    NS_power_post_minus_baseline = NS_power_roi_post - NS_power_roi_baseline
    AvgStim_power_post_minus_baseline = AvgStim_power_roi_post - AvgStim_power_roi_baseline
    BeforeStim_power_post_minus_baseline = BeforeStim_power_roi_post - BeforeStim_power_roi_baseline
    DuringStim_power_post_minus_baseline = DuringStim_power_roi_post - DuringStim_power_roi_baseline
    AfterStim_power_post_minus_baseline = AfterStim_power_roi_post - AfterStim_power_roi_baseline

    # Save the numpy arrays
    np.save(os.path.join(savepath, f'ROI_{roi_name}_NS_power_post_minus_baseline_remembered.npy'), NS_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_AvgStim_power_post_minus_baseline_remembered.npy'), AvgStim_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_BeforeStim_power_post_minus_baseline_remembered.npy'), BeforeStim_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_DuringStim_power_post_minus_baseline_remembered.npy'), DuringStim_power_post_minus_baseline)
    np.save(os.path.join(savepath, f'ROI_{roi_name}_AfterStim_power_post_minus_baseline_remembered.npy'), AfterStim_power_post_minus_baseline)
    
    # Create a range for the x-axis
    x = PSD_freqs
    # Plot the lines
    plt.close()
    plt.figure(figsize=(10, 6))
    plt.plot(x, NS_power_post_minus_baseline.flatten(), label='NS', color='blue', linestyle='-')
    plt.plot(x, AvgStim_power_post_minus_baseline.flatten(), label='AvgStim', color='red', linestyle='--')
    plt.plot(x, BeforeStim_power_post_minus_baseline.flatten(), label='OneStim', color='pink', linestyle='-')
    plt.plot(x, DuringStim_power_post_minus_baseline.flatten(), label='ThreeStim', color='purple', linestyle='-')
    plt.plot(x, AfterStim_power_post_minus_baseline.flatten(), label='AfterStim', color='orange', linestyle='-')

    # Add titles and labels
    plt.title(f'Remembered {roi_name} NoStim & Stim Power ')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power (dB)')
    plt.legend()

    # Show the plot
    plt.show()

    print("Plotting PSDs ...")

    # Power Spectral Density Params (PSD)
    pad = 0.1

    # Figure params
    plt.close()
    fig, axes = plt.subplots(1,2, figsize = (12,4), sharey = True) #first two numbers represent the matrix of subfigures created, for example, 1,4, will create 4 subfigures in a row

    # Plot PSDs for all channels
    # Avg across epochs. Log10 to make it decibels
    avg_ns_power_np_arr = np.log10(np.mean(NS_power_np_arr, axis=0)) * 10
    avg_beforestim_power_np_arr = np.log10(np.mean(BeforeStim_power_np_arr, axis=0)) * 10
    avg_duringstim_power_np_arr = np.log10(np.mean(DuringStim_power_np_arr, axis=0)) * 10
    avg_afterstim_power_np_arr = np.log10(np.mean(AfterStim_power_np_arr, axis=0)) * 10
    avg_avgstim_power_np_arr = (avg_beforestim_power_np_arr + avg_duringstim_power_np_arr + avg_afterstim_power_np_arr) / 3

    # Calculate the mean and standard deviation across samples
    ns_power_mean = np.mean(avg_ns_power_np_arr, axis=0)
    ns_power_std = np.std(avg_ns_power_np_arr, axis=0)
    beforestim_power_mean = np.mean(avg_beforestim_power_np_arr, axis=0)
    beforestim_power_std = np.std(avg_beforestim_power_np_arr, axis=0)
    duringstim_power_mean = np.mean(avg_duringstim_power_np_arr, axis=0)
    duringstim_power_std = np.std(avg_duringstim_power_np_arr, axis=0)
    afterstim_power_mean = np.mean(avg_afterstim_power_np_arr, axis=0)
    afterstim_power_std = np.std(avg_afterstim_power_np_arr, axis=0)
    avgstim_power_mean = np.mean(avg_avgstim_power_np_arr, axis=0)
    avgstim_power_std = np.std(avg_avgstim_power_np_arr, axis=0)

    axes[0].plot(NS_freqs, ns_power_mean, label='NS Mean', color = 'blue')
    axes[0].fill_between(NS_freqs, ns_power_mean - ns_power_std, ns_power_mean + ns_power_std, color='b', alpha=0.2, label='NS SD')
    axes[0].plot(NS_freqs, avgstim_power_mean, label='AvgStim Mean', color = 'red')
    axes[0].fill_between(NS_freqs, avgstim_power_mean - avgstim_power_std, avgstim_power_mean + avgstim_power_std, color='r', alpha=0.2, label='AvgStim SD')
    axes[0].plot(BeforeStim_freqs, beforestim_power_mean, label='BeforeStim Mean', color = 'pink')
    axes[0].fill_between(BeforeStim_freqs, beforestim_power_mean - beforestim_power_std, beforestim_power_mean + beforestim_power_std, color='pink', alpha=0.2, label='BeforeStim SD')
    axes[0].plot(DuringStim_freqs, duringstim_power_mean, label='DuringStim Mean', color = 'purple') 
    axes[0].fill_between(DuringStim_freqs, duringstim_power_mean - duringstim_power_std, duringstim_power_mean + duringstim_power_std, color='purple', alpha=0.2, label='DuringStim SD')
    axes[0].plot(AfterStim_freqs, afterstim_power_mean, label='AfterStim Mean', color = 'orange')
    axes[0].fill_between(AfterStim_freqs, afterstim_power_mean - afterstim_power_std, afterstim_power_mean + afterstim_power_std, color='orange', alpha=0.2, label='AfterStim SD')

    # Plot PSDs for ROI channels
    avg_ns_power_np_arr_roi = np.log10(np.mean(NS_power_np_arr_roi, axis=0)) * 10
    avg_beforestim_power_np_arr_roi = np.log10(np.mean(BeforeStim_power_np_arr_roi, axis=0)) * 10
    avg_duringstim_power_np_arr_roi = np.log10(np.mean(DuringStim_power_np_arr_roi, axis=0)) * 10
    avg_afterstim_power_np_arr_roi = np.log10(np.mean(AfterStim_power_np_arr_roi, axis=0)) * 10
    avg_avg_power_np_arr_roi = (avg_beforestim_power_np_arr_roi + avg_duringstim_power_np_arr_roi +avg_afterstim_power_np_arr_roi) / 3

    # Calculate the mean and standard deviation across samples for ROIs
    ns_power_mean_roi = np.mean(avg_ns_power_np_arr_roi, axis=0)
    ns_power_std_roi = np.std(avg_ns_power_np_arr_roi, axis=0)
    avgstim_power_mean_roi = np.mean(avg_avg_power_np_arr_roi, axis=0)
    avgstim_power_std_roi = np.std(avg_avg_power_np_arr_roi, axis=0)
    beforestim_power_mean_roi = np.mean(avg_beforestim_power_np_arr_roi, axis=0)
    beforestim_power_std_roi = np.std(avg_beforestim_power_np_arr_roi, axis=0)
    duringstim_power_mean_roi = np.mean(avg_duringstim_power_np_arr_roi, axis=0)
    duringstim_power_std_roi = np.std(avg_duringstim_power_np_arr_roi, axis=0)
    afterstim_power_mean_roi = np.mean(avg_afterstim_power_np_arr_roi, axis=0)
    afterstim_power_std_roi = np.std(avg_afterstim_power_np_arr_roi, axis=0)

    axes[1].plot(NS_freqs_roi, ns_power_mean_roi, label='NS Mean', color = 'blue')
    axes[1].fill_between(NS_freqs_roi, ns_power_mean_roi - ns_power_std_roi, ns_power_mean_roi + ns_power_std_roi, color='b', alpha=0.2, label='NS SD')
    axes[1].plot(BeforeStim_freqs_roi, beforestim_power_mean_roi, label='BeforeStim Mean', color = 'pink')
    axes[1].fill_between(BeforeStim_freqs_roi, beforestim_power_mean_roi - beforestim_power_std_roi, beforestim_power_mean_roi + beforestim_power_std_roi, color='pink', alpha=0.2, label='BeforeStim SD')
    axes[1].plot(DuringStim_freqs_roi, duringstim_power_mean_roi, label='DuringStim Mean', color = 'purple')
    axes[1].fill_between(DuringStim_freqs_roi, duringstim_power_mean_roi - duringstim_power_std_roi, duringstim_power_mean_roi + duringstim_power_std_roi, color='purple', alpha=0.2, label='DuringStim SD')
    axes[1].plot(AfterStim_freqs_roi, afterstim_power_mean_roi, label='AfterStim Mean', color = 'orange')
    axes[1].fill_between(AfterStim_freqs_roi, afterstim_power_mean_roi - afterstim_power_std_roi, afterstim_power_mean_roi + afterstim_power_std_roi, color='orange', alpha=0.2, label='AfterStim SD')
    axes[1].plot(NS_freqs_roi, avgstim_power_mean_roi, label='AvgStim Mean', color = 'red')
    axes[1].fill_between(NS_freqs_roi, avgstim_power_mean_roi - avgstim_power_std_roi, avgstim_power_mean_roi + avgstim_power_std_roi, color='r', alpha=0.2, label='AvgStim SD')

    # Clean-up Figures
    #fig.suptitle('\n (First 500 ms after image onset)')
    axes[0].set_title('Remembered Power by stim cond for all channels')
    axes[1].set_title(f'Remembered Power by stim cond for {roi_name} ROI channels')
    axes[0].set_ylabel('dB')
    axes[1].set_ylabel('dB')
    axes[0].set_xlabel('Frequency (Hz)')
    axes[1].set_xlabel('Frequency (Hz)')
    sns.despine(top = True, right = True)

    # Manually create legend with custom colored patches
    legend_patches = [
        mpatches.Patch(color='blue', label='No Stim',),
        mpatches.Patch(color='pink', label='One Stim'),
        mpatches.Patch(color='purple', label='Three Stim'),
        mpatches.Patch(color='red', label='Avg Stim'),
    ]
    axes[1].legend(handles=legend_patches, loc='upper right')
    axes[0].legend().remove()# Remove legend from the first subplot

    # Save
    plt.savefig(os.path.join(savepath, f'Remembered_PowerSpectra-{roi_name}.png'), dpi = 1200, bbox_inches = 'tight')
    plt.show()

NameError: name 'remembered_epoched_data' is not defined

### 3.2 Stim Spectrograms
Creates and saves spectograms for the first 500 ms of data between 5Hz and 55Hz for each channel post stim.

In [None]:
def plot_Spectrogram(TFR, chan, stimcond):
    '''
    Plot spectrogram from the mne.TFR object
    
    Arguments:
        TFR: mne.TFR object
        chan: channel index to plot (int)
        stimcond: adding stim condition to plot name
    
    Returns:
        None
    '''
    
    # Fig params
    fig, ax = plt.subplots(1,1, figsize = (10,4))
    vmin, vmax = -5.0, 5.0  # Define our color limits.
    cnorm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)  # min, center & max ERDS
    baseline = (-0.5, 0)  # Define our baseline period

    # Plot spectrogram (4-100 Hz, -200ms-700ms, log z-scored against baseline)
    TFR.plot(picks = [chan], baseline = baseline, mode = 'zlogratio', fmin = 4, fmax = 100, tmin = -0.2, tmax = 0.7, vmin = vmin, vmax = vmax, cmap = 'coolwarm', show = False, axes = ax, verbose = False)

    # Figure aeshetics
    ch_label = TFR.ch_names[chan]
    ax.axvline(x = 0, color = 'k', lw = 1, linestyle = '--')
    ax.axvline(x = 0.5, color = 'k', lw = 1, linestyle = '--')
    ax.axvspan(xmin = 0, xmax = 0.5, color = 'grey', alpha = 0.25)
    ax.set_yticks([5, 10, 30, 50, 70, 100], ['5', '10', '30', '50', '70', '100'], fontsize = 'small')
    ax.set_xticks([-0.2, 0, 0.5, 0.7], ['-0.2', '0', '0.5', '0.7'], fontsize = 'small')
    ax.set_xlabel('Time (s)', fontsize = 'medium')
    ax.set_ylabel('Frequency (Hz)', fontsize = 'medium')
    ax.set_title('Remembered '+ ROI_name + ' ' + ch_label + ' ' + stimcond, fontsize = 'large')
    plt.tight_layout()

    # Colorbar aesthetics
    cb = ax.collections[0].colorbar
    cb.set_ticks(ticks = [-4, -2, 0, 2, 4], labels = ['-4', '-2', '0', '2', '4'], fontsize = 'small')
    cb.set_label('$\it{z}$ (log pwr.)', fontsize = 'small')

    # Save
    plt.savefig(os.path.join(specPath,'Remembered_'+ ROI_name + '_' + ch_label + stimcond + '_Stim_Spectrogram.png'), dpi = 1500, bbox_inches = 'tight')
    plt.show()

### Run individual spectograms

In [None]:
# Create Spectrogram folder if it doesn't exist
specPath = os.path.join(savepath, 'Spectrograms')
if not os.path.exists(specPath):
     os.mkdir(specPath)

stimtype = {
     'NoStim' : NoStimEpoch.copy(),
     'BeforeStim' : BeforeStimEpoch.copy(),
     'DuringStim' : DuringStimEpoch.copy(),
     'AfterStim' : AfterStimEpoch.copy()
}

#specFreqs = np.logspace(np.log10(1), np.log10(150), 300) #justin's freq range, but i have less data in each epoch so i need a shorter f range
specFreqs = np.logspace(np.log10(4), np.log10(100), 300)

for stim, epoch in stimtype.items():
    stimTFR = tfr_multitaper(epoch, freqs = specFreqs, n_cycles = 2, time_bandwidth = 2, return_itc = False, average = True, n_jobs = -1, verbose = False)

    #loop though the ROIs that are not empty
    for ROI_name, ROI_channels in ROIs.items():
        if not ROI_channels or ROI_name == 'MTL':
            continue
        print(f"==== Generating Spectrograms for {stim} in {ROI_name} ROIs")

        # Find indices for the ROI channels
        ROI_indices = [stimTFR.ch_names.index(ch) for ch in ROI_channels if ch in stimTFR.ch_names]

        # Generate spectrograms for each ROI channel
        for idx in ROI_indices:
            plot_Spectrogram(stimTFR, idx, stim)
        plt.close()

## Spectogram difference plots between stim and no stim

In [None]:
def plot_DiffSpectrogram(diff_BeforeStimTFR, diff_DuringStimTFR, diff_AfterStimTFR, chan_idx, ROI_name):
    '''
    Plot difference spectrograms from the mne.TFR objects
    
    Arguments:
        diff_OneStimTFR: mne.TFR object for NoStim vs OneStim condition
        diff_ThreeStimTFR: mne.TFR object for NoStim vs ThreeStim condition
        chan_idx: channel index to plot (int)
        ROI_name: name of the region of interest (str)
    
    Returns:
        None
    '''
    
    # Fig params
    fig, axs = plt.subplots(1, 3, figsize=(30, 4))
    vmin, vmax = None, None  # Define our color limits.
    baseline = None

    cnorm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)  # min, center & max ERDS

    # Plot NoStim vs OneStim difference spectrogram
    diff_BeforeStimTFR.plot(picks = chan_idx, baseline = baseline, mode = 'zlogratio', fmin = 1, fmax = 100, tmin = -0.2, tmax = 0.7, vmin = vmin, vmax = vmax, cmap = 'coolwarm', show = False, axes = axs[0], verbose = False)
   
    axs[0].set_title(f'Remembered {ROI_name} {diff_BeforeStimTFR.ch_names[chan_idx]} NoStim vs BeforeStim', fontsize='large')
    axs[0].axvline(x=0, color='k', lw=1, linestyle='--')
    axs[0].axvline(x=0.5, color='k', lw=1, linestyle='--')
    axs[0].axvspan(xmin=0, xmax=0.5, color='grey', alpha=0.25)
    axs[0].set_yticks([5, 10, 30, 50, 70, 100], ['5', '10', '30', '50', '70', '100'], fontsize = 'small')
    axs[0].set_xticks([-0.2, 0, 0.5, 0.7], ['-0.2', '0', '0.5', '0.7'], fontsize = 'small')
    axs[0].set_xlabel('Time (s)', fontsize='medium')
    axs[0].set_ylabel('Frequency (Hz)', fontsize='medium')

    # Plot NoStim vs ThreeStim difference spectrogram
    diff_DuringStimTFR.plot(picks = chan_idx, baseline = baseline, mode = 'zlogratio', fmin = 1, fmax = 100, tmin = -0.2, tmax = 0.7, vmin = vmin, vmax = vmax, cmap = 'coolwarm', show = False, axes = axs[1], verbose = False)

    axs[1].set_title(f'Remembered {ROI_name} {diff_DuringStimTFR.ch_names[chan_idx]} ' NoStim vs DuringStim', fontsize='large')
    axs[1].axvline(x=0, color='k', lw=1, linestyle='--')
    axs[1].axvline(x=0.5, color='k', lw=1, linestyle='--')
    axs[1].axvspan(xmin=0, xmax=0.5, color='grey', alpha=0.25)
    axs[1].set_yticks([5, 10, 30, 50, 70, 100], ['5', '10', '30', '50', '70', '100'], fontsize = 'small')
    axs[1].set_xticks([-0.2, 0, 0.5, 0.7], ['-0.2', '0', '0.5', '0.7'], fontsize = 'small')
    axs[1].set_xlabel('Time (s)', fontsize='medium')
    axs[1].set_ylabel('Frequency (Hz)', fontsize='medium')

    # Plot NoStim vs ThreeStim difference spectrogram
    diff_AfterStimTFR.plot(picks = chan_idx, baseline = baseline, mode = 'zlogratio', fmin = 1, fmax = 100, tmin = -0.2, tmax = 0.7, vmin = vmin, vmax = vmax, cmap = 'coolwarm', show = False, axes = axs[2], verbose = False)

    axs[2].set_title(f'Remembered {ROI_name} {diff_AfterStimTFR.ch_names[chan_idx]}  + ' NoStim vs AfterStim', fontsize='large')
    axs[2].axvline(x=0, color='k', lw=1, linestyle='--')
    axs[2].axvline(x=0.5, color='k', lw=1, linestyle='--')
    axs[2].axvspan(xmin=0, xmax=0.5, color='grey', alpha=0.25)
    axs[2].set_yticks([5, 10, 30, 50, 70, 100], ['5', '10', '30', '50', '70', '100'], fontsize = 'small')
    axs[2].set_xticks([-0.2, 0, 0.5, 0.7], ['-0.2', '0', '0.5', '0.7'], fontsize = 'small')
    axs[2].set_xlabel('Time (s)', fontsize='medium')
    axs[2].set_ylabel('Frequency (Hz)', fontsize='medium')


    plt.tight_layout()

   # Save
    plt.savefig(os.path.join(specPath, f'{ROI_name}_{diff_ThreeStimTFR.ch_names[chan_idx]}_Diff_Spectrogram_Remembered.png'), dpi=1500, bbox_inches='tight')))
    plt.show()

In [None]:
# Create Spectrogram folder if it doesn't exist
specPath = os.path.join(savepath, 'Spectrograms')
if not os.path.exists(specPath):
    os.mkdir(specPath)

# Example frequency range and epochs for NoStim, OneStim, and ThreeStim conditions
specFreqs = np.logspace(np.log10(1), np.log10(100), 300)

# Replace these with actual epoch data for your conditions
stimSpecEpoch_NoStim = NoStimEpoch.copy()
stimEpochs_BeforeStim = BeforeStimEpoch.copy()
stimEpochs_DuringStim = DuringStimEpoch.copy()
stimEpochs_AfterStim = AfterStimEpoch.copy()

stimTFR_NoStim = tfr_multitaper(stimSpecEpoch_NoStim, freqs=specFreqs, n_cycles=2, time_bandwidth=2, return_itc=False, average=True, n_jobs=-1, verbose=False)
stimTFR_BeforeStim = tfr_multitaper(stimEpochs_BeforeStim, freqs=specFreqs, n_cycles=2, time_bandwidth=2, return_itc=False, average=True, n_jobs=-1, verbose=False)
stimTFR_DuringStim = tfr_multitaper(stimEpochs_DuringStim, freqs=specFreqs, n_cycles=2, time_bandwidth=2, return_itc=False, average=True, n_jobs=-1, verbose=False)
stimTFR_AfterStim = tfr_multitaper(stimEpochs_AfterStim, freqs=specFreqs, n_cycles=2, time_bandwidth=2, return_itc=False, average=True, n_jobs=-1, verbose=False)
print(type(stimTFR_NoStim))

diff_BeforeStimTFR = stimTFR_BeforeStim - stimTFR_NoStim
diff_DuringStimTFR = stimTFR_DuringStim - stimTFR_NoStim
diff_AfterStimTFR = stimTFR_AfterStim - stimTFR_NoStim

for ROI_name, ROI_channels in ROIs.items():
    if not ROI_channels or ROI_name == 'MTL':
        continue
    print(f"==== Generating Spectrograms for {stim} in {ROI_name} ROIs")

    # Find indices for the ROI channels
    ROI_indices = [stimTFR_NoStim.ch_names.index(ch) for ch in ROI_channels if ch in stimTFR_NoStim.ch_names]

# Generate spectrograms for each ROI channel
for idx in ROI_indices:
    plot_DiffSpectrogram(diff_BeforeStimTFR, diff_DuringStimTFR, diff_AfterStimTFR, idx)
    plt.close()

### 3.3 Statistical Analyis of Individual Frequency Change w/ FDR Correction

In [None]:
# stimtypes[stimtype] <-->  NoStimEpoch
stimtypes = {
    'NoStim': NS_power_allchans,
    'BeforeStim': BeforeStim_power_allchans,
    'DuringStim': DuringStim_power_allchans,
    'AfterStim': AfterStim_power_allchans,
}

stim_comparisons = ['BeforeStim', 'DuringStim', 'AfterStim']

ROI_name_lists_to_run = [
    ['roi_Hipp', ROIs['Hipp']],
    ['roi_BLA', ROIs['BLA']],
    ['roi_EC', ROIs['EC']],
    ['roi_PRC', ROIs['PRC']],
    ['roi_PHG', ROIs['PHG']],
    ['roi_MTL', ROIs['MTL']],
    ['all_ch', list(NoStimEpoch.ch_names)],
]

for stimtype in stim_comparisons:
    for display_ROI_name, channel_name_list in ROI_name_lists_to_run:
        print(f"{subject}: Exporting PSD Contrast StimType: {stimtype} vs NoStim for {display_ROI_name} channels (total channel: {len(channel_name_list)}): {channel_name_list}")

        if channel_name_list == []:
            continue

        # Get log-transformed PSDs
        NoStimPSDs = np.log10(NS_power_allchans.copy().pick(channel_name_list).get_data().copy())*10 # epochs, chans, freqs
        StimPSDs = np.log10(stimtypes[stimtype].copy().pick(channel_name_list).get_data().copy())*10

        # Loop through freqs
        PSDContrasts = []
        for i in range(len(PSD_freqs)):
            PSDContrasts.append(power_contrast_full(NoStimPSDs[:,:,i], StimPSDs[:,:,i], PSD_freqs[i], channel_name_list))
        
        # Export
        full_contrast = pd.concat(PSDContrasts)
        full_contrast.to_csv(os.path.join(savepath, f'full_powercontrast_{stimtype}NoStim_{display_ROI_name}.csv'))

### 3.4 Statistical Analysis of Band Power Change w/ Permutation Testing
Channel-wise paired t-test comparing pre-/post-stim power w/ permutation testing to control for false-positives.

In [None]:
def get_PSD_band(PSDs, freqs, band):
    
    # log-transform the PSDs
    #PSDs = np.log10(PSDs)
    
    # average over epochs & frequencies
    bandPwrAvg = PSDs[:,:,(freqs >= band[0]) & (freqs <= band[1])].mean(axis=(2))
  
    return bandPwrAvg

def power_contrast_band(stim_name, NoStimBandPSDs, StimBandPSDs, ch_names, saveStr = ''):
    # Construct pre-/post- PSD DFs
    NoStimPwrDF = pd.DataFrame(NoStimBandPSDs)
    NoStimPwrDF.columns = ch_names
    NoStimPwrDF = NoStimPwrDF.melt()
    NoStimPwrDF.columns = ['Chan', 'Pwr']
    NoStimPwrDF['Type'] = 'NoStim'

    StimPwrDF = pd.DataFrame(StimBandPSDs)
    StimPwrDF.columns = ch_names
    StimPwrDF = StimPwrDF.melt()
    StimPwrDF.columns = ['Chan', 'Pwr']
    StimPwrDF['Type'] = stim_name

    # Merge DFs
    pwrDF = pd.concat([NoStimPwrDF, StimPwrDF])

    # Paired t-test for nostim vs stim
    if NoStimBandPSDs.shape[0] != StimBandPSDs.shape[0]:
        # equal_var=False -> Welch's T-Test
        tVals, pVals = ttest_ind(NoStimBandPSDs, StimBandPSDs, nan_policy='omit', axis = 0, equal_var=False, random_state=0)
    else:
        tVals, pVals = ttest_rel(NoStimBandPSDs, StimBandPSDs, nan_policy = 'omit', axis = 0) # over epochs


    # Construct Contrast DF
    contrastDF = pd.DataFrame(tVals, columns = ['t Stat'])
    contrastDF['p'] = pVals
    contrastDF['Chan'] = np.array(ch_names)
    contrastDF['t Acceptable'] = (contrastDF['t Stat'] > -5) & (contrastDF['t Stat'] < 5)

    # Holder for permutation test results
    perm_Ts = []
    perm_Ps = []

    # loop through Chan in pwrDF
    for chan in pwrDF['Chan'].unique():
        # create empty holders for t- and p-values
        chan_tVals = []
        chan_pVals = []
        # create temporary DF for each chan
        tempDF = pwrDF[pwrDF['Chan'] == chan].copy()
        
        # permutation test
        for i in range(1000):
            # randomly shuffle values in place
            tempDF['Type_Shuffled'] = tempDF['Type'].copy().sample(frac=1).values  
            # separate tempDF into stim/nostim
            NoStimDF = tempDF[tempDF['Type_Shuffled'] == 'NoStim']
            StimDF = tempDF[tempDF['Type_Shuffled'] == stim_name]

            # Check if either DataFrame is empty and print a message
            if NoStimDF.empty or StimDF.empty:
                print(f"Skipping channel {chan} in iteration {i}: NoStimDF or StimDF is empty.")
                continue

            # Compute t-stat for shuffled DFs
            if NoStimDF['Pwr'].shape[0] != StimDF['Pwr'].shape[0]:
                # equal_var=False -> Welch's T-Test
                tVal, pVal = ttest_ind(NoStimDF['Pwr'], StimDF['Pwr'], nan_policy='omit', equal_var=False, random_state=0)
            else:
                tVal, pVal = ttest_rel(NoStimDF['Pwr'], StimDF['Pwr'], nan_policy='omit')
            chan_tVals.append(tVal)
            chan_pVals.append(pVal)

            # *** Ensure we have enough permutations to calculate bounds ***
        if len(chan_tVals) > 0:    
            # store upper/lower bounds for permuted t-statistic (p < 0.05)    
            t_UpperBound = np.percentile(chan_tVals, 97.5)
            t_LowerBound = np.percentile(chan_tVals, 2.5)
            
            # determine if empirical t-stat is significant
            trueT = contrastDF[contrastDF['Chan'] == chan]['t Stat'].values[0]
            if trueT > t_UpperBound or trueT < t_LowerBound:
                contrastDF.loc[contrastDF['Chan'] == chan, 'Perm Sig'] = True
            else:
                contrastDF.loc[contrastDF['Chan'] == chan, 'Perm Sig'] = False
        else:
            # *** If not enough permutations, mark the channel as not significant ***
            contrastDF.loc[contrastDF['Chan'] == chan, 'Perm Sig'] = False
            perm_Ts.append([np.nan] * 1000)
            perm_Ps.append([np.nan] * 1000)
                
        # save permuted t- and p-values
        perm_Ts.append(chan_tVals)
        perm_Ps.append(chan_pVals)

    # Save values from permutation test
    contrastDF['Perm Ts'] = perm_Ts
    contrastDF['Perm Ps'] = perm_Ps

    # Export DFs to .csv
    pwrDF.to_csv(os.path.join(savepath, saveStr + 'PwrDF.csv'))
    contrastDF.to_csv(os.path.join(savepath, saveStr + 'ContrastDF.csv'))
    
    return contrastDF

### 3.5 Visualize Channel-Wise Band Power Contrasts

#### 3.5.1 Plot all Contrasts (Theta (3-8Hz), Slow Gamma, HFA)

In [None]:
### Edit these before running the below cells for comparisons
nostim_PSD = NS_power_allchans # The power data. Make sure its all channels (we will filter below)
stimPSD = {
    'BeforeStim': BeforeStim_power_allchans,
    'DuringStim': DuringStim_power_allchans,
    'AfterStim': AfterStim_power_allchans,
}

# change this manually for either the ROI channels or all channels
roi_choice_name = 'all_chans' #'MTL'
roi_choice = list(NoStimEpoch.ch_names) # List of channels to pick (will use this to filter below)
#roi_choice = ROIs['MTL']


PSD_freqs = np.linspace(1.99203187, 149.40239044, 75)  # the original frequency array

# Structure is:
# first element: name for figures
# second element: name for the power_contrast_band function
# third element: the frequency range
freq_bands_to_process = [
    ['Theta (5-8 Hz)', 'theta', [5,8]],
    ['Slow Gamma (30-55 Hz)', 'SlowGamma', [30,55]],
    ['HFA (70-130 Hz)', 'HighFreqActivity', [70,130]],
]


In [None]:
for stim_name, stim_PSD in stimPSD.items():
    for band_figure_name, band_name, band_range in freq_bands_to_process:
        print(f'Running {stim_name} {band_name} Contrast Analysis...')
        NoStimBand = get_PSD_band(np.log10(nostim_PSD.copy().pick(roi_choice).get_data())*10, PSD_freqs, band = band_range)
        StimThetaBand = get_PSD_band( np.log10(stim_PSD.copy().pick(roi_choice).get_data())*10 , PSD_freqs, band = band_range)
        thetaContrastDF = power_contrast_band(stim_name, NoStimBand, StimThetaBand, roi_choice, band_name)

        # Figure params
        FigScaling = n_chans/5 # scale figures by number of channels
        fig, ax = plt.subplots(1,1, figsize = (FigScaling,3))

        # Color params
        vcenter = 0
        vmin = -5
        vmax = 5
        colormap = cm.coolwarm
        norm = plt.Normalize(-5, 5)

        # Copy contrast DF
        ContrastDF = thetaContrastDF.copy()

        # Remove channels w/ t-Stat greater/less than 5
        ContrastDF = ContrastDF[ContrastDF['t Acceptable'] == True]

        # Reset index
        ContrastDF = ContrastDF.reset_index()

        # Plot channel-wise contrasts
        sns.scatterplot(x = ContrastDF['Chan'], y = ContrastDF['t Stat'], c = ContrastDF['t Stat'], 
                        norm = norm, cmap = colormap, edgecolor='black', linewidth = .25, zorder = 10)
        plt.axhline(0, linestyle = 'dotted', linewidth = 1, color = 'grey') # draw zero line
        plt.title(f'{subject} {band_figure_name} Contrast: {stim_name} vs. NoStim ({roi_choice_name})')
        # Figure aesthetics
        ax.set_ylim(-5.5,5.5)
        ax.set_xlabel(f'Channel')
        ax.set_xticklabels(fontsize = 'x-small', labels = ContrastDF['Chan'], rotation = 90)
        ax.set_ylabel('Paired $\it{t}$-Stat \n (NoStim vs. Stim)')
        sns.despine(top = True, right = True)

        # Highlight significant channels in bold
        for i in ContrastDF[ContrastDF['Perm Sig'] == True].index: 
            ax.get_xticklabels()[i].set_fontweight("bold")

        # Save
        plt.savefig(os.path.join(savepath, f'{band_figure_name.replace(" ", "")}Contrast_{stim_name}_{roi_choice_name}.png'), dpi = 1200, bbox_inches = 'tight')
        plt.show()

# STOP HERE June 2024

## 4. Phase-Amplitude Coupling

In [None]:
#compute Inter-Trial Coherence (ITC) and Event-Related Phase-Amplitude Coupling (ERPAC) for each channel
def plot_ITCxPAC(epoch, chan):
    '''
    Plot ITC and PAC from raw epoch data.
    
    Arguments:
        epoch: mne.Epoch object
        chan: channel index to plot (int)
    
    Returns:
        None
    '''
    
    # Get post-stim data from one channel
    chanLabel = epoch.ch_names[chan]
    chanData = np.squeeze(epoch.copy().crop(tmin = 0, tmax = 0.7).pick(chan).get_data())
    time = np.arange(chanData.shape[1]) / 500
    stimcond = 'AvgStim'

    # Compute ITC & PAC
    itc = ITC(chanData, sf = 500, f_pha = (1, 12, 1, .1))
    pac = EventRelatedPac(f_pha=[5, 8], f_amp = (20, 181, 30, .5))
    erpac = pac.filterfit(500, chanData, method='gc', smooth=50)
    
    # Figure params
    plt.figure(figsize = (8,6))
    plt.suptitle(chanLabel + '_' + stimcond, fontsize = 'x-large', x = .45)

    # ITC Plot
    plt.subplot(211)
    itc.plot(times = time, cmap = 'magma')
    plt.axvline(x=0, lw = 1, linestyle = '-', color = 'w')
    plt.axvline(x=0.5, lw = 1, linestyle = '--', color = 'w')
    plt.xlabel('')
    plt.xticks([])
    plt.ylabel('Phase Freq. (Hz)')
    plt.yticks([5, 6, 7, 8])
    plt.ylim([5, 8])
    plt.title('Inter-Trial Coherence')

    # PAC Plot
    plt.subplot(212)
    pac.pacplot(erpac.squeeze(), time, pac.yvec, cmap = 'Spectral_r')
    plt.axvline(x=0, lw = 1, linestyle = '-', color = 'black')
    plt.axvline(x=0.5, lw = 1, linestyle = '--', color = 'black')
    plt.xlabel('Time (s)')
    #plt.xticks([0, 1, 2], ['0', '0.5', '0.7'])
    plt.ylabel('Amp Freq. (Hz)')
    plt.yticks([50, 70, 90, 100, 130])
    plt.ylim([50, 130])
    plt.title('Event-Related PAC (Phase Freq: 5-8 Hz)')

    # Save
    plt.savefig(os.path.join(pacPath, chanLabel + '_ITC-ERPA_'+ stimcond +'.png'), dpi = 1500, bbox_inches = 'tight')

In [None]:
# # Create PAC folder if it doesn't exist
pacPath = os.path.join(savepath, 'phase_amp_coupling')
if not os.path.exists(pacPath):
    os.mkdir(pacPath)

## **** Uncomment this section to run PAC for a single channel ****
# Define the channel you want to analyze
stim_epochs = {'NoStim': NoStimEpoch, 'AvgStim': AvgStimEpoch}

stimcond = 'AvgStim'
channel_to_analyze = '4Rd2'
chan_idx = stim_epochs[stimcond].ch_names.index(channel_to_analyze)
print(f'Computing Inter-Trial Coherence & PAC for {channel_to_analyze}')
plot_ITCxPAC(stim_epochs[stimcond], chan_idx)
#plt.close()

'''
    ## **** Uncomment this section to run PAC for all channels ****
    for i in range(n_chans):
        print(f'Computing Inter-Trial Coherence & PAC for {NoStimEpoch.ch_names[i]}')
        plot_ITCxPAC(NoStimEpoch, i)
        plt.close()
 '''

''' Uncomment for ROI channels
# List of ROI channel names to run    
channel_name_lists_to_run = [
    ['roi_hipp', ROI_hipp_channels],
    ['roi_amyg', ROI_amyg_channels],
    ['roi_ec', ROI_EC_channels],
    ['roi_prc', ROI_PRC_channels],
    ['roi_phc', ROI_PHC_channels],
    ['all_roi', ROI_channels],
    ['all_ch', NoStimEpoch.ch_names],
]

for display_channel_name, channel_name_list in channel_name_lists_to_run:
    print(f"Plotting PAC for {display_channel_name} channels (total channel: {len(channel_name_list)}): {channel_name_list}")

    if not channel_name_list:
        print(f"No channels found for {display_channel_name}, skipping.")
        continue

    # # Generate Inter-Trial Coherence (ITC) and Event-Related Phase-Amplitude Coupling (ERPAC) for each channel
    for i in channel_name_list:
        chan_idx = NoStimEpoch.ch_names.index(i)
        print(f'Computing Inter-Trial Coherence & PAC for {i}')
        plot_ITCxPAC(NoStimEpoch, chan_idx)
        plt.close()
    '''
clear_output(wait=True)