# Real-Time EMG to Angle Prediction with TensorFlow and LSL

This notebook demonstrates a more modular and robust approach to:
- Collecting real-time EMG (EXG) and angle (FingerPercentages) data from LSL streams.
- Applying online standardization to the EMG signals.
- Using a configurable CNN model to predict angles from recent EMG data.

**Features:**
- Adjustable EMG window size for the CNN input.
- More modular code structure with classes and separate functions.
- Graceful handling of missing or invalid data.
- Training on the most recent data available.

**Notes:**
- This example still assumes that LSL streams `filtered_exg` and `FingerPercentages` are available.
- Ensure you have the required dependencies installed: `pylsl`, `tensorflow`, `numpy`.
- Running this in real-time requires live LSL streams.

In [1]:
import time
import numpy as np
import threading
import tensorflow as tf
from pylsl import StreamInlet, resolve_stream, proc_ALL, local_clock
from collections import deque

## Online Standardizer
This class maintains a running mean and std using Welford's algorithm, allowing for online standardization of incoming data.


In [2]:
class OnlineStandardizer:
    """
    Maintains a running mean and std using Welford's algorithm.
    """
    def __init__(self, epsilon=1e-7):
        self.count = 0
        self.mean = 0.0
        self.M2 = 0.0
        self.epsilon = epsilon

    def update(self, x: np.ndarray):
        x = x.astype(np.float64)
        batch_count = x.shape[0]
        
        if batch_count == 0:
            return

        new_count = self.count + batch_count
        
        delta = x - self.mean
        self.mean += np.sum(delta, axis=0) / new_count
        delta2 = x - self.mean
        self.M2 += np.sum(delta * delta2, axis=0)
        self.count = new_count

    def get_stats(self):
        if self.count < 2:
            return self.mean, np.ones_like(self.mean)
        var = self.M2 / (self.count - 1)
        std = np.sqrt(var)
        std[std < self.epsilon] = self.epsilon
        return self.mean, std

    def apply(self, x: np.ndarray):
        mean, std = self.get_stats()
        x = (x - mean) / std
        return x

## EMG-Angle Collector
This class:
- Connects to specified LSL streams for EMG and angles.
- Continuously collects data in a background thread.
- Maintains a buffer of recent data.
- Applies online standardization to EMG signals.
- Provides a method to extract the most recent data window for training.

