Testing new pipeline 

In [None]:
# United SWR detector
import os
import subprocess
import numpy as np
import pandas as pd
from scipy import io, signal, stats
from scipy.signal import lfilter
import scipy.ndimage
from scipy.ndimage import gaussian_filter
from scipy.ndimage import gaussian_filter1d
from scipy import interpolate
import matplotlib.pyplot as plt
import ripple_detection
from ripple_detection import filter_ripple_band
import ripple_detection.simulate as ripsim  # for making our time vectors
from tqdm import tqdm
import time
import traceback
import logging
import logging.handlers
import sys
from multiprocessing import Pool, Process, Queue, Manager, set_start_method
import yaml
import string

# Get loader type from environment variable with a default value
#DATASET_TO_PROCESS = os.environ.get('DATASET_TO_PROCESS', 'ibl').lower()
DATASET_TO_PROCESS = 'abi_visual_behaviour'
valid_datasets = ['ibl', 'abi_visual_behaviour', 'abi_visual_coding']
if DATASET_TO_PROCESS not in valid_datasets:
    raise ValueError(f"DATASET_TO_PROCESS must be one of {valid_datasets}, got '{DATASET_TO_PROCESS}'")


# Lazy loading of the appropriate loader class
if DATASET_TO_PROCESS == 'ibl':
    from IBL_loader import ibl_loader
elif DATASET_TO_PROCESS == 'abi_visual_behaviour':
    from ABI_visual_behaviour_loader import abi_visual_behaviour_loader
elif DATASET_TO_PROCESS == 'abi_visual_coding':
    from ABI_visual_coding_loader import abi_visual_coding_loader
else:
    raise ValueError(f"Unknown dataset type: {DATASET_TO_PROCESS}")

# Load the configuration from a YAML file
config_path = os.environ.get('CONFIG_PATH', 'united_detector_config.yaml')
with open(config_path, "r") as f:
    # Parse the YAML content
    raw_content = f.read()
    # Replace environment variables
    for key, value in os.environ.items():
        raw_content = raw_content.replace(f"${key}", value)
    # Load the YAML
    full_config = yaml.safe_load(raw_content)

# Extract the unified output directory first
output_dir = full_config.get("output_dir", "")

# Load common settings
pool_size = full_config["pool_sizes"][DATASET_TO_PROCESS]
gamma_event_thresh = full_config["gamma_event_thresh"]
ripple_band_threshold = full_config["ripple_band_threshold"]
movement_artifact_ripple_band_threshold = full_config["movement_artifact_ripple_band_threshold"]
run_name = full_config["run_name"]
save_lfp = full_config["save_lfp"]

# Load dataset-specific settings
if DATASET_TO_PROCESS == 'ibl':
    # IBL specific settings
    dataset_config = full_config["ibl"]
    gamma_filters_path = full_config["filters"]["gamma_filters"]
    oneapi_cache_dir = dataset_config["oneapi_cache_dir"]
    swr_output_dir = dataset_config["swr_output_dir"]
    dont_wipe_these_sessions = dataset_config["dont_wipe_these_sessions"]
    session_npz_filepath = dataset_config["session_npz_filepath"]
    # Additional IBL-specific variables if needed
    
elif DATASET_TO_PROCESS == 'abi_visual_behaviour':
    # ABI (Allen) specific settings
    dataset_config = full_config["abi_visual_behaviour"]
    gamma_filters_path = full_config["filters"]["gamma_filters"]
    #sdk_cache_dir = dataset_config["sdk_cache_dir"]
    swr_output_dir = dataset_config["swr_output_dir"]
    dont_wipe_these_sessions = dataset_config["dont_wipe_these_sessions"]
    only_brain_observatory_sessions = dataset_config["only_brain_observatory_sessions"]
    # Setting up the ABI Cache (where data is held, what is present or absent)
    #manifest_path = os.path.join(sdk_cache_dir, "manifest.json")
    # There's no session_npz_filepath for ABI in the consolidated config

elif DATASET_TO_PROCESS == 'abi_visual_coding':
    # ABI (Allen) specific settings
    dataset_config = full_config["abi_visual_coding"]
    gamma_filters_path = full_config["filters"]["gamma_filters"]
    #sdk_cache_dir = dataset_config["sdk_cache_dir"]
    swr_output_dir = dataset_config["swr_output_dir"]
    dont_wipe_these_sessions = dataset_config["dont_wipe_these_sessions"]
    only_brain_observatory_sessions = dataset_config["only_brain_observatory_sessions"]
    # Setting up the ABI Cache (where data is held, what is present or absent)
    #manifest_path = os.path.join(sdk_cache_dir, "manifest.json")
    # There's no session_npz_filepath for ABI in the consolidated config

print(f"Configured for dataset: {DATASET_TO_PROCESS}")
print(f"Pool size: {pool_size}")
print(f"Output directory: {output_dir}")
print(f"SWR output directory: {swr_output_dir}")


