# Real-Time EMG to Angle Prediction with TensorFlow and LSL

This notebook demonstrates a more modular and robust approach to:
- Collecting real-time EMG (EXG) and angle (FingerPercentages) data from LSL streams.
- Applying online standardization to the EMG signals.
- Using a configurable CNN model to predict angles from recent EMG data.

**Features:**
- Adjustable EMG window size for the CNN input.
- More modular code structure with classes and separate functions.
- Graceful handling of missing or invalid data.
- Training on the most recent data available.

**Notes:**
- This example still assumes that LSL streams `filtered_exg` and `FingerPercentages` are available.
- Ensure you have the required dependencies installed: `pylsl`, `tensorflow`, `numpy`.
- Running this in real-time requires live LSL streams.

In [1]:
import time
import numpy as np
import threading
import tensorflow as tf
from pylsl import StreamInlet, resolve_stream, proc_ALL
from collections import deque

## Online Standardizer
This class maintains a running mean and std using Welford's algorithm, allowing for online standardization of incoming data.


In [2]:
class OnlineStandardizer:
    """
    Maintains a running mean and std using Welford's algorithm.
    """
    def __init__(self, epsilon=1e-7):
        self.count = 0
        self.mean = 0.0
        self.M2 = 0.0
        self.epsilon = epsilon

    def update(self, x: np.ndarray):
        batch_count = x.shape[0]
        
        if batch_count == 0:
            return

        new_count = self.count + batch_count
        
        delta = x - self.mean
        self.mean += np.sum(delta, axis=0) / new_count
        delta2 = x - self.mean
        self.M2 += np.sum(delta * delta2, axis=0)
        self.count = new_count

    def get_stats(self):
        if self.count < 2:
            return self.mean, np.ones_like(self.mean)
        var = self.M2 / (self.count - 1)
        std = np.sqrt(var)
        std[std < self.epsilon] = self.epsilon
        return self.mean, std

    def apply(self, x: np.ndarray):
        mean, std = self.get_stats()
        x = (x - mean) / std
        return x

## EMG-Angle Collector
This class:
- Connects to specified LSL streams for EMG and angles.
- Continuously collects data in a background thread.
- Maintains a buffer of recent data.
- Applies online standardization to EMG signals.
- Provides a method to extract the most recent data window for training.

In [3]:
class NumpyBuffer:
    def __init__(self, capacity, shape, dtype):
        self.capacity = capacity
        self.shape = shape
        self.dtype = dtype
        self.buffer = np.zeros((capacity, *shape), dtype=dtype)
        self.timestamps = np.zeros(capacity, dtype=np.float64)
        self.size = 0
        
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        # Adjust to work with negative indices and slices
        def adjust_index(idx, limit=self.size):
            if idx < 0:
                idx += self.size
            if idx < 0 or idx >= limit:
                raise IndexError("Index out of bounds.")
            return idx
        
        if isinstance(idx, int):
            idx = adjust_index(idx)
            return self.timestamps[idx], self.buffer[idx]
        elif isinstance(idx, slice):
            start, stop, step = idx.indices(self.size)
            start = adjust_index(start)
            stop = adjust_index(stop, self.size + 1)
            timestamps = self.timestamps[start:stop:step]
            buffer = self.buffer[start:stop:step]
            return timestamps, buffer
        else:
            raise TypeError("Invalid index type.")
        
    def expand(self):
        new_buffer = np.zeros((self.capacity * 2, *self.shape), dtype=self.dtype)
        new_timestamps = np.zeros(self.capacity * 2, dtype=np.float64)
        new_buffer[:self.size] = self.buffer
        new_timestamps[:self.size] = self.timestamps
        self.buffer = new_buffer
        self.timestamps = new_timestamps
        self.capacity *= 2
    
    def extend(self, timestamps, samples):
        if len(timestamps) != len(samples):
            raise ValueError("Timestamps and samples must have the same length.")
        
        new_size = len(samples)
        
        if self.size + new_size > self.capacity:
            self.expand()

        self.buffer[self.size:self.size+new_size] = samples
        self.timestamps[self.size:self.size+new_size] = timestamps
        self.size += new_size

    def get_numpy_buffers(self):
        return self.timestamps[:self.size], self.buffer[:self.size]

    def clear(self):
        self.size = 0

