## load dependencies

In [None]:
## Initialzing and loading required libraries and subfunctions
import numpy as np
import matplotlib.pyplot as plt
import copy
import yasa
from mne.filter import resample
import pynapple as nap
import seaborn as sns
import pandas as pd
from sklearn.preprocessing import normalize
import requests
from io import BytesIO
import sails
import re
from scipy.stats import entropy

import scipy
from scipy import signal
from scipy.interpolate import griddata
from scipy.signal import correlate
from scipy.stats import pearsonr
from scipy.fft import fft
from scipy.spatial.distance import euclidean
from scipy.signal import spectrogram
from scipy.io import loadmat
import scipy.fft
import scipy.stats
import scipy.io as sio
from scipy.signal import hilbert

import emd as emd
import emd.sift as sift
import emd.spectra as spectra

from neurodsp.sim import sim_combined
from neurodsp.plts import plot_time_series, plot_timefrequency
from neurodsp.utils import create_times
from neurodsp.timefrequency.wavelets import compute_wavelet_transform
from neurodsp.filt import filter_signal

# Load required libraries
import numpy as np
from scipy.io import loadmat
from scipy.signal import hilbert
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import seaborn as sns
from neurodsp.filt import filter_signal, filter_signal_fir, design_fir_filter
import emd
import pandas as pd
from sklearn.preprocessing import Normalizer
from tqdm import tqdm
import plotly.express as px
import copy
import umap.umap_ as umap
import skdim
from scipy.spatial import cKDTree
import pickle

## UTILS
from utils import *
from detect_pt import *

from scipy.io import loadmat
import numpy as np
from neurodsp.filt import filter_signal
import copy
import emd
from scipy.spatial import cKDTree
from tqdm import tqdm

sns.set(style='white', context='notebook')

In [None]:
config = emd.sift.SiftConfig.from_yaml_file('/Users/amir/Desktop/for Abdel/emd_masksift_CA1_config_2500.yml')

In [None]:
config_rgs = emd.sift.SiftConfig.from_yaml_file('/Users/amir/Desktop/for Abdel/emd_masksift_CA1_config.yml')

In [None]:
def extract_pt_intervals(lfpHPC, hypno, fs=2500):
    targetFs = 500
    n_down = fs / targetFs
    start, end = get_start_end(hypno=hypno, sleep_state_id=5)
    rem_interval = nap.IntervalSet(start=start, end=end)
    fs = int(n_down * targetFs)
    t = np.arange(0, len(lfpHPC) / fs, 1 / fs)
    lfp = nap.TsdFrame(t=t, d=lfpHPC, columns=['HPC'])

    # Detect phasic intervals
    lfpHPC_down = preprocess(lfpHPC, n_down)
    phREM = detect_phasic(lfpHPC_down, hypno, targetFs)

    # Create phasic REM IntervalSet
    start, end = [], []
    for rem_idx in phREM:
        for s, e in phREM[rem_idx]:
            start.append(s / targetFs)
            end.append(e / targetFs)
    phasic_interval = nap.IntervalSet(start, end)

    # Calculate tonic intervals
    tonic_interval = rem_interval.set_diff(phasic_interval)
    print(f'Number of detected Tonic intrevals:{len(tonic_interval)}')
    # Apply a 100 ms duration threshold to tonic intervals
    min_duration = 0.1  # 100 ms in seconds
    durations = tonic_interval['end'] - tonic_interval['start']
    valid_intervals = durations >= min_duration
    tonic_interval = nap.IntervalSet(tonic_interval['start'][valid_intervals], tonic_interval['end'][valid_intervals])
    print(f'Number of detected Tonic intrevals after threshold:{len(tonic_interval)}')
    return phasic_interval, tonic_interval, lfp

In [None]:
def get_cycle_data(imf5, fs=2500):
    cycle_data = {"fs": None, 'theta_imf': None,
                       "IP": None, "IF": None, "IP": None, "cycles": None}


    # Get cycles using IP
    IP, IF, IA = emd.spectra.frequency_transform(imf5, fs, 'hilbert')
    C = emd.cycles.Cycles(IP)
    cycles = get_cycles_with_metrics(C, imf5, IA, IF)

    cycle_data['fs'] = fs
    cycle_data['theta_imf'] = imf5
    cycle_data['IP'] = IP
    cycle_data['IF'] = IF
    cycle_data['IA'] = IA
    cycle_data['cycles'] = cycles
    return cycle_data

In [None]:
def extract_imfs_by_pt_intervals(lfp, fs, interval, config, return_imfs_freqs=False):

    all_imfs = []
    all_imf_freqs = []
    rem_lfp = []
    all_masked_freqs = []
    for ii in range(len(interval)):
        start_idx = int(interval.loc[ii, 'start'] * fs)
        end_idx = int(interval.loc[ii, 'end'] * fs)
        sig_part = lfp[start_idx:end_idx]
        sig = np.array(sig_part)

        rem_lfp.append(sig)

        try:
            imf, mask_freq = sift.mask_sift(sig, **config)
        except Exception as e:
            print(f"EMD Sift failed: {e}. Skipping this interval.")
            continue
        all_imfs.append(imf)
        all_masked_freqs.append(mask_freq)

        imf_frequencies = imf_freq(imf, fs)
        all_imf_freqs.append(imf_frequencies)

    if return_imfs_freqs:
        return all_imfs, all_imf_freqs, rem_lfp
    else:
        return all_imfs

# phasse aligned & non-phase-aligned waveforms

In [None]:
def extract_waveforms_with_and_without_alignment(imfs, imf_frequencies, max_extract=None):
    
    from scipy import interpolate

    aligned_waveforms = pd.DataFrame()
    non_aligned_waveforms = pd.DataFrame()
    all_cycle_metrics = pd.DataFrame()

    fs = 2500
    theta_range = [5, 12]
    
    # Loop over each IMF in the list
    for idx, imf in enumerate(imfs):
        # Get cycle data for the theta IMF (using the 6th column as before)
        cycle_data = get_cycle_data(imf[:, 5], fs=fs)
        
        # Apply amplitude and duration thresholds
        amp_thresh = np.percentile(cycle_data['IA'], 25)
        lo_freq_duration = fs / 5    # lower bound (in samples)
        hi_freq_duration = fs / 12   # upper bound (in samples)
        
        conditions = [
            'is_good==1',
            f'duration_samples<{lo_freq_duration}',
            f'duration_samples>{hi_freq_duration}',
            f'max_amp>{amp_thresh}'
        ]
        
        all_cycles = get_cycles_with_conditions(cycle_data['cycles'], conditions)
        
        # Check if any cycles satisfy the conditions
        if all_cycles is None or all_cycles.chain_vect.size == 0:
            print("No cycles satisfy the conditions for IMF", idx)
            continue
        
        # Get metrics for the selected cycles
        cycle_metrics = all_cycles.get_metric_dataframe(subset=True)
        
        # 1. Extract phase-aligned waveforms using emd.cycles.phase_align
        aligned_waves, _ = emd.cycles.phase_align(
            cycle_data['IP'],
            cycle_data['theta_imf'],
            cycles=all_cycles.iterate(through='subset'),
            npoints=100
        )
        # Transpose so that each row is one waveform
        aligned_df = pd.DataFrame(aligned_waves.T)
        
        # 2. Extract non-phase-aligned waveforms using the cycle indices
        non_aligned_list = []
        theta_imf = cycle_data['theta_imf']
        
        for cycle in all_cycles.iterate(through='subset'):
            # Expect the iterator to yield a tuple: (cycle_index, indices_array)
            try:
                _, inds = cycle
            except Exception as e:
                continue  # skip if not in expected format
                
            if len(inds) < 2:
                continue  # need at least 2 points for interpolation
                
            # Extract the raw waveform for this cycle using the indices
            raw_wave = theta_imf[inds]
            
            # Resample the raw waveform to 100 points using linear interpolation
            x_old = np.linspace(0, 1, len(raw_wave))
            x_new = np.linspace(0, 1, 100)
            f_interp = interpolate.interp1d(x_old, raw_wave, kind='linear',
                                            bounds_error=False, fill_value='extrapolate')
            resampled_wave = f_interp(x_new)
            non_aligned_list.append(resampled_wave)
        
        non_aligned_df = pd.DataFrame(non_aligned_list)
        
        # If a maximum number is set, limit the number of waveforms from this IMF
        if max_extract is not None:
            if len(aligned_df) > max_extract:
                aligned_df = aligned_df.iloc[:max_extract]
                non_aligned_df = non_aligned_df.iloc[:max_extract]
                cycle_metrics = cycle_metrics.iloc[:max_extract]
        
        # Concatenate the data from this IMF with previous ones
        aligned_waveforms = pd.concat([aligned_waveforms, aligned_df], ignore_index=True)
        non_aligned_waveforms = pd.concat([non_aligned_waveforms, non_aligned_df], ignore_index=True)
        all_cycle_metrics = pd.concat([all_cycle_metrics, cycle_metrics], ignore_index=True)
    
    return aligned_waveforms, non_aligned_waveforms, all_cycle_metrics


def plot_waveforms(aligned_waveforms, non_aligned_waveforms, num_plot=20, figsize=(12, 8)):
    
    # Determine the number of waveforms available for plotting
    num_plot = min(num_plot, len(aligned_waveforms), len(non_aligned_waveforms))
    
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Plot phase-aligned waveforms
    axes[0].set_title('Phase-aligned Waveforms')
    for i in range(num_plot):
        offset = i * 0.5  # vertical offset for stacking
        waveform = aligned_waveforms.iloc[i].values
        x = np.linspace(0, 1, len(waveform))
        axes[0].plot(x, waveform + offset, 'k-', alpha=0.7)
    axes[0].set_xlabel('Normalized Phase')
    axes[0].set_ylabel('Amplitude (stacked)')
    axes[0].set_yticks([])
    
    # Plot non-phase-aligned waveforms
    axes[1].set_title('Non-phase-aligned Waveforms')
    for i in range(num_plot):
        offset = i * 0.5
        waveform = non_aligned_waveforms.iloc[i].values
        x = np.linspace(0, 1, len(waveform))
        axes[1].plot(x, waveform + offset, 'k-', alpha=0.7)
    axes[1].set_xlabel('Normalized Time')
    axes[1].set_ylabel('Amplitude (stacked)')
    axes[1].set_yticks([])
    
    plt.tight_layout()
    return fig

In [None]:
aligned_waveforms, non_aligned_waveforms, cycle_metrics = extract_waveforms_with_and_without_alignment(
    rem_imfs, rem_imfs_freqs, max_extract=None)

In [None]:
fig = plot_waveforms(aligned_waveforms, non_aligned_waveforms,
                         num_plot=100, figsize=(12, 8))

plt.show()