# FUNCTIONS
# subprocess is a default module
def call_bash_function(bash_command=""):
    # example bash comand:
    # bash_command = "source /path/to/your/bash_script.sh && your_bash_function"
    process = subprocess.Popen(bash_command, stdout=subprocess.PIPE, shell=True)
    output, error = process.communicate()

    if process.returncode == 0:
        print("Bash function executed successfully.")
        print("Output:", output.decode("utf-8"))
    else:
        print("Error:", error.decode("utf-8"))


# Assuming you have your signal_array, b, and a defined as before
def finitimpresp_filter_for_LFP(
    LFP_array, samplingfreq, lowcut=1, highcut=250, filter_order=101
):
    """
    Filter the LFP array using a finite impulse response filter.

    Parameters
    ----------
    LFP_array : np.array
        The LFP array.
    samplingfreq : float
        The sampling frequency of the LFP array.
    lowcut : float
        The lowcut frequency.
    highcut : float
        The highcut frequency.
    filter_order : int
        The filter order.

    Returns
    -------
    np.array
        The filtered LFP array.
    """
    nyquist = 0.5 * samplingfreq

    # Design the FIR bandpass filter using scipy.signal.firwin
    fir_coeff = signal.firwin(
        filter_order,
        [lowcut / nyquist, highcut / nyquist],
        pass_zero=False,
        fs=samplingfreq,
    )

    # Apply the FIR filter to your signal_array
    # filtered_signal = signal.convolve(LFP_array, fir_coeff, mode='same', method='auto')
    filtered_signal = signal.lfilter(fir_coeff, 1.0, LFP_array, axis=0)
    return filtered_signal


def event_boundary_detector(
    time,
    five_to_fourty_band_power_df,
    envelope=True,
    minimum_duration=0.02,
    maximum_duration=0.4,
    threshold_sd=2.5,
    envelope_threshold_sd=1,
):
    """
    For detecting gamma events.
    Parameters
    ----------
    time : np.array
        The time values for the signal.
    five_to_fourty_band_power_df : np.array
        The power of the signal in the 5-40 Hz band.
    envelope : bool
        Whether to use the envelope threshold.
    minimum_duration : float
        The minimum duration of an event.
    maximum_duration : float
        The maximum duration of an event.
    threshold_sd : float
        The threshold in standard deviations.
    envelope_threshold_sd : float
        The envelope threshold in standard deviations.

    Returns
    -------
    pd.DataFrame
        A dataframe with the start and end times of the events.

    """
    # make df to fill
    row_of_info = {
        "start_time": [],
        "end_time": [],
        "duration": [],
    }

    # sharp_wave_events_df = pd.DataFrame()
    # scored_wave_power = stats.zscore(five_to_fourty_band_df)

    # compute our power threshold
    # wave_band_sd_thresh = np.std(five_to_fourty_band_df)*threshold_sd
    five_to_fourty_band_power_df = stats.zscore(five_to_fourty_band_power_df)
    past_thresh = five_to_fourty_band_power_df >= threshold_sd

    # now we expand the sections that are past thresh up to the points that
    # are past the envelope thresh, so not all sections above envelope thresh are true
    # but those sections which alse contain a region past the detection threshold are included
    def expand_sections(z_scores, boolean_array, thresh):
        # Find indices where boolean_array is True
        true_indices = np.where(boolean_array)[0]

        # Initialize an array to keep track of expanded sections
        expanded_sections = np.zeros_like(z_scores, dtype=bool)

        # Iterate over true_indices and expand sections
        for index in true_indices:
            # Find the start and end of the current section
            start = index
            end = index

            # Expand section to the left (while meeting conditions)
            while start > 0 and z_scores[start - 1] > thresh:
                start -= 1

            # Expand section to the right (while meeting conditions)
            while end < len(z_scores) - 1 and z_scores[end + 1] > thresh:
                end += 1

            # Check if the expanded section contains a point above envelope_threshold_sd in z_scores
            if any(z_scores[start : end + 1] > thresh):
                expanded_sections[start : end + 1] = True

        # Update the boolean_array based on expanded_sections
        boolean_array = boolean_array | expanded_sections

        return boolean_array

    if envelope == True:
        past_thresh = expand_sections(
            z_scores=five_to_fourty_band_power_df,
            boolean_array=past_thresh,
            thresh=envelope_threshold_sd,
        )

    # Find the indices where consecutive True values start
    starts = np.where(past_thresh & ~np.roll(past_thresh, 1))[0]
    row_of_info["start_time"] = time[starts]
    # Find the indices where consecutive True values end
    ends = np.where(past_thresh & ~np.roll(past_thresh, -1))[0]
    row_of_info["end_time"] = time[ends]

    row_of_info["duration"] = [
        row_of_info["end_time"][i] - row_of_info["start_time"][i]
        for i in range(0, len(row_of_info["start_time"]))
    ]

    # turn the dictionary into adataframe
    sharp_wave_events_df = pd.DataFrame(row_of_info)

    # filter for the duration range we want
    in_duration_range = (sharp_wave_events_df.duration > minimum_duration) & (
        sharp_wave_events_df.duration < maximum_duration
    )
    sharp_wave_events_df = sharp_wave_events_df[in_duration_range]

    return sharp_wave_events_df


