# EMG Pipeline #

This notebook is to demonstrate the pipeline for EMG source inversion.

# Imports #
These should be covered in `requirements.txt`- activate the virtual environment `.venv` for the Python kernel and then make sure that environment has installed `pip install -r requirements.txt`. 

In [None]:
from scipy.constants import epsilon_0
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import kv as K0, iv as I0
from scipy.linalg import solve, pinv, svd

from EMGinv_fns import *

from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull, Delaunay
from scipy.optimize import minimize

import pandas as pd
import scipy.io
import mat73
from pathlib import Path
import shutil
import os.path as op
import mne

import gc
import matplotlib.animation as animation
import pydicom

# Remove magic commands if turning into a script
import pyvistaqt
%matplotlib qt
%load_ext autoreload
%autoreload 2

In [None]:
# Loading source space and forward model generated in MNE-python, see Fwd_BEM_MNE.ipynb

# Load electrode positions
electrode_pos = np.load('Data/256simparm_electrode_pos.npy')

# Load forward model
mne_fwd = mne.read_forward_solution('Data/simp_arm_2mm-fwd.fif')
fwd = mne_fwd['sol']['data']
pos = mne_fwd['source_rr'] 

xscaling, yscaling, zscaling = (2e-3, 2e-3, 2e-3) # Is the distance between dipoles in source space #np.repeat(np.abs(np.sum(np.diff(pos[0], axis = 0))),3)  

# Transformations specific to dataset
# Need to remove some electrodes - get rid of the middle ones
fwd = np.vstack((fwd[:64,:], fwd[192:,:]))
electrode_pos = np.vstack((electrode_pos[:64,:], electrode_pos[192:,:]))

# If Blender defaults were used may need to adjust axes = ["X", "Z", "Y"], however this should not be necessary otherwise
# electrode_pos = electrode_pos[:,[0,2,1]] , # pos = pos[:,[0,2,1]]

# Removal of dipole sources - here remove sources lying within cylinders representing bone
pos, fwd = bone_remover(pos, fwd, -9e-3, 0, 5e-3) #Ulnar
pos, fwd = bone_remover(pos, fwd, 5e-3, -9e-3, 5e-3) # Radius

# Consider fwd model adjustments
# Condense the fwd such that there is only one dipole per voxel
dipole_ori = [0, 0, 1]
fwd = fwd_convertfixed(fwd, dipole_ori )

In [46]:
fwd.shape

(128, 20654)

In [30]:
# Alternative: Loading source space and forward model generated from (MRI) image data
# # Load electrode positions
# electrode_pos = np.load('electrode_pos.npy')

# # Load source space 
# xscaling, yscaling, zscaling = 1.5e-4, 1.5e-4, 0.5e-2
# pos = load_src_template(filename=None, xscaling=xscaling, yscaling=yscaling, zscaling=zscaling)

# # Load forward model
# fwd = np.load('Data/fwd_dipole.npy')
# # Alternative would be to generate with
# # fwd = fwd_generator(dipole_potential, pos, electrode_positions)
# # And then saving with np.save('fwd_dipole.npy', fwd) - Very big file

# # Consider fwd model adjustments
# # Condense the fwd such that there is only one dipole per voxel
# dipole_ori = [0, 0, 1]
# fwd = fwd_convertfixed(fwd, dipole_ori )

# # Adjust the fwd to remove some outliers
# # fwd[fwd>1e15] = 1e15
# # fwd = np.log10(fwd+1)  # this doesn't preserve sign
# fwd[np.isnan(fwd)] = 0 # for any nans

In [None]:
# Load muap template waveforms 
# filename = '/Users/pokhims/Library/CloudStorage/OneDrive-TheUniversityofMelbourne/Documents/Coding/CMU_EMGSL/Data/muaps_frommatlab.mat'
# muaps_1 = scipy.io.loadmat(filename)['muaps']

# Load some data - to construct covariance matrices
# MNE_raw = load_tmsitomne_combine(f_prox = 'Data/Pok_2024_08_21_A_PROX_8.poly5', f_dist='Data/Pok_2024_08_21_B_DIST_8.poly5', scale=1e-6)
MNE_raw = load_tmsitomne_combine(f_prox = 'C:/Data/MetaWB/MCP04_2024_09_17/TMSi/MCP04_2024_09_17_A_PROX_3.poly5', f_dist='C:/Data/MetaWB/MCP04_2024_09_17/TMSi/MCP04_2024_09_17_B_DIST_3.poly5', scale=1e-6)

In [None]:
# Important to filter the data (highpass filter)
MNE_raw = MNE_raw.filter(l_freq=100, h_freq=None, )
# Estimate the noise covariance matrix on Epoched data.  This means that the noise covariance will be estimated on pre-stimulus periods
channel_data = MNE_raw['Prox - TRIGGERS'][0][0]-252
events = tmsi_eventextractor(channel_data)
event_dict = {'Ext': -2, 'Flex': -6} # Should be correct 
epochs = mne.Epochs(MNE_raw, events, event_dict, tmin=-4, tmax=4, baseline=None, preload=True)
# Consider setting an average EEG reference across each panel of 32. 
# epochs.set_eeg_reference('average', projection=True)
# epochs.apply_proj()
# epochs.plot(n_epochs=1, scalings='auto', );

