In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

from emgdecomp.decomposition import EmgDecomposition
from emgdecomp.parameters import EmgDecompositionParams

from src.utils import load_config
from src.data.filter import Filter
from src.data.utils import bipolar_conversion, load_data, load_data_deprecated, average_reference

In [14]:
data_dir = "/Users/johnzhou/research/rumi/data/"
expt_name = "test_2023-08-19-1953_Closed-Loop-Neurofeedback-Interface"
emg_fname = data_dir + expt_name + "/data_streams/emg_stream.bin"
filter_fname = data_dir + expt_name + "/data_streams/filter_stream.bin"

try:
    emg_data = load_data(emg_fname)
#     filter_data = load_data(filter_fname)
except UnicodeDecodeError:
    emg_data = load_data_deprecated(emg_fname)
#     filter_data = load_data_deprecated(filter_fname)
    
emgbuffer = emg_data['emgbuffer']
# filterbuffer = filter_data['filterbuffer']

print(emgbuffer.shape)
# print(filterbuffer.shape[0] * 2)

(94568, 67)


In [15]:
class Buffer:
    def __init__(self, data):
        self.data = data
        self.num_chans = data.shape[-1]
        self.pointer = 0
    
    def poll(self):
        num_samples = int(np.random.rand() * 40 + 40)
        if np.random.rand() < 0.05:
            return None
        
        data = self.data[self.pointer:self.pointer + num_samples, :]
        self.pointer += num_samples
        return data

In [16]:
# class Decomposer:
#     def __init__(self, filter_buffer, emgde=None):
#         self.filter_buffer = filter_buffer
#         self.rest_time = 10
#         self.init_time = 60
#         self.Fs = 2000
#         self.num_chans = filter_buffer.num_chans
#         self.threshold_factor = 6
#         self.dead_time_remaining = 0
#         self.threshold_cross_deadtime = int(0.020 * self.Fs)
#         self.window_pre = int(0.010 * self.Fs)
#         self.window_post = int(0.020 * self.Fs)
#         self.decomp_params = EmgDecompositionParams(self.Fs)
#         self.num_crossings_to_add_sources = 50
#         self.crossings_count_without_detection = 0
#         self.crossing_windows_without_detection = np.zeros((self.num_chans, 
#                                                             self.window_pre + self.window_post, 
#                                                             self.num_crossings_to_add_sources))
#         self.threshold = None
#         if emgde is None:
#             self.emgde = EmgDecomposition(self.decomp_params)
#         else:
#             self.emgde = emgde
            
#         self.decomp_buffer = None
#         self.num_firings = 0
#         self.num_samples = 0
            
#     def run(self):
#         data = self.filter_buffer.poll()
#         if data is None:
#             return
# #         print("Num samples polled:", data.shape[0])
#         self.num_samples += data.shape[0]
        
#         if self.decomp_buffer is not None:
#             data = np.concatenate((self.decomp_buffer, data))
# #             print("Buffer len:", self.decomp_buffer.shape[0])
# #             print("Total data len:", data.shape[0])
#             self.decomp_buffer = None
        
#         num_samples = data.shape[0]
#         threshold_idxs = self.detect_threshold_crossing(data[self.dead_time_remaining:, ...]) + self.dead_time_remaining
        
#         # If there is not enough samples at the end to create a window around the threshold crossing, leave in the 
#         #  decomp buffer for the next run
#         keep_idxs = np.argwhere(threshold_idxs + self.window_post <= num_samples)
#         leave_idxs = np.argwhere(threshold_idxs + self.window_post > num_samples)
        
#         if leave_idxs.size > 0:
# #             print(f"Leaving {len(leave_idxs)} crossings behind, not enough samples...")
#             leave = threshold_idxs[leave_idxs].flatten()
#             self.decomp_buffer = data[int(leave[0]) - self.window_pre:, ...]
#         if keep_idxs.size > 0:
#             keep = threshold_idxs[keep_idxs].flatten()
# #             print("Decomposing window around idx", keep)
#             decomp_windows = self.create_windows(data,
#                                                  keep,
#                                                  self.window_pre, 
#                                                  self.window_post)
#             firings = self.emgde.transform(np.squeeze(decomp_windows))
#             self.num_firings += len(firings)
#             if len(firings):
#                 print(f"{np.unique(np.array([st[0] for st in firings])).size} sources identified with a "
#               f"total of {len(firings)} spikes.")
#             else:
#                 print("No spikes found")
#             num_decomp_windows = decomp_windows.shape[-1]
#             if len(firings) == 0:
#                 if num_decomp_windows + self.crossings_count_without_detection <= self.num_crossings_to_add_sources:
#                     self.crossing_windows_without_detection[...,
#                         self.crossings_count_without_detection:self.crossings_count_without_detection + decomp_windows.shape[-1]
#                     ] = decomp_windows
#                 else:
#                     remaining_windows = self.crossings_count_without_detection + num_decomp_windows - self.num_crossings_to_add_sources
#                     windows_to_add = np.expand_dims(np.squeeze(decomp_windows[..., 0:remaining_windows]), axis=-1)
#                     self.crossing_windows_without_detection[..., self.crossings_count_without_detection:self.num_crossings_to_add_sources] = windows_to_add
#                 self.crossings_count_without_detection += decomp_windows.shape[-1]
#             if self.crossings_count_without_detection >= self.num_crossings_to_add_sources:
#                 print(f"{self.num_crossings_to_add_sources} threshold crossings without sources id'ed, adding new sources!")
#                 firings = self.emgde.decompose_batch(self.crossing_windows_without_detection)
#                 self.num_firings += len(firings)
#                 if len(firings):
#                     print(f"{np.unique(np.array([st[0] for st in firings])).size} new sources identified with a "
#                   f"total of {len(firings)} spikes.")
#                 else:
#                     print("No new sources found!")
#                 self.crossings_count_without_detection = 0
#         if self.decomp_buffer is None:
#             self.decomp_buffer = data[-self.window_pre:, ...]
#         self.dead_time_remaining = self.window_pre
# #         print("\n")
    
