# Real-Time EMG to Angle Prediction with TensorFlow and LSL

This notebook demonstrates a more modular and robust approach to:
- Collecting real-time EMG (EXG) and angle (FingerPercentages) data from LSL streams.
- Applying online standardization to the EMG signals.
- Using a configurable CNN model to predict angles from recent EMG data.

**Features:**
- Adjustable EMG window size for the CNN input.
- More modular code structure with classes and separate functions.
- Graceful handling of missing or invalid data.
- Training on the most recent data available.

**Notes:**
- This example still assumes that LSL streams `filtered_exg` and `FingerPercentages` are available.
- Ensure you have the required dependencies installed: `pylsl`, `tensorflow`, `numpy`.
- Running this in real-time requires live LSL streams.

In [1]:
import time
import numpy as np
import threading
import tensorflow as tf
from pylsl import StreamInlet, resolve_stream, proc_ALL, local_clock

## Online Standardizer
This class maintains a running mean and std using Welford's algorithm, allowing for online standardization of incoming data.


In [2]:
class OnlineStandardizer:
    """
    Maintains a running mean and std using Welford's algorithm.
    """
    def __init__(self, epsilon=1e-7):
        self.count = 0
        self.mean = 0.0
        self.M2 = 0.0
        self.epsilon = epsilon

    def update(self, x: np.ndarray):
        x = x.astype(np.float64)
        batch_count = x.shape[0]
        
        if batch_count == 0:
            return

        new_count = self.count + batch_count
        
        delta = x - self.mean
        self.mean += np.sum(delta, axis=0) / new_count
        delta2 = x - self.mean
        self.M2 += np.sum(delta * delta2, axis=0)
        self.count = new_count

    def get_stats(self):
        if self.count < 2:
            return self.mean, np.ones_like(self.mean)
        var = self.M2 / (self.count - 1)
        std = np.sqrt(var)
        std[std < self.epsilon] = self.epsilon
        return self.mean, std

    def apply(self, x: np.ndarray):
        mean, std = self.get_stats()
        x = (x - mean) / std
        return x

## EMG-Angle Collector
This class:
- Connects to specified LSL streams for EMG and angles.
- Continuously collects data in a background thread.
- Maintains a buffer of recent data.
- Applies online standardization to EMG signals.
- Provides a method to extract the most recent data window for training.

In [3]:
class EMGAngleCollector:
    def __init__(self, emg_stream_name="filtered_exg", angle_stream_name="FingerPercentages",
                 buffer_seconds=10, batch_size=128):
        """
        Initialize the collector.
        
        Arguments:
        - emg_stream_name: Name of the EMG LSL stream.
        - angle_stream_name: Name of the angle LSL stream.
        - buffer_seconds: How many seconds of data to keep in buffer.
        - batch_size: The window size of data extracted for training.
        """
        
        print("Resolving EMG (EXG) stream...")
        emg_streams = resolve_stream('name', emg_stream_name)
        if not emg_streams:
            raise RuntimeError(f"No EMG stream found with name '{emg_stream_name}'.")

        print("Resolving angle (MP) stream...")
        angle_streams = resolve_stream('name', angle_stream_name)
        if not angle_streams:
            raise RuntimeError(f"No angle stream found with name '{angle_stream_name}'.")

        self.emg_inlet = StreamInlet(emg_streams[0], processing_flags=proc_ALL, max_buflen=2)
        self.angle_inlet = StreamInlet(angle_streams[0], processing_flags=proc_ALL, max_buflen=2)

        # Get EMG stream info
        emg_info = self.emg_inlet.info()
        self.emg_rate = emg_info.nominal_srate()
        if self.emg_rate <= 0:
            raise ValueError("EMG stream sampling rate is not set or zero.")
        self.num_exg_channels = emg_info.channel_count()
        if self.num_exg_channels <= 0:
            raise ValueError("EMG stream channel_count is not set or zero.")

        # Get angle stream info
        angle_info = self.angle_inlet.info()
        self.angle_rate = angle_info.nominal_srate()
        if self.angle_rate <= 0:
            raise ValueError("Angle stream sampling rate is not set or zero.")
        self.num_mp_channels = angle_info.channel_count()
        if self.num_mp_channels <= 0:
            raise ValueError("Angle stream channel_count is not set or zero.")

        self.batch_size = batch_size
        self.capacity = int(buffer_seconds * self.emg_rate)

        # Buffers
        self.emg_buffer = np.zeros((self.capacity, self.num_exg_channels), dtype=np.float32)
        self.angle_buffer = np.zeros((self.capacity, self.num_mp_channels), dtype=np.float32)

        self.size = 0
        self.lock = threading.Lock()
        self.stop_flag = False

        # Online standardizer for EMG
        self.standardizer = OnlineStandardizer()

        # Start collection thread
        self.thread = threading.Thread(target=self.run, daemon=True)
        self.thread.start()

    def run(self):
        while not self.stop_flag:
            self.update()
            time.sleep(0.01)

    def update(self):
        # Fetch data
        emg_data, emg_timestamps = self.emg_inlet.pull_chunk()
        angle_data, angle_timestamps = self.angle_inlet.pull_chunk()

        if emg_data and angle_data:
            length = min(len(emg_data), len(angle_data))

            emg_data = np.array(emg_data[:length], dtype=np.float32)  # shape: (length, num_exg_channels)
            angle_data = np.array(angle_data[:length], dtype=np.float32)  # shape: (length, num_mp_channels)
            
            # Remove NaN examples from EMG data and angle data
            nan_indices = np.isnan(emg_data).any(axis=1) | np.isnan(angle_data).any(axis=1)
            emg_data = emg_data[~nan_indices]
            angle_data = angle_data[~nan_indices]
            
            length = len(emg_data)

            # Rectify EMG data
            emg_data = np.abs(emg_data)
            
            # Update standardizer and apply scaling
            self.standardizer.update(emg_data)
            emg_data = self.standardizer.apply(emg_data)

            with self.lock:
                space_left = self.capacity - self.size
                if length > space_left:
                    # Overwrite oldest data
                    shift = length - space_left
                    self.emg_buffer = np.roll(self.emg_buffer, -shift, axis=0)
                    self.angle_buffer = np.roll(self.angle_buffer, -shift, axis=0)
                    self.size = self.capacity - length

                start_idx = self.size
                end_idx = self.size + length
                self.emg_buffer[start_idx:end_idx] = emg_data
                self.angle_buffer[start_idx:end_idx] = angle_data
                self.size += length

    def stop(self):
        self.stop_flag = True
        self.thread.join()

    def get_latest_batch(self, window_size: int = None):
        """
        Return the most recent `window_size` samples.
        If `window_size` is None, it will return self.batch_size samples.
        """
        if window_size is None:
            window_size = self.batch_size

        with self.lock:
            if self.size < window_size:
                return None, None
            X = self.emg_buffer[self.size - window_size:self.size]
            y = self.angle_buffer[self.size - window_size:self.size]
        return X, y

