# Dependencies

In [None]:
# %pip install pyxdf && echo "Installation Successful (pyxdf)" || echo "Installation Failed (pyxdf)"
# import sys
# !{sys.executable} -m pip install --upgrade pip
# !{sys.executable} -m pip list --outdated --format=freeze | grep -v '^\-e' | cut -d = -f 1 | xargs -n1 {sys.executable} -m pip install -U

In [None]:
# # In a Jupyter cell - using ! to run shell commands
# !pip install --upgrade numpy pandas matplotlib scikit-learn seaborn

In [None]:
import pyxdf
import pandas as pd
import numpy as np
import os
import glob
import re
import shutil
import io
import matplotlib.pyplot as plt
import pprint
from scipy.interpolate import interp1d

# F(x) XDF Data Processing

In [None]:
# ===========================================================
# Channel Label Extraction
# ===========================================================

def extract_channel_labels(stream):
    try:
        # Navigate the nested metadata structure to find channel entries
        channel_entries = stream['info']['desc'][0]['channels'][0]['channel']

        # Handle both single channel (dict) and multiple channels (list of dicts)
        if not isinstance(channel_entries, list):
            channel_entries = [channel_entries]

        labels, types, units = [], [], []

        # Extract information from each channel entry
        for i, ch in enumerate(channel_entries):
            # Get label (try 'label' first, fallback to 'name', finally generic)
            label = ch.get('label', ch.get('name', [f"Channel {i}"]))[0]
            signal_type = ch.get('type', [''])[0]
            unit = ch.get('unit', [''])[0]

            labels.append(label)
            types.append(signal_type)
            units.append(unit)

        return labels, types, units

    except (KeyError, IndexError, TypeError):
        # Fallback when metadata is missing or malformed
        # Use channel_count from stream info to generate generic labels
        num_channels = stream['info'].get('channel_count', ['0'])[0]
        try:
            num_channels = int(num_channels)
        except ValueError:
            num_channels = 1

        # Generate generic channel labels
        labels = [f"Channel {i}" for i in range(num_channels)]
        return labels, [''] * num_channels, [''] * num_channels

# ===========================================================
# String Normalization for Robust Matching
# ===========================================================

def normalize_string(s):
     return s.strip().lower()

# ===========================================================
# Channel Filtering with Robust Matching
# ===========================================================

def filter_stream_channels_robust(stream, channel_names):

    # Extract all available channel information from the stream
    labels, types, units = extract_channel_labels(stream)

    # Create a mapping from normalized names to original indices
    # This allows O(1) lookup for each requested channel
    normalized_labels = {normalize_string(label): i for i, label in enumerate(labels)}

    # Normalize the requested channel names for comparison
    normalized_requests = [normalize_string(name) for name in channel_names]

    # Find indices of matching channels
    matching_indices = []
    matched_names = []
    for req in normalized_requests:
        if req in normalized_labels:
            idx = normalized_labels[req]
            matching_indices.append(idx)
            matched_names.append(labels[idx])  # Keep original label (not normalized)

    # Return None if no channels matched
    if not matching_indices:
        return None

    # Create filtered stream with only the selected channels
    filtered_stream = {
        'info': stream['info'].copy(),  # Shallow copy of metadata
        'time_series': stream['time_series'][:, matching_indices],  # Select columns
        'time_stamps': stream.get('time_stamps', None),  # Preserve timestamps
    }

    # Update metadata to reflect the filtered channel set
    try:
        desc = filtered_stream['info']['desc'][0]
        if 'channels' in desc and 'channel' in desc['channels'][0]:
            original_channels = desc['channels'][0]['channel']
            if isinstance(original_channels, list):
                # Keep only the channel metadata entries that matched
                filtered_channels = [original_channels[i] for i in matching_indices]
                filtered_stream['info']['desc'][0]['channels'][0]['channel'] = filtered_channels
        # Update the channel count in metadata
        filtered_stream['info']['channel_count'] = [str(len(matching_indices))]
    except Exception:
        # If metadata update fails, continue (data is still filtered correctly)
        pass

    return filtered_stream

# ===========================================================
# Multi-Stream Selection with Diagnostics
# ===========================================================

def select_streams_by_name(streams, selection_dict, verbose=True):

    selected_streams = []

    # Create a lookup table: normalized name -> (original stream, original name)
    stream_lookup = {}
    for stream in streams:
        stream_name = stream['info']['name'][0]
        normalized_name = normalize_string(stream_name)
        stream_lookup[normalized_name] = (stream, stream_name)

    # Print available streams if verbose mode is enabled
    if verbose:
        print(f"\n🔍 Available streams ({len(streams)}):")
        for norm_name, (_, orig_name) in stream_lookup.items():
            print(f"  '{orig_name}'")
        print()

    # Process each requested stream
    for requested_name, channels in selection_dict.items():
        normalized_request = normalize_string(requested_name)

        if "polar" in normalized_request:
            # Search for a stream with "Polar H10" in original name (case-insensitive)
            matching_stream = None
            for norm_name, (stream, orig_name) in stream_lookup.items():
                if "polar h10" in orig_name.lower():
                    matching_stream = (stream, orig_name)
                    break
            if matching_stream:
                stream, original_name = matching_stream
                if verbose:
                    print(f"   Found stream matching 'Polar H10': '{original_name}'")
                filtered_stream = filter_stream_channels_robust(stream, channels)

                if filtered_stream is not None:
                    labels, _, _ = extract_channel_labels(filtered_stream)
                    if verbose:
                        print(f"   Matched {len(labels)} channel(s): {labels}")
                    selected_streams.append(filtered_stream)
                else:
                    if verbose:
                        print(f"   No matching channels found in '{original_name}'")
                continue
        # Check if the stream exists in our lookup
        if normalized_request in stream_lookup:
            stream, original_name = stream_lookup[normalized_request]
            if verbose:
                print(f"   Found stream: '{original_name}'")
                print(f"   Searching for channels: {channels}")

            # Filter the stream to include only requested channels
            filtered_stream = filter_stream_channels_robust(stream, channels)

            if filtered_stream is not None:
                # Successfully found at least one matching channel
                labels, _, _ = extract_channel_labels(filtered_stream)
                if verbose:
                    print(f"   Matched {len(labels)} channel(s): {labels}")
                selected_streams.append(filtered_stream)
            else:
                # Stream was found, but none of the requested channels matched
                if verbose:
                    print(f"   No matching channels found in '{original_name}'")
        else:
            # Stream name didn't match any available streams
            if verbose:
                print(f"Stream not found: '{requested_name}'")
                print(f"   (normalized as: '{normalized_request}')")

    # Print summary
    if verbose:
        print(f"\n Total streams selected: {len(selected_streams)}\n")

    return selected_streams

