# Peak Alignment (Phase-Based)

This notebook contains functions to extract peak-aligned waveforms based on instantaneous phase, aligning the peak to a specific target phase (default π/2).

In [None]:
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from scipy.signal import hilbert # Added for placeholder get_cycle_data
import emd # Assuming emd library is installed and contains necessary submodules
import matplotlib.pyplot as plt # For optional plotting
import copy # Added for placeholder get_cycles_with_conditions

## Helper Functions (Potentially needed from your existing pipeline)

**IMPORTANT:** These functions are placeholders based on the original `cycle_analysis-2.ipynb` and `utils.py`. You **MUST** replace these with your actual, fully functional implementations from your project for the main extraction function to work correctly.

In [None]:
# Placeholder for get_cycles_with_metrics - Replace with your actual function from utils.py
# This function should add various metrics to a Cycles object.
def get_cycles_with_metrics(cycles_obj, data, IA, IF):
    print("WARNING: Using placeholder get_cycles_with_metrics. Replace with actual implementation.")
    # Your actual function computes and adds metrics like 'duration_samples', 'max_amp', etc.
    # Example of adding one metric (replace with your full logic):
    try:
        cycles_obj.compute_cycle_metric('duration_samples', data, func=len, mode='augmented')
        cycles_obj.compute_cycle_metric('max_amp', IA, func=np.max, mode='augmented')
        # Add other metric computations here...
        cycles_obj.is_good = np.ones(cycles_obj.ncycles, dtype=bool) # Assume all are good initially
    except AttributeError as e:
        print(f"Placeholder get_cycles_with_metrics failed: {e}. Ensure Cycles object has compute_cycle_metric.")
    return cycles_obj

# Placeholder for get_cycle_data - Replace with your actual function
# This function should return a dictionary containing at least:
# 'theta_imf': the theta IMF signal (1D numpy array)
# 'IP': instantaneous phase (1D numpy array, same length as theta_imf)
# 'IA': instantaneous amplitude (1D numpy array, same length as theta_imf)
# 'IF': instantaneous frequency (1D numpy array, same length as theta_imf)
# 'cycles': an emd.cycles.Cycles object derived from IP, with metrics computed
def get_cycle_data(imf_channel, fs):
    # print("WARNING: Using placeholder get_cycle_data. Replace with actual implementation.")
    # Example: Basic Hilbert transform (replace with your EMD/sift logic)
    if len(imf_channel) == 0:
        return None # Cannot process empty channel
    analytic_signal = hilbert(imf_channel)
    IP = np.unwrap(np.angle(analytic_signal))
    IA = np.abs(analytic_signal)
    # Calculate IF safely, handle potential division by zero or small dt
    dt = 1.0 / fs
    IF = np.gradient(IP, dt) / (2 * np.pi)
    
    # Basic cycle detection (replace with emd.cycles.Cycles logic)
    cycles_obj = None
    try:
        C = emd.cycles.Cycles(IP) # Requires emd library
        # Compute metrics using the (placeholder) helper
        cycles_obj = get_cycles_with_metrics(C, imf_channel, IA, IF)
    except Exception as e:
        print(f"Placeholder: Failed to create or add metrics to emd.cycles.Cycles object: {e}.")
        
    if cycles_obj is None or not hasattr(cycles_obj, 'get_metric_dataframe'):
         print("Placeholder: Cycles object creation or metric computation failed.")
         return None # Return None if cycle object creation failed
         
    return {
        'fs': fs,
        'theta_imf': imf_channel,
        'IP': IP,
        'IF': IF, 
        'IA': IA,
        'cycles': cycles_obj
    }

