Code block 1

In [None]:
from datetime import datetime, time
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, AutoMinorLocator
import numpy as np
import os
import pandas as pd
import pytz
import re
from scipy.signal import savgol_filter
from scipy.signal import savgol_filter

grav_acc = 9.81
import logging
from typing import Tuple, Union

Fake GUI for User Inputs (modify segmentation, annotation, and design parameters)

In [None]:
data_dir = 'confidential'        
dry_fire = False

# display settings
plot_line_width = 1.0
plot_primary_line_alpha = 1.0
plot_secondary_line_alpha = 0.5
plot_grid_line_alpha = 0.5
plot_title_font_size = 18

full_force_range = range(1, 6)
full_IMU_range = ['X', 'Y', 'Z']
IMUs_to_plot = ['X', 'Y', 'Z']

# Modify these parameters to change the time range plotted
plot_start_time = None
plot_end_time = None
plot_start_time = time(00, 00, 9)   # time(hour, min, sec)
plot_end_time = time(00, 1, 44)     # time(hour, min, sec)

# Modify these parameters with caution as they drastically affect the performance of the segmentation code
rolling_window_width = 21           # design choice: size of moving average window
event_derivative_threshold_ADC = 4  # design choice: derivative threshold
horizon = 5                         # design choice: horizon

# Modify these parameters according to use case
n_seconds_x_tick = 5                # the number of seconds between each x-tick in the plot 
forces_to_plot = range(1, 6)        # the forces to plot (e.g., [1, 2, 3] or range(1, 4) plots forces F1, F2, and F3)
magnitude_threshold_ADC = 275       # threshold for eliminating minor peaks to speed up the segmentation and annotation process

In [None]:
data_dict = {}
calibration_dict = {}

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

for root, dirs, files in os.walk(data_dir):
    for file in files:
        if file.endswith('.csv') and (file.startswith('R_U_') or file.startswith('U_') or file.startswith('Trials')):

            # Defining trial names and loading pd data frames from saved .csv files
            file_path = os.path.join(root, file)
            parts = file_path.split(os.path.sep)
            c_num = parts[-3].split('_')[1]
            u_num = parts[-2].split('_')[1]
            task = parts[-1].split('_')[0].strip('.csv')
            key = f'C{c_num}_U{u_num}_{task}'
            df = pd.read_csv(file_path, low_memory=False)

            # Checking for corrupted data and excluding them from analyzed data
            target_num_columns = 8 if file.startswith('R_U_') else 15 if file.startswith('U_') else 33
            empty_rows = df.loc[df.isnull().any(axis=1)].index.tolist()
            row_lengths = df.apply(len, axis=1)
            long_rows = row_lengths[row_lengths > target_num_columns].index.tolist()
            
            if empty_rows != []:
                logger.warning(f"Empty rows at indices {empty_rows} for {key} ({len(empty_rows)} rows)")
            if long_rows != []:
                logger.warning(f"Long rows at indices {long_rows} for {key} ({len(long_rows)} rows). \nThis occurs when previus rows were interrupted before completion.")
            
            if file.startswith('R_U_') or file.startswith('U_'):
                data_dict[key] = df.drop(long_rows + empty_rows) 
            elif file.startswith('Trials'):
                calibration_dict[key] = df.drop(long_rows + empty_rows) 
            
print(data_dict.keys())
print(calibration_dict.keys())

In [None]:
def extract_key_components(key: str) -> Tuple[Union[int, float], Union[int, float]]:
    """
    Extracts company and user IDs from a trial key string.
    Args:
        key (str): The trial key string, e.g., 'C123_U456_R' or 'C789_U012_U'.

    Returns:
        tuple: A tuple of (company_id, user_id), where both are integers if the
            key is valid, or (float('inf'), float('inf')) if the key is invalid.
    """
    match = re.match(r'C(\d+)_U(\d+)_(R|U)', key)
    if match:
        company_id, user_id, = map(int, match.groups()[:2])
        return (company_id, user_id)
    return float('inf'), float('inf')

def get_trial_name(key: str) -> tuple[str, str]:
    """
    Splits a trial name into base name (company and user IDs) and task components.

    Args:
        key (str): The trial key string, e.g., 'C123_U456_R' or 'C789_U012_U'.

    Returns:
        tuple[str, str]: A tuple of (base_name, task), where base_name is 'company_user' and task is the remaining component.
    """
    parts = key.split('_')
    company = parts[0]
    user = parts[1]
    task = parts[2]
    
    base_name = f'{company}_{user}'
    return base_name, task