## Data Generator

This generator yields batches of data to the model for training. It:
- Continuously calls `get_latest_batch` to get the most recent window.
- Checks for NaNs or Infs.
- Adds a batch dimension.


In [4]:
def data_generator(collector: EMGAngleCollector, window_size: int):
    while True:
        X, y = collector.get_latest_batch(window_size)
        if X is None:
            # Not enough data yet
            time.sleep(0.1)
            continue
        
        # Validate data
        if np.isnan(X).any() or np.isinf(X).any():
            print("Invalid values detected in input data (X). Skipping this batch.")
            continue
        if np.isnan(y).any() or np.isinf(y).any():
            print("Invalid values detected in target data (y). Skipping this batch.")
            continue

        # Reshape X and y to include batch dimension
        # For TensorFlow training, we may want shape: (batch, time, features)
        # Here, we treat the entire window as a single example (batch=1)
        X = X[np.newaxis, ...]  # Shape: (1, window_size, num_exg_channels)
        y = y[np.newaxis, ...]  # Shape: (1, window_size, num_mp_channels)
        yield X, y

## Model Creation

A simple CNN model for sequence data. The kernel size and number of filters can be adjusted.

We make the model creation more modular to easily tweak hyperparameters.


In [5]:
def create_model(sequence_length, num_exg_channels, num_mp_channels, filters=32, kernel_size=3):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(sequence_length, num_exg_channels)),
        tf.keras.layers.GaussianNoise(0.05),
        tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Conv1D(filters=filters*2, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Conv1D(filters=filters*2, kernel_size=kernel_size, padding='same', activation='relu'),
        tf.keras.layers.MaxPooling1D(pool_size=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        # Output one value per MP channel
        tf.keras.layers.Dense(num_mp_channels)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

## Training Loop

Here we:
- Initialize the collector with chosen parameters.
- Determine the sequence_length (window_size) from user input.
- Create a `tf.data.Dataset` from the data generator.
- Train the model.

In [6]:
# User-configurable parameters
EMG_STREAM_NAME = "filtered_exg"
ANGLE_STREAM_NAME = "FingerPercentages"
BUFFER_SECONDS = 5
WINDOW_SIZE = 512  # The size of the EMG window for the CNN
STEPS_PER_EPOCH = 10
EPOCHS = 100
FILTERS = 32
KERNEL_SIZE = 3

# Initialize the collector
collector = EMGAngleCollector(
    emg_stream_name=EMG_STREAM_NAME,
    angle_stream_name=ANGLE_STREAM_NAME,
    buffer_seconds=BUFFER_SECONDS,
    batch_size=WINDOW_SIZE
)

# Extract shape info
sequence_length = WINDOW_SIZE
num_exg_channels = collector.num_exg_channels
num_mp_channels = collector.num_mp_channels

# Create dataset
dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(collector, window_size=sequence_length),
    output_signature=(
        tf.TensorSpec(shape=(1, sequence_length, num_exg_channels), dtype=tf.float32),
        tf.TensorSpec(shape=(1, sequence_length, num_mp_channels), dtype=tf.float32)
    )
).prefetch(1)

# Create model
model = create_model(sequence_length, num_exg_channels, num_mp_channels, filters=FILTERS, kernel_size=KERNEL_SIZE)

print("Starting training loop. Press Interrupt to stop.")
try:
    model.fit(dataset, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS)
except KeyboardInterrupt:
    print("Training interrupted by user.")
finally:
    collector.stop()
    print("Stopped data collection and training.")

Resolving EMG (EXG) stream...
Resolving angle (MP) stream...
Starting training loop. Press Interrupt to stop.
Epoch 1/100


2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:102   INFO| 	IPv4 addr: 7f000001
2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:105   INFO| 	IPv6 addr: ::1
2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:105   INFO| 	IPv6 addr: fe80::1%lo0
2024-12-12 20:41:04.147 (   4.610s) [          15F0AF]      netinterfaces.cpp:91    I

[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - loss: 0.6292
Epoch 2/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0268
Epoch 3/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0113
Epoch 4/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0122
Epoch 5/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0125
Epoch 6/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0077
Epoch 7/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0130
Epoch 8/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0118
Epoch 9/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0116
Epoch 10/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0141
Epoch 11/1