class EMGAngleCollector:
    def __init__(self, emg_stream_name="filtered_exg", angle_stream_name="FingerPercentages",
                 buffer_seconds=5, finger_thresholds = ((0.8, 0.9), (0.7, 0.875), (0.75, 0.875), (0.7, 0.8), (0.7, 0.825))):
        print("Resolving EMG (EXG) stream...")
        emg_streams = resolve_stream('name', emg_stream_name)
        if not emg_streams:
            raise RuntimeError(f"No EMG stream found with name '{emg_stream_name}'.")
        
        print("Resolving angle (MP) stream...")
        angle_streams = resolve_stream('name', angle_stream_name)
        if not angle_streams:
            raise RuntimeError(f"No angle stream found with name '{angle_stream_name}'.")

        self.emg_inlet = StreamInlet(emg_streams[0], processing_flags=proc_ALL, max_buflen=2)
        self.angle_inlet = StreamInlet(angle_streams[0], processing_flags=proc_ALL, max_buflen=2)
        self.finger_thresholds = np.array(finger_thresholds)

        # Get EMG stream info
        emg_info = self.emg_inlet.info()
        self.emg_rate = emg_info.nominal_srate()
        if self.emg_rate <= 0:
            raise ValueError("EMG stream sampling rate is not set or zero.")
        self.num_exg_channels = emg_info.channel_count()
        if self.num_exg_channels <= 0:
            raise ValueError("EMG stream channel_count is not set or zero.")

        # Get angle stream info
        angle_info = self.angle_inlet.info()
        self.angle_rate = angle_info.nominal_srate()
        if self.angle_rate <= 0:
            raise ValueError("Angle stream sampling rate is not set or zero.")
        self.num_mp_channels = angle_info.channel_count()
        if self.num_mp_channels <= 0:
            raise ValueError("Angle stream channel_count is not set or zero.")

        self.capacity = int(buffer_seconds * self.emg_rate)

        # Deques to store EMG and angle data
        self.emg_deque = deque(maxlen=self.capacity)
        self.emg_timestamp_deque = deque(maxlen=self.capacity)
        self.angle_deque = deque(maxlen=self.capacity)
        self.angle_timestamp_deque = deque(maxlen=self.capacity)

        self.lock = threading.Lock()
        self.stop_flag = False
        self.standardizer = OnlineStandardizer()
        
        self.update_fps = 50
        self.update_interval = 1 / self.update_fps

        self.thread = threading.Thread(target=self.run, daemon=True)
        self.thread.start()

    def run(self):
        last_time = time.perf_counter()
        while not self.stop_flag:
            self.update()
            current_time = time.perf_counter()
            elapsed_time = current_time - last_time
            if elapsed_time < self.update_interval:
                time.sleep(self.update_interval - elapsed_time)
            last_time = current_time

    def update(self):
        angle_data, angle_timestamps = self.angle_inlet.pull_chunk()
        emg_data, emg_timestamps = self.emg_inlet.pull_chunk()

        if not emg_data or not angle_data:
            return

        # Convert to arrays
        emg_data = np.array(emg_data, dtype=np.float32)
        angle_data = np.array(angle_data, dtype=np.float32)
        
        # Now create nan_mask safely
        emg_nan_mask = np.isnan(emg_data).any(axis=1)
        angle_nan_mask = np.isnan(angle_data).any(axis=1)
        emg_data = emg_data[~emg_nan_mask]
        angle_data = angle_data[~angle_nan_mask]
        emg_timestamps = np.array(emg_timestamps)[~emg_nan_mask]
        angle_timestamps = np.array(angle_timestamps)[~angle_nan_mask]
        
        with self.lock:
            if len(emg_data) != 0:
                self.emg_deque.extend(emg_data))
                self.emg_timestamp_deque.extend(emg_timestamps)
                self.standardizer.update(emg_data)
            if len(angle_data) != 0:
                self.angle_deque.extend(angle_data)
                self.angle_timestamp_deque.extend(angle_timestamps)
                
    def get_data(self):
        with self.lock:
            emg = np.array(self.emg_deque)
            emg_timestamp = np.array(self.emg_timestamp_deque)
            angle_deque_np = np.array(self.angle_deque)
            angle_timestamp = np.array(self.angle_timestamp_deque)
            self.emg_deque.clear()
            self.emg_timestamp_deque.clear()
            self.angle_deque.clear()
            self.angle_timestamp_deque.clear()
        return (emg_timestamp, emg), (angle_timestamp, angle_deque_np)
                
    def put_numpy_buffers(self, emg_buffer: NumpyBuffer, angle_buffer: NumpyBuffer):
        timestamped_emg, timestamped_angle = self.get_data()
        emg_buffer.extend(timestamped_emg[0], timestamped_emg[1])
        angle_buffer.extend(timestamped_angle[0], timestamped_angle[1])
    
    def stop(self):
        self.stop_flag = True
        self.thread.join()

