In [1]:
import dearpygui.dearpygui as dpg
import threading
import os
import scipy.io
import numpy as np
import pandas as pd
from tkinter import Tk, filedialog, simpledialog
import logging
import json
import time
from datetime import datetime, timedelta
import shutil 
import zipfile
import gc
from scipy.signal import butter, filtfilt, welch, iirnotch, hilbert, get_window
import tracemalloc
import h5py
import seaborn as sns
import pyperclip
import mne
from autoreject import AutoReject
from pynwb import NWBHDF5IO, NWBFile, TimeSeries
from pynwb.behavior import Position, SpatialSeries
from pynwb.file import Subject
from pynwb.misc import AnnotationSeries
from typing import Optional
from tensorpac import Pac
from tensorpac.signals import pac_signals_wavelet
import pickle
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import matplotlib.font_manager as fm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import io
import contextlib
from mne.time_frequency import psd_array_welch
from scipy.stats import linregress
from collections import defaultdict
import re
from pathlib import Path
from collections import defaultdict
from matplotlib.lines import Line2D
from __future__ import annotations

# Create the Logs and Save directory if it doesn't exist
log_dir = "Logs"
os.makedirs(log_dir, exist_ok=True)

save_dir = "Saved Info"
os.makedirs(save_dir, exist_ok=True)
config_file = os.path.join(save_dir, "config.json")

# Get the current timestamp
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

# Create the log filename with the timestamp
log_filename = os.path.join(log_dir, f'gui_V2_debug_{timestamp}.log')

# Set up logging
logging.basicConfig(level=logging.DEBUG, filename=log_filename, filemode='w', 
                    format='%(asctime)s - %(levelname)s - %(message)s')

global window_counter
window_counter = 0

# Suppress debug messages for the hdmf and pynwb libraries.
logging.getLogger("hdmf").setLevel(logging.WARNING)
logging.getLogger("pynwb").setLevel(logging.WARNING)
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)

# NWB stores an int16 ADC value NOT the voltage
ADBitVolts = 0.00000030517578125
ADBitmilliVolts = 0.00030517578125
ADBitmicroVolts = 0.30517578125

termes_font = FontProperties(fname="fonts/texgyretermes-regular.otf")
termes_font_bold = FontProperties(fname="fonts/texgyretermes-bold.otf")
#plt.rcParams['font.family'] = termes_font.get_name()

error_string = "Rats! An error has occured in the program. Sometimes this is ok, but \
it's best to reboot when this happens becuase the error could cause more errors.\n\n\
You can send an error log to Kyle to figure out what went wrong. They are located \
in chronological order in the 'Logs' folder. I recommend organizing by date modified \
to always get the most recent one. \n\nIt may also help to describe what you were doing \
when this happened so Kyle can try to reproduce the issue, thanks!."

In [2]:
#######################################################################################################################################################
# Themes
#######################################################################################################################################################

In [3]:
# Color Palettes
TABLEAU_COLORS = [
    '#1f77b4',  # Blue
    '#ff7f0e',  # Orange
    '#2ca02c',  # Green
    '#d62728',  # Red
    '#9467bd',  # Purple
    '#8c564b',  # Brown
    '#e377c2',  # Pink
    '#7f7f7f',  # Gray
]

"""
BRIGHT_COLORS = [
    '#e41a1c',  # Red
    '#377eb8',  # Blue
    '#4daf4a',  # Green
    '#984ea3',  # Purple
    '#ff7f00',  # Orange
    '#ffff33',  # Yellow
    '#a65628',  # Brown
    '#f781bf',  # Pink
]
"""

BRIGHT_COLORS = [
    '#e41a1c',  # Red
    '#ff7f00',  # Orange
    '#d9d904',  # Yellow
    '#4daf4a',  # Green
    '#377eb8',  # Blue
    '#984ea3',  # Purple
    '#a65628',  # Brown
    '#f781bf',  # Pink
]

COLORBLIND_OKABE_ITO = [
    '#E69F00',  # Orange
    '#56B4E9',  # Sky Blue
    '#009E73',  # Bluish Green
    '#F0E442',  # Yellow
    '#0072B2',  # Blue
    '#D55E00',  # Vermilion
    '#CC79A7',  # Reddish Purple
    '#999999',  # Gray
]

COLORBLIND_CUD = [
    '#0072B2',  # Blue
    '#009E73',  # Green
    '#D55E00',  # Vermilion
    '#CC79A7',  # Purple
    '#F0E442',  # Yellow
    '#56B4E9',  # Cyan
    '#E69F00',  # Orange
    '#999999',  # Gray
]


In [4]:
#######################################################################################################################################################
# Classes
#######################################################################################################################################################

In [5]:
class Plot:
    """
    A class to store information about a plot.
    
    Attributes:
        name : The plot's name, auto-generated as "Plot #" if not provided.
        ID : Unique ID number given to each plot, integral to tagging system in DearPyGui
        plot_type : The chosen plot type.
        nwb_file : The chosen NWB file.
        channel : The chosen channel (as an int).
        electrode_mapping : List that maps human-readable list to deafault list sorting from NWB file
        data_start : Starting timestamp for plotted data (in seconds)
        data_end : Ending timestamp for plotted data (in seconds)
    """
    # Class-level counter to generate default plot names.
    _plot_name_counter = 1
    _plot_ID_counter = 1

    # Used to update every set axis when shifting one
    sync_axis_list = []
    
    def __init__(self,
                 name: Optional[str] = None,
                 ID: Optional[str] = None,
                 plot_type: Optional[str] = None,
                 nwb_file: Optional[str] = None,
                 channel: Optional[int] = None,
                 electrode_mapping = None,
                 data_start: Optional[int]  = None,
                 data_end: Optional[int]  = None,
                 data_min = None,
                 data_max = None,
                 sfreq = None,
                 lowcut = None,
                 highcut = None,
                 voltage_scale = None,
                 custom_name = None,
                 plot_output_path = None,
                 export_status = None):
        if name is None:
            # Auto-generate the plot name as "Plot #" and increment the counter.
            self.name = f"Plot {Plot._plot_name_counter}"
            Plot._plot_name_counter += 1
        else:
            self.name = name

        self.ID = Plot._plot_ID_counter
        Plot._plot_ID_counter += 1

        self.plot_type = None
        self.nwb_file = None
        self.channel = None
        self.data_start = 0
        self.data_end = 120
        self.data_min = 0
        self.data_max = 0
        self.sfreq = 0
        self.lowcut = 1
        self.highcut = 200
        self.voltage_scale = 1000
        self.custom_name = "Default"
        self.plot_output_path = ""
        self.export_status = False

    def set_plot_name(self, name):
        self.name = name
    def get_plot_name(self):
        return self.name

    def set_plot_type(self, plot_type):
        self.plot_type = plot_type
    def get_plot_type(self):
        return self.plot_type

    def set_channel(self, channel):
        self.channel = channel
    def get_channel(self):
        return self.channel

    def set_electrode_mapping(self, electrode_mapping):
        self.electrode_mapping = electrode_mapping
    def get_electrode_mapping(self):
        return self.electrode_mapping

    def set_data_start(self, data_start):
        self.data_start = data_start
    def get_data_start(self):
        return self.data_start

    def set_data_end(self, data_end):
        self.data_end = data_end
    def get_data_end(self):
        return self.data_end

    def set_data_min(self, data_min):
        self.data_min = data_min
    def get_data_min(self):
        return self.data_min

    def set_data_max(self, data_max):
        self.data_max = data_max
    def get_data_max(self):
        return self.data_max

    def set_sfreq(self, sfreq):
        self.sfreq = sfreq
    def get_sfreq(self):
        return self.sfreq

    def set_lowcut(self, lowcut):
        self.lowcut = lowcut
    def get_lowcut(self):
        return self.lowcut

    def set_highcut(self, highcut):
        self.highcut = highcut
    def get_highcut(self):
        return self.highcut

    def set_voltage_scale(self, voltage_scale):
            self.voltage_scale = voltage_scale
    def get_voltage_scale(self):
        return self.voltage_scale

    def set_custom_name(self, custom_name):
        self.custom_name = custom_name
    def get_custom_name(self):
        return self.custom_name

    def set_plot_output_path(self, plot_output_path):
        self.plot_output_path = plot_output_path
    def get_plot_output_path(self):
        return self.plot_output_path

    def set_export_status(self, export_status):
        self.export_status = export_status
    def get_export_status(self):
        return self.export_status

    def add_to_sync_list(self, plot_instance_id):
        Plot.sync_axis_list.append(plot_instance_id)
    def remove_from_sync_list(self, plot_instance_id):
        Plot.sync_axis_list.remove(plot_instance_id)
    def get_sync_list():
        return Plot.sync_axis_list
        
    @property
    def get_folder_path(self):
        return NWBFolder.get_folder_path()

    @property
    def get_file_list(self):
        return NWBFolder.get_file_list()

    @property
    def get_rat_and_probe_from_channel(self):
        return NWBFolder.get_rat_and_probe_from_channel(self.nwb_file, self.channel)

    def __repr__(self):
        return (f"Plot(name={self.name!r}, plot_type={self.plot_type!r}, "
                f"nwb_file={self.nwb_file!r}, channel={self.channel!r})")

In [6]:
class Analysis:
    """
    A class to store information about analysis settings
    
    Attributes:

    """
    _analysis_ID_counter = 1
    
    def __init__(self,
                 name='PAC',
                 ID=None,
                 nwb_list=None,
                 lowfreq_override='Custom',
                 lowfreq_lowpass=5,
                 lowfreq_highpass=17,
                 lowfreq_width=2,
                 lowfreq_step=1,
                 
                 highfreq_override='Custom',
                 highfreq_lowpass=40,
                 highfreq_highpass=100,
                 highfreq_width=10,
                 highfreq_step=5,
                 
                 PAC_method='Modulation Index',
                 surrogate_method='Swap Phase/Ampl. Across Trials',
                 normalization_method='Z-score',
                 dcomplex_method='Wavelet',
                 phase_cycles=3,
                 amplitude_cycles=6,
                 morlet_width=7,
                 KLD_or_HRPAC_bins=18,
                 
                 high_pass_filter = 0.0,
                 filter_notch_frequencies='[60, 120, 180, 240, 300]',
                 rejection_threshold=0,
                 detrend_epochs = True,
                 apply_autofilter = False,

                 skip_PAC=False,
                 inject_PAC=False,
                 seed=0,
                 surrogate_count=200,
                 parallel_processes=8,
                 data_length=10,
                 minimum_length=5,
                 epoch_length=20,
                 sampling_rate=2000,
                 downsample_rate=500,

                 PAC_custom_colormap=False,
                 PAC_vmin=-1.96,
                 PAC_vmax=1.96,
                 PAC_comod_interpolation=0.0,
                 
                 PSD_notch_frequencies='None',
                 PSD_high_pass_filter=0.0,
                 PSD_fmin=0.0,
                 PSD_fmax=200.0,
                 PSD_voltage_scale="Microvolts",
                 PSD_FFT_resolution=8,
                 PSD_plot_raw=True,
                 PSD_plot_filtered=False,
                 PSD_plot_grouping_method='No Grouping',
                 PSD_correct_1overf=False,

                 output_directory=None,
                 output_folder_name="GUIDA PAC Session [Timestamp]",
                 save_data=False,
                 export_PNG=True,
                 export_PDF=False,
                 export_SVG=False,
                 export_EPS=False,
                 image_height=5.0,
                 image_width=6.0,
                 image_DPI=300,
                 color_palette='Bright Colors',
                 y_custom_axis=False,
                 yaxis_top=10.0,
                 yaxis_bottom=0.0,
                 alpha=1,
                 
                 selected_channels=None,
                 selected_rats=None):

        self.name = name
        self.ID = Analysis._analysis_ID_counter if ID is None else ID
        self.nwb_list = nwb_list if nwb_list is not None else []

        # Low frequency settings
        self.lowfreq_override = lowfreq_override
        self.lowfreq_lowpass = lowfreq_lowpass
        self.lowfreq_highpass = lowfreq_highpass
        self.lowfreq_width = lowfreq_width
        self.lowfreq_step = lowfreq_step

        # High frequency settings
        self.highfreq_override = highfreq_override
        self.highfreq_lowpass = highfreq_lowpass
        self.highfreq_highpass = highfreq_highpass
        self.highfreq_width = highfreq_width
        self.highfreq_step = highfreq_step

        # Filtering settings
        self.high_pass_filter = high_pass_filter
        self.filter_notch_frequencies = filter_notch_frequencies
        self.rejection_threshold = rejection_threshold
        self.detrend_epochs = detrend_epochs
        self.apply_autofilter = apply_autofilter

        # PAC method settings
        self.PAC_method = PAC_method
        self.surrogate_method = surrogate_method
        self.normalization_method = normalization_method
        self.dcomplex_method = dcomplex_method
        self.phase_cycles = phase_cycles
        self.amplitude_cycles = amplitude_cycles
        self.morlet_width = morlet_width
        self.KLD_or_HRPAC_bins = KLD_or_HRPAC_bins

        # Processing parameters
        self.skip_PAC = skip_PAC
        self.inject_PAC = inject_PAC
        self.seed = seed
        self.surrogate_count = surrogate_count
        self.parallel_processes = parallel_processes
        self.data_length = data_length
        self.minimum_length = minimum_length
        self.epoch_length = epoch_length
        self.sampling_rate = sampling_rate
        self.downsample_rate = downsample_rate

        # PAC export parameters
        self.PAC_custom_colormap = PAC_custom_colormap
        self.PAC_vmin = PAC_vmin
        self.PAC_vmax = PAC_vmax
        self.PAC_comod_interpolation = PAC_comod_interpolation

        # PSD parameters
        self.PSD_notch_frequencies = PSD_notch_frequencies
        self.PSD_high_pass_filter = PSD_high_pass_filter
        self.PSD_fmin = PSD_fmin
        self.PSD_fmax = PSD_fmax
        self.PSD_voltage_scale = PSD_voltage_scale
        self.PSD_FFT_resolution = PSD_FFT_resolution
        self.PSD_plot_raw = PSD_plot_raw
        self.PSD_plot_filtered = PSD_plot_filtered
        self.PSD_plot_grouping_method = PSD_plot_grouping_method
        self.PSD_correct_1overf = PSD_correct_1overf

        # Export parameters
        self.output_directory = output_directory
        self.output_folder_name = output_folder_name
        self.save_data = save_data
        self.export_PNG = export_PNG
        self.export_PDF = export_PDF
        self.export_SVG = export_SVG
        self.export_EPS = export_EPS
        self.image_height = image_height
        self.image_width = image_width
        self.image_DPI = image_DPI
        self.color_palette = color_palette
        self.y_custom_axis = y_custom_axis
        self.yaxis_top = yaxis_top
        self.yaxis_bottom = yaxis_bottom
        self.alpha = alpha

        self.selected_channels = selected_channels if selected_channels is not None else {}
        self.selected_rats = selected_rats if selected_rats is not None else {}

        Analysis._analysis_ID_counter += 1

    def set_name(self, name):
        self.name = name
    def get_name(self):
        return self.name

    # Low frequency settings

    def set_lowfreq_override(self, lowfreq_override):
        self.lowfreq_override = lowfreq_override
    def get_lowfreq_override(self):
        return self.lowfreq_override

    def set_lowfreq_lowpass(self, lowfreq_lowpass):
        self.lowfreq_lowpass = lowfreq_lowpass
    def get_lowfreq_lowpass(self):
        return self.lowfreq_lowpass

    def set_lowfreq_highpass(self, lowfreq_highpass):
        self.lowfreq_highpass = lowfreq_highpass
    def get_lowfreq_highpass(self):
        return self.lowfreq_highpass

    def set_lowfreq_width(self, lowfreq_width):
        self.lowfreq_width = lowfreq_width
    def get_lowfreq_width(self):
        return self.lowfreq_width

    def set_lowfreq_step(self, lowfreq_step):
        self.lowfreq_step = lowfreq_step
    def get_lowfreq_step(self):
        return self.lowfreq_step

    # High frequency settings

    def set_highfreq_override(self, highfreq_override):
        self.highfreq_override = highfreq_override
    def get_highfreq_override(self):
        return self.highfreq_override

    def set_highfreq_lowpass(self, highfreq_lowpass):
        self.highfreq_lowpass = highfreq_lowpass
    def get_highfreq_lowpass(self):
        return self.highfreq_lowpass

    def set_highfreq_highpass(self, highfreq_highpass):
        self.highfreq_highpass = highfreq_highpass
    def get_highfreq_highpass(self):
        return self.highfreq_highpass

    def set_highfreq_width(self, highfreq_width):
        self.highfreq_width = highfreq_width
    def get_highfreq_width(self):
        return self.highfreq_width

    def set_highfreq_step(self, highfreq_step):
        self.highfreq_step = highfreq_step
    def get_highfreq_step(self):
        return self.highfreq_step

    # Filtering settings

    def set_high_pass_filter(self, high_pass_filter):
        self.high_pass_filter = high_pass_filter
    def get_high_pass_filter(self):
        return self.high_pass_filter

    def set_filter_notch_frequencies(self, filter_notch_frequencies):
        self.filter_notch_frequencies = filter_notch_frequencies
    def get_filter_notch_frequencies(self):
        return self.filter_notch_frequencies

    def set_rejection_threshold(self, rejection_threshold):
        self.rejection_threshold = rejection_threshold
    def get_rejection_threshold(self):
        return self.rejection_threshold

    def set_detrend_epochs(self, detrend_epochs):
        self.detrend_epochs = detrend_epochs
    def get_detrend_epochs(self):
        return self.detrend_epochs

    def set_apply_autofilter(self, apply_autofilter):
        self.apply_autofilter = apply_autofilter
    def get_apply_autofilter(self):
        return self.apply_autofilter

    # PAC method settings

    def set_PAC_method(self, PAC_method):
        self.PAC_method = PAC_method
    def get_PAC_method(self):
        return self.PAC_method

    def set_surrogate_method(self, surrogate_method):
        self.surrogate_method = surrogate_method
    def get_surrogate_method(self):
        return self.surrogate_method

    def set_normalization_method(self, normalization_method):
        self.normalization_method = normalization_method
    def get_normalization_method(self):
        return self.normalization_method

    def set_dcomplex_method(self, dcomplex_method):
        self.dcomplex_method = dcomplex_method
    def get_dcomplex_method(self):
        return self.dcomplex_method

    def set_phase_cycles(self, phase_cycles):
        self.phase_cycles = phase_cycles
    def get_phase_cycles(self):
        return self.phase_cycles

    def set_amplitude_cycles(self, amplitude_cycles):
        self.amplitude_cycles = amplitude_cycles
    def get_amplitude_cycles(self):
        return self.amplitude_cycles

    def set_morlet_width(self, morlet_width):
        self.morlet_width = morlet_width
    def get_morlet_width(self):
        return self.morlet_width

    def set_KLD_or_HRPAC_bins(self, KLD_or_HRPAC_bins):
        self.KLD_or_HRPAC_bins = KLD_or_HRPAC_bins
    def get_KLD_or_HRPAC_bins(self):
        return self.KLD_or_HRPAC_bins

    # Processing parameters

    def set_skip_PAC(self, skip_PAC):
        self.skip_PAC = skip_PAC
    def get_skip_PAC(self):
        return self.skip_PAC

    def set_inject_PAC(self, inject_PAC):
        self.inject_PAC = inject_PAC
    def get_inject_PAC(self):
        return self.inject_PAC

    def set_seed(self, seed):
        self.seed = seed
    def get_seed(self):
        return self.seed
    
    def set_downsample_rate(self, downsample_rate):
        self.downsample_rate = downsample_rate
    def get_downsample_rate(self):
        return self.downsample_rate

    def set_sampling_rate(self, sampling_rate):
        self.sampling_rate = sampling_rate
    def get_sampling_rate(self):
        return self.sampling_rate

    def set_epoch_length(self, epoch_length):
        self.epoch_length = epoch_length
    def get_epoch_length(self):
        return self.epoch_length

    def set_data_length(self, data_length):
        self.data_length = data_length
    def get_data_length(self):
        return self.data_length
        
    def set_minimum_length(self, minimum_length):
        self.minimum_length = minimum_length
    def get_minimum_length(self):
        return self.minimum_length

    def set_parallel_processes(self, parallel_processes):
        self.parallel_processes = parallel_processes
    def get_parallel_processes(self):
        return self.parallel_processes

    def set_surrogate_count(self, surrogate_count):
        self.surrogate_count = surrogate_count
    def get_surrogate_count(self):
        return self.surrogate_count
        
    # PAC export parameters

    def set_PAC_custom_colormap(self, PAC_custom_colormap):
        self.PAC_custom_colormap = PAC_custom_colormap
    def get_PAC_custom_colormap(self):
        return self.PAC_custom_colormap

    def set_PAC_vmin(self, PAC_vmin):
        self.PAC_vmin = PAC_vmin
    def get_PAC_vmin(self):
        return self.PAC_vmin

    def set_PAC_vmax(self, PAC_vmax):
        self.PAC_vmax = PAC_vmax
    def get_PAC_vmax(self):
        return self.PAC_vmax

    def set_PAC_comod_interpolation(self, PAC_comod_interpolation):
        self.PAC_comod_interpolation = PAC_comod_interpolation
    def get_PAC_comod_interpolation(self):
        return self.PAC_comod_interpolation

    # PSD parameters

    def set_PSD_notch_frequencies(self, PSD_notch_frequencies):
        self.PSD_notch_frequencies = PSD_notch_frequencies
    def get_PSD_notch_frequencies(self):
        return self.PSD_notch_frequencies

    def set_PSD_high_pass_filter(self, PSD_high_pass_filter):
        self.PSD_high_pass_filter = PSD_high_pass_filter
    def get_PSD_high_pass_filter(self):
        return self.PSD_high_pass_filter

    def set_PSD_fmin(self, PSD_fmin):
        self.PSD_fmin = PSD_fmin
    def get_PSD_fmin(self):
        return self.PSD_fmin

    def set_PSD_fmax(self, PSD_fmax):
        self.PSD_fmax = PSD_fmax
    def get_PSD_fmax(self):
        return self.PSD_fmax

    def set_PSD_voltage_scale(self, PSD_voltage_scale):
        self.PSD_voltage_scale = PSD_voltage_scale
    def get_PSD_voltage_scale(self):
        return self.PSD_voltage_scale

    def set_PSD_FFT_resolution(self, PSD_FFT_resolution):
        self.PSD_FFT_resolution = PSD_FFT_resolution
    def get_PSD_FFT_resolution(self):
        return self.PSD_FFT_resolution

    def set_PSD_plot_raw(self, PSD_plot_raw):
        self.PSD_plot_raw = PSD_plot_raw
    def get_PSD_plot_raw(self):
        return self.PSD_plot_raw

    def set_PSD_plot_filtered(self, PSD_plot_filtered):
        self.PSD_plot_filtered = PSD_plot_filtered
    def get_PSD_plot_filtered(self):
        return self.PSD_plot_filtered

    def set_PSD_plot_grouping_method(self, PSD_plot_grouping_method):
        self.PSD_plot_grouping_method = PSD_plot_grouping_method
    def get_PSD_plot_grouping_method(self):
        return self.PSD_plot_grouping_method
    
    def set_PSD_correct_1overf(self, PSD_correct_1overf):
        self.PSD_correct_1overf = PSD_correct_1overf
    def get_PSD_correct_1overf(self):
        return self.PSD_correct_1overf
    
    # Export parameters

    def set_output_directory(self, output_directory):
        self.output_directory = output_directory
    def get_output_directory(self):
        return self.output_directory

    def set_output_folder_name(self, output_folder_name):
        self.output_folder_name = output_folder_name
    def get_output_folder_name(self):
        return self.output_folder_name

    def set_save_data(self, save_data):
        self.save_data = save_data
    def get_save_data(self):
        return self.save_data

    def set_export_PNG(self, export_PNG):
        self.export_PNG = export_PNG
    def get_export_PNG(self):
        return self.export_PNG

    def set_export_PDF(self, export_PDF):
        self.export_PDF = export_PDF
    def get_export_PDF(self):
        return self.export_PDF

    def set_export_SVG(self, export_SVG):
        self.export_SVG = export_SVG
    def get_export_SVG(self):
        return self.export_SVG

    def set_export_EPS(self, export_EPS):
        self.export_EPS = export_EPS
    def get_export_EPS(self):
        return self.export_EPS

    def set_image_height(self, image_height):
        self.image_height = image_height
    def get_image_height(self):
        return self.image_height

    def set_image_width(self, image_width):
        self.image_width = image_width
    def get_image_width(self):
        return self.image_width

    def set_image_DPI(self, image_DPI):
        self.image_DPI = image_DPI
    def get_image_DPI(self):
        return self.image_DPI

    def set_color_palette(self, color_palette):
        self.color_palette = color_palette
    def get_color_palette(self):
        return self.color_palette

    def set_y_custom_axis(self, y_custom_axis):
        self.y_custom_axis = y_custom_axis
    def get_y_custom_axis(self):
        return self.y_custom_axis

    def set_yaxis_top(self, yaxis_top):
        self.yaxis_top = yaxis_top
    def get_yaxis_top(self):
        return self.yaxis_top

    def set_yaxis_bottom(self, yaxis_bottom):
        self.yaxis_bottom = yaxis_bottom
    def get_yaxis_bottom(self):
        return self.yaxis_bottom

    def set_alpha(self, alpha):
        self.alpha = alpha
    def get_alpha(self):
        return self.alpha

    # Additional Class Functions

    def add_to_nwb_list(self, nwb_list_item):
        self.nwb_list.append(nwb_list_item)
    def remove_from_nwb_list(self, nwb_list_item):
        self.nwb_list.remove(nwb_list_item)
    def get_nwb_list(self):
        return self.nwb_list
    
    def add_selected_channel(self, file, channel_name, data_index):
        if file not in self.selected_channels:
            self.selected_channels[file] = {}
        self.selected_channels[file][channel_name] = data_index
    
    def remove_selected_channel(self, file, channel_name):
        if file in self.selected_channels and channel_name in self.selected_channels[file]:
            del self.selected_channels[file][channel_name]
            if not self.selected_channels[file]:  # If no channels left
                del self.selected_channels[file]

    def get_selected_channels(self):
        return self.selected_channels

    def add_selected_rat(self, file, rat, rat_channels, data_columns):
        if file not in self.selected_rats:
            self.selected_rats[file] = {}
        self.selected_rats[file][rat] = {'channels' : rat_channels, 'data_columns' : data_columns }
    
    def remove_selected_rat(self, file, rat):
        if file in self.selected_rats and rat in self.selected_rats[file]:
            del self.selected_rats[file][rat]
            if not self.selected_rats[file]:  # If no rats left
                del self.selected_rats[file]

    def get_selected_rats(self):
        return self.selected_rats

    def get_all_PAC_parameters(self):

        lowfreq_override = self.lowfreq_override
        lowfreq_lowpass = self.lowfreq_lowpass
        lowfreq_highpass = self.lowfreq_highpass
        lowfreq_width = self.lowfreq_width
        lowfreq_step = self.lowfreq_step

        highfreq_override = self.highfreq_override
        highfreq_lowpass = self.highfreq_lowpass
        highfreq_highpass = self.highfreq_highpass
        highfreq_width = self.highfreq_width
        highfreq_step = self.highfreq_step

        high_pass_filter = self.high_pass_filter
        filter_notch_frequencies = self.filter_notch_frequencies
        rejection_threshold = self.rejection_threshold
        detrend_epochs = self.detrend_epochs
        apply_autofilter = self.apply_autofilter
        
        PAC_method = self.PAC_method
        surrogate_method = self.surrogate_method
        normalization_method = self.normalization_method
        dcomplex_method = self.dcomplex_method
        phase_cycles = self.phase_cycles
        amplitude_cycles = self.amplitude_cycles
        morlet_width = self.morlet_width
        KLD_or_HRPAC_bins = self.KLD_or_HRPAC_bins

        skip_PAC = self.skip_PAC
        inject_PAC = self.inject_PAC
        seed = self.seed
        surrogate_count = self.surrogate_count
        parallel_processes = self.parallel_processes
        data_length = self.data_length
        minimum_length = self.minimum_length
        epoch_length = self.epoch_length
        sampling_rate = self.sampling_rate
        downsample_rate = self.downsample_rate

        if lowfreq_override == 'Custom':
            phase_frequency = (lowfreq_lowpass, lowfreq_highpass, lowfreq_width, lowfreq_step)
        else:
            phase_frequency = lowfreq_override

        if highfreq_override == 'Custom':
            amplitude_frequency = (highfreq_lowpass, highfreq_highpass, highfreq_width, highfreq_step)
        else:
            amplitude_frequency = highfreq_override
        
        if PAC_method == 'Mean Vector Length':
            idpac_a = 1
        elif PAC_method == 'Modulation Index':
            idpac_a = 2
        elif PAC_method == 'Heights Ratio':
            idpac_a = 3
        elif PAC_method == 'ndPAC':
            idpac_a = 4
        elif PAC_method == 'Phase-Locking Value':
            idpac_a = 5
        elif PAC_method == 'Gaussian Copula PAC':
            idpac_a = 6

        if surrogate_method == 'No Surrogates':
            idpac_b = 0
        elif surrogate_method == 'Swap Phase/Ampl. Across Trials':
            idpac_b = 1
        elif surrogate_method == 'Swap Amplitude Time Blocks':
            idpac_b = 2
        elif surrogate_method == 'Time Lag':
            idpac_b = 3

        if normalization_method == 'No Normalization':
            idpac_c = 0
        elif normalization_method == 'Subtract Mean of Surrogtes':
            idpac_c = 1
        elif normalization_method == 'Divide Mean of Surrogates':
            idpac_c = 2
        elif normalization_method == 'Sub+Div Mean of Surrogates':
            idpac_c = 3
        elif normalization_method == 'Z-score':
            idpac_c = 4

        if dcomplex_method == 'Wavelet':
            dcomplex = 'wavelet'
        elif dcomplex_method == 'Hilbert':
            dcomplex = 'hilbert'

        idpac = (idpac_a, idpac_b, idpac_c)
        cycles = (phase_cycles, amplitude_cycles)

        if filter_notch_frequencies == 'None':
            filter_notch_frequencies = []
        elif filter_notch_frequencies == '[60]':
            filter_notch_frequencies = [60]
        elif filter_notch_frequencies == '[60, 120]':
            filter_notch_frequencies = [60, 120]
        elif filter_notch_frequencies == '[60, 120, 180]':
            filter_notch_frequencies = [60, 120, 180]
        elif filter_notch_frequencies == '[60, 120, 180, 240]':
            filter_notch_frequencies = [60, 120, 180, 240]
        elif filter_notch_frequencies == '[60, 120, 180, 240, 300]':
            filter_notch_frequencies = [60, 120, 180, 240, 300]

        return (phase_frequency, amplitude_frequency, idpac, dcomplex, cycles, morlet_width, KLD_or_HRPAC_bins, 
                sampling_rate, epoch_length, data_length, minimum_length, parallel_processes, surrogate_count, 
                filter_notch_frequencies, high_pass_filter, detrend_epochs, apply_autofilter, downsample_rate,
                seed, rejection_threshold, skip_PAC, inject_PAC)

    def get_export_parameters(self):

        output_folder_name = self.output_folder_name
        save_data = self.save_data
        export_PNG = self.export_PNG
        export_PDF = self.export_PDF
        export_SVG = self.export_SVG
        export_EPS = self.export_EPS
        image_height = self.image_height
        image_width = self.image_width
        image_DPI = self.image_DPI
        color_palette = self.color_palette
        y_custom_axis = self.y_custom_axis
        yaxis_top = self.yaxis_top
        yaxis_bottom = self.yaxis_bottom
        alpha = self.alpha

        if color_palette == 'Bright Colors':
            color_palette = BRIGHT_COLORS
        elif color_palette == 'Tableau Colors':
            color_palette = TABLEAU_COLORS
        elif color_palette == 'Colorblind Okabe Ito':
            color_palette = COLORBLIND_OKABE_ITO
        elif color_palette == 'Colorblind Cud':
            color_palette = COLORBLIND_CUD

        return (output_folder_name, 
                export_PNG, 
                export_PDF, 
                export_SVG, 
                export_EPS, 
                image_height, 
                image_width, 
                image_DPI, 
                color_palette, 
                y_custom_axis, 
                yaxis_top, 
                yaxis_bottom,
                alpha,
                save_data)

    def get_all_PSD_parameters(self):
        
        PSD_notch_frequencies = self.PSD_notch_frequencies
        PSD_high_pass_filter = self.PSD_high_pass_filter
        sampling_rate = self.sampling_rate
        output_directory = self.output_directory
        PSD_fmin = self.PSD_fmin
        PSD_fmax = self.PSD_fmax
        PSD_voltage_scale = self.PSD_voltage_scale
        PSD_FFT_resolution = self.PSD_FFT_resolution
        PSD_plot_raw = self.PSD_plot_raw
        PSD_plot_filtered = self.PSD_plot_filtered
        PSD_plot_grouping_method = self.PSD_plot_grouping_method
        PSD_correct_1overf = self.PSD_correct_1overf 

        if PSD_notch_frequencies == 'None':
            PSD_notch_frequencies = []
        elif PSD_notch_frequencies == '[60]':
            PSD_notch_frequencies = [60]
        elif PSD_notch_frequencies == '[60, 120]':
            PSD_notch_frequencies = [60, 120]
        elif PSD_notch_frequencies == '[60, 120, 180]':
            PSD_notch_frequencies = [60, 120, 180]
        elif PSD_notch_frequencies == '[60, 120, 180, 240]':
            PSD_notch_frequencies = [60, 120, 180, 240]
        elif PSD_notch_frequencies == '[60, 120, 180, 240, 300]':
            PSD_notch_frequencies = [60, 120, 180, 240, 300]

        if PSD_voltage_scale == "Volts":
            PSD_voltage_scale = 1
        elif PSD_voltage_scale == "Millivolts":
            PSD_voltage_scale = 1000
        elif PSD_voltage_scale == "Microvolts":
            PSD_voltage_scale = 1000000

        return (PSD_notch_frequencies, PSD_high_pass_filter, PSD_correct_1overf, sampling_rate, 
                output_directory, PSD_fmin, PSD_fmax, PSD_voltage_scale, PSD_FFT_resolution,
                PSD_plot_raw, PSD_plot_filtered, PSD_plot_grouping_method)

    def get_all_PAC_export_parameters(self):

        PAC_custom_colormap = self.PAC_custom_colormap
        PAC_vmin = self.PAC_vmin
        PAC_vmax = self.PAC_vmax
        PAC_comod_interpolation = self.PAC_comod_interpolation

        return (PAC_custom_colormap, PAC_vmin, PAC_vmax, PAC_comod_interpolation)

    @property
    def get_folder_path(self):
        return NWBFolder.get_folder_path()

    @property
    def get_file_list(self):
        return NWBFolder.get_file_list()

    def get_rats_in_file(self, file_name):
        return NWBFolder.get_rats_in_file(file_name)

    def get_day_from_file(file):
        return NWBFolder.get_day_from_file(file)

    def get_rat_and_probe_from_channel(nwb_file, channel):
        return NWBFolder.get_rat_and_probe_from_channel(nwb_file, channel)

    def get_channels_from_rat(self, file_name, rat_number):
        return NWBFolder.get_channels_from_rat(file_name, rat_number)

    def get_channel_names_from_rat(self, rat_number):
        return NWBFolder.get_channel_names_from_rat(rat_number)