#     @staticmethod
#     def create_windows(data, idxs, window_pre, window_post):
#         num_windows = len(idxs)
#         num_chans = data.shape[-1]
#         window_len = window_pre + window_post
        
#         windows = np.zeros((num_chans, window_len, num_windows))
#         for window_idx, i in enumerate(idxs):
# #             print("Window start", i - window_pre, ", end", i + window_post)
#             start = i - window_pre
#             end = i + window_post
#             if start < 0 or end > data.shape[0]:
#                 raise ValueError("Not enough samples!")
#             windows[..., window_idx] = data[start:end, ...].T
#         return windows
        
#     def set_threshold(self) -> None:
#         print('Setting threshold')
#         buffer_len = self.rest_time * self.Fs
#         thresholding_buffer = np.zeros((buffer_len, self.num_chans))
#         write_pointer = 0
#         while write_pointer < buffer_len:
#             data = self.filter_buffer.poll()
#             if data is not None:
#                 num_samples = data.shape[0]
#                 if num_samples + write_pointer > buffer_len:
#                     remaining_space = buffer_len - write_pointer
#                     data = data[:remaining_space]
#                 thresholding_buffer[write_pointer:write_pointer + num_samples] = data
#                 write_pointer += num_samples

#         channel_means = np.mean(thresholding_buffer, axis=0)
#         channel_stds = np.std(thresholding_buffer, axis=0)
#         self.threshold = channel_means + self.threshold_factor * channel_stds
#         print('Threshold set!')

#     def detect_threshold_crossing(self, data):
#         """Detect threshold crossings in the data, return time indices of threshold crossings."""
#         num_samples = data.shape[0]
#         time_idxs, channel_idxs = np.nonzero(data > self.threshold)
        
#         if time_idxs.size == 0:
#             return np.array([])
#         else:
#             time_idxs = time_idxs  # get correct idxs w.r.t. input

#         sorted_time_idxs = np.sort(np.unique(time_idxs))
#         alive_idxs = []

#         dead_until_idx = -1
#         for idx in sorted_time_idxs:
#             if idx > dead_until_idx:
#                 alive_idxs.append(idx)
#                 dead_until_idx = idx + self.threshold_cross_deadtime
# #         print("Crossing identified at idxs:", alive_idxs)

#         return np.array(alive_idxs)
        
#     def init_model(self) -> None:
#         """Initialize the decomposition model on the first num_seconds of data."""
#         init_buffer = np.zeros((self.init_time * self.Fs, self.num_chans))
#         write_pointer = 0
        
#         while write_pointer < init_buffer.shape[0]:
#             data = self.filter_buffer.poll()
#             if data is not None:
#                 num_samples = data.shape[0]
#                 if num_samples + write_pointer > init_buffer.shape[0]:
#                     remaining_space = init_buffer.shape[0] - write_pointer
#                     data = data[:remaining_space]
#                 init_buffer[write_pointer:write_pointer + num_samples] = data
#                 write_pointer += num_samples

#         # TODO: Need to synchronously empty the buffer while fitting the model, can set
#         #  a gamestate flag and let the display empty it
#         print("Initializing model...")
#         start = time.time()
#         firings = self.emgde.decompose(init_buffer.T)
#         print(f"{np.unique(np.array([st[0] for st in firings])).size} sources identified with a "
#               f"total of {len(firings)} spikes.")
#         print(f"Took {time.time() - start} s")


In [19]:
print(emgbuffer.shape)

(94568, 67)


In [5]:
emgde = EmgDecomposition(EmgDecompositionParams)
with open('/Users/johnzhou/research/emg_decoder/models/emgde.pkl', 'rb') as f:
    emgde = emgde.load(f)
print(emgde.num_sources())
# emgde.decompose_batch(emgbuffer[:300, 1:32].T)

7