In [None]:
def plot_random_waveform_pairs(aligned_waveforms, non_aligned_waveforms, num_pairs=20, figsize=(15, 10)):
    # Make sure we don't try to select more pairs than available
    total_waves = len(aligned_waveforms)
    num_pairs = min(num_pairs, total_waves)
    
    # Randomly select indices
    random_indices = np.random.choice(total_waves, size=num_pairs, replace=False)
    
    # Create a figure with a grid of subplots
    rows = 4
    cols = 5
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle('Random Pairs of Aligned (blue) vs Non-aligned (red) Waveforms', fontsize=12)
    
    # Flatten axes for easier iteration
    axes_flat = axes.flatten()
    
    for i, (ax, idx) in enumerate(zip(axes_flat, random_indices)):
        # Get the corresponding waveforms
        aligned = aligned_waveforms.iloc[idx].values
        non_aligned = non_aligned_waveforms.iloc[idx].values
        
        # Create x-axis values
        x = np.linspace(0, 1, len(aligned))
        
        # Plot both waveforms
        ax.plot(x, aligned, 'b-', label='Aligned', alpha=0.7)
        ax.plot(x, non_aligned, 'r-', label='Non-aligned', alpha=0.7)
        
        # Add index number as title
        ax.set_title(f'Pair {idx}', fontsize=8)
        
        # Remove ticks for cleaner look
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Only add legend to the first subplot
        if i == 0:
            ax.legend(fontsize=8)
    
    plt.tight_layout()
    return fig


In [None]:
fig = plot_random_waveform_pairs(aligned_waveforms, non_aligned_waveforms)
plt.show()

# peak aligned waveforms


In [None]:
def prepare_data_for_dual_alignment(imfs, imf_frequencies):
    """
    Prepare data for UMAP analysis with both peak-aligned and phase-aligned waveforms.
    
    Parameters
    ----------
    imfs : list
        List of IMFs (each should be 2D, samples x channels)
    imf_frequencies : list
        List of frequencies for each IMF
        
    Returns
    -------
    peak_aligned_waveforms : DataFrame
        DataFrame with peak-aligned waveforms
    phase_aligned_waveforms : DataFrame
        DataFrame with phase-aligned waveforms (using emd.cycles.phase_align)
    trials : DataFrame
        DataFrame with cycle metrics and additional features
    all_FPPs : list
        List of frequency phase profiles
    """
    from scipy.interpolate import interp1d
    
    # Initialize DataFrames and lists
    peak_aligned_waveforms = pd.DataFrame()
    phase_aligned_waveforms = pd.DataFrame()
    trials = pd.DataFrame()
    all_FPPs = []
    
    # Define parameters
    theta_range = [5, 12]
    frequencies = np.arange(15, 141, 1)
    angles = np.linspace(-180, 180, 19)
    fs = 2500
    
    # Process each IMF
    for idx, imf in enumerate(imfs):
        # Get cycle data for theta IMF (assumed to be column 5)
        cycle_data = get_cycle_data(imf[:, 5], fs=fs)
        
        # Apply thresholds
        amp_thresh = np.percentile(cycle_data['IA'], 25)
        lo_freq_duration = fs / 5
        hi_freq_duration = fs / 12
        
        conditions = [
            'is_good==1',
            f'duration_samples<{lo_freq_duration}',
            f'duration_samples>{hi_freq_duration}',
            f'max_amp>{amp_thresh}'
        ]
        
        # Get cycles that meet conditions
        all_cycles = get_cycles_with_conditions(cycle_data['cycles'], conditions)
        
        # Check if any cycles meet conditions
        if all_cycles is None or all_cycles.chain_vect.size == 0:
            print(f"No cycles satisfy the conditions for IMF {idx}")
            continue
        
        # Get cycle metrics and indices
        subset_cycles_df = all_cycles.get_metric_dataframe(subset=True)
        subset_indices = subset_cycles_df['index'].values
        
        # Get cycle indices for frequency phase profile
        all_cycles_inds = get_cycle_inds(all_cycles, subset_indices)
        cycles_inds = arrange_cycle_inds(all_cycles_inds)
        
        # Frequency analysis
        freqs = imf_frequencies[idx]
        sub_theta, theta, supra_theta = tg_split(freqs, theta_range)
        supra_theta_sig = np.sum(imf.T[supra_theta], axis=0)
        
        # Wavelet transform computation
        raw_data = sails.wavelet.morlet(
            supra_theta_sig, freqs=frequencies, 
            sample_rate=fs, ncycles=5,
            ret_mode='power', normalise=None
        )
        supraPlot = scipy.stats.zscore(raw_data, axis=1)
        FPP = bin_tf_to_fpp(cycles_inds, supraPlot, bin_count=19)
        all_FPPs.append(FPP)
        
        # Compute features
        mode_freqs, entropies = compute_mode_frequency_and_entropy(FPP, frequencies, angles)
        
        # 1. Extract phase-aligned waveforms using emd.cycles.phase_align
        phase_aligned_arr, _ = emd.cycles.phase_align(
            cycle_data['IP'],
            cycle_data['theta_imf'],
            cycles=all_cycles.iterate(through='subset'),
            npoints=100,
        )
        # Transpose so that each row is one waveform
        phase_aligned_df = pd.DataFrame(phase_aligned_arr.T)
        
        peak_aligned_list = []
        theta_imf = cycle_data['theta_imf']
        npoints_aug = 100  # number of points for the augmented grid

        for cycle in all_cycles.iterate(through='subset'):
            try:
                _, inds = cycle
            except Exception:
                continue
            if len(inds) < 3:
                continue

            # Extract raw waveform for the current cycle
            raw_wave = theta_imf[inds]
            
            # Find the index of the peak
            peak_idx = np.argmax(raw_wave)
            
            # Create original normalized time grid for the cycle
            x_orig = np.linspace(0, 1, len(raw_wave))
            
            # Calculate the peak location in normalized time
            peak_loc = x_orig[peak_idx]
            
            # Shift the time grid so that the peak is centered at 0.5
            x_shifted = x_orig - peak_loc + 0.25
            
            # Map the shifted time grid to the augmented phase range [-pi/2, 2*pi]
            # This is a linear mapping: when x_shifted == 0 -> phase = -pi/2, and when x_shifted == 1 -> phase = 2*pi.
            phase_orig = -np.pi/2 + x_shifted * (5 * np.pi/2)
            
            # To allow for proper interpolation at boundaries, extend the arrays by one full augmented cycle (range = 5*pi/2)
            phase_extended = np.concatenate([phase_orig - 5 * np.pi/2, phase_orig, phase_orig + 5 * np.pi/2])
            wave_extended = np.concatenate([raw_wave, raw_wave, raw_wave])
            
            # Create the interpolation function over the extended phase domain
            f = interp1d(phase_extended, wave_extended, kind='linear',
                        bounds_error=False, fill_value='extrapolate')
            
            # Define the new uniform augmented phase grid over which to sample
            phase_new = np.linspace(-np.pi/2, 2 * np.pi, npoints_aug)
            
            # Resample the waveform on the new phase grid
            resampled_wave = f(phase_new)
            peak_aligned_list.append(resampled_wave)
            
        peak_aligned_df = pd.DataFrame(peak_aligned_list)
        
        # Add cycle metrics and features
        trial = all_cycles.get_metric_dataframe(subset=True)
        trial['mode_freqs'] = mode_freqs
        trial['entropy'] = entropies
        
        # Concatenate with existing data
        peak_aligned_waveforms = pd.concat([peak_aligned_waveforms, peak_aligned_df], ignore_index=True)
        phase_aligned_waveforms = pd.concat([phase_aligned_waveforms, phase_aligned_df], ignore_index=True)
        trials = pd.concat([trials, trial], ignore_index=True)
    
    return peak_aligned_waveforms, phase_aligned_waveforms, trials, all_FPPs