def event_boundary_times(time, past_thresh):
    """
    Finds the times of a vector of true statements and returns values from another
    array representing the times

    Parameters
    ----------
    time : np.array
        The time values for the signal.
    past_thresh : np.array
        The boolean array of the signal.

    Returns
    -------
    pd.DataFrame
        A dataframe with the start and end times of the events.
    """
    # Find the indices where consecutive True values start
    starts = np.where(past_thresh & ~np.roll(past_thresh, 1))[0]
    row_of_info["start_time"] = time[starts]
    # Find the indices where consecutive True values end
    ends = np.where(past_thresh & ~np.roll(past_thresh, -1))[0]
    row_of_info["end_time"] = time[ends]

    row_of_info["duration"] = [
        row_of_info["end_time"][i] - row_of_info["start_time"][i]
        for i in range(0, len(row_of_info["start_time"]))
    ]

    # turn the dictionary into adataframe
    events_df = pd.DataFrame(row_of_info)

    return events_df


def peaks_time_of_events(events, time_values, signal_values):
    """
    Computes the times when ripple power peaks in the events

    Parameters
    ----------
    events : pd.DataFrame
        The events dataframe.
    time_values : np.array
        The time values for the signal.
    signal_values : np.array
        The signal values for the signal.

    Returns
    -------
    np.array
        The times of the peaks in the ripple power signal.
    """

    # looks for the peaks in the ripple power signal, value of zscored raw lfp peak and returns time of peak
    signal_values_zscore = stats.zscore(signal_values)
    peak_times = []
    for start, end in zip(events["start_time"], events["end_time"]):
        window_idx = (time_values >= start) & (time_values <= end)
        ripple_lfp_zscore_signal = signal_values_zscore[window_idx]
        maxpoint = np.argmax(ripple_lfp_zscore_signal)
        rippletimepoints = time_values[window_idx]
        peak_times.append(rippletimepoints[maxpoint])
    return np.array(peak_times)


def resample_signal(signal, times, new_rate):
    """
    Resample a 2D signal array to a new sampling rate.

    Parameters:
    signal (np.array): 2D array where each column is a source and each row is a time point.
    times (np.array): 1D array of times corresponding to the rows of the signal array.
    new_rate (float): The new sampling rate in Hz.

    Returns:
    new_signal (np.array): The resampled signal array.
    new_times (np.array): The times corresponding to the rows of the new signal array.
    """
    nsamples_new = int(len(times) * new_rate / (len(times) / times[-1]))
    new_times = np.linspace(times[0], times[-1], nsamples_new)
    new_signal = np.zeros((signal.shape[0], nsamples_new))

    for i in range(signal.shape[0]):
        interp_func = interpolate.interp1d(
            times, signal[i, :], bounds_error=False, fill_value="extrapolate"
        )
        new_signal[i, :] = interp_func(new_times)

    return new_signal, new_times

def listener_process(queue):
    """
    This function listens for messages from the logging module and writes them to a log file.
    It sets the logging level to MESSAGE so that only messages with level MESSAGE or higher are written to the log file.
    This is a level we created to be between INFO and WARNING, so to see messages from this code and errors  but not other
    messages that are mostly irrelevant and make the log file too large and uninterpretable.

    Parameters
    ----------
    queue : multiprocessing.Queue
        The queue to get messages from.

    Returns
    -------
    None

    """
    root = logging.getLogger()
    h = logging.FileHandler(
        f"ibl_detector_{swr_output_dir}_{run_name}_app.log", mode="w"
    )
    f = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
    h.setFormatter(f)
    root.addHandler(h)
    root.setLevel(MESSAGE)  # Set logging level to MESSAGE

    while True:
        message = queue.get()
        if message == "kill":
            break
        logger = logging.getLogger(message.name)
        logger.handle(message)

def init_pool(*args):
    h = logging.handlers.QueueHandler(queue)
    root = logging.getLogger()
    root.addHandler(h)
    root.setLevel(MESSAGE)  # Set logging level to MESSAGE


# ABI Loaders
import time
import os
import numpy as np
import yaml
from scipy import signal, interpolate
from allensdk.brain_observatory.behavior.behavior_project_cache import (
    VisualBehaviorNeuropixelsProjectCache,
)

# Use the Allen SDK to get sessions
cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir='/space/scratch/allen_visbehave_data')