# Placeholder for get_cycles_with_conditions - Replace with your actual function
# This function should take an emd.cycles.Cycles object and conditions,
# and return a *new* filtered emd.cycles.Cycles object or the same object
# modified to indicate the subset (e.g., via an internal mask).
# Crucially, it needs the .iterate(through='subset') method to work.
def get_cycles_with_conditions(cycles_obj, conditions):
    # print("WARNING: Using placeholder get_cycles_with_conditions. Replace with actual implementation.")
    if cycles_obj is None:
        return None
    
    # Your actual function needs to parse 'conditions' string list 
    # (e.g., 'max_amp>0.5', 'duration_samples<500') and apply filtering.
    # It likely uses methods like pick_cycle_subset or similar.
    # For the placeholder, we assume the object might have a pick_cycle_subset method.
    cycles_filtered = copy.deepcopy(cycles_obj) # Work on a copy
    try:
        # Example call - replace with actual parsing and application of conditions
        cycles_filtered.pick_cycle_subset(conditions) 
        # Check if any cycles remain after filtering
        if not hasattr(cycles_filtered, 'subset_indices') or cycles_filtered.subset_indices is None or len(cycles_filtered.subset_indices) == 0:
             # print("Placeholder: No cycles left after applying conditions.")
             return None # Return None if no cycles meet criteria
        # print(f"Placeholder: Filtered cycles count: {len(cycles_filtered.subset_indices)}")
        return cycles_filtered
    except Exception as e:
        print(f"Placeholder get_cycles_with_conditions failed: {e}. Does Cycles object have pick_cycle_subset?")
        # Fallback: return the original object if filtering fails in placeholder
        # In your real code, you might want to handle this error differently.
        # Check if the original object has any cycles to begin with
        if cycles_obj.ncycles > 0:
             print("Placeholder: Returning original cycles object due to filtering error.")
             # Ensure the original object can iterate through all cycles if subset fails
             cycles_obj.subset_indices = np.arange(cycles_obj.ncycles)
             return cycles_obj
        else:
             return None

## Peak Alignment Function (Corrected)