# ===========================================================
# Stream-to-DataFrame Conversion with Synchronization
# ===========================================================

def streams_to_dataframe(streams, resample=True, target_freq=1.0, use_timestamps=True, n=0):

    # Handle empty input
    if not streams:
        print("No streams provided to streams_to_dataframe()")
        return pd.DataFrame()

    all_data = []

    # Process each stream individually
    for stream in streams:
        # Extract stream metadata
        name = stream['info']['name'][0]
        data = stream['time_series']
        ts = stream.get('time_stamps')
        labels, _, _ = extract_channel_labels(stream)

        # Safety check: ensure we have labels for all data columns
        num_channels = data.shape[1]
        labels = labels[:num_channels]

        # Determine time column
        if use_timestamps and ts is not None:
            time_col = ts  # Use original timestamps
        else:
            # Generate synthetic timestamps assuming constant sampling rate
            time_col = np.arange(len(data)) / target_freq

        # Create DataFrame for this stream with prefixed column names
        df_part = pd.DataFrame(data, columns=[f"{name}_{label}" for label in labels])
        df_part.insert(0, "XDFTime", time_col)  # CHANGED: Rename to XDFTime
        all_data.append(df_part)

    # Determine the overlapping time range across all streams
    # Use the latest start time and earliest end time
    min_t = max(df["XDFTime"].min() for df in all_data)  # CHANGED: XDFTime
    max_t = min(df["XDFTime"].max() for df in all_data)  # CHANGED: XDFTime

    # Create the new time grid
    if resample:
        # Uniform grid at target_freq Hz
        new_time = np.arange(min_t, max_t, 1.0 / target_freq)
    else:
        # Use the union of all original timestamps (sorted and unique)
        new_time = sorted(set(np.concatenate([df["XDFTime"].values for df in all_data])))  # CHANGED: XDFTime

    # Initialize the merged data dictionary
    merged = {"XDFTime": new_time}  # CHANGED: XDFTime

    # Interpolate each stream's channels onto the new time grid
    for df_part in all_data:
        for col in df_part.columns:
            if col == "XDFTime":  # CHANGED: XDFTime
                continue  # Skip the time column itself

            # Create interpolation function
            f = interp1d(df_part["XDFTime"], df_part[col],  # CHANGED: XDFTime
                        fill_value="extrapolate",
                        bounds_error=False)

            # Apply interpolation to new time grid
            merged[col] = f(new_time)

    # Create the final combined DataFrame
    df = pd.DataFrame(merged)

    # REMOVED: Time column generation here - we'll add it after truncation

    # Truncate n samples from start and end if requested
    if n > 0:
        if len(df) > 2 * n:
            # Remove first n and last n rows
            df = df.iloc[n:-n].reset_index(drop=True)
        else:
            # Not enough data to truncate safely
            print(f"Warning: n={n} too large for dataset length {len(df)} — skipping truncation.")

    # ADD Time column AFTER truncation to ensure it starts from 0
    df.insert(0, "Time", np.arange(len(df)) / target_freq)

    return df
    
# def streams_to_dataframe(streams, resample=True, target_freq=1.0, use_timestamps=True, n=0):

#     # Handle empty input
#     if not streams:
#         print("No streams provided to streams_to_dataframe()")
#         return pd.DataFrame()

#     all_data = []

#     # Process each stream individually
#     for stream in streams:
#         # Extract stream metadata
#         name = stream['info']['name'][0]
#         data = stream['time_series']
#         ts = stream.get('time_stamps')
#         labels, _, _ = extract_channel_labels(stream)

#         # Safety check: ensure we have labels for all data columns
#         num_channels = data.shape[1]
#         labels = labels[:num_channels]

#         # Determine time column
#         if use_timestamps and ts is not None:
#             time_col = ts  # Use original timestamps
#         else:
#             # Generate synthetic timestamps assuming constant sampling rate
#             time_col = np.arange(len(data)) / target_freq

#         # Create DataFrame for this stream with prefixed column names
#         df_part = pd.DataFrame(data, columns=[f"{name}_{label}" for label in labels])
#         df_part.insert(0, "Time", time_col)
#         all_data.append(df_part)

#     # Determine the overlapping time range across all streams
#     # Use the latest start time and earliest end time
#     min_t = max(df["Time"].min() for df in all_data)
#     max_t = min(df["Time"].max() for df in all_data)

#     # Create the new time grid
#     if resample:
#         # Uniform grid at target_freq Hz
#         new_time = np.arange(min_t, max_t, 1.0 / target_freq)
#     else:
#         # Use the union of all original timestamps (sorted and unique)
#         new_time = sorted(set(np.concatenate([df["Time"].values for df in all_data])))

#     # Initialize the merged data dictionary
#     merged = {"Time": new_time}

#     # Interpolate each stream's channels onto the new time grid
#     for df_part in all_data:
#         for col in df_part.columns:
#             if col == "Time":
#                 continue  # Skip the time column itself