class abi_visual_behaviour_loader_test:
    def __init__(self, session_id):
        """
        Initialize the ABI loader with a session ID.
        
        Parameters
        ----------
        session_id : int
            The ABI ecephys session ID
        """
        self.session_id = session_id
        self.cache = None
        self.session = None
        self.probe_id_list = None
        self.probes_of_interest = None
        
    def set_up(self, cache_directory=None):
        """
        Sets up the EcephysProjectCache and loads the session.
        
        Parameters
        ----------
        cache_directory : str, optional
            Directory where to store the cache. If None, uses default.
            
        Returns
        -------
        self : abi_loader
            Returns the instance for method chaining.
        """
        # Set up the cache
        config_path = os.environ.get('CONFIG_PATH', 'united_detector_config.yaml')
        with open(config_path, "r") as f:
            config_content = f.read()
            full_config = yaml.safe_load(config_content)
        dataset_config = full_config["abi_visual_behaviour"]
        sdk_cache_dir = dataset_config["sdk_cache_dir"]
        manifest_path = os.path.join(sdk_cache_dir, "manifest.json")
        #cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=sdk_cache_dir)
        
        if cache_directory is not None:
            #self.cache = EcephysProjectCache(manifest=manifest_path, fetch_api=EcephysProjectCache.from_warehouse(cache_directory))
            self.cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=sdk_cache_dir)
        else:
            self.cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=sdk_cache_dir)
            
        # Load the session
        self.session = self.cache.get_ecephys_session(ecephys_session_id=self.session_id)
        self.session.channels = self.session.get_channels()
        
        print(f"Session {self.session_id} loaded")
        return self
    
    def has_ca1_channels(self):
        """
        Checks if the session includes CA1 channels.
        
        Returns
        -------
        bool
            True if CA1 channels exist, False otherwise
        """
        has_ca1 = np.isin("CA1", list(self.session.channels.structure_acronym.unique()))
        
        if not has_ca1:
            print(f"Session {self.session_id} does not have CA1 channels")
            
        return has_ca1
    
    def get_probes_with_ca1(self):
        """
        Gets the list of probes that have CA1 channels.
        
        Returns
        -------
        list
            List of probe IDs with CA1 channels
        """
        # Get probes with LFP data
        probes_table_df = self.cache.get_probe_table()
        valid_lfp = probes_table_df[probes_table_df["has_lfp_data"]]
        
        # Get probes for this session
        self.probe_id_list = list(
            valid_lfp[valid_lfp.ecephys_session_id == self.session_id].index
        )
        
        # Find probes with CA1 channels
        self.probes_of_interest = []
        for probe_id in self.probe_id_list:
            has_ca1_and_exists = np.isin(
                "CA1",
                list(
                    self.session.channels[
                        self.session.channels.probe_id == probe_id
                    ].structure_acronym.unique()
                ),
            )
            if has_ca1_and_exists:
                self.probes_of_interest.append(probe_id)
        
        print(f"Found {len(self.probes_of_interest)} probes with CA1 channels")
        return self.probes_of_interest
    
    def process_probe(self, probe_id, filter_ripple_band_func=None):
        """
        Processes a single probe to extract CA1 and control channels.
        
        Parameters
        ----------
        probe_id : int
            ID of the probe to process
        filter_ripple_band_func : function, optional
            Function to filter for ripple band
            
        Returns
        -------
        dict
            Dictionary with processing results
        """
        print(f"Processing probe: {probe_id}")
        
        # Get LFP for the probe
        lfp = self.session.get_lfp(probe_id)
        og_lfp_obj_time_vals = lfp.time.values
        # Get control channels outside hippocampus
        idx = self.session.channels.probe_id == probe_id
        organisedprobechans = self.session.channels[idx].sort_values(
            by="probe_vertical_position"
        )
        organisedprobechans = organisedprobechans[
            np.isin(organisedprobechans.index.values, lfp.channel.values)
        ]
        
        # Find channels outside hippocampus
        not_a_ca1_chan = np.logical_not(
            np.isin(
                organisedprobechans.structure_acronym,
                ["CA3", "CA2", "CA1", "HPF", "EC", "DG"],
            )
        )
        
        # Choose two random channels
        take_two = np.random.choice(
            organisedprobechans.index[not_a_ca1_chan], 2, replace=False
        )
        control_channels = []
        
        # Get LFP for control channels
        for channel_outside_hp in take_two:
            movement_control_channel = lfp.sel(channel=channel_outside_hp)
            movement_control_channel = movement_control_channel.to_numpy()
            # Resample to match CA1 data
            movement_control_channel, lfp_time_index = self.resample_signal(movement_control_channel, lfp.time.values, 1500.0)
            # needed for ripple detector method
            #movement_control_channel = interp_func(lfp_time_index)
            movement_control_channel = movement_control_channel[:, None]
            control_channels.append(movement_control_channel)
        
        # Get CA1 channels for this probe
        ca1_chans = self.session.channels.probe_channel_number[
            (self.session.channels.probe_id == probe_id)
            & (self.session.channels.structure_acronym == "CA1")
        ]
        ca1_idx = np.isin(lfp.channel.values, ca1_chans.index.values)
        ca1_idx = lfp.channel.values[ca1_idx]
        
        # Select CA1 channels
        lfp_ca1 = lfp.sel(channel=ca1_idx)
        del lfp
        lfp_ca1 = lfp_ca1.to_pandas()
        lfp_ca1_chans = lfp_ca1.columns
        lfp_ca1 = lfp_ca1.to_numpy()
        
        # Check for NaNs
        if np.isnan(lfp_ca1).any():
            print(f"NaN detected in LFP data for probe {probe_id}, skipping")
            return None
        
        # Resample to 1500 Hz
        lfp_ca1, lfp_time_index = self.resample_signal(
            lfp_ca1, og_lfp_obj_time_vals, 1500.0
        )
        
        # Find channel with highest ripple power if function provided
        if filter_ripple_band_func is not None:
            lfp_ca1_rippleband = filter_ripple_band_func(lfp_ca1)
            highest_rip_power = np.abs(signal.hilbert(lfp_ca1_rippleband)) ** 2
            highest_rip_power = highest_rip_power.max(axis=0)
            
            # Get channel with highest ripple power
            peak_chan_idx = highest_rip_power.argmax()
            this_chan_id = int(lfp_ca1_chans[peak_chan_idx])
            peakrippleband = lfp_ca1_rippleband[:, peak_chan_idx]
            peakripchan_lfp_ca1 = lfp_ca1[:, lfp_ca1_chans == this_chan_id]
        else:
            peak_chan_idx = None
            this_chan_id = None
            peakrippleband = None
            peakripchan_lfp_ca1 = None
        del lfp_ca1

        
        # Collect results
        results = {
            'probe_id': probe_id,
            #'lfp_ca1': lfp_ca1,
            'lfp_time_index': lfp_time_index,
            'ca1_chans': lfp_ca1_chans,
            'control_lfps': control_channels,
            'control_channels': take_two,
            'peak_ripple_chan_idx': peak_chan_idx,
            'peak_ripple_chan_id': this_chan_id,
            'peak_ripple_chan_raw_lfp': peakripchan_lfp_ca1,
            'chan_id_string': str(this_chan_id) if this_chan_id is not None else None,
            'rippleband': peakrippleband
        }
        
        return results

    def resample_signal(self, signal_data, time_values, target_fs=1500.0):
        """
        Resamples a signal to the target sampling frequency.
        
        Parameters
        ----------
        signal_data : numpy.ndarray
            Signal data to resample
        time_values : numpy.ndarray
            Time values corresponding to the signal data
        target_fs : float, optional
            Target sampling frequency
            
        Returns
        -------
        tuple
            (resampled_signal, new_time_values)
        """
        # Create new time index
        t_start = time_values[0]
        t_end = time_values[-1]
        dt_new = 1.0 / target_fs
        n_samples = int(np.ceil((t_end - t_start) / dt_new))
        new_time_values = t_start + np.arange(n_samples) * dt_new
        
        # Resample signal
        if signal_data.ndim == 1:
            # For 1D signals
            interp_func = interpolate.interp1d(
                time_values, signal_data, bounds_error=False, fill_value="extrapolate"
            )
            resampled = interp_func(new_time_values)
        else:
            # For multi-channel signals
            #resampled = np.zeros((signal_data.shape[0], len(new_time_values)))
            resampled = np.zeros((len(new_time_values), signal_data.shape[1]))
            for i in range(signal_data.shape[1]):
                interp_func = interpolate.interp1d(
                    time_values, signal_data[:, i], bounds_error=False, fill_value="extrapolate"
                )
                resampled[:, i] = interp_func(new_time_values)
        
        return resampled, new_time_values
    
    def cleanup(self):
        """
        Cleans up resources to free memory.
        """
        self.session = None
        