# Set tmin and tmax based on experimental conditions.  In this case, the participant was not moving perfectly to the triggers.
noise_cov = mne.compute_covariance(epochs, method='auto', tmin=-2, tmax=0.01)
data_cov = mne.compute_covariance(epochs, method='auto', tmin=1, tmax=epochs.tmax)

noise_cov.plot(epochs.info,);
data_cov.plot(epochs.info,);

# Alternative to muaps_from matlab.mat - Load the mask, and then extract from the relevant EMG channels in MNE_raw
filename = '/Users/pokhims/Library/CloudStorage/OneDrive-TheUniversityofMelbourne/Documents/Coding/CMU_EMGSL/Data/muaps_mask.mat'
mask = scipy.io.loadmat(filename)['mask']
# Not sure why some of the masks aren't in the data - and why the end result is so different
mask = mask[:-2, :]
data = MNE_raw.get_data(picks='data')
# Need to get component for all channels
muaps = np.zeros((128,41))
for i in range(128):
    tmp = data[i,:]
    snips = tmp[mask]
    muaps[i,:] = np.mean(snips, axis=0)

# Another alternative to the data
ext_1 = epochs.get_data(picks='data')[0,:,25030:25070]
flex_1 = epochs.get_data(picks='data')[0,:,26670:26710]

In [33]:
# epochs.plot(n_epochs=1, scalings='auto',);

In [None]:
del MNE_raw, epochs, channel_data, data, tmp, snips

gc.collect()

In [None]:
# Beamformer - for online processing - Note max_power is only for use when dipole orientation is unknown

w_lcmv = lcmv_beamformer_constructor(fwd, data_cov=data_cov.data, noise_cov= noise_cov.data, pos=pos, arr_gain=True, max_power=False)

print(sys.getsizeof(w_lcmv)/8/1024/1024, 'MB')

In [None]:
# Choose the waveform to use - muaps, ext_1, flex_1
waveform = ext_1

# Apply beamformer
source_activity_time = np.dot(w_lcmv, waveform)

print(sys.getsizeof(source_activity_time)/8/1024/1024, 'MB')

In [37]:
# Alternative algorithms, can consider looking at minimum norm estimate:

# Minimum norm estimate
# source_activity_time = minimum_norm_estimate(fwd, waveform, noise_cov=noise_cov.data, reg=0.1)

# sLoreta
# source_activity_time = sloreta(fwd[:N,:], waveform, noise_cov=noise_cov, reg=0.1)

# Different optimisers for ECD (Equivalent Current Dipole) fitting - Which did not work too well
# n_dipoles = 1
# data = muaps[:, 20]
# N=data.shape[0]
# optimal_dipoles = ECD_fit_dipoles(data, fwd[:N,:]/np.linalg.norm(fwd[:N,:]), pos, n_dipoles, initial_guess=np.array([0.05, 0.09, 0.03, 1, 1, 1]), local=True) # 
# optimal_dipoles = ECD_fit_dipoles_analytical(data, electrode_pos[:N,:], n_dipoles, local=True)
# print("Optimal dipole parameters (position, strengths):", optimal_dipoles)

# Best dipole in fwd
# best_index, best_weights, save_arr = best_dipole_infwd(fwd, muaps)
# Pos as obtained from:
# pos[best_index[t],0], pos[best_index[t],1], pos[best_index[t],2]

In [38]:
plt.figure()
plt.plot(waveform.T);

# Visualisations

In [None]:
# Threshold
thresh = 0.5
# Look at specific timepoint in the source activity - 20 for matlab template waveform, 30 for other one; ext_1 - 5s and 15s are interesting; flex_1 - 9s
t = 15

source_activity = source_activity_time[:, t]

# Reshape source activity to condense N source orientations into 1 per voxel - Confirmed works for 3 orientations, should work for more.
reshape_by = source_activity.shape[0] // pos.shape[0]
reshaped_act = np.array(source_activity.reshape((reshape_by, -1), order='F'))
source_activity = np.linalg.norm(reshaped_act, axis=0)

ind = np.abs(source_activity) > thresh*np.max(np.abs(source_activity))
source_activity = source_activity[ind]
pos_t = pos[ind]

# Plot the convex hull and the moved points
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Plot electrode positions
ax.scatter(electrode_pos[:, 0], electrode_pos[:, 1], electrode_pos[:, 2], c=waveform[:,t], marker='o', cmap='turbo')
# Plot the source space
ax.scatter(pos_t[:, 0], pos_t[:, 1], pos_t[:, 2], c=source_activity, marker='o', alpha=0.8, cmap='viridis')
# Set labels
ax.set_xlabel('X Axis (m)')
ax.set_ylabel('Y Axis (m)')
ax.set_zlabel('Z Axis (m)')
# fig.show()