#             # Create interpolation function
#             f = interp1d(df_part["Time"], df_part[col],
#                         fill_value="extrapolate",
#                         bounds_error=False)

#             # Apply interpolation to new time grid
#             merged[col] = f(new_time)

#     # Create the final combined DataFrame
#     df = pd.DataFrame(merged)

#     # Truncate n samples from start and end if requested
#     if n > 0:
#         if len(df) > 2 * n:
#             # Remove first n and last n rows
#             df = df.iloc[n:-n].reset_index(drop=True)
#         else:
#             # Not enough data to truncate safely
#             print(f"Warning: n={n} too large for dataset length {len(df)} — skipping truncation.")

#     return df

# ===========================================================
# Post-Processing: Resampling
# ===========================================================

def resample_dataframe(data, target_freq=1.0, verbose=True):

    # Validate input - now check for XDFTime instead of Time
    if "XDFTime" not in data.columns:  # CHANGED: XDFTime
        raise ValueError("Input DataFrame must contain a 'XDFTime' column")  # CHANGED: XDFTime

    # Extract original time column (XDFTime)
    original_time = data["XDFTime"].values  # CHANGED: XDFTime

    if len(original_time) < 2:
        raise ValueError("DataFrame must have at least 2 time points for resampling")

    # Check original time spacing
    time_diffs = np.diff(original_time)
    tolerance = 1e-6
    is_equally_spaced = np.allclose(time_diffs, np.mean(time_diffs), atol=tolerance)

    if verbose:
        if is_equally_spaced:
            print("\n✓ Original time data is equally spaced.")
        else:
            print("\n⚠ Original time data is NOT equally spaced.")
        print(f"  Standard deviation of time differences: {np.std(time_diffs):.6f} seconds")

    # Create new uniform time grid for XDFTime
    new_xdftime = np.arange(original_time[0], original_time[-1], 1.0 / target_freq)

    # Helper function for interpolating a single column
    def interpolate_column(original_time, col_data, new_time):
        """Interpolate a single data column onto a new time grid."""
        if len(original_time) < 2:
            return np.full(len(new_time), np.mean(col_data))
        else:
            interp_func = interp1d(
                original_time,
                col_data,
                kind='linear',
                fill_value="extrapolate",
                bounds_error=False
            )
            return interp_func(new_time)

    # Resample each data column (excluding both Time and XDFTime columns)
    columns_to_resample = [col for col in data.columns if col not in ["Time", "XDFTime"]]  # CHANGED
    resampled_data = {}

    for col in columns_to_resample:
        col_data = data[col].values
        resampled_data[col] = interpolate_column(original_time, col_data, new_xdftime)

    # Create final resampled DataFrame
    final_dataset = pd.DataFrame(resampled_data)
    final_dataset.insert(0, "XDFTime", new_xdftime)  # CHANGED: XDFTime
    
    # REMOVED: Time column generation here - we'll add it after truncation in the main pipeline

    # Verify resampled time spacing for XDFTime
    resampled_time_diffs = np.diff(new_xdftime)
    is_resampled_equally_spaced = np.allclose(
        resampled_time_diffs,
        np.mean(resampled_time_diffs),
        atol=tolerance
    )

    if verbose:
        if is_resampled_equally_spaced:
            print("\n✓ Resampled XDFTime data is equally spaced.")  # CHANGED
        else:
            print("\n⚠ Resampled XDFTime data is NOT equally spaced.")  # CHANGED
        print(f"  Standard deviation of resampled XDFTime differences: {np.std(resampled_time_diffs):.6e} seconds")  # CHANGED

        print(f"\n📊 Dataset Summary:")
        print(f"  Shape: {final_dataset.shape} (rows × columns)")
        print(f"  Total data points: {final_dataset.size:,}")
        print(f"  XDFTime range: {new_xdftime[0]:.3f} - {new_xdftime[-1]:.3f} seconds")  # CHANGED
        print(f"  Duration: {new_xdftime[-1] - new_xdftime[0]:.3f} seconds")  # CHANGED
        print(f"  Sampling rate: {target_freq} Hz")
        print(f"  Time step: {1.0/target_freq:.6f} seconds")

    return final_dataset
    
# def resample_dataframe(data, target_freq=1.0, verbose=True):

#     # Validate input
#     if "Time" not in data.columns:
#         raise ValueError("Input DataFrame must contain a 'Time' column")

#     # Extract original time column
#     original_time = data["Time"].values

#     if len(original_time) < 2:
#         raise ValueError("DataFrame must have at least 2 time points for resampling")

#     # Check original time spacing
#     time_diffs = np.diff(original_time)
#     tolerance = 1e-6
#     is_equally_spaced = np.allclose(time_diffs, np.mean(time_diffs), atol=tolerance)

#     if verbose:
#         if is_equally_spaced:
#             print("\n✓ Original time data is equally spaced.")
#         else:
#             print("\n⚠ Original time data is NOT equally spaced.")
#         print(f"  Standard deviation of time differences: {np.std(time_diffs):.6f} seconds")

#     # Create new uniform time grid
#     new_time = np.arange(original_time[0], original_time[-1], 1.0 / target_freq)

#     # Helper function for interpolating a single column
#     def interpolate_column(original_time, col_data, new_time):
#         """Interpolate a single data column onto a new time grid."""
#         if len(original_time) < 2:
#             return np.full(len(new_time), np.mean(col_data))
#         else:
#             interp_func = interp1d(
#                 original_time,
#                 col_data,
#                 kind='linear',
#                 fill_value="extrapolate",
#                 bounds_error=False
#             )
#             return interp_func(new_time)

#     # Resample each data column
#     columns_to_resample = [col for col in data.columns if col != "Time"]
#     resampled_data = {}

#     for col in columns_to_resample:
#         col_data = data[col].values
#         resampled_data[col] = interpolate_column(original_time, col_data, new_time)