In [3]:
class EMGAngleCollector:
    def __init__(self, emg_stream_name="filtered_exg", angle_stream_name="FingerPercentages",
                 buffer_seconds=10, finger_thresholds = ((0.8, 0.9), (0.7, 0.875), (0.75, 0.875), (0.7, 0.8), (0.7, 0.825))):
        print("Resolving EMG (EXG) stream...")
        emg_streams = resolve_stream('name', emg_stream_name)
        if not emg_streams:
            raise RuntimeError(f"No EMG stream found with name '{emg_stream_name}'.")
        
        print("Resolving angle (MP) stream...")
        angle_streams = resolve_stream('name', angle_stream_name)
        if not angle_streams:
            raise RuntimeError(f"No angle stream found with name '{angle_stream_name}'.")

        self.emg_inlet = StreamInlet(emg_streams[0], processing_flags=proc_ALL, max_buflen=2)
        self.angle_inlet = StreamInlet(angle_streams[0], processing_flags=proc_ALL, max_buflen=2)
        self.finger_thresholds = np.array(finger_thresholds)

        # Get EMG stream info
        emg_info = self.emg_inlet.info()
        self.emg_rate = emg_info.nominal_srate()
        if self.emg_rate <= 0:
            raise ValueError("EMG stream sampling rate is not set or zero.")
        self.num_exg_channels = emg_info.channel_count()
        if self.num_exg_channels <= 0:
            raise ValueError("EMG stream channel_count is not set or zero.")

        # Get angle stream info
        angle_info = self.angle_inlet.info()
        self.angle_rate = angle_info.nominal_srate()
        if self.angle_rate <= 0:
            raise ValueError("Angle stream sampling rate is not set or zero.")
        self.num_mp_channels = angle_info.channel_count()
        if self.num_mp_channels <= 0:
            raise ValueError("Angle stream channel_count is not set or zero.")

        self.capacity = int(buffer_seconds * self.emg_rate)

        # Deques to store (timestamp, sample)
        self.emg_data_deque = deque(maxlen=self.capacity * 2)
        self.angle_data_deque = deque(maxlen=self.capacity * 2)

        self.lock = threading.Lock()
        self.stop_flag = False
        self.standardizer = OnlineStandardizer()

        self.thread = threading.Thread(target=self.run, daemon=True)
        self.thread.start()

    def run(self):
        while not self.stop_flag:
            self.update()
            time.sleep(0.01)

    def update(self):
        emg_data, emg_timestamps = self.emg_inlet.pull_chunk()
        angle_data, angle_timestamps = self.angle_inlet.pull_chunk()

        if not emg_data or not angle_data:
            return

        # Convert to arrays
        emg_data = np.array(emg_data, dtype=np.float32)
        angle_data = np.array(angle_data, dtype=np.float32)

        # Truncate to minimum length
        length = min(len(emg_data), len(angle_data))
        emg_data = emg_data[:length]
        angle_data = angle_data[:length]
        emg_timestamps = emg_timestamps[:length]
        angle_timestamps = angle_timestamps[:length]

        # Now create nan_mask safely
        nan_mask = np.isnan(emg_data).any(axis=1) | np.isnan(angle_data).any(axis=1)
        emg_data = emg_data[~nan_mask]
        angle_data = angle_data[~nan_mask]
        emg_timestamps = np.array(emg_timestamps)[~nan_mask]
        angle_timestamps = np.array(angle_timestamps)[~nan_mask]

        if len(emg_data) == 0:
            return

        with self.lock:
            for i in range(len(emg_data)):
                self.emg_data_deque.append((emg_timestamps[i], emg_data[i]))
            for j in range(len(angle_data)):
                self.angle_data_deque.append((angle_timestamps[j], angle_data[j]))

    def stop(self):
        self.stop_flag = True
        self.thread.join()

    def get_example(self, window_size: int):
        """
        Retrieve a single example aligned by timestamp:
        For an angle sample at time t_angle, we find the EMG samples that are just before t_angle.
        If we can form a full window, return (X, y).
        Otherwise return None.
        """
        with self.lock:
            if len(self.angle_data_deque) == 0:
                return None

            # Take the most recent angle sample
            t_angle, angle_val = self.angle_data_deque[-1]
            
            # Clip based on the finger thresholds
            angle_val = np.clip(angle_val, self.finger_thresholds[:, 0], self.finger_thresholds[:, 1])
            # Normalize to [0, 1]
            angle_val = (angle_val - self.finger_thresholds[:, 0]) / (self.finger_thresholds[:, 1] - self.finger_thresholds[:, 0])

            # Extract EMG samples before t_angle
            emg_times = np.array([x[0] for x in self.emg_data_deque])
            emg_data = np.array([x[1] for x in self.emg_data_deque])

            valid_indices = np.where(emg_times < t_angle)[0]
            if len(valid_indices) < window_size:
                # Not enough EMG data before this angle sample
                return None

            chosen_indices = valid_indices[-window_size:]
            X = emg_data[chosen_indices, :]  # (window_size, num_exg_channels)
            # For simplicity, use the current angle_val as the target (single time step)
            y = angle_val[np.newaxis, ...]   # (1, num_mp_channels)
            return X, y

## Data Generator

This generator yields batches of data to the model for training. It:
- Continuously calls `get_latest_batch` to get the most recent window.
- Checks for NaNs or Infs.
- Adds a batch dimension.