def plot_trial(trial_name:str, df:pd.DataFrame, dry_fire=False, start=None, end=None) -> list[int] | None:
    """
    Plots force, gyro, and accelerometer data for a trial with customizable time range.

    Args:
        trial_name (str): Name of the trial, used for titles and saving the plot.
        df (pd.DataFrame): DataFrame containing trial data with 'Timestamp', force, gyro, and accelerometer columns.
        dry_fire (boolean): Boolean indicating dry fire or live_fire.
        start (datetime.time, optional): Start time for the x-axis. Defaults to None. This affects what is displayed, but not the actual segmentation.
        end (datetime.time, optional): End time for the x-axis. Defaults to None. This affects what is displayed, but not the actual segmentation.

    Returns:
        list[int] | None: List of peak indices for segmented force data if only one force is plotted (when annotating segmented spikes), otherwise None.
    """
    global grav_acc

    colors = ['b', 'g', 'r', 'c', 'm', 'orange']

    figsize = (30, 12) if dry_fire else (30, 30)
    n_subplots = 2 if dry_fire else 4
    fig, axes = plt.subplots(n_subplots, 1, figsize=figsize)
    ax1, ax2, ax3, ax4 = (axes[0], axes[1], None, None) if dry_fire else axes

    force_ADC_columns = [f'F{force_to_plot}_ADC' for force_to_plot in forces_to_plot]
    gyro_columns = [f'Gyro_{IMU_to_plot}' for IMU_to_plot in IMUs_to_plot]
    acc_columns = [f'Acc_{IMU_to_plot}' for IMU_to_plot in IMUs_to_plot]

    lines_ax1 = []
    lines_ax2 = []

    for ax in axes:
        ax.xaxis.set_ticks([]) 
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
        ax.xaxis.set_major_locator(mdates.SecondLocator(interval=n_seconds_x_tick))
        # Setting limits using provided start and end times
        if start is not None and end is not None:
            reference_date = df['Timestamp'].iloc[0].date()
            start_time = datetime.combine(reference_date, start).replace(microsecond=000000, tzinfo=pytz.UTC)
            end_time = datetime.combine(reference_date, end).replace(microsecond=000000, tzinfo=pytz.UTC)
            ax.set_xlim(start_time, end_time)

        if ax in [ax1, ax2]:
            ax.grid(which='minor', alpha=plot_grid_line_alpha, zorder=0)
            ax.set_ylabel('Raw Data Derivative (ADC/s)')
            ax.yaxis.set_major_locator(MultipleLocator(50))
            ax.yaxis.set_minor_locator(AutoMinorLocator())
            ax.grid(which='major', axis='y', linestyle='-', color='k', alpha=plot_grid_line_alpha, zorder=0)
            ax.yaxis.set_tick_params(which='major', labelright=True)
            ax.set_ylim(220, 470)
    
    # Plots segmented spikes in subplot 1
    for idx, force_ADC_column in enumerate(force_ADC_columns):

        segmented_force = df[force_ADC_column].where(df['Spike_counter'].notna())

        line, = ax1.plot(df['Timestamp'], segmented_force, label=force_ADC_column, color=colors[idx], linewidth=plot_line_width, zorder=10)
        lines_ax1.append(line)
        ax1.legend(loc='upper right', bbox_to_anchor=(1, 1))
        ax1.set_title(f'Segmented Forces for {trial_name}', fontsize=plot_title_font_size)

        # When annottaing, spike counters are shown in the plot to help annotator remove undesired segments
        if len(force_ADC_columns) == 1:
            is_non_nan = segmented_force.notna()
            group = (is_non_nan != is_non_nan.shift(1)).cumsum()
            non_nan_segments = df['F1_ADC'][is_non_nan].groupby(group[is_non_nan])

            peak_indices = []
            segment_counters = {}
            for counter, (group_id, segment) in enumerate(non_nan_segments, 1):
                max_idx = segment.idxmax() 
                if pd.notna(max_idx): 
                    peak_indices.append(max_idx)
                    for idx in segment.index:
                        segment_counters[idx] = counter
            for idx in peak_indices:
                counter = segment_counters[idx]
                ax1.axvline(df.loc[idx, 'Timestamp'], linewidth=plot_line_width, linestyle=':', color='gray', alpha=plot_secondary_line_alpha)
                ax1.text(df.loc[idx, 'Timestamp'], 400+50*(counter%2), f'{counter}', 
                        verticalalignment='top', horizontalalignment='center', fontsize=10, color='gray')

    # Plots raw data in subplot 2 
    for idx, force_ADC_column in enumerate(force_ADC_columns):
        line, = ax2.plot(df['Timestamp'], df[force_ADC_column], label=force_ADC_column, color=colors[idx], linewidth=plot_line_width, zorder=10)
        lines_ax2.append(line)
        ax2.legend(loc='upper right', bbox_to_anchor=(1, 1))
        ax2.set_title(f'Raw Forces for {trial_name}', fontsize=plot_title_font_size)

    # Plots acceleration and IMU data in subplots 3 and 4 for live fire. 
    # Todo: This part is incomplete; it plots but does not provide much use. Will be expanded upon once live fire data is available.
    if not dry_fire:
        for idx, gyro_column in enumerate(gyro_columns):
            ax3.grid(which='both')
            ax3.yaxis.set_major_locator(MultipleLocator(50))
            ax3.yaxis.set_minor_locator(AutoMinorLocator())     
            ax3.grid(which='major', axis='y', linestyle='-', color='k', alpha=plot_grid_line_alpha)
            ax3.yaxis.set_tick_params(which='major', labelright=True, labelleft=True)
            ax3.plot(df['Timestamp'], df[gyro_column], label=gyro_column, color=colors[idx], linewidth=plot_line_width)
            ax3.legend(loc='upper right', bbox_to_anchor=(1, 1))

        for idx, acc_column in enumerate(acc_columns):
            ax4.grid(which='both')
            ax4.yaxis.set_major_locator(MultipleLocator(5))
            ax4.yaxis.set_minor_locator(AutoMinorLocator())     
            ax4.grid(which='major', axis='y', linestyle='-', color='k', alpha=plot_grid_line_alpha)
            ax4.yaxis.set_tick_params(which='major', labelright=True, labelleft=True)
            ax4.plot(df['Timestamp'], df[acc_column], label=acc_column, color=colors[idx], linewidth=plot_line_width)
            ax4.legend(loc='upper right', bbox_to_anchor=(1, 1))

    for ax in axes:
        plt.setp(ax.get_xticklabels(), rotation=90, ha='center')
        ax.set_xlabel('Time (HH:MM:SS)')
        ax.set_ylabel('Filtered Force (ADC)') if ax is ax1 else ax.set_ylabel('Force (ADC)') if ax is ax2 else ax.set_ylabel('Gyro ()') if ax is ax3 else ax.set_ylabel('Acc ()')

    plt.subplots_adjust(hspace=0.5)
    plt.tight_layout()
    plt.savefig(f'{trial_name}/{trial_name}.png', dpi=600)
    plt.show()

    if len(forces_to_plot) == 1:
        return peak_indices