#     # Create final resampled DataFrame
#     final_dataset = pd.DataFrame(resampled_data)
#     final_dataset.insert(0, "Time", new_time)

#     # Verify resampled time spacing
#     resampled_time_diffs = np.diff(new_time)
#     is_resampled_equally_spaced = np.allclose(
#         resampled_time_diffs,
#         np.mean(resampled_time_diffs),
#         atol=tolerance
#     )

#     if verbose:
#         if is_resampled_equally_spaced:
#             print("\n✓ Resampled time data is equally spaced.")
#         else:
#             print("\n⚠ Resampled time data is NOT equally spaced.")
#         print(f"  Standard deviation of resampled time differences: {np.std(resampled_time_diffs):.6e} seconds")

#         print(f"\n📊 Dataset Summary:")
#         print(f"  Shape: {final_dataset.shape} (rows × columns)")
#         print(f"  Total data points: {final_dataset.size:,}")
#         print(f"  Time range: {new_time[0]:.3f} - {new_time[-1]:.3f} seconds")
#         print(f"  Duration: {new_time[-1] - new_time[0]:.3f} seconds")
#         print(f"  Sampling rate: {target_freq} Hz")
#         print(f"  Time step: {1.0/target_freq:.6f} seconds")

#     return final_dataset

# ===========================================================
# Post-Processing: Truncation
# ===========================================================

def truncate_dataframe(data, n=0, verbose=True):

    if n < 0:
        raise ValueError(f"n must be non-negative, got n={n}")

    if n == 0:
        if verbose:
            print("\nℹ No truncation performed (n=0)")
        return data

    original_length = len(data)

    if original_length <= 2 * n:
        if verbose:
            print(f"\n⚠ Warning: n={n} is too large for dataset length {original_length} — skipping truncation.")
            print(f"   Minimum length required: {2*n + 1} (to remove {n} from each end)")
        return data

    if verbose:
        print(f"\n📊 Truncation Summary:")
        print(f"  Original shape: {data.shape}")
        print(f"  Removing {n} rows from start and end ({2*n} total)")

    truncated_data = data.iloc[n:-n].reset_index(drop=True)
    
    # REGENERATE Time column to ensure it starts from 0 after truncation
    if "Time" in truncated_data.columns:
        # Remove old Time column
        truncated_data = truncated_data.drop(columns=["Time"])
    # Add new Time column starting from 0
    target_freq = 1.0 / (truncated_data["XDFTime"].iloc[1] - truncated_data["XDFTime"].iloc[0]) if len(truncated_data) > 1 else 1.0
    truncated_data.insert(0, "Time", np.arange(len(truncated_data)) / target_freq)

    if verbose:
        print(f"  Final shape: {truncated_data.shape}")

        if "XDFTime" in truncated_data.columns:
            time_removed_start = data.iloc[n]["XDFTime"] - data.iloc[0]["XDFTime"]
            time_removed_end = data.iloc[-1]["XDFTime"] - data.iloc[-n-1]["XDFTime"]
            print(f"  XDFTime removed from start: {time_removed_start:.3f} seconds")
            print(f"  XDFTime removed from end: {time_removed_end:.3f} seconds")
            print(f"  New XDFTime range: {truncated_data['XDFTime'].iloc[0]:.3f} - {truncated_data['XDFTime'].iloc[-1]:.3f} seconds")
            print(f"  New Time range: {truncated_data['Time'].iloc[0]:.3f} - {truncated_data['Time'].iloc[-1]:.3f} seconds")  # Now starts at 0!

    return truncated_data
    
# def truncate_dataframe(data, n=0, verbose=True):

#     if n < 0:
#         raise ValueError(f"n must be non-negative, got n={n}")

#     if n == 0:
#         if verbose:
#             print("\nℹ No truncation performed (n=0)")
#         return data

#     original_length = len(data)

#     if original_length <= 2 * n:
#         if verbose:
#             print(f"\n⚠ Warning: n={n} is too large for dataset length {original_length} — skipping truncation.")
#             print(f"   Minimum length required: {2*n + 1} (to remove {n} from each end)")
#         return data

#     if verbose:
#         print(f"\n📊 Truncation Summary:")
#         print(f"  Original shape: {data.shape}")
#         print(f"  Removing {n} rows from start and end ({2*n} total)")

#     truncated_data = data.iloc[n:-n].reset_index(drop=True)

#     if verbose:
#         print(f"  Final shape: {truncated_data.shape}")

#         if "Time" in truncated_data.columns:
#             time_removed_start = data.iloc[n]["Time"] - data.iloc[0]["Time"]
#             time_removed_end = data.iloc[-1]["Time"] - data.iloc[-n-1]["Time"]
#             print(f"  Time removed from start: {time_removed_start:.3f} seconds")
#             print(f"  Time removed from end: {time_removed_end:.3f} seconds")
#             print(f"  New time range: {truncated_data['Time'].iloc[0]:.3f} - {truncated_data['Time'].iloc[-1]:.3f} seconds")

#     return truncated_data

# ===========================================================
# Diagnostic and Utility Functions
# ===========================================================

def print_stream_info(streams):
    print(f"\n{'='*60}")
    print(f"Found {len(streams)} stream(s):")
    print(f"{'='*60}\n")

    for stream in streams:
        name = stream['info']['name'][0]
        labels, types, units = extract_channel_labels(stream)

        print(f" Stream: {name}")
        print(f"   Channels ({len(labels)}):")

        for l, t, u in zip(labels, types, units):
            type_str = f"({t})" if t else ""
            unit_str = f"[{u}]" if u else ""
            print(f"      • {l} {type_str} {unit_str}".strip())

        print(f"{'-'*60}\n")


def get_copyable_format(streams):
    print("\n Copy this template and fill in your desired channels:\n")
    print("selection_dict = {")

    for stream in streams:
        name = stream['info']['name'][0]
        labels, _, _ = extract_channel_labels(stream)
        print(f'    "{name}": {labels},')

    print("}\n")