In [None]:
def extract_peak_aligned_waveforms_phase_based(imfs, theta_col=5, max_extract=None, npoints=100, target_peak_phase=np.pi/2, fs=2500):
    """
    Extract peak-aligned waveforms based on instantaneous phase, aligning the peak
    to a specific target phase (default pi/2, point 25 on a 100-point 0-2pi axis).
    Uses the .iterate(through='subset') method to access cycle indices.

    Parameters
    ----------
    imfs : list
        List of IMFs (each should be 2D numpy array, samples x channels).
    theta_col : int, default 5
        Column index corresponding to the theta IMF.
    max_extract : int, optional
        Maximum number of waveforms to extract per IMF (if None, extract all).
    npoints : int, default 100
        Number of points for the output waveform (representing 0 to 2*pi).
    target_peak_phase : float, default np.pi/2
        The phase value (between 0 and 2*pi) where the peak should be aligned.
    fs : int, default 2500
        Sampling frequency in Hz.

    Returns
    -------
    peak_aligned_phase_based_waveforms : DataFrame
        DataFrame where each row is a peak-aligned cycle based on phase.
    all_cycle_metrics_peak_phase : DataFrame
        DataFrame containing per-cycle metrics for the successfully extracted waveforms.
    target_phase_axis : ndarray
        The common phase axis (0 to 2*pi) used for alignment.
    """

    peak_aligned_phase_based_waveforms = pd.DataFrame()
    all_cycle_metrics_peak_phase = pd.DataFrame()
    target_phase_axis = np.linspace(0, 2 * np.pi, npoints)

    for idx, imf in enumerate(imfs):
        if imf.ndim != 2 or imf.shape[1] <= theta_col:
            print(f"Skipping IMF {idx}: Invalid shape or theta_col index.")
            continue
            
        # --- Compute cycle data ---
        cycle_data = get_cycle_data(imf[:, theta_col], fs=fs)
        if cycle_data is None or cycle_data.get('cycles') is None or cycle_data.get('IP') is None:
            # print(f"Could not get sufficient cycle data for IMF {idx}")
            continue

        # --- Apply thresholds (Example thresholds, adjust as needed) ---
        if 'IA' not in cycle_data or cycle_data['IA'] is None or len(cycle_data['IA']) == 0:
             # print(f"Skipping IMF {idx}: Missing or empty IA for thresholding.")
             continue
        # Handle potential NaNs in IA before percentile calculation
        ia_valid = cycle_data['IA'][~np.isnan(cycle_data['IA'])]
        if len(ia_valid) == 0:
            # print(f"Skipping IMF {idx}: IA contains only NaNs.")
            continue
        amp_thresh = np.percentile(ia_valid, 25)
        lo_freq_duration = fs / 5  # Corresponds to 5 Hz
        hi_freq_duration = fs / 12 # Corresponds to 12 Hz
        conditions = [
            'is_good==1', # Assumes your Cycles object has this metric
            f'duration_samples<{lo_freq_duration}',
            f'duration_samples>{hi_freq_duration}',
            f'max_amp>{amp_thresh}'
        ]
        
        # Get the filtered Cycles object
        all_cycles_filtered = get_cycles_with_conditions(cycle_data['cycles'], conditions)
        
        # Check if filtering was successful and returned a valid object
        if all_cycles_filtered is None or not hasattr(all_cycles_filtered, 'iterate') or not hasattr(all_cycles_filtered, 'get_metric_dataframe'):
             # print(f"Condition filtering failed or returned invalid object for IMF {idx}")
             continue
             
        # Get the metrics dataframe for the *filtered* cycles
        try:
            cycle_metrics_subset = all_cycles_filtered.get_metric_dataframe(subset=True) 
        except Exception as e:
             # print(f"Could not get metric dataframe for filtered cycles in IMF {idx}: {e}")
             continue
             
        if cycle_metrics_subset is None or cycle_metrics_subset.empty:
            # print(f"No cycles satisfy conditions for IMF {idx}")
            continue

        # --- Extract Peak-Aligned (Phase-Based) Waveforms ---
        peak_aligned_list = []
        valid_cycle_metrics_list = [] # Store metrics for valid cycles
        theta_imf = cycle_data['theta_imf']
        instantaneous_phase = cycle_data['IP']

        # Iterate through the filtered cycles using the iterator
        try:
            cycle_iterator = all_cycles_filtered.iterate(through='subset')
        except Exception as e:
            # print(f"Failed to get cycle iterator for IMF {idx}: {e}")
            continue
            
        for cycle_iter_idx, cycle_tuple in enumerate(cycle_iterator):
            try:
                # cycle_index_original: index in the *original* unfiltered Cycles object
                # inds: the sample indices for this specific cycle
                cycle_index_original, inds = cycle_tuple
            except Exception as e:
                # print(f"Could not unpack cycle tuple for iteration {cycle_iter_idx} in IMF {idx}: {e}")
                continue # Skip if tuple unpacking fails
                
            if inds is None or len(inds) < 3: # Need at least 3 points
                continue

            # --- Perform alignment using 'inds' ---
            raw_wave = theta_imf[inds]
            ip_cycle = instantaneous_phase[inds]
            ip_cycle_unwrapped = np.unwrap(ip_cycle)
            if len(ip_cycle_unwrapped) < 2: continue

            peak_idx_local = np.argmax(raw_wave)
            peak_phase_actual = ip_cycle_unwrapped[peak_idx_local]
            ip_shifted = ip_cycle_unwrapped - peak_phase_actual + target_peak_phase

            # Ensure the interpolation points (shifted phase) are strictly increasing
            unique_indices = np.unique(ip_shifted, return_index=True)[1]
            if len(unique_indices) < 2: continue # Need at least 2 unique points
            ip_shifted_interp = ip_shifted[np.sort(unique_indices)]
            raw_wave_interp = raw_wave[np.sort(unique_indices)]

            try:
                f_interp = interp1d(ip_shifted_interp, raw_wave_interp, kind='linear',
                                    bounds_error=False, fill_value=np.nan)
                resampled_wave = f_interp(target_phase_axis)

                # Check if interpolation resulted in excessive NaNs
                if np.isnan(resampled_wave).sum() < npoints / 2:
                    peak_aligned_list.append(resampled_wave)
                    # Get the corresponding row from the pre-calculated metrics dataframe
                    # Use cycle_iter_idx as it corresponds to the iteration order
                    if cycle_iter_idx < len(cycle_metrics_subset):
                         valid_cycle_metrics_list.append(cycle_metrics_subset.iloc[cycle_iter_idx])
                    else:
                         # This case should ideally not happen if iterator and dataframe match
                         print(f"Warning: Metrics index mismatch for cycle {cycle_iter_idx} in IMF {idx}")
                         # Append NaNs or handle appropriately
                         valid_cycle_metrics_list.append(pd.Series(index=cycle_metrics_subset.columns, dtype=float))
                # else:
                #     # Optional: log cycles skipped due to NaNs
                #     # print(f"Excessive NaNs for cycle original index {cycle_index_original} in IMF {idx}")
                    pass

            except ValueError as e:
                 # print(f"Interpolation failed for cycle original index {cycle_index_original} in IMF {idx}: {e}")
                 continue # Skip cycle if interpolation fails

        # --- After the loop for the current IMF ---
        if not peak_aligned_list: 
            # print(f"No valid peak-aligned waveforms generated for IMF {idx}")
            continue # Skip IMF if no valid cycles processed

        peak_aligned_df = pd.DataFrame(peak_aligned_list)
        # Create metrics dataframe from the collected valid rows
        # Important: Ensure the list contains Series/DataFrames before creating DataFrame
        if valid_cycle_metrics_list and isinstance(valid_cycle_metrics_list[0], pd.Series):
             filtered_cycle_metrics_df = pd.DataFrame(valid_cycle_metrics_list)
        elif valid_cycle_metrics_list and isinstance(valid_cycle_metrics_list[0], pd.DataFrame): # If list contains DFs
             filtered_cycle_metrics_df = pd.concat(valid_cycle_metrics_list, ignore_index=True)
        else:
             # Create an empty DF with correct columns if possible, or handle error
             filtered_cycle_metrics_df = pd.DataFrame(columns=cycle_metrics_subset.columns)
             print(f"Warning: Could not reconstruct metrics DataFrame for IMF {idx}")
             
        # Ensure the number of waveforms matches the number of metric rows
        if len(peak_aligned_df) != len(filtered_cycle_metrics_df):
            print(f"Warning: Mismatch between waveform count ({len(peak_aligned_df)}) and metrics count ({len(filtered_cycle_metrics_df)}) for IMF {idx}. Adjusting metrics.")
            # Attempt to align - this might indicate a deeper issue
            min_len = min(len(peak_aligned_df), len(filtered_cycle_metrics_df))
            peak_aligned_df = peak_aligned_df.iloc[:min_len]
            filtered_cycle_metrics_df = filtered_cycle_metrics_df.iloc[:min_len]

        # --- Limit the number of waveforms if requested ---
        if max_extract is not None:
            if len(peak_aligned_df) > max_extract:
                limit = min(max_extract, len(peak_aligned_df))
                peak_aligned_df = peak_aligned_df.iloc[:limit]
                filtered_cycle_metrics_df = filtered_cycle_metrics_df.iloc[:limit]

        # --- Concatenate ---
        peak_aligned_phase_based_waveforms = pd.concat([peak_aligned_phase_based_waveforms, peak_aligned_df], ignore_index=True)
        all_cycle_metrics_peak_phase = pd.concat([all_cycle_metrics_peak_phase, filtered_cycle_metrics_df], ignore_index=True)

    # Final check and reset index
    all_cycle_metrics_peak_phase = all_cycle_metrics_peak_phase.reset_index(drop=True)
    peak_aligned_phase_based_waveforms = peak_aligned_phase_based_waveforms.reset_index(drop=True)

    return peak_aligned_phase_based_waveforms, all_cycle_metrics_peak_phase, target_phase_axis

