# Imports

In [1]:
import mne
import matplotlib.pyplot as plt
import os
import json
import numpy as np
import itertools
import pandas as pd
from collections import defaultdict
from collections import Counter
import re
import seaborn as sns
from statsmodels.formula.api import rlm, mixedlm

from preprocess_nirs import *

from mne_nirs.channels import picks_pair_to_idx, get_long_channels
from mne_nirs.preprocessing import peak_power, scalp_coupling_index_windowed
from mne.preprocessing.nirs import source_detector_distances, scalp_coupling_index
from mne_connectivity import spectral_connectivity_epochs
from mne_connectivity.viz import plot_connectivity_circle
from mne.viz import circular_layout
from mne_nirs.experimental_design import make_first_level_design_matrix
from mne_nirs.statistics import run_glm, statsmodels_to_results
from mne_nirs.visualisation import plot_glm_group_topo

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from sklearn.linear_model import LogisticRegression, SGDClassifier
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn.svm import LinearSVC, SVC
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import GridSearchCV

# Set the image dpi
dpi = 600

# Set the number of cores to use
n_jobs = 1
if os.cpu_count() is not None:
    n_jobs = os.cpu_count() // 2

# Pick Analyses to run

In [2]:
# Preprocess the data
process_data = False
good_threshold = 0.7
cov_threshold = 15

# Get the non windowed SCI and CV measure
get_full_sci = False
get_full_cv = False

# Get the peak power/scalp coupling index dataframes
get_peak_power_sci_df = False
peak_power_threshold = 0.1

# Get the SCI violin plots per participant
get_participant_sci_plots = False

# Get the GLM data
get_glm_analysis = False
get_ind_glm_plots = False
get_group_glm_plots = False
get_group_contrast_plots = False

# Get the average timeseries activity
get_avg_timeseries_activity = False

# Get the ERP plots
get_erp_mean_regions_plots = False
get_erp_per_region_plots = False

# Get the topographic plots
get_topo_condition_plots = False
get_topo_diff_plots = False

# Get the connectivity data
run_ind_connectivity = False
get_ind_con_plots = False

# Get the average connectivity over time over all participants
get_avg_ind_con_plot_avg = False

# Get the connectivity data for each condition
run_condition_connectivity = False

# Get the average connectivity data over time
get_avg_condition_con_plot = False

# Get the histogram connectivity plots
get_hist_con_plots = False
get_hist_diff_face_plots = False
get_hist_diff_emotion_plots = False

# Get the heatmap connectivity plots
get_heatmap_con_dist_plots = False
get_heatmap_diff_face_plots = False
get_heatmap_diff_emotion_plots = False

# Get the chord connectivity plots
get_chord_con_plots = False
get_chord_diff_face_plots = False
get_chord_diff_emotion_plots = False

# Get the time series data
get_time_series = False
run_models = False

# Get Participants

In [3]:
# Define the folder path
data_path = r"""C:\Users\super\OneDrive - Ontario Tech University\fNIRS_Emotions\data"""

# Get a list of paths of all the subfolders of the folders labeled 'P_1', 'P_2', etc.
participants = [os.path.join(data_path, f) for f in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, f))]

participants_with_same_order = []

# remove participants P_1 to P_11 but keep P_10, and P_12 onwards
# P_1 to P_11 have the same order of faces, and P_86 and P_87 have the same order of faces
for i in range(1, 12):
    if i != 10:
        participants_with_same_order.append(os.path.join(data_path, f'P_{i}'))
        #participants.remove(os.path.join(data_path, f'P_{i}'))

participants_with_same_order.append(os.path.join(data_path, f'P_87'))
#participants.remove(os.path.join(data_path, f'P_87'))

# remove participants P_13 due to not recording
participants.remove(os.path.join(data_path, f'P_13'))

# remove participants P_50 due to ending early
participants.remove(os.path.join(data_path, f'P_50'))

# participant P_54 used their phone
participants.remove(os.path.join(data_path, f'P_54'))

# Search recursively for the folder with the .snirf extension
fnirs_folders = []
for participant in participants:
    for root, dirs, files in os.walk(participant):
        for file in files:
            if file.endswith('.snirf'):
                fnirs_folders.append(root)
                break

# Preprocessing

In [None]:
raws = []
raw_ods = []
raw_haemos = []

if process_data:
    # Load the snirf files
    for folder in fnirs_folders:
        # find all the .snirf files in the folder but get the full path
        snirf_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.snirf')]
        if len(snirf_files) == 0:
            print(f"No .snirf files found in {folder}")
            continue
        elif len(snirf_files) > 1:
            raise Exception(f"Multiple .snirf files found in {folder}")
        else:
            raw = mne.io.read_raw_snirf(snirf_files[0], optode_frame='mri', preload=True, verbose=False)
        
        # find all the 'description.json' files in the folder but get the full path
        description_files = [f for f in os.listdir(folder) if 'description.json' in f]
        if len(description_files) == 0:
            print(f"No description.json files found in {folder}")
            continue
        elif len(description_files) > 1:
            raise Exception(f"Multiple description.json files found in {folder}")
        else:
            description = json.load(open(os.path.join(folder, description_files[0])))

        # add the description to the raw object
        raw.info['description'] = str(description)
        raws.append(raw)

    # sort the raws by the measurement date
    raws = sorted(raws, key=lambda x: x.info['meas_date'])

    i = 1
    for raw in raws:
        print(f"Processing participant {i}")
        raw_od, raw_haemo = preprocess_data(raw)
        raw_ods.append(raw_od)
        raw_haemos.append(raw_haemo)
        i += 1

    # clear any files in each folder
    for folder in ['processed_data/raws', 'processed_data/raw_ods', 'processed_data/raw haemos']:
        for f in os.listdir(folder):
            os.remove(os.path.join(folder, f))

    for i, (raw, raw_od, raw_haemo) in enumerate(zip(raws, raw_ods, raw_haemos)):
        # save raw as a fif file
        raw.save(f'processed_data/raws/raw{i}.fif', overwrite=True)

        # save raw_od as a fif file
        raw_od.save(f'processed_data/raw_ods/raw_od{i}.fif', overwrite=True)

        # save raw_haemo as a fif file
        raw_haemo.save(f'processed_data/raw haemos/raw_haemo{i}.fif', overwrite=True)
    
    raws = []
    raw_ods = []
    raw_haemos = []

# count how many files are in the processed_data/raw_haemos folder
processed_data_count = len([f for f in os.listdir('processed_data/raw haemos') if f.endswith('.fif')])

# Load the processed data
for i in range(1, processed_data_count + 1):
    raw = mne.io.read_raw_fif(f'processed_data/raws/raw{i}.fif', preload=True, verbose=False)
    raws.append(raw)

    raw_od = mne.io.read_raw_fif(f'processed_data/raw_ods/raw_od{i}.fif', preload=True, verbose=False)
    raw_ods.append(raw_od)

    raw_haemo = mne.io.read_raw_fif(f'processed_data/raw haemos/raw_haemo{i}.fif', preload=True, verbose=False)
    raw_haemos.append(raw_haemo)

  raw = mne.io.read_raw_fif(f'processed_data/raws/raw{i}.fif', preload=True, verbose=False)
  raw_od = mne.io.read_raw_fif(f'processed_data/raw_ods/raw_od{i}.fif', preload=True, verbose=False)
  raw_haemo = mne.io.read_raw_fif(f'processed_data/raw haemos/raw_haemo{i}.fif', preload=True, verbose=False)
  raw = mne.io.read_raw_fif(f'processed_data/raws/raw{i}.fif', preload=True, verbose=False)
  raw_od = mne.io.read_raw_fif(f'processed_data/raw_ods/raw_od{i}.fif', preload=True, verbose=False)
  raw_haemo = mne.io.read_raw_fif(f'processed_data/raw haemos/raw_haemo{i}.fif', preload=True, verbose=False)
  raw = mne.io.read_raw_fif(f'processed_data/raws/raw{i}.fif', preload=True, verbose=False)
  raw_od = mne.io.read_raw_fif(f'processed_data/raw_ods/raw_od{i}.fif', preload=True, verbose=False)
  raw_haemo = mne.io.read_raw_fif(f'processed_data/raw haemos/raw_haemo{i}.fif', preload=True, verbose=False)
  raw = mne.io.read_raw_fif(f'processed_data/raws/raw{i}.fif', preload=True, verbose=False)

# Participant Information

### Participant Age/Head Size Plot

In [5]:
# Find the file with description.json as a substring in each subfolder
description_files = [os.path.join(subfolder, f) for subfolder in fnirs_folders for f in os.listdir(subfolder) if 'description.json' in f]

# Load the description files
descriptions = [json.load(open(description_file)) for description_file in description_files]

# Get the average age of the participants and convert it to a float
ages = [float(description['age']) for description in descriptions]
average_age = sum(ages) / len(ages)
min_age = min(ages)
max_age = max(ages)
std_age = np.std(ages)

# Convert the remarks to a float and get the average
remarks = [description['remarks'] for description in descriptions]
# if remark is '', replace it with None
remarks = [float(remark) if remark != '' else None for remark in remarks]
average_head_circumference = sum(remark for remark in remarks if remark is not None) / len([remark for remark in remarks if remark is not None])
std_head_circumference = np.std([remark for remark in remarks if remark is not None])
# Create subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 4))

# Age histogram
axs[0].hist(ages, bins=40, edgecolor='black')
axs[0].set_xlabel('Age')
axs[0].set_ylabel('Number of participants')
axs[0].set_title(f'Age distribution of participants\n{len(ages)} participants, Mean: {round(average_age, 2)}, Std: {round(std_age, 2)}')

# Head circumference histogram
axs[1].hist([remark * 2.54 for remark in remarks if remark is not None], bins=40, edgecolor='black')
axs[1].set_xlabel('Head circumference (cm)')
axs[1].set_ylabel('Number of participants')
axs[1].set_title(f'Head circumference distribution of participants\nMean: {round(average_head_circumference * 2.54, 2)} cm, Std: {round(std_head_circumference * 2.54, 2)} cm')

# Adjust layout
plt.tight_layout()
plt.savefig('plots/participants/participant_info.png', dpi=dpi)
plt.close()

### Measurement Dates

In [6]:
# Extract the measurement dates
measurement_dates = [raw_haemo.info['meas_date'] for raw_haemo in raw_haemos]

# Convert to pandas datetime
measurement_dates = pd.to_datetime(measurement_dates)

# Create a plot of the measurement dates
plt.figure(figsize=(15, 10))
plt.plot(measurement_dates, range(1, len(measurement_dates) + 1), 'o-')
plt.xlabel('Measurement date')
plt.ylabel('Participant number')
plt.title('Measurement dates of participants, N = ' + str(len(measurement_dates)))
plt.grid()
plt.tight_layout()
plt.savefig('plots/participants/measurement_dates.png', dpi=dpi / 4)
plt.close()

### Short Distance Channels Check

In [7]:
distance_counts = [
    (
        sum(distances < 0.01),
        sum((distances >= 0.01))
    )
    for raw in raws
    for distances in [np.array(source_detector_distances(raw.info))]
]

# Count unique tuples
unique_distance_counts = Counter(distance_counts)

# Display the results
for (short_count, long_count), participant_count in unique_distance_counts.items():
    print(f"{participant_count} participants with {short_count} short channels (< 1 cm) and {long_count} long channels (>= 1 cm)")

48 participants with 4 short channels (< 1 cm) and 206 long channels (>= 1 cm)
39 participants with 16 short channels (< 1 cm) and 206 long channels (>= 1 cm)


# Mapping brain regions

In [8]:
# get the channel names for hbo
ch_names_hbo = [ch_name for ch_name in raw_haemos[0].ch_names if 'hbo' in ch_name]

ch_mapping_hbo = {
    "Left Frontal": [],
    "Right Frontal": [],
    "Left Prefrontal": [],
    "Right Prefrontal": [],
    "Left Parietal": [],
    "Right Parietal": [],
    "Left Occipital": [],
    "Right Occipital": [],
}

group_boundaries = [0]

ch_mapping_hbo["Left Frontal"].append('S1_D1 hbo')

ch_mapping_hbo["Left Frontal"].append('S1_D2 hbo')

ch_mapping_hbo["Left Frontal"].append('S1_D17 hbo')

# find the channels that have 'S2_', 'S3_', 'S4_', 'S5_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S2_' in ch_name or 'S3_' in ch_name or 'S4_' in ch_name or 'S5_' in ch_name]:
    ch_mapping_hbo["Left Frontal"].append(ch_name)