In [9]:
class DisplayStage:
    def __init__(self, **kwargs):
        self.complete = False
        self.code = None
    
    def run(self, code):
        """Run behavior may depend on external codes - consider whether to make a new stage instead"""
        if self.complete:
            self.exit()
        # Do something
        self.render()
        raise NotImplementedError
    
    def render(self):
        raise NotImplementedError
    
    def exit(self):
        """Set signals and next stage here"""
        raise NotImplementedError


In [10]:
class DecompStage:
    def __init__(self, **kwargs):
        self.complete = False
        
    def run(self, data):
        """Deal with polled data here"""
        if self.complete:
            self.exit()
        raise NotImplementedError
        
    def exit(self):
        """Set signals and next stage here"""
        raise NotImplementedError

In [5]:
class Decomposer:
    def __init__(self, filter_buffer, emgde=None):
        self.filter_buffer = filter_buffer
        self.game_state = np.zeros(3)
        self.rest_time = 10
        self.init_time = 60
        self.Fs = 2000
        self.num_chans = filter_buffer.num_chans
        self.threshold_factor = 6
        self.dead_time_remaining = 0
        self.threshold_cross_deadtime = int(0.020 * self.Fs)
        self.window_pre = int(0.010 * self.Fs)
        self.window_post = int(0.020 * self.Fs)
        self.decomp_params = EmgDecompositionParams(self.Fs)
        self.num_crossings_to_add_sources = 50
        self.crossings_count_without_detection = 0
        self.crossing_windows_without_detection = np.zeros((self.num_chans, 
                                                            self.window_pre + self.window_post, 
                                                            self.num_crossings_to_add_sources))
        self.threshold = None
        if emgde is None:
            self.emgde = EmgDecomposition(self.decomp_params)
        else:
            self.emgde = emgde
            
        self.decomp_buffer = None
        self.num_firings = 0
        self.num_samples = 0

In [81]:
filter_buffer = Buffer(filterbuffer[:, :31])
module = Decomposer(filter_buffer, emgde=emgde)
# module.init_model()
module.set_threshold()

Setting threshold
Threshold set!


In [82]:
for i in range(40000):
    module.run()

2 sources identified with a total of 3 spikes.
No spikes found
1 sources identified with a total of 2 spikes.
2 sources identified with a total of 2 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
4 sources identified with a total of 5 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
3 sources identified with a total of 3 spikes.
No spikes found
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
No spikes found
3 sources identified with a total of 3 spikes.
No spikes found
No spikes found
2 sources identified with a total of 3 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
2 sources identified with a total of 2 spikes.
No spikes found
No spikes found
No spikes found
1 sources identified with a total of 1 

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


No new sources found!
1 sources identified with a total of 1 spikes.
4 sources identified with a total of 5 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
No spikes found
3 sources identified with a total of 3 spikes.
4 sources identified with a total of 4 spikes.
No spikes found
6 sources identified with a total of 7 spikes.
5 sources identified with a total of 6 spikes.
No spikes found
No spikes found
No spikes found
No spikes found
No spikes found
No spikes found
3 sources identified with a total of 4 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
2 sources identified with a total of 2 spikes.
No spikes found
2 sources identified with a total of 3 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
No spikes found
No spikes found
1 sources identified with a total of 1 spikes.
1 sources identified with a total of 1 spikes.
1 sources identified with a total of 1 spikes.
3 sources identified with a total

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


No new sources found!
1 sources identified with a total of 1 spikes.
No spikes found
No spikes found
2 sources identified with a total of 2 spikes.
No spikes found
3 sources identified with a total of 3 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
No spikes found
No spikes found
No spikes found
No spikes found
No spikes found
No spikes found
No spikes found
2 sources identified with a total of 2 spikes.
2 sources identified with a total of 2 spikes.
No spikes found
2 sources identified with a total of 2 spikes.
2 sources identified with a total of 2 spikes.
2 sources identified with a total of 2 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
2 sources identified with a total of 2 spikes.
2 sources identified with a total of 3 spikes.
No spikes found
No spikes found
No spikes found
No spikes found
1 sources identified with a total of 1 spikes.
2 sources identified with a total of 2 spikes.
2

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


No new sources found!
No spikes found
2 sources identified with a total of 2 spikes.
No spikes found
3 sources identified with a total of 4 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
3 sources identified with a total of 3 spikes.
No spikes found
No spikes found
3 sources identified with a total of 3 spikes.
1 sources identified with a total of 1 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
2 sources identified with a total of 3 spikes.
2 sources identified with a total of 2 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
1 sources identified with a total of 1 spikes.
No spikes found
No spikes found
No spikes found
No spikes found
No spikes found
2 sources identified with a total of 2 spikes.
No spikes found
2 sources identified with a total of 2 spikes.
No spikes found
No spikes found
No spikes found
2 sources identi

In [83]:
emgde._raw_sources.shape

(496, 7)

In [86]:
print(module.num_firings, module.num_samples / 2000)

338 94.819