session_id = 
"""
This function takes in a session_id (eid in the IBL) and loops through the probes in that session,
for each probe it finds the CA1 channel with the highest ripple power and uses that
channel to detect SWR events.  It also detects gamma events and movement artifacts
on two channels outside of the brain.

Parameters
----------
session_id : int
    The session id for the session to be processed.
queue : multiprocessing.Queue
    The queue to send messages to the listener process for recording errors.

Returns
-------
None
but...
Saves the following files to the folder specified by swr_output_dir_path.

Notes:
- The LFP is interpolated to 1500 Hz for all channels used.
- The SWR detector used is the Karlsson ripple detector from the ripple_detection module.
- The folders are titled by session and all files contain the name of the probe and the channel they originated from
"""

process_stage = f"Starting the process, session{str(session_id)}"  # for debugging
probe_id = "Not Loaded Yet"
one_exists = False
# Add this near the beginning of the function
data_files = None
process_stage = "Starting the process"  # for debugging
probe_id = "Not Loaded Yet"

# Create session subfolder
session_subfolder = "swrs_session_" + str(session_id)
session_subfolder = os.path.join(swr_output_dir_path, session_subfolder)

try:
    # Set up brain atlas
    process_stage = "Setting up brain atlas"
    #ba = AllenAtlas()
    #br = BrainRegions()
    
    process_stage = "Session loaded, checking if directory exists"
    # Check if directory already exists
    if os.path.exists(session_subfolder):
        raise FileExistsError(f"The directory {session_subfolder} already exists.")
    else:
        os.makedirs(session_subfolder)
        
    if save_lfp == True:
        # Create subfolder for lfp data
        session_lfp_subfolder = "lfp_session_" + str(session_id)
        session_lfp_subfolder = os.path.join(lfp_output_dir_path, session_lfp_subfolder)
        os.makedirs(session_lfp_subfolder, exist_ok=True)
    
    # Initialize and set up the IBL loader
    process_stage = "Setting up IBL loader"
    
    if DATASET_TO_PROCESS == 'ibl':
        loader = ibl_loader(session_id)
    elif DATASET_TO_PROCESS == 'abi_visual_behaviour':
        loader = abi_visual_behaviour_loader(session_id)
    elif DATASET_TO_PROCESS == 'abi_visual_coding':
        loader = abi_visual_coding_loader(session_id)
    loader.set_up()
    one_exists = True  # Mark that we have a connection for error handling
    
    # Get probe IDs and names
    process_stage = "Getting probe IDs and names"
    if DATASET_TO_PROCESS == 'abi_visual_coding':
        probenames = None
        probelist = loader.get_probes_with_ca1()
    elif DATASET_TO_PROCESS == 'abi_visual_behaviour':
        probenames = None
        probelist = loader.get_probes_with_ca1()
    elif DATASET_TO_PROCESS == 'ibl':
        probelist, probenames = loader.get_probe_ids_and_names()

    process_stage = "Running through the probes in the session"
    icount = 0
    # Process each probe
    for this_probe in range(len(probelist)):
        if icount > 0:
            break
        icount = icount + 1
        
        if DATASET_TO_PROCESS == 'ibl':
            probe_name = probenames[this_probe]
        probe_id = probelist[this_probe]  # Always get the probe_id from probelist
        print(f"Processing probe: {str(probe_id)}")

        # Process the probe and get results
        process_stage = f"Processing probe with id {str(probe_id)}"
        if DATASET_TO_PROCESS == 'abi_visual_coding':
            results = loader.process_probe(probe_id, filter_ripple_band)  # Use probe_id, not this_probe
        elif DATASET_TO_PROCESS == 'abi_visual_behaviour':
            results = loader.process_probe(probe_id, filter_ripple_band)  # Use probe_id, not this_probe
        elif DATASET_TO_PROCESS == 'ibl':
            results = loader.process_probe(this_probe, filter_ripple_band)  # Use probe_id, not this_probe
        # Skip if no results (no CA1 channels or no bin file)
        if results is None:
            print(f"No results for probe {probe_id}, skipping...")
            continue

        # Extract results
        #lfp_ca1 = results['lfp_ca1']
        peakripple_chan_raw_lfp = results['peak_ripple_chan_raw_lfp']
        lfp_time_index = results['lfp_time_index']
        ca1_chans = results['ca1_chans']
        outof_hp_chans_lfp = results['control_lfps']
        take_two = results['control_channels']
        peakrippleband = results['rippleband']
        this_chan_id = results['peak_ripple_chan_id']

        # Filter to gamma band
        gamma_band_ca1 = np.convolve(
            peakripple_chan_raw_lfp.reshape(-1), gamma_filter, mode="same"
        )

        # write our lfp to file
        np.savez(
            os.path.join(
                session_lfp_subfolder,
                f"probe_{probe_id}_channel_{this_chan_id}_lfp_ca1_peakripplepower.npz",
            ),
            lfp_ca1=peakripple_chan_raw_lfp,
        )
        np.savez(
            os.path.join(
                session_lfp_subfolder,
                f"probe_{probe_id}_channel_{this_chan_id}_lfp_time_index_1500hz.npz",
            ),
            lfp_time_index = lfp_time_index,
        )
        print(f"outof_hp_chans_lfp : {outof_hp_chans_lfp}")
        for i in range(2):
            channel_outside_hp = take_two[i]
            channel_outside_hp = "channelsrawInd_" + str(channel_outside_hp)
            np.savez(
                os.path.join(
                    session_lfp_subfolder,
                    f"probe_{probe_id}_channel_{channel_outside_hp}_lfp_control_channel.npz",
                ),
                lfp_control_channel=outof_hp_chans_lfp[i],
            )

        # create a dummy speed vector
        dummy_speed = np.zeros_like(peakrippleband)
        print("Detecting Putative Ripples")
        # we add a dimension to peakrippleband because the ripple detector needs it
        process_stage = f"Detecting Putative Ripples on probe with id {str(probe_id)}"
        
        Karlsson_ripple_times = ripple_detection.Karlsson_ripple_detector(
            time=lfp_time_index,
            zscore_threshold=ripple_band_threshold,
            filtered_lfps=peakrippleband[:, None],
            speed=dummy_speed,
            sampling_frequency=1500.0,
        )

        Karlsson_ripple_times = Karlsson_ripple_times[
            Karlsson_ripple_times.duration < 0.25
        ]
        print("Done")
        # adds some stuff we want to the file

        # ripple band power
        peakrippleband_power = np.abs(signal.hilbert(peakrippleband)) ** 2
        Karlsson_ripple_times["Peak_time"] = peaks_time_of_events(
            events=Karlsson_ripple_times,
            time_values=lfp_time_index,
            signal_values=peakrippleband_power,
        )
        speed_cols = [
            col for col in Karlsson_ripple_times.columns if "speed" in col
        ]
        Karlsson_ripple_times = Karlsson_ripple_times.drop(columns=speed_cols)
        csv_filename = (
            f"probe_{probe_id}_channel_{this_chan_id}_karlsson_detector_events.csv"
        )
        csv_path = os.path.join(session_subfolder, csv_filename)
        Karlsson_ripple_times.to_csv(csv_path, index=True, compression="gzip")
        print("Writing to file.")
        print("Detecting gamma events.")

        # compute this later, I will have a seperate script called SWR filtering which will do this
        process_stage = f"Detecting Gamma Events on probe with id {str(probe_id)}"
        
        gamma_power = np.abs(signal.hilbert(gamma_band_ca1)) ** 2
        gamma_times = event_boundary_detector(
            time=lfp_time_index,
            threshold_sd=gamma_event_thresh,
            envelope=False,
            minimum_duration=0.015,
            maximum_duration=float("inf"),
            five_to_fourty_band_power_df=gamma_power,
        )
        print("Done")
        csv_filename = (
            f"probe_{probe_id}_channel_{this_chan_id}_gamma_band_events.csv"
        )
        csv_path = os.path.join(session_subfolder, csv_filename)
        gamma_times.to_csv(csv_path, index=True, compression="gzip")

        # movement artifact detection
        process_stage = f"Detecting Movement Artifacts on probe with id {probe_id}"
        
        for i in [0, 1]:
            channel_outside_hp = take_two[i]
            process_stage = f"Detecting Movement Artifacts on control channel {channel_outside_hp} on probe {probe_id}"
            # process control channel ripple times
            ripple_band_control = outof_hp_chans_lfp[i]
            dummy_speed = np.zeros_like(ripple_band_control)
            ripple_band_control = filter_ripple_band(ripple_band_control)
            rip_power_controlchan = np.abs(signal.hilbert(ripple_band_control)) ** 2
            
            print(f"ripple_band_control shape: {ripple_band_control.shape}, length: {len(ripple_band_control)}")
            print(f"lfp_time_index shape: {lfp_time_index.shape}, length: {len(lfp_time_index)}")
            print(f"dummy_speed shape: {dummy_speed.shape}, length: {len(dummy_speed)}")
            
            if DATASET_TO_PROCESS == 'abi_visual_behaviour':
                lfp_time_index = lfp_time_index.reshape(-1)
                dummy_speed = dummy_speed.reshape(-1)
            if DATASET_TO_PROCESS == 'ibl':
                # Reshape to ensure consistent (n_samples, n_channels) format for detector
                # Prevents memory error when pd.notnull() creates boolean arrays with shape (n, n)
                rip_power_controlchan = rip_power_controlchan.reshape(-1,1)
            
            movement_controls = ripple_detection.Karlsson_ripple_detector(
                time=lfp_time_index.reshape(-1),  # if this doesnt work try adding .reshape(-1)
                filtered_lfps=rip_power_controlchan,  # indexing [:,None] is not needed here, rip_power_controlchan is already 2d (nsamples, 1)
                speed=dummy_speed.reshape(-1),  # if this doesnt work try adding .reshape(-1)
                zscore_threshold=movement_artifact_ripple_band_threshold,
                sampling_frequency=1500.0,
            )
            speed_cols = [
                col for col in movement_controls.columns if "speed" in col
            ]
            movement_controls = movement_controls.drop(columns=speed_cols)
            # write to file name
            channel_outside_hp = "channelsrawInd_" + str(channel_outside_hp)  # no cjannel id in IBL dataset, so this will do instead
            csv_filename = f"probe_{probe_id}_channel_{channel_outside_hp}_movement_artifacts.csv"
            csv_path = os.path.join(session_subfolder, csv_filename)
            movement_controls.to_csv(csv_path, index=True, compression="gzip")
            print("Done Probe id " + str(probe_id))

    # deleting the session folder
    # del one  # so that we can delete the session folder, note sr and ssl need to be deleted as well, already done earlier
    if 'loader' in locals() and loader is not None:
        loader.cleanup()
    process_stage = "All processing done, Deleting the session folder"

    # in the session
    logging.log(MESSAGE, f"Processing complete for id {session_id}.")