In [4]:
def data_generator(collector: EMGAngleCollector, window_size: int, batch_size: int):
    while True:
        X_batch = []
        y_batch = []

        for _ in range(batch_size):
            example = None
            # Wait until we get a valid example
            while example is None:
                example = collector.get_example(window_size)
                if example is None:
                    time.sleep(0.01)
            X, y = example
            X_batch.append(X)
            y_batch.append(y)

        X_batch = np.array(X_batch, dtype=np.float32)  # (batch_size, window_size, num_exg_channels)
        y_batch = np.array(y_batch, dtype=np.float32)  # (batch_size, 1, num_mp_channels)

        # Reshape y to (batch_size, num_mp_channels)
        y_batch = np.squeeze(y_batch, axis=1)
        
        print(X_batch.shape, y_batch.shape)
        # print the distribution of the target
        print('Mean', np.mean(y_batch, axis=0))
        print('STD', np.std(y_batch, axis=0))
        

        yield X_batch, y_batch

## Model Creation

A simple CNN model for sequence data. The kernel size and number of filters can be adjusted.

We make the model creation more modular to easily tweak hyperparameters.


In [5]:
def create_model(sequence_length, num_exg_channels, num_mp_channels, filters=32, kernel_size=3):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(sequence_length, num_exg_channels)),
        tf.keras.layers.GaussianNoise(0.05),
        tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Conv1D(filters=filters*2, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Conv1D(filters=filters*2, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        # Output one value per MP channel
        tf.keras.layers.Dense(num_mp_channels)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

## Training Loop

Here we:
- Initialize the collector with chosen parameters.
- Determine the sequence_length (window_size) from user input.
- Create a `tf.data.Dataset` from the data generator.
- Train the model.

In [6]:
# User-configurable parameters
EMG_STREAM_NAME = "filtered_exg"
ANGLE_STREAM_NAME = "FingerPercentages"
BUFFER_SECONDS = 2
WINDOW_SIZE = 50  # Number of EMG samples per example
BATCH_SIZE = 512      # Examples per training step
STEPS_PER_EPOCH = 10
EPOCHS = 1000
FILTERS = 32
KERNEL_SIZE = 3

# Initialize collector
collector = EMGAngleCollector(
    emg_stream_name=EMG_STREAM_NAME,
    angle_stream_name=ANGLE_STREAM_NAME,
    buffer_seconds=BUFFER_SECONDS,
)

num_exg_channels = collector.num_exg_channels
num_mp_channels = collector.num_mp_channels

# Create dataset
dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(collector, window_size=WINDOW_SIZE, batch_size=BATCH_SIZE),
    output_signature=(
        tf.TensorSpec(shape=(BATCH_SIZE, WINDOW_SIZE, num_exg_channels), dtype=tf.float32),
        tf.TensorSpec(shape=(BATCH_SIZE, num_mp_channels), dtype=tf.float32)
    )
).prefetch(1)

# Create model
model = create_model(WINDOW_SIZE, num_exg_channels, num_mp_channels, filters=FILTERS, kernel_size=KERNEL_SIZE)

print("Starting training loop. Press Interrupt to stop.")
try:
    model.fit(dataset, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS)
except KeyboardInterrupt:
    print("Training interrupted by user.")
finally:
    collector.stop()
    print("Stopped data collection and training.")

Resolving EMG (EXG) stream...
Resolving angle (MP) stream...


2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:102   INFO| 	IPv4 addr: 7f000001
2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:105   INFO| 	IPv6 addr: ::1
2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:105   INFO| 	IPv6 addr: fe80::1%lo0
2024-12-12 21:15:15.591 (   0.038s) [          166927]      netinterfaces.cpp:91    I

Starting training loop. Press Interrupt to stop.
Epoch 1/1000
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 1/10[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m24s[0m 3s/step - loss: 740.3913(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 4/10[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m0s[0m 20ms/step - loss: 454.0529(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 6/10[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m0s[0m 23ms/step - loss: 383.5713(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0. 0. 0. 0. 0.]
[1m 8/10[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m0s[0m 26ms/step - loss: 335.8624(512, 50, 16) (512, 5)
Mean [1. 1. 1. 1. 1.]
STD [0