In [7]:
class JSON:
    """
    A class to store information about JSON file analysis.
    """
    _session_ID_counter = 1
    
    def __init__(self,
                 custom_name = None,
                 plot_output_path = None,
                 session_ID_counter = 1):
        self.plot_output_path = ""
        self.export_status = False
        self.session_ID_counter = session_ID_counter

    def set_custom_name(self, custom_name):
        self.custom_name = custom_name
    def get_custom_name(self):
        return self.custom_name

    def set_plot_output_path(self, plot_output_path):
        self.plot_output_path = plot_output_path
    def get_plot_output_path(self):
        return self.plot_output_path

    def set_session_ID_counter(self, session_ID_counter):
        self.session_ID_counter = session_ID_counter
    def get_session_ID_counter(self):
        return self.session_ID_counter

In [8]:
class NWBFolder:
    '''
    Class for access to folder path to NWB files
    '''
    folder_path = None
    file_list = None

    # Static rat-to-probe mapping from Excel
    rat_probe_map = {
        1: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        2: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        3: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        4: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        5: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        6: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        7: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        8: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 9, 'RDHPC': 7, 'LHPCSCREW': 15, 'RHPCSCREW': 3, 'LPFCSCREW': 16, 'RPFCSCREW': 2},
        9: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        10: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        11: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 2, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        12: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        13: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        14: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        15: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        16: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        17: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        18: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        19: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        20: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 9, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        21: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 16, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        22: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        23: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        24: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 16, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        25: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        26: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        27: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        28: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        29: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        30: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        31: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 16, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        32: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        33: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 16, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        34: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        35: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        36: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        37: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        38: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        39: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        40: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        41: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        42: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        43: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        44: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        45: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        46: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 16, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        47: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        48: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        49: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        50: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        51: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        52: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        53: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 9, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        54: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 16, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        55: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        56: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        57: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        58: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        59: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
        60: {'RAMG': 4, 'RPFC': 1, 'RVHPC': 5, 'RDHPC': 8, 'LHPCSCREW': 12, 'RHPCSCREW': 9, 'LPFCSCREW': 13, 'RPFCSCREW': 2},
    }

    @classmethod
    def set_folder_path(cls, path):
        cls.folder_path = path

    @classmethod
    def get_folder_path(cls):
        return cls.folder_path

    @classmethod
    def set_file_list(cls, list):
        cls.file_list = list

    @classmethod
    def get_file_list(cls):
        return cls.file_list

    @staticmethod
    def get_rats_in_file(file_name):
        base = file_name.split('_')
        if base[0].startswith("DFPCT"):
            first_rat = base[0].split('DFPCT')[1]
            numbers = [first_rat] + [ b for b in base if b.isdigit()]
            return [int(n) for n in numbers if n.isdigit()]
        return []

    @staticmethod
    def get_day_from_file(file):
        split_name = file.split('_')
        for name_piece in split_name:
            if "Day" in name_piece:
                day = ''.join(filter(str.isdigit, name_piece))
                return day, name_piece
    
    @classmethod
    def get_rat_and_probe_from_channel(cls, file_name, global_channel):
        # Pull list of rats from working NWB file
        rat_list = cls.get_rats_in_file(file_name)

        # Convert channel string to int, i.e. 'CSC20' to 20.
        if global_channel.startswith("CSC"):
            global_channel = int(global_channel[3:])

        # Determine what set of 16 channels channel is in, return if invalid
        index = (global_channel - 1) // 16
        if index >= len(rat_list):
            return None, None

        # Shift based on bin, include probe i.e. 'RPFCSCREW' if found
        rat = rat_list[index]
        local_channel = (global_channel - 1) % 16 + 1
        for probe, ch in cls.rat_probe_map.get(rat, {}).items():
            if ch == local_channel:
                return rat, probe
        probe = 'NaN Probe'
        return rat, probe

    @classmethod
    def get_channels_from_rat(cls, file_name, rat_number):
        rat_list = cls.get_rats_in_file(file_name)
        if rat_number not in rat_list:
            return None
        rat_index = rat_list.index(rat_number)
        rat_channels = list(cls.rat_probe_map.get(rat_number, {}).values())

        if rat_channels:
            for i in range(len(rat_channels)):
                rat_channels[i] = rat_index * 16 + rat_channels[i]
        return rat_channels

    @classmethod
    def get_channel_names_from_rat(cls, rat_number):
        return list(cls.rat_probe_map.get(rat_number, {}).keys())


In [9]:
# Instantiate Class to store global info on project files at runtime
NWBFolder_Class = NWBFolder()

In [10]:
#######################################################################################################################################################
# Functions
#######################################################################################################################################################

In [11]:
##### SYNTHETIC Plotting Functions #####

In [12]:
def generate_custom_pac(f_pha=12, f_amp=75, sf=2000, duration=600, amp_strength=1.0, noise_std=0.5, coupling_strength=1.0):
    """
    Generate a synthetic PAC signal manually using an amplitude-modulated carrier.

    Returns:
        pac_signal (1D np.array)
        time vector (1D np.array)
    """
    n_samples = int(duration * sf)
    t = np.arange(n_samples) / sf

    # Phase-providing low-frequency signal
    phase_signal = 10 * np.sin(2 * np.pi * f_pha * t)

    # Amplitude modulation envelope using the low-freq phase
    envelope = 1 + coupling_strength * np.sin(2 * np.pi * f_pha * t)

    # High-frequency carrier
    hf_signal = amp_strength * envelope * np.sin(2 * np.pi * f_amp * t)

    # Additive white noise
    noise = np.random.normal(0, noise_std, size=n_samples)

    pac_signal = hf_signal + noise

    return pac_signal, t

In [13]:
def generate_sinusoidal_noise(scaling_factor, t):
    # Add additional noise to the signal for more authenticity
    power = scaling_factor * np.sin(2 * np.pi * 60 * t)         # 60 Hz interference
    power_harmonic = scaling_factor / 2 * np.sin(2 * np.pi * 120 * t)  # 120 Hz harmonic

    # Add low-frequency movement noise
    low_freq_noise = (
        scaling_factor * 2 * np.sin(2 * np.pi * 0.3 * t) +
        scaling_factor * np.sin(2 * np.pi * 0.7 * t)
    )
    
    # Add signal drift
    drift = 0.5 * np.cumsum(np.random.randn(len(t)))
    drift = drift / np.max(np.abs(drift))  # normalize
    drift *= scaling_factor  # scale

    # --- Add broad-spectrum sinusoidal noise to mimic 1/f ---
    broadband_noise = np.zeros_like(t)
    
    # Loop from 1 Hz to 200 Hz
    freqs = np.arange(1, 201) + np.random.uniform(-0.4, 0.4, 200)
    size = t.shape
    
    for f in freqs:
        for jitter in np.linspace(-0.1, 0.1, 15):  # 15 small neighboring freqs
            freq = f + jitter
            amp = scaling_factor / f
            phase = 2 * np.pi * np.random.rand()
            sine_wave = np.sin(2 * np.pi * freq * t + phase)
            white_noise = np.random.normal(0, 0.5, size)
            noisy_signal = sine_wave + white_noise
            broadband_noise += amp * noisy_signal
    
    # --- Combine everything ---
    sinusoidal_noise = power + power_harmonic + low_freq_noise + drift + broadband_noise
    return sinusoidal_noise

In [14]:
def generate_1overf_noise(n_samples, sf, exponent, amplitude, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)

    # Generate frequency vector
    freqs = np.fft.rfftfreq(n_samples, d=1. / sf)

    # Generate white noise spectrum
    spectrum = np.random.randn(len(freqs)) + 1j * np.random.randn(len(freqs))

    # Avoid division by 0 at DC
    freqs[0] = freqs[1]

    # 1/f scaling with proper amplitude shaping (magnitude = sqrt(power))
    scale = 1.0 / (freqs**(exponent / 2.0))

    # Shape the spectrum
    shaped_spectrum = spectrum * scale

    # Inverse FFT to time domain
    shaped_noise = np.fft.irfft(shaped_spectrum, n=n_samples)

    # Normalize and scale
    shaped_noise -= np.mean(shaped_noise)
    shaped_noise /= np.std(shaped_noise)
    shaped_noise *= amplitude

    return shaped_noise

In [15]:
def save_plot(t, signal, output_dir, filename, title, ylabel="Amplitude", xlabel="Time (s)", legend_label=None, show_plot=False, figsize=(16, 4)):
    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(t, signal, label=legend_label if legend_label else None, color="black")
    #ax.set_ylim(top=0.6, bottom=-0.6)
    ax.set_xlabel(xlabel).set_fontproperties(termes_font_bold)
    ax.set_ylabel(ylabel).set_fontproperties(termes_font_bold)
    ax.set_title(title).set_fontproperties(termes_font_bold)
    ax.grid(visible=True, which='major')
    if legend_label:
        ax.legend()
    fig.tight_layout()
    
    filepath = os.path.join(output_dir, filename)
    fig.savefig(filepath + ".png", dpi=600)
    fig.savefig(filepath + ".pdf")
    if show_plot:
        plt.show()
    plt.close(fig)

In [16]:
def plot_minimal_traces(traces, labels, sampling_rate, output_path, figure_letter, name):
    """
    Plot multiple EEG/LFP traces in minimal publication style.

    Parameters:
        traces         : list of 1D numpy arrays (voltage signals)
        labels         : list of strings (region/channel names)
        sampling_rate  : Hz (float)
        output_path    : full path to save figure (e.g., 'Figure_A.png')
    """
    n_traces = len(traces)
    trace_length = len(traces[0])
    duration = trace_length / sampling_rate
    time = np.linspace(0, duration, trace_length)

    # Define spacing between traces
    amplitude = max(np.ptp(tr) for tr in traces)
    bottom_amplitude = np.ptp(traces[0])
    spacing = amplitude * 1.5
    
    # Prepare figure
    fig, ax = plt.subplots(figsize=(9, 6))  # Adjust size as needed
    for i, trace in enumerate(traces):
        offset = i * spacing
        ax.plot(time, trace + offset, color='black', linewidth=0.75)

        # Add label to the left
        txt = ax.text(-0.05 * duration, 
                offset, 
                labels[i], 
                va='center', 
                ha='right', 
                fontsize=10, 
                rotation='vertical')
        txt.set_fontproperties(termes_font_bold)
        txt.set_fontsize(16)

    bar_left = duration - (duration / 5) + (duration / 25)
    bar_right = duration + (duration / 25)
    bar_bottom = 0 - bottom_amplitude * 1
    bar_top = bottom_amplitude * 0.7
    x_center = (bar_right - bar_left) / 2
    y_center = (bar_top - bar_bottom) / 2
    # Position bars
    ax.plot([bar_left, bar_right], [bar_bottom, bar_bottom], color='black', linewidth=2)
    ax.plot([bar_right, bar_right], [bar_bottom, bar_top], color='black', linewidth=2)

    xscale = ax.text(bar_left + x_center, 
            bar_bottom + (bar_bottom / 10), 
            f'{round(bar_right - bar_left, 1)} s', 
            ha='center', 
            va='top', 
            fontsize=9)
    xscale.set_fontproperties(termes_font_bold)
    xscale.set_fontsize(16)
    
    yscale = ax.text(bar_right + (bar_right / 25), 
            bar_bottom + y_center, 
            f'{round(bar_top - bar_bottom, 1)} mV', 
            ha='right', 
            va='center', 
            rotation='vertical', 
            fontsize=9)
    yscale.set_fontproperties(termes_font_bold)
    yscale.set_fontsize(16)
    
    letter = ax.text(-0.05 * duration, 
                     len(traces) * (amplitude * 1.35), 
                     f'{figure_letter}', 
                     ha='center', 
                     va='top',  
                     fontsize=32)

    letter.set_fontproperties(termes_font_bold)
    letter.set_fontsize(32)

    # Remove all axis elements
    ax.set_axis_off()
    plt.tight_layout()
    
    filepath = os.path.join(output_path, f"synthetic_multiplot_{name}")
    fig.savefig(filepath + ".png", dpi=900, bbox_inches='tight')
    fig.savefig(filepath + ".pdf", bbox_inches='tight')
    plt.close(fig)

In [17]:
def compute_PAC_of_channels_SYNTHETIC(Analysis_class):
    """
    This function processes PAC on channels individually. It uses synthetic data to
    validate a correct process for detected PAC.
    """
    try:
        # Retrieve PAC settings
        (pha_freqs, amp_freqs, idpac, dcomplex, cycles, width, n_bins,
         sampling_rate, epoch_len, data_length, minimum_length, parallel_processes,
         surrogate_permutations, notch_freqs, high_pass_filter, detrend_epochs,
         apply_autofilter, downsample_rate, seed, rejection_threshold, skip_PAC, 
         inject_PAC)= Analysis_class.get_all_PAC_parameters()

        # Prepare IO paths
        timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
        folder_path = Analysis_class.get_folder_path
        nwb_list = Analysis_class.get_nwb_list()
        channels = Analysis_class.get_selected_channels()

        output_directory = Analysis_class.get_output_directory()
        output_folder_name = Analysis_class.get_output_folder_name()
        if "[Timestamp]" in output_folder_name:
            output_folder_name = output_folder_name.split("[Timestamp]")[0]
            output_folder_name = f"{output_folder_name}{timestamp} SYNTHETIC TEST"
        
        output_directory = os.path.join(output_directory, output_folder_name)
        os.makedirs(output_directory, exist_ok=True)

        #=========================================================================================================
        ### Make raw data ###

        # TensorPAC method
        f_pha = 12     # frequency phase for the coupling
        f_amp = 75     # frequency amplitude for the coupling
        n_epochs = 30  # number of trials
        epoch_len = 20 # length of epoch in seconds
        sf = 2000.     # sampling frequency
        n_times = int(epoch_len * sf)  # number of time points
        sampling_rate = sf
        tensorpac_epochs, time = pac_signals_wavelet(sf=sf, f_pha=f_pha, f_amp=f_amp, noise=1.0,
                                      n_epochs=n_epochs, n_times=n_times)
        raw_data = tensorpac_epochs.flatten()
        
        # Manual method
        #raw_data, _ = generate_custom_pac(f_pha, f_amp, sf, n_epochs * epoch_len)

        # Scale synthetic PAC signal to realistic amplitude (e.g., ±0.5 mV)
        target_peak_to_peak = 10  # mV total span (±5 mV)
        actual_ptp = np.ptp(raw_data)
        scale_factor = target_peak_to_peak / actual_ptp
        raw_data *= scale_factor * 0.015 # Adjust const to weaken coupling
        
        #=========================================================================================================
        ### Define variables for plotting ###

        # Generate plot of synthetic data
        duration_sec = raw_data.shape[0] / sampling_rate
        t = np.linspace(0, duration_sec, raw_data.shape[0], endpoint=False)
        
        # Define how many seconds to view
        view_seconds = 2
        wide_view_scaling = 5
        samples_to_plot = int(view_seconds * sampling_rate)
        
        # Generate time vector for plotting
        t_plot = t[:samples_to_plot]
        t_plot_wide = t[:samples_to_plot * wide_view_scaling]

        # Save current data
        raw_unfiltered = raw_data.copy()
        signal_01_plot = raw_unfiltered[:samples_to_plot]
        signal_01_plot_wide = raw_unfiltered[:samples_to_plot * wide_view_scaling]

        #=========================================================================================================
        ### Add noise to signal ###
        
        scaling_factor = 0.025

        # Compute noise using both custom methods
        aperiodic = generate_1overf_noise(n_samples=len(t), sf=sf, exponent=1, amplitude=2.0)
        # Add additional noise to the signal for more authenticity
        power = scaling_factor * np.sin(2 * np.pi * 60 * t)         # 60 Hz interference
        power_harmonic = scaling_factor * 0.6 * np.sin(2 * np.pi * 120 * t)  # 120 Hz harmonic
        power_harmonic_2 = scaling_factor * 0.3 * np.sin(2 * np.pi * 180 * t)  # 120 Hz harmonic
        
        #sinusoidal_noise = generate_sinusoidal_noise(scaling_factor, t)

        # Combine with coupled data
        raw_data += scaling_factor * aperiodic
        raw_data += power + power_harmonic + power_harmonic_2

        # Save current data
        raw_noisy = raw_data.copy()
        signal_02_plot = raw_noisy[:samples_to_plot]
        signal_02_plot_wide = raw_noisy[:samples_to_plot * wide_view_scaling]
        
        #=========================================================================================================
        ### Apply notch filter ###

        # Wrap raw data into MNE object for less manual processing steps
        info = mne.create_info(["Synthetic Data"], sampling_rate, ch_types="eeg")
        raw = mne.io.RawArray(raw_data[np.newaxis, :], info, verbose=False)

        # Apply notch filters
        notch_freqs = [60, 120, 180, 240, 300]
        raw.notch_filter(freqs=notch_freqs, picks="Synthetic Data", method='fir', verbose=False)

        # Save current data
        raw_after_notch = raw.get_data()
        signal_03_plot = raw_after_notch[0, :samples_to_plot]
        signal_03_plot_wide = raw_after_notch[0, :samples_to_plot * wide_view_scaling]
        
        #=========================================================================================================
        ### Apply bandpass, downsample, detrend ###

        # Apply bandpass filter
        raw.filter(l_freq=1.0, h_freq=200, picks="Synthetic Data", verbose=False)

        # Save current data
        raw_after_bandpass = raw.get_data()
        signal_04_plot = raw_after_bandpass[0, :samples_to_plot]
        signal_04_plot_wide = raw_after_bandpass[0, :samples_to_plot * wide_view_scaling]

        # Downsample data
        if downsample_rate:
            raw.resample(sfreq=downsample_rate)
        downsampled_sampling_rate = raw.info['sfreq']

        # Epoch data
        events = mne.make_fixed_length_events(raw, duration=epoch_len)
        epochs = mne.Epochs(raw, events, tmin=0, tmax=epoch_len, baseline=None, preload=True)
    
        # Detrend data
        if detrend_epochs:
            data = epochs.get_data()
            order = 1
            data_detrended = mne.filter.detrend(data, order=order)
            cleaned_epochs = data_detrended.squeeze(axis=1)
        else:
            cleaned_epochs = epochs.get_data().squeeze(axis=1)
        
        #=========================================================================================================
        ### Compute PSD using MNE welch method ###
        
        nperseg = int(sf * 8)
        scaled_raw_data = raw_noisy[np.newaxis, :] * 1000
        scaled_filter_data = raw_after_bandpass * 1000
        psd_raw, freqs_raw = psd_array_welch(scaled_raw_data, sfreq=sf, n_fft=nperseg, fmin=0, fmax=200, verbose=False)
        psd_filter, freqs_filter = psd_array_welch(scaled_filter_data, sfreq=sf, n_fft=nperseg, fmin=0, fmax=200, verbose=False)
        
        # Convert to dB
        psd_raw_db = 10 * np.log10(np.where(psd_raw[0] > 0, psd_raw[0], np.nan))
        psd_filter_db = 10 * np.log10(np.where(psd_filter[0] > 0, psd_filter[0], np.nan))

        # Plot
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.plot(freqs_raw, psd_raw_db, color="red", label="Raw")
        ax.plot(freqs_filter, psd_filter_db, label="Filtered")
        title = ax.set_title("Synthetic Data PSD")
        xl = ax.set_xlabel("Frequency (Hz)")
        yl = ax.set_ylabel("Power Spectral Density (dB)")
        title.set_fontproperties(termes_font_bold)
        title.set_fontsize(24)
        xl.set_fontproperties(termes_font_bold)
        xl.set_fontsize(18)
        yl.set_fontproperties(termes_font_bold)
        yl.set_fontsize(18)
        ax.set_ylim(top=35, bottom=0)
        legend = ax.legend()
        for text in legend.get_texts():
            text.set_fontproperties(termes_font)
            text.set_fontsize(14)
        for label in ax.get_xticklabels() + ax.get_yticklabels():
            label.set_fontproperties(termes_font)
            label.set_fontsize(14)
        ax.grid(True)
        base_name = "raw_synthetic_PSD"
        fig_path = os.path.join(output_directory, base_name + ".png")
        plt.tight_layout()
        fig.savefig(fig_path, dpi=600) 
        fig_path = os.path.join(output_directory, base_name + ".pdf")
        fig.savefig(fig_path)
        plt.close(fig)

        #=========================================================================================================
        ### Plot Results ###

        # Short plots
        save_plot(
            t_plot, signal_01_plot, output_directory,
            filename="01_raw_synthetic_signal",
            title=f"Raw Synthetic Coupled Signal (First {view_seconds} sec)",
        )

        save_plot(
            t_plot, signal_02_plot, output_directory,
            filename="02_noisy_synthetic_signal",
            title=f"Noisy Synthetic Coupled Signal (First {view_seconds} sec)",
        )

        save_plot(
            t_plot, signal_03_plot, output_directory,
            filename="03_notch_filter_synthetic_signal",
            title=f"Notch Filter Synthetic Coupled Signal (First {view_seconds} sec)",
        )

        save_plot(
            t_plot, signal_04_plot, output_directory,
            filename="04_bandpass_filter_synthetic_signal",
            title=f"Bandpass Filter Synthetic Coupled Signal (First {view_seconds} sec)",
        )

        traces = [signal_04_plot, signal_03_plot, signal_02_plot, signal_01_plot]
        labels = ['Bandpass', 'Notch', 'Noisy', 'Raw' ]
        plot_minimal_traces(traces, labels, sf, output_directory, "A", "short")

        # Wide plots
        save_plot(
            t_plot_wide, signal_01_plot_wide, output_directory,
            filename="wide_01_raw_synthetic_signal",
            title=f"Raw Synthetic Coupled Signal (First {view_seconds * wide_view_scaling} sec)",
        )

        save_plot(
            t_plot_wide, signal_02_plot_wide, output_directory,
            filename="wide_02_noisy_synthetic_signal",
            title=f"Noisy Synthetic Coupled Signal (First {view_seconds * wide_view_scaling} sec)",
        )

        save_plot(
            t_plot_wide, signal_03_plot_wide, output_directory,
            filename="wide_03_notch_filter_synthetic_signal",
            title=f"Notch Filter Synthetic Coupled Signal (First {view_seconds * wide_view_scaling} sec)",
        )

        save_plot(
            t_plot_wide, signal_04_plot_wide, output_directory,
            filename="wide_04_bandpass_filter_synthetic_signal",
            title=f"Bandpass Filter Synthetic Coupled Signal (First {view_seconds * wide_view_scaling} sec)",
        )

        traces = [signal_04_plot_wide, signal_03_plot_wide, signal_02_plot_wide, signal_01_plot_wide]
        labels = ['Bandpass', 'Notch', 'Noisy', 'Raw' ]
        plot_minimal_traces(traces, labels, sf, output_directory, "B", "wide")
        

        #=========================================================================================================
        def compute_PAC(data, name, sampling_rate, minimal_plots=False):
            # Initialize TensorPAC object
            PAC_method, surg_method, norm_method = idpac
            p = Pac(idpac=idpac, f_pha=pha_freqs, f_amp=amp_freqs, dcomplex=dcomplex, cycle=cycles, width=width, n_bins=n_bins, verbose=False)
            
            phases     = p.filter(sampling_rate, data, ftype='phase',     n_jobs=parallel_processes)
            amplitudes = p.filter(sampling_rate, data, ftype='amplitude', n_jobs=parallel_processes)
            pac_map    = p.fit(phases, amplitudes, n_perm=surrogate_permutations, n_jobs=parallel_processes)
            del phases, amplitudes
    
            # Generate comodulogram
            pvals = p.infer_pvalues(p=0.05)
            comod = pac_map.mean(-1)
        
            # Plot Comodulogram
            #plt.figure(figsize=(6,5)) if minimal_plots else plt.figure(figsize=(6,5))
            fig = plt.figure(figsize=(5.236,5), constrained_layout=True)
            fig.set_constrained_layout_pads(w_pad=0.01, h_pad=0.01, wspace=0.01, hspace=0.01)

            # Hide axis labels in minimal mode; keep ticks + title
            xlabel = None if minimal_plots else "Frequency for Phase (Hz)"
            ylabel = None if minimal_plots else "Frequency for Amplitude (Hz)"
            
            p.comodulogram(comod,
                           title=f"{name}",
                           xlabel=xlabel,
                           ylabel=ylabel,
                           fz_labels=10,
                           cmap='viridis', 
                           colorbar=False,
                           vmin=-1.96,
                           vmax=1.96,
                           interp=None)
    
            # Grab the axis and apply fonts
            ax = plt.gca()

            im = ax.images[0] if ax.images else ax.collections[0]  # the heatmap
            if not minimal_plots:
                cbar = ax.figure.colorbar(
                    im, ax=ax,
                    orientation="horizontal",
                    location="bottom",   # Matplotlib ≥3.6; omit if older
                    pad=0.02,            # distance from axes
                    fraction=0.045,      # size of colorbar relative to axes
                    aspect=40            # length/thickness ratio
                )
                # Label font
                cbar.set_label("Z-Score", fontproperties=termes_font_bold, fontsize=24)
                
                # Tick label font/size
                if cbar.orientation == "horizontal":
                    cbar.ax.xaxis.labelpad = 2    # tighten spacing if desired
                    for lab in cbar.ax.get_xticklabels():
                        lab.set_fontproperties(termes_font)
                        lab.set_fontsize(12)
                else:
                    cbar.ax.yaxis.labelpad = 2
                    for lab in cbar.ax.get_yticklabels():
                        lab.set_fontproperties(termes_font)
                        lab.set_fontsize(12)
            
            # Title and axis labels
            ax.set_title(ax.get_title(), fontproperties=termes_font_bold, fontsize=24)
            if not minimal_plots:
                ax.set_xlabel(ax.get_xlabel(), fontproperties=termes_font_bold, fontsize=16)
                ax.set_ylabel(ax.get_ylabel(), fontproperties=termes_font_bold, fontsize=16)
            
            # Tick labels
            for tick in ax.get_xticklabels() + ax.get_yticklabels():
                tick.set_fontproperties(termes_font)
                tick.set_fontsize(12)
            
            PAC_method = Analysis_class.get_PAC_method()
            norm_method = Analysis_class.get_normalization_method()
    
            # Set Colormap name based on PAC settings
            if PAC_method == 'Mean Vector Length':
                PAC_method = 'MVL'
            elif PAC_method == 'Modulation Index':
                PAC_method = 'MI'
            elif PAC_method == 'Heights Ratio':
                PAC_method = 'HR'
            elif PAC_method == 'ndPAC':
                pass
            elif PAC_method == 'Phase-Locking Value':
                PAC_method = 'PLV'
            elif PAC_method == 'Gaussian Copula PAC':
                PAC_method = 'GC PAC'
    
            if norm_method == 'No Normalization':
                norm_method = ''
            elif norm_method == 'Subtract Mean of Surrogtes':
                norm_method = 'Sub Mean of Surg.'
            elif norm_method == 'Divide Mean of Surrogates':
                norm_method = 'Div Mean of Surg.'
            elif norm_method == 'Sub+Div Mean of Surrogates':
                norm_method = 'Sub+Div Mean of Surg.'
            elif norm_method == 'Z-score':
                norm_method = 'Z-Score'
            
            # Colorbar (if shown)
            if ax.images and not minimal_plots:
                cbar = ax.figure.axes[-1]  # Last axis is usually colorbar
                if norm_method == 'Z-Score':
                    cbar.set_ylabel(norm_method, fontproperties=termes_font_bold, fontsize=16)
                    #cbar.set_fontsize(16)
                else:
                    cbar.set_ylabel(f"{PAC_method} ({norm_method})", fontproperties=termes_font_bold, fontsize=16)
                    #cbar.set_fontsize(16)
                for label in cbar.get_yticklabels():
                    label.set_fontproperties(termes_font)
                    label.set_fontsize(12)
    
            # Save comodulogram
            base_name = f"{name}_PAC"
            fig_path = os.path.join(output_directory, base_name + ".pdf")
            plt.savefig(fig_path, bbox_inches="tight", pad_inches=0.1)
            fig_path = os.path.join(output_directory, base_name + ".png")
            plt.savefig(fig_path, dpi=600, bbox_inches="tight", pad_inches=0.1)
            

        # Build raw arrays
        info = mne.create_info(["Synthetic Data"], sampling_rate, ch_types="eeg")
        mne_raw_noisy = mne.io.RawArray(raw_noisy[np.newaxis, :], info, verbose=False)
        mne_filter_notch = mne.io.RawArray(raw_after_notch, info, verbose=False)
        mne_filter_bandpass = mne.io.RawArray(raw_after_bandpass, info, verbose=False)

        # Create events per object
        ev_noisy = mne.make_fixed_length_events(mne_raw_noisy, duration=epoch_len)
        ev_notch = mne.make_fixed_length_events(mne_filter_notch, duration=epoch_len)
        ev_bp = mne.make_fixed_length_events(mne_filter_bandpass, duration=epoch_len)
                
        epochs_raw_noisy = mne.Epochs(mne_raw_noisy, ev_noisy, tmin=0, tmax=epoch_len, baseline=None, preload=True)
        epochs_filter_notch = mne.Epochs(mne_filter_notch, ev_notch, tmin=0, tmax=epoch_len, baseline=None, preload=True)
        epochs_filter_bandpass = mne.Epochs(mne_filter_bandpass, ev_bp, tmin=0, tmax=epoch_len, baseline=None, preload=True)
        
        compute_PAC(epochs_raw_noisy.get_data().squeeze(axis=1), "Raw Noisy Signal", sampling_rate, minimal_plots=True)
        compute_PAC(epochs_filter_notch.get_data().squeeze(axis=1), "Add Notch Filter", sampling_rate, minimal_plots=True)
        compute_PAC(epochs_filter_bandpass.get_data().squeeze(axis=1), "Add Bandpass Filter", sampling_rate, minimal_plots=True)
        compute_PAC(cleaned_epochs, "Add Downsample & Detrend", downsampled_sampling_rate, minimal_plots=True)
        
    except Exception as e:
            logging.error(f"Error in compute_PAC_of_channels: {e}", exc_info=True)
            show_popup(error_string)

In [18]:
##### JSON Plotting Functions #####

In [19]:
# PSD Plotting Functions

In [20]:
def extract_rows(d):
    rows = []
    for key, item in d["Data"].items():
        info   = item["Info"]
        rat    = info["Rat"]
        day    = info["Day"]                    # e.g. "Day0"
        sess   = info["Session"]                # e.g. "0001"
        probe  = info["Probe"]                  # e.g. "RAMG"

        bio = item["BiomarkerSummary"]
        for metric, val in bio.items():         # Broadband_dB, ThetaDelta_dB, …
            if isinstance(val, dict):           # handles {"Raw": .., "Filtered": ..}
                for subk, v in val.items():
                    rows.append([rat, day, sess, probe,
                                 f"{metric}_{subk}",   # → Broadband_dB_Raw
                                 float(v)])
            else:                               # single float
                rows.append([rat, day, sess, probe,
                             metric, float(val)])
    return rows

In [21]:
def load_case_json(json_path):
    with open(json_path, "r") as f:
        raw = json.load(f)
    rows = extract_rows(raw)
    df = pd.DataFrame(rows, columns=[
        "Rat", "Day", "Session", "Probe", "Metric", "Value"
    ])

    # numeric day for sorting (Day0→0, Day1→1, etc.)
    df["DayNum"] = df["Day"].str.extract(r"Day(\d+)").astype(int)
    return df