# set up the logging
log_file = os.environ.get('LOG_FILE', f"{DATASET_TO_PROCESS}_detector_{swr_output_dir}_{run_name}_app.log")
MESSAGE = 25  # Define a custom logging level, between INFO (20) and WARNING (30)
logging.addLevelName(MESSAGE, "MESSAGE")

# Set up file handler for logging
file_handler = logging.FileHandler(log_file, mode="w")
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)

# Set up root logger - but don't remove existing handlers
root_logger = logging.getLogger()
root_logger.setLevel(MESSAGE)  # Only log MESSAGE level and above
root_logger.addHandler(file_handler)

# Prevent propagation of lower-level warnings to the root logger
for logger_name in ['hdmf', 'pynwb', 'spikeglx', 'ripple_detection']:
    logger = logging.getLogger(logger_name)
    logger.propagate = False  # Don't send these to the root logger


# loading filters (crates artifacts in first and last ~ 3.5 seconds of recordings, remember to clip these off)
# I don't think I need this it's at the start of my files
gamma_filter = np.load(gamma_filters_path)
gamma_filter = gamma_filter["arr_0"]

# Searching for datasets
brain_acronym = "CA1"
# query sessions endpoint
# sessions, sess_details = one.search(atlas_acronym=brain_acronym, query_type='remote', details=True)