# ===========================================================
# Complete Pipeline: All-in-One Processing
# ===========================================================

def process_xdf_streams(streams, selection_dict, target_freq=50.0, truncate_n=0, verbose=True):

    if verbose:
        print("\n" + "="*70)
        print("XDF STREAM PROCESSING PIPELINE")
        print("="*70)
        print("\n📋 Configuration:")
        print(f"  Target frequency: {target_freq} Hz")
        print(f"  Truncation: {truncate_n} samples from each end")
        print(f"  Streams requested: {len(selection_dict)}")

    # STEP 1: Stream Selection
    if verbose:
        print("\n" + "-"*70)
        print("STEP 1: Stream Selection")
        print("-"*70)

    selected_streams = select_streams_by_name(streams, selection_dict, verbose=verbose)

    if not selected_streams:
        if verbose:
            print("\n❌ No streams selected. Check your selection_dict.")
        return pd.DataFrame()

    # STEP 2: Synchronization and Initial Processing
    if verbose:
        print("\n" + "-"*70)
        print("STEP 2: Synchronization and Resampling")
        print("-"*70)

    df = streams_to_dataframe(
        selected_streams,
        resample=True,
        target_freq=target_freq,
        use_timestamps=True,
        n=0
    )

    if df.empty:
        if verbose:
            print("\n❌ Empty DataFrame after synchronization.")
        return df

    if verbose:
        print(f"\n✓ Synchronized DataFrame created")
        print(f"  Shape: {df.shape}")
        print(f"  Time range: {df['Time'].iloc[0]:.3f} - {df['Time'].iloc[-1]:.3f} seconds")
        print(f"  Duration: {df['Time'].iloc[-1] - df['Time'].iloc[0]:.3f} seconds")

    # STEP 3: Final Resampling
    if verbose:
        print("\n" + "-"*70)
        print("STEP 3: Final Resampling Verification")
        print("-"*70)

    df_resampled = resample_dataframe(df, target_freq=target_freq, verbose=verbose)

    # STEP 4: Truncation
    if truncate_n > 0:
        if verbose:
            print("\n" + "-"*70)
            print("STEP 4: Edge Truncation")
            print("-"*70)

        df_final = truncate_dataframe(df_resampled, n=truncate_n, verbose=verbose)
    else:
        df_final = df_resampled
        # Ensure Time column starts from 0 even when no truncation
        if "Time" in df_final.columns:
            df_final = df_final.drop(columns=["Time"])
        df_final.insert(0, "Time", np.arange(len(df_final)) / target_freq)
        
        if verbose:
            print("\n" + "-"*70)
            print("STEP 4: Edge Truncation")
            print("-"*70)
            print("\nℹ Truncation skipped (truncate_n=0)")

    # Final Summary
    if verbose:
        print("\n" + "="*70)
        print("PROCESSING COMPLETE ✓")
        print("="*70)
        print(f"\n📊 Final Dataset Summary:")
        print(f"  Shape: {df_final.shape[0]:,} rows × {df_final.shape[1]} columns")
        print(f"  Time range: {df_final['Time'].iloc[0]:.3f} - {df_final['Time'].iloc[-1]:.3f} seconds")  # Now always starts at 0!
        print(f"  Duration: {df_final['Time'].iloc[-1] - df_final['Time'].iloc[0]:.3f} seconds")
        print(f"  Sampling frequency: {target_freq} Hz")
        print(f"  Columns: {list(df_final.columns)}")
        print("\n" + "="*70 + "\n")

    return df_final
    
# def process_xdf_streams(streams, selection_dict, target_freq=50.0, truncate_n=0, verbose=True):

#     if verbose:
#         print("\n" + "="*70)
#         print("XDF STREAM PROCESSING PIPELINE")
#         print("="*70)
#         print("\n📋 Configuration:")
#         print(f"  Target frequency: {target_freq} Hz")
#         print(f"  Truncation: {truncate_n} samples from each end")
#         print(f"  Streams requested: {len(selection_dict)}")

#     # STEP 1: Stream Selection
#     if verbose:
#         print("\n" + "-"*70)
#         print("STEP 1: Stream Selection")
#         print("-"*70)

#     selected_streams = select_streams_by_name(streams, selection_dict, verbose=verbose)

#     if not selected_streams:
#         if verbose:
#             print("\n❌ No streams selected. Check your selection_dict.")
#         return pd.DataFrame()

#     # STEP 2: Synchronization and Initial Processing
#     if verbose:
#         print("\n" + "-"*70)
#         print("STEP 2: Synchronization and Resampling")
#         print("-"*70)

#     df = streams_to_dataframe(
#         selected_streams,
#         resample=True,
#         target_freq=target_freq,
#         use_timestamps=True,
#         n=0
#     )

#     if df.empty:
#         if verbose:
#             print("\n❌ Empty DataFrame after synchronization.")
#         return df

#     if verbose:
#         print(f"\n✓ Synchronized DataFrame created")
#         print(f"  Shape: {df.shape}")
#         print(f"  Time range: {df['Time'].iloc[0]:.3f} - {df['Time'].iloc[-1]:.3f} seconds")
#         print(f"  Duration: {df['Time'].iloc[-1] - df['Time'].iloc[0]:.3f} seconds")

#     # STEP 3: Final Resampling
#     if verbose:
#         print("\n" + "-"*70)
#         print("STEP 3: Final Resampling Verification")
#         print("-"*70)

#     df_resampled = resample_dataframe(df, target_freq=target_freq, verbose=verbose)

#     # STEP 4: Truncation
#     if truncate_n > 0:
#         if verbose:
#             print("\n" + "-"*70)
#             print("STEP 4: Edge Truncation")
#             print("-"*70)