def extract_dual_aligned_data_for_rat(rat_id):
    """
    Extract data for a specific rat, including both peak-aligned and phase-aligned waveforms.
    Follows the same structure as the original extract_data_for_rat function.
    
    Parameters
    ----------
    rat_id : int or str
        Rat ID to process
        
    Returns
    -------
    all_peak_aligned_waveforms : DataFrame
        DataFrame with peak-aligned waveforms from all recordings
    all_phase_aligned_waveforms : DataFrame
        DataFrame with phase-aligned waveforms from all recordings
    all_trials : DataFrame
        DataFrame with trial metrics from all recordings
    """
    # Define the base path to OS Basic datasets
    base_path = '/Users/amir/Desktop/for Abdel/OS Basic'
    fs = 2500  # Sample frequency

    # Initialize empty DataFrames for concatenation
    all_peak_aligned_waveforms = pd.DataFrame()
    all_phase_aligned_waveforms = pd.DataFrame()
    all_trials = pd.DataFrame()

    rat_path = os.path.join(base_path, str(rat_id))

    # Check if the specified rat folder exists
    if not os.path.isdir(rat_path):
        print(f"Rat folder {rat_id} does not exist.")
        return None, None, None

    # List all recording folders in the rat directory
    recording_folders = [
        f for f in os.listdir(rat_path)
        if os.path.isdir(os.path.join(rat_path, f))
    ]

    if not recording_folders:
        print(f"No recording folders found for Rat {rat_id}.")
        return None, None, None

    # Loop over each recording folder
    for recording_folder in recording_folders:
        print(f"Processing recording folder: {recording_folder}")
        recording_path = os.path.join(rat_path, recording_folder)

        # Use regular expressions to parse the folder name
        match = re.match(r'^Rat-OS-Ephys_(Rat\d+)_SD(\d+)_([\w-]+)_([\d-]+)$', recording_folder)
        if not match:
            print(f"Unexpected folder name format: {recording_folder}. Skipping...")
            continue

        rat_id_part = match.group(1)       # e.g., 'Rat6'
        sd_number = match.group(2)         # e.g., '4'
        condition = match.group(3)         # e.g., 'CON'
        date_part = match.group(4)         # e.g., '22-02-2018'

        rat_id_from_folder = ''.join(filter(str.isdigit, rat_id_part))

        # Check if rat_id_from_folder matches rat_id
        if rat_id_from_folder != str(rat_id):
            print(f"Rat ID mismatch in folder {recording_folder}. Expected Rat{rat_id}, found Rat{rat_id_from_folder}. Skipping...")
            continue

        # Detect all trial folders and filter for post_trial2 to post_trial5
        trial_folders = [
            f for f in os.listdir(recording_path)
            if os.path.isdir(os.path.join(recording_path, f)) and
            re.search(r'(?i)post[\-_]?trial[\-_]?([2-5])', f)
        ]

        if not trial_folders:
            print(f"No trial folders found in {recording_folder}.")
            continue

        for trial_folder in trial_folders:
            print(f"Processing trial folder: {trial_folder}")
            trial_path = os.path.join(recording_path, trial_folder)

            # Search for LFP and state files in the trial folder
            lfp_file = None
            state_file = None

            for file_name in os.listdir(trial_path):
                if 'HPC' in file_name and file_name.endswith('.mat'):
                    lfp_file = os.path.join(trial_path, file_name)
                elif ('states' in file_name.lower()) and file_name.endswith('.mat'):
                    state_file = os.path.join(trial_path, file_name)

            # Ensure both LFP and state files were found
            if not lfp_file or not state_file:
                print(f"Missing LFP or state file in {trial_path}. Skipping...")
                continue

            # Extract trial number from folder name
            trial_number_match = re.search(r'(?i)post[\-_]?trial[\-_]?([2-5])', trial_folder)
            if trial_number_match:
                trial_number = int(trial_number_match.group(1))
            else:
                print(f"Unable to extract trial number from folder name: {trial_folder}. Skipping...")
                continue

            # Load data using custom functions
            try:
                # Load LFP and hypnogram data
                lfpHPC, hypno, _ = get_data(lfp_file, state_file)

                # Extract phasic and tonic intervals
                try:
                    phasic_interval, tonic_interval, lfp = extract_pt_intervals(lfpHPC, hypno)
                except ValueError as e:
                    print(f"No REM sleep found in {trial_folder} for Rat {rat_id}, Condition {condition}. Skipping...")
                    continue

                # Extract IMFs for phasic and tonic intervals if intervals are not empty
                if len(phasic_interval) > 0 and len(tonic_interval) > 0:
                    # Extract IMFs for tonic intervals
                    tonic_imfs, tonic_freqs, tonic_lpf = extract_imfs_by_pt_intervals(
                        lfp, fs, tonic_interval, config, return_imfs_freqs=True)
                    
                    # Process tonic intervals if IMFs were extracted successfully
                    if tonic_imfs:  # Check if not empty
                        tonic_peak_aligned, tonic_phase_aligned, tonic_trials, tonic_FPPs = prepare_data_for_dual_alignment(
                            tonic_imfs, tonic_freqs)
                        
                        # Add metadata to waveforms and trials
                        for df in [tonic_peak_aligned, tonic_phase_aligned, tonic_trials]:
                            if not df.empty:
                                df['rat_id'] = rat_id
                                df['condition'] = condition
                                df['trial'] = trial_number
                                df['cycle_type'] = 'tonic'
                                df['SD'] = sd_number
                                df['date'] = date_part
                        
                        # Concatenate with combined data
                        if not tonic_peak_aligned.empty:
                            all_peak_aligned_waveforms = pd.concat(
                                [all_peak_aligned_waveforms, tonic_peak_aligned], ignore_index=True)
                        
                        if not tonic_phase_aligned.empty:
                            all_phase_aligned_waveforms = pd.concat(
                                [all_phase_aligned_waveforms, tonic_phase_aligned], ignore_index=True)
                        
                        if not tonic_trials.empty:
                            all_trials = pd.concat(
                                [all_trials, tonic_trials], ignore_index=True)
                    
                    # Extract IMFs for phasic intervals
                    phasic_imfs, phasic_freqs, phasic_lpf = extract_imfs_by_pt_intervals(
                        lfp, fs, phasic_interval, config, return_imfs_freqs=True)
                    
                    # Process phasic intervals if IMFs were extracted successfully
                    if phasic_imfs:  # Check if not empty
                        phasic_peak_aligned, phasic_phase_aligned, phasic_trials, phasic_FPPs = prepare_data_for_dual_alignment(
                            phasic_imfs, phasic_freqs)
                        
                        # Add metadata to waveforms and trials
                        for df in [phasic_peak_aligned, phasic_phase_aligned, phasic_trials]:
                            if not df.empty:
                                df['rat_id'] = rat_id
                                df['condition'] = condition
                                df['trial'] = trial_number
                                df['cycle_type'] = 'phasic'
                                df['SD'] = sd_number
                                df['date'] = date_part
                        
                        # Concatenate with combined data
                        if not phasic_peak_aligned.empty:
                            all_peak_aligned_waveforms = pd.concat(
                                [all_peak_aligned_waveforms, phasic_peak_aligned], ignore_index=True)
                        
                        if not phasic_phase_aligned.empty:
                            all_phase_aligned_waveforms = pd.concat(
                                [all_phase_aligned_waveforms, phasic_phase_aligned], ignore_index=True)
                        
                        if not phasic_trials.empty:
                            all_trials = pd.concat(
                                [all_trials, phasic_trials], ignore_index=True)

            except FileNotFoundError:
                print(f"Data not found in {trial_path}. Skipping...")
            except Exception as e:
                print(f"Error processing {trial_path}: {str(e)}")
                continue

    # Check if any data was extracted
    if all_peak_aligned_waveforms.empty:
        print(f"No data extracted for Rat {rat_id}.")
        return None, None, None

    print(f"Extracted {len(all_peak_aligned_waveforms)} peak-aligned waveforms")
    print(f"Extracted {len(all_phase_aligned_waveforms)} phase-aligned waveforms")
    print(f"Extracted {len(all_trials)} trials")

    return all_peak_aligned_waveforms, all_phase_aligned_waveforms, all_trials


def plot_umaps(peak_aligned_waveforms, phase_aligned_waveforms, trials_df, 
               color_by='peak_values', perplexity=30, min_dist=0.1, figsize=(12, 5)):
    """
    Create side-by-side UMAP plots for peak-aligned and phase-aligned waveforms,
    following the style from your snippet.
    
    Parameters
    ----------
    peak_aligned_waveforms : DataFrame
        DataFrame with peak-aligned waveforms
    phase_aligned_waveforms : DataFrame
        DataFrame with phase-aligned waveforms
    trials_df : DataFrame
        DataFrame with trial metrics
    color_by : str
        Column in trials_df to use for coloring points
    perplexity : int
        UMAP parameter for perplexity
    min_dist : float
        UMAP parameter for minimum distance
    figsize : tuple
        Figure size
    
    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object
    """
    import umap
    import matplotlib.pyplot as plt
    from sklearn.preprocessing import StandardScaler
    
    # Create figure with two subplots
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Get color values from trials_df
    color_values = trials_df[color_by].values
    
    # Determine color scale limits
    vmin = np.percentile(color_values, 1)
    vmax = np.percentile(color_values, 99)
    
    # Process peak-aligned waveforms
    waveform_data = peak_aligned_waveforms.iloc[:, :100].values
    
    # Standardize data
    scaler = StandardScaler()
    waveform_data_scaled = scaler.fit_transform(waveform_data)
    
    # Apply UMAP
    reducer = umap.UMAP(n_neighbors=perplexity, min_dist=min_dist, random_state=42)
    embedding_peak = reducer.fit_transform(waveform_data_scaled)
    
    # Plot UMAP with scatter plot
    scatter1 = axes[0].scatter(
        embedding_peak[:, 0], 
        embedding_peak[:, 1], 
        c=color_values,
        vmin=vmin, 
        vmax=vmax,
        cmap='hot', 
        s=5
    )
    axes[0].set_title('Peak-Aligned Waveforms')
    
    # Process phase-aligned waveforms
    waveform_data = phase_aligned_waveforms.iloc[:, :100].values
    
    # Standardize data
    scaler = StandardScaler()
    waveform_data_scaled = scaler.fit_transform(waveform_data)
    
    # Apply UMAP
    reducer = umap.UMAP(n_neighbors=perplexity, min_dist=min_dist, random_state=42)
    embedding_phase = reducer.fit_transform(waveform_data_scaled)
    
    # Plot UMAP with scatter plot
    scatter2 = axes[1].scatter(
        embedding_phase[:, 0], 
        embedding_phase[:, 1], 
        c=color_values,
        vmin=vmin, 
        vmax=vmax,
        cmap='hot', 
        s=5
    )
    axes[1].set_title('Phase-Aligned Waveforms')
    
    # Add common colorbar
    cbar = fig.colorbar(scatter1, ax=axes.ravel().tolist())
    cbar.set_label(color_by)
    
    plt.tight_layout()
    
    return fig

In [None]:

peak_aligned_waveforms, phase_aligned_waveforms, trials_df = extract_dual_aligned_data_for_rat('11')

In [None]:
def plot_random_alignment_triplets(peak_aligned_waveforms, phase_aligned_waveforms, num_triplets=20, figsize=(15, 10)):
    # Make sure we don't try to select more triplets than available
    total_waves = min(len(peak_aligned_waveforms), len(phase_aligned_waveforms))
    num_triplets = min(num_triplets, total_waves)
    
    # Randomly select indices
    random_indices = np.random.choice(total_waves, size=num_triplets, replace=False)
    
    # Create a figure with a grid of subplots
    rows = 4
    cols = 5
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle('Peak-aligned (blue) vs Phase-aligned (red)', fontsize=12)
    
    # Flatten axes for easier iteration
    axes_flat = axes.flatten()
    
    for i, (ax, idx) in enumerate(zip(axes_flat, random_indices)):
        # Get the corresponding waveforms
        peak_aligned = peak_aligned_waveforms.iloc[idx].values[:100]  # Ensure we only take the first 100 columns
        phase_aligned = phase_aligned_waveforms.iloc[idx].values[:100]  # Ensure we only take the first 100 columns
        
        # Create x-axis values
        x = np.linspace(0, 1, len(peak_aligned))
        
        # Plot both waveforms
        ax.plot(x, peak_aligned, 'b-', label='Peak-aligned', alpha=0.7)
        ax.plot(x, phase_aligned, 'r-', label='Phase-aligned', alpha=0.7)
        
        # Add index number as title
        ax.set_title(f'Pair {idx}', fontsize=8)
        
        # Remove ticks for cleaner look
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Only add legend to the first subplot
        if i == 0:
            ax.legend(fontsize=8)
    
    plt.tight_layout()
    return fig


In [None]:
fig = plot_random_alignment_triplets(peak_aligned_waveforms, phase_aligned_waveforms)
plt.show()

In [None]:

peak_aligned_waveforms_combined = peak_aligned_waveforms.drop(columns=['rat_id', 'condition', 'trial', 'cycle_type', 'SD', 'date'])

In [None]:

phase_aligned_waveforms_combined = phase_aligned_waveforms.drop(columns=['rat_id', 'condition', 'trial', 'cycle_type', 'SD', 'date'])

In [None]:
phase_aligned_umap_embedder = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=4, metric='euclidean', random_state=42)
phase_aligned_embedding_rem = phase_aligned_umap_embedder.fit_transform(phase_aligned_waveforms_combined.to_numpy())

In [None]:
peak_aligned_umap_embedder = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=4, metric='euclidean', random_state=42)
peak_aligned_embedding_rem = peak_aligned_umap_embedder.fit_transform(peak_aligned_waveforms_combined.to_numpy())