In [22]:
def plot_trajectories(json_path,
                      metrics=("Broadband_dB_Raw", "ThetaDelta_dB_Raw"),
                      probes=None,
                      agg_fn="mean",
                      save=False,
                      output_dir=None,
                      palette=None):

    # ---- 1) hard-coded colors (edit here) --------------------------
    SERIES_COLOR = "#111111"                   # one color for all figures
    # Or metric-specific:
    METRIC_COLORS = {
        "Broadband_dB_Raw":   "#111111",
        "ThetaDelta_dB_Raw":  "#1f78b4",
        "Aperiodic_Exponent": "#e31a1c",
    }
        
    df = load_case_json(json_path)

    # ---- Day0 session split and even-spacing keys ----------------
    # D0 sessions become D0:0..D0:3 using last digit; others become Day N
    def time_key(row):
        d = row["DayNum"]
        if d == 0:
            return f"D0:{int(str(row['Session'])[-1])}"
        return f"Day {d}"

    df["TimeKey"] = df.apply(time_key, axis=1)

    # Desired order (only keep keys that exist in this file)
    desired_order = ["D0:0", "D0:1", "D0:2", "D0:3",
                     "Day 1", "Day 3", "Day 7", "Day 14"]
    present = [k for k in desired_order if k in df["TimeKey"].unique()]
    df["TimeKey"] = pd.Categorical(df["TimeKey"], categories=present, ordered=True)

    if probes is not None:
        df = df[df["Probe"].isin(probes)]

    # ---- aggregate: mean ± SEM over rats/probes available --------
    agg = df.groupby(["TimeKey", "Metric"]).agg(
        mean=("Value", "mean"),
        sem =("Value", lambda x: x.std(ddof=1) / np.sqrt(len(x)))
    ).reset_index()

    # ---- output folder (once per call) ---------------------------
    if save and output_dir:
        output_dir = Path(output_dir)
        timestamp      = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
        session_folder = output_dir / f"PSD JSON Statistics {timestamp}"
        session_folder.mkdir(parents=True, exist_ok=True)
        stats_path = session_folder / f"{Path(json_path).stem}_trajectory_stats.tsv"
        stats_rows = []            # collect rows, write at the end

    # ---- plotting ------------------------------------------------
    for metric in metrics:
        # Messy exception for the TDR since it should be linear not logarithmic
        if metric in ("ThetaDelta_dB_Raw", "ThetaDelta_dB_Filtered"):
            """
            # Convert EACH value from dB → linear ratio, then aggregate
            sub_df = df[df["Metric"] == metric].copy()
            sub_df["Value_lin"] = 10 ** (sub_df["Value"] / 10.0)
    
            g = sub_df.groupby("TimeKey")["Value_lin"]
            sub = g.mean().reset_index(name="mean")
            sub["sem"] = g.std(ddof=1) / np.sqrt(g.count())
    
            sub = sub.sort_values("TimeKey")
            x_labels = sub["TimeKey"].astype(str).tolist()
            x_pos    = np.arange(len(x_labels))
            y        = sub["mean"].to_numpy()
            e        = sub["sem"].to_numpy()
            """

            sub_df = df[df["Metric"] == metric].copy()
            sub_df["Value_lin"] = 10 ** (sub_df["Value"] / 10.0)
            
            grp = (
                sub_df.groupby("TimeKey")["Value_lin"]
                .agg(mean="mean", std="std", n="count")
            )
            
            grp["sem"] = grp["std"] / np.sqrt(grp["n"])
            grp["sem"] = grp["sem"].fillna(0.0)     # in case n==1 → std NaN
            
            sub = grp.reset_index().sort_values("TimeKey")
            
            x_labels = sub["TimeKey"].astype(str).tolist()
            x_pos    = np.arange(len(x_labels))
            y        = sub["mean"].to_numpy(dtype=float)
            e        = sub["sem"].to_numpy(dtype=float)

            # record rows for the stats table
            for _, row in sub.iterrows():
                stats_rows.append({
                    "Metric"  : metric.replace("_", " "),
                    "TimeKey" : row["TimeKey"],
                    "Mean"    : row["mean"],
                    "SEM"     : row["sem"],
                    "N"       : row["n"] if "n" in row else np.nan   # optional
                })
    
            fig = plt.figure(figsize=(10, 5), constrained_layout=True)
            fig.set_constrained_layout_pads(w_pad=0.05, h_pad=0.05, wspace=0.05, hspace=0.05)
            color = METRIC_COLORS.get(metric, SERIES_COLOR)
            plt.errorbar(x_pos, y, yerr=e, marker="o", linestyle="-",
                         color=color, capsize=4)

            title = f"Trajectory of {metric}"
            ylabel = metric.replace("_", " ")
            if metric == "ThetaDelta_dB_Raw":
                title = "PSD Theta-Delta Ratio (4-12 Hz, 1-4 Hz)"
                ylabel = "Theta-Delta Ratio (Unitless)"
    
            fig = plt.gcf(); ax = plt.gca()
            xl = ax.set_xlabel("Time-point", fontproperties=termes_font_bold)
            yl = ax.set_ylabel(f"{ylabel}", fontproperties=termes_font_bold)
            t = ax.set_title(f"{title}", fontproperties=termes_font_bold)
            xl.set_fontsize(20)
            yl.set_fontsize(20)
            t.set_fontsize(26)

            # Hard-coded fix to label actualy shown on x-axis, donn't want to disturb df access
            x_map = {"D0:0": "Control Case", "D0:1": "DFP Inj.", "D0:2": "MDZ Interv."}
            x_labels_shown = [x_map.get(lbl, lbl) for lbl in x_labels]

            ax.set_xticks(x_pos, x_labels_shown)
            for lab in ax.get_xticklabels() + ax.get_yticklabels():
                lab.set_fontproperties(termes_font)
                lab.set_fontsize(16)
            ax.grid(True, alpha=.3)
            # Optional: clamp y to start at 0
            ax.set_ylim(bottom=0)
    
            fig.tight_layout()
    
            if save:
                png_name = session_folder / f"{Path(json_path).stem}_{metric}_LINEAR.png"
                pdf_name = session_folder / f"{Path(json_path).stem}_{metric}_LINEAR.pdf"
                plt.savefig(png_name, dpi=600, bbox_inches="tight", pad_inches=0.05)
                plt.savefig(pdf_name, bbox_inches="tight", pad_inches=0.05)
                plt.close(fig)
            else:
                plt.show()
                plt.close(fig)
    
            continue  # skip the default dB path below
        
        sub = agg[agg["Metric"] == metric].sort_values("TimeKey")
        x_labels = sub["TimeKey"].astype(str).tolist()
        x_pos    = np.arange(len(x_labels))            # even spacing
        y        = sub["mean"].to_numpy()
        e        = sub["sem"].to_numpy()

        # record rows for the stats table
        for _, row in sub.iterrows():
            stats_rows.append({
                "Metric"  : metric.replace("_", " "),
                "TimeKey" : row["TimeKey"],
                "Mean"    : row["mean"],
                "SEM"     : row["sem"],
                "N"       : row["n"] if "n" in row else np.nan   # optional
            })

        fig = plt.figure(figsize=(10, 5), constrained_layout=True)
        fig.set_constrained_layout_pads(w_pad=0.05, h_pad=0.05, wspace=0.05, hspace=0.05)
        color = METRIC_COLORS.get(metric, SERIES_COLOR)
        plt.errorbar(x_pos, y, yerr=e, marker="o", linestyle="-",
                     color=color, capsize=4)

        fig = plt.gcf()
        ax  = plt.gca()

        # axis labels & title (your thesis fonts)
        title = f"Trajectory of {metric}"
        ylabel = metric.replace("_", " ")
        if metric == "Broadband_dB_Raw":
            title = "PSD Broadband Power (1-200Hz)"
            ylabel = "Broadband (dB µV²/Hz)"
        elif metric == "ThetaDelta_dB_Raw":
            title = "PSD Theta-Delta Ratio (1-200Hz)"
            ylabel = "Theta-Delta Ratio (Unitless)"
        elif metric == "Aperiodic_Exponent":
            title = "PSD Aperiodic Exponent (2-40Hz)"
            ylabel = "Aperiodic Exponent (Unitless)"
        
        xl = ax.set_xlabel("Time-point", fontproperties=termes_font_bold)
        yl = ax.set_ylabel(ylabel, fontproperties=termes_font_bold)
        t = ax.set_title(f"{title}", fontproperties=termes_font_bold)
        xl.set_fontsize(20)
        yl.set_fontsize(20)
        t.set_fontsize(26)

        # Hard-coded fix to label actualy shown on x-axis, donn't want to disturb df access
        x_map = {"D0:0": "Control Case", "D0:1": "DFP Inj.", "D0:2": "MDZ Interv."}
        x_labels_shown = [x_map.get(lbl, lbl) for lbl in x_labels]
        
        # evenly spaced categorical ticks
        ax.set_xticks(x_pos, x_labels_shown)

        # tick label fonts
        for lab in ax.get_xticklabels() + ax.get_yticklabels():
            lab.set_fontproperties(termes_font)
            lab.set_fontsize(16)

        ax.grid(True, alpha=.3)
        fig.tight_layout()

        # ------------------------------------------------------------------
        # write stats table
        if save and stats_rows:
            import csv
            with stats_path.open("w", newline="") as f:
                writer = csv.DictWriter(f,
                                        fieldnames=["Metric", "TimeKey", "Mean", "SEM", "N"],
                                        delimiter="\t")
                writer.writeheader()
                writer.writerows(stats_rows)

        if save:
            png_name = session_folder / f"{Path(json_path).stem}_{metric}.png"
            pdf_name = session_folder / f"{Path(json_path).stem}_{metric}.pdf"
            plt.savefig(png_name, dpi=600, bbox_inches="tight", pad_inches=0.05)
            plt.savefig(pdf_name, bbox_inches="tight", pad_inches=0.05)
            plt.close(fig)
        else:
            plt.show()

In [23]:
# PAC Plotting Functions

In [24]:
def load_pac_summary_json(json_path):
    """
    Load PAC summary data from a JSON file.

    Returns:
        metadata (dict): Dictionary containing session-level metadata.
        pac_values (list of dict): List of entries with rat, day, channel, probe, and average_pac.
    """
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)

        metadata = data.get("Metadata", {})
        pac_values = data.get("Data", {})
        return metadata, pac_values

    except Exception as e:
        logging.error(f"Error in load_pac_summary_json: {e}", exc_info=True)
        return {}, []

In [25]:
def create_JSON_window(JSON_class):
    """
    Create an analysis window that displays JSON file data and plots average PAC scores.
    """
    try:
        dpg.add_button(label="Import JSON File",
                       parent="JSON_buttons_group",
                       callback=import_JSON_file_callback,
                       user_data=(JSON_class, "JSON_meta_window", "JSON_data_window", "JSON_plot_window"))

        with dpg.group(tag="JSON_window_group",
                           parent="JSON_child_window",
                           horizontal=True):

            with dpg.child_window(tag="JSON_meta_window", 
                                  border=True, 
                                  width=300, 
                                  autosize_y=True, 
                                  horizontal_scrollbar=True, 
                                  no_scrollbar=False, 
                                  menubar=True, 
                                  parent="JSON_window_group"):
                with dpg.menu_bar():
                        dpg.add_menu(label="JSON File Metadata")

            with dpg.child_window(tag="JSON_data_window", 
                                  border=True, 
                                  width=300, 
                                  autosize_y=True, 
                                  horizontal_scrollbar=True, 
                                  no_scrollbar=False, 
                                  menubar=True, 
                                  parent="JSON_window_group"):
                with dpg.menu_bar():
                        dpg.add_menu(label="JSON Data")

            with dpg.child_window(tag="JSON_plot_window", 
                                  border=True, 
                                  autosize_x=True,
                                  autosize_y=True, 
                                  parent="JSON_window_group"):
                with dpg.plot(label="Plot", tag="JSON_plot", width=-1, height=-1,no_title=True):        
                    with dpg.plot_axis(dpg.mvXAxis, tag="JSON_x_axis"):
                        pass 
                    with dpg.plot_axis(dpg.mvYAxis, tag="JSON_y_axis"):
                        pass
                        
    except Exception as e:
        logging.error(f"Error in create_JSON_window: {e}", exc_info=True)
        show_popup(error_string)

In [26]:
##### PAC Sanity Check Functions #####

In [27]:
def _bandpass_mne(x, fs, l_freq, h_freq):
    # Zero-phase FIR like your pipeline; same defaults each time
    return mne.filter.filter_data(
        x, sfreq=fs, l_freq=l_freq, h_freq=h_freq,
        method='fir', phase='zero', fir_window='hamming', verbose=False
    )