#         df_final = truncate_dataframe(df_resampled, n=truncate_n, verbose=verbose)
#     else:
#         df_final = df_resampled
#         if verbose:
#             print("\n" + "-"*70)
#             print("STEP 4: Edge Truncation")
#             print("-"*70)
#             print("\nℹ Truncation skipped (truncate_n=0)")

#     # Final Summary
#     if verbose:
#         print("\n" + "="*70)
#         print("PROCESSING COMPLETE ✓")
#         print("="*70)
#         print(f"\n📊 Final Dataset Summary:")
#         print(f"  Shape: {df_final.shape[0]:,} rows × {df_final.shape[1]} columns")
#         print(f"  Time range: {df_final['Time'].iloc[0]:.3f} - {df_final['Time'].iloc[-1]:.3f} seconds")
#         print(f"  Duration: {df_final['Time'].iloc[-1] - df_final['Time'].iloc[0]:.3f} seconds")
#         print(f"  Sampling frequency: {target_freq} Hz")
#         print(f"  Columns: {list(df_final.columns)}")
#         print("\n" + "="*70 + "\n")

#     return df_final

# Survey

In [None]:
## Data Upload

# Read CSV directly with full path
directory = r'placeholder' # Data Directory
df = pd.read_csv(directory)
print(f"Original data shape: {df.shape}")
survey_save_to = r'placeholder' # Save Directory

# # Display available columns with indices
# print("Available columns:")
# for i, col in enumerate(df.columns):
#     print(f"  ({i}) {col}")

In [None]:
# # Quick one-liner for checking column index
# for idx, col in enumerate(df.columns):
#     print(f"[{idx}] {col}")

In [None]:
# Select Columns and Rows
selected_columns = df.columns[list(range(18, 35)) + list(range(35, 51))]
selected_rows = slice(-34, None)

# Create final DataFrame
selected_data = df[selected_columns].iloc[selected_rows]

# Show result
print(f"\nSelected {len(selected_columns)} columns and {len(selected_data)} rows")
print(f"Final shape: {selected_data.shape}")
print("\nFinal DataFrame:")
display(selected_data)

In [None]:
## Survey Data Cleanup

# Make a proper copy first to avoid SettingWithCopyWarning
selected_data = selected_data.copy()

# Clean up the 'pid' column, Remove '#' from values like '#5'
selected_data['pid'] = selected_data['pid'].astype(str).str.replace('#', '', regex=False)

# Drop rows where pid starts with an alphabet, Keep only rows where pid starts with a digit (after removing #)
selected_data = selected_data[selected_data['pid'].str.match(r'^\d', na=False)].copy()

# Remove any comma from the 'flt_hrs' column
selected_data['flt_hrs'] = selected_data['flt_hrs'].astype(str).str.replace(',', '', regex=False)

# Clean up the 'experience' column
# Convert to string first
selected_data['experience'] = selected_data['experience'].astype(str)

# For values that look like years (>= 1900), calculate experience as 2025 - year
# For normal experience values, keep as is
def clean_experience(val):
    try:
        num = float(val.replace(',', ''))  # Remove commas if any
        if num >= 1900:  # Likely a year
            return str(int(2025 - num))
        else:  # Already an experience value
            return str(int(num))
    except:
        return val  # Keep original if conversion fails

selected_data['experience'] = selected_data['experience'].apply(clean_experience)

# Rename the cleaned data
clean_survey_data = selected_data.copy()

print(f"\nAfter cleaning pid, flt_hrs, and experience columns:")
print(f"Final shape: {clean_survey_data.shape}")
print(f"\nCleaned pid values:")
print(clean_survey_data['pid'].unique())
print(f"\nCleaned flt_hrs sample:")
print(clean_survey_data['flt_hrs'].head(10))
print(f"\nCleaned experience values:")
print(clean_survey_data['experience'].unique())

# Save as CSV
clean_survey_data.to_csv(os.path.join(survey_save_to, 'clean_survey_data.csv'), index=False)
print(f"\n✓ Saved to: {os.path.join(survey_save_to, 'clean_survey_data.csv')}")
display(clean_survey_data)

In [None]:
# # Quick one-liner for checking column index
# for idx, col in enumerate(clean_survey_data.columns):
#     print(f"[{idx}] {col}")

# XDF Test (Single File)

In [None]:
# Define the file path
file_path = r'placeholder'
# Load Data
streams, header = pyxdf.load_xdf(file_path)
test_save = r'placeholder'

In [None]:
def extract_channel_labels(stream):
    """
    Extracts channel labels from a stream, handling multiple metadata formats.
    Returns a list of labels and optionally types and units if available.
    """
    try:
        channel_entries = stream['info']['desc'][0]['channels'][0]['channel']
        if not isinstance(channel_entries, list):
            channel_entries = [channel_entries]

        labels = []
        types = []
        units = []

        for i, ch in enumerate(channel_entries):
            label = ch.get('label', ch.get('name', [f"Channel {i}"]))[0]
            signal_type = ch.get('type', [''])[0]
            unit = ch.get('unit', [''])[0]

            labels.append(label)
            types.append(signal_type)
            units.append(unit)

        return labels, types, units

    except (KeyError, IndexError, TypeError):
        # Fallback: use generic channel names
        num_channels = stream['info'].get('channel_count', ['0'])[0]
        try:
            num_channels = int(num_channels)
        except ValueError:
            num_channels = 1
        labels = [f"Channel {i}" for i in range(num_channels)]
        return labels, [''] * num_channels, [''] * num_channels

# View stream names and channel details
print("=== Streams and Channels ===")
print(f"Number of streams: {len(streams)}")
print("=" * 50)

for i, stream in enumerate(streams):
    stream_name = stream['info']['name'][0]
    num_channels = int(stream['info']['channel_count'][0])
    labels, types, units = extract_channel_labels(stream)
    
    print(f"Stream {i}: '{stream_name}' with {num_channels} channels")
    print("Channel details:")
    for j in range(len(labels)):
        type_info = f" ({types[j]})" if types[j] else ""
        unit_info = f" [{units[j]}]" if units[j] else ""
        print(f"  {j}: {labels[j]}{type_info}{unit_info}")
    print("-" * 50)