In [None]:
defaul_peaks = np.array(trials_df['peak_values'])

In [None]:
trial_asc2desc = np.array(trials_df['asc2desc'])

In [None]:
plt.figure(figsize=(8,6))  # Adjusted size for a single plot

plt.scatter(peak_aligned_embedding_rem[:, 0], peak_aligned_embedding_rem[:, 1], 
            c=trials_df['peak_values'],
            vmin=np.percentile(trials_df['peak_values'], 1),
            vmax=np.percentile(trials_df['peak_values'], 99),
            cmap='hot', s=50)

plt.colorbar()
plt.title('Peak aligned UMAP - Rat 11')
plt.show()

In [None]:
plt.figure(figsize=(8,6))  # Adjusted size for a single plot

plt.scatter(peak_aligned_embedding_rem[:, 0], peak_aligned_embedding_rem[:, 1], 
            c=trial_asc2desc,
            vmin=np.percentile(trial_asc2desc, 1),
            vmax=np.percentile(trial_asc2desc, 99),
            cmap='hot', s=50)

plt.colorbar()
plt.title('Peak aligned UMAP - Rat 11')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

# Split data into phasic and tonic
phasic_mask = trials_df['cycle_type'] == 'phasic'
tonic_mask = trials_df['cycle_type'] == 'tonic'

# Get indices for phasic and tonic cycles
phasic_indices = np.where(phasic_mask)[0]
tonic_indices = np.where(tonic_mask)[0]

# Extract asc2desc values for each type
phasic_asc2desc = trial_asc2desc[phasic_indices]
tonic_asc2desc = trial_asc2desc[tonic_indices]

# Extract UMAP coordinates for each type
phasic_x = peak_aligned_embedding_rem[phasic_indices, 0]
phasic_y = peak_aligned_embedding_rem[phasic_indices, 1]
tonic_x = peak_aligned_embedding_rem[tonic_indices, 0]
tonic_y = peak_aligned_embedding_rem[tonic_indices, 1]

# Create figure and axis
plt.figure(figsize=(10, 8))

# Calculate percentiles for color scaling
vmin_phasic = np.percentile(phasic_asc2desc, 1)
vmax_phasic = np.percentile(phasic_asc2desc, 99)
vmin_tonic = np.percentile(tonic_asc2desc, 1)
vmax_tonic = np.percentile(tonic_asc2desc, 99)

# Create custom colormaps
tonic_cmap = plt.cm.Blues
phasic_cmap = plt.cm.Reds

# Plot tonic cycles first (bottom layer)
sc1 = plt.scatter(tonic_x, tonic_y,
                 c=tonic_asc2desc,
                 vmin=vmin_tonic, vmax=vmax_tonic,
                 cmap=tonic_cmap, s=40, 
                 alpha=0.8, label='Tonic')

# Plot phasic cycles on top (top layer)
sc2 = plt.scatter(phasic_x, phasic_y,
                 c=phasic_asc2desc,
                 vmin=vmin_phasic, vmax=vmax_phasic,
                 cmap=phasic_cmap, s=40,
                 alpha=0.8, label='Phasic', marker='^')

# Add colorbars
cbar_ax1 = plt.colorbar(sc1)
cbar_ax1.set_label('Tonic asc2desc')
cbar_ax2 = plt.colorbar(sc2)
cbar_ax2.set_label('Phasic asc2desc')

# Add title and legend
plt.title('Peak aligned UMAP - Rat 11 (Phasic vs Tonic cycles)', fontsize=14)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.legend()

# Improve aesthetics
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# --- Assume these variables are defined in your workspace ---
# trials_df: DataFrame with a 'cycle_type' column (values: 'phasic' or 'tonic')
# trial_asc2desc: numpy array of asc2desc values for each cycle
# peak_aligned_embedding_rem: numpy array with UMAP coordinates (at least 2D)

# Split data into phasic and tonic masks
phasic_mask = trials_df['cycle_type'] == 'phasic'
tonic_mask = trials_df['cycle_type'] == 'tonic'

# Get indices for phasic and tonic cycles
phasic_indices = np.where(phasic_mask)[0]
tonic_indices = np.where(tonic_mask)[0]

# Extract asc2desc values for each type
phasic_asc2desc = trial_asc2desc[phasic_indices]
tonic_asc2desc = trial_asc2desc[tonic_indices]

# Extract corresponding UMAP coordinates
phasic_x = peak_aligned_embedding_rem[phasic_indices, 0]
phasic_y = peak_aligned_embedding_rem[phasic_indices, 1]
tonic_x  = peak_aligned_embedding_rem[tonic_indices, 0]
tonic_y  = peak_aligned_embedding_rem[tonic_indices, 1]

# ---- Discretize the continuous asc2desc values into 3 bins ----
tol = 0.05  # tolerance for "approximately symmetric"

def discretize(values, tol=0.05):
    """
    Map asc2desc values into:
      0: value < 1 - tol     (Faster ascent, slower descent)
      1: |value - 1| <= tol   (Symmetric cycle)
      2: value > 1 + tol      (Slower ascent, faster descent)
    """
    discrete = np.empty_like(values, dtype=int)
    discrete[values < 1 - tol] = 0
    discrete[np.abs(values - 1) <= tol] = 1
    discrete[values > 1 + tol] = 2
    return discrete

phasic_bins = discretize(phasic_asc2desc, tol)
tonic_bins  = discretize(tonic_asc2desc, tol)

# ---- Create discrete colormaps ----
# For phasic cycles: based on Reds
phasic_base_cmap = plt.cm.Reds
colors_phasic = phasic_base_cmap(np.linspace(0.5, 0.8, 3))
discrete_phasic_cmap = ListedColormap(colors_phasic)
norm_phasic = BoundaryNorm([0, 1, 2, 3], discrete_phasic_cmap.N)

# For tonic cycles: based on Blues
tonic_base_cmap = plt.cm.Blues
colors_tonic = tonic_base_cmap(np.linspace(0.5, 0.8, 3))
discrete_tonic_cmap = ListedColormap(colors_tonic)
norm_tonic = BoundaryNorm([0, 1, 2, 3], discrete_tonic_cmap.N)

# ---- Create the plot ----
plt.figure(figsize=(10, 8))

# Plot tonic cycles first (bottom layer)
sc1 = plt.scatter(tonic_x, tonic_y,
                  c=tonic_bins,
                  cmap=discrete_tonic_cmap,
                  norm=norm_tonic,
                  s=40,
                  alpha=0.8,
                  label='Tonic')

# Plot phasic cycles on top (marker style changed for clarity)
sc2 = plt.scatter(phasic_x, phasic_y,
                  c=phasic_bins,
                  cmap=discrete_phasic_cmap,
                  norm=norm_phasic,
                  s=40,
                  alpha=0.8,
                  label='Phasic',
                  marker='^')

# Define custom colorbar ticks and labels (center positions for each bin)
tick_locs = [0.5, 1.5, 2.5]
tick_labels = ['Fast Ascent\nSlow Descent', 'Symmetric', 'Slow Ascent\nFast Descent']

# Add a colorbar for tonic cycles
cbar1 = plt.colorbar(sc1, ticks=tick_locs)
cbar1.ax.set_yticklabels(tick_labels)
cbar1.set_label('Tonic asc2desc')

# Add a colorbar for phasic cycles
cbar2 = plt.colorbar(sc2, ticks=tick_locs)
cbar2.ax.set_yticklabels(tick_labels)
cbar2.set_label('Phasic asc2desc')

# Add titles, labels, and legends
plt.title('Peak Aligned UMAP with Discrete asc2desc Categories', fontsize=14)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming phase_aligned_waveforms_combined_cleaned and phase_aligned_embedding_rem are already created.

# Create a boolean mask of the rows that were kept.
inf_mask = np.isinf(phase_aligned_waveforms_combined.to_numpy()).any(axis=1)
kept_rows_mask = ~inf_mask

# Filter trials_df to keep the rows that were kept in the waveform dataframe.
trials_df_filtered = trials_df[kept_rows_mask]

# Now create the plot.
plt.figure(figsize=(8,6))

plt.scatter(phase_aligned_embedding_rem[:, 0], phase_aligned_embedding_rem[:, 1], 
            c=defaul_peaks,
            vmin=np.percentile(defaul_peaks, 1),
            vmax=np.percentile(defaul_peaks, 99),
            cmap='hot', s=50)

plt.colorbar()
plt.title('Phase aligned UMAP - Rat 11')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming phase_aligned_waveforms_combined_cleaned and phase_aligned_embedding_rem are already created.

# Create a boolean mask of the rows that were kept.
inf_mask = np.isinf(phase_aligned_waveforms_combined.to_numpy()).any(axis=1)
kept_rows_mask = ~inf_mask

# Filter trials_df to keep the rows that were kept in the waveform dataframe.
trials_df_filtered = trials_df[kept_rows_mask]

# Now create the plot.
plt.figure(figsize=(8,6))

plt.scatter(phase_aligned_embedding_rem[:, 0], phase_aligned_embedding_rem[:, 1], 
            c=trials_df_filtered['asc2desc'],
            vmin=np.percentile(trials_df_filtered['asc2desc'], 1),
            vmax=np.percentile(trials_df_filtered['asc2desc'], 99),
            cmap='hot', s=50)

plt.colorbar()
plt.title('Phase aligned UMAP - Rat 11')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

# Split data into phasic and tonic
phasic_mask = trials_df['cycle_type'] == 'phasic'
tonic_mask = trials_df['cycle_type'] == 'tonic'

# Get indices for phasic and tonic cycles
phasic_indices = np.where(phasic_mask)[0]
tonic_indices = np.where(tonic_mask)[0]

# Extract asc2desc values for each type
phasic_asc2desc = trial_asc2desc[phasic_indices]
tonic_asc2desc = trial_asc2desc[tonic_indices]

# Extract UMAP coordinates for each type
phasic_x = phase_aligned_embedding_rem[phasic_indices, 0]
phasic_y = phase_aligned_embedding_rem[phasic_indices, 1]
tonic_x = phase_aligned_embedding_rem[tonic_indices, 0]
tonic_y = phase_aligned_embedding_rem[tonic_indices, 1]

# Create figure and axis
plt.figure(figsize=(10, 8))

# Calculate percentiles for color scaling
vmin_phasic = np.percentile(phasic_asc2desc, 1)
vmax_phasic = np.percentile(phasic_asc2desc, 99)
vmin_tonic = np.percentile(tonic_asc2desc, 1)
vmax_tonic = np.percentile(tonic_asc2desc, 99)

# Create custom colormaps
tonic_cmap = plt.cm.Blues
phasic_cmap = plt.cm.Reds