ch_mapping_hbo["Left Frontal"].append('S6_D2 hbo')

ch_mapping_hbo["Left Frontal"].append('S6_D3 hbo')

ch_mapping_hbo["Left Frontal"].append('S6_D18 hbo')

group_boundaries.append(len(ch_mapping_hbo["Left Frontal"]))

# find the channels that have 'S9_', 'S10_', 'S11_', 'S12_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S9_' in ch_name or 'S10_' in ch_name or 'S11_' in ch_name or 'S12_' in ch_name]:
    ch_mapping_hbo["Right Frontal"].append(ch_name)

group_boundaries.append(len(ch_mapping_hbo["Right Frontal"]) + group_boundaries[-1])

ch_mapping_hbo["Left Prefrontal"].append('S6_D31 hbo')

# find the channels that have 'S7_', 'S8_', 'S25_', 'S26_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S7_' in ch_name or 'S8_' in ch_name or 'S25_' in ch_name or 'S26_' in ch_name]:
    ch_mapping_hbo["Left Prefrontal"].append(ch_name)

group_boundaries.append(len(ch_mapping_hbo["Left Prefrontal"]) + group_boundaries[-1])

# find the channels that have 'S13_', 'S14_', 'S15_', 'S16_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S13_' in ch_name or 'S14_' in ch_name or 'S15_' in ch_name or 'S16_' in ch_name]:
    ch_mapping_hbo["Right Prefrontal"].append(ch_name)

group_boundaries.append(len(ch_mapping_hbo["Right Prefrontal"]) + group_boundaries[-1])

# find the channels that have 'S27_', 'S28_', 'S29_', 'S30_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S27_' in ch_name or 'S28_' in ch_name or 'S29_' in ch_name or 'S30_' in ch_name]:
    ch_mapping_hbo["Left Parietal"].append(ch_name)

group_boundaries.append(len(ch_mapping_hbo["Left Parietal"]) + group_boundaries[-1])

ch_mapping_hbo["Right Occipital"].append('S21_D13 hbo')

ch_mapping_hbo["Right Occipital"].append('S21_D16 hbo')

ch_mapping_hbo["Right Occipital"].append('S23_D15 hbo')

ch_mapping_hbo["Right Occipital"].append('S23_D16 hbo')

# find the channels that have 'S17_', 'S18_', 'S19_', 'S20_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S17_' in ch_name or 'S18_' in ch_name or 'S19_' in ch_name or 'S20_' in ch_name]:
    ch_mapping_hbo["Right Parietal"].append(ch_name)

group_boundaries.append(len(ch_mapping_hbo["Right Parietal"]) + group_boundaries[-1])

# find the channels that have 'S32_', 'S31_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S32_' in ch_name or 'S31_' in ch_name]:
    ch_mapping_hbo["Left Occipital"].append(ch_name)

group_boundaries.append(len(ch_mapping_hbo["Left Occipital"]) + group_boundaries[-1])

# find the channels that have 'S22_', 'S24_' in them
for ch_name in [ch_name for ch_name in ch_names_hbo if 'S22_' in ch_name or 'S24_' in ch_name]:
    ch_mapping_hbo["Right Occipital"].append(ch_name)

ch_mapping_hbo["Right Occipital"].append('S21_D28 hbo')

ch_mapping_hbo["Right Occipital"].append('S23_D30 hbo')

ch_mapping_hbr = {region: [channel.replace('hbo', 'hbr') for channel in ch_mapping_hbo[region]] for region in ch_mapping_hbo}

ch_mapping_all = {region: ch_mapping_hbo[region] + ch_mapping_hbr[region] for region in ch_mapping_hbo}

# concatenate the values of the dictionary into a list
all_channels_hbo = [channel for region in ch_mapping_hbo.values() for channel in region]

# duplicate all_channels but replace 'hbo' with 'hbr'
all_channels_hbr = [channel.replace('hbo', 'hbr') for channel in all_channels_hbo]

# concatenate all_channels_hbo and all_channels_hbr
all_channels = all_channels_hbo + all_channels_hbr

# make a list called ch_mapping_names that has the channel names without the 'hbo' or 'hbr' at the end
ch_mapping_names = {region: [channel[:-4] for channel in ch_mapping_hbo[region]] for region in ch_mapping_hbo}

# Signal Quality

### Scalp Coupling Index (SCI)

In [9]:
if get_full_sci:
    # for each recording, count the number of channels with a sci greater than good_threshold
    good_channels = [sum(scalp_coupling_index(raw_od, verbose=False) >= good_threshold) for raw_od in raw_ods]
    bad_channels = [sum(scalp_coupling_index(raw_od, verbose=False) < good_threshold) for raw_od in raw_ods]
    good_recordings = sum([good_channel >= good_threshold * (good_channel + bad_channel) for good_channel, bad_channel in zip(good_channels, bad_channels)])

    # Plot the good vs bad channels for each recording in a dual bar chart
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(range(len(good_channels)), good_channels, label='Good Channels')
    ax.bar(range(len(bad_channels)), bad_channels, bottom=good_channels, label='Bad Channels')
    ax.set_xlabel('Recording')
    ax.set_ylabel('Number of Channels')
    ax.axhline(raw_od.info['nchan'] * good_threshold, color='green', linestyle='--')
    title = f'Good vs Bad Channels (T = {good_threshold})\nGood Recordings: {good_recordings}, N = {len(raw_ods)}, Retention Rate: {good_recordings / len(raw_ods) * 100:.2f}%'
    ax.set_title(title)
    ax.legend()
    plt.savefig(f'plots/signal quality/Signal Quality (SCI).png', dpi=dpi)
    plt.close()

### Coefficient of Variance (CV)

In [10]:
if get_full_cv:
    # for each recording, count the number of channels with a cv less than cov_threshold
    good_channels = [sum(100 * np.std(ch) / np.mean(ch) < cov_threshold for ch in get_long_channels(raw).get_data()) for raw in raws]
    bad_channels = [sum(100 * np.std(ch) / np.mean(ch) >= cov_threshold for ch in get_long_channels(raw).get_data()) for raw in raws]
    good_recordings = sum([good_channel >= good_threshold * (good_channel + bad_channel) for good_channel, bad_channel in zip(good_channels, bad_channels)])

    # Plot the good vs bad channels for each recording in a dual bar chart
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(range(len(good_channels)), good_channels, label='Good Channels')
    ax.bar(range(len(bad_channels)), bad_channels, bottom=good_channels, label='Bad Channels')
    ax.set_xlabel('Recording')
    ax.set_ylabel('Number of Channels')
    ax.axhline(raw.info['nchan']*good_threshold, color='green', linestyle='--')
    title = f'Good vs Bad Channels (T = {cov_threshold})\nGood Recordings: {good_recordings}, N = {len(raws)}, Retention Rate: {good_recordings / len(raws) * 100:.2f}%'
    ax.set_title(title)
    ax.legend()
    plt.savefig(f'plots/signal quality/Signal Quality (CV).png', dpi=dpi)
    plt.close()

### Get Peak Power/SCI Sliding Window CSV

In [11]:
if get_peak_power_sci_df:
    peak_power_df = pd.DataFrame()
    sci_df = pd.DataFrame()
    for i, raw_od in enumerate(raw_ods, 1):
        raw_od_annotated_pp, scores_pp, times_pp = peak_power(raw_od, time_window=5, threshold=peak_power_threshold, verbose=False)
        raw_od_annotated_sci, scores_sci, times_sci = scalp_coupling_index_windowed(raw_od, time_window=5, threshold=good_threshold, verbose=False)

        # Convert scores and times to a DataFrame
        df_pp = pd.DataFrame(scores_pp.T, columns=[ch_name for ch_name in raw_od.ch_names])
        df_sci = pd.DataFrame(scores_sci.T, columns=[ch_name for ch_name in raw_od.ch_names])

        # Add time window information
        df_pp["Start_Time"] = [t[0] for t in times_pp]
        df_pp["End_Time"] = [t[1] for t in times_pp]
        df_sci["Start_Time"] = [t[0] for t in times_sci]
        df_sci["End_Time"] = [t[1] for t in times_sci]

        # Add an index column for window number
        df_pp.insert(0, 'Window', range(1, len(df_pp) + 1))
        df_sci.insert(0, 'Window', range(1, len(df_sci) + 1))

        # Reorder columns so time comes first
        df_pp = df_pp[["Start_Time", "End_Time"] + list(df_pp.columns[:-2])]
        df_sci = df_sci[["Start_Time", "End_Time"] + list(df_sci.columns[:-2])]

        # remove the columns with '850' in the name
        df_pp = df_pp.loc[:, ~df_pp.columns.str.contains('850')]
        df_sci = df_sci.loc[:, ~df_sci.columns.str.contains('850')]

        # rename the columns to remove the ' 760' at the end if it exists
        df_pp.columns = [col[:-4] if col.endswith(' 760') else col for col in df_pp.columns]
        df_sci.columns = [col[:-4] if col.endswith(' 760') else col for col in df_sci.columns]

        # Add a column for participant number, make it the first column
        df_pp.insert(0, 'Participant', i)
        df_sci.insert(0, 'Participant', i)

        # Append the DataFrame to the list
        peak_power_df = pd.concat([peak_power_df, df_pp])
        sci_df = pd.concat([sci_df, df_sci])
        print(f"Processed participant {i}")

    # reset the index
    peak_power_df.reset_index(drop=True, inplace=True)
    sci_df.reset_index(drop=True, inplace=True)

    # Save the DataFrame to a CSV file
    peak_power_df.to_csv('processed_data/windows/peak_power.csv', index=False)
    sci_df.to_csv('processed_data/windows/sci.csv', index=False)

# Load the DataFrame
peak_power_df = pd.read_csv('processed_data/windows/peak_power.csv')
sci_df = pd.read_csv('processed_data/windows/sci.csv')

### Peak Power/SCI Sliding Window

In [12]:
# Compute the proportion of windows with peak power > peak_power_threshold for each channel
percentage_good_windows_peak_power_df = (
    peak_power_df.groupby("Participant")[peak_power_df.columns[4:]]
    .apply(lambda df: (df > peak_power_threshold).sum() / len(df))
)

# add a good recordings column
good_recordings = (percentage_good_windows_peak_power_df > good_threshold).sum(axis=1) / len(percentage_good_windows_peak_power_df.columns)
percentage_good_windows_peak_power_df.insert(0, f'Good Recordings (peak_power > {peak_power_threshold} for > {good_threshold * 100}% of channels)', good_recordings)

# Compute the proportion of windows with SCI > good_threshold for each channel
percentage_good_windows_sci_df = (
    sci_df.groupby("Participant")[sci_df.columns[4:]]
    .apply(lambda df: (df >= good_threshold).sum() / len(df))
)

# add a good recordings column
good_recordings = (percentage_good_windows_sci_df > good_threshold).sum(axis=1) / len(percentage_good_windows_sci_df.columns)
percentage_good_windows_sci_df.insert(0, f'Good Recordings (SCI > {good_threshold} for > {good_threshold * 100}% of channels)', good_recordings)

# merge the two dataframes on the first 2 columns
percentage_good_windows_df = pd.merge(percentage_good_windows_peak_power_df[percentage_good_windows_peak_power_df.columns[:1]], percentage_good_windows_sci_df[percentage_good_windows_sci_df.columns[:1]], on='Participant')

# create a new column that is true if both columns are greater than good_threshold
percentage_good_windows_df['Good Recording'] = (percentage_good_windows_df.iloc[:, 0] > good_threshold) & (percentage_good_windows_df.iloc[:, 1] > good_threshold)

# Plot a bar chart where the SCI and peak power windows are compared, two bars next to each other for each participant
fig, ax = plt.subplots(figsize=(12, 6))
bar_width = 0.45
x = np.arange(len(percentage_good_windows_df))
ax.bar(x, percentage_good_windows_df.iloc[:, 0], bar_width, label='Peak Power')
ax.bar(x + bar_width, percentage_good_windows_df.iloc[:, 1], bar_width, label='SCI')
ax.axhline(good_threshold, color='green', linestyle='--')
ax.set_xlabel('Participant')
ax.set_ylabel('Percentage of Good Windows')
title = f'Percentage of Good Windows: peak_power > {peak_power_threshold}, SCI > {good_threshold}\nGood Recordings: {percentage_good_windows_df["Good Recording"].sum()}, N = {len(percentage_good_windows_df)}, Retention Rate: {percentage_good_windows_df["Good Recording"].sum() / len(percentage_good_windows_df) * 100:.2f}%'
ax.set_title(title)
ax.set_xticks(x + bar_width / 2)
ax.set_xticklabels(percentage_good_windows_df.index)
ax.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(f'plots/signal quality/Percentage of Good Windows.png', dpi=dpi / 2)
plt.close()