def check_state(data:pd.Series, threshold:float, previous_state:int) -> int:
    """
    Determines the state of a signal based on its values relative to a threshold.

    Args:
        data (pd.Series): A pandas Series of numeric values to evaluate.
        threshold (float): Positive threshold for state comparison.
        previous_state (int): Previous state (1, -1, or 0) for hysteresis logic.

    Returns:
        int: State value (1 if all values >= threshold, -1 if all <= -threshold,
            0 if all within [-threshold, threshold], or based on previous_state
            for mixed values).
    """
    if (data >= threshold).all():
        return 1
    if (data <= -threshold).all():
        return -1
    if (abs(data) < threshold).all():
        return 0
    return int((data > threshold).any()) if previous_state > 0 else -int((data < -threshold).any()) if previous_state < 0 else 0

def peak_event_detection_consecutive(df:pd.DataFrame, threshold_derivative_filter:float, horizon:int, threshold_magnitude:int) -> pd.DataFrame:
    """
    Detects and segments spike events in a force sensor signal ('F1_ADC') based on a custom
    momentum-aware state machine, filtering out non-event spikes based on magnitude.
    The function uses the derivative of a rolling average of the force signal to determine
    the signal's trend (increase, decrease, or no change). 

    Parameters:
        df : pd.DataFrame
            The input DataFrame containing time-series data. It must contain the columns:
            'F1_ADC_rolling' (the rolling average of the raw force signal),
            'F1_ADC' (the raw force signal), and
            'Timestamp' (datetime objects used to calculate the derivative).
        threshold_derivative_filter : float
            The minimum absolute value of the derivative's moving average required to switch
            the trend state from zero (no change) to a positive or negative state.
        horizon : int
            The horizon length for state evaluation used by the state machine.
        threshold_magnitude : int
            The minimum peak force value (in 'F1_ADC') a detected spike segment must reach
            to be considered a valid event. Segments below this are marked as a secondary
            status code (2).

    Returns:
        pd.DataFrame
            The original DataFrame with two new columns added:
            'Masks': The internal state (1: increasing trend, -1: decreasing trend, NaN: filtered or neutral).
            'Spike_counter': An integer ID (starting from 1) for each detected and kept spike event segment.
    """

    previous_state = 0
    masks = pd.Series(0, index=df[f'F1_ADC_rolling'].index)

    # Differentiates the moving average of F1
    force_rolling_diff = df['F1_ADC_rolling'].diff()
    time_diff = df['Timestamp'].diff().dt.total_seconds()
    force_rolling_deriv = force_rolling_diff / time_diff

    for idx in range (horizon, len(df)+1):
        # Checks a unique state that essentuially encapsulates the confidence in the current trend (increase, decrease, or neither) in the signal 
        # by checking moving_avg(derivative(moving_avg(raw_signal))) and the previous state
        current_state = check_state(data=force_rolling_deriv.iloc[idx - horizon: idx], threshold=threshold_derivative_filter, previous_state=previous_state)
        if previous_state <= 0:             # Trusts the current state when the previous state does not indicate an increase (decrease or no change)
            current_state = current_state
        else:                               # Favors the momentum in previous increasing trends, only stopping that momemntum when the current state indicates a decrease
            current_state = current_state if current_state < 0 else 1
        masks.iloc[idx - horizon] = current_state
        previous_state = current_state

    group = (masks != masks.shift(1)).cumsum()
    segments = masks.groupby(group)
    ones_segments = segments.apply(lambda x: x.index if (x == 1).all() else None).dropna()

    for segment_indices in ones_segments:
        
        # Removes spike recorded that are clearly not firearm trigger engagement events due to low peak value 
        max_f1_adc = df.loc[segment_indices, 'F1_ADC'].max()
        if max_f1_adc < threshold_magnitude:
            masks.loc[segment_indices] = 2
            
            last_idx = segment_indices[-1]
            pos = masks.index.get_loc(last_idx)
            if pos < len(masks) - 1:
                next_pos = pos + 1
                next_indices = []
                while next_pos < len(masks) and masks.iloc[next_pos] == -1:
                    next_indices.append(masks.index[next_pos])
                    next_pos += 1
                if next_indices:
                    masks.loc[next_indices] = 2

    # Keeping and extarcing relevant segments and numbering them
    df['Masks'] = masks.where(masks.isin([1, -1]), np.nan)
    is_non_nan = df['Masks'].notna()                                       
    group_ids = (is_non_nan != is_non_nan.shift()).cumsum()              
    non_nan_groups = group_ids[is_non_nan]                                
    segment_numbers = non_nan_groups.groupby(non_nan_groups).ngroup() + 1  

    full_segment_series = pd.Series(segment_numbers, index=non_nan_groups.index)
    df['Spike_counter'] = full_segment_series 

    return df