# Plot tonic cycles first (bottom layer)
sc1 = plt.scatter(tonic_x, tonic_y,
                 c=tonic_asc2desc,
                 vmin=vmin_tonic, vmax=vmax_tonic,
                 cmap=tonic_cmap, s=40, 
                 alpha=0.8, label='Tonic')

# Plot phasic cycles on top (top layer)
sc2 = plt.scatter(phasic_x, phasic_y,
                 c=phasic_asc2desc,
                 vmin=vmin_phasic, vmax=vmax_phasic,
                 cmap=phasic_cmap, s=40,
                 alpha=0.8, label='Phasic', marker='^')

# Add colorbars
cbar_ax1 = plt.colorbar(sc1)
cbar_ax1.set_label('Tonic asc2desc')
cbar_ax2 = plt.colorbar(sc2)
cbar_ax2.set_label('Phasic asc2desc')

# Add title and legend
plt.title('Phase aligned UMAP - Rat 11 (Phasic vs Tonic cycles)', fontsize=14)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.legend()

# Improve aesthetics
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# --- Assume these variables are defined in your workspace ---
# trials_df: DataFrame with a 'cycle_type' column (values: 'phasic' or 'tonic')
# trial_asc2desc: numpy array of asc2desc values for each cycle
# peak_aligned_embedding_rem: numpy array with UMAP coordinates (at least 2D)

# Split data into phasic and tonic masks
phasic_mask = trials_df['cycle_type'] == 'phasic'
tonic_mask = trials_df['cycle_type'] == 'tonic'

# Get indices for phasic and tonic cycles
phasic_indices = np.where(phasic_mask)[0]
tonic_indices = np.where(tonic_mask)[0]

# Extract asc2desc values for each type
phasic_asc2desc = trial_asc2desc[phasic_indices]
tonic_asc2desc = trial_asc2desc[tonic_indices]

# Extract corresponding UMAP coordinates
phasic_x = phase_aligned_embedding_rem[phasic_indices, 0]
phasic_y = phase_aligned_embedding_rem[phasic_indices, 1]
tonic_x  = phase_aligned_embedding_rem[tonic_indices, 0]
tonic_y  = phase_aligned_embedding_rem[tonic_indices, 1]

# ---- Discretize the continuous asc2desc values into 3 bins ----
tol = 0.05  # tolerance for "approximately symmetric"

def discretize(values, tol=0.05):
    """
    Map asc2desc values into:
      0: value < 1 - tol     (Faster ascent, slower descent)
      1: |value - 1| <= tol   (Symmetric cycle)
      2: value > 1 + tol      (Slower ascent, faster descent)
    """
    discrete = np.empty_like(values, dtype=int)
    discrete[values < 1 - tol] = 0
    discrete[np.abs(values - 1) <= tol] = 1
    discrete[values > 1 + tol] = 2
    return discrete

phasic_bins = discretize(phasic_asc2desc, tol)
tonic_bins  = discretize(tonic_asc2desc, tol)

# ---- Create discrete colormaps ----
# For phasic cycles: based on Reds
phasic_base_cmap = plt.cm.Reds
colors_phasic = phasic_base_cmap(np.linspace(0.5, 0.8, 3))
discrete_phasic_cmap = ListedColormap(colors_phasic)
norm_phasic = BoundaryNorm([0, 1, 2, 3], discrete_phasic_cmap.N)

# For tonic cycles: based on Blues
tonic_base_cmap = plt.cm.Blues
colors_tonic = tonic_base_cmap(np.linspace(0.5, 0.8, 3))
discrete_tonic_cmap = ListedColormap(colors_tonic)
norm_tonic = BoundaryNorm([0, 1, 2, 3], discrete_tonic_cmap.N)

# ---- Create the plot ----
plt.figure(figsize=(10, 8))

# Plot tonic cycles first (bottom layer)
sc1 = plt.scatter(tonic_x, tonic_y,
                  c=tonic_bins,
                  cmap=discrete_tonic_cmap,
                  norm=norm_tonic,
                  s=40,
                  alpha=0.8,
                  label='Tonic')

# Plot phasic cycles on top (marker style changed for clarity)
sc2 = plt.scatter(phasic_x, phasic_y,
                  c=phasic_bins,
                  cmap=discrete_phasic_cmap,
                  norm=norm_phasic,
                  s=40,
                  alpha=0.8,
                  label='Phasic',
                  marker='^')

# Define custom colorbar ticks and labels (center positions for each bin)
tick_locs = [0.5, 1.5, 2.5]
tick_labels = ['Fast Ascent\nSlow Descent', 'Symmetric', 'Slow Ascent\nFast Descent']

# Add a colorbar for tonic cycles
cbar1 = plt.colorbar(sc1, ticks=tick_locs)
cbar1.ax.set_yticklabels(tick_labels)
cbar1.set_label('Tonic asc2desc')

# Add a colorbar for phasic cycles
cbar2 = plt.colorbar(sc2, ticks=tick_locs)
cbar2.ax.set_yticklabels(tick_labels)
cbar2.set_label('Phasic asc2desc')

# Add titles, labels, and legends
plt.title('Phase Aligned UMAP with Discrete asc2desc Categories', fontsize=14)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create a figure with two 3D subplots
fig = plt.figure(figsize=(15, 6))

# Peak-aligned 3D plot
ax1 = fig.add_subplot(121, projection='3d')
scatter1 = ax1.scatter(peak_aligned_embedding_rem[:, 0], 
                      peak_aligned_embedding_rem[:, 1], 
                      peak_aligned_embedding_rem[:, 2],
                      c=trials_df['peak_values'],
                      vmin=np.percentile(trials_df['peak_values'], 1),
                      vmax=np.percentile(trials_df['peak_values'], 99),
                      cmap='hot',
                      s=50)
ax1.set_title('Peak-aligned UMAP 3D')
ax1.set_xlabel('UMAP 1')
ax1.set_ylabel('UMAP 2')
ax1.set_zlabel('UMAP 3')

# Phase-aligned 3D plot
ax2 = fig.add_subplot(122, projection='3d')
scatter2 = ax2.scatter(phase_aligned_embedding_rem[:, 0], 
                      phase_aligned_embedding_rem[:, 1], 
                      phase_aligned_embedding_rem[:, 2],
                      c=trials_df['peak_values'],
                      vmin=np.percentile(trials_df['peak_values'], 1),
                      vmax=np.percentile(trials_df['peak_values'], 99),
                      cmap='hot',
                      s=50)
ax2.set_title('Phase-aligned UMAP 3D')
ax2.set_xlabel('UMAP 1')
ax2.set_ylabel('UMAP 2')
ax2.set_zlabel('UMAP 3')

# Add colorbars
plt.colorbar(scatter1, ax=ax1, label='Peak Values')
plt.colorbar(scatter2, ax=ax2, label='Peak Values')

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create a figure with two 3D subplots
fig = plt.figure(figsize=(15, 6))

# Peak-aligned 3D plot
ax1 = fig.add_subplot(121, projection='3d')
scatter1 = ax1.scatter(peak_aligned_embedding_rem[:, 0], 
                      peak_aligned_embedding_rem[:, 1], 
                      peak_aligned_embedding_rem[:, 2],
                      c=trials_df['peak_values'],
                      vmin=np.percentile(trials_df['peak_values'], 1),
                      vmax=np.percentile(trials_df['peak_values'], 99),
                      cmap='hot',
                      s=50)
ax1.set_title('Peak-aligned UMAP 3D')
ax1.set_xlabel('UMAP 1')
ax1.set_ylabel('UMAP 2')
ax1.set_zlabel('UMAP 3')

# Phase-aligned 3D plot
ax2 = fig.add_subplot(122, projection='3d')
scatter2 = ax2.scatter(phase_aligned_embedding_rem[:, 0], 
                      phase_aligned_embedding_rem[:, 1], 
                      phase_aligned_embedding_rem[:, 2],
                      c=trials_df['peak_values'],
                      vmin=np.percentile(trials_df['peak_values'], 1),
                      vmax=np.percentile(trials_df['peak_values'], 99),
                      cmap='hot',
                      s=50)
ax2.set_title('Phase-aligned UMAP 3D')
ax2.set_xlabel('UMAP 1')
ax2.set_ylabel('UMAP 2')
ax2.set_zlabel('UMAP 3')

# Add colorbars
plt.colorbar(scatter1, ax=ax1, label='Peak Values')
plt.colorbar(scatter2, ax=ax2, label='Peak Values')

plt.tight_layout()
plt.show()

In [None]:
import plotly.express as px
import plotly.graph_objects as go

# First plot: Peak-aligned UMAP 3D
fig1 = px.scatter_3d(
    x=peak_aligned_embedding_rem[:, 0],
    y=peak_aligned_embedding_rem[:, 1],
    z=peak_aligned_embedding_rem[:, 2],
    color=trials_df['peak_values'],
    color_continuous_scale='hot',
    range_color=(np.percentile(trials_df['peak_values'], 1),
                 np.percentile(trials_df['peak_values'], 99)),
    labels={'color': 'Peak Values'},
    title='Peak-aligned UMAP 3D'
)

# Update layout for better visualization
fig1.update_layout(
    scene=dict(
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        zaxis_title='UMAP 3'
    ),
    width=800,
    height=800
)

# Show the first plot
fig1.show()

In [None]:
# Second plot: Phase-aligned UMAP 3D
fig2 = px.scatter_3d(
    x=phase_aligned_embedding_rem[:, 0],
    y=phase_aligned_embedding_rem[:, 1],
    z=phase_aligned_embedding_rem[:, 2],
    color=trials_df_filtered['peak_values'],
    color_continuous_scale='hot',
    range_color=(np.percentile(trials_df_filtered['peak_values'], 1),
                 np.percentile(trials_df_filtered['peak_values'], 99)),
    labels={'color': 'Peak Values'},
    title='Phase-aligned UMAP 3D'
)

# Update layout for better visualization
fig2.update_layout(
    scene=dict(
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        zaxis_title='UMAP 3'
    ),
    width=800,
    height=800
)

# Show the second plot
fig2.show()

In [None]:
# SI 

In [None]:
from structure_index import compute_structure_index, draw_graph