# # Print full metadata for the first stream
# print("\n=== Full Metadata for First Stream ===")
# pprint.pprint(streams[0]['info'])

In [None]:
selection = {
    "VarjoEyeMetrics": [
        "focusDistance",
        "stability", 
        "interPupillaryDist",
        "leftPupilIrisRatio",
        "rightPupilIrisRatio",
        "leftPupilDiam",
        "rightPupilDiam", 
        "leftIrisDiam",
        "rightIrisDiam",
        "leftOpenness",
        "rightOpenness"
    ],
    "VarjoGaze": [
        "fwdX", 
        "fwdY", 
        "fwdZ"
    ],
    "XPlaneData": [
        "sim/flightmodel/position/latitude",
        "sim/flightmodel/position/longitude",
        "sim/cockpit2/gauges/indicators/altitude_ft_pilot",
        "sim/cockpit/autopilot/airspeed"
    ],
    "Polar H10": [
        "HR",
        "RRI"
    ]
}

# Define column name mappings
column_rename_map = {
    # VarjoEyeMetrics
    "VarjoEyeMetrics_focusDistance": "Focus Distance",
    "VarjoEyeMetrics_stability": "Gaze Stability",
    "VarjoEyeMetrics_interPupillaryDist": "Interpupillary Distance",
    "VarjoEyeMetrics_leftPupilIrisRatio": "Left Pupil Iris Ratio",
    "VarjoEyeMetrics_rightPupilIrisRatio": "Right Pupil Iris Ratio",
    "VarjoEyeMetrics_leftPupilDiam": "Left Pupil Diameter",
    "VarjoEyeMetrics_rightPupilDiam": "Right Pupil Diameter",
    "VarjoEyeMetrics_leftIrisDiam": "Left Iris Diameter",
    "VarjoEyeMetrics_rightIrisDiam": "Right Iris Diameter",
    "VarjoEyeMetrics_leftOpenness": "Left Eye Openness",
    "VarjoEyeMetrics_rightOpenness": "Right Eye Openness",
    
    # VarjoGaze
    "VarjoGaze_fwdX": "Gaze Forward X",
    "VarjoGaze_fwdY": "Gaze Forward Y", 
    "VarjoGaze_fwdZ": "Gaze Forward Z",
    
    # XPlaneData (simplified names)
    "XPlaneData_sim/flightmodel/position/latitude": "Latitude",
    "XPlaneData_sim/flightmodel/position/longitude": "Longitude",
    "XPlaneData_sim/cockpit2/gauges/indicators/altitude_ft_pilot": "Altitude",
    "XPlaneData_sim/cockpit/autopilot/airspeed": "Airspeed",
    
    # Polar H10 - these will be dynamically matched
    "HR": "Heart Rate",
    "RRI": "R-R Interval"
}

# Function to rename Polar columns dynamically
def rename_polar_columns(df):
    """Rename any columns containing 'Polar' to standardized names"""
    new_columns = {}
    
    for col in df.columns:
        if 'Polar' in col:
            # Extract the actual measurement name (after the last underscore)
            measurement = col.split('_')[-1] if '_' in col else col
            
            # Map to standardized name
            if measurement in column_rename_map:
                new_columns[col] = column_rename_map[measurement]
            else:
                # Keep original but clean it up
                new_columns[col] = f"Polar {measurement}"
    
    return df.rename(columns=new_columns)

In [None]:
try:
    # Run the complete pipeline for a single file
    df_complete = process_xdf_streams(
        streams=streams,
        selection_dict=selection,
        target_freq=4,
        truncate_n=20,
        verbose=True
    )
    
    # First rename non-Polar columns using the static mapping
    df_complete = df_complete.rename(columns=column_rename_map)
    
    # Then dynamically rename any Polar columns
    df_complete = rename_polar_columns(df_complete)
    
    # Display the resulting dataframe
    print("Processed Dataframe:")
    display(df_complete)
    print(f"Data shape: {df_complete.shape}")
    print(f"Data columns: {df_complete.columns.tolist()}")

    # Create CSV filename in the test_save directory
    import os
    filename = os.path.basename(file_path).replace('.xdf', '.csv')
    csv_filename = os.path.join(test_save, filename)
    df_complete.to_csv(csv_filename, index=False)
    print(f"✓ Saved processed data to: {csv_filename}")
    
except Exception as e:
    print(f"Error during processing: {e}")
    import traceback
    traceback.print_exc()

# XDF b1: Baseline HRV (Multiple Files)

In [None]:
# Define Directory, Selection and Dataframe Names
directory = r'placeholder' # Data Directory
selection = {
    "Polar H10": ["HR", "RRI"]
}
hrv_save_to = r'placeholder' # Save Directory

# Define column name mappings for Polar data
column_rename_map = {
    "HR": "Heart Rate",
    "RRI": "R-R Interval"
}

# Function to rename Polar columns dynamically
def rename_polar_columns(df):
    """Rename any columns containing 'Polar' to standardized names"""
    new_columns = {}
    
    for col in df.columns:
        if 'Polar' in col:
            # Extract the actual measurement name (after the last underscore)
            measurement = col.split('_')[-1] if '_' in col else col
            
            # Map to standardized name
            if measurement in column_rename_map:
                new_columns[col] = column_rename_map[measurement]
            else:
                # Keep original but clean it up
                new_columns[col] = f"Polar {measurement}"
        else:
            # Keep non-Polar columns as they are
            new_columns[col] = col
    
    return df.rename(columns=new_columns)