### Peak Power/SCI Over Time

In [13]:
pp_channel_slopes = []
sci_channel_slopes = []
for pp_channel, sci_channel in zip(percentage_good_windows_peak_power_df.columns[4:], percentage_good_windows_sci_df.columns[4:]):
    pp_slope = []
    sci_slope = []
    for participant in peak_power_df['Participant'].unique():
        # get the data for the channel
        pp_array = peak_power_df[peak_power_df['Participant'] == participant][pp_channel]
        sci_array = sci_df[sci_df['Participant'] == participant][sci_channel]

        # get a line of best fit for the data
        x = np.arange(len(pp_array))
        pp_m, pp_b = np.polyfit(x, pp_array, 1)
        pp_slope.append(pp_m)

        sci_m, sci_b = np.polyfit(x, sci_array, 1)
        sci_slope.append(sci_m)
    pp_channel_slopes.append((pp_channel, np.mean(pp_slope)))
    sci_channel_slopes.append((sci_channel, np.mean(sci_slope)))

# plot the slopes
fig, ax = plt.subplots(figsize=(20, 6))
# sort the channels by slope
pp_channel_slopes.sort(key=lambda x: x[1])
sci_channel_slopes.sort(key=lambda x: x[1])
bar_width = 0.45
x = np.arange(len(pp_channel_slopes))
ax.bar(x, [slope for channel, slope in pp_channel_slopes], bar_width, label='Peak Power')
ax.bar(x + bar_width, [slope for channel, slope in sci_channel_slopes], bar_width, label='SCI')
ax.set_xlabel('Channel')
ax.set_ylabel('Slope')
ax.set_title('Slope of Peak Power and SCI over Time')
ax.set_xticks(x + bar_width / 2)
ax.set_xticklabels([channel for channel, slope in pp_channel_slopes])
ax.legend()
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(f'plots/signal quality/Slope of Peak Power and SCI over Time.png', dpi=dpi)
plt.close()

### Head Size vs. SCI Windows

In [14]:
cap_size = 58
percentage_good_windows_sci_df_with_head_size = percentage_good_windows_sci_df.copy()

# add the head size to the percentage_good_windows_sci_df as the first column
if 'Head Size (cm)' not in percentage_good_windows_sci_df_with_head_size.columns:
    percentage_good_windows_sci_df_with_head_size.insert(0, 'Head Size (cm)', cap_size)

no_head_size = []

for i, raw_haemo in enumerate(raw_haemos, 1):
    # get the head size
    head_size = get_info(raw_haemo)['remarks']
    if head_size:
        head_size = float(head_size) * 2.54
    else:
        # add the participant number to the no_head_size list
        no_head_size.append(i)
        continue
    
    # append the head_size to percentage_good_windows_sci_df
    percentage_good_windows_sci_df_with_head_size.loc[i, 'Head Size (cm)'] = head_size

# remove the participants with no head size
percentage_good_windows_sci_df_with_head_size = percentage_good_windows_sci_df_with_head_size.drop(no_head_size)

# get the correlation between head size and the channels
correlations = []
for channel in percentage_good_windows_sci_df_with_head_size.columns[2:]:
    correlation = percentage_good_windows_sci_df_with_head_size['Head Size (cm)'].corr(percentage_good_windows_sci_df_with_head_size[channel])
    correlations.append((channel, correlation))

# get the correlation between head size and the second column
correlation = percentage_good_windows_sci_df_with_head_size['Head Size (cm)'].corr(percentage_good_windows_sci_df_with_head_size.iloc[:, 1])

# plot the correlations
fig, ax = plt.subplots(figsize=(20, 6))
# sort the channels by correlation
correlations.sort(key=lambda x: x[1])
bar_width = 0.45
x = np.arange(len(correlations))
ax.bar(x, [correlation for channel, correlation in correlations], bar_width)
ax.set_xlabel('Channel')
ax.set_ylabel('Correlation')
ax.set_title('Correlation between Head Size and SCI, Correlation with Good Recordings: ' + str(correlation) + ', N = ' + str(len(percentage_good_windows_sci_df_with_head_size)))
ax.set_xticks(x)
ax.set_xticklabels([channel for channel, correlation in correlations])
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(f'plots/signal quality/Correlation between Head Size and SCI.png', dpi=dpi / 3)
plt.close()

  percentage_good_windows_sci_df_with_head_size.loc[i, 'Head Size (cm)'] = head_size


### Average SCI per Channel across Participants

In [15]:
# get a list of participants in percentage_good_windows_sci_df where Good Recording is True
good_participants = percentage_good_windows_df[percentage_good_windows_df['Good Recording'] == True].index

# make a dataframe of the average sci for each channel
avg_sci_df = sci_df.groupby('Participant').mean().drop(columns=['Window', 'Start_Time', 'End_Time'])

# drop the participants that are not in good_participants
avg_sci_df_good = avg_sci_df.loc[good_participants]

# drop the participants that are in good_participants
avg_sci_df_bad = avg_sci_df.drop(index=good_participants)

# make a list of the dataframes
avg_sci_dfs = [avg_sci_df, avg_sci_df_good, avg_sci_df_bad]
df_names = ['All Participants', 'Good Participants', 'Bad Participants']
color_list = ['red', 'blue', 'green', 'purple', 'orange', 'brown', 'pink', 'gray']

for df, df_name in zip(avg_sci_dfs, df_names):

    # make a violin plot of the average sci for each channel
    fig, ax = plt.subplots(figsize=(35, 6))
    parts = ax.violinplot(df, showmeans=False, widths=1, showextrema=False)

    # match the violin plot colors to the columns in avg_sci_df to the channels in ch_mapping_names
    color_i = 0
    colors = []
    region_labels = []

    # for each region in ch_mapping_names, apply the color to the channels in that region
    for region, channels in ch_mapping_names.items():
        for channel in channels:
            if channel in df.columns:
                colors.append(color_list[color_i])
                region_labels.append(region)
        color_i += 1

    # set the colors of the violins
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(colors[i])
        pc.set_edgecolor('black')
        pc.set_alpha(1)

    # create a legend
    handles = [plt.Rectangle((0, 0), 1, 1, color=color) for color in list(dict.fromkeys(colors))]
    ax.legend(handles, list(dict.fromkeys(region_labels)), loc='lower left')

    # add a white scatter plot of the mean sci for each channel
    ax.scatter(np.arange(1, len(df.columns) + 1), df.mean(), color='white', zorder=3)

    ax.set_xlabel('Channel')
    ax.set_ylabel('Average SCI')
    ax.set_ylim(0, 1)
    ax.axhline(good_threshold, color='green', linestyle='--')
    ax.set_title(f'Average SCI per Channel: ({df_name}), N = {len(df)}')
    ax.set_xticks(np.arange(1, len(df.columns) + 1))
    ax.set_xticklabels(df.columns)
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.savefig(f'plots/signal quality/Average SCI (Windowed) per Channel/{df_name}.png', dpi=dpi / 3)
    plt.close()

### SCI of Windows per Channel for each Participant

In [16]:
if get_participant_sci_plots:
    for i in sci_df['Participant'].unique():
        df = sci_df[sci_df['Participant'] == i]

        # get the good recording status
        good_recording = False
        if percentage_good_windows_df['Good Recording'][i]:
            good_recording = True

        # make a violin plot of the sci for each channel
        fig, ax = plt.subplots(figsize=(35, 6))
        parts = ax.violinplot(df[df.columns[4:]], showmeans=False, widths=1, showextrema=False)

        # match the violin plot colors to the columns in sci_df to the channels in ch_mapping_names
        color_i = 0
        colors = []
        region_labels = []

        # for each region in ch_mapping_names, apply the color to the channels in that region
        for region, channels in ch_mapping_names.items():
            for channel in channels:
                if channel in df.columns:
                    colors.append(color_list[color_i])
                    region_labels.append(region)
            color_i += 1

        # set the colors of the violins
        for j, pc in enumerate(parts['bodies']):
            pc.set_facecolor(colors[j])
            pc.set_edgecolor('black')
            pc.set_alpha(1)

        # create a legend
        handles = [plt.Rectangle((0, 0), 1, 1, color=color) for color in list(dict.fromkeys(colors))]
        ax.legend(handles, list(dict.fromkeys(region_labels)), loc='lower left')

        # add a white scatter plot of the mean sci for each channel
        ax.scatter(np.arange(1, len(df.columns) - 3), df[df.columns[4:]].mean(), color='white', zorder=3)

        ax.set_xlabel('Channel')
        ax.set_ylabel('SCI')
        ax.set_ylim(0, 1)
        ax.axhline(good_threshold, color='green', linestyle='--')
        ax.set_title(f'SCI per Channel: Participant {i}, Windows = {len(df)}, Good Recording = {good_recording}')
        ax.set_xticks(np.arange(1, len(df.columns) - 3))
        ax.set_xticklabels(df.columns[4:])
        plt.xticks(rotation=90)
        plt.tight_layout()
        plt.savefig(f'plots/signal quality/Average SCI (Windowed) per Channel/individual/Participant {i}.png', dpi=dpi / 3)
        plt.close()

# Get Good Recordings

In [17]:
raw_haemo_good_recordings = []
for i, raw_haemo in enumerate(raw_haemos, 1):
    if len(percentage_good_windows_df) >= i:
        if percentage_good_windows_df.iloc[i - 1]['Good Recording']:
            raw_haemo_good_recordings.append(raw_haemo.copy())

# GLM analysis

In [None]:
channel_types = ['hbo', 'hbr', 'hbt']

modes = ['face_type', 'emotion']
conditions_list = {
    'face_type': ['Real', 'Virt', 'Base'],
    'emotion': ['Joy', 'Fear', 'Anger', 'Disgust', 'Sadness', 'Neutral', 'Surprise']
}

if get_glm_analysis:
    for mode in modes:
        cha_df = pd.DataFrame()
        roi_df = pd.DataFrame()
        con_df = pd.DataFrame()
        for i, raw_haemo in enumerate(raw_haemo_good_recordings, 1):
            raw_haemo_annots = raw_haemo.copy()
            pick_channels(raw_haemo_annots, channel_types)
            relabel_annotations(raw_haemo_annots, mode=mode)

            # Create a design matrix
            design_matrix = make_first_level_design_matrix(
                raw_haemo_annots,
                drift_model="cosine",
                high_pass=0.03125,  # The cutoff period (1/high_pass) should be set as the longest period between two trials of the same condition multiplied by 2
                hrf_model="spm",
                stim_dur=16.0,
            )
            
            # Run GLM
            glm_est = run_glm(raw_haemo_annots, design_matrix)

            cha = glm_est.to_dataframe()

            # in ch_mapping_all, for each list of channels in the dict, each string is formatted as 'S{number}_D{number} {hbo/hbr}', extract the number from the string and replace the string with [number, number]
            groups = {region: [[int(re.findall(r'\d+', channel)[0]), int(re.findall(r'\d+', channel)[1])] for channel in ch_mapping_all[region]] for region in ch_mapping_all}
            # apply picks_pair_to_idx to each region in groups
            for region in groups:
                groups[region] = picks_pair_to_idx(raw_haemo_annots, groups[region], on_missing='ignore')

            # Compute region of interest results from channel data
            roi = glm_est.to_dataframe_region_of_interest(
                groups, design_matrix.columns, demographic_info=True
            )

            # Define contrasts
            contrast_matrix = np.eye(design_matrix.shape[1])
            basic_conts = dict(
                [(column, contrast_matrix[j]) for j, column in enumerate(design_matrix.columns)]
            )
            contrasts = []
            unique_annots = np.unique(raw_haemo_annots.annotations.description).tolist()
            pairs = list(itertools.combinations(unique_annots, 2))

            # Compute defined contrast pairs
            for pair in pairs:
                con = glm_est.compute_contrast(basic_conts[pair[0]] - basic_conts[pair[1]]).to_dataframe()
                con["contrast_pair"] = f"{pair[0]} - {pair[1]}"
                contrasts.append(con)

            # Add the participant ID to the dataframes
            roi["Participant"] = cha["Participant"] = i
            for con in contrasts:
                con["Participant"] = i

            # Convert to uM for nicer plotting below.
            cha["theta"] = [t * 1.0e6 for t in cha["theta"]]
            roi["theta"] = [t * 1.0e6 for t in roi["theta"]]
            for con in contrasts:
                con["effect"] = [t * 1.0e6 for t in con["effect"]]

            # Append the dataframes to the main dataframes
            cha_df = pd.concat([cha_df, cha])
            roi_df = pd.concat([roi_df, roi])
            for con in contrasts:
                con_df = pd.concat([con_df, con])

        cha_df.to_csv('processed_data/glm/cha/cha_df_' + mode + '.csv', index=False)
        roi_df.to_csv('processed_data/glm/roi/roi_df_' + mode + '.csv', index=False)
        con_df.to_csv('processed_data/glm/cons/con_df_' + mode + '.csv', index=False)