In [None]:
# Rotation and Translation of dataset to align it to another dataset (template) based on a feature
# Arguments:
# data = Dataset of size N to be aligned
# feature = an array of size N representing alignment feature
# template_data = Dataset to which (data) will be aligned
# template_feature = an array representing alignment feature of the template data
def align_point_cloud(data, feature,
                      template_data,
                      template_feature,
                      n_bins=15,
                      n_neighbors=15,
                      dims=None,
                      distance_metric='euclidean',
                      discrete_label=False,
                      num_shuffles=10,
                      verbose=False):
    params = {
        'n_bins': n_bins,
        'n_neighbors': n_neighbors,
        'dims': dims,
        'distance_metric': distance_metric,
        'discrete_label': discrete_label,
        'num_shuffles': num_shuffles,
        'verbose': verbose,
    }

    # Get bins from the data
    SI, binLabel, overlapMat, sSI = compute_structure_index(data, np.array(feature), **params)
    SI_temp, binLabel_temp, overlapMat_temp, sSI_temp = compute_structure_index(template_data, np.array(template_feature), **params)

    # Get centroids of bins (p and p'); these are the points that will be aligned
    p = []
    for i in range(params['n_bins']):
        p.append(np.mean(data[binLabel[0] == i], axis=0))
    p = np.array(p)
    
    p_temp = []
    for i in range(params['n_bins']):
        p_temp.append(np.mean(template_data[binLabel_temp[0] == i], axis=0))
    p_temp = np.array(p_temp)

    # q = p - mean(p)
    # Get deviations of points from their means (q and q'); the Qs are used to get the scatter matrix H
    p_mean = np.mean(p, axis=0)
    p_temp_mean = np.mean(p_temp, axis=0)
    q = p - p_mean
    q_temp = p_temp - p_temp_mean

    # Get Rotation matrix
    H = np.dot(q_temp.T, q)
    U, S, Vh = np.linalg.svd(H, full_matrices=True)
    R = np.dot(U, Vh)
    if np.linalg.det(R) < 0:
        Vh[Vh.shape[0]-1, :] = Vh[Vh.shape[0]-1, :] * (-1)
        R = np.dot(U, Vh)

    T = p_temp_mean - np.dot(p_mean, R)
    new_data = np.dot(data, R) + T
    return new_data

In [None]:
SI, binLabel, overlapMat, sSI = compute_structure_index(phase_aligned_embedding_rem, np.array(defaul_peaks), **params)

In [None]:
print(f"Structure Index (SI): {SI}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Plot the overlap matrix
plt.figure(figsize=(10, 8))
sns.heatmap(overlapMat, annot=True, fmt=".2f")
plt.title('Overlap Matrix - Phase Aligned')
plt.xlabel('Bin Group')
plt.ylabel('Bin Group')
plt.show()

# Plot the UMAP embedding colored by bin labels
plt.figure(figsize=(10, 8))
scatter = plt.scatter(phase_aligned_embedding_rem[:, 0], phase_aligned_embedding_rem[:, 1], c=binLabel[0], alpha=0.7)
plt.title('UMAP Embedding for Phase Aligned Colored by Bin Labels')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.colorbar(scatter, label='Bin Label')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Assuming 'cycle_type' column in default_trials_df indicates 'phasic' or 'tonic'
trials_df['bin_label'] = binLabel[0]

# Count the number of phasic and tonic cycles in each bin
phasic_counts = trials_df[trials_df['cycle_type'] == 'phasic'].groupby('bin_label').size()
tonic_counts = trials_df[trials_df['cycle_type'] == 'tonic'].groupby('bin_label').size()

# Fill missing bins with 0
phasic_counts = phasic_counts.reindex(range(n_bins), fill_value=0)
tonic_counts = tonic_counts.reindex(range(n_bins), fill_value=0)

# Calculate the percentage for each bin
total_counts = phasic_counts + tonic_counts
phasic_percentage = (phasic_counts / total_counts) * 100
tonic_percentage = (tonic_counts / total_counts) * 100

# Setup the grouped bar plot
x = np.arange(n_bins)       # the label locations
width = 0.35                # the width of the bars

plt.figure(figsize=(10, 8))
plt.bar(x - width/2, phasic_percentage, width, alpha=0.7, label='Phasic')
plt.bar(x + width/2, tonic_percentage, width, alpha=0.7, label='Tonic')

plt.xlabel('SI Bin')
plt.ylabel('Percentage')
plt.title('Percentage of Phasic and Tonic Cycles in Each SI Bin - Phase Aligned')
plt.xticks(x, [f"Bin {i}" for i in x])
plt.legend()
plt.show()

In [None]:
SI_pa, binLabel_pa, overlapMat_pa, sSI_pa = compute_structure_index(peak_aligned_embedding_rem, np.array(defaul_peaks), **params)

In [None]:
print(f"Structure Index (SI): {SI_pa}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Plot the overlap matrix
plt.figure(figsize=(10, 8))
sns.heatmap(overlapMat, annot=True, fmt=".2f")
plt.title('Overlap Matrix - Phase Aligned')
plt.xlabel('Bin Group')
plt.ylabel('Bin Group')
plt.show()

# Plot the UMAP embedding colored by bin labels
plt.figure(figsize=(10, 8))
scatter = plt.scatter(peak_aligned_embedding_rem[:, 0], peak_aligned_embedding_rem[:, 1], c=binLabel[0], alpha=0.7)
plt.title('UMAP Embedding for Phase Aligned Colored by Bin Labels')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.colorbar(scatter, label='Bin Label')
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Assuming 'cycle_type' column in default_trials_df indicates 'phasic' or 'tonic'
trials_df['bin_label'] = binLabel_pa[0]

# Count the number of phasic and tonic cycles in each bin
phasic_counts = trials_df[trials_df['cycle_type'] == 'phasic'].groupby('bin_label').size()
tonic_counts = trials_df[trials_df['cycle_type'] == 'tonic'].groupby('bin_label').size()

# Fill missing bins with 0
phasic_counts = phasic_counts.reindex(range(n_bins), fill_value=0)
tonic_counts = tonic_counts.reindex(range(n_bins), fill_value=0)

# Calculate the percentage for each bin
total_counts = phasic_counts + tonic_counts
phasic_percentage = (phasic_counts / total_counts) * 100
tonic_percentage = (tonic_counts / total_counts) * 100

# Setup the grouped bar plot
x = np.arange(n_bins)       # the label locations
width = 0.35                # the width of the bars

plt.figure(figsize=(10, 8))
plt.bar(x - width/2, phasic_percentage, width, alpha=0.7, label='Phasic')
plt.bar(x + width/2, tonic_percentage, width, alpha=0.7, label='Tonic')

plt.xlabel('SI Bin')
plt.ylabel('Percentage')
plt.title('Percentage of Phasic and Tonic Cycles in Each SI Bin - Peak Aligned')
plt.xticks(x, [f"Bin {i}" for i in x])
plt.legend()
plt.show()

# compare controls from OS Basic and RGS

In [None]:
import os
import pandas as pd
import re

def extract_data_for_rat(rat_id):
    # Define the base path to OS Basic datasets
    base_path = '/Users/amir/Desktop/for Abdel/OS Basic'
    fs = 2500  # Sample frequency

    # Initialize empty DataFrames for concatenation across all recordings and trials for the specified rat
    all_combined_waveforms = pd.DataFrame()
    all_combined_trials = pd.DataFrame()

    rat_path = os.path.join(base_path, str(rat_id))

    # Check if the specified rat folder exists
    if not os.path.isdir(rat_path):
        print(f"Rat folder {rat_id} does not exist.")
        return None, None

    # List all recording folders in the rat directory
    recording_folders = [
        f for f in os.listdir(rat_path)
        if os.path.isdir(os.path.join(rat_path, f))
    ]

    if not recording_folders:
        print(f"No recording folders found for Rat {rat_id}.")
        return None, None

    # Loop over each recording folder
    for recording_folder in recording_folders:
        print(f"Processing recording folder: {recording_folder}")
        recording_path = os.path.join(rat_path, recording_folder)

        # Use regular expressions to parse the folder name
        match = re.match(r'^Rat-OS-Ephys_(Rat\d+)_SD(\d+)_([\w-]+)_([\d-]+)$', recording_folder)
        if not match:
            print(f"Unexpected folder name format: {recording_folder}. Skipping...")
            continue

        rat_id_part = match.group(1)       # e.g., 'Rat6'
        sd_number = match.group(2)         # e.g., '4'
        condition = match.group(3)         # e.g., 'CON'
        date_part = match.group(4)         # e.g., '22-02-2018'

        rat_id_from_folder = ''.join(filter(str.isdigit, rat_id_part))

        # Check if rat_id_from_folder matches rat_id
        if rat_id_from_folder != str(rat_id):
            print(f"Rat ID mismatch in folder {recording_folder}. Expected Rat{rat_id}, found Rat{rat_id_from_folder}. Skipping...")
            continue

        # Detect all trial folders and filter for post_trial2 to post_trial5, considering various folder name formats
        trial_folders = [
            f for f in os.listdir(recording_path)
            if os.path.isdir(os.path.join(recording_path, f)) and
            re.search(r'(?i)post[\-_]?trial[\-_]?([2-5])', f)
        ]

        if not trial_folders:
            print(f"No trial folders found in {recording_folder}.")
            continue

        for trial_folder in trial_folders:
            print(f"Processing trial folder: {trial_folder}")
            trial_path = os.path.join(recording_path, trial_folder)

            # Search for LFP and state files in the trial folder
            lfp_file = None
            state_file = None

            for file_name in os.listdir(trial_path):
                if 'HPC' in file_name and file_name.endswith('.mat'):
                    lfp_file = os.path.join(trial_path, file_name)
                elif 'states' in file_name and file_name.endswith('.mat'):
                    state_file = os.path.join(trial_path, file_name)
                elif 'States' in file_name and file_name.endswith('.mat'):
                    state_file = os.path.join(trial_path, file_name)

            # Ensure both LFP and state files were found
            if not lfp_file or not state_file:
                print(f"Missing LFP or state file in {trial_path}. Skipping...")
                continue

            # Extract trial number from folder name
            trial_number_match = re.search(r'(?i)post[\-_]?trial[\-_]?([2-5])', trial_folder)
            if trial_number_match:
                trial_number = int(trial_number_match.group(1))
            else:
                print(f"Unable to extract trial number from folder name: {trial_folder}. Skipping...")
                continue

            # Load data using custom functions
            try:
                lfpHPC, hypno, _ = get_data(lfp_file, state_file)

                # Extract phasic and tonic intervals, handling cases with no REM sleep
                try:
                    phasic_interval, tonic_interval, lfp = extract_pt_intervals(lfpHPC, hypno)
                except ValueError as e:
                    print(f"No REM sleep found in {trial_folder} for Rat {rat_id}, Condition {condition}. Filling with empty intervals.")
                    phasic_interval, tonic_interval, lfp = [[], [], []]

                # Extract IMFs and frequencies for phasic and tonic intervals if intervals are not empty
                if phasic_interval and tonic_interval:
                    # Assume 'config' is defined elsewhere in your code
                    tonic_imfs, tonic_freqs, tonic_lpf = extract_imfs_by_pt_intervals(
                        lfp, fs, tonic_interval, config, return_imfs_freqs=True)
                    phasic_imfs, phasic_freqs, phasic_lpf = extract_imfs_by_pt_intervals(
                        lfp, fs, phasic_interval, config, return_imfs_freqs=True)

                    # Prepare UMAP data for both phasic and tonic
                    phasic_waveforms, phasic_trials, _ = prepare_data_for_umap(phasic_imfs, phasic_freqs)
                    tonic_waveforms, tonic_trials, _ = prepare_data_for_umap(tonic_imfs, tonic_freqs)

                    # Add metadata columns, including cycle type labels
                    for df in [phasic_waveforms, phasic_trials]:
                        df['rat_id'] = rat_id
                        df['condition'] = condition
                        df['trial'] = trial_number
                        df['cycle_type'] = 'phasic'
                        df['SD'] = sd_number
                        df['date'] = date_part

                    for df in [tonic_waveforms, tonic_trials]:
                        df['rat_id'] = rat_id
                        df['condition'] = condition
                        df['trial'] = trial_number
                        df['cycle_type'] = 'tonic'
                        df['SD'] = sd_number
                        df['date'] = date_part

                    # Concatenate into combined DataFrames
                    all_combined_waveforms = pd.concat(
                        [all_combined_waveforms, phasic_waveforms, tonic_waveforms], ignore_index=True)
                    all_combined_trials = pd.concat(
                        [all_combined_trials, phasic_trials, tonic_trials], ignore_index=True)

            except FileNotFoundError:
                print(f"Data not found in {trial_path}. Skipping...")

    if all_combined_waveforms.empty:
        print(f"No data extracted for Rat {rat_id}.")
        return None, None

    return all_combined_waveforms, all_combined_trials

In [None]:
import pandas as pd
import os
import re

# Define the rat IDs
rgs_control_ids = [1, 2, 6, 9]  # Control rats in RGS dataset
os_basic_rat_ids = [1, 3, 4, 6, 9, 11, 13]  # OS Basic rats

# Initialize empty DataFrames for final combined data
all_waveforms = pd.DataFrame()
all_trials = pd.DataFrame()

# Step 1: Extract data for RGS control rats
print("Extracting data for RGS control rats...")
for rat_id in rgs_control_ids:
    print(f"Processing RGS control rat {rat_id}...")
    waveforms, trials = extract_data_for_rat_rgs(rat_id)
    
    if waveforms is not None and not waveforms.empty:
        # Add dataset identifier
        waveforms['dataset'] = 'RGS_control'
        trials['dataset'] = 'RGS_control'
        
        # Concatenate to main dataframes
        all_waveforms = pd.concat([all_waveforms, waveforms], ignore_index=True)
        all_trials = pd.concat([all_trials, trials], ignore_index=True)
    else:
        print(f"No data extracted for RGS control rat {rat_id}")


In [None]:
# Step 2: Extract data for OS Basic rats
print("\nExtracting data for OS Basic rats...")
for rat_id in os_basic_rat_ids:
    print(f"Processing OS Basic rat {rat_id}...")
    waveforms, trials = extract_data_for_rat(rat_id)
    
    if waveforms is not None and not waveforms.empty:
        # Add dataset identifier
        waveforms['dataset'] = 'OS_Basic'
        trials['dataset'] = 'OS_Basic'
        
        # Concatenate to main dataframes
        all_waveforms = pd.concat([all_waveforms, waveforms], ignore_index=True)
        all_trials = pd.concat([all_trials, trials], ignore_index=True)
    else:
        print(f"No data extracted for OS Basic rat {rat_id}")

# Print summary statistics
print("\nData extraction complete!")
print(f"Total waveforms: {len(all_waveforms)}")
print(f"Total trials: {len(all_trials)}")

In [None]:
print(f"OS Basic data points: {sum(all_waveforms['dataset'] == 'OS_Basic')}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Normalize
import matplotlib.cm as cm
from mpl_toolkits.mplot3d import Axes3D
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

# Assuming all_waveforms and all_trials have been created from the previous script
# Make sure both datasets actually have data before proceeding

# First, let's check if the OS Basic data exists
print(f"OS Basic data points: {sum(all_waveforms['dataset'] == 'OS_Basic')}")
print(f"RGS Control data points: {sum(all_waveforms['dataset'] == 'RGS_control')}")

# If there's an issue with OS Basic data, ensure extraction was successful
# For now, let's proceed with the improved visualization

# Step 1: Drop metadata columns for UMAP processing (keep only the waveform data)
waveforms_for_umap = all_waveforms.drop(columns=['rat_id', 'condition', 'trial', 'cycle_type', 'SD', 'date', 'dataset'], errors='ignore')

# Step 2: Create UMAP model with 3 components for both 2D and 3D visualization
print("Creating UMAP embedding...")
umap_embedder = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=3, metric='euclidean', random_state=42)
embedding = umap_embedder.fit_transform(waveforms_for_umap.to_numpy())

# Step 3: Create masks for each dataset
os_basic_mask = all_waveforms['dataset'] == 'OS_Basic'
rgs_mask = all_waveforms['dataset'] == 'RGS_control'

# Step 4: Get peak values from trials dataframe and create a better color normalization
peak_values = all_trials['peak_values'].values

# Create more visually distinct colormaps
vmin = np.percentile(peak_values, 5)
vmax = np.percentile(peak_values, 95)
norm = Normalize(vmin=vmin, vmax=vmax)

# ================ IMPROVED 2D PLOT ================
plt.figure(figsize=(14, 10))
plt.style.use('default')  # Reset style to default

# Use a white background with dark grid for better visibility
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.facecolor'] = 'white'
plt.grid(True, linestyle='--', alpha=0.7)

# First plot OS Basic with a distinctive color/style
if sum(os_basic_mask) > 0:
    sc1 = plt.scatter(
        embedding[os_basic_mask, 0],
        embedding[os_basic_mask, 1],
        c=peak_values[os_basic_mask],
        cmap='viridis',  # More distinct colormap
        norm=norm,
        s=70,
        alpha=0.8,
        edgecolor='black',
        linewidth=0.5,
        label="OS Basic"
    )

# Then plot RGS Control on top with a different colormap
if sum(rgs_mask) > 0:
    sc2 = plt.scatter(
        embedding[rgs_mask, 0],
        embedding[rgs_mask, 1],
        c=peak_values[rgs_mask],
        cmap='plasma',  # Contrasting colormap
        norm=norm,
        s=60,
        alpha=0.75,
        edgecolor='black',
        linewidth=0.3,
        label="RGS Control"
    )

# Use a color bar that represents both datasets
if sum(os_basic_mask) > 0 and sum(rgs_mask) > 0:
    # Create a combined colorbar
    cbar = plt.colorbar(sc2 if sum(rgs_mask) > 0 else sc1, pad=0.02)
    cbar.set_label('Peak Values', fontsize=12, weight='bold')
    
    # Clearly indicate which dataset uses which colormap
    from matplotlib.lines import Line2D
    
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#21918c', markersize=10,
              markeredgecolor='black', markeredgewidth=0.5, label=f'OS Basic (n={sum(os_basic_mask)})'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#f0605d', markersize=10,
              markeredgecolor='black', markeredgewidth=0.3, label=f'RGS Control (n={sum(rgs_mask)})')
    ]
    plt.legend(handles=legend_elements, loc='upper right', fontsize=12, framealpha=0.9)