# Loop through all .xdf files and create the .csv dataframe by the same name
files = os.listdir(directory)
for file in files:
    if file.endswith('.xdf'):
        # Create full path to the file
        filepath = os.path.join(directory, file)

        try:
            streams, header = pyxdf.load_xdf(filepath)  # Use filepath instead of file
        except Exception as e:
            print(f"Failed to load {file}: {e}")
            continue

        try:
            df_complete = process_xdf_streams(
                streams=streams,
                selection_dict=selection,
                target_freq=4.0,
                truncate_n=20,
                verbose=True
            )
            
            # Rename Polar columns dynamically
            df_complete = rename_polar_columns(df_complete)
            
        except Exception as e:
            print(f"Failed to process streams from {file}: {e}")
            continue

        # Create savename by replacing .xdf extension with .csv
        savename = file.replace('.xdf', '.csv')
        # Save to the hrv_save_to directory
        savepath = os.path.join(hrv_save_to, savename)

        try:
            df_complete.to_csv(savepath, index=False)  # Save to full path
            print(f"Saved {savename} with {len(df_complete)} rows to {hrv_save_to}")
            print(f"Columns: {list(df_complete.columns)}")
        except Exception as e:
            print(f"Failed to save {savename}: {e}")

# XDF b3: Station B (Multiple Files)

In [None]:
# b3 = Station B

# Define Directory, Selection and Dataframe Names
directory = r'placeholder' # b3 =Station B
selection = {
    "VarjoEyeMetrics": [
        "focusDistance",
        "stability", 
        "interPupillaryDist",
        "leftPupilIrisRatio",
        "rightPupilIrisRatio",
        "leftPupilDiam",
        "rightPupilDiam", 
        "leftIrisDiam",
        "rightIrisDiam",
        "leftOpenness",
        "rightOpenness"
    ],
    "VarjoGaze": [
        "fwdX", 
        "fwdY", 
        "fwdZ"
    ],
    "XPlaneData": [
        "sim/flightmodel/position/latitude",
        "sim/flightmodel/position/longitude",
        "sim/cockpit2/gauges/indicators/altitude_ft_pilot"
    ],
    "Polar H10": [
        "HR",
        "RRI"
    ]
}
b3_save_to = r'placeholder3' # Save Directory

# Define column name mappings
column_rename_map = {
    # VarjoEyeMetrics
    "VarjoEyeMetrics_focusDistance": "Focus Distance",
    "VarjoEyeMetrics_stability": "Gaze Stability",
    "VarjoEyeMetrics_interPupillaryDist": "Interpupillary Distance",
    "VarjoEyeMetrics_leftPupilIrisRatio": "Left Pupil Iris Ratio",
    "VarjoEyeMetrics_rightPupilIrisRatio": "Right Pupil Iris Ratio",
    "VarjoEyeMetrics_leftPupilDiam": "Left Pupil Diameter",
    "VarjoEyeMetrics_rightPupilDiam": "Right Pupil Diameter",
    "VarjoEyeMetrics_leftIrisDiam": "Left Iris Diameter",
    "VarjoEyeMetrics_rightIrisDiam": "Right Iris Diameter",
    "VarjoEyeMetrics_leftOpenness": "Left Eye Openness",
    "VarjoEyeMetrics_rightOpenness": "Right Eye Openness",
    
    # VarjoGaze
    "VarjoGaze_fwdX": "Gaze Forward X",
    "VarjoGaze_fwdY": "Gaze Forward Y", 
    "VarjoGaze_fwdZ": "Gaze Forward Z",
    
    # XPlaneData (simplified names)
    "XPlaneData_sim/flightmodel/position/latitude": "Latitude",
    "XPlaneData_sim/flightmodel/position/longitude": "Longitude",
    "XPlaneData_sim/cockpit2/gauges/indicators/altitude_ft_pilot": "Altitude",
    
    # Polar H10 - these will be dynamically matched
    "HR": "Heart Rate",
    "RRI": "R-R Interval"
}

# Function to rename Polar columns dynamically
def rename_polar_columns(df):
    """Rename any columns containing 'Polar' to standardized names"""
    new_columns = {}
    
    for col in df.columns:
        if 'Polar' in col:
            # Extract the actual measurement name (after the last underscore)
            measurement = col.split('_')[-1] if '_' in col else col
            
            # Map to standardized name
            if measurement in column_rename_map:
                new_columns[col] = column_rename_map[measurement]
            else:
                # Keep original but clean it up
                new_columns[col] = f"Polar {measurement}"
        else:
            # Keep non-Polar columns as they are
            new_columns[col] = col
    
    return df.rename(columns=new_columns)

# Loop through all .xdf files and create the .csv dataframe by the same name
files = os.listdir(directory)
for file in files:
    if file.endswith('.xdf'):
        # Create full path to the file
        filepath = os.path.join(directory, file)

        try:
            streams, header = pyxdf.load_xdf(filepath)  # Use filepath instead of file
        except Exception as e:
            print(f"Failed to load {file}: {e}")
            continue

        try:
            df_complete = process_xdf_streams(
                streams=streams,
                selection_dict=selection,
                target_freq=4.0,
                truncate_n=20,
                verbose=True
            )
            
            # First rename non-Polar columns using the static mapping
            df_complete = df_complete.rename(columns=column_rename_map)
            
            # Then dynamically rename any Polar columns
            df_complete = rename_polar_columns(df_complete)
            
        except Exception as e:
            print(f"Failed to process streams from {file}: {e}")
            continue

        # Create savename by replacing .xdf extension with .csv
        savename = file.replace('.xdf', '.csv')
        # Save to the b3_save_to directory
        savepath = os.path.join(b3_save_to, savename)

        try:
            df_complete.to_csv(savepath, index=False)  # Save to full path
            print(f"Saved {savename} with {len(df_complete)} rows to {b3_save_to}")
            print(f"Columns: {list(df_complete.columns)}")
        except Exception as e:
            print(f"Failed to save {savename}: {e}")