swr_output_dir_path = os.path.join(output_dir, swr_output_dir)
os.makedirs(swr_output_dir_path, exist_ok=True)
sessions_without_ca1 = np.array([])
# from multiprocessing import Pool

if save_lfp == True:
    lfp_output_dir_path = os.path.join(output_dir, swr_output_dir + "_lfp_data")
    os.makedirs(lfp_output_dir_path, exist_ok=True)

queue = Queue()
listener = Process(target=listener_process, args=(queue,))
listener.start()

if DATASET_TO_PROCESS == "abi_visual_coding":
    # If processing Allen data
    data_file_path = os.path.join("session_id_lists", "allen_viscoding_ca1_session_ids.npz")
    data = np.load(data_file_path)
    all_sesh_with_ca1_eid = data["data"]
    del data
    print(f"Loaded {len(all_sesh_with_ca1_eid)} sessions from {data_file_path}")

if DATASET_TO_PROCESS == "abi_visual_behaviour":
    # If processing Allen data
    data_file_path = os.path.join("session_id_lists", "allen_visbehave_ca1_session_ids.npz")
    data = np.load(data_file_path)
    all_sesh_with_ca1_eid = data["data"]
    del data
    print(f"Loaded {len(all_sesh_with_ca1_eid)} sessions from {data_file_path}")