else:
    # If only one dataset has points, create appropriate legend
    if sum(os_basic_mask) > 0:
        plt.colorbar(sc1, label='Peak Values')
        plt.legend([f'OS Basic (n={sum(os_basic_mask)})'], loc='upper right')
    elif sum(rgs_mask) > 0:
        plt.colorbar(sc2, label='Peak Values')
        plt.legend([f'RGS Control (n={sum(rgs_mask)})'], loc='upper right')

# Improve title and labels
plt.title('UMAP Visualization of Theta Cycles by Dataset', fontsize=16, weight='bold')
plt.xlabel('UMAP Dimension 1', fontsize=12, weight='bold')
plt.ylabel('UMAP Dimension 2', fontsize=12, weight='bold')

# Add statistics box
stats_text = (
    f"Total cycles: {len(all_waveforms)}\n"
    f"OS Basic: {sum(os_basic_mask)} cycles\n"
    f"RGS Control: {sum(rgs_mask)} cycles\n"
    f"Peak value range: {vmin:.2f} to {vmax:.2f}"
)
plt.annotate(
    stats_text,
    xy=(0.02, 0.02),
    xycoords='axes fraction',
    fontsize=12,
    bbox=dict(boxstyle="round,pad=0.5", fc='white', ec="gray", alpha=0.9)
)

plt.tight_layout()
plt.show()


In [None]:

# Step 3: Create masks for each dataset
os_basic_mask = all_waveforms['dataset'] == 'OS_Basic'
rgs_mask = all_waveforms['dataset'] == 'RGS_control'

# Step 4: Get peak values
peak_values = all_trials['peak_values'].values

# Create the plot
plt.figure(figsize=(12, 9))

# Create separate colormaps and normalizations for each dataset
vmin_os = np.percentile(peak_values[os_basic_mask], 5) if sum(os_basic_mask) > 0 else np.min(peak_values)
vmax_os = np.percentile(peak_values[os_basic_mask], 95) if sum(os_basic_mask) > 0 else np.max(peak_values)
norm_os = Normalize(vmin=vmin_os, vmax=vmax_os)

vmin_rgs = np.percentile(peak_values[rgs_mask], 5) if sum(rgs_mask) > 0 else np.min(peak_values)
vmax_rgs = np.percentile(peak_values[rgs_mask], 95) if sum(rgs_mask) > 0 else np.max(peak_values)
norm_rgs = Normalize(vmin=vmin_rgs, vmax=vmax_rgs)

# Plot OS Basic first (using Blues colormap)
if sum(os_basic_mask) > 0:
    sc1 = plt.scatter(
        embedding[os_basic_mask, 0],
        embedding[os_basic_mask, 1],
        c=peak_values[os_basic_mask],
        cmap='Blues',
        norm=norm_os,
        s=40,
        alpha=0.8,
        label=f'OS Basic (n={sum(os_basic_mask)})'
    )
    
    # Add colorbar for OS Basic
    cbar1 = plt.colorbar(sc1, location='left', pad=0.1)
    cbar1.set_label('OS Basic Peak Values')

# Plot RGS Control on top (using Reds colormap)
if sum(rgs_mask) > 0:
    sc2 = plt.scatter(
        embedding[rgs_mask, 0],
        embedding[rgs_mask, 1],
        c=peak_values[rgs_mask],
        cmap='Reds',
        norm=norm_rgs,
        s=40,
        alpha=0.8,
        label=f'RGS Control (n={sum(rgs_mask)})'
    )
    
    # Add colorbar for RGS Control
    cbar2 = plt.colorbar(sc2, location='right', pad=0.1)
    cbar2.set_label('RGS Control Peak Values')

# Add title and labels
plt.title('UMAP Visualization of Theta Cycles by Dataset', fontsize=14)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.legend(loc='upper center')

plt.tight_layout()
plt.show()

