# Tests



## Set-up

In [3]:
# import needed libraries
import numpy as np
from sklearn.decomposition import PCA
import sys
import os

# Get the current working directory (make sure to adjust the path accordingly)
current_dir = os.path.dirname(os.path.abspath("__file__"))

# Add the path of the decomposition_lib to the system path
sys.path.append(os.path.join(current_dir, 'plots'))
import plots
sys.path.append(os.path.join(current_dir, 'emg_lib'))
import emg_lib as emgl
import matplotlib.pyplot as plt


# Add the path of the decomposition_lib to the system path
sys.path.append(os.path.join(current_dir, 'decomposition_lib'))
import decomposition_lib as dl
import scipy.io 
from scipy.signal import butter, filtfilt, iirnotch



## Offline decomposition example

In [None]:
fs = 2048
exfactor = 16
nbIterations = 20
#emg_filt = np.array([[1, 2], [3, 4]])
# Load the MATLAB file
mat_file  = scipy.io.loadmat('sample_data\hdemg_SOL_sample.mat')
# Load the variable, e.g., if it's named 'emg_filt' in MATLAB
# Assuming it might be structured; convert to numpy array
if 'EMG' in mat_file:
    emg = np.array(mat_file['EMG'])
    emg = emg.T
    #emg = emg[:,fs:-fs]
else:
    raise ValueError("Variable 'emg_filt' not found in the MATLAB file.")

emg_filt, emg_mask, emg_envelope = emgl.filter_emg(emg,fs)

emg_filt = emg_filt[:,4*fs:-4*fs]

sources, decompParams = \
    dl.decomposition_offline(emg_filt, fs, exfactor,nbIterations, preOptFilters=None, refineStrategy='SIL', \
                              showPlots=True, h=None, ax=None, peeloff_flag=False, removeDuplicates=True, qc_threshold=0.85)

Decomposition:   0%|          | 0/20 [00:00<?, ?it/s]

Figure closed!


## Plot example

In [31]:
dischargeRates=sources['dischargeRates']
spikeTrains=sources['spikeTrains']
PulseT=sources['PulseT']
# Plotting
nMUs, nSamples = PulseT.shape

# Normalize the discharge rates to 40 Hz (or any specific target frequency)
normalized_dischargeRates = dischargeRates / np.nanmax(dischargeRates) * 40

# Compute median discharge rates for y-ticks
meanDR = np.nanmean(dischargeRates, axis=1)

# Create time vector in seconds
tVec = np.arange(nSamples) / fs

# Create color array for each MN (example using a colormap)
colors = plt.cm.viridis(np.linspace(0, 1, nMUs))

# Set up subplots
fig, ax = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=False)

# Plot spike trains
for i in range(nMUs):
    ax[0].plot(tVec, (0.8 * spikeTrains[i, :] + i + 1 - 0.4),color=colors[i])

# Plot discharge rates
for i in range(nMUs):
    ax[1].plot(tVec, normalized_dischargeRates[i, :] / 40 + i + 1 - 0.5, '.', markersize=10, color=colors[i])



# Set y-ticks for both plots
ax[0].set_yticks(np.arange(1,nMUs+1))
ax[0].set_yticklabels([str(i) for i in range(1,nMUs+1)])

ax[1].set_yticks(np.arange(1,nMUs+1))
ax[1].set_yticklabels([f'{meanDR[i]:.1f}' for i in range(nMUs)])

# Label axes
ax[0].set_title("Spike Trains (Raster Plot)")
ax[0].set_ylabel("Motor Units")
ax[0].set_xlabel("Time (s)")

ax[1].set_title("Normalized Instantaneous Discharge Rates")
ax[1].set_ylabel("Median DR (Hz)")
ax[1].set_xlabel("Time (s)")

# Show the plot
#plt.tight_layout()
plt.draw()
plt.pause(0.1)