In [3]:
import time
import datetime
import numpy as np
import threading
import tensorflow as tf
from pylsl import StreamInlet, resolve_stream, proc_ALL, local_clock

class OnlineStandardizer:
    """
    Maintains a running mean and std using Welford's algorithm.
    """
    def __init__(self, epsilon=1e-7):
        self.count = 0
        self.mean = 0.0
        self.M2 = 0.0
        self.epsilon = epsilon

    def update(self, x: np.ndarray):
        x = x.astype(np.float64)
        batch_count = x.shape[0]

        new_count = self.count + batch_count
        delta = x - self.mean
        self.mean += np.sum(delta, axis=0) / new_count
        delta2 = x - self.mean
        self.M2 += np.sum(delta * delta2, axis=0)
        self.count = new_count

    def get_stats(self):
        if self.count < 2:
            return self.mean, np.ones_like(self.mean)
        var = self.M2 / (self.count - 1)
        std = np.sqrt(var)
        std[std < self.epsilon] = self.epsilon
        return self.mean, std

    def apply(self, x: np.ndarray):
        mean, std = self.get_stats()
        x = (x - mean) / std
        return x

class EMGAngleCollector:
    def __init__(self, emg_stream_name="filtered_exg", angle_stream_name="FingerPercentages",
                 buffer_seconds=10, batch_size=128):
        print("Resolving EMG (EXG) stream...")
        emg_streams = resolve_stream('name', emg_stream_name)
        if not emg_streams:
            raise RuntimeError(f"No EMG stream found with name '{emg_stream_name}'.")

        print("Resolving angle (MP) stream...")
        angle_streams = resolve_stream('name', angle_stream_name)
        if not angle_streams:
            raise RuntimeError(f"No angle stream found with name '{angle_stream_name}'.")

        self.emg_inlet = StreamInlet(emg_streams[0], processing_flags=proc_ALL, max_buflen=2)
        self.angle_inlet = StreamInlet(angle_streams[0], processing_flags=proc_ALL, max_buflen=2)

        # Get EMG stream info
        emg_info = self.emg_inlet.info()
        self.emg_rate = emg_info.nominal_srate()
        if self.emg_rate <= 0:
            raise ValueError("EMG stream sampling rate is not set or zero.")
        self.num_exg_channels = emg_info.channel_count()
        if self.num_exg_channels <= 0:
            raise ValueError("EMG stream channel_count is not set or zero.")

        # Get angle stream info
        angle_info = self.angle_inlet.info()
        self.angle_rate = angle_info.nominal_srate()
        if self.angle_rate <= 0:
            raise ValueError("Angle stream sampling rate is not set or zero.")
        self.num_mp_channels = angle_info.channel_count()
        if self.num_mp_channels <= 0:
            raise ValueError("Angle stream channel_count is not set or zero.")

        self.batch_size = batch_size
        self.capacity = int(buffer_seconds * self.emg_rate)

        # Buffers
        self.emg_buffer = np.zeros((self.capacity, self.num_exg_channels), dtype=np.float32)
        self.angle_buffer = np.zeros((self.capacity, self.num_mp_channels), dtype=np.float32)

        self.size = 0
        self.lock = threading.Lock()
        self.stop_flag = False

        # Online standardizer for EMG
        self.standardizer = OnlineStandardizer()

        # Start collection thread
        self.thread = threading.Thread(target=self.run, daemon=True)
        self.thread.start()

    def run(self):
        while not self.stop_flag:
            self.update()
            time.sleep(0.01)

    def update(self):
        # Fetch data
        emg_data, emg_timestamps = self.emg_inlet.pull_chunk()
        angle_data, angle_timestamps = self.angle_inlet.pull_chunk()

        if emg_data and angle_data:
            length = min(len(emg_data), len(angle_data))

            emg_data = np.array(emg_data[:length], dtype=np.float32)  # shape: (length, num_exg_channels)
            angle_data = np.array(angle_data[:length], dtype=np.float32)  # shape: (length, num_mp_channels)

            # Rectify EMG data
            emg_data = np.abs(emg_data)

            # Update standardizer and apply scaling
            self.standardizer.update(emg_data)
            emg_data = self.standardizer.apply(emg_data)

            with self.lock:
                space_left = self.capacity - self.size
                if length > space_left:
                    # Overwrite oldest data
                    shift = length - space_left
                    self.emg_buffer = np.roll(self.emg_buffer, -shift, axis=0)
                    self.angle_buffer = np.roll(self.angle_buffer, -shift, axis=0)
                    self.size = self.capacity - length

                start_idx = self.size
                end_idx = self.size + length
                self.emg_buffer[start_idx:end_idx] = emg_data
                self.angle_buffer[start_idx:end_idx] = angle_data
                self.size += length

    def stop(self):
        self.stop_flag = True
        self.thread.join()

    def get_latest_batch(self):
        with self.lock:
            if self.size < self.batch_size:
                return None, None
            X = self.emg_buffer[self.size - self.batch_size:self.size]
            y = self.angle_buffer[self.size - self.batch_size:self.size]
        return X, y

def data_generator(collector: EMGAngleCollector):
    while True:
        X, y = collector.get_latest_batch()
        if X is None:
            time.sleep(0.1)
            continue
        # Reshape X and y to include batch dimension
        X = X[np.newaxis, ...]  # Shape: (1, batch_size, num_exg_channels)
        y = y[np.newaxis, ...]  # Shape: (1, batch_size, num_mp_channels)
        yield X, y

def create_model(sequence_length, num_exg_channels, num_mp_channels):
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(sequence_length, num_exg_channels)),  # Fixed sequence length
        tf.keras.layers.Conv1D(filters=32, kernel_size=5, padding='same', activation='relu'),
        tf.keras.layers.Conv1D(filters=64, kernel_size=5, padding='same', activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        # Output one value per MP channel (e.g., multiple fingers)
        tf.keras.layers.Dense(num_mp_channels)
    ])
    model.compile(optimizer='adam', loss='mse')
    return model

# Initialize the collector and let it determine channel counts
collector = EMGAngleCollector(
    emg_stream_name="filtered_exg",
    angle_stream_name="FingerPercentages",
    buffer_seconds=10,
    batch_size=128
)

# Define the fixed sequence length (must match batch_size)
sequence_length = collector.batch_size

num_exg_channels = collector.num_exg_channels
num_mp_channels = collector.num_mp_channels

# Create the dataset with appropriate shapes
dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(collector),
    output_signature=(
        tf.TensorSpec(shape=(1, sequence_length, num_exg_channels), dtype=tf.float32),
        tf.TensorSpec(shape=(1, sequence_length, num_mp_channels), dtype=tf.float32)
    )
).prefetch(1)

# Create the model with fixed sequence length
model = create_model(sequence_length, num_exg_channels, num_mp_channels)

print("Starting training loop. Press Interrupt to stop.")
try:
    model.fit(dataset, steps_per_epoch=10, epochs=5)
except KeyboardInterrupt:
    print("Training interrupted by user.")
finally:
    collector.stop()
    print("Stopped data collection and training.")


Resolving EMG (EXG) stream...
Resolving angle (MP) stream...
Starting training loop. Press Interrupt to stop.
Epoch 1/5
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - loss: nan
Epoch 2/5
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - loss: nan
Epoch 3/5
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - loss: nan
Epoch 4/5
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: nan
Epoch 5/5
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - loss: nan
Stopped data collection and training.


2024-12-12 00:07:15.598 ( 703.051s) [R_filtered_exg  ]      data_receiver.cpp:344    ERR| Stream transmission broke off (Input stream error.); re-connecting...
