Failed to detect onset of EMG burst

In [1]:
# library
import os
import numpy as np
import scipy.io
import scipy.signal as signal
import matplotlib.pyplot as plt



In [2]:
def load_semg_data(data_folder):
    """
    Load sEMG data from the given folder into a layered dictionary.
    Structure: {subject_name: {gesture_name: semg_array, ...}, ...}
    
    Args:
        data_folder (str): The path to the Data folder containing subject folders.
        
    Returns:
        dict: A nested dictionary with the sEMG data.
    """
    semg_data = {}
    
    for subject in os.listdir(data_folder):
        subject_path = os.path.join(data_folder, subject)
        if os.path.isdir(subject_path) and subject.startswith('HS'): # for now, only HS subjects
            semg_data[subject] = {}
            
            # each .mat file (gesture)
            for filename in os.listdir(subject_path):
                if filename.endswith('.mat'):
                    gesture_name = os.path.splitext(filename)[0]  # Remove .mat extension
                    file_path = os.path.join(subject_path, filename)
                    mat_contents = scipy.io.loadmat(file_path)
                    
                    # Each .mat file is assumed to have only one variable (ignore __header__, __version__, __globals__)
                    for key in mat_contents:
                        if not key.startswith('__'):
                            semg_array = mat_contents[key]
                            break
                            
                    # Store the sEMG data under the gesture name for the subject
                    semg_data[subject][gesture_name] = semg_array
                    
    return semg_data

# Example usage:
data_folder_path = 'Data' 
structured_data = load_semg_data(data_folder_path)

# Now data_structure is a dictionary where each key is a subject folder name and each value is another dictionary
# mapping gesture names to the corresponding sEMG data (64 x N numpy array).
print(structured_data.keys())  # list all subject folder names



dict_keys(['HS1', 'HS2', 'HS3', 'HS4', 'HS5', 'HS6', 'HS7', 'HS8'])


In [3]:
# ==========================
# Hyperparameter Definitions
# ==========================
fs = 2048              # Sampling frequency (Hz)
h = 5                  # Threshold multiplier (for baseline: T = mu + h * std)
bp_order = 4           # Order for the bandpass filter
bp_low = 20.0          # Bandpass low cutoff (Hz)
bp_high = 450.0         # Bandpass high cutoff (Hz)
notch_order = 2        # Order for the notch filter (IIR)
notch_freq = 50.0      # Notch filter frequency (Hz)
notch_Q = 50.0         # Notch filter Q factor
lp_order = 3           # Order for the low-pass filter (for rectified signal)
lp_cutoff = 50.0       # Low-pass cutoff (Hz)
baseline_duration = 1.0  # Duration (in seconds) for baseline estimation
window_length = 50     # Number of consecutive samples to check in aggregation
channel_percent = 0.05  # Percentage of channels required to be active (20%)

# ==========================
# Helper Functions
# ==========================

def bandpass_and_notch(sig, fs):
    """
    Apply a bandpass filter (4th order Butterworth, 20–450 Hz) 
    and a notch filter (2nd order IIR at 50 Hz, Q=50) to the signal.
    """
    nyq = 0.5 * fs
    lowcut = bp_low / nyq
    highcut = bp_high / nyq
    b_bp, a_bp = signal.butter(bp_order, [lowcut, highcut], btype='bandpass')
    filtered = signal.filtfilt(b_bp, a_bp, sig)
    
    # Notch filter design
    w0 = notch_freq / nyq
    b_notch, a_notch = signal.iirnotch(w0, notch_Q)
    filtered = signal.filtfilt(b_notch, a_notch, filtered)
    return filtered

def apply_TKEO(sig):
    """
    Apply the Teager–Kaiser Energy Operator (TKEO) on a 1D signal.
    Uses zero padding at the boundaries.
    """
    padded = np.concatenate(([0], sig, [0]))
    tkeo = padded[1:-1]**2 - padded[0:-2]*padded[2:]
    return tkeo

def lowpass_filter(sig, fs, cutoff=lp_cutoff, order=lp_order):
    """Apply a low-pass Butterworth filter."""
    nyq = 0.5 * fs
    normalized_cutoff = cutoff / nyq
    b, a = signal.butter(order, normalized_cutoff, btype='low')
    return signal.filtfilt(b, a, sig)

def process_channel(sig, fs, h):
    """
    Process a single channel signal by:
      1. Filtering with bandpass and notch filters.
      2. Applying TKEO.
      3. Full-wave rectification.
      4. Low-pass filtering.
      5. Computing threshold T from the baseline (first 2 s).
    Returns the processed signal and its threshold.
    """
    filtered = bandpass_and_notch(sig, fs)
    tkeo = apply_TKEO(filtered)
    rectified = np.abs(tkeo)
    processed = lowpass_filter(rectified, fs)
    
    baseline_samples = int(baseline_duration * fs)
    if len(processed) >= 3 * baseline_samples:
        baseline = processed[-2*baseline_samples:-baseline_samples]
    else:
        baseline = processed
    mu = np.mean(baseline)
    std = np.std(baseline)
    T = mu + h * std
    return processed, T