print('Upside should be facing towards us.  i.e. Palmside in positive y-direction.  Electrodes go along x dimension.  Z dimension is proximal to distal.')

In [50]:
source_activity = source_activity_time[:, t]

# Faster to reconfigure the scatter points to be in a grid, and then use imshow to plot the activity.
grid = pos_to_3Dgrid_converter(pos, source_activity, (xscaling, yscaling, zscaling))

x_min, x_max = np.min(pos[:, 0]), np.max(pos[:, 0])
y_min, y_max = np.min(pos[:, 1]), np.max(pos[:, 1])
z_min, z_max = np.min(pos[:, 2]), np.max(pos[:, 2])

In [None]:
grid.shape

In [None]:
arm_image.shape

In [None]:
# Plot a slice of the source estimate
z = 10

# Display the source estimate and an arm for reference
plt.figure(figsize=(12, 6))

# Load arm MRI file
ds = pydicom.read_file('/Users/pokhims/Library/CloudStorage/OneDrive-TheUniversityofMelbourne/Documents/Coding/CMU_EMGSL/Data/R_Forearm.dcm')
arm_image = ds.pixel_array

plt.subplot(121) 
plt.imshow(arm_image[z*arm_image.shape[0]//grid.shape[2], 100:500 , 330:670 , 2], extent=[y_min, y_max, x_min, x_max], cmap='gray', origin='upper')
plt.title('MRI for visualisation of orientation only'), plt.axis('on')
print('MRI for visualisation, does not represent patient nor model arm and needs to be flipped betwen left and right arm.')
plt.scatter(electrode_pos[:, 0], electrode_pos[:, 1], c=waveform[:,t], cmap='turbo', marker='o',)
plt.colorbar(label='Electrode Activity')
plt.xlabel('X Axis (m)'), plt.ylabel('Y Axis (m)')

plt.subplot(122)
plt.imshow(grid[:,:,z], extent=[y_min, y_max, x_min, x_max], origin='upper', cmap='viridis')
plt.colorbar(label='Source Activity')
plt.scatter(electrode_pos[:, 0], electrode_pos[:, 1], c=waveform[:,t], cmap='turbo',marker='o',)
plt.xlabel('X Axis (m)'), plt.ylabel('Y Axis (m)')

plt.title('Source Estimate'), plt.axis('on')
plt.show()

In [42]:
# Plot all the slices
# This is hard to use when there are too many slices!  May need to change the matplotlib viewer if too many slices
# %matplotlib inline

num_slices = grid.shape[2]

# Create subplots
fig, axes = plt.subplots(num_slices, 1, figsize=(8, 200))

# Plot each slice
for z in range(num_slices):
    # Plot activity
    axes[z].imshow(grid[:,:,z], extent=[y_min, y_max, x_min, x_max], origin='upper', cmap='viridis', vmin=0, vmax=np.nanmax(grid) )
    # Plot electrode position
    axes[z].scatter(electrode_pos[:, 0], electrode_pos[:, 1], c=waveform[:,t], cmap='turbo', marker='o')
    axes[z].set_title(f'{int(1000*z*zscaling)}mm')
    axes[z].axis('off')

# fig.colorbar(axes[0].imshow(labels[:, :, 0], cmap='viridis'), ax=axes, orientation='vertical')

# Display the plot
# plt.tight_layout()
plt.show()

In [43]:
# # If want to do 3D animated plot

# def update_plot(i, data, scat, ):
#     scat.set_array(data[i])
#     ax.set_title(f'Timepoint {i}')
#     return scat,

# numframes = muaps.shape[1]
# colour_data = muaps.T
# pos_move = pos[best_index,:]

# # Plot 3D image of electrode positions and select dipole positions

# fig = plt.figure(figsize=(10, 10))
# ax = fig.add_subplot(111, projection='3d')

# # Plot the electrode positions
# p = ax.scatter(electrode_pos[:, 0], electrode_pos[:, 1], electrode_pos[:, 2], c=muaps[:,0], marker='o', ) #vmin=muaps.min(), vmax=muaps.max())
# # fig.colorbar(p, ax=ax)
# # Colourbar and vmin and vmax may not be useful as only one electrode may seem to change colour
# # p2 = ax.scatter(pos[best_index[:],0], pos[best_index[:],1], pos[best_index[:],2], c='r', marker='x', s=50, )

# # Set labels
# ax.set_xlabel('X (m)')
# ax.set_ylabel('Y (m)')
# ax.set_zlabel('Z (m)')

# ani = animation.FuncAnimation(fig, update_plot, frames=range(numframes),
#                                 fargs=(colour_data, p, ))
# plt.show()