## Data Generator

This generator yields batches of data to the model for training. It:
- Continuously calls `get_latest_batch` to get the most recent window.
- Checks for NaNs or Infs.
- Adds a batch dimension.


In [4]:
def data_generator(collector: EMGAngleCollector, window_size: int, exg_buffer: NumpyBuffer, mp_buffer: NumpyBuffer):
    while True:
        X_batch = []
        y_batch = []
            
        timestamps_exg, timestamped_mp = collector.get_data()
        X, y = example
        X_batch.append(X)
        y_batch.append(y)

        X_batch = np.array(X_batch, dtype=np.float32)  # (batch_size, window_size, num_exg_channels)
        y_batch = np.array(y_batch, dtype=np.float32)  # (batch_size, 1, num_mp_channels)

        # Reshape y to (batch_size, num_mp_channels)
        y_batch = np.squeeze(y_batch, axis=1)
        
        # print the distribution of the target
        print('Mean', np.mean(y_batch, axis=0))
        print('STD', np.std(y_batch, axis=0))

        yield X_batch, y_batch

## Model Creation

A simple CNN model for sequence data. The kernel size and number of filters can be adjusted.

We make the model creation more modular to easily tweak hyperparameters.


In [5]:
def create_model(sequence_length, num_exg_channels, num_mp_channels, filters=32, kernel_size=3):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(sequence_length, num_exg_channels)),
        tf.keras.layers.GaussianNoise(0.05),
        
        tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        
        tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        
        tf.keras.layers.Conv1D(filters=filters*2, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        
        tf.keras.layers.Conv1D(filters=filters*2, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(32, activation='relu'),
        # Output one value per MP channel
        tf.keras.layers.Dense(num_mp_channels)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

## Training Loop

Here we:
- Initialize the collector with chosen parameters.
- Determine the sequence_length (window_size) from user input.
- Create a `tf.data.Dataset` from the data generator.
- Train the model.

In [7]:
# User-configurable parameters
EMG_STREAM_NAME = "filtered_exg"
ANGLE_STREAM_NAME = "FingerPercentages"
BUFFER_SECONDS = 30
WINDOW_SIZE = 50  # Number of EMG samples per example
BATCH_SIZE = 512      # Examples per training step
STEPS_PER_EPOCH = 10
EPOCHS = 1000
FILTERS = 32
KERNEL_SIZE = 3

# Initialize collector
collector = EMGAngleCollector(
    emg_stream_name=EMG_STREAM_NAME,
    angle_stream_name=ANGLE_STREAM_NAME,
    buffer_seconds=BUFFER_SECONDS,
)

num_exg_channels = collector.num_exg_channels
num_mp_channels = collector.num_mp_channels

# Create dataset
dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(collector, window_size=WINDOW_SIZE, batch_size=BATCH_SIZE),
    output_signature=(
        tf.TensorSpec(shape=(BATCH_SIZE, WINDOW_SIZE, num_exg_channels), dtype=tf.float32),
        tf.TensorSpec(shape=(BATCH_SIZE, num_mp_channels), dtype=tf.float32)
    )
).prefetch(1)

# Create model
model = create_model(WINDOW_SIZE, num_exg_channels, num_mp_channels, filters=FILTERS, kernel_size=KERNEL_SIZE)

print("Starting training loop. Press Interrupt to stop.")
try:
    model.fit(dataset, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS)
except KeyboardInterrupt:
    print("Training interrupted by user.")
finally:
    collector.stop()
    print("Stopped data collection and training.")

Resolving EMG (EXG) stream...
Resolving angle (MP) stream...
Starting training loop. Press Interrupt to stop.
Epoch 1/1000
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 1/10[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m11s[0m 1s/step - loss: 1593.1334(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 4/10[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m0s[0m 31ms/step - loss: 985.6958(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 6/10[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m0s[0m 38ms/step - loss: 815.3760(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 7/10[0m [32m━━━━━━━━━━━━━━[0m[37m━━━━━━[0m [1m0s[0m 40ms/step - loss: 759.0032(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0.

2024-12-12 21:23:32.294 (  95.223s) [R_FingerPercen  ]      data_receiver.cpp:344    ERR| Stream transmission broke off (Input stream error.); re-connecting...
2024-12-12 21:23:33.971 (  96.900s) [R_filtered_exg  ]      data_receiver.cpp:344    ERR| Stream transmission broke off (Input stream error.); re-connecting...