def aggregated_activity(processed_data, thresholds, window_length=window_length, channel_percent=channel_percent):
    """
    Compute an aggregated binary time sequence for a gesture.
    At each time index, if at least ceil(channel_percent*num_channels) channels have
    50 consecutive samples (starting at that index, padded if necessary) above their threshold,
    mark that time index as active (1); otherwise, inactive (0).
    """
    num_channels, N = processed_data.shape
    required_channels = int(np.ceil(channel_percent * num_channels))
    
    binary_seq = np.zeros(N, dtype=int)
    for i in range(N):
        # Determine the window: if not enough samples, pad with the last sample's value for each channel
        if i + window_length <= N:
            window_data = processed_data[:, i:i+window_length]
        else:
            pad_length = i + window_length - N
            window_data = np.concatenate((processed_data[:, i:], 
                                          np.tile(processed_data[:, -1:], (1, pad_length))), axis=1)
        # For each channel, check if all values in the window exceed the channel's threshold
        active_channels = np.sum(np.all(window_data > thresholds[:, None], axis=1))
        if active_channels >= required_channels:
            binary_seq[i] = 1
    return binary_seq

def consecutive_segments(binary_signal):
    """
    Compute consecutive segment lengths for a binary signal.
    Returns a list of tuples (segment_length, value).
    """
    segments = []
    if len(binary_signal) == 0:
        return segments
    current_val = binary_signal[0]
    count = 1
    for val in binary_signal[1:]:
        if val == current_val:
            count += 1
        else:
            segments.append((count, current_val))
            current_val = val
            count = 1
    segments.append((count, current_val))
    return segments

# ==========================
# Main Processing Loop
# ==========================
# Assume structured_data is already loaded with the format:
# { subject_name: { gesture_name: sEMG_array (64 x N), ... }, ... }
processed_data = {}  # new dictionary to hold aggregated binary sequences
first_subject_key = list(structured_data.keys())[0]
truncated_structured_data = {first_subject_key: structured_data[first_subject_key]}
for subject, gestures in truncated_structured_data.items():
    print(f"Processing subject: {subject}...")
    processed_data[subject] = {}
    
    for gesture, data in gestures.items():
        # Expect each gesture to be a 2D NumPy array with shape (64, N)
        if isinstance(data, np.ndarray) and data.ndim == 2 and data.shape[0] == 64:
            num_channels, N = data.shape
            proc_channels = np.zeros_like(data, dtype=float)  # store processed signal for each channel
            thresholds = np.zeros(num_channels)
            
            # Process each channel
            for ch in range(num_channels):
                sig = data[ch, :]
                proc_sig, T = process_channel(sig, fs, h)
                proc_channels[ch, :] = proc_sig
                thresholds[ch] = T
            
            # Aggregate binary activity based on new criterion: if at least 20% of channels have 
            # 50 consecutive samples above threshold, mark time index as active.
            agg_binary = aggregated_activity(proc_channels, thresholds, window_length, channel_percent)
            processed_data[subject][gesture] = agg_binary
        else:
            print(f"Skipping {subject} - {gesture}: Data format unexpected.")

print("All subjects processed.")
            

Processing subject: HS1...
All subjects processed.


In [None]:
# For example, to print debug info for subject 'HS01' (change as needed)
subject_to_debug = 'HS1'  # Adjust to one of the keys in processed_data

if subject_to_debug in processed_data:
    print(f"Debug information for subject: {subject_to_debug}")
    for gesture, bin_seq in processed_data[subject_to_debug].items():
        segs = consecutive_segments(bin_seq)
        # Build a string of segments: e.g., "3000 (0) 3000 (1) 3000 (0) ..."
        seg_str = " ".join([f"{length} ({val})" for (length, val) in segs])
        print(f"Gesture {gesture}: [{seg_str}]")
        # Check if any segment is less than 500 samples and print a warning if so
        # for length, val in segs:
        #     if length < 500:
        #         print(f"  Warning: In gesture '{gesture}', segment of value {val} with length {length} (<500)")
else:
    print(f"Subject {subject_to_debug} not found in processed_data.")

print("Processing complete. 'processed_data' now holds the aggregated binary sequences.")




Debug information for subject: HS1
Gesture closehand: [201 (0) 13 (1) 13552 (0) 19 (1) 56 (0) 6827 (1) 18 (0) 10 (1) 81 (0) 45 (1) 5411 (0) 88 (1) 9 (0) 11 (1) 60 (0) 99 (1) 55 (0) 5787 (1) 48 (0) 259 (1) 7 (0) 9 (1) 1 (0) 317 (1) 62 (0) 47 (1) 67 (0) 40 (1) 8 (0) 17 (1) 215 (0) 17 (1) 5186 (0) 176 (1) 13 (0) 5736 (1) 65 (0) 94 (1) 14 (0) 407 (1) 240 (0) 38 (1) 97 (0) 23 (1) 77 (0) 21 (1) 5022 (0) 182 (1) 69 (0) 5928 (1) 61 (0) 155 (1) 6 (0) 143 (1) 22 (0) 31 (1) 74 (0) 39 (1) 43 (0) 4 (1) 362 (0) 18 (1) 5275 (0) 1 (1) 5 (0) 31 (1) 8 (0) 71 (1) 85 (0) 6113 (1) 39 (0) 14 (1) 25 (0) 56 (1) 1 (0) 1 (1) 95 (0) 71 (1) 64 (0) 21 (1) 82 (0) 112 (1) 73 (0) 65 (1) 56 (0) 97 (1) 64 (0) 10 (1) 4942 (0) 5 (1) 18 (0) 129 (1) 69 (0) 5381 (1) 24 (0) 384 (1) 118 (0) 39 (1) 200 (0) 128 (1) 26 (0) 264 (1) 55 (0) 70 (1) 64 (0) 115 (1) 67 (0) 52 (1) 110 (0) 48 (1) 93 (0) 1 (1) 7 (0) 30 (1) 151 (0) 7 (1) 1 (0) 5 (1) 2 (0) 2 (1) 273 (0) 64 (1) 4122 (0) 96 (1) 8 (0) 31 (1) 17 (0) 6518 (1) 47 (0) 262 (1) 64 (