In [28]:
def inject_spike_into_raw_data(
    x, fs,
    phase_band=(6, 10),          # driver (e.g., theta)
    amp_band=(40, 80),           # modulated gamma
    depth=0.3,                   # modulation depth in [0, 1)
    target_hf_snr_db=None,       # added HF power (in amp_band) relative to *existing* HF power
    hf_scale=None,               # alternative: scale of injected HF vs sd(x) if SNR not given
    burst_duty=1.0,              # fraction of time with PAC on (0..1]
    burst_len_s=2.0,             # length of each burst if duty<1
    burst_random=True,           # randomize burst placement
    fade_s=0.05,                 # taper edges of bursts
    random_state=None,
):
    """
    Injects PAC into real data:
      - Envelope e(t) = 1 + depth*cos(phi(t)), where phi(t) is low-freq phase from `phase_band`
      - Carrier: band-passed noise in `amp_band`
      - Optional bursts via `burst_duty`, `burst_len_s`
      - Optional target SNR in `amp_band`

    Returns y (x with PAC added) and a dict with details + the injected component y_add.
    """
    rng = np.random.default_rng(random_state)
    x = np.asarray(x, dtype=float)
    n = x.size

    # 1) Low-frequency phase from the *real* signal
    lf = _bandpass_mne(x, fs, phase_band[0], phase_band[1])
    phi = np.angle(hilbert(lf))  # [-pi, pi]

    # 2) High-frequency carrier (band-limited noise)
    carrier = rng.normal(size=n)
    carrier = _bandpass_mne(carrier, fs, amp_band[0], amp_band[1])
    # normalize carrier variance
    carrier /= np.std(carrier) + 1e-12

    # 3) Phase-locked envelope (strictly positive)
    depth = float(np.clip(depth, 0.0, 0.99))
    env = 1.0 + depth * np.cos(phi)

    # 4) Optional bursts mask
    mask = np.ones(n, dtype=float)
    if burst_duty < 1.0:
        burst_len = max(1, int(round(burst_len_s * fs)))
        total_on = int(round(burst_duty * n))
        n_bursts = max(1, total_on // burst_len)
        mask[:] = 0.0
        indices = []
        if burst_random:
            # choose non-overlapping starts as best as possible
            max_start = max(1, n - burst_len - 1)
            starts = rng.choice(max_start, size=n_bursts, replace=False)
            starts.sort()
            indices = [slice(int(s), int(s) + burst_len) for s in starts]
        else:
            # evenly spaced
            step = max(1, (n - burst_len) // n_bursts)
            starts = np.arange(0, n_bursts * step, step)
            indices = [slice(int(s), int(s) + burst_len) for s in starts]

        for sl in indices:
            mask[sl] = 1.0

        # fade edges for continuity
        fade = int(max(1, round(fade_s * fs)))
        if fade > 1:
            win = get_window(("tukey", 1.0), 2 * fade)
            up, down = win[:fade], win[fade:]
            for sl in indices:
                a, b = sl.start, min(sl.stop, n)
                # fade in
                a2 = a
                b2 = min(a + fade, b)
                mask[a2:b2] *= up[: b2 - a2]
                # fade out
                a3 = max(a, b - fade)
                b3 = b
                mask[a3:b3] *= down[fade - (b3 - a3):]

    # 5) Form the injected HF component
    y_add = env * carrier * mask

    # 6) Scale to target SNR in amp_band (preferred), else hf_scale
    def _band_power(sig):
        sig_hf = _bandpass_mne(sig, fs, amp_band[0], amp_band[1])
        _, pxx = _safe_welch(sig_hf, fs)
        return float(np.trapz(pxx))  # rough but consistent for matching

    if target_hf_snr_db is not None:
        Px = _band_power(x)
        Py = _band_power(y_add)
        if Py < 1e-18:
            scale = 0.0
        else:
            target_ratio = 10 ** (target_hf_snr_db / 10.0)
            scale = np.sqrt((target_ratio * Px) / (Py + 1e-18))
        y_add *= scale
        scale_info = {"mode": "snr_db", "target_hf_snr_db": float(target_hf_snr_db), "scale": float(scale)}
    elif hf_scale is not None:
        y_add *= float(hf_scale) * (np.std(x) + 1e-12)
        scale_info = {"mode": "hf_scale", "hf_scale": float(hf_scale)}
    else:
        # default: roughly match HF band power to 25% of existing
        Px = _band_power(x)
        Py = _band_power(y_add)
        scale = 0.5 if Py < 1e-18 else np.sqrt((0.25 * Px) / (Py + 1e-18))
        y_add *= scale
        scale_info = {"mode": "auto_0.25_bandpower", "scale": float(scale)}

    y = x + y_add

    info = {
        "phase_band": tuple(phase_band),
        "amp_band": tuple(amp_band),
        "depth": float(depth),
        "burst_duty": float(burst_duty),
        "burst_len_s": float(burst_len_s),
        "fade_s": float(fade_s),
        "random_state": int(random_state) if random_state is not None else None,
        **scale_info,
    }
    return y, info, y_add


In [29]:
# Functions run during PAC Analysis

In [30]:
def _safe_welch(x, fs, nperseg=None, noverlap=None):
    try:
        n = len(x)
        if nperseg is None:
            # ~4–8 seconds is a decent default; clamp by length
            nperseg = int(min(n, max(fs * 4, 256)))
        if noverlap is None:
            noverlap = int(0.5 * nperseg)
        f, pxx = welch(x, fs=fs, nperseg=nperseg, noverlap=noverlap, detrend='constant')
        return f, pxx
    except Exception as e:
        import logging
        logging.error(f"Error in _safe_welch: {e}", exc_info=True)
        show_popup(error_string)

In [31]:
def phase_band_snr(
    x, fs, band=(4, 10), neighbor_bw=2.0, exclude_bw=0.5,
    remove_aperiodic=True, fit_range=(2, 200)
):
    """
    Estimate SNR (dB) of the dominant peak inside `band` vs adjacent flanks.
    If remove_aperiodic=True, fit 1/f on log-log outside the band and subtract.
    Returns (snr_db, info_dict).
    """
    try:
        x = np.asarray(x, dtype=float)
        f, pxx = _safe_welch(x, fs)
    
        # Optional 1/f correction
        pxx_work = pxx.copy()
        if remove_aperiodic:
            vr = (f >= fit_range[0]) & (f <= fit_range[1]) & ~((f >= band[0]) & (f <= band[1]))
            vf = f[vr]
            yp = np.log10(pxx[vr] + np.finfo(float).eps)
            xp = np.log10(vf + np.finfo(float).eps)
            # robust-ish linear fit
            A = np.vstack([np.ones_like(xp), xp]).T
            coeff, *_ = np.linalg.lstsq(A, yp, rcond=None)  # y = a + b*x
            a, b = coeff
            pxx_fit = 10 ** (a + b * np.log10(f + 1e-12))
            pxx_work = pxx / (pxx_fit + 1e-18)
    
        # Peak in band
        in_band = (f >= band[0]) & (f <= band[1])
        if not np.any(in_band):
            return np.nan, {"reason": "band outside PSD range"}
    
        idx_peak = np.argmax(pxx_work[in_band])
        f_band = f[in_band]
        p_band = pxx_work[in_band]
        f0 = f_band[idx_peak]
    
        # Define small windows around peak and its neighbors
        pk_win = (f >= max(f0 - exclude_bw, f[1])) & (f <= f0 + exclude_bw)
        lo_win = (f >= max(f0 - neighbor_bw - exclude_bw, f[1])) & (f < f0 - exclude_bw)
        hi_win = (f > f0 + exclude_bw) & (f <= f0 + neighbor_bw)
    
        if not (np.any(lo_win) and np.any(hi_win) and np.any(pk_win)):
            return np.nan, {"reason": "insufficient bins around peak", "f0": float(f0)}
    
        p_peak = np.mean(pxx_work[pk_win])
        p_flank = np.mean(np.concatenate([pxx_work[lo_win], pxx_work[hi_win]]))
    
        snr_db = 10.0 * np.log10((p_peak + 1e-18) / (p_flank + 1e-18))
        info = {
            "SNR dB" : snr_db,
            "f0": float(f0),
            "p_peak": float(p_peak),
            "p_flank": float(p_flank),
            "aperiodic_removed": bool(remove_aperiodic),
            "band": tuple(band),
            "neighbor_bw": float(neighbor_bw),
            "exclude_bw": float(exclude_bw),
        }
        return snr_db, info
    except Exception as e:
        import logging
        logging.error(f"Error in phase_band_snr: {e}", exc_info=True)
        show_popup(error_string)

In [32]:
def hf_band_power_metrics(
    x, fs,
    hf_band=(50, 120),          # your amplitude band of interest
    ref_band=(20, 150),         # broader HF reference for context
    remove_aperiodic=True,
    fit_range=(2, 200),         # range to fit 1/f outside bands
    _welch_fn=_safe_welch       # reuse the one we already wrote
):
    """
    Returns a dict with integrated HF power (aperiodic-corrected),
    its ratio to a broader HF reference band, and a percentile score
    telling you how 'elevated' the HF band is within the reference band.
    """
    try:
        x = np.asarray(x, dtype=float)
    
        f, pxx = _welch_fn(x, fs)
        pxx_work = pxx.copy()
    
        if remove_aperiodic:
            # Fit log10(pxx) = a + b*log10(f) outside both bands
            mask_fit = (f >= fit_range[0]) & (f <= fit_range[1])
            # Exclude hf_band from the fit (optional: also exclude ref_band edges)
            ex = (f >= hf_band[0]) & (f <= hf_band[1])
            vr = mask_fit & (~ex)
            vf = f[vr]
            if np.count_nonzero(vf) >= 10:
                xp = np.log10(vf + 1e-12)
                yp = np.log10(pxx[vr] + 1e-18)
                A = np.vstack([np.ones_like(xp), xp]).T
                a, b = np.linalg.lstsq(A, yp, rcond=None)[0]
                pxx_fit = 10 ** (a + b * np.log10(f + 1e-12))
                pxx_work = pxx / (pxx_fit + 1e-18)
    
        # Slices
        hf_sel  = (f >= hf_band[0])  & (f <= hf_band[1])
        ref_sel = (f >= ref_band[0]) & (f <= ref_band[1])
    
        if not np.any(hf_sel) or not np.any(ref_sel):
            return {
                "reason": "no bins in band(s)",
                "hf_band": tuple(hf_band), "ref_band": tuple(ref_band),
                "aperiodic_removed": bool(remove_aperiodic)
            }
    
        # Integrated 'periodic' power (area under PSD)
        hf_power  = float(np.trapz(pxx_work[hf_sel], f[hf_sel]))
        ref_power = float(np.trapz(pxx_work[ref_sel], f[ref_sel]))
        rel_power = hf_power / (ref_power + 1e-18)
    
        # Percentile of HF mean within the reference-bin distribution
        hf_mean = float(np.mean(pxx_work[hf_sel]))
        ref_vals = pxx_work[ref_sel]
        hf_percentile = float(100.0 * np.mean(ref_vals <= hf_mean))
    
        # Optional: dB of integrated power (on corrected spectrum)
        hf_power_db = float(10.0 * np.log10(hf_power + 1e-18))
    
        return {
            "hf_band": tuple(hf_band),
            "ref_band": tuple(ref_band),
            "aperiodic_removed": bool(remove_aperiodic),
            "hf_power": hf_power,
            "hf_power_db": hf_power_db,
            "ref_power": ref_power,
            "hf_rel_power": rel_power,             # fraction of ref power
            "hf_mean_bin_power": hf_mean,          # mean per-bin power in HF
            "hf_percentile_in_ref": hf_percentile  # 0–100
        }
    except Exception as e:
        import logging
        logging.error(f"Error in hf_band_power_metrics: {e}", exc_info=True)
        show_popup(error_string)

In [33]:
def parse_freq_at(freq_vector, i, mode="center"):
    """
    freq_vector: p.f_pha or p.f_amp. Elements may be scalars or [low, high].
    i: index along that axis
    mode: "center" | "band" | "low" | "high"
    """
    try:
        # Robust indexing (handles lists, np arrays, object arrays)
        f_i = np.asarray(freq_vector, dtype=object)[i]
    
        f_i = np.asarray(f_i)
        if f_i.ndim == 0:
            lo = hi = float(f_i)                     # scalar -> same low/high
        elif f_i.size == 2:
            lo, hi = float(f_i[0]), float(f_i[1])    # band [low, high]
        else:
            # Very rare: if something odd comes back, be defensive
            lo, hi = float(np.min(f_i)), float(np.max(f_i))
    
        center = (lo + hi) / 2.0
    
        if mode == "center":
            value = center
        elif mode == "low":
            value = lo
        elif mode == "high":
            value = hi
        elif mode == "band":
            value = (lo, hi)
        else:
            value = center
    
        return {"low": lo, "high": hi, "center": center, "value": value}
    except Exception as e:
        import logging
        logging.error(f"Error in parse_freq_at: {e}", exc_info=True)
        show_popup(error_string)

In [34]:
# Funtions run during Crunch PAC JSON Data - Coupling AVG Traces

In [35]:
def _apply_thesis_fonts(ax, title, xlabel, ylabel, termes_font=None, termes_font_bold=None):
    if termes_font_bold:
        title = ax.set_title(title, fontproperties=termes_font_bold)
        xl = ax.set_xlabel(xlabel, fontproperties=termes_font_bold)
        yl = ax.set_ylabel(ylabel, fontproperties=termes_font_bold)
    else:
        title = ax.set_title(title)
        xl = ax.set_xlabel(xlabel)
        yl = ax.set_ylabel(ylabel)
    if termes_font:
        for lab in ax.get_xticklabels() + ax.get_yticklabels():
            lab.set_fontproperties(termes_font)

    xl.set_fontsize(20)
    yl.set_fontsize(20)
    title.set_fontsize(26)

In [36]:
def _find_session_summary(path_like: str | os.PathLike) -> Path:
    """Accept a folder OR a specific json. Return the Session Summary path."""
    p = Path(path_like)
    if p.is_file():
        return p
    # Look for a likely “Session Summary” JSON in the folder tree
    candidates = list(p.glob("**/Session Summary.json")) + list(p.glob("**/*Session Summary*.json"))
    if not candidates:
        raise FileNotFoundError("No 'Session Summary' JSON found under folder.")
    # Prefer the most-recent
    candidates.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    return candidates[0]

In [37]:
def load_pac_json_to_long_df(json_or_folder: str | os.PathLike,
                             metrics=None,
                             probes: list[str] | None = None,
                             rats: list[int] | None = None) -> pd.DataFrame:
    """
    Load the Session Summary JSON and return a long DataFrame with columns:
      ['Rat','DayNum','Session','Channel','Probe','TimeKey',
       'primary_summary_name','Metric','Value']
    """
    path = _find_session_summary(json_or_folder)
    with open(path, "r") as f:
        root = json.load(f)

    # Check if JSON file contains PAC analysis
    skip = bool(root.get("Metadata", {}).get("skip_PAC", False))
    if skip:
        return None

    data = root.get("Data", {})
    if metrics is None:
        # Default set you asked for
        metrics = [
            "primary_summary_value",
            "peak_value",
            "peak_phase_hz",
            "peak_amplitude_hz",
            "z_abs_mean",
            "z_topk_mean",
            "frac_sig_ge_1.96",
        ]

    rows = []
    # Expect structure: Data[rat][day][session][channel] -> dict
    for rat_str, days_dict in data.items():
        try:
            rat = int(rat_str)
        except Exception:
            rat = rat_str  # fallback if non-numeric
        for day_str, sessions_dict in days_dict.items():
            try:
                day = int(day_str)
            except Exception:
                day = day_str
            for sess_code, chans_dict in sessions_dict.items():
                for ch_name, ch_info in chans_dict.items():
                    if not isinstance(ch_info, dict):
                        continue
                    probe = ch_info.get("Probe Type") or ch_info.get("Probe")
                    # Optional filtering
                    if probes is not None and probe not in probes:
                        continue
                    if rats is not None and (isinstance(rat, int) and rat not in rats):
                        continue

                    # Collect chosen metrics present for this channel
                    present_vals = {m: ch_info[m] for m in metrics if m in ch_info}
                    if not present_vals:
                        continue

                    primary_name = ch_info.get("primary_summary_name")
                    
                    # Build time key: D0:x for Day 0 sessions, else "Day N"
                    if day == 0:
                        # 0000 -> D0:0, 0001 -> D0:1, 0002 -> D0:2, etc.
                        try:
                            suffix = str(sess_code)[-1]
                        except Exception:
                            suffix = sess_code

                        #time_key = f"D0:{suffix}"
                        if suffix == "0":
                            time_key = "CRTL"
                        elif suffix == "1":
                            time_key = "DFP"
                        elif suffix == "2":
                            time_key = "MDZ"
                        
                    else:
                        time_key = f"Day {day}"

                    base = {
                        "Rat": rat,
                        "DayNum": day,
                        "Session": str(sess_code),
                        "Channel": str(ch_name),
                        "Probe": probe,
                        "TimeKey": time_key,
                        "primary_summary_name": primary_name,
                    }
                    # wide → rows (we’ll melt to long below; this is still fine)
                    base.update(present_vals)
                    rows.append(base)

    if not rows:
        raise ValueError("No rows parsed; check filters, metrics, or JSON structure.")

    df_wide = pd.DataFrame(rows)

    # Melt to long
    value_vars = [m for m in metrics if m in df_wide.columns]
    df_long = df_wide.melt(
        id_vars=["Rat", "DayNum", "Session", "Channel", "Probe", "TimeKey", "primary_summary_name"],
        value_vars=value_vars,
        var_name="Metric",
        value_name="Value",
    )

    # Enforce consistent categorical order only when you plot (case-specific)
    return df_long

In [38]:
def aggregate_mean_sem(df_long: pd.DataFrame) -> pd.DataFrame:
    """Aggregate mean ± SEM by TimeKey, Metric."""
    g = (
        df_long.groupby(["TimeKey", "Metric"])["Value"]
        .agg(mean="mean", std="std", n="count")
        .reset_index()
    )
    g["sem"] = g["std"] / np.sqrt(g["n"])
    g["sem"] = g["sem"].fillna(0.0)  # n==1 -> std NaN -> sem 0
    return g

In [39]:
def _desired_order_for_case(df_long: pd.DataFrame, case: str) -> list[str]:
    uniq = df_long["TimeKey"].unique().tolist()
    if case == "D0_sessions":
        order = ["D0:0", "D0:1", "D0:2", "D0:3"]
    elif case == "D0_1_to_14":
        order = ["D0:1", "Day 1", "Day 3", "Day 7", "Day 14"]
    elif case == "Full_House":
        order = ["CRTL", "DFP", "MDZ", "Day 1", "Day 3", "Day 7", "Day 14"]
    else:
        # Fallback: natural-ish order: D0:* first, then Day N ascending
        d0 = sorted([x for x in uniq if str(x).startswith("D0:")])
        later = sorted(
            [x for x in uniq if str(x).startswith("Day ")],
            key=lambda s: int(str(s).split()[1])
        )
        order = d0 + later
    return [x for x in order if x in uniq]

In [40]:
def plot_pac_trajectories(df_long: pd.DataFrame,
                          metrics: list[str] | None = None,
                          case: str = "D0_sessions",
                          save: bool = False,
                          output_dir: str | os.PathLike | None = None,
                          palette: dict[str, str] | None = None,
                          termes_font=None,
                          termes_font_bold=None,
                          clamp_y0: bool = False):
    """
    Plot mean ± SEM trajectories for selected metrics.
      case = "D0_sessions" or "D0_1_to_14" (your two study setups)
    """
    if metrics is None:
        metrics = sorted(df_long["Metric"].unique())

    order = _desired_order_for_case(df_long, case)
    if not order:
        raise ValueError("No timepoints found for the requested case.")

    # Filter to the timepoints of interest and lock the plotting order
    df_plot = df_long[df_long["TimeKey"].isin(order)].copy()
    df_plot["TimeKey"] = pd.Categorical(df_plot["TimeKey"], categories=order, ordered=True)

    agg = aggregate_mean_sem(df_plot)

    # Prepare optional save folder and stats file
    if save and output_dir:
        out = Path(output_dir)
        ts = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
        session_folder = out / f"Z-score Trajectories"
        session_folder.mkdir(parents=True, exist_ok=True)
        stats_path = session_folder / "trajectory_stats.tsv"
        # Write one combined stats table for everything we plot
        agg_sorted = agg.sort_values(["Metric", "TimeKey"])
        agg_sorted.to_csv(stats_path, sep="\t", index=False)

    # Helper to title/labels by metric
    def _metric_labels(metric: str) -> tuple[str, str]:
        # Y-axis label and title
        if metric == "primary_summary_value":
            # Try to surface the primary_summary_name for clarity
            names = df_plot.loc[df_plot["Metric"] == metric, "primary_summary_name"].dropna().unique()
            detail = f" ({names[0]})" if len(names) else ""
            return f"Primary{detail}", f"Primary Summary Value{detail}"
        elif metric == "z_abs_mean":
            return "Z-score (|Z|)", "Mean Comodulogram Z-score"
        elif metric == "z_topk_mean":
            return "Z-score (top-k mean)", "Top-k Mean Z"
        elif metric == "peak_value":
            return "Z-score", "Peak Z-score"
        elif metric == "peak_phase_hz":
            return "Frequency (Hz)", "Average PSD Theta Frequency Peak (4-12 Hz)"
        elif metric == "peak_amplitude_hz":
            return "Frequency (Hz)", "Peak Amplitude Frequency"
        elif metric == "frac_sig_ge_1.96":
            return "Fraction of significant bins (z≥1.96)", "Number of Bins (z≥1.96) Over Total Bins"
        else:
            return "Value", metric

    BRIGHT_COLORS = [
        '#d9d904',  # Yellow
        '#ff7f00',  # Orange
        '#e41a1c',  # Red
        '#a65628',  # Brown
        '#377eb8',  # Blue
        '#984ea3',  # Purple
        '#4daf4a',  # Green
        '#f781bf',  # Pink
    ]

    # Colors
    SERIES_COLOR = "#111111"
    default_palette = {m: BRIGHT_COLORS[i] for i, m in enumerate(metrics)}
    if palette:
        default_palette.update(palette)

    # Plot one figure per metric
    for metric in metrics:
        sub = agg[agg["Metric"] == metric].sort_values("TimeKey")
        if sub.empty:
            continue

        x_labels = sub["TimeKey"].astype(str).tolist()
        x = np.arange(len(x_labels))
        y = sub["mean"].to_numpy(dtype=float)
        e = sub["sem"].to_numpy(dtype=float)

        fig = plt.figure(figsize=(10, 5), constrained_layout=True)
        fig.set_constrained_layout_pads(w_pad=0.05, h_pad=0.05, wspace=0.05, hspace=0.05)
        ax = plt.gca()
        ax.errorbar(x, y, yerr=e, marker="o", linestyle="-",
                    color=default_palette.get(metric, SERIES_COLOR), capsize=4)

        ylabel, title = _metric_labels(metric)
        _apply_thesis_fonts(ax, title=title, xlabel="Time-point", ylabel=ylabel,
                            termes_font=termes_font, termes_font_bold=termes_font_bold)

        # Hard-coded fix to label actualy shown on x-axis, donn't want to disturb df access
        x_map = {"CRTL": "Control Case", "DFP": "DFP Inj.", "MDZ": "MDZ Interv."}
        x_labels_shown = [x_map.get(lbl, lbl) for lbl in x_labels]

        ax.set_xticks(x, x_labels_shown)
        for lab in ax.get_xticklabels() + ax.get_yticklabels():
            lab.set_fontproperties(termes_font)
            lab.set_fontsize(16)
        
        ax.grid(True, alpha=0.3)
        if clamp_y0:
            ax.set_ylim(bottom=0)

        #fig.tight_layout()

        if save and output_dir:
            base = f"{case}_{metric}".replace(" ", "_")
            png = session_folder / f"{base}.png"
            pdf = session_folder / f"{base}.pdf"
            plt.savefig(png, dpi=600, bbox_inches="tight", pad_inches=0.05)
            plt.savefig(pdf, bbox_inches="tight", pad_inches=0.05)
            plt.close(fig)
        else:
            plt.show()
            plt.close(fig)

In [41]:
# Funtions run during Crunch PAC JSON Data - SNR LF Box Plots

In [42]:
def load_snr_df(json_path, condition_map=None, bands_keep=None):
    """
    # --- 1) Parse SNR from your JSON into a tidy DataFrame ---
    Reads your Session Summary.json and returns a DataFrame with columns:
      ['rat','day','session','condition','channel','probe','band','snr_db','f0',
       'z_abs_mean','z_topk_mean','frac_sig_ge_1.96','run_json_path']
    condition_map maps session codes to labels, e.g. {'0000':'Control','0001':'DFP','0002':'MDZ'}.
    bands_keep limits to certain band names (as stored in JSON: 'Delta','Theta',...)
    """
    with open(json_path, "r") as f:
        obj = json.load(f)

    D = obj.get("Data", {})
    rows = []

    # If not provided, default to your three conditions; unknown codes pass through
    if condition_map is None:
        #condition_map = {"0000": "Control", "0001": "DFP", "0002": "MDZ"}
        condition_map = {"0000": "CRTL", 
                         "0001": "DFP", 
                         "0002": "MDZ",
                         "1": "Day 1",
                         "3": "Day 3",
                         "7": "Day 7",
                         "14": "Day 14"}

    # Walk nested dicts: rat -> day -> session -> channel
    for rat_key, rat_val in D.items():
        for day_key, day_val in rat_val.items():
            for session_key, session_val in day_val.items():
                if day_key != "0":
                    condition = condition_map.get(day_key, day_key)
                else:
                    condition = condition_map.get(session_key, session_key)
                if not isinstance(session_val, dict):
                    continue

                for channel, ch_dict in session_val.items():
                    if not isinstance(ch_dict, dict):
                        continue

                    # Pull optional PAC metrics
                    probe   = ch_dict.get("Probe Type", "")
                    z_abs   = ch_dict.get("z_abs_mean", np.nan)
                    z_topk  = ch_dict.get("z_topk_mean", np.nan)
                    frac196 = ch_dict.get("frac_sig_ge_1.96", np.nan)
                    runjp   = ch_dict.get("run_json_path", "")

                    # Find SNR band blocks (they’re the dicts with "SNR dB")
                    for band_name, band_dict in ch_dict.items():
                        if bands_keep is not None and band_name not in bands_keep:
                            continue
                        if isinstance(band_dict, dict) and "SNR dB" in band_dict:
                            rows.append({
                                "rat": str(rat_key),
                                "day": str(day_key),
                                "session": str(session_key),
                                "condition": str(condition),
                                "channel": str(channel),
                                "probe": str(probe),
                                "band": str(band_name),
                                "snr_db": float(band_dict.get("SNR dB", np.nan)),
                                "f0": float(band_dict.get("f0", np.nan)),
                                "z_abs_mean": float(z_abs) if z_abs is not None else np.nan,
                                "z_topk_mean": float(z_topk) if z_topk is not None else np.nan,
                                "frac_sig_ge_1.96": float(frac196) if frac196 is not None else np.nan,
                                "run_json_path": str(runjp),
                            })

    df = pd.DataFrame(rows)
    # Optional: order condition
    cond_order = ["CRTL", "DFP", "MDZ", "Day 1", "Day 3", "Day 7", "Day 14"]
    df["condition"] = pd.Categorical(df["condition"], categories=cond_order, ordered=True)
    return df

In [43]:
def summarize_snr(df):
    """
    # --- 3) Tabular summary (median/IQR) you can paste into the thesis ---
    Returns a wide table with metrics (n, median, q1, q3, iqr) by band x condition.
    Columns are a MultiIndex: (metric, condition).
    """
    if df is None or df.empty:
        return pd.DataFrame()

    # Make sure expected columns exist
    needed = {"band", "condition", "snr_db"}
    missing = needed - set(df.columns)
    if missing:
        raise ValueError(f"summarize_snr: missing columns {missing}")

    #print("summarize_snr df cols:", df.columns.tolist())

    # Compute stats cleanly with agg (no weird column levels)
    stats = (
        df.dropna(subset=["snr_db"])
          .groupby(["band", "condition"])["snr_db"]
          .agg(
              n="count",
              median="median",
              q1=lambda s: np.percentile(s, 25),
              q3=lambda s: np.percentile(s, 75),
          )
          .reset_index()
    )
    #print("stats cols:", stats.columns.tolist())
    stats["iqr"] = stats["q3"] - stats["q1"]

    # Wide table (band as rows, metrics x condition as columns)
    wide = stats.pivot(index="band", columns="condition", values=["n", "median", "q1", "q3", "iqr"])
    # Optional: sort column levels consistently
    wide = wide.sort_index(axis=1, level=0)

    return wide


In [44]:
def plot_snr_distributions(df, save_dir=None, kind="box", thresholds=(3.0, 6.0), show_probe_legend=True):
    BRIGHT_COLORS = [
        '#e41a1c',  # Red
        '#ff7f00',  # Orange
        '#d9d904',  # Yellow
        '#4daf4a',  # Green
        '#377eb8',  # Blue
        '#984ea3',  # Purple
        '#a65628',  # Brown
        '#f781bf',  # Pink
    ]

    band_range_str = {'Delta' : '(1-4 Hz)', 'Theta' : '(6-10 Hz)', 'Wide Theta' : '(4-12 Hz)', 'High Theta': '(10-16 Hz)', 'Beta' : '(13-30 Hz)'}
    PROBES = ['RAMG', 'RPFC', 'RVHPC', 'RDHPC', 'LHPCSCREW', 'RHPCSCREW', 'LPFCSCREW', 'RPFCSCREW']
    probe_color_map = {probe: BRIGHT_COLORS[i % len(BRIGHT_COLORS)] for i, probe in enumerate(PROBES)}

    figs = {}
    bands = list(df["band"].dropna().unique())
    # If condition is categorical great; if not, just use unique order:
    conds = list(df["condition"].cat.categories) if hasattr(df["condition"], "cat") else list(df["condition"].unique())
    conds = [c for c in conds if c in df["condition"].unique()]

    for band in bands:
        sub = df[df["band"] == band]
        data = [sub[sub["condition"] == c]["snr_db"].dropna().values for c in conds]

        fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True)
        fig.set_constrained_layout_pads(w_pad=0.05, h_pad=0.05, wspace=0.05, hspace=0.05)
        if kind == "violin":
            ax.violinplot(data, showmeans=True, showextrema=False)
            ax.set_xticks(range(1, len(conds) + 1))
            ax.set_xticklabels(conds)
        else:
            ax.boxplot(data, labels=conds, showfliers=False)

        # ---- probe-colored jitter points ----
        handles_needed = {}
        for i, c in enumerate(conds, start=1):
            cdf = sub[(sub["condition"] == c) & (~sub["snr_db"].isna())]
            if cdf.empty:
                continue
            # jittered x per-point
            xj = np.random.normal(loc=i, scale=0.05, size=len(cdf))
            colors = [probe_color_map.get(p, "#444444") for p in cdf["probe"]]
            ax.scatter(xj, cdf["snr_db"].values, s=12, alpha=0.7, zorder=3, c=colors, edgecolors="none")
            # track probes present for legend
            for p in cdf["probe"].unique():
                if p not in handles_needed:
                    handles_needed[p] = Line2D([0], [0], marker='o', linestyle='',
                                               markerfacecolor=probe_color_map.get(p, "#444444"),
                                               markersize=6, label=p)

        yl = ax.set_ylabel("Phase-band SNR (dB)")
        yl.set_fontproperties(termes_font_bold)
        yl.set_fontsize(16)
        
        title = ax.set_title(f"SNR: {band} {band_range_str[band]}")
        # your fixed y-lims (adjust if you want negatives visible)
        ax.set_ylim(top=8, bottom=-2)

        for thr in thresholds:
            ax.axhline(thr, linestyle="--", linewidth=1)
            #ax.text(0.02, (thr - ax.get_ylim()[0])/(ax.get_ylim()[1]-ax.get_ylim()[0]),
            #        f"{thr:.0f} dB", transform=ax.transAxes, va="bottom", fontsize=8)

        if show_probe_legend and handles_needed:
            legend = ax.legend(handles=list(handles_needed.values()), loc='upper right', ncols=2, fontsize=8)

        # Adjust font
        for text in legend.get_texts():
            text.set_fontproperties(termes_font)
        for label in ax.get_xticklabels() + ax.get_yticklabels():
            label.set_fontproperties(termes_font_bold)
        title.set_fontproperties(termes_font_bold)
        title.set_fontsize(20)
        
        #ax.grid(True, axis="y", alpha=0.3)
        fig.tight_layout()

        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            base = os.path.join(save_dir, f"SNR_{band.replace(' ','_')}_{kind}")
            fig.savefig(base + ".png", dpi=600, bbox_inches="tight", pad_inches=0.2)
            fig.savefig(base + ".pdf", bbox_inches="tight", pad_inches=0.2)

        figs[band] = fig
        plt.close(fig)

    return figs

In [45]:
# Funtions run during Crunch PAC JSON Data - SNR HF Box Plots

In [46]:
def load_hf_df(json_path, condition_map=None):
    """
    Parse 'HF Power' blocks from Session Summary.json into a tidy DataFrame.
    Columns:
      ['rat','day','session','condition','channel','probe',
       'hf_band_low','hf_band_high','ref_band_low','ref_band_high',
       'hf_power','hf_power_db','ref_power','hf_rel_power','hf_percentile_in_ref',
       'run_json_path']
    """
    with open(json_path, "r") as f:
        obj = json.load(f)

    D = obj.get("Data", {})
    rows = []

    if condition_map is None:
        condition_map = {"0000": "CRTL", 
                         "0001": "DFP", 
                         "0002": "MDZ",
                         "1": "Day 1",
                         "3": "Day 3",
                         "7": "Day 7",
                         "14": "Day 14"}

    for rat_key, rat_val in D.items():
        for day_key, day_val in rat_val.items():
            for session_key, session_val in day_val.items():
                if day_key != "0":
                    condition = condition_map.get(day_key, day_key)
                else:
                    condition = condition_map.get(session_key, session_key)
                if not isinstance(session_val, dict):
                    continue

                for channel, ch_dict in session_val.items():
                    if not isinstance(ch_dict, dict):
                        continue

                    hf = ch_dict.get("HF Power")
                    if not isinstance(hf, dict):
                        continue  # skip channels without HF entry

                    probe   = ch_dict.get("Probe Type", "")
                    runjp   = ch_dict.get("run_json_path", "")

                    hf_band = hf.get("hf_band", [np.nan, np.nan])
                    ref_band = hf.get("ref_band", [np.nan, np.nan])

                    rows.append({
                        "rat": str(rat_key),
                        "day": str(day_key),
                        "session": str(session_key),
                        "condition": str(condition),
                        "channel": str(channel),
                        "probe": str(probe),

                        "hf_band_low":  float(hf_band[0]) if len(hf_band) >= 2 else np.nan,
                        "hf_band_high": float(hf_band[1]) if len(hf_band) >= 2 else np.nan,
                        "ref_band_low":  float(ref_band[0]) if len(ref_band) >= 2 else np.nan,
                        "ref_band_high": float(ref_band[1]) if len(ref_band) >= 2 else np.nan,

                        "hf_power": float(hf.get("hf_power", np.nan)),
                        "hf_power_db": float(hf.get("hf_power_db", np.nan)),
                        "ref_power": float(hf.get("ref_power", np.nan)),
                        "hf_rel_power": float(hf.get("hf_rel_power", np.nan)),
                        "hf_percentile_in_ref": float(hf.get("hf_percentile_in_ref", np.nan)),

                        "run_json_path": str(runjp),
                    })

    df = pd.DataFrame(rows)
    # Optional: ordered categories for condition
    cond_order = ["CRTL", "DFP", "MDZ", "Day 1", "Day 3", "Day 7", "Day 14"]
    df["condition"] = pd.Categorical(df["condition"], categories=cond_order, ordered=True)
    return df

In [47]:
def plot_hf_distributions(df, metric="hf_rel_power", save_dir=None, kind="box",
                          thresholds=None, probe_color_map=None, show_probe_legend=True):
    """
    One figure per metric (grouped by condition), dots colored by probe.
    thresholds: e.g., for rel_power use (0.1, 0.2), for percentile use (50,), etc.
    """
    BRIGHT_COLORS = [
        '#e41a1c',  # Red
        '#ff7f00',  # Orange
        '#d9d904',  # Yellow
        '#4daf4a',  # Green
        '#377eb8',  # Blue
        '#984ea3',  # Purple
        '#a65628',  # Brown
        '#f781bf',  # Pink
    ]

    PROBES = ['RAMG', 'RPFC', 'RVHPC', 'RDHPC', 'LHPCSCREW', 'RHPCSCREW', 'LPFCSCREW', 'RPFCSCREW']
    probe_color_map = {probe: BRIGHT_COLORS[i % len(BRIGHT_COLORS)] for i, probe in enumerate(PROBES)}

    figs = {}
    # Guard
    if metric not in df.columns:
        raise ValueError(f"metric '{metric}' not found in DataFrame")

    conds = list(df["condition"].cat.categories) if hasattr(df["condition"], "cat") else list(df["condition"].unique())
    conds = [c for c in conds if c in df["condition"].unique()]

    fig, ax = plt.subplots(figsize=(10, 5))
    data = [df[df["condition"] == c][metric].dropna().values for c in conds]

    if kind == "violin":
        ax.violinplot(data, showmeans=True, showextrema=False)
        ax.set_xticks(range(1, len(conds) + 1))
        ax.set_xticklabels(conds)
    else:
        ax.boxplot(data, labels=conds, showfliers=False)

    # probe-colored dots
    handles_needed = {}
    for i, c in enumerate(conds, start=1):
        cdf = df[(df["condition"] == c) & (~df[metric].isna())]
        if cdf.empty:
            continue
        xj = np.random.normal(loc=i, scale=0.05, size=len(cdf))
        colors = [probe_color_map.get(p, "#444444") for p in cdf["probe"]]
        ax.scatter(xj, cdf[metric].values, s=12, alpha=0.7, zorder=3, c=colors, edgecolors="none")
        for p in cdf["probe"].unique():
            if p not in handles_needed:
                handles_needed[p] = Line2D([0], [0], marker='o', linestyle='',
                                           markerfacecolor=probe_color_map.get(p, "#444444"),
                                           markersize=6, label=p)

    # Labels
    ylab = {
        "hf_rel_power": "HF Relative Power (HF / Ref)",
        "hf_percentile_in_ref": "HF Mean Percentile in Ref (%)",
        "hf_power_db": "HF Integrated Power (dB)"
    }.get(metric, metric)
    yl = ax.set_ylabel(ylab)
    yl.set_fontproperties(termes_font_bold)
    yl.set_fontsize(16)
    title = ax.set_title(f"{ylab}")

    if thresholds:
        for thr in thresholds:
            ax.axhline(thr, linestyle="--", linewidth=1)

    if show_probe_legend and handles_needed:
        legend = ax.legend(handles=list(handles_needed.values()), loc="upper right", ncols=2, fontsize=8)

    # Adjust font
    for text in legend.get_texts():
        text.set_fontproperties(termes_font)
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontproperties(termes_font_bold)
    title.set_fontproperties(termes_font_bold)
    title.set_fontsize(20)

    #ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        base = os.path.join(save_dir, f"HF_{metric}_{kind}")
        fig.savefig(base + ".png", dpi=600, bbox_inches="tight", pad_inches=0.2)
        fig.savefig(base + ".pdf", bbox_inches="tight", pad_inches=0.2)

    figs[metric] = fig
    plt.close(fig)
    return figs

In [48]:
##### Analysis Functions #####

In [49]:
def generate_synthetic_data():
    try:
        # TensorPAC method
        f_pha = 12     # frequency phase for the coupling
        f_amp = 75     # frequency amplitude for the coupling
        n_epochs = 30  # number of trials
        epoch_len = 20 # length of epoch in seconds
        sf = 2000.     # sampling frequency
        n_times = int(epoch_len * sf)  # number of time points
        n_samples = int(n_epochs * epoch_len * sf)
        sampling_rate = sf
        tensorpac_epochs, time = pac_signals_wavelet(sf=sf, f_pha=f_pha, f_amp=f_amp, noise=1.0,
                                      n_epochs=n_epochs, n_times=n_times)
        raw_data = tensorpac_epochs.flatten()
        
        # Manual method
        #raw_data, _ = generate_custom_pac(f_pha, f_amp, sf, n_epochs * epoch_len)
        
        # Scale synthetic PAC signal to realistic amplitude (e.g., ±0.5 mV)
        target_peak_to_peak = 10  # mV total span (±5 mV)
        actual_ptp = np.ptp(raw_data)
        scale_factor = target_peak_to_peak / actual_ptp
        raw_data *= scale_factor * 0.015 # Adjust const to weaken coupling
        
        ### Add noise to signal ###
        scaling_factor = 0.025
        
        # Compute noise using both custom methods
        aperiodic = generate_1overf_noise(n_samples=n_samples, sf=sf, exponent=1, amplitude=2.0)
        # Add additional noise to the signal for more authenticity
        power = scaling_factor * np.sin(2 * np.pi * 60 * n_samples)         # 60 Hz interference
        power_harmonic = scaling_factor * 0.6 * np.sin(2 * np.pi * 120 * n_samples)  # 120 Hz harmonic
        power_harmonic_2 = scaling_factor * 0.3 * np.sin(2 * np.pi * 180 * n_samples)  # 120 Hz harmonic

        # Combine with coupled data
        raw_data += scaling_factor * aperiodic
        raw_data += power + power_harmonic + power_harmonic_2
        return raw_data
    except Exception as e:
        import logging
        logging.error(f"Error in generate_synthetic_data: {e}", exc_info=True)
        show_popup(error_string)

In [50]:
def capture_stdout(func, *args, **kwargs):
    """
    Utility function that performs a function but directs the console output
    to a buffer that can be stored in a JSON or txt. NOTE data is passed by 
    reference, the original values are modified, not copies.
    
    result = original result of function, use _ if no variable is set to result
    buf.getvalue() = console output
    None/e = returns if an error occurs
    
    Example use:
    raw = mne.io.RawArray(raw_data[np.newaxis, :], info)
    Turns into,
    raw, log, e = capture_stdout(mne.io.RawArray, raw_data[np.newaxis, :], info)
    """
    buf = io.StringIO()
    try:
        with contextlib.redirect_stdout(buf):
            result = func(*args, **kwargs)
        return result, buf.getvalue(), None
    except Exception as e:
        return None, buf.getvalue(), e

In [51]:
def save_pac_outputs(
    Analysis_class, output_directory, rat, channel_name, day_num, full_day_desc, start_s, end_s, 
    timestamp, p, log_lines, pac_map, json_path, probe, day_session, SNR_JSON_data, HF_JSON_data,
    skip_PAC
):
    """
    Save PAC comodulogram figure(s) and JSON outputs.

    Writes:
      1) A per-run JSON beside the figure with full details (comod matrix, peak, logs, metadata).
      2) Updates the aggregator JSON at `json_path` with compact per-channel/day summaries.
    """
    try:
        # ---- Export settings
        (output_folder_name,
         export_PNG, export_PDF, export_SVG, export_EPS,
         image_height, image_width, image_DPI, color_palette,
         y_custom_axis, yaxis_top, yaxis_bottom,
         alpha, save_data) = Analysis_class.get_export_parameters()

        # Hard-coded setting to simplify plots for placing in thesis
        minimal_plots = True

        if skip_PAC == False:
            (PAC_custom_colormap, PAC_vmin, PAC_vmax, PAC_comod_interpolation) = Analysis_class.get_all_PAC_export_parameters()
    
            if PAC_custom_colormap is False:
                PAC_vmin, PAC_vmax = None, None
    
            if PAC_comod_interpolation == 0.0:
                PAC_comod_interpolation = None
            else:
                PAC_comod_interpolation = (PAC_comod_interpolation, PAC_comod_interpolation)
    
            PAC_method = Analysis_class.get_PAC_method()
            norm_method = Analysis_class.get_normalization_method()
    
            # Label short-hands
            PAC_method_map = {
                'Mean Vector Length': 'MVL',
                'Modulation Index': 'MI',
                'Heights Ratio': 'HR',
                'ndPAC': 'ndPAC',
                'Phase-Locking Value': 'PLV',
                'Gaussian Copula PAC': 'GC PAC'
            }
            PAC_method = PAC_method_map.get(PAC_method, PAC_method)
    
            norm_method_map = {
                'No Normalization': '',
                'Subtract Mean of Surrogates': 'Sub Mean of Surg.',
                'Divide Mean of Surrogates': 'Div Mean of Surg.',
                'Sub+Div Mean of Surrogates': 'Sub+Div Mean of Surg.',
                'Z-score': 'Z-Score'
            }
            norm_method_label = norm_method_map.get(norm_method, norm_method)
    
            # ---- Comod & p-values
            # pac_map: typically (n_phase, n_amp, n_epochs). We average over last axis:
            comod = pac_map.mean(-1)  # 2D array [n_phase, n_amp]
            # Be careful: infer_pvalues may return numeric p-values or a boolean significance mask.
            pvals = p.infer_pvalues(p=0.05)

            # Hide axis labels in minimal mode; keep ticks + title
            xlabel = None if minimal_plots else "Frequency for Phase (Hz)"
            ylabel = None if minimal_plots else "Frequency for Amplitude (Hz)"
    
            # ---- Plot
            if minimal_plots:
                fig = plt.figure(figsize=(5.236,5), constrained_layout=True)
                fig.set_constrained_layout_pads(w_pad=0.1, h_pad=0.1, wspace=0.1, hspace=0.1)
                title = f"{probe}"
                colorbar = False
            else:
                plt.figure(figsize=(image_width, image_height))
                title=f"Rat{rat} | {full_day_desc} | {day_session} | {probe}"
                colorbar = True
            p.comodulogram(
                comod,
                title=title,
                xlabel=xlabel,
                ylabel=ylabel,
                fz_labels=10,
                cmap='viridis',
                colorbar=colorbar,
                vmin=PAC_vmin,
                vmax=PAC_vmax,
                interp=PAC_comod_interpolation
            )
    
            ax = plt.gca()
            ax.set_title(ax.get_title(), fontproperties=termes_font_bold, fontsize=24)
            if not minimal_plots:
                ax.set_xlabel(ax.get_xlabel(), fontproperties=termes_font_bold, fontsize=16)
                ax.set_ylabel(ax.get_ylabel(), fontproperties=termes_font_bold, fontsize=16)
            for tick in ax.get_xticklabels() + ax.get_yticklabels():
                tick.set_fontproperties(termes_font)
    
            if ax.images and not minimal_plots:
                cbar = ax.figure.axes[-1]
                if norm_method_label == 'Z-Score':
                    cbar.set_ylabel(norm_method_label, fontproperties=termes_font_bold)
                else:
                    cbar.set_ylabel((PAC_method + ("  (" + norm_method_label + ")" if norm_method_label else "")),
                                    fontproperties=termes_font_bold)
                for label in cbar.get_yticklabels():
                    label.set_fontproperties(termes_font)
    
            # ---- Paths
            rat_folder = os.path.join(output_directory, f"Rat{rat}")
            day_folder = os.path.join(rat_folder, f"{full_day_desc}")
            os.makedirs(day_folder, exist_ok=True)
    
            start_min = int(start_s / 60)
            end_min   = int(end_s   / 60)
            base_name = f"Rat{rat}_{channel_name}_min{start_min}-{end_min}_{timestamp}"
    
            # ---- Save figure(s)
            if export_PNG:
                plt.savefig(os.path.join(day_folder, base_name + ".png"), dpi=image_DPI)
            if export_PDF:
                plt.savefig(os.path.join(day_folder, base_name + ".pdf"))
            if export_SVG:
                plt.savefig(os.path.join(day_folder, base_name + ".svg"))
            if export_EPS:
                plt.savefig(os.path.join(day_folder, base_name + ".eps"))
            plt.close()
    
            # Peak index (guard against all-NaN just in case)
            if np.all(np.isnan(comod)):
                raise ValueError("Comodulogram is all NaN; cannot compute a peak.")
            
            idx = np.unravel_index(np.nanargmax(comod, axis=None), comod.shape)
            
            pha_info = parse_freq_at(p.f_pha, idx[0], mode="center")
            amp_info = parse_freq_at(p.f_amp, idx[1], mode="center")
            
            pha_f    = float(pha_info["center"])   # for single-number reporting
            amp_f    = float(amp_info["center"])
            peak_val = float(comod[idx])
    
            # pval/Significance at peak
            peak_p = None
            peak_sig = None
            if pvals is not None:
                # Heuristics: detect shape & dtype
                pv = pvals[idx]
                if isinstance(pv, (float, np.floating)):
                    peak_p = float(pv)
                elif isinstance(pv, (bool, np.bool_)):
                    peak_sig = bool(pv)
                # else: leave None
    
            # ---- Summaries suited for Z-score (fall back gracefully if not Z)
            is_z = (norm_method == 'Z-score')
            z = comod if is_z else comod  # if not Z, these are just "values"; still useful
    
            # Numerical summaries (NaN-safe)
            def safe_mean(a): return float(np.nanmean(a))
            def safe_percentile(a, q): return float(np.nanpercentile(a, q))
    
            z_abs_mean  = safe_mean(np.abs(z))
            z_mean_pos  = safe_mean(np.clip(z, 0, None))
            z_mean      = safe_mean(z)
            z_p95       = safe_percentile(z, 95.0)
            # Top-k mean (5% of bins)
            flat = z.ravel()
            k = max(1, int(0.05 * flat.size))
            if k < flat.size:
                kth = np.partition(flat, flat.size - k)[-k:]
                z_topk_mean = float(np.nanmean(kth))
            else:
                z_topk_mean = z_p95
    
            # Fraction significant & excess area (only meaningful if Z)
            z_thr = 1.96
            frac_sig = float(np.mean((z >= z_thr))) if is_z else None
            excess_area = float(np.mean(np.clip(z - z_thr, 0, None))) if is_z else None
    
            # ---- Per-run JSON (full details; can be large)
            # To keep size sane, round to 4 decimals (you can change this).
            comod_rounded = np.round(comod.astype(np.float32), 4).tolist()
            # You can also choose to omit p-values matrix if it’s large; here we store just peak info.
    
            run_json = {
                "meta": {
                    "rat": int(rat),
                    "day_number": int(day_num),
                    "full_day_desc": full_day_desc,
                    "channel": channel_name,
                    "chunk_min": [start_min, end_min],
                    "probe": probe,
                    "day_session": day_session,
                    "timestamp": timestamp,
                    "pac_method": PAC_method,
                    "normalization": norm_method,          # full string
                    "normalization_label": norm_method_label,  # short label used on plots
                    "vmin": PAC_vmin,
                    "vmax": PAC_vmax,
                    "interp": PAC_comod_interpolation
                },
                "comodulogram": {
                    "shape": [int(comod.shape[0]), int(comod.shape[1])],
                    "values": comod_rounded,
                },
                "peak": {
                    "phase_center_hz": pha_info["center"],
                    "phase_low_hz": pha_info["low"],
                    "phase_high_hz": pha_info["high"],
                    "amp_center_hz": amp_info["center"],
                    "amp_low_hz": amp_info["low"],
                    "amp_high_hz": amp_info["high"],
                    "value": peak_val,
                    "p_value": peak_p,
                    "significant_at_alpha": peak_sig
                },
                "summaries": {
                    "is_z": bool(is_z),
                    "z_mean": z_mean,
                    "z_mean_pos": z_mean_pos,
                    "z_abs_mean": z_abs_mean,
                    "z_topk_mean": z_topk_mean,
                    "z_p95": z_p95,
                    "frac_sig_ge_1.96": frac_sig,
                    "excess_area_over_1.96": excess_area,
                    "primary_summary": "z_mean_pos" if is_z else "z_mean"
                },
                "log": list(log_lines)
            }
    
            run_json_path = os.path.join(day_folder, base_name + ".json")
            with open(run_json_path, "w") as f:
                json.dump(run_json, f, indent=2)
    
            # ---- Aggregator JSON (compact; for plotting over time)
            rat_str, day_str, day_session_str = str(rat), str(day_num), str(day_session)
    
            # Load or init aggregator
            if os.path.exists(json_path):
                with open(json_path, "r") as f:
                    pac_data = json.load(f)
            else:
                pac_data = {}
    
            pac_data.setdefault("Data", {})
            pac_data["Data"].setdefault(rat_str, {})
            pac_data["Data"][rat_str].setdefault(day_str, {})
            pac_data["Data"][rat_str][day_str].setdefault(day_session_str, {})
    
            # Choose which metric is your “average coupling” line:
            # If Z -> use z_mean_pos (no cancellation). Else -> mean of values.
            average_coupling = z_mean_pos if is_z else z_mean
    
            pac_data["Data"][rat_str][day_str][day_session_str][channel_name] = {
                "Probe Type": probe,
                "normalization": norm_method,
                "pac_method": PAC_method,
                "primary_summary_value": average_coupling,
                "primary_summary_name": "mean_positive_z" if is_z else "mean_value",
                "peak_value": peak_val,
                "peak_phase_hz": pha_f,
                "peak_amplitude_hz": amp_f,
                "run_json_path": run_json_path  # so you can reload full details later
            }
    
            # Optional: stash a couple more summaries you may plot later
            pac_data["Data"][rat_str][day_str][day_session_str][channel_name].update({
                "z_abs_mean": z_abs_mean,
                "z_topk_mean": z_topk_mean,
                "frac_sig_ge_1.96": frac_sig
            })
    
            if SNR_JSON_data:
                pac_data["Data"][rat_str][day_str][day_session_str][channel_name].update(SNR_JSON_data)
    
            if HF_JSON_data:
                pac_data["Data"][rat_str][day_str][day_session_str][channel_name].update(HF_JSON_data)
    
            with open(json_path, "w") as f:
                json.dump(pac_data, f, indent=2)
            #####################################################################################################################
        else:
            # ---- Paths
            rat_folder = os.path.join(output_directory, f"Rat{rat}")
            day_folder = os.path.join(rat_folder, f"{full_day_desc}")
            os.makedirs(day_folder, exist_ok=True)
            base_name = f"Rat{rat}_{channel_name}_{timestamp}"
            
            run_json = {
                "meta": {
                    "rat": int(rat),
                    "day_number": int(day_num),
                    "full_day_desc": full_day_desc,
                    "channel": channel_name,
                    "probe": probe,
                    "day_session": day_session,
                    "timestamp": timestamp,
                },
                "log": list(log_lines)
            }
    
            run_json_path = os.path.join(day_folder, base_name + ".json")
            with open(run_json_path, "w") as f:
                json.dump(run_json, f, indent=2)
    
            # ---- Aggregator JSON (compact; for plotting over time)
            rat_str, day_str, day_session_str = str(rat), str(day_num), str(day_session)
    
            # Load or init aggregator
            if os.path.exists(json_path):
                with open(json_path, "r") as f:
                    pac_data = json.load(f)
            else:
                pac_data = {}
    
            pac_data.setdefault("Data", {})
            pac_data["Data"].setdefault(rat_str, {})
            pac_data["Data"][rat_str].setdefault(day_str, {})
            pac_data["Data"][rat_str][day_str].setdefault(day_session_str, {})
    
            pac_data["Data"][rat_str][day_str][day_session_str][channel_name] = {
                "Probe Type": probe,
                "run_json_path": run_json_path  # so you can reload full details later
            }
    
            if SNR_JSON_data:
                pac_data["Data"][rat_str][day_str][day_session_str][channel_name].update(SNR_JSON_data)
    
            if HF_JSON_data:
                pac_data["Data"][rat_str][day_str][day_session_str][channel_name].update(HF_JSON_data)
    
            with open(json_path, "w") as f:
                json.dump(pac_data, f, indent=2)

    except Exception as e:
        import logging
        logging.error(f"Error in save_pac_outputs: {e}", exc_info=True)
        show_popup(error_string)

In [52]:
def compute_PAC_of_channels(Analysis_class):
    """
    This function processes PAC on channels individually. It is the original way of
    performing the computation but by design cannot use the autofilter library which
    necessitates processing all channels of a rat at once. It is kept to retain the
    felxiblity that comes with choosing channels to process but is not able to remove
    poor epochs from data.
    """
    try:
        # Retrieve PAC settings
        (pha_freqs, amp_freqs, idpac, dcomplex, cycles, width, n_bins,
         sampling_rate, epoch_len, data_length, minimum_length, parallel_processes,
         surrogate_permutations, notch_freqs, high_pass_filter, detrend_epochs,
         apply_autofilter, downsample_rate, seed, rejection_threshold, skip_PAC,
         inject_PAC)= Analysis_class.get_all_PAC_parameters()

        # Initialize TensorPAC object
        p, PAC_msg, e = capture_stdout(Pac, idpac=idpac, f_pha=pha_freqs, f_amp=amp_freqs,
                                   dcomplex=dcomplex, cycle=cycles, width=width, n_bins=n_bins, verbose=True)

        # Prepare IO paths
        timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
        folder_path = Analysis_class.get_folder_path
        nwb_list = Analysis_class.get_nwb_list()
        channels = Analysis_class.get_selected_channels()

        output_directory = Analysis_class.get_output_directory()
        output_folder_name = Analysis_class.get_output_folder_name()
        if "[Timestamp]" in output_folder_name:
            output_folder_name = output_folder_name.split("[Timestamp]")[0]
            output_folder_name = f"{output_folder_name}{timestamp}"
        
        output_directory = os.path.join(output_directory, output_folder_name)
        os.makedirs(output_directory, exist_ok=True)
        json_path = os.path.join(output_directory, "Session Summary.json")

        # Add Session-level information to JSON output file
        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                pac_data = json.load(f)
        else:
            pac_data = {}

        metadata = "Metadata"
        pac_data[metadata] = {
            "pha_freqs" : pha_freqs, 
            "amp_freqs" : amp_freqs, 
            "idpac" : idpac, 
            "dcomplex" : dcomplex, 
            "cycles" : cycles, 
            "width" : width, 
            "n_bins" : n_bins,
            "sampling_rate" : sampling_rate, 
            "epoch_len" : epoch_len, 
            "data_length" : data_length, 
            "minimum_length" : minimum_length, 
            "parallel_processes" : parallel_processes,
            "surrogate_permutations"  : surrogate_permutations, 
            "notch_freqs" : notch_freqs, 
            "high_pass_filter" : high_pass_filter, 
            "detrend_epochs" : detrend_epochs,
            "apply_autofilter" : apply_autofilter, 
            "downsample_rate" : downsample_rate, 
            "seed" : seed, 
            "rejection_threshold" : rejection_threshold,
            "skip_PAC" : skip_PAC,
            "inject_PAC" : inject_PAC,
            "timestamp" : timestamp,
            "folder_path" : folder_path,
            "nwb_list" : nwb_list,
            "nwb_list" : nwb_list,
            "channels" : channels,
            "output_directory" : output_directory,
            "output_folder_name" : output_folder_name,
            "json_path" : json_path,
        }

        pac_data["Data"] = {}
        
        # Write back to JSON
        with open(json_path, "w") as f:
            json.dump(pac_data, f, indent=4)

        if seed == 0:
            seed = None

        # Beginning of PAC Algorithm
        for file in nwb_list:
            file_path = os.path.join(folder_path, file)
            selected  = channels.get(file, {})
            if not selected:
                # No channels selected for file, skipping.
                continue
            
            for channel_name, data_index in selected.items():
                ##### 1) Load Data & Create Objects #####

                # Fetch corresponding rat and it's probe from channel/file
                rat, probe = Analysis.get_rat_and_probe_from_channel(file, channel_name)
                day_num, full_day_desc = Analysis.get_day_from_file(file)
                
                # Obtain current rat and probe from channel
                channel_num = int(channel_name.split('CSC')[1])
                current_rat, current_probe = Analysis.get_rat_and_probe_from_channel(file, channel_name)
        
                # Grab the session (i.e. Day0_0001.nwb -> 0001)
                split_name = file.split('_')  
                if split_name[-1]:
                    day_session = split_name[-1].split('.nwb')[0]
                full_day_desc = full_day_desc + f"_{day_session}"

                # Log information to save to txt
                log_lines = []
                log_lines.append(f"############### Started PAC processing for Rat{rat}, Channel {channel_name} on Day {day_num}_{day_session}. ###############")
                log_lines.append(PAC_msg)
                
                # Load raw data
                times, raw_data = get_raw_data(file_path, data_index, start=0, end=-1)
                # Determine actual sampling frequency, based on length of data
                sampling_rate = 1.0 / np.median(np.diff(times))
                del times
                log_lines.append(f"Extracted NWB data from file {file}, column {data_index}, shape {np.shape(raw_data)}.\n")
        
                # Wrap raw data into MNE object for less manual processing steps
                info = mne.create_info([channel_name], sampling_rate, ch_types="eeg")
                raw = mne.io.RawArray(raw_data[np.newaxis, :], info, verbose=False)
                log_lines.append(f"Wrapped raw data in MNE object with info: {info}")

                #-------------------------------------------------------------------------------------------------#
                ##### Data tampering section #####

                # Debugging w/ synthetic data
                #raw_data = generate_synthetic_data()

                # Spike injection to raw data
                x = raw.get_data(picks=channel_name)[0]   # shape (N,)
                
                #----- 1.1) Inject PAC for Sensitivity Test -----#
                if inject_PAC:
                    fs=float(raw.info['sfreq'])
                    y, pac_meta, y_add = inject_spike_into_raw_data(
                        x, fs, phase_band=(6, 10), amp_band=(80, 120),
                        depth=0.3,      # e.g., 0.1 .. 0.6
                        target_hf_snr_db=20,  # e.g., +3 to +10 dB
                        burst_duty=0.5,  # e.g., 0.3 .. 1.0
                        burst_len_s=2.0, burst_random=True, random_state=42
                    )
                    log_lines.append(f"Spike-in PAC: {pac_meta}")
                    # replace raw with the spiked data and continue pipeline:
                    raw = mne.io.RawArray(y[np.newaxis, :], raw.info, verbose=False)


                    # Plot Simple PSD (Injection Visualizer)
                    psd_filt, freqs_filt = psd_array_welch(x[np.newaxis, :], 
                                                           sfreq=fs, 
                                                           n_fft=int(fs * 4), 
                                                           fmin=0, 
                                                           fmax=200, 
                                                           n_jobs=parallel_processes, 
                                                           verbose=False)
                    psd_injected, freqs_injected = psd_array_welch(y[np.newaxis, :], 
                                                           sfreq=fs, 
                                                           n_fft=int(fs * 4), 
                                                           fmin=0, 
                                                           fmax=200, 
                                                           n_jobs=parallel_processes, 
                                                           verbose=False)
                    
                    psd_filt_db = 10 * np.log10(np.where(psd_filt[0] > 0, psd_filt[0], np.nan))
                    psd_injected_db = 10 * np.log10(np.where(psd_injected[0] > 0, psd_injected[0], np.nan))

                    fig, ax = plt.subplots(figsize=(10, 5))
                    ax.plot(freqs_filt, psd_filt_db, label=f"Raw Data", color="black")
                    ax.plot(freqs_injected, psd_injected_db, label=f"PAC Injection")
                    ax.set_title("PAC Spike Injection").set_fontproperties(termes_font_bold)
                    ax.set_xlabel("Frequency (Hz)").set_fontproperties(termes_font_bold)
                    ax.set_ylabel("Power Spectral Density (dB)").set_fontproperties(termes_font_bold)
                    #ax.set_ylim(top=35, bottom=0)
                    legend = ax.legend()
                    for text in legend.get_texts():
                        text.set_fontproperties(termes_font)
                    for label in ax.get_xticklabels() + ax.get_yticklabels():
                        label.set_fontproperties(termes_font)
                    ax.grid(True)
                    base_name = "PAC_spike_injection_PSD"
                    plt.tight_layout()
                    fig_path = os.path.join(output_directory, base_name + ".png")
                    fig.savefig(fig_path, dpi=600) 
                    fig_path = os.path.join(output_directory, base_name + ".pdf")
                    fig.savefig(fig_path)
                    plt.close(fig)
                
                #-------------------------------------------------------------------------------------------------#

                ##### 2) Apply Channel-Wide Changes #####
        
                # Apply notch filters
                if notch_freqs:
                    raw.notch_filter(notch_freqs, picks=channel_name, method='fir', verbose=False)
                    log_lines.append(f"Data notch filtered for frequencies: {notch_freqs}")
        
                # Apply high-pass filter
                if high_pass_filter and high_pass_filter > 0.0:
                    raw.filter(l_freq=high_pass_filter, h_freq=200, picks=channel_name, verbose=False)
                    log_lines.append(f"Data bandpass filtered from {high_pass_filter} to 200 Hz.")

                # Downsample data
                if downsample_rate:
                    raw.resample(sfreq=downsample_rate, verbose=False)
                    log_lines.append(f"Data downsampled from {sampling_rate} to {downsample_rate}.")
                downsampled_sampling_rate = raw.info['sfreq']

                #-------------------------------------------------------------------------------------------------#

                ##### 2) Track SNR and HF Band Power #####
                
                #----- 2.1) Phase-Band SNR Check -----#
                check_SNR = True
                SNR_JSON_data = None
                if check_SNR == True:
                    bands_name = ['Delta', 'Theta', 'Wide Theta', 'High Theta', 'Beta']
                    bands = [(1,4), (6,10), (4,12), (10,16), (13,30)]
                    SNR_JSON_data = {}
                    x = raw.get_data(picks=channel_name)[0]
                    fs = float(raw.info['sfreq'])
                    for i, b in enumerate(bands):
                        snr_db, snr_info = phase_band_snr(x, fs, band=b, neighbor_bw=2.0, exclude_bw=0.5, remove_aperiodic=True)
                        SNR_JSON_data[f"{bands_name[i]}"] = snr_info
                        log_lines.append(f"Phase-band SNR: {snr_db} dB at ~{snr_info.get('f0', np.nan)} Hz "
                                         f"(aperiodic_removed={snr_info['aperiodic_removed']})")

                #----- 2.2) HF Band Power -----#
                check_HF = True
                HF_JSON_data = None
                if check_HF == True:
                    HF_JSON_data = {}
                    # Set HF band to process
                    hf_low, hf_high = 50, 120
                    
                    hf_metrics = hf_band_power_metrics(
                        x, fs,
                        hf_band=(hf_low, hf_high),
                        ref_band=(20, 150),
                        remove_aperiodic=True,
                        fit_range=(2, 200)
                    )
                    HF_JSON_data["HF Power"] = hf_metrics
                    log_lines.append(
                        "HF Power: "
                        f"band={hf_metrics.get('hf_band')}, ref={hf_metrics.get('ref_band')}, "
                        f"rel={hf_metrics.get('hf_rel_power', np.nan):.4f}, "
                        f"pct_in_ref={hf_metrics.get('hf_percentile_in_ref', np.nan):.1f}%"
                    )   

                #-------------------------------------------------------------------------------------------------#

                # Skip entire PAC calculation if only want additional diagnostic info on data
                if skip_PAC == False:
                    ##### 3) Trim Sample, Epoch Data, Detrend #####
    
                    # Determine end of segment to process, in seconds
                    start_sec = 0
                    target_sec = data_length * 60.0 if data_length != -1 else raw.times[-1]
                    target_min = int(target_sec / 60)
                    tmax = min(target_sec, raw.times[-1])
                    # Snap to an exact number of samples:
                    target_n = int(np.floor(min(tmax, raw.times[-1]) * downsampled_sampling_rate))
                    tmax_exact = (target_n - 1) / downsampled_sampling_rate if target_n > 0 else 0
                    raw.crop(tmin=start_sec, tmax=tmax_exact, include_tmax=True, verbose=False)
                    log_lines += [
                        f"Raw fs before resample: {sampling_rate} Hz",
                        f"Raw fs after resample:  {downsampled_sampling_rate} Hz",
                        f"Total samples: {raw.n_times}, total_sec: {raw.times[-1]}",
                    ]
    
                    # Skip file if shorter than allowed minimum
                    if target_min < minimum_length:
                        log_lines.append(f"{file} < {minimum_length} min. Skipping.")
                        continue
    
                    reject_criteria = None if rejection_threshold == 0.0 else dict(eeg=rejection_threshold * 1e-3)  # scaled to mV
    
                    events = mne.make_fixed_length_events(raw, duration=epoch_len)
                    epochs = mne.Epochs(raw, 
                                        events, 
                                        tmin=0, 
                                        tmax=epoch_len, 
                                        reject=reject_criteria,  
                                        baseline=None, 
                                        preload=True)
                    log_lines.append(f"Data Epoched to a length of {epoch_len} seconds.")
    
                    if len(epochs) == 0:
                        log_lines.append(f"All epochs rejected. Channel skipped as bad channel.")
                        continue 
                    
                    # Detrend data
                    if detrend_epochs:
                        data = epochs.get_data()
                        order = 1
                        data_detrended = mne.filter.detrend(data, order=order)
                        log_lines.append(f"Epochs detrended with an order of {order}.")
                        cleaned_epochs = data_detrended.squeeze(axis=1)
                    else:
                        cleaned_epochs = epochs.get_data().squeeze(axis=1)
    
                    ##### 4) Extract phase and amplitdue components of epochs, compute PAC map, permute surrogate data #####
                    
                    phases, msg, e = capture_stdout(p.filter, downsampled_sampling_rate, cleaned_epochs,
                                                    ftype='phase', n_jobs=parallel_processes)
                    log_lines.append(msg)
                    
                    amplitudes, msg, e = capture_stdout(p.filter, downsampled_sampling_rate, cleaned_epochs,
                                                        ftype='amplitude', n_jobs=parallel_processes)
                    log_lines.append(msg)
            
                    pac_map, msg, e = capture_stdout(p.fit, phases, amplitudes, n_perm=surrogate_permutations,
                                                     n_jobs=parallel_processes, random_state=seed, verbose=False)
                    log_lines.append(msg)
                    
                    del phases, amplitudes

                    ##### 5) Generate comodulogram, export plot and algorithm info to txt, and session info to JSON #####
                    
                    save_pac_outputs(Analysis_class, output_directory, rat, channel_name, day_num, full_day_desc, start_sec, target_sec, 
                                     timestamp, p, log_lines, pac_map, json_path, probe, day_session, SNR_JSON_data, HF_JSON_data, skip_PAC)
                else:
                    start_sec = None
                    target_sec = None
                    p = None
                    pac_map = None
                    save_pac_outputs(Analysis_class, output_directory, rat, channel_name, day_num, full_day_desc, start_sec, target_sec, 
                                     timestamp, p, log_lines, pac_map, json_path, probe, day_session, SNR_JSON_data, HF_JSON_data, skip_PAC)
            
    except Exception as e:
            logging.error(f"Error in compute_PAC_of_channels: {e}", exc_info=True)
            show_popup(error_string)

In [53]:
def compute_PAC_of_rats(Analysis_class):
    """
    This function processes PAC on channels individually. It is the original way of
    performing the computation but by design cannot use the autofilter library which
    necessitates processing all channels of a rat at once. It is kept to retain the
    felxiblity that comes with choosing channels to process but is not able to remove
    poor epochs from data.
    """
    try:
        # Retrieve PAC settings
        (pha_freqs, amp_freqs, idpac, dcomplex, cycles, width, n_bins,
         sampling_rate, epoch_len, MIN_CHUNK_MIN, CHUNK_MIN, CHUNK_LIMIT,
         parallel_processes, surrogate_permutations, notch_freqs,
         high_pass_filter, detrend_epochs, apply_autofilter
        )= Analysis_class.get_all_PAC_parameters()

        # Initialize TensorPAC object
        p = Pac(idpac=idpac, f_pha=pha_freqs, f_amp=amp_freqs,
                dcomplex=dcomplex, cycle=cycles, width=width, n_bins=n_bins)

        # Prepare IO paths
        output_dir = r"C:\Users\holot\Desktop\Summer Research\Research GUI V2\PAC Comodulogram"
        folder_path = Analysis_class.get_folder_path
        nwb_list = Analysis_class.get_nwb_list()


        for file in nwb_list:
            file_path = os.path.join(folder_path, file)

            # Nested dict with file, rat, channel, and column info
            selected_rats_info = Analysis_class.get_selected_rats()
            rat_keys = selected_rats_info[file].keys()
            for rat in rat_keys:
                data_columns = selected_rats_info[file][rat]["data_columns"]
                channels = selected_rats_info[file][rat]["channels"]
                channel_names = Analysis_class.get_channel_names_from_rat(rat)

                # Sort columns for HDF5 to be able to grab from matrix
                # Match sorting with channels to retain channel/column/name pairing
                # a) zip → (column, channel, name)
                linked_col_ch_nm = list(zip(data_columns, channels, channel_names))
                
                # b) sort by column index (first element of each tuple)
                linked_col_ch_nm.sort(key=lambda x: x[0])
                
                # c) unzip into synced, sorted lists
                data_columns, channels, channel_names = map(list, zip(*linked_col_ch_nm))

                # Grab raw data (n_ch, n_samples)
                data = get_raw_matrix(file_path, data_columns)  
                
                # Build MNE object
                
                #channel_names = []
                #for channel_num in sorted_channels:
                #    channel_names.append('CSC' + str(channel_num))
                
                info = mne.create_info(ch_names=channel_names, sfreq=sampling_rate, ch_types="eeg")
                info.set_montage(rat_probe_montage(channel_names))  # coords for AutoReject
                raw  = mne.io.RawArray(data, info)
            
                # Apply channel-wide filters
                if notch_freqs:
                    raw.notch_filter(freqs=notch_freqs, method="fir")
                if high_pass_filter:
                    raw.filter(l_freq=high_pass_filter, h_freq=None)
            
                # Epoch data
                events = mne.make_fixed_length_events(raw, duration=epoch_len)
                epochs = mne.Epochs(raw, events, tmin=0, tmax=epoch_len,
                                    baseline=None, preload=True)
            
                # Apply Autoreject filter
                ar = AutoReject(n_interpolate=[1, 2], n_jobs=8)
                epochs, reject_log = ar.fit_transform(epochs, return_log=True)
                print(f"Rejected {sum(reject_log.bad_epochs)} bad epochs")
                reject_log.plot('horizontal')

                # ---------- PAC channel-by-channel ----------
                ep_array = epochs.get_data()  # (n_epochs, n_ch, n_times)
                for i, channel_name in enumerate(channel_names):
                    cleaned = ep_array[:, i, :]               # shape (n_epochs, n_times)
                    phases  = p.filter(sampling_rate, cleaned, ftype="phase",     n_jobs=parallel_processes)
                    amps    = p.filter(sampling_rate, cleaned, ftype="amplitude", n_jobs=parallel_processes)
                    pac_map = p.fit(phases, amps, n_perm=surrogate_permutations,
                                    n_jobs=parallel_processes)
                    
                    del phases, amps
            
                    pvals = p.infer_pvalues(p=0.05)
                    comod = pac_map.mean(-1)
                
                    # Plot Comodulogram
                    fig, ax = plt.subplots(figsize=(6,5))
                    p.comodulogram(comod,
                                   title=f"{file} | {channel_name}",
                                   fz_labels=10, cmap='viridis', colorbar=True)
                
                    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                    #start_min = int(start_s / 60)
                    #end_min   = int(end_s   / 60)
                    #base = f"file_{file}_ch_{channel_name}_{start_min}-{end_min}min_{timestamp}"
                    base = f"file_{file}_ch_{channel_name}_{timestamp}"
                    fig.savefig(os.path.join(output_dir, base + ".png"), dpi=300)
                    plt.close(fig)
                
                    idx    = np.unravel_index(np.argmax(comod, axis=None), comod.shape)
                    pha_f  = p.f_pha[idx[0]]
                    amp_f  = p.f_amp[idx[1]]
                    val    = comod[idx]
                    peak_p = pvals[idx]
                
                    # write summary without formatting errors
                    with open(os.path.join(output_dir, base + ".txt"), "w") as fh:
                        fh.write(f"channel:         {channel_name}\n")
                        fh.write(f"phase_freq (Hz): {pha_f}\n")
                        fh.write(f"amp_freq  (Hz):  {amp_f}\n")
                        fh.write(f"max_PAC_value:   {val}\n")
                        fh.write(f"p_value_peak:    {peak_p}\n")
                        
    except Exception as e:
            logging.error(f"Error in compute_PAC_of_rats: {e}", exc_info=True)
            show_popup(error_string)

In [54]:
def rat_probe_montage(ch_names):
    """
    Sets very rough spatial locations for autoreject to compare
    neighboring channels.
    """
    try:
        coords = {}
        for idx, name in enumerate(ch_names):
            # simplistic left/right & anterior/posterior mapping
            side = -0.01 if "L" in name else 0.01
            ant  =  0.01 if "PFC" in name else -0.01 if "HPC" in name else 0.0
            vent = -0.01 if "AMG" or "D" in name else 0.01 if "V" in name else 0.0
            coords[name] = np.array([side, ant, vent])
        return mne.channels.make_dig_montage(coords)
    except Exception as e:
            logging.error(f"Error in rat_probe_montage: {e}", exc_info=True)
            show_popup(error_string)

In [55]:
def get_raw_matrix(file_path, col_indices, start_sec=0, end_sec=-1):
    """
    Read multiple columns from an NWB ElectricalSeries in one pass.
    """
    try:
        with NWBHDF5IO(file_path, "r") as io:
            nwb = io.read()
            elec_series = nwb.acquisition["ElectricalSeries"]
            sampling_rate = elec_series.rate  # in Hz
            conversion = elec_series.conversion  # NCS scaling factor (bits to volts)
    
            # slice rows once, all cols at once
            start_idx = int(start_sec * sampling_rate)
            stop_idx  = int(elec_series.data.shape[0] if end_sec == -1 else end_sec * sampling_rate)
    
            # fancy-index the second axis → shape (n_samples, n_sel)
            arr = elec_series.data[start_idx:stop_idx, col_indices]
            data = np.asarray(arr).T * conversion       # -> (n_sel, n_samples)
    
        return data
    except Exception as e:
            logging.error(f"Error in get_raw_matrix: {e}", exc_info=True)
            show_popup(error_string)

In [56]:
def compute_PSD(Analysis_class):
    """
    Compute and save power spectral density (PSD) plots for raw and filtered data
    on a single channel. Includes optional notch filter, high-pass filter, and
    1/f correction via log-log linear detrending.
    """
    try:
        # Load PSD parameters
        (notch_freqs, 
         high_pass_filter, 
         correct_1overf, 
         sampling_rate, 
         output_directory,
         fmin,
         fmax,
         voltage_scale, 
         PSD_FFT_resolution, 
         PSD_plot_raw, 
         PSD_plot_filtered, 
         PSD_plot_grouping_method) = Analysis_class.get_all_PSD_parameters()

        # Load output parameters
        (output_folder_name, 
         export_PNG, 
         export_PDF, 
         export_SVG, 
         export_EPS, 
         image_height, 
         image_width, 
         image_DPI,
         color_palette,
         y_custom_axis,
         yaxis_top,
         yaxis_bottom,
         alpha,
         save_data) = Analysis_class.get_export_parameters()
    
        folder_path = Analysis_class.get_folder_path
        nwb_list    = Analysis_class.get_nwb_list()
        channels    = Analysis_class.get_selected_channels()
        fig_dict = {}

        parallel_processes = Analysis_class.get_parallel_processes()

        # Band definitions 
        BANDS_HZ = {
            "Delta"      : (1,   4),
            "Theta"      : (4,  12),
            "Beta"       : (12, 30),
            "Gamma"      : (30, 80),
            "HighGamma"  : (80, 200),
        }

        # ------------------------------------------------------------------
        def summarise_bands(freqs_hz: np.ndarray, psd_db: np.ndarray):
            """
            Return {band: {'peak_db': ..., 'avg_db': ...}, ...}
            using inclusive lower-edge, exclusive upper-edge bins.
            """
            band_summary = {}
            for band, (low, high) in BANDS_HZ.items():
                sel = (freqs_hz >= low) & (freqs_hz < high)
                if not sel.any():          # band not covered by current f-range
                    band_summary[band] = {"peak_db": None, "avg_db": None}
                    continue
        
                band_vals = psd_db[sel]
                band_summary[band] = {
                    "peak_db": float(np.nanmax(band_vals)),
                    "avg_db" : float(np.nanmean(band_vals)),
                    # optional: the frequency at the peak
                    # "peak_Hz": float(freqs_hz[sel][np.nanargmax(band_vals)]),
                }
            return band_summary
        # ------------------------------------------------------------------

        # Create probe color map
        PROBES = ['RAMG', 'RPFC', 'RVHPC', 'RDHPC', 'LHPCSCREW', 'RHPCSCREW', 'LPFCSCREW', 'RPFCSCREW']
        DAYS = ['Day0', 'Day1', 'Day3', 'Day7', 'Day14', 'Day21', 'Day0control', 'PreDay0']
        PROBE_COLOR_MAP = {probe: color_palette[i] for i, probe in enumerate(PROBES)}
        DAY_COLOR_MAP = {probe: color_palette[i] for i, probe in enumerate(DAYS)}
        
        # Create Output Folder
        timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
        if "[Timestamp]" in output_folder_name:
            output_folder_name = output_folder_name.split("[Timestamp]")[0]
            output_folder_name = f"{output_folder_name}{timestamp}"
    
        output_folder_name = output_folder_name.replace("PAC","PSD")
        output_directory = os.path.join(output_directory, output_folder_name)
        os.makedirs(output_directory, exist_ok=True)
        json_path = os.path.join(output_directory, "PSD Session Summary.json")
        json_data_path = os.path.join(output_directory, "PSD Session Data.json")

        # Add Session-level information to JSON output file
        if os.path.exists(json_path):
            with open(json_path, "r") as f:
                psd_metadata = json.load(f)
        else:
            psd_metadata = {}

        metadata = "Metadata"
        psd_metadata[metadata] = {
            "timestamp" : timestamp,
            "notch_freqs" : notch_freqs, 
            "high_pass_filter" : high_pass_filter, 
            "correct_1overf" : correct_1overf, 
            "sampling_rate" : sampling_rate, 
            "output_directory" : output_directory,
            "fmin" : fmin,
            "fmax" : fmax,
            "voltage_scale" : voltage_scale, 
            "PSD_FFT_resolution" : PSD_FFT_resolution, 
            "PSD_plot_raw" : PSD_plot_raw, 
            "PSD_plot_filtered" : PSD_plot_filtered, 
            "PSD_plot_grouping_method" : PSD_plot_grouping_method,
            "output_folder_name" : output_folder_name, 
            "export_PNG" : export_PNG, 
            "export_PDF" : export_PDF, 
            "export_SVG" : export_SVG, 
            "export_EPS" : export_EPS, 
            "image_height" : image_height, 
            "image_width" : image_width, 
            "image_DPI" : image_DPI,
            "color_palette" : color_palette,
            "y_custom_axis" : y_custom_axis,
            "yaxis_top" : yaxis_top,
            "yaxis_bottom" : yaxis_bottom,
            "alpha" : alpha,
            "folder_path" : folder_path,
            "nwb_list" : nwb_list,
            "channels" : channels,   
        }

        # Rat-specific data in metadata JSON file
        psd_metadata["Data"] = {}
        
        # Write back to JSON
        with open(json_path, "w") as f:
            json.dump(psd_metadata, f, indent=4)

        # Separate JSON file for large data files
        if save_data:
            # Add Session-level information to JSON output file
            if os.path.exists(json_data_path):
                with open(json_data_path, "r") as f:
                    psd_data = json.load(f)
            else:
                psd_data = {}
            
            # Dict to store computed results
            psd_data["Data"] = {}
            
            # Write back to JSON
            with open(json_data_path, "w") as f:
                json.dump(psd_data, f, indent=4)

        # Loop through every selected NWB file
        for file in nwb_list:
            file_path = os.path.join(folder_path, file)
            selected  = channels.get(file, {})
            if not selected:
                continue

            # Loop through every channel selected in current NWB file
            for channel_name, data_index in selected.items():
                # ---------------------------------------------------------------------------------------
                ########## Setup ##########
                # ---------------------------------------------------------------------------------------
                
                # Obtain current rat and probe from channel
                channel_num = int(channel_name.split('CSC')[1])
                current_rat, current_probe = Analysis.get_rat_and_probe_from_channel(file, channel_name)

                # Determine day from file name
                file_day = "Day N/A"
                split_name = file.split('_')
                for name_piece in split_name:
                    if "Day" in name_piece:
                        file_day = name_piece
                        file_day_num = ''.join(filter(str.isdigit, name_piece))

                # Grab the session (i.e. Day0_0001.nwb -> 0001)
                if split_name[-1]:
                    file_day_session = split_name[-1].split('.nwb')[0]

                # Create dictionary entiries of the user wishes to group channels on plots
                if PSD_plot_grouping_method == "Group by Rat":

                    # Check if this rat already has a dict
                    if current_rat not in fig_dict:
                        fig_dict[current_rat] = {}

                    # Check if this rat already has this day
                    if file_day not in fig_dict[current_rat]:
                        fig_dict[current_rat][file_day] = {}

                    # Check if this rat already has this session
                    if file_day_session not in fig_dict[current_rat][file_day]:
                        fig, ax = plt.subplots(figsize=(image_width, image_height))
                        fig_dict[current_rat][file_day][file_day_session] = (fig, ax)
                        
                elif PSD_plot_grouping_method == "Group by Channel":
                    # Check if this channel already has a figure
                    if current_probe not in fig_dict:
                        fig, ax = plt.subplots(figsize=(image_width, image_height))
                        fig_dict[current_probe] = (fig, ax)
                        
                # ---------------------------------------------------------------------------------------
                ########## Filtering ##########
                # ---------------------------------------------------------------------------------------
                        
                # Load raw data and timestamps
                times, data = get_raw_data(file_path, data_index, start=0, end=-1)
                data = data * voltage_scale
                raw_data = data.copy()
                fs = 1.0 / np.median(np.diff(times))
                #del times

                # Wrap in MNE object
                info = mne.create_info([channel_name], fs, ch_types="eeg")
                raw = mne.io.RawArray(data[np.newaxis, :], info, verbose=False)

                # Apply notch filters
                if notch_freqs:
                    raw.notch_filter(notch_freqs, picks=channel_name, method='fir', verbose=False)
        
                # Apply high-pass filter
                if high_pass_filter and high_pass_filter > 0.0:
                    raw.filter(l_freq=high_pass_filter, h_freq=200, picks=channel_name, verbose=False)

                # Pull data out of MNE object
                data_filt = raw.get_data(picks=channel_name)[0]
            
                # Determine number of FFT values to compute. 
                # Higher resolution = more plotting points, MAX nperseg = # data samples in data / sampling rate
                nperseg = int(fs * PSD_FFT_resolution)

                # Compute PSD using MNE welch method
                psd_raw, freqs_raw = psd_array_welch(raw_data[np.newaxis, :], 
                                                     sfreq=fs, 
                                                     n_fft=nperseg, 
                                                     fmin=fmin, 
                                                     fmax=fmax, 
                                                     n_jobs=parallel_processes, 
                                                     verbose=False)
                
                psd_filt, freqs_filt = psd_array_welch(data_filt[np.newaxis, :], 
                                                       sfreq=fs, 
                                                       n_fft=nperseg, 
                                                       fmin=fmin, 
                                                       fmax=fmax, 
                                                       n_jobs=parallel_processes, 
                                                       verbose=False)
            
                # Linear PSD value (unused, dB conversion interfaces better with matplotlib)
                #psd_raw_linear = psd_raw[0]
                #psd_filt_linear = psd_filt[0]
                
                # Convert to dB
                psd_raw_db = 10 * np.log10(np.where(psd_raw[0] > 0, psd_raw[0], np.nan))
                psd_filt_db = 10 * np.log10(np.where(psd_filt[0] > 0, psd_filt[0], np.nan))

                # Apply 1/f correction [Applies to FILTERED data]
                if correct_1overf:
                    log_freqs = np.log10(freqs_filt[1:])
                    log_psd = np.log10(np.where(psd_filt[0][1:] > 0, psd_filt[0][1:], np.nan))
                    # Remove NaNs before fitting
                    valid_idx = np.isfinite(log_freqs) & np.isfinite(log_psd)
                    slope, intercept, *_ = linregress(log_freqs, log_psd)
                    trend = slope * log_freqs + intercept
                    psd_filt_corrected = log_psd - trend
                    freqs_final = freqs_filt[1:]

                # ---------------------------------------------------------------------------------------
                ########## Plotting ##########
                # ---------------------------------------------------------------------------------------

                # Access current plot based on grouping method
                if PSD_plot_grouping_method == "Group by Rat":
                    fig, ax = fig_dict[current_rat][file_day][file_day_session]
                elif PSD_plot_grouping_method == "Group by Channel":
                    fig, ax = fig_dict[current_probe]
                elif PSD_plot_grouping_method == "No Grouping":
                    fig, ax = plt.subplots(figsize=(image_width, image_height))

                # Choose Label based on user selection
                if PSD_plot_raw and PSD_plot_filtered:
                    raw_label = f" Raw"
                    filter_label = f" Filt"
                    oneoverf_label = f" (1/f)"
                else:
                    raw_label = f""
                    filter_label = f""
                    oneoverf_label = f""

                # Set correct color and label based on grouping method
                if PSD_plot_grouping_method == "Group by Channel":
                    selection_criteria = f"R{current_rat} " + file_day.replace('Day', 'D')
                    # Select color map based on day/probe, black as backup
                    color = DAY_COLOR_MAP.get(file_day, 'black')
                else:
                    selection_criteria = current_probe
                    # Select color map based on day/probe, black as backup
                    color = PROBE_COLOR_MAP.get(current_probe, 'black')

                # Plot data according to user selection
                if PSD_plot_raw:
                    ax.plot(freqs_raw, psd_raw_db, label=f"{selection_criteria}{raw_label}", color=color, alpha=alpha)

                if PSD_plot_filtered:
                    ax.plot(freqs_filt, psd_filt_db, label=f"{selection_criteria}{filter_label}", color=color, alpha=alpha)

                if correct_1overf:
                    ax.plot(freqs_final, psd_filt_corrected, label=f"{selection_criteria}{oneoverf_label}", color=color, alpha=alpha)

                # ---------------------------------------------------------------------------------------
                ########## Statistic Calculating ##########
                # ---------------------------------------------------------------------------------------

                band_stats_raw  = summarise_bands(freqs_raw,  psd_raw_db)
                band_stats_filt = summarise_bands(freqs_filt, psd_filt_db)
                
                # 1. Broadband power (dB mean over 1–200 Hz)
                bb_mask = (freqs_raw >= 1) & (freqs_raw < 200)
                broadband_raw  = float(np.nanmean(psd_raw_db[bb_mask]))
                broadband_filt = float(np.nanmean(psd_filt_db[bb_mask]))
                
                # 2. Theta-Delta Ratio
                tdr_raw  = band_stats_raw['Theta']['avg_db'] - band_stats_raw['Delta']['avg_db']   # dB diff = log-ratio
                tdr_filt = band_stats_filt['Theta']['avg_db'] - band_stats_filt['Delta']['avg_db']
                
                # 3. Aperiodic exponent & offset (fit on log10(power) vs log10(freq))
                fit_mask = (freqs_raw > 0) & (freqs_raw >= 2) & (freqs_raw < 40)
                
                x = np.log10(freqs_raw[fit_mask])
                y = psd_raw_db[fit_mask] / 10.0      # convert dB → log10(power)
                
                valid = np.isfinite(x) & np.isfinite(y)
                slope, offset_log10 = np.polyfit(x[valid], y[valid], 1)
                
                aperiodic_exp = float(-slope)        # unitless k
                aperiodic_off = float(10 * offset_log10)  # dB µV²/Hz

                biomarkers = {
                    "Broadband_dB": {"Raw": broadband_raw, "Filtered": broadband_filt},
                    "ThetaDelta_dB": {"Raw": tdr_raw, "Filtered": tdr_filt},
                    "Aperiodic": {"Exponent": aperiodic_exp, "Offset": aperiodic_off},
                }

                # ---------------------------------------------------------------------------------------
                ########## JSON Exporting ##########
                # ---------------------------------------------------------------------------------------

                # Access JSON file to store info
                if os.path.exists(json_path):
                    with open(json_path, "r") as f:
                        psd_metadata = json.load(f)
                else:
                    psd_metadata = {}

                # Access JSON file to store data
                if os.path.exists(json_data_path):
                    with open(json_data_path, "r") as f:
                        psd_data = json.load(f)
                else:
                    psd_data = {}

                # Save plot info to JSON file
                data_key = f"{current_rat}_{file_day}_{file_day_session}_{current_probe}"
                if data_key not in psd_metadata["Data"]:
                    # Save rat metadata
                    psd_metadata["Data"][data_key] = {}
                    psd_metadata["Data"][data_key]["Info"] = {
                        "Rat" : current_rat,
                        "Day" : file_day,
                        "Session" : file_day_session,
                        "Probe" : current_probe
                    }

                    # Save band-specific info
                    band_dict = {}
                    if PSD_plot_raw:
                        band_dict["Raw"] = band_stats_raw
                    if PSD_plot_filtered:
                        band_dict["Filtered"] = band_stats_filt
                    psd_metadata["Data"][data_key]["BandSummary"] = band_dict

                    # Save general biomarkers
                    psd_metadata["Data"][data_key]["BiomarkerSummary"] = biomarkers

                    # Write back to JSON
                    with open(json_path, "w") as f:
                        json.dump(psd_metadata, f, indent=4)

                
                # Save PSD data
                if save_data == True and data_key not in psd_data["Data"]:
                    psd_data["Data"][data_key] = {}
                    metadata_dict = {
                        "Rat" : current_rat,
                        "Day" : file_day,
                        "Session" : file_day_session,
                        "Probe" : current_probe
                    }

                    psd_dict = {}
                    if PSD_plot_raw:
                        psd_dict |= {"Raw_x": freqs_raw.tolist(), "Raw_y": psd_raw_db.tolist()}
                    if PSD_plot_filtered:
                        psd_dict |= {"Filt_x": freqs_filt.tolist(), "Filt_y": psd_filt_db.tolist()}
                    if correct_1overf:
                        psd_dict |= {"Norm_x": freqs_final.tolist(), "Norm_y": psd_filt_corrected.tolist()}
                    psd_data["Data"][data_key] = {"Info" : metadata_dict, "Data" : psd_dict}

                    # Write back to JSON
                    with open(json_data_path, "w") as f:
                        json.dump(psd_data, f, indent=4)

                # ---------------------------------------------------------------------------------------
                ########## Figure Exporting ##########
                # ---------------------------------------------------------------------------------------
                                    
                # Immediately configure and save figure if there is no grouping
                if PSD_plot_grouping_method == "No Grouping":
                    title = ax.set_title(f"PSD | Rat{current_rat} | {file_day} | {file_day_session} | {current_probe}")
                    title.set_fontproperties(termes_font_bold)
                    title.set_fontsize(24)
                    xl = ax.set_xlabel("Frequency (Hz)")
                    xl.set_fontproperties(termes_font_bold)
                    xl.set_fontsize(18)
                    yl = ax.set_ylabel("Power Spectral Density (dB µ$V^2$/Hz)")
                    yl.set_fontproperties(termes_font_bold)
                    yl.set_fontsize(18)
                    if y_custom_axis:
                        ax.set_ylim(top=yaxis_top, bottom=yaxis_bottom)
                    legend = ax.legend()
                    for text in legend.get_texts():
                        text.set_fontproperties(termes_font)
                        text.set_fontsize(14)
                    for label in ax.get_xticklabels() + ax.get_yticklabels():
                        label.set_fontproperties(termes_font)
                        label.set_fontsize(14)
                    ax.grid(True)
                
                    # Save figure
                    timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
                    base_name = f"{channel_name}_PSD_{timestamp}"
    
                    if export_PNG:
                        fig_path = os.path.join(output_directory, base_name + ".png")
                        fig.tight_layout()
                        fig.savefig(fig_path, dpi=image_DPI)
                        
                    if export_PDF:
                        fig_path = os.path.join(output_directory, base_name + ".pdf")
                        fig.tight_layout()
                        fig.savefig(fig_path)
            
                    if export_SVG:
                        fig_path = os.path.join(output_directory, base_name + ".svg")
                        fig.tight_layout()
                        fig.savefig(fig_path)
            
                    if export_EPS:
                        fig_path = os.path.join(output_directory, base_name + ".eps")
                        fig.tight_layout()
                        fig.savefig(fig_path)
            
                    plt.close(fig)
                
                #############################################################################################################

        # Process dictonaries outside for-loop, makes applying sorting method easier
        if PSD_plot_grouping_method == "Group by Rat":
            for rat_id, day_dict in fig_dict.items():
                for file_day, day_session_dict in day_dict.items():
                    for file_day_session, (fig, ax) in day_session_dict.items():
                        extension = ""
                        if file_day_session == "0000":
                            extension = " | Control Case"
                        elif file_day_session == "0001":
                            extension = " | DFP Injection"
                        elif file_day_session == "0002":
                            extension = " | MDZ Response"

                        title = ax.set_title(f"PSD | Rat{rat_id} | {file_day}{extension}")
                        title.set_fontproperties(termes_font_bold)
                        title.set_fontsize(24)
                        xl = ax.set_xlabel("Frequency (Hz)")
                        xl.set_fontproperties(termes_font_bold)
                        xl.set_fontsize(18)
                        yl = ax.set_ylabel("Power Spectral Density (dB µV²/Hz)")
                        yl.set_fontproperties(termes_font_bold)
                        yl.set_fontsize(18)
                        if y_custom_axis:
                            ax.set_ylim(top=yaxis_top, bottom=yaxis_bottom)
                        legend = ax.legend()
                        for text in legend.get_texts():
                            text.set_fontproperties(termes_font)
                            text.set_fontsize(14)
                        for label in ax.get_xticklabels() + ax.get_yticklabels():
                            label.set_fontproperties(termes_font)
                            label.set_fontsize(14)
                        ax.grid(True)
        
                        # Save figure
                        timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
                        base_name = f"{rat_id}_{file_day}_{file_day_session}_PSD_{timestamp}"
        
                        if export_PNG:
                            fig_path = os.path.join(output_directory, base_name + ".png")
                            fig.tight_layout()
                            fig.savefig(fig_path, dpi=image_DPI)
                            
                        if export_PDF:
                            fig_path = os.path.join(output_directory, base_name + ".pdf")
                            fig.tight_layout()
                            fig.savefig(fig_path)
                
                        if export_SVG:
                            fig_path = os.path.join(output_directory, base_name + ".svg")
                            fig.tight_layout()
                            fig.savefig(fig_path)
                
                        if export_EPS:
                            fig_path = os.path.join(output_directory, base_name + ".eps")
                            fig.tight_layout()
                            fig.savefig(fig_path)
                
                        plt.close(fig)

        if PSD_plot_grouping_method == "Group by Channel":                        
            for current_probe, (fig, ax) in fig_dict.items():   
                title = ax.set_title(f"PSD | {current_probe}")
                title.set_fontproperties(termes_font_bold)
                title.set_fontsize(24)
                xl = ax.set_xlabel("Frequency (Hz)")
                xl.set_fontproperties(termes_font_bold)
                xl.set_fontsize(18)
                yl = ax.set_ylabel("1/f Corrected Power Spectral Density (dB µV²)").set_fontproperties(termes_font_bold)
                yl.set_fontproperties(termes_font_bold)
                yl.set_fontsize(18)
                if y_custom_axis:
                    ax.set_ylim(top=yaxis_top, bottom=yaxis_bottom)
                legend = ax.legend() 
                for text in legend.get_texts():
                    text.set_fontproperties(termes_font)
                    text.set_fontsize(14)
                for label in ax.get_xticklabels() + ax.get_yticklabels():
                    label.set_fontproperties(termes_font)
                    label.set_fontsize(14)
                ax.grid(True)

                # Save figure
                timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
                base_name = f"{current_probe}_PSD_{timestamp}"

                if export_PNG:
                    fig_path = os.path.join(output_directory, base_name + ".png")
                    fig.tight_layout()
                    fig.savefig(fig_path, dpi=image_DPI)
                    
                if export_PDF:
                    fig_path = os.path.join(output_directory, base_name + ".pdf")
                    fig.tight_layout()
                    fig.savefig(fig_path)
        
                if export_SVG:
                    fig_path = os.path.join(output_directory, base_name + ".svg")
                    fig.tight_layout()
                    fig.savefig(fig_path)
        
                if export_EPS:
                    fig_path = os.path.join(output_directory, base_name + ".eps")
                    fig.tight_layout()
                    fig.savefig(fig_path)
        
                plt.close(fig)

    except Exception as e:
        logging.error(f"Error in compute_PSD: {e}", exc_info=True)
        show_popup(error_string)

In [57]:
def compute_PSD_depreciated(Analysis_class):
    """
    Compute and save power spectral density (PSD) plots for raw and filtered data
    for each selected NWB file and channel. The raw PSD will be plotted alongside
    a series of notch filters (60, 120, 180, 240 Hz) and a 2 Hz high-pass filtered PSD
    to illustrate noise removal.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from datetime import datetime
    from scipy.signal import welch, iirnotch, butter, filtfilt

    try:
        # Prepare I/O paths
        output_dir = r"C:\Users\holot\Desktop\Summer Research\Research GUI V2\PAC Comodulogram"
        os.makedirs(output_dir, exist_ok=True)
        folder_path = Analysis_class.get_folder_path
        nwb_list    = Analysis_class.get_nwb_list()
        channels    = Analysis_class.get_selected_channels()

        notch_freqs = [60.0, 120.0, 180.0, 240.0]

        for file in nwb_list:
            file_path = os.path.join(folder_path, file)
            selected  = channels.get(file, {})
            if not selected:
                print(f"No channels selected for {file}. Skipping PSD computation.")
                continue

            for channel_name, data_index in selected.items():
                print(f"Computing PSD for {file} | channel {channel_name}")

                # Load raw data and timestamps
                times, data = get_raw_data(file_path, data_index, start=0, end=-1)
                fs = 1.0 / np.median(np.diff(times))
                del times

                # Compute raw PSD
                nperseg = min(len(data), int(4 * fs))
                f_raw, Pxx_raw = welch(data, fs=fs, nperseg=nperseg)

                # Apply cascaded notch filters at specified harmonics
                data_notch = data.copy()
                for f0 in notch_freqs:
                    b_notch, a_notch = iirnotch(f0, Q=30.0, fs=fs)
                    data_notch = filtfilt(b_notch, a_notch, data_notch)

                # Apply 2 Hz high-pass filter to remove slow drift
                b_hp, a_hp = butter(4, 2.0 / (fs / 2.0), btype='high')
                data_filt = filtfilt(b_hp, a_hp, data_notch)

                # Compute filtered PSD
                f_filt, Pxx_filt = welch(data_filt, fs=fs, nperseg=nperseg)

                # Plot and save both PSDs
                plt.figure(figsize=(8, 4))
                plt.semilogy(f_raw,  Pxx_raw,  label='Raw')
                plt.semilogy(f_filt, Pxx_filt, label='Notches+HPF')
                plt.xlabel('Frequency (Hz)')
                plt.ylabel('Power Spectral Density')
                plt.title(f"PSD: {file} | Channel {channel_name}")
                plt.legend()

                # Build filename and save
                timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
                base      = f"{file}_ch_{channel_name}_PSD_{timestamp}"
                png_path  = os.path.join(output_dir, base + ".png")
                plt.tight_layout()
                plt.savefig(png_path, dpi=300)
                plt.close()

    except Exception as e:
        import logging
        logging.error(f"Error in compute_PSD: {e}", exc_info=True)
        show_popup(error_string)

In [58]:
def create_analysis_window(parent, Analysis_class):
    """
    Create an analysis window with all options for file selection, preprocessing, and results.
    """
    try:
        # Set a default output directory for the analysis
        base_dir = get_default_output_directory()
        output_dir = os.path.join(base_dir, "PAC Output")
        os.makedirs(output_dir, exist_ok=True)
        Analysis_class.set_output_directory(output_dir)
        
        with dpg.child_window(tag="ana_buttons_window",
                              border=True, 
                              autosize_x=True,
                              height=35,
                              parent="ana_child_window"):
            with dpg.group(horizontal=True, parent="ana_buttons", tag="ana_buttons_group"):
                dpg.add_text("PAC Options:")
                                
        nwb_files = Analysis_class.get_file_list
        dpg.add_button(label="Compute PAC",
                       parent="ana_buttons_group",
                       callback=PAC_callback,
                       user_data=Analysis_class)
        dpg.add_button(label="Compute Synthetic PAC",
                       parent="ana_buttons_group",
                       callback=lambda:compute_PAC_of_channels_SYNTHETIC(Analysis_class))
        dpg.add_button(label="Crunch PAC JSON Data",
                       parent="ana_buttons_group",
                       callback=crunch_PAC_JSON_data_callback,
                       user_data=Analysis_class)

        dpg.add_text("  PSD Options:", parent="ana_buttons_group")
        dpg.add_button(label="Compute PSD",
                       parent="ana_buttons_group",
                       callback=lambda:compute_PSD(Analysis_class))
        dpg.add_button(label="Crunch PSD JSON Data",
                       parent="ana_buttons_group",
                       callback=crunch_PSD_JSON_data_callback,
                       user_data=Analysis_class)

        dpg.add_text("  Extra:", parent="ana_buttons_group")
        dpg.add_button(label="Open Output Location",
                       parent="ana_buttons_group",
                       callback=open_output_location_callback,
                       user_data=Analysis_class)
        
        with dpg.child_window(parent=parent,
                              tag=f"ana{Analysis_class.ID}_window",
                              autosize_y=True,
                              autosize_x=True,
                              border=True,
                              horizontal_scrollbar=True):
            # NWB File Subwindow
            with dpg.group(tag=f"ana{Analysis_class.ID}_window_group",
                           parent=f"ana{Analysis_class.ID}_window",
                           horizontal=True):
                # NWB File Selected Column
                with dpg.child_window(tag=f"ana{Analysis_class.ID}_NWB_window", 
                                      border=True, 
                                      width=275, 
                                      autosize_y=True, 
                                      horizontal_scrollbar=True, 
                                      no_scrollbar=False, 
                                      menubar=True, 
                                      parent=f"ana{Analysis_class.ID}_window_group"):
                    with dpg.menu_bar():
                            dpg.add_menu(label="NWB Files")

                    # Build rat-to-files dictionary
                    rat_to_files = {i: [] for i in range(1, 65)}  # Rats 1 through 64
            
                    for file in nwb_files:
                        rats_in_file = NWBFolder.get_rats_in_file(file)
                        for rat in rats_in_file:
                            if 1 <= rat <= 64:
                                rat_to_files[rat].append(file)
            
                    # Build UI elements grouped by rat
                    for rat_num in range(1, 65):
                        files = rat_to_files[rat_num]
                        if not files:
                            continue  # Skip rats with no files
            
                        with dpg.collapsing_header(label=f"Rat {rat_num}", 
                                                   tag=f"ana{Analysis_class.ID}_{rat_num}_header", 
                                                   default_open=False, 
                                                   parent=f"ana{Analysis_class.ID}_NWB_window"):
                            for file in sorted(files):
                                checkbox_tag = f"{parent}_{rat_num}_{file}"
                                dpg.add_checkbox(
                                    label=file,
                                    tag=checkbox_tag,
                                    callback=update_ana_nwb_list_callback,
                                    user_data=(Analysis_class, file, f"ana{Analysis_class.ID}_channel_window")
                                )

                    """
                    for file in nwb_files:
                        dpg.add_checkbox(label=file,
                                      tag=f"ana{Analysis_class.ID}_window_{file}",
                                      parent=f"ana{Analysis_class.ID}_NWB_window",
                                      callback=update_ana_nwb_list_callback,
                                      user_data=(Analysis_class, file, f"ana{Analysis_class.ID}_channel_window"))
                    """

                # Channel Selection Column
                with dpg.child_window(tag=f"ana{Analysis_class.ID}_channel_window", 
                                      border=True, 
                                      width=275, 
                                      autosize_y=True,
                                      horizontal_scrollbar=True, 
                                      no_scrollbar=False, 
                                      menubar=True, 
                                      parent=f"ana{Analysis_class.ID}_window_group"):
                    with dpg.menu_bar():
                            dpg.add_menu(label="Channels")              

                # PAC Options Column
                with dpg.child_window(tag=f"ana{Analysis_class.ID}_parameter_window", 
                                      border=True, 
                                      width=335, 
                                      autosize_y=True,
                                      horizontal_scrollbar=True, 
                                      no_scrollbar=False, 
                                      menubar=True, 
                                      parent=f"ana{Analysis_class.ID}_window_group"):
                    with dpg.menu_bar():
                                      dpg.add_menu(label="Algorithm Parameters")
                    with dpg.tab_bar(tag=f"ana{Analysis_class.ID}_parameter_tabs", 
                                      parent=f"ana{Analysis_class.ID}_parameter_window"):
                        with dpg.tab(label="PAC", tag="analysis_PAC_tab", 
                                      parent=f"ana{Analysis_class.ID}_parameter_tabs"):
        
                            dpg.add_text("Low Frequency Parameters")
        
                            dpg.add_combo(label="Resolution Override",
                                              tag="PAC_options_low_override",
                                              items=['lres', 'mres',
                                                     'hres', 'No Bins', 'Custom'],
                                              width=150,
                                              default_value=Analysis_class.get_lowfreq_override(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Lowpass",
                                              tag="PAC_options_low_lowpass",
                                              width=150,
                                              default_value=Analysis_class.get_lowfreq_lowpass(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Highpass",
                                              tag="PAC_options_low_highpass",
                                              width=150,
                                              default_value=Analysis_class.get_lowfreq_highpass(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Width",
                                              tag="PAC_options_low_width",
                                              width=150,
                                              default_value=Analysis_class.get_lowfreq_width(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Step",
                                              tag="PAC_options_low_step",
                                              width=150,
                                              default_value=Analysis_class.get_lowfreq_step(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_separator()
                            dpg.add_text("High Frequency Parameters")
        
                            dpg.add_combo(label="Resolution Override",
                                              tag="PAC_options_high_override",
                                              items=['lres', 'mres',
                                                     'hres','No Bins', 'Custom'],
                                              width=150,
                                              default_value=Analysis_class.get_highfreq_override(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Lowpass",
                                              tag="PAC_options_high_lowpass",
                                              width=150,
                                              default_value=Analysis_class.get_highfreq_lowpass(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Highpass",
                                              tag="PAC_options_high_highpass",
                                              width=150,
                                              default_value=Analysis_class.get_highfreq_highpass(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Width",
                                              tag="PAC_options_high_width",
                                              width=150,
                                              default_value=Analysis_class.get_highfreq_width(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Step",
                                              tag="PAC_options_high_step",
                                              width=150,
                                              default_value=Analysis_class.get_highfreq_step(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_separator()
                            dpg.add_text("Filtering")
        
                            dpg.add_input_float(label="High-Pass Filter",
                                              tag="PAC_options_high_pass_filter",
                                              width=150,
                                              default_value=Analysis_class.get_high_pass_filter(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_combo(label="Filter Notch Freq.",
                                              tag="PAC_options_filter_notch_frequencies",
                                              items=['None',
                                                     '[60]',
                                                     '[60, 120]',
                                                     '[60, 120, 180]',
                                                     '[60, 120, 180, 240]',
                                                     '[60, 120, 180, 240, 300]'],
                                              width=150,
                                              default_value=Analysis_class.get_filter_notch_frequencies(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_input_int(label="Rejection Threshold",
                                              tag="PAC_options_rejection_threshold",
                                              width=150,
                                              default_value=Analysis_class.get_rejection_threshold(),
                                              callback=lambda sender, app_data, user_data: user_data.set_rejection_threshold(app_data),
                                              user_data=Analysis_class)

                            dpg.add_text("^ 1 = 1mV, Recc. ~3")
        
                            dpg.add_checkbox(label="Detrend Epochs", 
                                              tag="PAC_options_detrend_epochs",
                                              default_value=Analysis_class.get_detrend_epochs(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_checkbox(label="Apply Autofilter", 
                                              tag="PAC_options_apply_autofilter",
                                              default_value=Analysis_class.get_apply_autofilter(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
                            dpg.add_text("(^ Must Select a Rat!)")

                            dpg.add_separator()
                            dpg.add_text("Phase-Amplitdue Coupling Parameters")
                                
                            dpg.add_combo(label="PAC Methods",
                                              tag="PAC_options_PAC_methods",
                                              items=['Mean Vector Length', 'Modulation Index',
                                                     'Heights Ratio', 'ndPAC',
                                                     'Phase-Locking Value', 'Gaussian Copula PAC'],
                                              width=150,
                                              default_value=Analysis_class.get_PAC_method(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
                            
                            dpg.add_combo(label="Surrogate Method",
                                              tag="PAC_options_surrogate_method",
                                              items=['No Surrogates',
                                                     'Swap Phase/Ampl. Across Trials',
                                                     'Swap Amplitude Time Blocks','Time Lag'],
                                              width=150,
                                              default_value=Analysis_class.get_surrogate_method(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
                            
                            dpg.add_combo(label="Normalization Method",
                                              tag="PAC_options_normalization_method",
                                              items=['No Normalization','Subtract Mean of Surrogtes',
                                                     'Divide Mean of Surrogates',
                                                     'Sub+Div Mean of Surrogates','Z-score'],
                                              width=150,
                                              default_value=Analysis_class.get_normalization_method(),
                                              callback=lambda sender, app_data, user_data: user_data.set_normalization_method(app_data),
                                              user_data=Analysis_class)

                            dpg.add_separator()
                            dpg.add_text("Situational")

                            dpg.add_checkbox(label="Skip PAC", 
                                              tag="PAC_options_skip_PAC",
                                              default_value=Analysis_class.get_skip_PAC(),
                                              callback=lambda sender, app_data, user_data: user_data.set_skip_PAC(app_data),
                                              user_data=Analysis_class)

                            dpg.add_checkbox(label="Inject PAC", 
                                              tag="PAC_options_inject_PAC",
                                              default_value=Analysis_class.get_inject_PAC(),
                                              callback=lambda sender, app_data, user_data: user_data.set_inject_PAC(app_data),
                                              user_data=Analysis_class)
        
                            dpg.add_combo(label="Dcomplex Method",
                                              tag="PAC_options_dcomplex_method",
                                              items=['Wavelet', 'Hilbert'],
                                              width=150,
                                              default_value=Analysis_class.get_dcomplex_method(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
                            
                            dpg.add_input_int(label="Phase Cycles",
                                              tag="PAC_options_phase_cycles",
                                              width=150,
                                              default_value=Analysis_class.get_phase_cycles(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
            
                            dpg.add_input_int(label="Amplitude Cycles",
                                              tag="PAC_options_amplitude_cycles",
                                              width=150,
                                              default_value=6,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Morlets Wavelet Width",
                                              tag="PAC_options_morlet’s_wavelet_width",
                                              width=150,
                                              default_value=7,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="KLD or HRPAC Bins",
                                              tag="PAC_options_KLD_or_HRPAC_bins",
                                              width=150,
                                              default_value=18,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_separator()
                            dpg.add_text("Performance")
                            dpg.add_text("* -1 = Maximum Possible")

                            dpg.add_input_int(label="Set Seed (0=None)",
                                              tag="Performance_options_seed",
                                              width=150,
                                              default_value=Analysis_class.get_seed(),
                                              callback=lambda sender, app_data, user_data: user_data.set_seed(app_data),
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Sampling Rate",
                                              tag="Performance_options_sampling_rate",
                                              width=150,
                                              default_value=2000,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Downsample Rate",
                                              tag="Performance_options_downsample_rate",
                                              width=150,
                                              default_value=500,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Epoch Length (sec)",
                                              tag="Performance_options_epoch_length",
                                              width=150,
                                              default_value=5,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Data Length (min)",
                                              tag="Performance_options_data_length",
                                              width=150,
                                              default_value=10,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Minimum Length",
                                              tag="Performance_options_minimum_length",
                                              width=150,
                                              default_value=5,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Parallel Processes*",
                                              tag="Performance_options_parallel_processes",
                                              width=150,
                                              default_value=8,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
        
                            dpg.add_input_int(label="Surrogate Count",
                                              tag="Performance_options_surrogate_count",
                                              width=150,
                                              default_value=200,
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
    
    
                        with dpg.tab(label="PSD",
                                      tag="analysis_PSD_tab", 
                                      parent=f"ana{Analysis_class.ID}_parameter_tabs"):
                            
                            dpg.add_text("Parameters")

                            dpg.add_input_float(label="Min Frequency",
                                              tag="PSD_options_fmin",
                                              width=150,
                                              default_value=Analysis_class.get_PSD_fmin(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_input_float(label="Max Frequency",
                                              tag="PSD_options_fmax",
                                              width=150,
                                              default_value=Analysis_class.get_PSD_fmax(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_input_float(label="High-Pass Filter",
                                              tag="PSD_options_high_pass_filter",
                                              width=150,
                                              default_value=Analysis_class.get_PSD_high_pass_filter(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_combo(label="Filter Notch Freq.",
                                              tag="PSD_options_notch_frequencies",
                                              items=['None',
                                                     '[60]',
                                                     '[60, 120]',
                                                     '[60, 120, 180]',
                                                     '[60, 120, 180, 240]',
                                                     '[60, 120, 180, 240, 300]'],
                                              width=150,
                                              default_value=Analysis_class.get_PSD_notch_frequencies(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            dpg.add_combo(label="Voltage Scale",
                                              tag="PSD_options_voltage_scale",
                                              items=['Volts',
                                                     'Millivolts',
                                                     'Microvolts'],
                                              width=150,
                                              default_value=Analysis_class.get_PSD_voltage_scale(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                            
                            dpg.add_input_int(label="FFT Resolution",
                                              tag="PSD_options_FFT_resolution",
                                              width=150,
                                              default_value=Analysis_class.get_PSD_FFT_resolution(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PSD_FFT_resolution(app_data),
                                              user_data=Analysis_class)

                            dpg.add_text("^ 256 Max/per 5 min of data")
                            dpg.add_separator()

                            dpg.add_combo(label="Plot Grouping Method",
                                              tag="PSD_options_plot_grouping_method",
                                              items=['No Grouping',
                                                     'Group by Rat',
                                                     'Group by Channel'],
                                              width=150,
                                              default_value=Analysis_class.get_PSD_plot_grouping_method(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PSD_plot_grouping_method(app_data),
                                              user_data=Analysis_class)

                            dpg.add_checkbox(label="Plot Raw Data", 
                                              tag="PSD_options_plot_raw",
                                              default_value=Analysis_class.get_PSD_plot_raw(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PSD_plot_raw(app_data),
                                              user_data=Analysis_class)

                            dpg.add_checkbox(label="Plot Filtered Data", 
                                              tag="PSD_options_plot_filtered",
                                              default_value=Analysis_class.get_PSD_plot_filtered(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PSD_plot_filtered(app_data),
                                              user_data=Analysis_class)

                            dpg.add_checkbox(label="Correct 1/f", 
                                              tag="PSD_options_correct_1overf",
                                              default_value=Analysis_class.get_PSD_correct_1overf(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PSD_correct_1overf(app_data),
                                              user_data=Analysis_class)
                            
                # Performance & Export Options Column
                with dpg.child_window(tag=f"ana{Analysis_class.ID}_perform_window", 
                                      border=True, 
                                      width=325, 
                                      autosize_y=True,
                                      horizontal_scrollbar=True, 
                                      no_scrollbar=False, 
                                      menubar=True, 
                                      parent=f"ana{Analysis_class.ID}_window_group"):
                    with dpg.menu_bar():
                            dpg.add_menu(label="Export")

                    dpg.add_button(label="Set Output Location",
                                  callback=set_PAC_output_directory_callback,
                                  user_data=Analysis_class)

                    dpg.add_text("Output Folder Name (Editable)")

                    dpg.add_input_text(label="",
                                      tag="Export_options_output_folder_name",
                                      default_value="GUIDA PAC Session [Timestamp]",
                                      callback=PAC_options_callback,
                                      user_data=Analysis_class)

                    dpg.add_checkbox(label="Export Data to JSON", 
                                      tag="Export_options_save_data",
                                      default_value=Analysis_class.get_save_data(),
                                      callback=lambda sender, app_data, user_data: user_data.set_save_data(app_data),
                                      user_data=Analysis_class)

                    dpg.add_separator()
                    dpg.add_text("Export Image As:")

                    dpg.add_checkbox(label="PNG", 
                                      tag="Export_options_export_PNG",
                                      default_value=Analysis_class.get_export_PNG(),
                                      callback=PAC_options_callback,
                                      user_data=Analysis_class)

                    dpg.add_checkbox(label="PDF", 
                                      tag="Export_options_export_PDF",
                                      default_value=Analysis_class.get_export_PDF(),
                                      callback=PAC_options_callback,
                                      user_data=Analysis_class)

                    dpg.add_checkbox(label="SVG", 
                                      tag="Export_options_export_SVG",
                                      default_value=Analysis_class.get_export_SVG(),
                                      callback=PAC_options_callback,
                                      user_data=Analysis_class)
                    
                    dpg.add_checkbox(label="EPS", 
                                      tag="Export_options_export_EPS",
                                      default_value=Analysis_class.get_export_EPS(),
                                      callback=PAC_options_callback,
                                      user_data=Analysis_class)
                    
                    dpg.add_separator()
                    
                    dpg.add_input_float(label="Image Width (In.)",
                                              tag="Export_options_image_width",
                                              width=150,
                                              default_value=Analysis_class.get_image_width(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                    dpg.add_input_float(label="Image Height (In.)",
                                              tag="Export_options_image_height",
                                              width=150,
                                              default_value=Analysis_class.get_image_height(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)
                    
                    dpg.add_input_int(label="Image DPI",
                                              tag="Export_options_image_DPI",
                                              width=150,
                                              default_value=Analysis_class.get_image_DPI(),
                                              callback=PAC_options_callback,
                                              user_data=Analysis_class)

                    dpg.add_text("Comodulogram Reccomendation")
                    dpg.add_text("Width = 6   Height = 5")
                    dpg.add_text("PSD Reccomendation")
                    dpg.add_text("Width = 10  Height = 5")

                    dpg.add_separator()
                    dpg.add_text("PAC Options")

                    dpg.add_checkbox(label="Use Custom Colormap Scale", 
                                              tag="PAC_export_options_custom_colormap",
                                              default_value=Analysis_class.get_PAC_custom_colormap(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PAC_custom_colormap(app_data),
                                              user_data=Analysis_class)

                    dpg.add_input_float(label="Set Vmin",
                                              tag="PAC_export_options_vmin",
                                              width=150,
                                              default_value=Analysis_class.get_PAC_vmin(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PAC_vmin(app_data),
                                              user_data=Analysis_class)

                    dpg.add_input_float(label="Set Vmax",
                                              tag="PAC_export_options_vmax",
                                              width=150,
                                              default_value=Analysis_class.get_PAC_vmax(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PAC_vmax(app_data),
                                              user_data=Analysis_class)

                    dpg.add_input_float(label="Comod. Interp. *",
                                              tag="PAC_export_options_comod_interpolation",
                                              width=150,
                                              default_value=Analysis_class.get_PAC_comod_interpolation(),
                                              callback=lambda sender, app_data, user_data: user_data.set_PAC_comod_interpolation(app_data),
                                              user_data=Analysis_class)

                    dpg.add_text("* Interpolates Only Image NOT Data")
                    dpg.add_text("0.1 = (Row x Columns) x 10")
                    dpg.add_text("0.5 = (Row x Columns) x 2")

                    dpg.add_separator()
                    dpg.add_text("PSD Options")
                    
                    dpg.add_checkbox(label="Use Custom Y-axis", 
                                              tag="Export_options_y_custom_axis",
                                              default_value=Analysis_class.get_y_custom_axis(),
                                              callback=lambda sender, app_data, user_data: user_data.set_y_custom_axis(app_data),
                                              user_data=Analysis_class)

                    dpg.add_input_float(label="Y-axis Top (dB)",
                                              tag="Export_options_yaxis_top",
                                              width=150,
                                              default_value=Analysis_class.get_yaxis_top(),
                                              callback=lambda sender, app_data, user_data: user_data.set_yaxis_top(app_data),
                                              user_data=Analysis_class)

                    dpg.add_input_float(label="Y-axis Bottom (dB)",
                                              tag="Export_options_yaxis_bottom",
                                              width=150,
                                              default_value=Analysis_class.get_yaxis_bottom(),
                                              callback=lambda sender, app_data, user_data: user_data.set_yaxis_bottom(app_data),
                                              user_data=Analysis_class)

                    dpg.add_combo(label="Color Palette",
                                              tag="Export_options_color_palette",
                                              items=['Bright Colors',
                                                     'Tableau Colors',
                                                     'Colorblind Okabe Ito',
                                                     'Colorblind Cud'],
                                              width=150,
                                              default_value=Analysis_class.get_color_palette(),
                                              callback=lambda sender, app_data, user_data: user_data.set_color_palette(app_data),
                                              user_data=Analysis_class)

                    dpg.add_input_float(label="Line Opacity (0-1)",
                                              tag="Export_options_alpha",
                                              width=150,
                                              default_value=Analysis_class.get_alpha(),
                                              callback=lambda sender, app_data, user_data: user_data.set_alpha(app_data),
                                              user_data=Analysis_class)


    except Exception as e:
        logging.error(f"Error in create_plot_window: {e}", exc_info=True)
        show_popup(error_string)

In [59]:
##### Plotting Functions #####

In [60]:
def plot_data(plot_instance):
    """
    Function for plotting provided data in dearpygui's plot widget.
    """
    try:
        # Determine data to access
        electrode_mapping = plot_instance.get_electrode_mapping()
        selected_channel = plot_instance.get_channel()
        raw_data_column_index = electrode_mapping[selected_channel]
        voltage_scaling_factor = plot_instance.get_voltage_scale()
    
        # Construct full file path from the selected NWB file
        file_path = os.path.join(plot_instance.get_folder_path, plot_instance.nwb_file)
        # Get the raw data
        x, y = get_raw_data(file_path, raw_data_column_index, plot_instance.get_data_start(), plot_instance.get_data_end())

        # Scale to be volts, millivolts, or microvolts
        if voltage_scaling_factor != 1:
            y = y * voltage_scaling_factor

        # Apply bandpass filter if that plot type is selected
        if plot_instance.get_plot_type() == "Filter":
            with NWBHDF5IO(file_path, 'r') as io:
                nwb = io.read()
                elec_series = nwb.acquisition["ElectricalSeries"]
                sampling_rate = elec_series.rate  # in Hz

            # Convert y to an acceptable data format for mne int16 -> float64
            y = y.astype(np.float64)
            y = filter_data(data=y, sfreq=sampling_rate, lowcut=plot_instance.get_lowcut(), highcut=plot_instance.get_highcut())

        # Update the plot data
        series_tag = f"{plot_instance.ID}_plot_series"
        # For a line series in DearPyGui, we update its x and y data.
        dpg.set_value(series_tag, [x.tolist(), y.tolist()])
    except Exception as e:
        logging.error(f"Error in plot_data: {e}", exc_info=True)
        show_popup(error_string)

In [61]:
def get_channels_for_file(file_path, plot_instance):
    """
    Given the full file path, return a list of available timeseries channels.
    """
    try:
        with NWBHDF5IO(file_path, 'r') as io:
            nwb = io.read()
            
            # Convert the electrodes table to a pandas DataFrame for easier inspection.
            df_electrodes = nwb.electrodes.to_dataframe()
            
            # Extract the 'channel_name' column
            channel_names = df_electrodes['channel_name'].tolist()
        
            # Sort using a key that extracts the numeric part:
            sorted_channel_names = sorted(channel_names, key=lambda x: int(x.replace('CSC', '')))
        
            # Map displayed order (1,2,3,4,...64) to default order (10,11,12,...,1,11)
            mapping = {row['channel_name']: idx for idx, row in df_electrodes.iterrows()}
            plot_instance.set_electrode_mapping(mapping)
        
            return sorted_channel_names
    except Exception as e:
        logging.error(f"Error in get_channels_for_file: {e}", exc_info=True)
        show_popup(error_string)

In [62]:
def get_raw_data(file_path, data_index, start, end):
    """
    Retrieve raw data for the given file and channel (as an index) for the duration of start and end.
    Use end = -1 to automatically go to the end of the recording.
    """
    try:
        with NWBHDF5IO(file_path, 'r') as io:
            nwb = io.read()
            elec_series = nwb.acquisition["ElectricalSeries"]
            sampling_rate = elec_series.rate  # in Hz
            conversion = elec_series.conversion  # NCS scaling factor (bits to volts)
            # Get starting time; if None, assume 0.
            t0 = elec_series.starting_time if elec_series.starting_time is not None else 0
    
            total_samples = elec_series.data.shape[0]
            sample_start = int(start * sampling_rate)

            if end == -1:
                sample_end = total_samples
            else:
                sample_end = int(end * sampling_rate)

            # Ensure bounds are valid
            sample_start = max(0, min(sample_start, total_samples))
            sample_end = max(sample_start + 1, min(sample_end, total_samples))
            
            # Generate time vector (x):
            if elec_series.timestamps is not None:
                # If explicit timestamps exist, slice them.
                x = np.array(elec_series.timestamps[sample_start:sample_end])
            else:
                # Otherwise, compute time from starting_time and sampling_rate.
                # np.arange(sample_start, sample_end) gives the sample numbers.
                x = t0 + np.arange(sample_start, sample_end) / sampling_rate
            
            # Extract y data from the specified channel (data_index).
            y = np.array(elec_series.data[sample_start:sample_end, data_index]) * conversion
            
        return x, y

    except Exception as e:
        logging.error(f"Error in get_raw_data: {e}", exc_info=True)
        show_popup(error_string)

In [63]:
def get_raw_data_column_length(file_path):
    """
    Retrieve the total number of time samples in the ElectricalSeries
    and return the duration in seconds.
    """
    try:
        with NWBHDF5IO(file_path, 'r') as io:
            nwb = io.read()
            elec_series = nwb.acquisition["ElectricalSeries"]
            sampling_rate = elec_series.rate  # Hz

            # Get total number of rows (samples)
            total_samples = elec_series.data.shape[0]

            duration_sec = int(total_samples / sampling_rate)
            return duration_sec

    except Exception as e:
        logging.error(f"Error in get_raw_data_column_length: {e}", exc_info=True)
        show_popup(error_string)


In [64]:
def create_plot_window(parent, plot_instance):
    """
    Create a plot window containing a settings panel on the left and a plot on the right.
    
    Parameters:
        parent: The parent tag or window to which this plot window will be attached.
        plot_instance: An instance of the Plot class containing plot-related settings.
    """
    try:
        with dpg.group(tag=f"viz_list_plot_group_{plot_instance.ID}", horizontal=True, parent=parent):
            # Create the overall child window for the plot window.
            with dpg.child_window(parent=f"viz_list_plot_group_{plot_instance.ID}", tag=f"{plot_instance.ID}_window", height=300, autosize_x=True, border=True):
                
                # Create a horizontal group to split the window into left and right sections.
                with dpg.group(horizontal=True):
                    
                    # Left Section: Settings Panel
                    with dpg.child_window(width=220, autosize_y=True, border=False, menubar=True, tag=f"{plot_instance.ID}_settings_panel"):
                        with dpg.menu_bar():
                            dpg.add_menu(label=f"", tag=f"{plot_instance.ID}_plot_name")

                        with dpg.tab_bar(label="Plot Settings"):
                            with dpg.tab(label="Import", tag=f"{plot_instance.ID}_import_tab"):
                                dpg.add_combo(label="Plot Type",
                                              items=["Raw", "Filter"],
                                              default_value=plot_instance.plot_type if plot_instance.plot_type else "Raw",
                                              width=138,
                                              callback=change_plot_type_callback,
                                              user_data=plot_instance)
                
                                dpg.add_separator()
                
                                NWB_file_list = plot_instance.get_file_list
                                dpg.add_combo(label="NWB File",
                                              tag=f"{plot_instance.ID}_nwb_file_combo",
                                              items=NWB_file_list,
                                              width=138,
                                              callback=select_NWB_file_callback,
                                              user_data=plot_instance)
                
                                dpg.add_combo(label="Channel",
                                              tag=f"{plot_instance.ID}_channel_combo",
                                              items=[],
                                              width=138,
                                              callback=select_NWB_channel_callback,
                                              user_data=plot_instance)

                            with dpg.tab(label="Plotting", tag=f"{plot_instance.ID}_plotting_tab"):
                                dpg.add_text(f"Data Range: {plot_instance.get_data_min()}-{plot_instance.get_data_max()}",
                                             tag=f"{plot_instance.ID}_range_text")
            
                                dpg.add_input_int(label="Start (s)",
                                                  tag=f"{plot_instance.ID}_start_int",
                                                  width=138,
                                                  callback=set_start_callback,
                                                  user_data=plot_instance,
                                                  on_enter=True,
                                                  min_value=plot_instance.get_data_min(),
                                                  max_value=plot_instance.get_data_max(),
                                                  min_clamped=True,
                                                  max_clamped=True)
            
                                dpg.add_input_int(label="End (s)",
                                                  tag=f"{plot_instance.ID}_end_int",
                                                  width=138,
                                                  callback=set_end_callback,
                                                  user_data=plot_instance,
                                                  on_enter=True,
                                                  min_value=plot_instance.get_data_min(),
                                                  max_value=plot_instance.get_data_max(),
                                                  min_clamped=True,
                                                  max_clamped=True)
        
                                dpg.add_separator()
        
                                dpg.add_input_int(label="Shift (s)",
                                                  tag=f"{plot_instance.ID}_shift_int",
                                                  width=138,
                                                  on_enter=True)
        
                                with dpg.group(horizontal=True):
        
                                    dpg.add_text(f"Shift Keys: ")
        
                                    dpg.add_button(label="Button", 
                                                   tag=f"{plot_instance.ID}_shift_left",
                                                   callback=shift_data_left_callback, 
                                                   user_data=plot_instance,
                                                   arrow=True, 
                                                   direction=dpg.mvDir_Left)
                                    
                                    dpg.add_button(label="Button", 
                                                   tag=f"{plot_instance.ID}_shift_right",
                                                   callback=shift_data_right_callback, 
                                                   user_data=plot_instance,
                                                   arrow=True, 
                                                   direction=dpg.mvDir_Right)

                                dpg.add_checkbox(label="Sync Axis", 
                                                 tag=f"{plot_instance.ID}_sync_status",
                                                 callback=sync_axis_callback,
                                                 user_data=plot_instance,
                                                 default_value=False)

                                dpg.add_combo(label="Y Scale",
                                              tag=f"{plot_instance.ID}_voltage_combo",
                                              items=["Volts","Millivolts","Microvolts"],
                                              width=138,
                                              default_value="Millivolts",
                                              callback=voltage_scale_callback,
                                              user_data=plot_instance)

                            with dpg.tab(label="Export", tag=f"{plot_instance.ID}_export_tab"):
                                dpg.add_input_text(label="Custom Name",
                                                   tag=f"{plot_instance.ID}_custom_plot_name",
                                                   #hint="Leave blank for default",
                                                   callback=lambda sender, app_data,
                                                   user_data : user_data.set_export_status(export_status),
                                                   user_data=plot_instance)

                                dpg.add_checkbox(label="Use Custom Output Path", 
                                                 tag=f"{plot_instance.ID}_export_status",
                                                 callback=lambda sender, app_data,
                                                 user_data : user_data.set_custom_name(app_data),
                                                 user_data=plot_instance,
                                                 default_value=False)
                                
                                dpg.add_button(label="Figure Export", 
                                                 tag=f"{plot_instance.ID}_export_trigger",
                                                 callback=export_trigger_callback,
                                                 user_data=plot_instance)
                    
                    # Right Section: Plot Display
                    # This child window auto-sizes in both directions.
                    with dpg.child_window(autosize_x=True, autosize_y=True, border=False, tag=f"{plot_instance.ID}_plot_area_panel"):
                        # Create a plot widget that fills this section.
                        # Note: Setting width and height to -1 lets the plot auto-size to its container.
                        with dpg.plot(label="Plot", tag=f"{plot_instance.ID}_plot_widget", width=-1, height=-1,no_title=True):        
                            # Add the X Axis (using the DearPyGui constant for the X Axis).
                            with dpg.plot_axis(dpg.mvXAxis, tag=f"{plot_instance.ID}_x_axis"):
                                pass  # You can add series to the axis later.
                            
                            # Add the Y Axis.
                            with dpg.plot_axis(dpg.mvYAxis, tag=f"{plot_instance.ID}_y_axis"):
                                # Create an empty line series that will be updated with data.
                                dpg.add_line_series([], [], label="Raw Data", tag=f"{plot_instance.ID}_plot_series")

        dpg.add_button(label="Add Plot", 
                       tag="add_plot_button",
                       parent=parent, 
                       user_data=parent, 
                       callback=add_plot_callback,
                       show=True)
    except Exception as e:
        logging.error(f"Error in create_plot_window: {e}", exc_info=True)
        show_popup(error_string)

In [65]:
def sync_axis(plot_instance):
    """
    Function to sync every other selected synced plot to the newest change in a plot.
    """
    try:
        # Wait for rendered plot to update before fetching axis data
        dpg.split_frame()
        
        x_min, x_max = dpg.get_axis_limits(f"{plot_instance.ID}_x_axis")
        y_min, y_max = dpg.get_axis_limits(f"{plot_instance.ID}_y_axis")

        logging.debug(f"x_min and x_max are: {x_min}, {x_max}")
        logging.debug(f"y_min and y_max are: {y_min}, {y_max}")

        for plot_instance_ID in Plot.get_sync_list():
            logging.debug(f"Processing plot ID {plot_instance_ID}")
            if plot_instance_ID != plot_instance.ID:  # Avoid syncing to itself
                logging.debug(f"Applying limit from plot {plot_instance.name} to ID {plot_instance_ID}")
                dpg.set_axis_limits_auto(f"{plot_instance_ID}_x_axis")
                dpg.set_axis_limits_auto(f"{plot_instance_ID}_y_axis")
                dpg.set_axis_limits(f"{plot_instance_ID}_x_axis", x_min, x_max)
                dpg.set_axis_limits(f"{plot_instance_ID}_y_axis", y_min, y_max)
  
    except Exception as e:
        logging.error(f"Error in sync_axis: {e}", exc_info=True)
        show_popup(error_string)

In [66]:
def filter_data(data, sfreq, lowcut, highcut):
    """
    Function to apply bandpass filter to raw data.
    """
    try:
        filtered_data = mne.filter.filter_data(data, sfreq, l_freq=lowcut, h_freq=highcut, method='fir', fir_design='firwin', phase='zero')
        return filtered_data
    except Exception as e:
        logging.error(f"Error in filter_data: {e}", exc_info=True)
        show_popup(error_string)


In [67]:
def export_plot(x_data, y_data, x_min, x_max, y_min, y_max, name, export_path):
    """
    Generate and save a matplotlib plot from provided data and axis limits.

    Parameters:
        x_data (array-like): X-axis data
        y_data (array-like): Y-axis data
        x_min (float): Minimum x-axis limit
        x_max (float): Maximum x-axis limit
        y_min (float): Minimum y-axis limit
        y_max (float): Maximum y-axis limit
        name (str): Title of the plot and base of filename
        export_path (str): Directory to save the output image
    """
    try:
        # Create figure and axis
        fig, ax = plt.subplots(figsize=(8, 4))
        ax.plot(x_data, y_data, linewidth=1)

        # Set axis limits
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)

        # Set title and labels
        ax.set_title(name)
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Signal")

        # Ensure export path exists
        os.makedirs(export_path, exist_ok=True)

        # Construct full filename
        filename = f"{name}.png"
        filepath = os.path.join(export_path, filename)

        # Save figure
        plt.savefig(filepath, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"Plot saved to: {filepath}")

    except Exception as e:
        print(f"Failed to export plot: {e}")

In [68]:
##### I/O & Misc Functions #####

In [69]:
def get_default_output_directory():
    '''
    # Function to get the current file path of the executable.
    '''
    try:
        if getattr(sys, 'frozen', False):
            # Running as compiled executable
            base_dir = os.path.dirname(sys.executable)
        elif '__file__' in globals():
            # Running as a Python script
            base_dir = os.path.dirname(os.path.abspath(__file__))
        else:
            # Interactive environment (e.g., Jupyter, IPython)
            base_dir = os.getcwd()

        return base_dir

    except Exception as e:
        logging.error(f"Error in get_default_output_directory: {e}", exc_info=True)
        show_popup(error_string)

In [70]:
def get_NWB_folder_filepath():
    '''
    # Function to get where NWB files are stored
    '''
    try:
        # Open a folder selection dialog
        root = Tk()
        root.withdraw()  # Hide the root window
        project_folder = filedialog.askdirectory(initialdir="Projects", title="Select NWB Data Folder")
        if not project_folder:
            return  # No folder selected
        return project_folder

    except Exception as e:
        logging.error(f"Error in select_NWB_folder: {e}", exc_info=True)
        show_popup(error_string)

In [71]:
def display_metadata(file_path):
    """
    Load metadata from an NWB file and display it in the UI.
    """
    try:
        with NWBHDF5IO(file_path, mode='r') as io:
            nwbfile = io.read()
            metadata = {
                "Identifier": nwbfile.identifier,
                "Session Description": nwbfile.session_description,
                "Experimenter": nwbfile.experimenter,
                "Session ID" : nwbfile.session_id,
                "Institution": nwbfile.institution,
                "Experiment Description" : nwbfile.experiment_description,
                
                "Session Start Time": nwbfile.session_start_time.isoformat(),
                "File Create Date" : nwbfile.file_create_date,
                "Timestamps Reference Time" : nwbfile.timestamps_reference_time.isoformat(),
            
                "Notes" : nwbfile.notes
            }

        # Update or create the metadata display window
        if dpg.does_item_exist("import_window_grouping_stageB"):
            dpg.delete_item("import_window_grouping_stageB")  # Clear previous content
            
        with dpg.group(tag="import_window_grouping_stageB", parent="import_window_grouping_stageA"):
            with dpg.child_window(tag="project_info_window", menubar=True, border=True, autosize_y=True, autosize_x=True, parent="import_window_grouping_stageB"):
                with dpg.menu_bar():
                    dpg.add_menu(label="Project Info")
                for key, value in metadata.items():
                    dpg.add_text(f"{key}: {value}", wrap=0)

    except Exception as e:
        logging.error(f"Error in display_metadata: {e}", exc_info=True)
        metadata = {"Error": "Unable to load metadata"}
        show_popup(error_string)

In [72]:
def create_NWB_file_window(parent_tag, parent, nwb_files, folder_path, NWBFolder_Class):
    """
    Creates the lefthand UI for NWB files under the specified parent.
    """
    try:
        #print(f"NWB File List: {nwb_files}")
        # Check if the window already exists and close it if necessary
        if dpg.does_item_exist(f"{parent}_window_grouping_stageA"):
            dpg.delete_item(f"{parent}_window_grouping_stageA")
        
        with dpg.group(tag=f"{parent}_window_grouping_stageA", parent=parent_tag, horizontal=True):
            with dpg.child_window(tag=f"{parent}_NWB_window", 
                                  border=True, width=300, 
                                  autosize_y=True, 
                                  horizontal_scrollbar=False, 
                                  no_scrollbar=False, 
                                  menubar=True, 
                                  parent=f"{parent}_window_grouping_stageA"):
                with dpg.menu_bar():
                    dpg.add_menu(label="NWB Files")
                """
                for file in nwb_files:
                    dpg.add_checkbox(
                        label=file,
                        tag=f"{parent}_{file}",  # Ensure tags are unique across parents
                        callback=checkbox_callback,
                        user_data=(folder_path, f"{parent}_NWB_window", file),  # Pass all tags and folder path
                        parent=f"{parent}_NWB_window"
                    )
                """

                # Build rat-to-files dictionary
                rat_to_files = {i: [] for i in range(1, 65)}  # Rats 1 through 64
        
                for file in nwb_files:
                    rats_in_file = NWBFolder.get_rats_in_file(file)
                    for rat in rats_in_file:
                        if 1 <= rat <= 64:
                            rat_to_files[rat].append(file)
        
                # Build UI elements grouped by rat
                for rat_num in range(1, 65):
                    files = rat_to_files[rat_num]
                    if not files:
                        continue  # Skip rats with no files
        
                    with dpg.collapsing_header(label=f"Rat {rat_num}",tag=f"{parent}_{rat_num}_header", default_open=False, parent=f"{parent}_NWB_window"):
                        for file in sorted(files):
                            checkbox_tag = f"{parent}_{rat_num}_{file}"
                            dpg.add_checkbox(
                                label=file,
                                tag=checkbox_tag,
                                callback=checkbox_callback,
                                user_data=(folder_path, f"{parent}_{rat_num}_header", file)
                            )


                """
                # List of recognized days in proper order
                DAYS = ['PreDay0', 'Day0Control', 'Day0', 'Day1', 'Day3', 'Day7', 'Day14', 'Day21']
        
                # Build rat-to-files dictionary
                rat_to_files = {i: [] for i in range(1, 65)}  # Rats 1 through 64
        
                for file in nwb_files:
                    rats_in_file = NWBFolder.get_rats_in_file(file)
                    for rat in rats_in_file:
                        if 1 <= rat <= 64:
                            rat_to_files[rat].append(file)
        
                # Build UI elements grouped by rat, then by day
                for rat_num in range(1, 65):
                    files = rat_to_files[rat_num]
                    if not files:
                        continue
        
                    with dpg.collapsing_header(label=f"Rat {rat_num}", default_open=False, parent=f"{parent}_NWB_window"):
                        
                        # Build a dictionary to group files by day
                        day_to_files = {}
                        for file in files:
                            day_number, day_label = NWBFolder.get_day_from_file(file)
                            if day_label is None:
                                continue
                            if day_label not in day_to_files:
                                day_to_files[day_label] = []
                            day_to_files[day_label].append(file)
        
                        # Sort days using the custom DAY order
                        for day_label in DAYS:
                            if day_label not in day_to_files:
                                continue
                            with dpg.collapsing_header(label=f"{day_label}",tag=f"", default_open=False):
                                for file in sorted(day_to_files[day_label]):
                                    checkbox_tag = f"{parent}_{file}"
                                    dpg.add_checkbox(
                                        label=file,
                                        tag=checkbox_tag,
                                        callback=checkbox_callback,
                                        user_data=(folder_path, f"{parent}_NWB_window", file)
                                    )
                """

    except Exception as e:
        logging.error(f"Error in create_NWB_file_window: {e}", exc_info=True)
        show_popup(error_string)

In [73]:
def display_NWB_files(NWBFolder_Class):
    """
    Function to display NWB files found in input folder.
    """
    try:
        folder_path = NWBFolder_Class.get_folder_path()

        # Get list of NWB files in folder
        nwb_files = [f for f in os.listdir(folder_path) if f.endswith('.nwb')]

        # Clear previous list and update class
        NWBFolder.set_file_list(nwb_files)

        # Clear existing NWB file UI if present
        #if dpg.does_item_exist("import_child_window"):
        #    dpg.delete_item("import_child_window", children_only=True)

        # Recreate NWB file list
        create_NWB_file_window(parent_tag="import_child_window",
                               parent="import", 
                               nwb_files=nwb_files, 
                               folder_path=folder_path,
                               NWBFolder_Class=NWBFolder_Class)

        # Clear and rebuild visualization window
        if dpg.does_item_exist("viz_child_window"):
            dpg.delete_item("viz_child_window", children_only=True)

        with dpg.group(tag="viz_buttons", horizontal=True, parent="viz_child_window"):
            with dpg.child_window(tag="plot_buttons", border=True, autosize_x=True, height=35, parent="viz_buttons"):
                with dpg.group(horizontal=True, parent="plot_buttons"):
                    dpg.add_text("Plot Options:")

        # Create new plot instance and add to plot window
        plot_class_1 = Plot()

        # Clear existing Plot UI if present
        if dpg.does_item_exist("viz_child_window"):
            dpg.delete_item("viz_child_window", children_only=True)
            
        create_plot_window(parent="viz_child_window", plot_instance=plot_class_1)

        # Instantiate Analysis class for Analysis window
        Analysis_class = Analysis()

        # Clear existing Plot UI if present
        if dpg.does_item_exist("ana_child_window"):
            dpg.delete_item("ana_child_window", children_only=True)

        # Fill analysis tab with starting info
        create_analysis_window(parent="ana_child_window", Analysis_class=Analysis_class)

        # Create new JSON instance and add to JSON window
        JSON_class = JSON()
        create_JSON_window(JSON_class)

    except Exception as e:
        logging.error(f"Error in display_NWB_files: {e}", exc_info=True)
        show_popup(error_string)


In [74]:
def load_config():
    """
    Function to load the configuration on startup.
    """
    try:
        if os.path.exists(config_file):
            try:
                with open(config_file, "r") as f:
                    config_data = json.load(f)
                    logging.debug(f"Config data: {config_data}")
                return config_data
            except Exception as e:
                logging.error(f"Error loading config: {e}", exc_info=True)
        return {}
    except Exception as e:
        logging.error(f"Error in load_config: {e}", exc_info=True)
        show_popup(error_string)

In [75]:
def copy_log_to_clipboard():
    '''
    # Function to copy current error log to clipboard
    '''
    try:
        with open(log_filename, 'r') as log_file:
            log_content = log_file.read()
            clipboard_content = f"Log File: {os.path.basename(log_filename)}\n\n{log_content}"
            pyperclip.copy(clipboard_content)
            logging.debug("Log file copied to clipboard with file name")
    except Exception as e:
        logging.error(f"Error in copy_log_to_clipboard: {e}", exc_info=True)
        show_popup(error_string)

In [76]:
def show_popup(message):
    '''
    # Function to create a pop-up window in Dear PyGui
    '''
    try:
        if dpg.does_item_exist("popup_window"):
            dpg.delete_item("popup_window")
        with dpg.window(label="Notification", modal=True, no_title_bar=True, autosize=True, tag="popup_window", show=False):
            dpg.add_text(message, wrap=600)
            with dpg.group(tag="popup_window_buttons", parent="popup_window", horizontal=True):
                dpg.add_button(label="Close", callback=lambda: dpg.delete_item("popup_window"))
                dpg.add_button(label="Open Log", callback=lambda: os.startfile(log_filename))
                dpg.add_button(label="Copy Log to Clipboard", callback=copy_log_to_clipboard)
    
        dpg.show_item("popup_window")
        
        # Allow the window to calculate its size
        dpg.split_frame()
    
        # Get the viewport size
        viewport_width = dpg.get_viewport_client_width()
        viewport_height = dpg.get_viewport_client_height()
    
        # Get the window size
        window_width = dpg.get_item_width("popup_window")
        window_height = dpg.get_item_height("popup_window")
    
        # Calculate the position
        pos_x = (viewport_width - window_width) // 2
        pos_y = (viewport_height - window_height) // 2
    
        # Set the window position
        dpg.set_item_pos("popup_window", [pos_x, pos_y])
    except Exception as e:
        logging.error(f"Error in show_popup: {e}", exc_info=True)

In [77]:
#######################################################################################################################################################
# Callbacks
#######################################################################################################################################################

In [78]:
##### JSON Plotting Callbacks #####

In [79]:
def crunch_PAC_JSON_data_callback(sender, app_data, user_data):
    """
    Callback to plot data of PAC JSON data.
    """
    try:
        Analysis_class = user_data
        output_dir = Analysis_class.get_output_directory()
        
        # Open a folder selection dialog
        root = Tk()
        root.withdraw() # hide the little Tk window
    
        JSON_path = filedialog.askopenfilename(
            initialdir=output_dir,
            title="Select JSON file",
            filetypes=[("JSON files", "*.json")],
        )
        root.destroy()
        if not JSON_path:
            return  # No folder selected

        #----- LF SNR Plotting -----#

        ### 1) Load JSON's into pandas dataframe
        df = load_snr_df(JSON_path, condition_map=None,
                         bands_keep=["Delta","Theta","Wide Theta","High Theta","Beta"])

        ### 2) Track channels with high SNR
        # Ensure snr_db is numeric (paranoia)
        df["snr_db"] = pd.to_numeric(df["snr_db"], errors="coerce")
        
        # Keep rows with SNR >= 6 dB
        hits = df.loc[df["snr_db"].ge(6)].copy()
        
        # Keep only the useful columns
        cols_order = ["rat","day","session","condition","probe","channel","band","snr_db"]
        hits = hits[[c for c in cols_order if c in hits.columns]]
        
        ### 3) Plot per-band distributions
        output_dir = Path(output_dir)
        timestamp = datetime.now().strftime("%Y.%m.%d-%H.%M.%S")
        session_folder = output_dir / f"PAC JSON Statistics {timestamp}"
        session_folder.mkdir(parents=True, exist_ok=True)
        
        figs = plot_snr_distributions(df, save_dir=session_folder, kind="box", thresholds=(3,6))
        
        ### 4) Get a table of medians/IQRs to report
        table = summarize_snr(df)

        ### 5) Save SNR results
        # Table data
        out = os.path.join(session_folder, f"SNR_report.csv")
        table.to_csv(out, float_format="%.4f")

        # High SNR data
        for band, g in hits.groupby("band"):
            out = os.path.join(session_folder, f"High_SNR_channels_in_{band}.csv")
            g.to_csv(out, index=False, float_format="%.4f")

        #----- HF SNR Plotting -----#

        # 1) Load df with HF info
        hf_df = load_hf_df(JSON_path, condition_map=None)
        
        # 2) Compute HF hits and distributions
        # Save high-HF rows if you like (e.g., rel_power ≥ 0.2)
        hf_hits = hf_df[hf_df["hf_rel_power"].ge(0.4)].copy()
        out = os.path.join(session_folder, f"hf_hits_relpower_ge0.2.csv")

        # 3) Keep only the useful columns
        cols_order = ["rat","day","session","condition","probe","hf_band_low", "hf_band_high", "hf_power", "hf_power_db", "ref_power", "hf_rel_power", "hf_percentile_in_ref"]
        hf_hits = hf_hits[[c for c in cols_order if c in hf_hits.columns]]
        hf_hits.to_csv(out, index=False, float_format="%.4f")
        
        # 4) Plot distributions
        plot_hf_distributions(hf_df, metric="hf_rel_power", save_dir=session_folder, thresholds=(0.3, 0.5))
        plot_hf_distributions(hf_df, metric="hf_percentile_in_ref", save_dir=session_folder, thresholds=(60,80))


        #----- PAC AVG Trajectory Plotting -----#
        # Load data from JSON file
        df_PAC_AVG = load_pac_json_to_long_df(JSON_path)
        if df_PAC_AVG is not None:
            """
            # Case A: Day 0 sessions only (D0:0, D0:1, D0:2)
            plot_pac_trajectories(df_PAC_AVG,
                                  metrics=["primary_summary_value","peak_value","z_abs_mean",
                                           "z_topk_mean","frac_sig_ge_1.96","peak_phase_hz","peak_amplitude_hz"],
                                  case="Full_House",
                                  save=True,
                                  output_dir=session_folder,
                                  termes_font=termes_font, termes_font_bold=termes_font_bold,
                                  clamp_y0=False)

            
            # Case B: D0:1 → Day 1 → Day 3 → Day 7 → Day 14
            plot_pac_trajectories(df_PAC_AVG,
                                  metrics=["primary_summary_value","peak_value","z_abs_mean",
                                           "z_topk_mean","frac_sig_ge_1.96","peak_phase_hz","peak_amplitude_hz"],
                                  case="D0_1_to_14",
                                  save=True,
                                  output_dir=session_folder,
                                  termes_font=termes_font, termes_font_bold=termes_font_bold,
                                  clamp_y0=False)
            """
            # Case study of all time points of interest
            plot_pac_trajectories(df_PAC_AVG,
                                  metrics=["primary_summary_value","peak_value","z_abs_mean",
                                           "z_topk_mean","frac_sig_ge_1.96","peak_phase_hz","peak_amplitude_hz"],
                                  case="Full_House",
                                  save=True,
                                  output_dir=session_folder,
                                  termes_font=termes_font, termes_font_bold=termes_font_bold,
                                  clamp_y0=False)

    except Exception as e:
        logging.error(f"Error in crunch_PAC_JSON_data_callback: {e}", exc_info=True)
        show_popup(error_string)

In [80]:
def crunch_PSD_JSON_data_callback(sender, app_data, user_data):
    """
    Callback to plot data of JSON checkbox.
    """
    try:
        Analysis_class = user_data
        output_dir = Analysis_class.get_output_directory()
        
        # Open a folder selection dialog
        root = Tk()
        root.withdraw() # hide the little Tk window
    
        JSON_path = filedialog.askopenfilename(
            initialdir=output_dir,
            title="Select JSON file",
            filetypes=[("JSON files", "*.json")],
        )
        root.destroy()
        if not JSON_path:
            return  # No folder selected

        plot_trajectories(JSON_path,
                          metrics=["Broadband_dB_Raw",
                                   "ThetaDelta_dB_Raw",
                                   "Aperiodic_Exponent"],
                          probes=None, # or None to include all
                          save=True,
                          output_dir = output_dir)
    except Exception as e:
        logging.error(f"Error in crunch_PSD_JSON_data_callback: {e}", exc_info=True)
        show_popup(error_string)

In [81]:
def JSON_plot_data_callback(sender, app_data, user_data):
    """
    Callback to plot data of JSON checkbox.
    """
    try:
        rat_id, day_id, channel_name, probe, data, session_id = user_data
        if app_data == True:
            dpg.add_scatter_series([int(day_id)],
                                   [data],
                                   label=f"R{rat_id}_D{day_id}_{channel_name}_{probe}",
                                   parent="JSON_y_axis",
                                   tag=f"F{session_id}_R{rat_id}_D{day_id}_{channel_name}_{probe}")
        else:
            if dpg.does_item_exist(f"F{session_id}_R{rat_id}_D{day_id}_{channel_name}_{probe}"):
                dpg.delete_item(f"F{session_id}_R{rat_id}_D{day_id}_{channel_name}_{probe}")
    
    except Exception as e:
        logging.error(f"Error in JSON_plot_data_callback: {e}", exc_info=True)
        show_popup(error_string)

In [82]:
def JSON_plot_data_series_callback(sender, app_data, user_data):
    """
    Callback to plot data of JSON checkbox.
    """
    try:
        rat_id, day_series, channel_name, probe, average_pac_series, session_id = user_data
        if app_data == True:
            dpg.add_scatter_series(day_series,
                                   average_pac_series,
                                   label=f"R{rat_id}_{channel_name}_{probe}",
                                   parent="JSON_y_axis",
                                   tag=f"F{session_id}_R{rat_id}_{channel_name}_{probe}")
        else:
            if dpg.does_item_exist(f"F{session_id}_R{rat_id}_{channel_name}_{probe}"):
                dpg.delete_item(f"F{session_id}_R{rat_id}_{channel_name}_{probe}")
    
    except Exception as e:
        logging.error(f"Error in JSON_plot_data_series_callback: {e}", exc_info=True)
        show_popup(error_string)

In [83]:
def import_JSON_file_callback(sender, app_data, user_data):
    """
    Callback to import user-selected JSON file to program.
    """
    try:
        JSON_class, meta_parent, data_parent, plot_parent = user_data
        # Open a folder selection dialog
        root = Tk()
        root.withdraw()  # Hide the root window
        JSON_file = filedialog.askopenfilename(initialdir="PAC Output",
                                            title="Select JSON File",
                                            filetypes=[("JSON files", "json")])
        if not JSON_file:
            return  # No folder selected
        metadata, pac_scores = load_pac_summary_json(JSON_file)

        if "Session Name" in metadata:
            session_name = metadata["Session Name"]
        else:
            session_name = "Session (Name DNE)"

        session_id = JSON_class.get_session_ID_counter()
        JSON_class.set_session_ID_counter(session_id + 1)
        dpg.add_tree_node(label=session_name,
                          tag=f"JSON_{session_id}_meta_tree",
                          parent=meta_parent)

        for key in metadata:
            if key != "Session Name":
                dpg.add_text(f"{key}: {metadata[key]}",
                             parent=f"JSON_{session_id}_meta_tree")

        # Loop through rat → day → channel
        dpg.add_tree_node(label=session_name,
                          tag=f"JSON_{session_id}_data_tree",
                          parent=data_parent)

        # Create individual rat-day-channel checkboxes
        for rat_id, rat_data in pac_scores.items():
            dpg.add_tree_node(label=f"Rat {rat_id}",
                              tag=f"JSON_{session_id}_{rat_id}_data_tree",
                              parent=f"JSON_{session_id}_data_tree")

            dpg.add_tree_node(label="Channels",
                          tag=f"JSON_{session_id}_{rat_id}_channel_tree",
                          parent=f"JSON_{session_id}_{rat_id}_data_tree")
            
            for day_id, day_data in rat_data.items():
                dpg.add_tree_node(label=f"Day {day_id}",
                              tag=f"JSON_{session_id}_{rat_id}_{day_id}_data_tree",
                              parent=f"JSON_{session_id}_{rat_id}_data_tree")
                
                for channel_name, channel_data in day_data.items():
                    # Reduce channel number to 1-16
                    ch_split = channel_name.split('CSC')
                    ch_num = int(ch_split[1])
                    ch_num = (ch_num - 1) % 16 + 1
                    channel_name = ch_split[0] + str(ch_num)

                    # Add checkbox for individual rat-channel-day
                    average_pac = channel_data['average_pac']
                    probe = channel_data['Probe Type']
                    dpg.add_checkbox(label=f"Channel {channel_name}",
                                     tag=f"JSON_{session_id}_{rat_id}_{day_id}_{channel_name}_data_tree",
                                     parent=f"JSON_{session_id}_{rat_id}_{day_id}_data_tree",
                                     default_value=False,
                                     callback=JSON_plot_data_callback,
                                     user_data=(rat_id,
                                                day_id,
                                                channel_name,
                                                probe,
                                                average_pac,
                                                session_id))

                    # Add checkbox for channel across all days
                    if dpg.does_item_exist(f"JSON_{session_id}_{rat_id}_{channel_name}_checkbox"):
                        # Get current user_data
                        (rat_id, day_series, channel_name, probe, average_pac_series,
                         session_id) = dpg.get_item_user_data(f"JSON_{session_id}_{rat_id}_{channel_name}_checkbox")
                        # Append new average_pac to series and day
                        average_pac_series.append(average_pac)
                        day_series.append(int(day_id))
                        # Update user_data
                        dpg.set_item_user_data(f"JSON_{session_id}_{rat_id}_{channel_name}_checkbox",
                                               (rat_id, day_series, channel_name, probe,
                                               average_pac_series, session_id))
                    else:
                        # Make checkbox if this is the first time enountering channel
                        dpg.add_checkbox(label=f"Channel {channel_name}",
                                         tag=f"JSON_{session_id}_{rat_id}_{channel_name}_checkbox",
                                         parent=f"JSON_{session_id}_{rat_id}_channel_tree",
                                         default_value=False,
                                         callback=JSON_plot_data_series_callback,
                                         user_data=(rat_id,
                                                    [int(day_id)],
                                                    channel_name,
                                                    probe,
                                                    [average_pac],
                                                    session_id))
                    
    except Exception as e:
        logging.error(f"Error in import_JSON_file_callback: {e}", exc_info=True)
        show_popup(error_string)

In [84]:
##### Analysis Callbacks #####

In [85]:
def PAC_options_callback(sender, app_data, user_data):
    """
    Callback to update selected PAC Options.
    """
    try:
        setting = sender
        updated_value = app_data
        Analysis_class = user_data

        if setting == "PAC_options_low_override":
            Analysis_class.set_lowfreq_override(updated_value)
        elif setting == "PAC_options_low_lowpass":
            Analysis_class.set_lowfreq_lowpass(updated_value)
        elif setting == "PAC_options_low_highpass":
            Analysis_class.set_lowfreq_highpass(updated_value)
        elif setting == "PAC_options_low_width":
            Analysis_class.set_lowfreq_width(updated_value)
        elif setting == "PAC_options_low_step":
            Analysis_class.set_lowfreq_step(updated_value)
        
        elif setting == "PAC_options_high_override":
            Analysis_class.set_highfreq_override(updated_value)
        elif setting == "PAC_options_high_lowpass":
            Analysis_class.set_highfreq_lowpass(updated_value)
        elif setting == "PAC_options_high_highpass":
            Analysis_class.set_highfreq_highpass(updated_value)
        elif setting == "PAC_options_high_width":
            Analysis_class.set_highfreq_width(updated_value)
        elif setting == "PAC_options_high_step":
            Analysis_class.set_highfreq_step(updated_value)

        elif setting == "PAC_options_high_pass_filter":
            Analysis_class.set_high_pass_filter(updated_value)
        elif setting == "PAC_options_filter_notch_frequencies":
            Analysis_class.set_filter_notch_frequencies(updated_value)
        elif setting == "PAC_options_detrend_epochs":
            Analysis_class.set_detrend_epochs(updated_value)
        elif setting == "PAC_options_apply_autofilter":
            Analysis_class.set_apply_autofilter(updated_value)

        elif setting == "PAC_options_PAC_methods":
            Analysis_class.set_PAC_method(updated_value)
        elif setting == "PAC_options_surrogate_method":
            Analysis_class.set_surrogate_method(updated_value)
        elif setting == "PAC_options_normalization_method":
            Analysis_class.set_normalization_method(updated_value)
        elif setting == "PAC_options_dcomplex_method":
            Analysis_class.set_dcomplex_method(updated_value)
        elif setting == "PAC_options_phase_cycles":
            Analysis_class.set_phase_cycles(updated_value)
        elif setting == "PAC_options_amplitude_cycles":
            Analysis_class.set_amplitude_cycles(updated_value)
        elif setting == "PAC_options_morlet’s_wavelet_width":
            Analysis_class.set_morlet_width(updated_value)
        elif setting == "PAC_options_KLD_or_HRPAC_bins":
            Analysis_class.set_KLD_or_HRPAC_bins(updated_value)

        elif setting == "Performance_options_sampling_rate":
            Analysis_class.set_sampling_rate(updated_value)
        elif setting == "Performance_options_downsample_rate":
            Analysis_class.set_downsample_rate(updated_value)
        elif setting == "Performance_options_epoch_length":
            Analysis_class.set_epoch_length(updated_value)
        elif setting == "Performance_options_data_length":
            Analysis_class.set_data_length(updated_value)  
        elif setting == "Performance_options_minimum_length":
            Analysis_class.set_minimum_length(updated_value)
        elif setting == "Performance_options_parallel_processes":
            Analysis_class.set_parallel_processes(updated_value)
        elif setting == "Performance_options_surrogate_count":
            Analysis_class.set_surrogate_count(updated_value)

        elif setting == "PSD_options_high_pass_filter":
            Analysis_class.set_PSD_high_pass_filter(updated_value)
        elif setting == "PSD_options_notch_frequencies":
            Analysis_class.set_PSD_notch_frequencies(updated_value)
        elif setting == "PSD_options_correct_1overf":
            Analysis_class.set_PSD_correct_1overf(updated_value)
        elif setting == "PSD_options_fmin":
            Analysis_class.set_PSD_fmin(updated_value)
        elif setting == "PSD_options_fmax":
            Analysis_class.set_PSD_fmax(updated_value)
        elif setting == "PSD_options_voltage_scale":
            Analysis_class.set_PSD_voltage_scale(updated_value)

        elif setting == "Export_options_output_folder_name":
            Analysis_class.set_output_folder_name(updated_value)
        elif setting == "Export_options_export_PNG":
            Analysis_class.set_export_PNG(updated_value)
        elif setting == "Export_options_export_PDF":
            Analysis_class.set_export_PDF(updated_value)
        elif setting == "Export_options_export_SVG":
            Analysis_class.set_export_SVG(updated_value)
        elif setting == "Export_options_export_EPS":
            Analysis_class.set_export_EPS(updated_value)
        elif setting == "Export_options_image_height":
            Analysis_class.set_image_height(updated_value)
        elif setting == "Export_options_image_width":
            Analysis_class.set_image_width(updated_value)
        elif setting == "Export_options_image_DPI":
            Analysis_class.set_image_DPI(updated_value)
        
    except Exception as e:
        logging.error(f"Error in PAC_options_callback: {e}", exc_info=True)
        show_popup(error_string)

In [86]:
def PAC_callback(sender, app_data, user_data):
    """
    Callback to start PAC analysis of selected files.
    Determines how to process PAC as individual channels
    or all associated channels of a rat, based on the 
    user's decision to use autofilter.
    """
    try:
        Analysis_class = user_data
        autofilter_status = Analysis_class.get_apply_autofilter()
        if autofilter_status:
            compute_PAC_of_rats(Analysis_class)
        else:
            compute_PAC_of_channels(Analysis_class)
    except Exception as e:
        logging.error(f"Error in PAC_callback: {e}", exc_info=True)
        show_popup(error_string)

In [87]:
def update_ana_nwb_list_callback(sender, app_data, user_data):
    """
    Callback to update list of selected NWB files in analysis class.
    """
    try:
        Analysis_class, file, channel_column_location = user_data
        checked = app_data
        if checked == True:
            Analysis_class.add_to_nwb_list(file)

            dpg.add_tree_node(label=file, tag=f"{file}_tree", parent=channel_column_location)
            folder_path = Analysis_class.get_folder_path
            file_path = os.path.join(folder_path, file)
            with NWBHDF5IO(file_path, 'r') as io:
                nwb = io.read()
                # Convert electrodes table to a pandas DataFrame
                df_electrodes = nwb.electrodes.to_dataframe()

                # Create channel_name -> data index mapping
                channel_name_to_index = {row['channel_name']: idx for idx, row in df_electrodes.iterrows()}
            
                # Extract and sort channel names by number (e.g., CSC10 -> 10)
                sorted_channel_names = sorted(
                    channel_name_to_index.keys(),
                    key=lambda name: int(name.replace('CSC', ''))
                )

                # Make QOL Buttons to select all channels for each rat
                rats_in_file = Analysis_class.get_rats_in_file(file)
                for rat in rats_in_file:
                    rat_channels = Analysis_class.get_channels_from_rat(file, rat)

                    # Get the acutal column # of data in NWB file corresponding to this channel
                    data_columns = []
                    for channel in rat_channels:
                        channel_name = 'CSC' + str(channel)
                        if channel_name in channel_name_to_index:
                            data_columns.append(channel_name_to_index[channel_name])
                    
                    dpg.add_checkbox(
                        label=f"Rat {rat}",
                        tag=f"{file}_{rat}",
                        parent=f"{file}_tree",
                        callback=update_selected_rat_callback,
                        user_data=(Analysis_class, file, rat, rat_channels, data_columns)
                    )

                dpg.add_separator(parent=f"{file}_tree")

                # Add buttons for each channel
                for channel_name in sorted_channel_names:
                    if channel_name not in channel_name_to_index:
                        continue  # skip weird cases
                    
                    data_index = channel_name_to_index[channel_name]
                    dpg.add_checkbox(
                        label=channel_name,
                        tag=f"{file}_{channel_name}",
                        parent=f"{file}_tree",
                        callback=update_selected_PAC_channels_callback,
                        user_data=(Analysis_class, file, channel_name, data_index)
                    )
        else:
            # Remove from list of selected nwb files
            Analysis_class.remove_from_nwb_list(file)

            # Remove NWB file entries in selected channels dict
            if file in Analysis_class.selected_channels:
                del Analysis_class.selected_channels[file]

            # Delete the dearpygui tree node and children
            if dpg.does_item_exist(f"{file}_tree"):
                dpg.delete_item(f"{file}_tree")

    except Exception as e:
        logging.error(f"Error in update_ana_nwb_list_callback: {e}", exc_info=True)
        show_popup(error_string)

In [88]:
def update_selected_rat_callback(sender, app_data, user_data):
    """
    Callback to select all channels associated to a rat.
    """
    try:
        Analysis_class, file, rat, rat_channels, data_columns = user_data
        is_checked = app_data

        # Update rat selection
        if is_checked:
            Analysis_class.add_selected_rat(file, rat, rat_channels, data_columns)
        else:
            Analysis_class.remove_selected_rat(file, rat)

        # Update corresponding channel selections
        for channel_name in rat_channels:
            channel_name = 'CSC' + str(channel_name)
            checkbox_tag = f"{file}_{channel_name}"
            if dpg.does_item_exist(checkbox_tag):
                dpg.set_value(checkbox_tag, is_checked)
                
                config = dpg.get_item_configuration(checkbox_tag)
                callback = config.get("callback")
                user_data = config.get("user_data")
    
                if callback:
                    callback(checkbox_tag, is_checked, user_data)
    except Exception as e:
        logging.error(f"Error in update_selected_rat_callback: {e}", exc_info=True)
        show_popup(error_string)

In [89]:
def update_selected_PAC_channels_callback(sender, app_data, user_data):
    """
    Callback to update selected channels for each NWB file.
    """
    try:
        Analysis_class, file, channel_name, data_index = user_data
        checked = app_data  # True if checked, False if unchecked

        if checked:
            Analysis_class.add_selected_channel(file, channel_name, data_index)
        else:
            Analysis_class.remove_selected_channel(file, channel_name)

    except Exception as e:
        logging.error(f"Error in update_selected_PAC_channels_callback: {e}", exc_info=True)
        show_popup(str(e))

In [90]:
def open_analysis_callback(sender, app_data, user_data):
    """
    Callback to instantiate a window in the analysis tab for data analysis.
    """
    try:
        # Instantiate first instance of analysis class
        Analysis_class = Analysis()

        # Fill analysis tab with starting info
        create_analysis_window(parent="ana_child_window", Analysis_class=Analysis_class)
    except Exception as e:
        logging.error(f"Error in open_analysis_callback: {e}", exc_info=True)
        show_popup(error_string)

In [91]:
def set_PAC_output_directory_callback(sender, app_data, user_data):
    '''
    # Function to get where NWB files are stored
    '''
    try:
        Analysis_class = user_data
        # Open a folder selection dialog
        root = Tk()
        root.withdraw()  # Hide the root window
        project_folder = filedialog.askdirectory(initialdir="PAC Output", title="Set Output Folder Location")
        if not project_folder:
            return  # No folder selected
        Analysis_class.set_output_directory(project_folder)

    except Exception as e:
        logging.error(f"Error in set_PAC_output_directory: {e}", exc_info=True)
        show_popup(error_string)

In [92]:
##### Plotting Callbacks #####

In [93]:
def select_NWB_file_callback(sender, app_data, user_data):
    """
    Callback when an NWB file is selected.
    Updates the Plot instance's nwb_file attribute and updates the channel combo.
    """
    try:
        plot_instance = user_data  # Passed in as user_data from the callback
        selected_file = app_data  # The NWB file name selected from the combo
        logging.debug(f"Plot '{plot_instance.ID}': Selected NWB file: {selected_file}")
        plot_instance.nwb_file = selected_file # Update selected file in class
    
        # Construct full file path
        file_path = os.path.join(plot_instance.get_folder_path, selected_file)
        # Get channels available in this file
        channels = get_channels_for_file(file_path,plot_instance)
        # Update the channel combo widget using its unique tag.
        channel_combo_tag = f"{plot_instance.ID}_channel_combo"
        dpg.configure_item(channel_combo_tag, items=channels)

        # Clear channel selection and plot if anything has already been selected
        dpg.set_value(channel_combo_tag, "")
        dpg.set_value(f"{plot_instance.ID}_plot_series", [[],[]])
        dpg.configure_item(f"{plot_instance.ID}_plot_name", label=f"")
        
    except Exception as e:
        logging.error(f"Error in selected_NWB_file_callback: {e}", exc_info=True)
        show_popup(error_string)

In [94]:
def select_NWB_channel_callback(sender, app_data, user_data):
    """
    Callback when a channel is selected.
    Updates the Plot instance's channel attribute, gets the raw data,
    and updates the plot with the new data.
    """
    try:
        plot_instance = user_data
        selected_channel = app_data
        # Update selected channel
        plot_instance.set_channel(selected_channel)
        # Construct full file path from the selected NWB file
        file_path = os.path.join(plot_instance.get_folder_path, plot_instance.nwb_file)
        
        plot_data(plot_instance)

        # Extract column of interest for updating GUI info
        electrode_mapping = plot_instance.get_electrode_mapping()
        selected_channel = plot_instance.get_channel()
        raw_data_column_index = electrode_mapping[selected_channel]

        # Update maximum allowed range for selected channel
        max_length = get_raw_data_column_length(file_path)
        plot_instance.set_data_max(max_length)
        updated_range = f"Data Range: {plot_instance.get_data_min()}-{max_length}"
        dpg.set_value(f"{plot_instance.ID}_range_text", updated_range)
        dpg.configure_item(f"{plot_instance.ID}_end_int", max_value=max_length)
        dpg.configure_item(f"{plot_instance.ID}_start_int", max_value=max_length)

        # Update name of plot
        rat_num, probe = plot_instance.get_rat_and_probe_from_channel
        plot_instance.set_plot_name(rat_num)
        dpg.configure_item(f"{plot_instance.ID}_plot_name",
                           label=f"Rat {rat_num} - {selected_channel} - ({probe})")

        # Update start and end widgets
        dpg.set_value(f"{plot_instance.ID}_start_int", plot_instance.get_data_start())
        dpg.set_value(f"{plot_instance.ID}_end_int", plot_instance.get_data_end())

        # Fit both axes to the new data
        dpg.fit_axis_data(f"{plot_instance.ID}_x_axis")
        dpg.fit_axis_data(f"{plot_instance.ID}_y_axis")
        
    except Exception as e:
        logging.error(f"Error in select_NWB_channel_callback: {e}", exc_info=True)
        show_popup(error_string)

In [95]:
def voltage_scale_callback(sender, app_data, user_data):
    """
    Callback when a voltage scale is changed, updates corresponding plot.
    """
    try:
        plot_instance = user_data
        new_voltage_scale = app_data

        # Convert string representation to int
        if new_voltage_scale == "Millivolts":
            new_voltage_scale = 1000
        elif new_voltage_scale == "Microvolts":
            new_voltage_scale = 1000000
        else:
            new_voltage_scale = 1

        # Get current scale and set new scale
        current_voltage_scale = plot_instance.get_voltage_scale()
        plot_instance.set_voltage_scale(new_voltage_scale)

        # Compute ratio and get plot data to adjust y axis
        scaling_ratio = new_voltage_scale / current_voltage_scale
        x, y = dpg.get_value(f"{plot_instance.ID}_plot_series")
        
        # Convert y to a NumPy array if it isn’t already, then scale
        y_scaled = (np.array(y) * scaling_ratio).tolist()
        
        # Update the line series
        dpg.set_value(f"{plot_instance.ID}_plot_series", [x, y_scaled])

        # Fit both axes to the new data
        dpg.fit_axis_data(f"{plot_instance.ID}_x_axis")
        dpg.fit_axis_data(f"{plot_instance.ID}_y_axis")
    except Exception as e:
        logging.error(f"Error in voltage_scale_callback: {e}", exc_info=True)
        show_popup(error_string)

In [96]:
def set_start_callback(sender, app_data, user_data):
    """
    Set starting x-axis values in seconds for selected plot
    """
    try:
        # Update left-most point to observe
        plot_instance = user_data
        new_data_start = app_data
        plot_instance.set_data_start(new_data_start)

        # Update right-most point if necessary
        data_end = plot_instance.get_data_end()
        if data_end < new_data_start:
            plot_instance.set_data_end(new_data_start)
            dpg.set_value(f"{plot_instance.ID}_end_int", new_data_start)

        plot_data(plot_instance)
        
    except Exception as e:
        logging.error(f"Error in set_start_callback: {e}", exc_info=True)
        show_popup(error_string)

In [97]:
def set_end_callback(sender, app_data, user_data):
    """
    Set ending x-axis values in seconds for selected plot
    """
    try:
        # Update right-most point to observe
        plot_instance = user_data
        new_data_end = app_data
        plot_instance.set_data_end(new_data_end)

        # Update left-most point if necessary
        data_start = plot_instance.get_data_start()
        if data_start > new_data_end:
            plot_instance.set_data_start(new_data_end)
            dpg.set_value(f"{plot_instance.ID}_start_int", new_data_end)

        plot_data(plot_instance)
    except Exception as e:
        logging.error(f"Error in set_end_callback: {e}", exc_info=True)
        show_popup(error_string)

In [98]:
def shift_data_right_callback(sender, app_data, user_data):
    """
    Shifts observed data to the right in seconds by shift length specified in shift widget
    """
    try:
        # Calculate new window of data to observe
        plot_instance = user_data
        shift_length = dpg.get_value(f"{plot_instance.ID}_shift_int")
        data_start = plot_instance.get_data_start() + shift_length
        data_end = plot_instance.get_data_end() + shift_length
        
        # Check if shift has exceeded max allowable value
        max_value = plot_instance.get_data_max()
        if data_start > max_value:
            data_start = max_value
        if data_end > max_value:
            data_end = max_value
        
        # Update start and end widgets
        plot_instance.set_data_start(data_start)
        plot_instance.set_data_end(data_end)
        dpg.set_value(f"{plot_instance.ID}_start_int", data_start)
        dpg.set_value(f"{plot_instance.ID}_end_int", data_end)

        plot_data(plot_instance)

        # Fit both axes to the new data
        dpg.fit_axis_data(f"{plot_instance.ID}_x_axis")
        dpg.fit_axis_data(f"{plot_instance.ID}_y_axis")

        sync_axis(plot_instance)
    except Exception as e:
        logging.error(f"Error in shift_data_right_callback: {e}", exc_info=True)
        show_popup(error_string)

In [99]:
def shift_data_left_callback(sender, app_data, user_data):
    """
    Shifts observed data to the left in seconds by shift length specified in shift widget
    """
    try:
        # Calculate new window of data to observe
        plot_instance = user_data
        shift_length = dpg.get_value(f"{plot_instance.ID}_shift_int")
        data_start = plot_instance.get_data_start() - shift_length
        data_end = plot_instance.get_data_end() - shift_length
        
        # Check if shift has exceeded min allowable value
        min_value = plot_instance.get_data_min()
        if data_start < min_value:
            data_start = min_value
        if data_end < min_value:
            data_end = min_value

        # Update start and end widgets
        plot_instance.set_data_start(data_start)
        plot_instance.set_data_end(data_end)
        dpg.set_value(f"{plot_instance.ID}_start_int", data_start)
        dpg.set_value(f"{plot_instance.ID}_end_int", data_end)

        plot_data(plot_instance)
        
        # Fit both axes to the new data
        dpg.fit_axis_data(f"{plot_instance.ID}_x_axis")
        dpg.fit_axis_data(f"{plot_instance.ID}_y_axis")

        sync_axis(plot_instance)
    except Exception as e:
        logging.error(f"Error in shift_data_left_callback: {e}", exc_info=True)
        show_popup(error_string)

In [100]:
def change_plot_type_callback(sender, app_data, user_data):
    """
    Callback that updates chosen plot type.
    """
    try:
        plot_instance = user_data
        plot_type = app_data
        plot_instance.set_plot_type(plot_type)

        if plot_type == "Filter":
            dpg.add_separator(parent=f"{plot_instance.ID}_import_tab",
                              tag=f"{plot_instance.ID}_filter_separator")
            dpg.add_text(default_value=f"Filter Settings",
                         parent=f"{plot_instance.ID}_import_tab",
                         tag=f"{plot_instance.ID}_filter_label")
            dpg.add_input_int(label="Lowcut (f)",
                              tag=f"{plot_instance.ID}_lowcut",
                              parent=f"{plot_instance.ID}_import_tab",
                              width=200,
                              callback=set_lowcut_callback,
                              user_data=plot_instance,
                              on_enter=True,
                              min_value=0,
                              min_clamped=True,
                              default_value=1)
            dpg.add_input_int(label="Highcut (f)",
                              tag=f"{plot_instance.ID}_highcut",
                              parent=f"{plot_instance.ID}_import_tab",
                              width=200,
                              callback=set_highcut_callback,
                              user_data=plot_instance,
                              on_enter=True,
                              min_value=0,
                              min_clamped=True,
                              default_value=200)
        if plot_type == "Raw":
            if dpg.does_item_exist(f"{plot_instance.ID}_filter_separator"):
                dpg.delete_item(f"{plot_instance.ID}_filter_separator")
            if dpg.does_item_exist(f"{plot_instance.ID}_filter_label"):
                dpg.delete_item(f"{plot_instance.ID}_filter_label")
            if dpg.does_item_exist(f"{plot_instance.ID}_lowcut"):
                dpg.delete_item(f"{plot_instance.ID}_lowcut")
            if dpg.does_item_exist(f"{plot_instance.ID}_highcut"):
                dpg.delete_item(f"{plot_instance.ID}_highcut")
            plot_data(plot_instance)
    
    except Exception as e:
        logging.error(f"Error in change_plot_type_callback: {e}", exc_info=True)
        show_popup(error_string)

In [101]:
def set_lowcut_callback(sender, app_data, user_data):
    """
    Callback that updates lowcut value.
    """
    try:
        plot_instance = user_data
        lowcut = app_data
        plot_instance.set_lowcut(lowcut)
        plot_data(plot_instance)
    except Exception as e:
        logging.error(f"Error in set_lowcut_callback: {e}", exc_info=True)
        show_popup(error_string)

In [102]:
def set_highcut_callback(sender, app_data, user_data):
    """
    Callback that updates highcut value.
    """
    try:
        plot_instance = user_data
        highcut = app_data
        plot_instance.set_highcut(highcut)
        plot_data(plot_instance)
    except Exception as e:
        logging.error(f"Error in set_highcut_callback: {e}", exc_info=True)
        show_popup(error_string)

In [103]:
def sync_axis_callback(sender, app_data, user_data):
    """
    Callback that syncs every plot to match the new axis.
    """
    try:
        plot_instance = user_data
        checkbox = app_data

        # Crucially remove lock on plot anytime sync is called
        dpg.set_axis_limits_auto(f"{plot_instance.ID}_x_axis")
        dpg.set_axis_limits_auto(f"{plot_instance.ID}_y_axis")

        if checkbox == True:
            plot_instance.add_to_sync_list(plot_instance.ID)
        else:
            plot_instance.remove_from_sync_list(plot_instance.ID)

        sync_axis(plot_instance)
    except Exception as e:
        logging.error(f"Error in sync_axis_callback: {e}", exc_info=True)
        show_popup(error_string)

In [104]:
def export_trigger_callback(sender, app_data, user_data):
    """
    Callback that sets paremeters for generating figure of chosen plots.
    """
    try:
        plot_instance = user_data
        rat_num, probe = plot_instance.get_rat_and_probe_from_channel
        selected_channel = plot_instance.get_channel()

        # Use either default output path or custom
        custom_name = plot_instance.get_custom_name()
        output_choice = plot_instance.get_export_status()
        if output_choice == True:
            name = custom_name
        else:
            name = f"Rat {rat_num} - {selected_channel} - {probe}"

        # Grab data
        x_data, y_data = dpg.get_value(f"{plot_instance.ID}_plot_series")

        # Grab plot boundaries
        x_min, x_max = dpg.get_axis_limits(f"{plot_instance.ID}_x_axis")
        y_min, y_max = dpg.get_axis_limits(f"{plot_instance.ID}_y_axis")

        # Grab export path, use program folder location as default
        export_path = plot_instance.get_plot_output_path()

        if not export_path:  # If user hasn't set a path (default value)
            export_path = os.path.join(os.getcwd(), "Plot Output")
            os.makedirs(export_path, exist_ok=True)  # Ensure folder exists

        export_plot(x_data, y_data, x_min, x_max, y_min, y_max, name, export_path)
    except Exception as e:
        logging.error(f"Error in export_trigger_callback: {e}", exc_info=True)
        show_popup(error_string)

In [105]:
def custom_plot_name_callback(sender, app_data, user_data):
    """
    Callback that overrides default naming scheme
    """
    try:
        plot_instance = user_data
        custom_name = app_data

        plot_instance.set_custom_name(app_data)
    except Exception as e:
        logging.error(f"Error in export_trigger_callback: {e}", exc_info=True)
        show_popup(error_string)

In [106]:
def update_export_status_callback(sender, app_data, user_data):
    """
    Callback that updates if plot should use a custom filepath.
    """
    try:
        plot_instance = user_data
        export_status = app_data
        plot_instance.set_export_status(export_status)
    except Exception as e:
        logging.error(f"Error in update_export_status_callback: {e}", exc_info=True)
        show_popup(error_string)

In [107]:
def select_plot_output_callback(sender, app_data, user_data):
    """
    Callback that sets a custom output filepath for plots.
    """
    try:
        plot_instance = user_data
        
        # Open a folder selection dialog
        root = Tk()
        root.withdraw()  # Hide the root window
        plot_output_folder = filedialog.askdirectory(initialdir="Projects", title="Select Output Folder")
        root.destroy()  # Destroy the Tk instance
        if not plot_output_folder:
            return  # No folder selected

        plot_instance.set_plot_output_path(plot_output_folder)
        logging.debug(f"User-selected plot output path: {plot_output_folder}")
    except Exception as e:
        logging.error(f"Error in select_plot_output_callback: {e}", exc_info=True)
        show_popup(error_string)

In [108]:
def add_plot_callback(sender, app_data, user_data):
    """
    Button to add an additional plot
    """
    try:
        parent = user_data
        plot_class_2 = Plot()
        if dpg.does_item_exist(sender):
            dpg.delete_item(sender)
        create_plot_window(parent, plot_class_2)

    except Exception as e:
        logging.error(f"Error in add_plot_callback: {e}", exc_info=True)
        show_popup(error_string)

In [109]:
##### I/O & Misc Callbacks #####

In [110]:
def open_output_location_callback(sender, app_data, user_data):
    '''
    # Function to open current analysis output location.
    '''
    try:
        Analysis_class = user_data
        output_dir = Analysis_class.get_output_directory()

        if not os.path.isdir(output_dir):
            logging.warning(f"Output directory does not exist: {output_dir}")
            show_popup(f"Output directory does not exist:\n{output_dir}")
            return

        os.startfile(output_dir)
        
    except Exception as e:
        logging.error(f"Error in open_output_location_callback: {e}", exc_info=True)
        show_popup(error_string)

In [111]:
def select_NWB_folder_callback(sender, app_data):
    '''
    # Callback to find then display NWB files
    '''
    try:
        # Get filepath to NWB folder from user, Display NWB files found
        new_folder = get_NWB_folder_filepath()
        if new_folder is None:
            # The user may have canceled the folder selection.
            logging.debug(f"Cancelled folder selection in select_NWB_folder_callback.")
            return
        NWBFolder.set_folder_path(new_folder)
        display_NWB_files(NWBFolder_Class)
    except Exception as e:
        logging.error(f"Error in select_NWB_folder_callback: {e}", exc_info=True)
        show_popup(error_string)

In [112]:
def save_NWB_folder_callback(sender, app_data, user_data):
    """
    Callback to save the NWB folder and checkbox state to JSON.
    """
    try:
        folder_path = NWBFolder.get_folder_path()
        is_checked = dpg.get_value("save_NWB_folder")
        config_data = {
            "nwb_folder": folder_path,
            "save_nwb_folder": is_checked
        }
        with open(config_file, "w") as f:
            json.dump(config_data, f)
    except Exception as e:
        logging.error(f"Error in save_NWB_folder_callback: {e}", exc_info=True)
        show_popup(error_string)

In [113]:
def checkbox_callback(sender, app_data, user_data):
    """
    Callback for checkbox interaction.
    Clears other checkboxes within the same parent context and displays metadata of the selected file.
    """
    try:
        # Obtain information on sender and relevant tags
        folder_path, parent, file = user_data
        children = dpg.get_item_children(parent)
        children_ids = children[1]
        tag_list = [dpg.get_item_alias(child_id) for child_id in children_ids]
        logging.debug(f"Found children '{children}' and children ids '{children_ids}' and tag_list '{tag_list}' for parent '{parent}'.")

        # Iterate through the child IDs
        for child_id in children_ids:
            # Ensure the item exists and has a valid tag
            if not dpg.does_item_exist(child_id):
                continue

            # Get the tag of the child
            tag = dpg.get_item_alias(child_id)

            # Skip invalid or empty tags
            if not tag:
                continue

            # Check if the child is a checkbox
            if dpg.get_item_type(child_id) == "mvAppItemType::mvCheckbox" and tag != sender:
                # Uncheck the checkbox
                dpg.set_value(tag, False)

        # Display update metadata under Import tab
        file_path = os.path.join(folder_path, file)
        display_metadata(file_path)
            
    except Exception as e:
        logging.error(f"Error in checkbox_callback: {e}", exc_info=True)
        show_popup(error_string)

In [114]:
def resize_callback(sender, app_data):
    '''
    # Callback to adjust the size of GUIDA to fit size of entire window
    '''
    try:
        height = dpg.get_viewport_client_height()
        width = dpg.get_viewport_client_width()
        #if dpg.does_item_exist("plot"):
        #    dpg.configure_item("plot", width=width-16, height=height-110)
        if dpg.does_item_exist("GUI_Tools_win"):
            dpg.configure_item("GUI_Tools_win", width=width, height=height)
    except Exception as e:
        logging.error(f"Error in resize_callback: {e}", exc_info=True)
        show_popup(error_string)

In [115]:
#######################################################################################################################################################
# Main
#######################################################################################################################################################

In [None]:
def run_GUI():
    '''
    # Function to start the Dear PyGui context in a thread
    '''
    try:
        dpg.create_context()

        with dpg.window(label="GUI Tools", 
                        width=1150, 
                        height=550, 
                        no_resize=True, 
                        no_title_bar=True,
                        no_move=True, 
                        no_collapse=True, 
                        no_close=True,
                        tag="GUI_Tools_win"):
            
            with dpg.tab_bar(label="tabs"):
                with dpg.tab(label="Import", tag="import_tab"):
                    with dpg.child_window(tag="import_child_window", 
                                          border=True, 
                                          autosize_x=True,
                                          autosize_y=True, 
                                          parent="import_tab"):
                        with dpg.child_window(tag="import_buttons_child_window", 
                            border=True, 
                            autosize_x=True,
                            height=35,
                            parent="import_child_window"):
                            with dpg.group(tag="import_buttons", horizontal=True):
                                dpg.add_text("Project Options:")
                                dpg.add_button(label="Select NWB Folder", tag="select_NWB_folder", callback=select_NWB_folder_callback)
                                dpg.add_checkbox(label="Save NWB Location", tag="save_NWB_folder", callback=save_NWB_folder_callback)
                                dpg.add_text("  Display Options:")
                                dpg.add_button(label="FullScreen", callback= lambda: dpg.toggle_viewport_fullscreen())
                                
                with dpg.tab(label="Visualization", tag="viz_tab"):
                    with dpg.child_window(tag="viz_child_window", 
                                          border=True, 
                                          autosize_x=True,
                                          autosize_y=True, 
                                          parent="viz_tab"):
                        with dpg.group(tag="viz_buttons", horizontal=True, parent="viz_child_window"):
                            with dpg.child_window(tag="plot_buttons",
                                                  border=True, 
                                                  autosize_x=True,
                                                  height=35, 
                                                  parent="viz_buttons"):
                                with dpg.group(horizontal=True, parent="plot_buttons"):
                                    dpg.add_text("Plot Options:")
                                    
                with dpg.tab(label="Analysis", tag="ana_tab"):
                    with dpg.child_window(tag="ana_child_window", 
                                          border=True, 
                                          autosize_x=True,
                                          autosize_y=True, 
                                          parent="ana_tab"):
                        with dpg.child_window(tag="ana_buttons_window",
                                              border=True, 
                                              autosize_x=True,
                                              height=35):
                            with dpg.group(horizontal=True, parent="ana_buttons", tag="ana_buttons_group"):
                                dpg.add_text("Analysis Options:")
                                
                with dpg.tab(label="JSON", tag="JSON_tab"):
                    with dpg.child_window(tag="JSON_child_window", 
                                          border=True, 
                                          autosize_x=True,
                                          autosize_y=True, 
                                          parent="JSON_tab"):
                        with dpg.child_window(tag="JSON_buttons_window",
                                              border=True, 
                                              autosize_x=True,
                                              height=35):
                            with dpg.group(horizontal=True, parent="JSON_buttons", tag="JSON_buttons_group"):
                                dpg.add_text("Analysis Options:")
                                    

        # At startup, load the config and update NWBFolder and the checkbox accordingly.
        config = load_config()
        if config.get("save_nwb_folder", False):
            saved_folder = config.get("nwb_folder", "")
            if saved_folder:
                # Instantiate Class, Get filepath to NWB folder from user, Display NWB files found
                NWBFolder_Class = NWBFolder()
                NWBFolder.set_folder_path(saved_folder)
                display_NWB_files(NWBFolder_Class)
            try:
                dpg.set_value("save_NWB_folder", True)
            except Exception:
                pass
                
        dpg.create_viewport(title='GUIDA', 
                            width=1400, 
                            height=700)
        
        dpg.setup_dearpygui()
        dpg.show_viewport()
        dpg.set_viewport_resize_callback(resize_callback)
        # Set the exit callback
        dpg.set_exit_callback(close_gui)
        dpg.start_dearpygui()
        dpg.destroy_context()

    except Exception as e:
        logging.error(f"Error in run_GUI: {e}", exc_info=True)
        show_popup(error_string)

# Function to properly close the GUI and clean up
def close_gui():
    try:
        global gui_running
        gui_running = False
        time.sleep(1)  # Give some time for the loop to exit
        logging.info("GUI has been closed and cleaned up.")
    except Exception as e:
        logging.error(f"Error in close_gui: {e}", exc_info=True)
        show_popup(error_string)

run_GUI()

  fig.tight_layout()
  fig.tight_layout()
  fig.tight_layout()
  fig.tight_layout()
  fig.tight_layout()
  .groupby(["band", "condition"])["snr_db"]
  df_long.groupby(["TimeKey", "Metric"])["Value"]