# load the dataframes
glm = {
    mode: {
        'cha': pd.read_csv(f'processed_data/glm/cha/cha_df_{mode}.csv'),
        'roi': pd.read_csv(f'processed_data/glm/roi/roi_df_{mode}.csv'),
        'con': pd.read_csv(f'processed_data/glm/cons/con_df_{mode}.csv')
    }
    for mode in modes
}

### Individual GLM Results

In [23]:
if get_ind_glm_plots:
    for mode in modes:
        grp_results = glm[mode]['roi'].query(f"Condition in {conditions_list[mode]}")
        grp_results = grp_results.query("Chroma in ['hbo']")

        theta_min = grp_results['theta'].min()
        theta_max = grp_results['theta'].max()

        # clear any files in the plots/glm/individual folder
        for f in os.listdir('plots/glm/individual_' + mode):
            os.remove(os.path.join('plots/glm/individual_' + mode, f))

        for i in grp_results['Participant'].unique():
            # make a scatter plot of the GLM results
            fig, ax = plt.subplots(figsize=(6, 6))
            sns.swarmplot(data=grp_results.query(f"Participant == {i}"), x='Condition', y='theta', hue='ROI', ax=ax, dodge=False)
            ax.set_title(f'GLM Results for Participant {i}')
            ax.set_ylabel('Theta (uM)')
            ax.set_ylim(theta_min, theta_max)
            ax.set_xlabel('Condition')
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.tight_layout()
            plt.savefig(f'plots/glm/individual_{mode}/Participant {i}.png', dpi=dpi / 4)
            plt.close()

### Group GLM Results