elif DATASET_TO_PROCESS == "ibl":
    # If processing IBL data
    session_file_path = os.path.join("session_id_lists", session_npz_filepath)
    data = np.load(session_file_path)
    all_sesh_with_ca1_eid = data["all_sesh_with_ca1_eid_unique"]
    del data
    print(f"Loaded {len(all_sesh_with_ca1_eid)} sessions from {session_file_path}")


# run the processes with the specified number of cores:
with Pool(pool_size, initializer=init_pool, initargs=(queue,)) as p:
    p.map(process_session, all_sesh_with_ca1_eid[10:11])

queue.put("kill")
listener.join()

# Find and clean up empty session folders
print(f"Checking for empty session folders in {swr_output_dir_path}")
empty_folder_count = 0

for folder_name in os.listdir(swr_output_dir_path):
    folder_path = os.path.join(swr_output_dir_path, folder_name)
    
    # Check if it's a directory and starts with the session prefix
    if os.path.isdir(folder_path) and folder_name.startswith("swrs_session_"):
        # Check if the directory is empty
        if not os.listdir(folder_path):
            session_id = folder_name.replace("swrs_session_", "")
            logging.log(MESSAGE, f"Empty session folder found and removed: {session_id}")
            print(f"Removing empty session folder: {folder_path}")
            
            # Remove the empty directory
            os.rmdir(folder_path)
            empty_folder_count += 1

print(f"Removed {empty_folder_count} empty session folders")
logging.log(MESSAGE, f"Processing complete. Removed {empty_folder_count} empty session folders.")


Configured for dataset: abi_visual_coding
Pool size: 6
Output directory: /space/scratch/SWR_final_pipeline/testing_dir
SWR output directory: allen_viscoding_swr_murphylab2024