## Example Usage

**Note:** Replace the placeholder data loading and helper functions with your actual project code.

In [None]:
# --- Load your IMF data here ---
# Example: Replace with your actual data loading
# Assuming 'rem_imfs' is a list of numpy arrays (e.g., loaded from .npy or extracted earlier)
# Ensure rem_imfs is a list of 2D arrays [samples x channels]
try:
    # Attempt to load from a common variable name used in the original notebook
    # You might need to load phasic_imfs or tonic_imfs depending on your goal
    # rem_imfs = np.load('path/to/your/rem_imfs.npy', allow_pickle=True) 
    # Check if rem_imfs exists from a previous cell, otherwise use placeholder
    if 'rem_imfs' not in locals() or not isinstance(rem_imfs, list):
        raise NameError("'rem_imfs' not found or not a list, using placeholder data.")
    imf_input = rem_imfs
    fs_actual = 2500 # Assuming this is the correct fs for your loaded data
    print(f"--- USING LOADED 'rem_imfs' ({len(imf_input)} IMFs) ---")
except NameError as e:
    print(f"--- {e} ---")
    # --- Placeholder Data (REMOVE THIS IN YOUR PIPELINE) ---
    print("--- USING PLACEHOLDER IMF DATA ---")
    fs_actual = 2500
    t_example = np.arange(0, 5*fs_actual)/fs_actual
    sig1 = np.sin(2*np.pi*7*t_example) + 0.5*np.sin(2*np.pi*14*t_example) + 0.1*np.random.randn(len(t_example))
    sig2 = np.sin(2*np.pi*8*t_example + np.pi/2) + 0.3*np.random.randn(len(t_example))
    dummy_imf1 = np.zeros((len(t_example), 6))
    dummy_imf1[:, 5] = sig1 # Put signal in theta_col
    dummy_imf2 = np.zeros((len(t_example), 6))
    dummy_imf2[:, 5] = sig2
    imf_input = [dummy_imf1, dummy_imf2] 
    print(f"Placeholder data: {len(imf_input)} IMFs, shape of first: {imf_input[0].shape}")
    print("--- END PLACEHOLDER IMF DATA ---")
    # --- End Placeholder Data ---