In [None]:
if __name__ == "__main__":

    sorted_keys = sorted(data_dict.keys(), key=extract_key_components)      
    data_dict = {key: data_dict[key] for key in sorted_keys}

    for key, value in data_dict.items():
        base_name, task = get_trial_name(key)
        company_id, user_id = extract_key_components(key)
        is_trial_data = True if task.startswith('U') else False     # Only run the segmentor on the sensor data

        if is_trial_data:
            data_df = value.copy()
            data_df['Timestamp'] = data_df['Timestamp'].str.replace(r'(\d{2}:\d{2}:\d{2}):(\d{3}Z)', r'\1.\2', regex=True)
            data_df['Timestamp'] = pd.to_datetime(data_df['Timestamp'])

            # Segmentation
            calibration_df = calibration_dict[key.strip(f'{'_U'}') + '_Trial'].copy()
            data_df[f'F1_ADC_rolling'] = data_df[f'F1_ADC'].rolling(window=rolling_window_width, center=True, min_periods=1).mean()     # Todo: Currently hard-coded to channel 1 for dry fire, not sure if changes will be needed
            filtered_df = peak_event_detection_consecutive(df=data_df, threshold_derivative_filter=event_derivative_threshold_ADC, horizon=horizon, threshold_magnitude=magnitude_threshold_ADC)
        
            os.makedirs(base_name, exist_ok=True)
            peak_indices = plot_trial(base_name, filtered_df, dry_fire=dry_fire, start=plot_start_time, end=plot_end_time)
            filtered_df.to_csv(f'{base_name}/{base_name}_filtered.csv', index=False)

            # Segmentation for annontation
            if len(forces_to_plot) == 1 and not os.path.isfile(f'{base_name}/{base_name}_annotation.csv'):
                segmentation_df_csv = pd.DataFrame({
                    'Company': company_id,
                    'User': user_id,
                    'Time': data_df.loc[peak_indices, 'Timestamp'],
                    'Counter': range(1, len(peak_indices) + 1),
                    'Keep': np.nan,
                    'Trigger': np.nan,
                    'Grip': np.nan
                })

                segmentation_df_csv.to_csv(f'{base_name}/{base_name}.csv', index=False)