# Optional 3D plot
from mpl_toolkits.mplot3d import Axes3D

if embedding.shape[1] < 3:
    # If we only have 2D UMAP, create a new one with 3 components
    umap_embedder_3d = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=3, metric='euclidean', random_state=42)
    embedding_3d = umap_embedder_3d.fit_transform(waveforms_for_umap.to_numpy())
else:
    embedding_3d = embedding

# Create 3D plot
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

# Plot OS Basic
if sum(os_basic_mask) > 0:
    sc1_3d = ax.scatter(
        embedding_3d[os_basic_mask, 0],
        embedding_3d[os_basic_mask, 1],
        embedding_3d[os_basic_mask, 2],
        c=peak_values[os_basic_mask],
        cmap='Blues',
        s=30,
        alpha=0.8,
        label=f'OS Basic (n={sum(os_basic_mask)})'
    )

# Plot RGS Control
if sum(rgs_mask) > 0:
    sc2_3d = ax.scatter(
        embedding_3d[rgs_mask, 0],
        embedding_3d[rgs_mask, 1],
        embedding_3d[rgs_mask, 2],
        c=peak_values[rgs_mask],
        cmap='Reds',
        s=30,
        alpha=0.8,
        label=f'RGS Control (n={sum(rgs_mask)})'
    )

ax.set_title('3D UMAP Visualization of Theta Cycles', fontsize=14)
ax.set_xlabel('UMAP Dimension 1')
ax.set_ylabel('UMAP Dimension 2')
ax.set_zlabel('UMAP Dimension 3')
ax.legend()

plt.tight_layout()
plt.show()

In [None]:
# Step 3: Create masks for each dataset
os_basic_mask = all_waveforms['dataset'] == 'OS_Basic'
rgs_mask = all_waveforms['dataset'] == 'RGS_control'

# Step 4: Get peak values
peak_values = all_trials['peak_values'].values

# Create a common normalization for peak values across both plots for consistency
vmin = np.percentile(peak_values, 5)
vmax = np.percentile(peak_values, 95)
norm = Normalize(vmin=vmin, vmax=vmax)

# 1. Plot OS Basic dataset
plt.figure(figsize=(10, 8))
if sum(os_basic_mask) > 0:
    sc1 = plt.scatter(
        embedding[os_basic_mask, 0], 
        embedding[os_basic_mask, 1],
        c=peak_values[os_basic_mask],
        cmap='hot',  # Use viridis for OS Basic
        norm=norm,
        s=50,
        alpha=0.8,
        edgecolor='black',
        linewidth=0.3
    )
    
    # Add colorbar
    cbar1 = plt.colorbar(sc1)
    cbar1.set_label('Peak Values', fontsize=12)
    
    # Add title and labels
    plt.title(f'OS Basic Dataset UMAP (n={sum(os_basic_mask)})', fontsize=14)
    plt.xlabel('UMAP Dimension 1', fontsize=12)
    plt.ylabel('UMAP Dimension 2', fontsize=12)
    
    # Add a text box with statistics
    if sum(os_basic_mask) > 0:
        os_basic_peak_min = np.min(peak_values[os_basic_mask])
        os_basic_peak_max = np.max(peak_values[os_basic_mask])
        os_basic_peak_mean = np.mean(peak_values[os_basic_mask])
        
        stats_text = (
            f"Total cycles: {sum(os_basic_mask)}\n"
            f"Peak value range: {os_basic_peak_min:.2f} to {os_basic_peak_max:.2f}\n"
            f"Mean peak value: {os_basic_peak_mean:.2f}"
        )
        
        plt.annotate(
            stats_text,
            xy=(0.02, 0.02),
            xycoords='axes fraction',
            fontsize=10,
            bbox=dict(boxstyle="round,pad=0.5", fc='white', ec="gray", alpha=0.8)
        )
    
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('os_basic_umap.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No OS Basic data points to plot")

# 2. Plot RGS Control dataset
plt.figure(figsize=(10, 8))
if sum(rgs_mask) > 0:
    sc2 = plt.scatter(
        embedding[rgs_mask, 0], 
        embedding[rgs_mask, 1],
        c=peak_values[rgs_mask],
        cmap='hot',  # Use plasma for RGS Control
        norm=norm,  # Use the same normalization for consistency
        s=50,
        alpha=0.8,
        edgecolor='black',
        linewidth=0.3
    )
    
    # Add colorbar
    cbar2 = plt.colorbar(sc2)
    cbar2.set_label('Peak Values', fontsize=12)
    
    # Add title and labels
    plt.title(f'RGS Control Dataset UMAP (n={sum(rgs_mask)})', fontsize=14)
    plt.xlabel('UMAP Dimension 1', fontsize=12)
    plt.ylabel('UMAP Dimension 2', fontsize=12)
    
    # Add a text box with statistics
    if sum(rgs_mask) > 0:
        rgs_peak_min = np.min(peak_values[rgs_mask])
        rgs_peak_max = np.max(peak_values[rgs_mask])
        rgs_peak_mean = np.mean(peak_values[rgs_mask])
        
        stats_text = (
            f"Total cycles: {sum(rgs_mask)}\n"
            f"Peak value range: {rgs_peak_min:.2f} to {rgs_peak_max:.2f}\n"
            f"Mean peak value: {rgs_peak_mean:.2f}"
        )
        
        plt.annotate(
            stats_text,
            xy=(0.02, 0.02),
            xycoords='axes fraction',
            fontsize=10,
            bbox=dict(boxstyle="round,pad=0.5", fc='white', ec="gray", alpha=0.8)
        )
    
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('rgs_control_umap.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No RGS Control data points to plot")

# 3. BONUS: Create plots with cycle type (phasic vs tonic) highlighted
# For OS Basic
if sum(os_basic_mask) > 0:
    plt.figure(figsize=(12, 10))
    
    # Create masks for phasic and tonic within OS Basic
    os_basic_phasic_mask = os_basic_mask & (all_waveforms['cycle_type'] == 'phasic')
    os_basic_tonic_mask = os_basic_mask & (all_waveforms['cycle_type'] == 'tonic')
    
    # Plot tonic first (triangles)
    plt.scatter(
        embedding[os_basic_tonic_mask, 0],
        embedding[os_basic_tonic_mask, 1],
        c=peak_values[os_basic_tonic_mask],
        cmap='Blues',
        norm=norm,
        s=50,
        alpha=0.8,
        marker='v',
        label=f'OS Basic Tonic (n={sum(os_basic_tonic_mask)})'
    )
    
    # Plot phasic on top (circles)
    sc3 = plt.scatter(
        embedding[os_basic_phasic_mask, 0],
        embedding[os_basic_phasic_mask, 1],
        c=peak_values[os_basic_phasic_mask],
        cmap='Reds',
        norm=norm,
        s=50,
        alpha=0.8,
        marker='o',
        label=f'OS Basic Phasic (n={sum(os_basic_phasic_mask)})'
    )
    
    plt.colorbar(sc3, label='Peak Values')
    plt.title('OS Basic Dataset UMAP by Cycle Type', fontsize=14)
    plt.xlabel('UMAP Dimension 1', fontsize=12)
    plt.ylabel('UMAP Dimension 2', fontsize=12)
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('os_basic_by_cycle_type.png', dpi=300, bbox_inches='tight')
    plt.show()

# For RGS Control
if sum(rgs_mask) > 0:
    plt.figure(figsize=(12, 10))
    
    # Create masks for phasic and tonic within RGS Control
    rgs_phasic_mask = rgs_mask & (all_waveforms['cycle_type'] == 'phasic')
    rgs_tonic_mask = rgs_mask & (all_waveforms['cycle_type'] == 'tonic')
    
    # Plot tonic first (triangles)
    plt.scatter(
        embedding[rgs_tonic_mask, 0],
        embedding[rgs_tonic_mask, 1],
        c=peak_values[rgs_tonic_mask],
        cmap='Blues',
        norm=norm,
        s=50,
        alpha=0.8,
        marker='v',
        label=f'RGS Tonic (n={sum(rgs_tonic_mask)})'
    )
    
    # Plot phasic on top (circles)
    sc4 = plt.scatter(
        embedding[rgs_phasic_mask, 0],
        embedding[rgs_phasic_mask, 1],
        c=peak_values[rgs_phasic_mask],
        cmap='Reds',
        norm=norm,
        s=50,
        alpha=0.8,
        marker='o',
        label=f'RGS Phasic (n={sum(rgs_phasic_mask)})'
    )
    
    plt.colorbar(sc4, label='Peak Values')
    plt.title('RGS Control Dataset UMAP by Cycle Type', fontsize=14)
    plt.xlabel('UMAP Dimension 1', fontsize=12)
    plt.ylabel('UMAP Dimension 2', fontsize=12)
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('rgs_by_cycle_type.png', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
import plotly.graph_objects as go


# Create masks for each dataset
os_basic_mask = all_waveforms['dataset'] == 'OS_Basic'
rgs_mask = all_waveforms['dataset'] == 'RGS_control'

# Get peak values
peak_values = all_trials['peak_values'].values

# Create scatter plot for OS Basic (orange)
if sum(os_basic_mask) > 0:
    scatter_os = go.Scatter3d(
        x=embedding_3d[os_basic_mask, 0],
        y=embedding_3d[os_basic_mask, 1],
        z=embedding_3d[os_basic_mask, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=peak_values[os_basic_mask],
            colorscale='Oranges',
            opacity=0.8
        ),
        name=f'OS Basic (n={sum(os_basic_mask)})'
    )

# Create scatter plot for RGS Control (blue)
if sum(rgs_mask) > 0:
    scatter_rgs = go.Scatter3d(
        x=embedding_3d[rgs_mask, 0],
        y=embedding_3d[rgs_mask, 1],
        z=embedding_3d[rgs_mask, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=peak_values[rgs_mask],
            colorscale='Blues',
            opacity=0.8,
            colorbar=dict(
                title="Peak Values",
                thickness=20,
                len=0.7
            )
        ),
        name=f'RGS Control (n={sum(rgs_mask)})'
    )

# Combine both scatter plots
data = []
if sum(os_basic_mask) > 0:
    data.append(scatter_os)
if sum(rgs_mask) > 0:
    data.append(scatter_rgs)

# Create the figure
fig = go.Figure(data=data)

# Update the layout
fig.update_layout(
    title="Interactive 3D UMAP: OS Basic (Orange) vs RGS Control (Blue)",
    width=900,
    height=700,
    scene=dict(
        xaxis=dict(title='UMAP Dimension 1'),
        yaxis=dict(title='UMAP Dimension 2'),
        zaxis=dict(title='UMAP Dimension 3')
    ),
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.99
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)

# Save as HTML file
fig.write_html("theta_cycles_3d_umap.html")

# Show the figure (this will display in the notebook)
fig.show()