# --- Run the extraction ---
try:
    peak_aligned_waveforms, cycle_metrics_aligned, phase_axis = \
        extract_peak_aligned_waveforms_phase_based(
            imf_input, 
            theta_col=5,      # Ensure this matches your data structure
            max_extract=None, # Set to a number to limit cycles per IMF
            npoints=100, 
            target_peak_phase=np.pi/2, # Align peak to point 25 (pi/2)
            fs=fs_actual # Use the correct sampling frequency
        )
    
    print(f"\nSuccessfully extracted {len(peak_aligned_waveforms)} peak-aligned waveforms.")
    if not peak_aligned_waveforms.empty:
        print(f"Waveform shape: {peak_aligned_waveforms.shape}")
        print(f"Metrics shape: {cycle_metrics_aligned.shape}")
        
        # Display first few waveforms and metrics
        print("\nFirst 5 Peak-Aligned Waveforms:")
        print(peak_aligned_waveforms.head())
        print("\nFirst 5 Corresponding Cycle Metrics:")
        print(cycle_metrics_aligned.head())
    else:
        print("\nNo waveforms were extracted (check input data and helper functions).")

except Exception as e:
    print(f"\nAn error occurred during extraction: {e}")
    import traceback
    traceback.print_exc()

## Optional: Visualization

In [None]:
def plot_aligned_waveforms(waveforms_df, phase_axis, title="Aligned Waveforms", n_examples=20):
    """Plots a subset of aligned waveforms."""
    fig, ax = plt.subplots(figsize=(10, 6))
    n_plot = min(n_examples, len(waveforms_df))
    if n_plot == 0:
        print("No waveforms to plot.")
        # Return an empty figure or handle as appropriate
        # plt.close(fig) 
        return fig 

    # Calculate a dynamic offset based on standard deviation of the plotted subset
    subset_to_plot = waveforms_df.iloc[:n_plot]
    # Calculate std only for non-NaN values to avoid issues
    all_std = np.nanmean(subset_to_plot.std(axis=1)) 
    if np.isnan(all_std) or all_std == 0:
        all_std = 1 # Default std if calculation fails
        
    offset_factor = 0.2 # Adjust this factor to control spacing
    
    for i in range(n_plot):
        waveform_data = waveforms_df.iloc[i].values
        # Add offset based on the average std deviation
        offset = i * all_std * offset_factor 
        ax.plot(phase_axis, waveform_data + offset, alpha=0.7)

    # Mark the target peak phase (assuming it was pi/2)
    target_peak_phase = np.pi / 2 
    ax.axvline(target_peak_phase, color='r', linestyle='--', label=f'Target Peak Phase ({target_peak_phase:.2f} rad)')

    ax.set_title(f"{title} (First {n_plot} Examples)")
    ax.set_xlabel("Phase (radians)")
    ax.set_ylabel("Amplitude (stacked with offset)")
    ax.set_xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi])
    ax.set_xticklabels(['0', 'π/2', 'π', '3π/2', '2π'])
    ax.legend()
    plt.tight_layout()
    plt.show()
    return fig

# Plot the results if extraction was successful
if 'peak_aligned_waveforms' in locals() and not peak_aligned_waveforms.empty:
    plot_aligned_waveforms(peak_aligned_waveforms, phase_axis, title="Peak-Aligned (Phase-Based, Peak at π/2)")
else:
    print("\nSkipping plot: No peak-aligned waveforms were generated or an error occurred.")