In [24]:
if get_group_glm_plots:
    for mode in modes:
        grp_results = glm[mode]['roi'].query(f"Condition in {conditions_list[mode]}")
        grp_results = grp_results.query("Chroma in ['hbo']")

        # Run a GLM model
        roi_model = mixedlm("theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["Participant"]).fit(method="nm")
        #roi_model = rlm('theta ~ -1 + ROI:Condition:Chroma', grp_results).fit()

        # Get the results of the model and put it in a csv file
        roi_model_results = statsmodels_to_results(roi_model)
        with open('processed_data/glm/results/roi_model_results_' + mode + '.csv', 'w') as f:
            f.write(roi_model_results.to_csv())
        f.close()

        # plot the results of the model
        fig, ax = plt.subplots(figsize=(6, 6))
        sns.swarmplot(data=roi_model_results, x='Condition', y='Coef.', hue='ROI', ax=ax, dodge=False)
        ax.set_title('GLM Results')
        ax.set_ylabel('Coefficient')
        ax.set_xlabel('Condition')
        ax.set_ylim(roi_model_results['Coef.'].min() - 1, roi_model_results['Coef.'].max() + 1)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.xticks(rotation=45)
        plt.savefig(f'plots/glm/group_results/results_{mode}.png', dpi=dpi / 4)
        plt.close()

### Group Contrasts

In [49]:
if get_group_contrast_plots:
    for mode in modes:
        pairs = list(itertools.combinations(conditions_list[mode], 2))
        for pair in pairs:
            con_summary = glm[mode]['con'].query(f"contrast_pair == '{pair[0]} - {pair[1]}'")
            con_summary = con_summary.query("Chroma in ['hbo']")

            if len(con_summary) == 0:
                continue

            # Run group level model and convert to dataframe
            con_model = mixedlm("effect ~ -1 + ch_name:Chroma", con_summary, groups=con_summary["Participant"]).fit(method="nm")

            con_model_df = statsmodels_to_results(con_model, order=raw_haemo.copy().pick(picks="hbo").ch_names)
            vlim = max(abs(con_model_df['Coef.'].min()), abs(con_model_df['Coef.'].max()))
            vlim = (-vlim, vlim)

            plot_glm_group_topo(
                raw_haemo.copy().pick(picks="hbo"), con_model_df, colorbar=True, extrapolate="head", threshold=True, vlim=vlim
            )
            plt.suptitle(f"Contrast: {pair[0]} - {pair[1]}")
            plt.savefig(f'plots/glm/topomaps/GLM_Result_{pair[0]}-{pair[1]}.png', dpi=dpi / 3)
            plt.close()

# Average Timeseries Activity

In [23]:
if get_avg_timeseries_activity:
    # create an empty dataframe
    average_timeseries_activity = pd.DataFrame()

    tmin = 4
    tmax = 16

    i = 1
    for raw_haemo in raw_haemo_good_recordings:
        face_epochs = relabel_annotations(raw_haemo.copy(), mode='face_type')
        emotion_epochs = relabel_annotations(raw_haemo.copy(), mode='emotion')

        # crop the epochs to tmin-tmax
        face_epochs.crop(tmin=tmin, tmax=tmax)

        # convert the epochs to a dataframe
        face_epochs_df = face_epochs.to_data_frame()

        # remove the baseline condition
        face_epochs_df = face_epochs_df.where(face_epochs_df['condition'] != 'Base').dropna()

        # average these columns: column_names[3:].tolist(), by the epoch column and condition column
        face_epochs_df = face_epochs_df.groupby(['epoch', 'condition']).mean().reset_index()

        # crop the epochs to tmix-tmax
        emotion_epochs.crop(tmin=tmin, tmax=tmax)

        # convert the epochs to a dataframe
        emotion_epochs_df = emotion_epochs.to_data_frame()

        # remove the baseline condition
        emotion_epochs_df = emotion_epochs_df.where(emotion_epochs_df['condition'] != 'Base').dropna()

        # average these columns: column_names[3:].tolist(), by the epoch column and condition column
        emotion_epochs_df = emotion_epochs_df.groupby(['epoch', 'condition']).mean().reset_index()

        # add the condition column from the emotion_epochs_df to the face_epochs_df and line it up with the epoch column
        face_epochs_df['emotion'] = emotion_epochs_df['condition']

        # put the emotion column in the third column
        all_epochs_df = face_epochs_df[['epoch', 'condition', 'emotion'] + face_epochs_df.columns[2:-1].tolist()]

        # rename the condition column to face type
        all_epochs_df.rename(columns={'condition': 'face type'}, inplace=True)

        # divide the epoch column by 2 and floor it and convert it to an integer
        all_epochs_df['epoch'] = (all_epochs_df['epoch'] // 2).astype(int)

        # remove the time column
        all_epochs_df.drop(columns='time', inplace=True)

        for region, channels in ch_mapping_hbo.items():
            # Ensure the channels exist in the dataframe to avoid errors
            valid_channels = [channel for channel in channels if channel in all_epochs_df.columns]
            if valid_channels:
                # Create a new column for the region's average
                all_epochs_df[region + ' Average Hbo'] = all_epochs_df[valid_channels].mean(axis=1)

        for region, channels in ch_mapping_hbr.items():
            # Ensure the channels exist in the dataframe to avoid errors
            valid_channels = [channel for channel in channels if channel in all_epochs_df.columns]
            if valid_channels:
                # Create a new column for the region's average
                all_epochs_df[region + ' Average Hbr'] = all_epochs_df[valid_channels].mean(axis=1)

        # drop all the channel columns
        all_epochs_df.drop(columns=all_channels, inplace=True)

        # add participant number column
        all_epochs_df['Participant'] = i

        # make the participant number the first column
        all_epochs_df = all_epochs_df[['Participant'] + all_epochs_df.columns[:-1].tolist()]

        # add measurement date column
        #all_epochs_df['Measurement Date'] = raw_haemo.info['meas_date']

        # add an empty column for repetition and put it after the emotion column
        all_epochs_df.insert(4, 'Repetition', '')

        conditions = defaultdict(int)
        for index, row in all_epochs_df.iterrows():
            # add the condition-emotion pair to the conditions dictionary and increment the count
            conditions[f"{row['face type']}-{row['emotion']}"] += 1

            # add the count to the repetition column
            all_epochs_df.at[index, 'Repetition'] = conditions[f"{row['face type']}-{row['emotion']}"]

        # put it after the repetition column
        all_epochs_df.insert(5, 'Sex', '')

        # add the sex column
        all_epochs_df['Sex'] = get_info(raw_haemo)['gender']

        # add the dataframe to the average_timeseries_activity datafram
        average_timeseries_activity = pd.concat([average_timeseries_activity, all_epochs_df])

        i += 1

    # reset the index
    average_timeseries_activity.reset_index(drop=True, inplace=True)

    # name the index column to 'observation'
    average_timeseries_activity.index.name = 'Observation'

    # replace any spaces in the column names with underscores
    average_timeseries_activity.columns = average_timeseries_activity.columns.str.replace(' ', '_')

    # capitalize the column names
    average_timeseries_activity.columns = average_timeseries_activity.columns.str.capitalize()

    mappings = {}
    for col in average_timeseries_activity.select_dtypes(include=['object']).columns:
        # Create a mapping dictionary for the column
        unique_values = average_timeseries_activity[col].unique()
        col_mapping = {val: idx for idx, val in enumerate(unique_values)}
        mappings[col] = col_mapping
        
        # Replace the column values in the DataFrame with numeric values
        average_timeseries_activity[col] = average_timeseries_activity[col].map(col_mapping)

    # Save mappings to a JSON file
    with open('processed_data/average_timeseries_activity/mappings.json', 'w') as json_file:
        json.dump(mappings, json_file, indent=4)

    # save the dataframe to a csv file
    average_timeseries_activity.to_csv('processed_data/average_timeseries_activity/average_timeseries_activity.csv')

# Generate Epochs for all Conditions

In [24]:
all_epochs_faces = []
all_epochs_emotions = []

for raw_haemo in raw_haemo_good_recordings:
    raw_haemo.reorder_channels(all_channels)
    face_epochs = relabel_annotations(raw_haemo.copy(), mode='face_type')
    emotion_epochs = relabel_annotations(raw_haemo.copy(), mode='emotion')
    all_epochs_faces.append(face_epochs)
    all_epochs_emotions.append(emotion_epochs)

all_epochs = {
        "Real": mne.concatenate_epochs([epochs['Real'] for epochs in all_epochs_faces]),
        "Virtual": mne.concatenate_epochs([epochs['Virt'] for epochs in all_epochs_faces]),
        "Joy": mne.concatenate_epochs([epochs['Joy'] for epochs in all_epochs_emotions]),
        "Fear": mne.concatenate_epochs([epochs['Fear'] for epochs in all_epochs_emotions]),
        "Anger": mne.concatenate_epochs([epochs['Anger'] for epochs in all_epochs_emotions]),
        "Disgust": mne.concatenate_epochs([epochs['Disgust'] for epochs in all_epochs_emotions]),
        "Sadness": mne.concatenate_epochs([epochs['Sadness'] for epochs in all_epochs_emotions]),
        "Neutral": mne.concatenate_epochs([epochs['Neutral'] for epochs in all_epochs_emotions]),
        "Surprise": mne.concatenate_epochs([epochs['Surprise'] for epochs in all_epochs_emotions])
}

  "Real": mne.concatenate_epochs([epochs['Real'] for epochs in all_epochs_faces]),


Not setting metadata
1092 matching events found
Applying baseline correction (mode: mean)


  "Virtual": mne.concatenate_epochs([epochs['Virt'] for epochs in all_epochs_faces]),


Not setting metadata
1092 matching events found
Applying baseline correction (mode: mean)


  "Joy": mne.concatenate_epochs([epochs['Joy'] for epochs in all_epochs_emotions]),


Not setting metadata
312 matching events found
Applying baseline correction (mode: mean)


  "Fear": mne.concatenate_epochs([epochs['Fear'] for epochs in all_epochs_emotions]),


Not setting metadata
312 matching events found
Applying baseline correction (mode: mean)


  "Anger": mne.concatenate_epochs([epochs['Anger'] for epochs in all_epochs_emotions]),


Not setting metadata
312 matching events found
Applying baseline correction (mode: mean)


  "Disgust": mne.concatenate_epochs([epochs['Disgust'] for epochs in all_epochs_emotions]),


Not setting metadata
312 matching events found
Applying baseline correction (mode: mean)


  "Sadness": mne.concatenate_epochs([epochs['Sadness'] for epochs in all_epochs_emotions]),


Not setting metadata
312 matching events found
Applying baseline correction (mode: mean)


  "Neutral": mne.concatenate_epochs([epochs['Neutral'] for epochs in all_epochs_emotions]),


Not setting metadata
312 matching events found
Applying baseline correction (mode: mean)


  "Surprise": mne.concatenate_epochs([epochs['Surprise'] for epochs in all_epochs_emotions])


Not setting metadata
312 matching events found
Applying baseline correction (mode: mean)


# ERP Analysis

### Plots settings

In [25]:
include_hbr = False

# set the y-axis limit
lims = dict(hbo=[-6, 6], hbr=[-6, 6])

# Get all emotions in the dataset
emotions = ["Joy", "Fear", "Anger", "Disgust", "Sadness", "Neutral", "Surprise"]

# Create pairwise combinations of emotions
emotion_pairs = list(itertools.combinations(emotions, 2))

if include_hbr:
    color_dict_faces = dict(RealHbo="r", RealHbr="r", VirtualHbo="b", VirtualHbr="b")
    styles_dict_faces = dict(RealHbo=dict(linestyle="solid"), RealHbr=dict(linestyle="dotted"), VirtualHbo=dict(linestyle="solid"), VirtualHbr=dict(linestyle="dotted"))

    evoked_dict_faces = dict(
    RealHbo=list(all_epochs["Real"].pick(picks="hbo").iter_evoked()),
    RealHbr=list(all_epochs["Real"].pick(picks="hbr").iter_evoked()),
    VirtualHbo=list(all_epochs["Virtual"].pick(picks="hbo").iter_evoked()),
    VirtualHbr=list(all_epochs["Virtual"].pick(picks="hbr").iter_evoked()),
    )

    evoked_dict_emotions = dict(
    JoyHbo=list(all_epochs["Joy"].pick(picks="hbo").iter_evoked()),
    JoyHbr=all_epochs["Joy"].pick(picks="hbr"),
    FearHbo=list(all_epochs["Fear"].pick(picks="hbo").iter_evoked()),
    FearHbr=all_epochs["Fear"].pick(picks="hbr"),
    AngerHbo=list(all_epochs["Anger"].pick(picks="hbo").iter_evoked()),
    AngerHbr=all_epochs["Anger"].pick(picks="hbr"),
    DisgustHbo=list(all_epochs["Disgust"].pick(picks="hbo").iter_evoked()),
    DisgustHbr=all_epochs["Disgust"].pick(picks="hbr"),
    SadnessHbo=list(all_epochs["Sadness"].pick(picks="hbo").iter_evoked()),
    SadnessHbr=all_epochs["Sadness"].pick(picks="hbr"),
    NeutralHbo=list(all_epochs["Neutral"].pick(picks="hbo").iter_evoked()),
    NeutralHbr=all_epochs["Neutral"].pick(picks="hbr"),
    SurpriseHbo=list(all_epochs["Surprise"].pick(picks="hbo").iter_evoked()),
    SurpriseHbr=all_epochs["Surprise"].pick(picks="hbr"),
    )

    color_dict_emotions = dict(
    JoyHbo="yellow",        # A bright and classic representation of happiness
    JoyHbr="yellow",        # Same as JoyHbo
    FearHbo="purple",       # Associated with tension or mystery
    FearHbr="purple",       # Same as FearHbo
    AngerHbo="red",         # Classic association with anger
    AngerHbr="red",         # Same as AngerHbo
    DisgustHbo="green",     # Green, but brighter to differentiate from olive
    DisgustHbr="green",     # Same as DisgustHbo
    SadnessHbo="blue",      # Commonly associated with sadness
    SadnessHbr="blue",      # Same as SadnessHbo
    NeutralHbo="gray",      # Neutral tones
    NeutralHbr="gray",      # Same as NeutralHbo
    SurpriseHbo="orange",   # Bright, attention-grabbing orange
    SurpriseHbr="orange"    # Same as SurpriseHbo
    )

    styles_dict_emotions = dict(
    JoyHbo=dict(linestyle="solid"),
    JoyHbr=dict(linestyle="dotted"),
    FearHbo=dict(linestyle="solid"),
    FearHbr=dict(linestyle="dotted"),
    AngerHbo=dict(linestyle="solid"),
    AngerHbr=dict(linestyle="dotted"),
    DisgustHbo=dict(linestyle="solid"),
    DisgustHbr=dict(linestyle="dotted"),
    SadnessHbo=dict(linestyle="solid"),
    SadnessHbr=dict(linestyle="dotted"),
    NeutralHbo=dict(linestyle="solid"),
    NeutralHbr=dict(linestyle="dotted"),
    SurpriseHbo=dict(linestyle="solid"),
    SurpriseHbr=dict(linestyle="dotted"),
    )
else:
    color_dict_faces = dict(RealHbo="r", VirtualHbo="b")
    styles_dict_faces = dict(RealHbo=dict(linestyle="solid"), VirtualHbo=dict(linestyle="solid"))

    evoked_dict_faces = dict(
    RealHbo=list(all_epochs["Real"].pick(picks="hbo").iter_evoked()),
    VirtualHbo=list(all_epochs["Virtual"].pick(picks="hbo").iter_evoked()),
    )

    evoked_dict_emotions = dict(
    JoyHbo=list(all_epochs["Joy"].pick(picks="hbo").iter_evoked()),
    FearHbo=list(all_epochs["Fear"].pick(picks="hbo").iter_evoked()),
    AngerHbo=list(all_epochs["Anger"].pick(picks="hbo").iter_evoked()),
    DisgustHbo=list(all_epochs["Disgust"].pick(picks="hbo").iter_evoked()),
    SadnessHbo=list(all_epochs["Sadness"].pick(picks="hbo").iter_evoked()),
    NeutralHbo=list(all_epochs["Neutral"].pick(picks="hbo").iter_evoked()),
    SurpriseHbo=list(all_epochs["Surprise"].pick(picks="hbo").iter_evoked()),
    )

    color_dict_emotions = dict(
    JoyHbo="yellow",        # A bright and classic representation of happiness
    FearHbo="purple",       # Associated with tension or mystery
    AngerHbo="red",         # Classic association with anger
    DisgustHbo="green",     # Green, but brighter to differentiate from olive
    SadnessHbo="blue",      # Commonly associated with sadness
    NeutralHbo="gray",      # Neutral tones
    SurpriseHbo="orange"    # Bright, attention-grabbing orange
    )

    styles_dict_emotions = dict(
    JoyHbo=dict(linestyle="solid"),
    FearHbo=dict(linestyle="solid"),
    AngerHbo=dict(linestyle="solid"),
    DisgustHbo=dict(linestyle="solid"),
    SadnessHbo=dict(linestyle="solid"),
    NeutralHbo=dict(linestyle="solid"),
    SurpriseHbo=dict(linestyle="solid"),
    )

### ERP Plots Mean Region

In [26]:
if get_erp_mean_regions_plots:
    # Plot the real vs virtual evoked responses
    mne.viz.plot_compare_evokeds(
        evoked_dict_faces,
        combine="mean",
        ci=0.95,
        colors=color_dict_faces,
        styles=styles_dict_faces,
        show=False,
        ylim=lims,
        title="Real vs Virtual",
        legend="lower left",
        truncate_yaxis=True,
    )

    plt.savefig(f'plots/erp/mean_all_regions/Real vs Virtual_All Regions.png', dpi=dpi)
    plt.close()

    # Iterate through each pair of emotions and plot them
    for emo1, emo2 in emotion_pairs:
        # Create a sub-dictionary for the current pair
        evoked_pair_dict = {
            f"{emo1}Hbo": evoked_dict_emotions[f"{emo1}Hbo"],
            #f"{emo1}Hbr": evoked_dict_emotions[f"{emo1}Hbr"],
            f"{emo2}Hbo": evoked_dict_emotions[f"{emo2}Hbo"],
            #f"{emo2}Hbr": evoked_dict_emotions[f"{emo2}Hbr"],
        }

        # Create a color and style dictionary for the current pair
        color_pair_dict = {key: color_dict_emotions[key] for key in evoked_pair_dict}
        styles_pair_dict = {key: styles_dict_emotions[key] for key in evoked_pair_dict}

        # Plot the pair
        mne.viz.plot_compare_evokeds(
            evoked_pair_dict,
            combine="mean",
            ci=0.95,
            show=False,
            colors=color_pair_dict,
            styles=styles_pair_dict,
            ylim=lims,
            title=f"{emo1} vs {emo2}",
            legend="lower left",
            truncate_yaxis=True,
        )

        if emo1 == "Neutral" or emo2 == "Neutral":
            plt.savefig(f'plots/erp/mean_all_regions_neutral/{emo1} vs {emo2}_All Regions.png', dpi=dpi)
        else:
            plt.savefig(f'plots/erp/mean_all_regions/{emo1} vs {emo2}_All Regions.png', dpi=dpi)
        plt.close()

### ERP Plots Per Region

In [27]:
if get_erp_per_region_plots:
    # Iterate through each region in ch_mapping
    for region, channels in ch_mapping_hbo.items():
        # Get the indices for the current region from group_boundaries
        keys_list = list(ch_mapping_hbo.keys())
        start_idx = group_boundaries[keys_list.index(region)]
        end_idx = group_boundaries[keys_list.index(region) + 1] if keys_list.index(region) + 1 < len(group_boundaries) else None
        
        # Filter the epochs for the current region
        region_evoked_faces = {
            key: [evoked.copy().pick(evoked.ch_names[start_idx:end_idx]) for evoked in evokeds]
            for key, evokeds in evoked_dict_faces.items()
        }
        region_evoked_emotions = {
            key: [evoked.copy().pick(evoked.ch_names[start_idx:end_idx]) for evoked in evokeds]
            for key, evokeds in evoked_dict_emotions.items()
        }

        # Plot Real vs Virtual for the current region
        mne.viz.plot_compare_evokeds(
            region_evoked_faces,
            combine="mean",
            ci=0.95,
            colors=color_dict_faces,
            styles=styles_dict_faces,
            show=False,
            ylim=lims,
            title=f"Real vs Virtual - {region}",
            legend="lower left",
            truncate_yaxis=True,
        )
        plt.savefig(f'plots/erp/per_region/Real_vs_Virtual_{region}.png', dpi=dpi)
        plt.close()

        # Plot each pair of emotions for the current region
        for emo1, emo2 in emotion_pairs:
            # Create a sub-dictionary for the current pair
            evoked_pair_dict = {
                f"{emo1}Hbo": region_evoked_emotions[f"{emo1}Hbo"],
                #f"{emo1}Hbr": region_evoked_emotions[f"{emo1}Hbr"],
                f"{emo2}Hbo": region_evoked_emotions[f"{emo2}Hbo"],
                #f"{emo2}Hbr": region_evoked_emotions[f"{emo2}Hbr"],
            }

            # Create a color and style dictionary for the current pair
            color_pair_dict = {key: color_dict_emotions[key] for key in evoked_pair_dict}
            styles_pair_dict = {key: styles_dict_emotions[key] for key in evoked_pair_dict}

            # Plot the pair
            mne.viz.plot_compare_evokeds(
                evoked_pair_dict,
                combine="mean",
                ci=0.95,
                show=False,
                colors=color_pair_dict,
                styles=styles_pair_dict,
                ylim=lims,
                title=f"{emo1} vs {emo2} - {region}",
                legend="lower left",
                truncate_yaxis=True,
            )

            if emo1 == "Neutral" or emo2 == "Neutral":
                plt.savefig(f'plots/erp/per_region_neutral/{emo1}_vs_{emo2}_{region}.png', dpi=dpi)
            else:
                plt.savefig(f'plots/erp/per_region/{emo1}_vs_{emo2}_{region}.png', dpi=dpi)
            plt.close()

# Topographic Maps

In [28]:
if get_topo_condition_plots:
    if include_hbr:
        evoked_dict_topomap = {
            "Real (Hbo)": all_epochs["Real"].average(picks="hbo"),
            "RealHbr": all_epochs["Real"].average(picks="hbr"),
            "Virtual (Hbo)": all_epochs["Virtual"].average(picks="hbo"),
            "VirtualHbr": all_epochs["Virtual"].average(picks="hbr"),
            "Joy (Hbo)": all_epochs["Joy"].average(picks="hbo"),
            "JoyHbr": all_epochs["Joy"].average(picks="hbr"),
            "Fear (Hbo)": all_epochs["Fear"].average(picks="hbo"),
            "FearHbr": all_epochs["Fear"].average(picks="hbr"),
            "Anger (Hbo)": all_epochs["Anger"].average(picks="hbo"),
            "AngerHbr": all_epochs["Anger"].average(picks="hbr"),
            "Disgust (Hbo)": all_epochs["Disgust"].average(picks="hbo"),
            "DisgustHbr": all_epochs["Disgust"].average(picks="hbr"),
            "Sadness (Hbo)": all_epochs["Sadness"].average(picks="hbo"),
            "SadnessHbr": all_epochs["Sadness"].average(picks="hbr"),
            "Neutral (Hbo)": all_epochs["Neutral"].average(picks="hbo"),
            "NeutralHbr": all_epochs["Neutral"].average(picks="hbr"),
            "Surprise (Hbo)": all_epochs["Surprise"].average(picks="hbo"),
            "SurpriseHbr": all_epochs["Surprise"].average(picks="hbr")
        }
    else:
        evoked_dict_topomap = {
            "Real (Hbo)": all_epochs["Real"].average(picks="hbo"),
            "Virtual (Hbo)": all_epochs["Virtual"].average(picks="hbo"),
            "Joy (Hbo)": all_epochs["Joy"].average(picks="hbo"),
            "Fear (Hbo)": all_epochs["Fear"].average(picks="hbo"),
            "Anger (Hbo)": all_epochs["Anger"].average(picks="hbo"),
            "Disgust (Hbo)": all_epochs["Disgust"].average(picks="hbo"),
            "Sadness (Hbo)": all_epochs["Sadness"].average(picks="hbo"),
            "Neutral (Hbo)": all_epochs["Neutral"].average(picks="hbo"),
            "Surprise (Hbo)": all_epochs["Surprise"].average(picks="hbo")
        }

    for condition in evoked_dict_topomap:
        evoked_dict_topomap[condition].plot_topomap(
            times=[8],
            average=16,
            extrapolate="head",
            colorbar=True,
            size=2,
            vlim=(-15, 15),
            show=False
        )
        plt.suptitle(f'{condition}')
        plt.savefig(f'plots/topomaps/average for all 16/{condition}.png', dpi=dpi)
        plt.close()

    for condition in evoked_dict_topomap:
        evoked_dict_topomap[condition].plot_topomap(
            times=[4, 8, 12, 16],
            average=4,
            extrapolate="head",
            colorbar=True,
            size=2,
            vlim=(-15, 15),
            show=False
        )
        plt.suptitle(f'{condition}')
        plt.savefig(f'plots/topomaps/4-8-12-16/{condition}.png', dpi=dpi)
        plt.close()

### Topographic Difference Maps

In [29]:
if get_topo_diff_plots:
    evoked_dict_differences_topomap = {
        "Real - Virtual (Hbo)": mne.combine_evoked([evoked_dict_topomap["Real (Hbo)"], evoked_dict_topomap["Virtual (Hbo)"]], weights=[1, -1]),
        "Virtual - Real (Hbo)": mne.combine_evoked([evoked_dict_topomap["Virtual (Hbo)"], evoked_dict_topomap["Real (Hbo)"]], weights=[1, -1]),
    }

    # Get all emotions in the dataset
    emotions = ["Joy", "Fear", "Anger", "Disgust", "Sadness", "Neutral", "Surprise"]

    for emotion in itertools.combinations(emotions, 2):
        evoked_dict_differences_topomap[f"{emotion[0]} - {emotion[1]} (Hbo)"] = mne.combine_evoked(
            [evoked_dict_topomap[f"{emotion[0]} (Hbo)"], evoked_dict_topomap[f"{emotion[1]} (Hbo)"]],
            weights=[1, -1]
        )
        evoked_dict_differences_topomap[f"{emotion[1]} - {emotion[0]} (Hbo)"] = mne.combine_evoked(
            [evoked_dict_topomap[f"{emotion[1]} (Hbo)"], evoked_dict_topomap[f"{emotion[0]} (Hbo)"]],
            weights=[1, -1]
        )

    for condition in evoked_dict_differences_topomap:
        evoked_dict_differences_topomap[condition].plot_topomap(
            times=[8],
            average=16,
            extrapolate="head",
            colorbar=True,
            size=2,
            vlim=(-15, 15),
            show=False
        )
        plt.title(f'{condition}')
        if 'Neutral' in condition:
            plt.savefig(f'plots/topomaps/average differences for all 16_neutral/{condition}.png', dpi=dpi)
        else:
            plt.savefig(f'plots/topomaps/average differences for all 16/{condition}.png', dpi=dpi)
        plt.close()

# Individual Connectivity Analysis

In [30]:
# pick the channels
pick_chs = 'hbo'

# pick the connectivity method
method = "coh"

# pick the mode
mode = "cwt_morlet"

# pick the frequency range
cwt_freqs = np.linspace(0.01, 0.5, 10)

# pick the number of cycles
cwt_n_cycles = 1

# average the connectivity matrices across frequencies
faverage = True

if run_ind_connectivity:
    # clear the processed_data\connectivity\individual folder
    for file in os.listdir('processed_data\\connectivity\\individual'):
        os.remove(f'processed_data\\connectivity\\individual\\{file}')

    # for each raw_haemo in raw_haemo_good_recordings, compute the connectivity
    for i, raw_haemo in enumerate(raw_haemo_good_recordings, 1):
        # use spectral_connectivity_time to compute the connectivity
        con = spectral_connectivity_epochs(
            data = relabel_annotations(raw_haemo.copy(), mode='all').pick(picks=pick_chs),
            method=method,
            mode=mode,
            cwt_freqs=cwt_freqs,
            cwt_n_cycles=cwt_n_cycles,
            faverage=faverage,
            n_jobs=n_jobs,
            verbose=True
        )
        np.save(f"processed_data\\connectivity\\individual\\con_{i}.npy", con.get_data())
        del con

    # make a dictionsary to store the connectivity parameters
    connectivity_params = {
        "pick_chs": pick_chs,
        "method": method,
        "mode": mode,
        "cwt_freqs": cwt_freqs.tolist(),
        "cwt_n_cycles": cwt_n_cycles,
        "faverage": faverage,
        "ch_names": raw_haemo.copy().pick(picks=pick_chs).ch_names
    }

    # save the connectivity parameters to disk in preprocessed_data\connectivity
    with open("processed_data\\connectivity\\individual\\connectivity_params.json", "w") as f:
        json.dump(connectivity_params, f)

### Loading Individual Connectivity Data

In [31]:
# get number of .npy files in the processed_data\connectivity\individual directory
num_files = len([name for name in os.listdir('processed_data\\connectivity\\individual') if name.endswith('.npy')])

# load the numpy files from disk
ind_con = [np.load(f"processed_data\\connectivity\\individual\\con_{i}.npy") for i in range(1, num_files + 1)]

# load the connectivity parameters from disk
with open("processed_data\\connectivity\\individual\\connectivity_params.json", "r") as f:
    ind_connectivity_params = json.load(f)

### Individual Connectivity Heatmap Plots

In [32]:
if get_ind_con_plots:
    # clear the plots\connectivity\heatmaps\individual folder
    for file in os.listdir('plots\\connectivity\\heatmaps\\individual'):
        os.remove(f'plots\\connectivity\\heatmaps\\individual\\{file}')

    for i, con in enumerate(ind_con, 1):
        # Average across time points
        averaged_data = np.mean(con, axis=(1, 2))

        # Get the grid size
        grid_size = int(np.sqrt(averaged_data.size))

        # Reshape the data into a 2D grid
        heatmap_data = averaged_data.reshape((grid_size, grid_size))

        # Make the matrix symmetric
        symmetric_data = heatmap_data + heatmap_data.T

        # Set the diagonal to the highest value
        np.fill_diagonal(symmetric_data, np.max(symmetric_data))

        # Plot the heatmap
        fig, ax = plt.subplots(figsize=(20, 20))
        ax.set_title(f'Connectivity Matrix for con_{i}')
        im = ax.imshow(symmetric_data, cmap='viridis')
        ax.set_xlabel('Channel')
        ax.set_ylabel('Channel')
        # set the x and y ticks to ind_connectivity_params['ch_names']
        ax.set_xticks(np.arange(grid_size))
        ax.set_yticks(np.arange(grid_size))
        ax.set_xticklabels(ind_connectivity_params['ch_names'])
        ax.set_yticklabels(ind_connectivity_params['ch_names'])
        plt.xticks(rotation=90)

        # Add a single colorbar for the entire figure
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('Connectivity Value')
        plt.savefig(f'plots/connectivity/heatmaps/individual/con_{i}.png', dpi=dpi / 3)
        plt.close()

### Average Individual Connectivity over time

In [33]:
if get_avg_ind_con_plot_avg:
    all_ind_avg = [np.mean(con, axis=(0, 1)) for con in ind_con]
    all_ind_avg = np.mean(all_ind_avg, axis=0)

    # plot the connectivity over the 99 time points
    plt.figure(figsize=(10, 6))
    plt.plot(all_ind_avg)
    plt.title(f'Average Channel Connectivity over time for all individuals, {ind_connectivity_params["pick_chs"]}, {ind_connectivity_params["method"]}, {ind_connectivity_params["mode"]}')
    plt.xlabel('Time (s)')
    plt.xticks(np.arange(0, 99, 11), np.arange(0, 17, 2))
    plt.ylabel('Connectivity')
    plt.savefig(f'plots/connectivity/over_time/Average Channel Connectivity over time for all individuals.png', dpi=dpi / 4)
    plt.close()

# Condition Connectivity

In [34]:
# pick the channels
pick_chs = 'hbo'

# pick the connectivity method
method = "coh"

# pick the mode
mode = "cwt_morlet"

# pick the frequency range
cwt_freqs = np.linspace(0.01, 0.5, 10)

# pick the number of cycles
cwt_n_cycles = 1

# average the connectivity matrices across frequencies
faverage = True

if run_condition_connectivity:
    # for each group of epochs, calculate the connectivity matrix
    for condition, epoch in all_epochs.items():
        con = spectral_connectivity_epochs(
            data = epoch.pick(picks=pick_chs),
            method=method,
            mode=mode,
            cwt_freqs=cwt_freqs,
            cwt_n_cycles=cwt_n_cycles,
            faverage=faverage,
            n_jobs=n_jobs,
            verbose=True
        )
        np.save(f"processed_data\\connectivity\\{condition}_connectivity.npy", con.get_data())

    # make a dictionsary to store the connectivity parameters
    connectivity_params = {
        "pick_chs": pick_chs,
        "method": method,
        "mode": mode,
        "cwt_freqs": cwt_freqs.tolist(),
        "cwt_n_cycles": cwt_n_cycles,
        "faverage": faverage,
        "ch_names": all_epochs["Real"].copy().pick(picks=pick_chs).ch_names
    }

    # save the connectivity parameters to disk in preprocessed_data\connectivity
    with open("processed_data\\connectivity\\connectivity_params.json", "w") as f:
        json.dump(connectivity_params, f)

### Loading Condition Connectivity Data

In [35]:
# load the numpy array from disk
all_con = {
    "Real": np.load("processed_data\\connectivity\\Real_connectivity.npy"),
    "Virtual": np.load("processed_data\\connectivity\\Virtual_connectivity.npy"),
    "Joy": np.load("processed_data\\connectivity\\Joy_connectivity.npy"),
    "Fear": np.load("processed_data\\connectivity\\Fear_connectivity.npy"),
    "Anger": np.load("processed_data\\connectivity\\Anger_connectivity.npy"),
    "Disgust": np.load("processed_data\\connectivity\\Disgust_connectivity.npy"),
    "Sadness": np.load("processed_data\\connectivity\\Sadness_connectivity.npy"),
    "Neutral": np.load("processed_data\\connectivity\\Neutral_connectivity.npy"),
    "Surprise": np.load("processed_data\\connectivity\\Surprise_connectivity.npy")
}

# load the connectivity parameters from disk
if os.path.exists("processed_data\\connectivity\\connectivity_params.json"):
    with open("processed_data\\connectivity\\connectivity_params.json", "r") as f:
        connectivity_params = json.load(f)

### Average Condition Connectivity Over Time

In [36]:
if get_avg_condition_con_plot:
    # average the connectivity matrices across epochs, keeping the time dimension
    all_con_avg = {key: np.mean(con, axis=0) for key, con in all_con.items()}
    # convert (1, 99) to a single list
    all_con_avg = {key: list(con[0]) for key, con in all_con_avg.items()}

    # plot the connectivity over the 99 time points
    plt.figure(figsize=(10, 6))
    for condition, con in all_con_avg.items():
        plt.plot(con, label=condition)
    plt.title(f'Average Channel Connectivity over time for all conditions, {connectivity_params["pick_chs"]}, {connectivity_params["method"]}, {connectivity_params["mode"]}')
    plt.xlabel('Time (s)')
    plt.xticks(np.arange(0, 99, 11), np.arange(0, 17, 2))
    plt.ylabel('Connectivity')
    plt.legend()
    plt.savefig('plots/connectivity/over_time/Average Channel Connectivity over time for all conditions.png', dpi=dpi / 4)
    plt.close()

### Histogram connectivity Plots

In [37]:
if get_hist_con_plots:
    for condition, con in all_con.items():
        averaged_data = np.mean(con, axis=(1, 2))

        # Remove any values that are 0
        averaged_data = averaged_data[averaged_data != 0]

        # plot a histogram of the connectivity values
        plt.hist(averaged_data, bins=40, edgecolor='black')
        plt.xlabel('Connectivity Value')
        plt.ylabel('Number of Values')
        plt.title(f'Connectivity Histogram ({condition})')
        plt.savefig(f'plots/connectivity/histograms/conditions/{condition}_connectivity_histogram.png', dpi=dpi)
        plt.close()

### Histogram difference in connectivity between Real/Virtual Faces

In [38]:
if get_hist_diff_face_plots:
    real_virtual_diff = np.mean(all_con["Real"], axis=(1, 2)) - np.mean(all_con["Virtual"], axis=(1, 2))
    virtual_real_diff = np.mean(all_con["Virtual"], axis=(1, 2)) - np.mean(all_con["Real"], axis=(1, 2))
    plot_titles = ["Real - Virtual", "Virtual - Real"]

    for i in [real_virtual_diff, virtual_real_diff]:
        # Remove any values that are 0
        i = i[i != 0]

        # plot a histogram of the connectivity values
        plt.hist(i, bins=40, edgecolor='black')
        plt.xlabel('Connectivity Value')
        plt.ylabel('Number of Values')
        title = plot_titles.pop(0)
        plt.title(f'Connectivity Histogram ({title})')
        plt.savefig(f'plots/connectivity/histograms/differences/{title}_connectivity_histogram.png', dpi=dpi)
        plt.close()

### Histogram difference in connectivity between Emotions

In [39]:
if get_hist_diff_emotion_plots:
    # Get all emotions in the dataset
    emotions = ["Joy", "Fear", "Anger", "Disgust", "Sadness", "Neutral", "Surprise"]

    # Placeholder for the result dictionary
    diff_results = {}

    # Loop through all combinations of two conditions
    for cond1, cond2 in itertools.combinations(emotions, 2):
        # Calculate the difference in both directions
        diff_1_2 = np.mean(all_con[cond1], axis=(1, 2)) - np.mean(all_con[cond2], axis=(1, 2))
        diff_2_1 = np.mean(all_con[cond2], axis=(1, 2)) - np.mean(all_con[cond1], axis=(1, 2))
        
        # Store the results in the dictionary
        diff_results[f"{cond1}-{cond2}"] = diff_1_2
        diff_results[f"{cond2}-{cond1}"] = diff_2_1

    for diff in diff_results:
        # Remove any values that are 0
        diff_results[diff] = diff_results[diff][diff_results[diff] != 0]

        # plot a histogram of the connectivity values
        plt.hist(diff_results[diff], bins=40, edgecolor='black')
        plt.xlabel('Connectivity Value')
        plt.ylabel('Number of Values')
        plt.title(f'Connectivity Histogram ({diff})')
        if 'Neutral' in diff:
            plt.savefig(f'plots/connectivity/histograms/differences_neutral/{diff}_connectivity_histogram.png', dpi=dpi)
        else:
            plt.savefig(f'plots/connectivity/histograms/differences/{diff}_connectivity_histogram.png', dpi=dpi)
        plt.close()

### Heatmap connectivity Plots/Distance Measure Heatmap

In [40]:
if get_heatmap_con_dist_plots:
    heatmaps = []

    for con in all_con:
        # Average across participants, frequencies, and time points
        averaged_data = np.mean(all_con[con], axis=(1, 2))  # Shape becomes (10609,)

        # Get the grid size
        grid_size = int(np.sqrt(averaged_data.size))

        # Reshape the data into a 2D grid
        heatmap_data = averaged_data.reshape((grid_size, grid_size))

        # Make the matrix symmetric
        symmetric_data = heatmap_data + heatmap_data.T

        # Set the diagonal to the highest value
        np.fill_diagonal(symmetric_data, np.max(symmetric_data))

        heatmaps.append((con, symmetric_data))

        # Plot the heatmap
        fig, ax = plt.subplots(figsize=(20, 20))
        ax.set_title(con)
        im = ax.imshow(symmetric_data, cmap='viridis')
        ax.set_xlabel('Channel')
        ax.set_ylabel('Channel')
        # set the x and y ticks to connectivity_params['ch_names']
        ax.set_xticks(np.arange(grid_size))
        ax.set_yticks(np.arange(grid_size))
        ax.set_xticklabels(connectivity_params['ch_names'])
        ax.set_yticklabels(connectivity_params['ch_names'])
        plt.xticks(rotation=90)

        # Add a single colorbar for the entire figure
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('Connectivity Value')
        plt.savefig(f'plots/connectivity/heatmaps/conditions/{con}.png', dpi=dpi / 3)
        plt.close()

    # create a 2D numpy array to store the differences between the heatmaps
    diff_heatmaps = np.zeros((len(heatmaps), len(heatmaps)))

    # calculate the absolute differences between the heatmaps
    for i, (con1, heatmap1) in enumerate(heatmaps):
        for j, (con2, heatmap2) in enumerate(heatmaps):
            diff_heatmaps[i, j] = np.sum(np.abs(heatmap1 - heatmap2))

    # plot the heatmap of the differences, with the numbers on the heatmap
    plt.figure(figsize=(7, 6))
    plt.imshow(diff_heatmaps, cmap='viridis')
    plt.title('Condition Differences')
    plt.ylabel('Condition')
    plt.xticks(range(len(heatmaps)), [con for con, _ in heatmaps], rotation=45)
    plt.yticks(range(len(heatmaps)), [con for con, _ in heatmaps])
    plt.colorbar(label='Difference', fraction=0.046, pad=0.04)
    # add the numbers to the heatmap
    for i in range(len(heatmaps)):
        for j in range(len(heatmaps)):
            plt.text(j, i, f'{diff_heatmaps[i, j]:.0f}', ha='center', va='center', color='white')
    plt.savefig(f'plots/connectivity/heatmaps/condition_sum_of_differences.png', dpi=dpi)
    plt.close()

### Heatmap difference in connectivity between Real/Virtual Faces

In [41]:
if get_heatmap_diff_face_plots:
    # Average the connectivity matrices across Real and Virtual conditions
    real_virtual_diff = np.mean(all_con["Real"], axis=(1, 2)) - np.mean(all_con["Virtual"], axis=(1, 2))
    virtual_real_diff = np.mean(all_con["Virtual"], axis=(1, 2)) - np.mean(all_con["Real"], axis=(1, 2))

    for i, diff in enumerate([real_virtual_diff, virtual_real_diff]):
        # Get the grid size
        grid_size = int(np.sqrt(diff.size))

        # Reshape the data into a 2D grid
        heatmap_data = diff.reshape((grid_size, grid_size))

        # Make the matrix symmetric
        symmetric_data = heatmap_data + heatmap_data.T

        # Set the diagonal to the lowest value
        np.fill_diagonal(symmetric_data, np.min(symmetric_data))

        # Plot the heatmap
        fig, ax = plt.subplots(figsize=(20, 20))
        title = 'Real-Virtual' if i == 0 else 'Virtual-Real'
        ax.set_title(title)
        im = ax.imshow(symmetric_data, cmap='viridis')
        ax.set_xlabel('Channel')
        ax.set_ylabel('Channel')
        # set the x and y ticks to connectivity_params['ch_names']
        ax.set_xticks(np.arange(grid_size))
        ax.set_yticks(np.arange(grid_size))
        ax.set_xticklabels(connectivity_params['ch_names'])
        ax.set_yticklabels(connectivity_params['ch_names'])
        plt.xticks(rotation=90)

        # Add a single colorbar for the entire figure
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('Difference in Connectivity Strength')
        plt.savefig(f'plots/connectivity/heatmaps/differences/{title}.png', dpi=dpi / 3)
        plt.close()

### Heatmap difference in connectivity between Emotions

In [42]:
if get_heatmap_diff_emotion_plots:
    # Get all emotions in the dataset
    emotions = ["Joy", "Fear", "Anger", "Disgust", "Sadness", "Neutral", "Surprise"]

    # Placeholder for the result dictionary
    diff_results = {}

    # Loop through all combinations of two conditions
    for cond1, cond2 in itertools.combinations(emotions, 2):
        # Calculate the difference in both directions
        diff_1_2 = np.mean(all_con[cond1], axis=(1, 2)) - np.mean(all_con[cond2], axis=(1, 2))
        diff_2_1 = np.mean(all_con[cond2], axis=(1, 2)) - np.mean(all_con[cond1], axis=(1, 2))
        
        # Store the results in the dictionary
        diff_results[f"{cond1}-{cond2}"] = diff_1_2
        diff_results[f"{cond2}-{cond1}"] = diff_2_1

    for diff in diff_results:
        # Get the grid size
        grid_size = int(np.sqrt(diff_results[diff].size))

        # Reshape the data into a 2D grid
        heatmap_data = diff_results[diff].reshape((grid_size, grid_size))

        # Make the matrix symmetric
        symmetric_data = heatmap_data + heatmap_data.T

        # Set the diagonal to the lowest value
        np.fill_diagonal(symmetric_data, np.min(symmetric_data))

        # plot the heatmap
        plt.figure(figsize=(20, 20))
        plt.imshow(symmetric_data, cmap='viridis')
        plt.xlabel('Channel')
        plt.ylabel('Channel')
        plt.title(f'Difference in Connectivity Matrices: {diff}')
        plt.xticks(range(grid_size), connectivity_params['ch_names'], rotation=90)
        plt.yticks(range(grid_size), connectivity_params['ch_names'])

        # Add a single colorbar for the entire figure
        cbar = plt.colorbar(fraction=0.046, pad=0.04)
        cbar.set_label('Difference in Connectivity Strength')

        if 'Neutral' in diff:
            plt.savefig(f'plots/connectivity/heatmaps/differences_neutral/{diff}.png', dpi=dpi / 3)
        else:
            plt.savefig(f'plots/connectivity/heatmaps/differences/{diff}.png', dpi=dpi / 3)
        plt.close()

### Chord plots for Connectivity

In [43]:
if get_chord_con_plots:
    # make a color scheme
    colorscheme = dict(
        facecolor='white',
        textcolor='black',
        colormap='hot',
        facecolor2='black',
        textcolor2='white',
    )

    # get the node angles
    node_angles = circular_layout(
        ch_names_hbo, all_channels_hbo, start_pos=90, group_boundaries=group_boundaries
    )

    for con in all_con:
        averaged_data = np.mean(all_con[con], axis=(1, 2))  # Shape becomes (10609,)
        grid_size = int(np.sqrt(averaged_data.size))
        heatmap_data = averaged_data.reshape((grid_size, grid_size))
        plot_connectivity_circle(
            heatmap_data,
            node_names=ch_names_hbo,
            node_angles=node_angles,
            n_lines=10000,
            title=con,
            colorbar_size=1,
            fontsize_colorbar=8,
            facecolor=colorscheme['facecolor'],
            textcolor=colorscheme['textcolor'],
            colormap=colorscheme['colormap'],
            padding=3,
            vmin=0.2,
            vmax=0.6,
            fontsize_title=24,
            colorbar=True,
            show=False
        )

        plt.savefig(f'plots/connectivity/chord_plots/conditions/{con}.png', dpi=dpi)
        plt.close()

### Chord plot difference in connectivity between Real/Virtual Faces

In [44]:
if get_chord_diff_face_plots:
    real_virtual_diff = np.mean(all_con["Real"], axis=(1, 2)) - np.mean(all_con["Virtual"], axis=(1, 2))
    virtual_real_diff = np.mean(all_con["Virtual"], axis=(1, 2)) - np.mean(all_con["Real"], axis=(1, 2))

    for i in [real_virtual_diff, virtual_real_diff]:
        grid_size = int(np.sqrt(i.size))
        heatmap_data = i.reshape((grid_size, grid_size))
        title = 'Real-Virtual' if i is real_virtual_diff else 'Virtual-Real'
        plot_connectivity_circle(
            heatmap_data,
            node_names=ch_names_hbo,
            node_angles=node_angles,
            n_lines=10000,
            title=title,
            colorbar_size=1,
            fontsize_colorbar=8,
            facecolor=colorscheme['facecolor'],
            textcolor=colorscheme['textcolor'],
            colormap=colorscheme['colormap'],
            padding=3,
            vmin=0.035,
            vmax=0.08,
            fontsize_title=24,
            colorbar=True, 
            show=False
        )

        plt.savefig(f'plots/connectivity/chord_plots/differences/{title}.png', dpi=dpi)
        plt.close()

### Chord plot difference in connectivity between Emotions

In [45]:
if get_chord_diff_emotion_plots:
    # Placeholder for the result dictionary
    diff_results = {}

    # Loop through all combinations of two conditions
    for cond1, cond2 in itertools.combinations(emotions, 2):
        # Calculate the difference in both directions
        diff_1_2 = np.mean(all_con[cond1], axis=(1, 2)) - np.mean(all_con[cond2], axis=(1, 2))
        diff_2_1 = np.mean(all_con[cond2], axis=(1, 2)) - np.mean(all_con[cond1], axis=(1, 2))
        
        # Store the results in the dictionary
        diff_results[f"{cond1}-{cond2}"] = diff_1_2
        diff_results[f"{cond2}-{cond1}"] = diff_2_1

    for diff in diff_results:
        grid_size = int(np.sqrt(diff_results[diff].size))
        heatmap_data = diff_results[diff].reshape((grid_size, grid_size))
        plot_connectivity_circle(
            heatmap_data,
            node_names=ch_names_hbo,
            node_angles=node_angles,
            n_lines=10000,
            title=diff,
            colorbar_size=1,
            fontsize_colorbar=8,
            facecolor=colorscheme['facecolor'],
            textcolor=colorscheme['textcolor'],
            colormap=colorscheme['colormap'],
            padding=3,
            vmin=0.075,
            vmax=0.15,
            fontsize_title=24,
            colorbar=True,
            show=False
        )

        if 'Neutral' in diff:
            plt.savefig(f'plots/connectivity/chord_plots/differences_neutral/{diff}.png', dpi=dpi)
        else:
            plt.savefig(f'plots/connectivity/chord_plots/differences/{diff}.png', dpi=dpi)
        plt.close()

# Time Series Analysis

### Get Time Series Data

In [46]:
modes = ['face_type', 'emotion']

if get_time_series:
    for mode in modes:
        time_series = pd.DataFrame()

        for raw_haemo in raw_haemo_good_recordings:
            raw_haemo_annots = raw_haemo.copy()
            relabel_annotations(raw_haemo_annots, mode=mode)

            df = raw_haemo_annots.to_data_frame()

            annots = raw_haemo_annots.annotations.to_data_frame(time_format='ms')

            # drop the duration column
            annots.drop(columns=['duration'], inplace=True)

            # set the first onset to 0 and each onset to the previous onset + duration in seconds
            annots['onset'] = (annots['onset'] - annots['onset'][0]) / 1000

            # insert a new column after the time column called 'event'
            df.insert(1, 'event', np.nan)

            # create an empty last row and shuffle all the descriptions down one row
            annots.loc[len(annots)] = None
            annots['description'] = annots['description'].shift(1)

            # remove the first row
            annots = annots.iloc[1:]

            for row in annots.iterrows():
                time = row[1]['onset']

                # fill all the events that occur since the last non NaN event and the current event
                df.loc[(df['time'] < time) & (df['event'].isna()), 'event'] = row[1]['description']

            # get the last event
            last_event = annots['description'].iloc[-1]

            # fill all the events that occur after the last event with the last event
            df.loc[df['event'].isna(), 'event'] = last_event

            time_series = pd.concat([time_series, df])

        # save to csv
        time_series.to_csv(f'processed_data/time_series/datasets/time_series_{mode}.csv', index=False)

# load the time series from disk
all_time_series = {
    mode: pd.read_csv(f'processed_data/time_series/datasets/time_series_{mode}.csv')
    for mode in modes
}

# drop the Base event from both time series
drop_base_event = True
if drop_base_event:
    for mode in modes:
        all_time_series[mode] = all_time_series[mode][all_time_series[mode]['event'] != 'Base']

label_mapping = {
    mode: dict(sorted(dict(zip(LabelEncoder().fit_transform(all_time_series[mode]["event"]), all_time_series[mode]["event"])).items()))
    for mode in modes
}

### Run Classifiers

In [47]:
if run_models:
    models_to_run = [
        RandomForestClassifier(n_estimators=250, class_weight='balanced', random_state=42, n_jobs=n_jobs),
        # XGBClassifier(n_jobs=n_jobs, random_state=42),
        # LGBMClassifier(n_jobs=n_jobs, random_state=42),  # Faster alternative to XGBoost
        # HistGradientBoostingClassifier(random_state=42),  # Optimized for large datasets
        # LogisticRegression(multi_class="multinomial", solver="lbfgs", max_iter=1000, random_state=42, n_jobs=n_jobs),  # Linear classifier
        KNeighborsClassifier(n_neighbors=2, algorithm='auto', leaf_size=10, weights='distance', n_jobs=n_jobs),  # Simple and effective for certain datasets
        # LinearSVC(dual=False, random_state=42),  # dual=False is faster for large datasets
        # SGDClassifier(loss="hinge", random_state=42, n_jobs=n_jobs),  # Hinge loss = linear SVM
        # MLPClassifier(
        #     hidden_layer_sizes=(256, 128, 64),  # Deeper network for complex patterns
        #     activation='relu',  # Best default choice
        #     solver='adam',  # 'adam' works well for most cases, but 'sgd' can be tested
        #     alpha=1e-4,  # L2 regularization to reduce overfitting
        #     batch_size=256,  # Larger batch sizes improve speed
        #     learning_rate='adaptive',  # Adjusts learning rate based on performance
        #     max_iter=500,  # More iterations for large datasets
        #     early_stopping=True,  # Stops if validation loss doesn’t improve
        #     random_state=42
        # )
    ]

    # Define parameter grids for models to be tuned
    param_grids = {
    }

    # if the model name folder doesn't exist, create it
    for model in models_to_run:
        if not os.path.exists(f'processed_data/time_series/results/{model.__class__.__name__}'):
            os.makedirs(f'processed_data/time_series/results/{model.__class__.__name__}')

    for mode, time_series in all_time_series.items():
        # Convert 'event' to numerical labels
        label_encoder = LabelEncoder()
        time_series["event"] = label_encoder.fit_transform(time_series["event"])

        # Convert 'time' to numeric type
        time_series["time"] = pd.to_numeric(time_series["time"], errors="coerce")

        # Normalize all sensor data
        scaler = StandardScaler()
        sensor_columns = time_series.columns[2:]  # Exclude 'time' and 'event'
        time_series[sensor_columns] = scaler.fit_transform(time_series[sensor_columns])

        # Check for missing values
        missing_values = time_series.isnull().sum().sum()

        # Split data into features and target variable
        X = time_series.drop(columns=["event", "time"])  # Features
        y = time_series["event"]  # Target

        # Split into train and test sets
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

        for model in models_to_run:
            model_name = model.__class__.__name__

            # Check if the model requires hyperparameter tuning
            if model_name in param_grids:
                # Perform GridSearchCV to find the best hyperparameters
                grid_search = GridSearchCV(estimator=model, param_grid=param_grids[model_name], cv=3, n_jobs=4, scoring='accuracy', verbose=3)
                grid_search.fit(X_train, y_train)
                model = grid_search.best_estimator_
                print(model.get_params())

            # Fit the model
            model.fit(X_train, y_train)

            # Make predictions
            y_pred = model.predict(X_test)

            report = classification_report(y_test, y_pred, output_dict=True)

            # convert the report to a DataFrame
            report = pd.DataFrame(report).transpose()

            for index in report.index:
                if index.isdigit():
                    report.rename(index={index: label_mapping[mode][int(index)]}, inplace=True)

            # create a new df to store the feature importances
            if hasattr(model, 'feature_importances_'):
                feature_importances = pd.DataFrame(model.feature_importances_, index=X.columns, columns=['importance'])
                feature_importances.sort_values(by='importance', ascending=False, inplace=True)

                # save the feature importances to disk
                feature_importances.to_csv(f'processed_data/time_series/results/{model.__class__.__name__}/feature_importances_' + mode + '.csv')

            # save the report to csv
            report.to_csv(f'processed_data/time_series/results/{model.__class__.__name__}/report_' + mode + '.csv')

# get all folder names from the results folder
models = [folder for folder in os.listdir('processed_data/time_series/results')]

# load everything from disk to a dict called results, where you can access the results of each model
# like this: results[model_name][mode]['feature_importances'] or results[model_name][mode]['report']
results = {
    model: {
        mode: {
            'feature_importances': pd.read_csv(f'processed_data/time_series/results/{model}/feature_importances_{mode}.csv', index_col=0) if os.path.exists(f'processed_data/time_series/results/{model}/feature_importances_{mode}.csv') else None,
            'report': pd.read_csv(f'processed_data/time_series/results/{model}/report_{mode}.csv', index_col=0) if os.path.exists(f'processed_data/time_series/results/{model}/report_{mode}.csv') else None
        }
        for mode in modes
    }
    for model in models
}

### Plot Classifier Accuracy Results

In [48]:
for mode in modes:
    plt.figure(figsize=(10, 6))
    x = np.arange(len(label_mapping[mode]))  # the label locations
    width = 0.1  # the width of the bars

    model_index = 0
    for i, model in enumerate(models):
        if results[model][mode]['report'] is not None:
            report = results[model][mode]['report']
            report = report.loc[report.index.intersection(label_mapping[mode].values())]
            plt.bar(x + width * model_index, report['precision'], width, label=model)
            model_index += 1

    plt.title(f'Accuracy per Condition for {mode}')
    plt.ylabel('Accuracy')
    plt.xlabel('Condition')
    plt.xticks(x + width * (model_index - 1) / 2, label_mapping[mode].values())
    plt.ylim(0, 1)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    plt.savefig(f'plots/time_series/accuracy/accuracy_per_condition_{mode}.png', dpi=dpi / 4)
    plt.close()

for mode in modes:
    plt.figure(figsize=(10, 6))

    model_index = 0
    models_used = []
    for i, model in enumerate(models):
        if results[model][mode]['report'] is not None:
            report = results[model][mode]['report']
            accuracy = report['precision'].accuracy
            plt.bar(model_index, accuracy, label=model)
            models_used.append(model)
            model_index += 1

    plt.title(f'Total Accuracy for {mode}')
    plt.ylabel('Accuracy')
    plt.xlabel('Model')
    plt.xticks(range(len(models_used)), models_used, rotation=90)
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.savefig(f'plots/time_series/accuracy/total_accuracy_{mode}.png', dpi=dpi / 4)
    plt.close()

### Plot Feature Importances per Classifier

In [49]:
# plot the feature importances for each model
for model in models:
    for mode in modes:
        if results[model][mode]['feature_importances'] is not None:
            feature_importances = results[model][mode]['feature_importances']

            plt.figure(figsize=(6, 30))
            plt.barh(feature_importances.index, feature_importances['importance'])
            plt.title(f'Feature Importances for {model} ({mode})')
            plt.xlabel('Importance')
            plt.ylabel('Feature')
            plt.tight_layout()
            plt.savefig(f'plots/time_series/feature_importances/{model}_{mode}.png', dpi=dpi / 4)
            plt.close()