In [None]:
from brian2 import *
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import butter, filtfilt, windows
from scipy.stats import linregress
%matplotlib inline
import scipy.io
from scipy.signal import resample, freqz
from scipy.interpolate import interp1d
from scipy.signal import remez, firwin, lfilter
from factor_analyzer import FactorAnalyzer
from sklearn.decomposition import PCA
import random
import matplotlib.cm as cm
import os
import pandas as pd
from scipy.signal import csd, detrend
import pickle
from scipy.fft import fft, fftfreq
import sys
import scipy.ndimage

In [None]:
### PARAMETERS ###
start_scope() # Re-initialize Brian
perform_coherence_analysis = False # Takes quite a long time to run if True
# If False, will return an error at some point
plt.style.use('_classic_test_patch') # Plotting style

# Set parameters ############################
sim_name = 'sim_test_VL_VM_VI'
muscle_names = ["VL","VM","VI"]
muscle_colors = ['C0','C1','C2']
muscle_colormaps = ['viridis','plasma','summer']
nb_of_pools_to_simulate = 3

# TIME PARAMETERS
fsamp = 1000  # set your fsamp # This is NOT the dt at which the simulation runs. The simulation timesteps are 0.1ms in duration by default
window_beginning_ignore = 1 # in s
window_end_ignore = 1 # in s
true_duration = 20
duration_with_ignored_window = (true_duration+window_beginning_ignore+window_end_ignore)*second
ISI_threshold_for_discontinuity = 0.4*second # in s ; motoneurons whose max(ISI)>threshold will be removed from analysis (so only continuous MNs are kept)
# VOLTAGE THRESHOLDS OF ALL NEURONS
voltage_rest = 0 # arbitrary ; 0 at rest
voltage_thresh = 1 # arbitrary ; 1 for generating a spike
# NUMBER OF NEURONS SIMULATED - PER POOL
nb_motoneurons = 50 #100 # Mean MNs in experimental data = ~48 for VL, ~43 for GM
nb_renshaw = 10 # Ratio of 377 MNs to 64 RCs in Williams & Baker 2009 (64/377) =~17%

### MOTONEURON PROPERTIES ACCORDING TO THEIR SIZES ##########################
# Everything calculated from Caillet et al 2022 https://elifesciences.org/articles/76489
# Assuming that soma diameter from human motoneurons vary between 50 and 100 micrometers, based on https://journals.physiology.org/doi/full/10.1152/physiol.00021.2018
    # ^ "Scaling of motoneurons, From Mouse to Human" Manuel et al. Physiology (2018)
min_soma_diameter = 50 # in micrometers, for smallest MN
max_soma_diameter = 100 # in micrometers, for largest MN
# Select the range of MNs simulates (0 being smallest, 100 being largest)
min_normalized_boundary = 0 # Normalized between 0 and 100
max_normalized_boundary = 20 # Normalized between 0 and 100 - use a value below 100 if not simulating very fast MNs #20 to go up to 20% MVC
#### Time constant (tau) #####
tau_constant = 2.6*(10**4) # Caillet et al 2022
tau_exponent = 1.5 # Caillet et al 2022
    # https://www.desmos.com/calculator/bfhcpgiltr = visualize the curve
    # Min time constant for smallest MN (50 micrometers) = ~26ms
    # Max time constant for biggest MN (100 micrometers) = ~70ms
    # From Williams & Baker 2009 = "These decay times correspond with exponential time constants of 24 –26 ms, in keeping with previous models of motoneurons (Matthews, 1997)"
    # From Vertebrate Motoneurons 2022:
        # 3-15ms in cat ; 2-20ms in mouse
        # "Consequently, several labs use the half-decay time of AHP to distinguish between F-fast and S-slow rat motoneurons (F > 20 ms, S < 20 ms, (Gardiner 1993)) but the difference between the FR and FF motoneurons is not well defned."
    # From "Principles of Neural Science":
        # "Typical values of τ for neurons range from 20 to 50ms" (but not motor neurons specifically)
    # From the same book: "Cell membrane time constant; the product of resistance and capacitance of the membrane (typical values 1–20 ms). tau = Rm ⋅ Cm"
    # Maltenfort, Heckman 1998 simulation study = 2.5-10ms time constant for motoneurons
#### Input weight = normalizd resistance, so that the input to the smallest MN is scaled by a factor of 0 #####
resistance_constant = 9.6*(10**5) # Caillet et al 2022
resistance_exponent = 2.4 # Caillet et al 2022
    # https://www.desmos.com/calculator/pbs97zynff = visualize the curve for resistance (ohms) and input weights (between 0 and 1)
    # Min input weight for smallest MN (50 micrometers) = 1
    # Max input weight for biggest MN (100 micrometers) = ~0.19
#### Refractory period, not dependent on MN size (Caillet's paper gives equations for AHP duration but not for refractory period duration) #####
refractory_period_MN = 5*ms
    # Manuel et al. 2019 "Scaling of motor output, from Mouse to Humans"
        # "Statistical methods employed at low firing rates indicate the AHP durations of low-threshold human motoneurons, presumably type S and perhaps some type FR, are ~125–140 ms."
    # Herbert & Gandevia 1999 assume a 5ms (absolute?) refractory period
    # Lateva et at 2001 = Absolute refractory period of 3ms in muscle fibers, and relative refractory period of 10ms
    # University of Washington textbook of physiology = in a typical neuron, the absolute refractory period lasts a few ms and the relative period tens of ms

### RENSHAW CELL PROPERTIES ##########################
tau_Renshaw = 8*ms # time constant of Renshaw cells
# Williams & Baker 2009: "time constant of 8ms, similar to experimental data (Desilligny, 1979; Hultborn et al., 1979)"
# Maltenfort, Heckman 1998 simulation study = 8ms time constant and 30ms AHP
refractory_period_RC = 36*ms # refractory period of Renshaw cells => Williams & Baker 2009: "AHP of 36ms similar to experimental data (Deseilligny, 1979; Hultborn et al., 1979)"
# Firing rates of RCs can be expected to go up to >60pps (Moore et al 2015)
# Maltenfort, Heckman 1998 = "The maximum dendrites that happen to travel into the motor nucleus region. In steady-state firing rate of Renshaw cells is 200 pps (Cleveland et al. 1981)"

### CONNECTIVITY BETWEEN MOTONEURONS AND RENSHAW CELLS ##########################
MN_to_Renshaw_connectivity_probability_within_pool = 0.2 #0.2 # 0.5 # 0.25 #if 0.1, each Renshaw cell will receive excitatory input from a random subset of 10% of the homonymous MN pool
Renshaw_to_MN_connectivity_probability_within_pool = 0.5 #0.6 # 0.25 # 0.5 #if 0.3, each Renshaw cell will send inhibitory input to a random subset of 30% of the homonymous MN pool
    # Williams & Baker 2009 = " each motoneuron receives input from 10-20 Renshaw cells, and each Renshaw cell receives input from 20-50 motoneurons."
        # => Motoneurons receive input from [10 to 20]/64 Renshaw cells (0.16 to 0.31), each Renshaw cell receives input from [20-50]/377 MNs  (0.05 to 0.13)
    # From Moore et al 2015: a typical Renshaw cell receives input from ~6-7 motoneurons, and that a Renshaw cell projects back to ~40 motoneurons (so 40/6 ratio )
        # "The results of the present study indicate that the number of contacts from a motoneuron (7.1 +- 1.2) is indeed 6. (...)
        # (...) however, is less than the extrapolated count reported previously (Alvarez et al., 1999). (...)
        # Our estimates of both the number of contacts from all motoneurons and of the convergence of 4 motoneurons contacting individual Renshaw cells can therefore only represent lower bounds.
        # Previous paired recordings (Bhumbra et al., 2014) of the Renshaw cell to motoneuron synapse report an average of 5.5+-0.5 for the number of contacts.
        # The results from the ventral root stimulation experiments of the present study yields an average of 225 release sites, suggesting a convergence quotient of 40 Renshaw cells per motoneuron.
        # These estimates again represent lower bounds because of the slice preparation.
        # Within the limits highlighted above, our data suggest that the degree of convergence for the inhibitory projection may be as much as 10 times greater than that of the excitatory projection."
    # => Ratio of MN->Renshaw and Renshaw->MN comprised between 1/3 and 1/10
    # Edgley, Williams, Baker 2021 = proportion is very different across muscles anyway (primate upper limb)
    # Maltenfort, Heckman 1998 =
        # Simulated probability of connectivity between RCs and MNs according to distance. Max distance (aribitrary) of 15 for connections from RC to MNs,
        # and max distance of 2 for connections from MNs to RC
        # "Each simulated motoneuron therefore synapsed on five Renshaw cells (...) each Renshaw cell could receive input from 20 motoneurons"
        # ^ 256 MNs simulated, so ~0.1 ratio
MN_to_Renshaw_connectivity_probability_across_pool = 0 #0.1 #if 0.1, each Renshaw cell will receive excitatory input from a random subset of 10% of the MNs from the other pools
Renshaw_to_MN_connectivity_probability_across_pool = 0.2 #0.3 #if 0.3, each Renshaw cell will send inhibitory input to a random subset of 30% of the MNs from the other pools
equal_MN_to_RC_connectivity_for_all_MN = False # if true, will apply a softmax function to each row of the MN to RC connectivity matrix, so that it sums to 1.
# If true, the connectivity matrix will no longer be binary
equal_RC_to_MN_connectivity_for_all_MN = False # if true, will apply a softmax function to each column of the RC to MN connectivity matrix, so that it sums to 1.
# If true, the connectivity matrix will no longer be binary

### POST-SYNAPTIC EFFECTS ##########################
MN_to_Renshaw_excit = 0.4 # increase in V in renshaw cell when receiving spike from MN - From Moore et al 2015 = MN-RC pair recordings, with 1 MN spike on average resulting in a probability of 0.3 of RC spike
Renshaw_to_MN_inhib = -0.1 # -0.05 #-0.3 # Decrease in V in MN when receiving spike from Renshaw cell # = parameter for which to test several values
synpatic_delay = 1*ms # Williams & Baker 2009: "A 1 ms conduction delay was introduced for both motoneuron to Renshaw cell, and Renshaw cell to motoneuron contacts. "

### INPUT PARAMETERS ##########################
# MOTONEURONS COMMON INPUT & INDEPENDENT INPUT (NOISE)
# COMMON INPUT(S) ARE GIVEN PER MN POOL
nb_common_inputs = 3
experimental_data_PC_or_factor_as_common_inputs = 'PC' #'factor' #'PC' # Any value other than 'factor' or 'PC' will return an error
# When using 'factor' to simulate 3 common inputs and then estimate the number of factors with [Cheung et al 2009]'s method,
# we get sometimes 2 factors instead of 3 because the factors used as input are correlated with each other.
# However, PCs are necessarily orthogonal to each other, and thus we do get 3 components when using 3 common inputs
path_for_input_files = "D:\THESE\Git_Scripts\Python_Scripts\motoneuron_simulation\Experimental_data_extracted_components"
filename_for_input_file = ["\S2_VL_latents.mat", # for VL
    "\S2_VM_latents.mat", # for VM
    "\S2_VM_latents.mat"] # for VI - actually, created synthetically later by mixing 1/3 of VL's input, 1/3 VM's input, 1/3 2.5hz-filtered noise

common_input_current_mean = 1.6 #0.6 # 0.6 # For common input, common noise, and independent noise. Corresponds to an offset above the input threshold at which MNs fire.
common_input_std = 0.2
independent_input_std = 0.2
# Experimental data (mean of ~9pps for VL, mean fo ~10pps GM)
common_noise_max_freq = 30 #in hz
common_input_to_common_noise_ratio = 1 #if 1, there won't be any common noise
independent_noise_max_freq = 50 # in hz
common_input_to_independent_noise_ratio = 0.3 # 0.5 # Farina & Negro 2015 (37% common input ; 13% common noise ; 50% indepependent noise)"
# Williams & Baker 2009 ("noise (inepdenpendent input) 3.2x larger than cortical input (common input)") ;
common_input_current_baseline = voltage_thresh + common_input_current_mean # 'voltage_thresh' so that baseline allows for continuous firing of motoneurons
# RENSHAW CELLS INDEPENDENT INPUT (NOISE)
Renshaw_cell_excit_noise_input_mean = 0.5 #0.9
Renshaw_cell_excit_noise_input_std = 0 #0.35
# Williams & Baker 2009 = 
    # "The number of inputs per time step was determined by white Gaussian noise, with a mean and variance of 2.27 inputs per 0.2 ms time step,
    # yielding a background firing rate of 11 Hz. This simulated the known supraspinal input to Renshaw cells which is independent from the motoneurons (Windhorst, 1996)."


In [None]:
# check available styles and what they look like

# figure(figsize=([5,10]))
    # plt.style.use('default')
    # plt.style.use('fivethirtyeight')
    # plt.style.use('_classic_test_patch')
    # plt.style.use('Solarize_Light2')
    # plt.style.use('ggplot')
    # plt.style.use('seaborn-v0_8-colorblind')
# for i in range(10):
#     plt.plot([0,1],[i,i],color=f'C{i}',linewidth=10)

In [None]:
### SAVE PARAMETERS
new_directory = sim_name
new_filename = 'parameters.txt'

# Create the directory if it doesn't exist
if not os.path.exists(new_directory):
    os.makedirs(new_directory)
else:
    directory_n = 0
    while os.path.exists(new_directory):
        directory_n = directory_n+1
        new_directory = str(sim_name + "_iter_" + str(directory_n))
        if not os.path.exists(new_directory):
            os.makedirs(new_directory)
            break
        if directory_n > 99: # prevent infinite loop
            break
save_file_path = os.path.join(new_directory, new_filename)

# Write the variables to the file
with open(save_file_path, 'w') as file:
    file.write(f"General parameters -----\n")
    file.write(f"   nb_of_pools_to_simulate: {nb_of_pools_to_simulate}\n")
    file.write(f"   muscle_names: {muscle_names[0:nb_of_pools_to_simulate]}\n")
    file.write(f"   duration_with_ignored_window: {duration_with_ignored_window}\n")
    file.write(f"   nb_motoneurons_per_pool: {nb_motoneurons}\n")
    file.write(f"   nb_renshaw_per_pool: {nb_renshaw}\n")
    file.write(f"   min_soma_diameter: {min_soma_diameter}\n")
    file.write(f"   min_soma_diameter: {min_soma_diameter}\n")
    file.write(f"   max_soma_diameter: {max_soma_diameter}\n")
    file.write(f"   min_normalized_boundary: {min_normalized_boundary}\n")
    file.write(f"   max_normalized_boundary: {max_normalized_boundary}\n")

    file.write(f"\n")
    file.write(f"Connectivity and synaptic weights parameters -----\n")
    file.write(f"   MN_to_Renshaw_connectivity_probability_within_pool: {MN_to_Renshaw_connectivity_probability_within_pool}\n")
    file.write(f"   Renshaw_to_MN_connectivity_probability_within_pool: {Renshaw_to_MN_connectivity_probability_within_pool}\n")
    file.write(f"   MN_to_Renshaw_connectivity_probability_across_pool: {MN_to_Renshaw_connectivity_probability_across_pool}\n")
    file.write(f"   Renshaw_to_MN_connectivity_probability_across_pool: {Renshaw_to_MN_connectivity_probability_across_pool}\n")
    file.write(f"   equal_MN_to_RC_connectivity_for_all_MN: {equal_MN_to_RC_connectivity_for_all_MN}\n")
    file.write(f"   equal_RC_to_MN_connectivity_for_all_MN: {equal_RC_to_MN_connectivity_for_all_MN}\n")
    file.write(f"   MN_to_Renshaw_excit: {MN_to_Renshaw_excit}\n")
    file.write(f"   Renshaw_to_MN_inhib: {Renshaw_to_MN_inhib}\n")

    file.write(f"\n")
    file.write(f"Input parameters -----\n")
    file.write(f"   nb_common_inputs: {nb_common_inputs}\n")
    file.write(f"   experimental_data_PC_or_factor_as_common_inputs: {experimental_data_PC_or_factor_as_common_inputs}\n")
    file.write(f"   filename_for_input_files: {filename_for_input_file}\n")
    file.write(f"   common_input_current_mean: {common_input_current_mean}\n")
    file.write(f"   common_input_std: {common_input_std}\n")
    file.write(f"   common_noise_max_freq: {common_noise_max_freq}\n")
    file.write(f"   common_input_to_common_noise_ratio: {common_input_to_common_noise_ratio}\n")
    file.write(f"   independent_input_std: {independent_input_std}\n")
    file.write(f"   common_input_to_independent_noise_ratio: {common_input_to_independent_noise_ratio}\n")
    file.write(f"   Renshaw_cell_excit_noise_input_mean: {Renshaw_cell_excit_noise_input_mean}\n")
    file.write(f"   Renshaw_cell_excit_noise_input_std: {Renshaw_cell_excit_noise_input_std}\n")


In [None]:
# Define lerp (linear interpolation) function:
def lerp(a, b, t):
    return a + t * (b - a)

####### Generate motoneurons and their properties
motoneuron_normalized_soma_diameters = linspace(min_normalized_boundary,max_normalized_boundary,nb_motoneurons)
motoneuron_soma_diameters = np.zeros(nb_motoneurons)
for mni in range(nb_motoneurons):
    motoneuron_soma_diameters[mni] = lerp(min_soma_diameter, max_soma_diameter, motoneuron_normalized_soma_diameters[mni]/100)

# Time constant (tau)
tau_motoneurons = np.zeros(nb_motoneurons)
for mni in range(nb_motoneurons):
    tau_motoneurons[mni] = tau_constant*(motoneuron_soma_diameters[mni]**((-1)*tau_exponent))
figure(figsize=(10,5))
plt.plot(tau_motoneurons, color='C0')
plt.title("Time constant (tau) distribution of simulated MNs")
plt.xlabel("Motoneuron index (smallest simulated MN is 0 ; largest simulated MN is "+str(nb_motoneurons-1)+")")
plt.ylabel("Time constant (ms)")
plt.show()

# Input resistance and input weight (normalized resistance)
resistance_motoneurons = np.zeros(nb_motoneurons)
input_weight_motoneurons = np.zeros(nb_motoneurons)
for mni in range(nb_motoneurons):
    resistance_motoneurons[mni] = resistance_constant*(motoneuron_soma_diameters[mni]**((-1)*resistance_exponent))
    input_weight_motoneurons[mni] = resistance_motoneurons[mni]/resistance_motoneurons[0] 

fig, ax1 = plt.subplots(figsize=(10,5))
ax1.plot(resistance_motoneurons, color='C1', label = 'Input resistance')
ax1.set_ylabel("Resistance (ohms)", color='C1')
ax2 = ax1.twinx()
ax2.plot(input_weight_motoneurons, color='C3', label = 'Input weight')
ax2.set_ylabel("Input weight", color='C3')
ax2.set_ylim([0,1])
ax1.legend(loc='upper right', bbox_to_anchor=(1, 1))
ax2.legend(loc='upper right', bbox_to_anchor=(1, 0.9))
ax1.set_xlabel("Motoneuron index (smallest simulated MN is 0 ; largest simulated MN is "+str(nb_motoneurons-1)+")")
plt.title("Resistance distribution of simulated MNs")
plt.show()


In [None]:
# To easily generate legend labels according to the muscle names
def transform_list(lst, n):
    from collections import Counter
    
    # Count occurrences of each unique element
    counts = Counter(lst)
    
    # Create a dictionary to track the current index for each element
    index_tracker = {key: 1 for key in counts.keys()}
    
    # Initialize the transformed list
    transformed_list = []
    
    # Iterate over the original list and append the indexed elements
    for item in lst:
        for i in range(1, n + 1):
            transformed_list.append(f"{item} - input #{i}")
        # Increment the index for the current item
        index_tracker[item] += 1
    
    return transformed_list

In [None]:
# Define a low-pass filter
def butter_lowpass(cutoff, fs, order=5):
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a
def lowpass_filter(data, cutoff, fs, order=5):
    b, a = butter_lowpass(cutoff, fs, order=order)
    y = filtfilt(b, a, data)
    return y

In [None]:
# Import Matlab files with principal components (derived from experimental data) used as input
mat_files = {}
downsampled_input_PCs = {}
for filei in range(nb_of_pools_to_simulate):
    downsampled_input_PCs[filei] = []
    if (filei < 2): # VI is 2, because of the zero-ndexing start
        mat_files[filei] = scipy.io.loadmat(str(path_for_input_files+filename_for_input_file[filei])) # uncomment when needed
        # Access variables: data = mat_file['variable_name']
        input_sample_size = size(mat_files[filei]['PCA_components'],0) # number of samples
        input_fsamp = 2048
        num_samples_for_resampling = int(input_sample_size * fsamp / input_fsamp)
        for factori in range(nb_common_inputs):
            if experimental_data_PC_or_factor_as_common_inputs == 'PC':
                current_PC_temp = mat_files[filei]['PCA_components'][:,factori]
            elif experimental_data_PC_or_factor_as_common_inputs == 'factor':
                current_PC_temp = mat_files[filei]['FA_factors'][0,nb_common_inputs-1][:,factori]
            resampled_signal_temp = resample(current_PC_temp,num_samples_for_resampling)
            # cut signal to be of the appropriate duration
            resampled_signal_temp = resampled_signal_temp[0:int(fsamp*duration_with_ignored_window)]
            # Z-transform the signal (mean-center, std of 1)
            resampled_signal_temp = (resampled_signal_temp - np.mean(resampled_signal_temp)) / np.std(resampled_signal_temp)
            downsampled_input_PCs[filei].append(resampled_signal_temp)
    elif (filei==2): # for VI, mix 1/3 VL component, 1/3 VM component, 1/3 noise 2.5hz low-pass filtered
        for factori in range(nb_common_inputs):
            temp_common_input_signal = downsampled_input_PCs[filei-2][factori]
            temp_common_input_signal = temp_common_input_signal+ downsampled_input_PCs[filei-1][factori]
            temp_common_input_signal = temp_common_input_signal / 2
            noise_temp = np.random.normal(0, 1, int(duration_with_ignored_window * fsamp))
            # remove noise edges
            noise_temp[0:int(np.round(window_beginning_ignore))] = 0
            noise_temp[len(noise_temp)-int(np.round(window_end_ignore)):len(noise_temp)] = 0
            noise_temp  = lowpass_filter(noise_temp , 2.5, fsamp)
            noise_temp = (noise_temp-np.mean(noise_temp))/np.std(noise_temp)
            temp_common_input_signal = temp_common_input_signal + (noise_temp/2)
            downsampled_input_PCs[filei].append(temp_common_input_signal)

legend_labels = transform_list(muscle_names, nb_common_inputs)

plt.figure()
for filei in range(nb_of_pools_to_simulate):
    for inputi in range(nb_common_inputs):
        plt.plot(downsampled_input_PCs[filei][inputi], color = muscle_colors[filei], alpha = 1/nb_common_inputs)
plt.ylabel("Amplitude (arbitrary, z-transformed)")
plt.xlabel("Time (ms)")
plt.title(f"Input signal(s)")
plt.legend(legend_labels)
plt.show()

# Check power spectrum of common input
common_input_power_integral = []
plt.figure()
for filei in range(nb_of_pools_to_simulate):
    common_input_power_integral = []
    for inputi in range(nb_common_inputs):
        N = len(downsampled_input_PCs[filei][inputi])
        yf = fft(downsampled_input_PCs[filei][inputi])
        xf = fftfreq(N, 1 / fsamp)
        power_spectrum_temp = (np.abs(yf[:N//2])**2) / N
        plt.plot(xf[:N//2], power_spectrum_temp, color = muscle_colors[filei], alpha = 1/nb_common_inputs)
        common_input_power_integral.append(np.sum(power_spectrum_temp))
        print(f"Common input #{inputi} integral ({muscle_names[filei]}) = {common_input_power_integral[inputi]}")
    common_input_power_integral = np.sum(common_input_power_integral)/len(common_input_power_integral)
    print(f"Mean common input integral (across all inputs) ({muscle_names[filei]}) = {common_input_power_integral}")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power")
plt.title(f"Power spectrum of the common input")
plt.xlim([0,5])
plt.legend(legend_labels)
new_filename = f'Signal_common_input_power_spectrum.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

In [None]:
## GENERATE A COMMON INPUT IN THE 0-30HZ BANDWIDTH THAT IS COMMON TO ALL MNs WITHIN A POOL
common_noise_MN_specific_bandwidth = {}
for pooli in range(nb_of_pools_to_simulate):
    common_noise_MN_specific_bandwidth[pooli] = np.random.normal(0, 1, int(duration_with_ignored_window * fsamp))
    # Apply low-pass filter to the Gaussian noise
    common_noise_MN_specific_bandwidth[pooli]  = lowpass_filter(common_noise_MN_specific_bandwidth[pooli] , common_noise_max_freq, fsamp)

plt.figure()
for pooli in range(nb_of_pools_to_simulate):
    plt.plot(common_noise_MN_specific_bandwidth[pooli], color = muscle_colors[pooli], linewidth = 1.5, alpha =0.7)
plt.xlabel("Time (ms)")
plt.ylabel("Net input")
plt.title(f"Common noise with specific 0-{common_noise_max_freq}hz bandwidth to MNs")
plt.legend(muscle_names)
plt.show()

# Check power spectrum of common noise
plt.figure()
common_noise_power_integral = {}
for pooli in range(nb_of_pools_to_simulate):
    N = len(common_noise_MN_specific_bandwidth[pooli])
    yf = fft(common_noise_MN_specific_bandwidth[pooli])
    xf = fftfreq(N, 1 / fsamp)
    # Plot the frequency domain signal
    power_spectrum_temp = (np.abs(yf[:N//2])**2) / N
    plt.plot(xf[:N//2], power_spectrum_temp, color = muscle_colors[pooli], linewidth = 1.5, alpha =0.7)
    common_noise_power_integral[pooli] = np.sum(power_spectrum_temp)
    print(f"Common noise power integral ({muscle_names[pooli]}) = {common_noise_power_integral[pooli]}")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power")
plt.title(f"Power spectrum of the common noise (not boosted)")
plt.xlim([0,50])
plt.legend(muscle_names)
plt.show()


# Boosting common noise power to match common input power
power_common_input_noise_ratio = {}
plt.figure()
for pooli in range(nb_of_pools_to_simulate):
    power_common_input_noise_ratio[pooli] = np.sqrt(common_input_power_integral / common_noise_power_integral[pooli])
    common_noise_MN_specific_bandwidth[pooli] = common_noise_MN_specific_bandwidth[pooli] * power_common_input_noise_ratio[pooli]
    plt.plot(common_noise_MN_specific_bandwidth[pooli], color=muscle_colors[pooli], linewidth = 1.5, alpha =0.7)
plt.xlabel("Time (ms)")
plt.ylabel("Net input")
plt.title(f"Common noise (boosted) with specific 0-{common_noise_max_freq}hz bandwidth to MNs")
plt.legend(muscle_names)
new_filename = f'Signal_common_noise_input.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

# Check power spectrum of common noise
plt.figure()
for pooli in range(nb_of_pools_to_simulate):
    N = len(common_noise_MN_specific_bandwidth[pooli])
    yf = fft(common_noise_MN_specific_bandwidth[pooli])
    xf = fftfreq(N, 1 / fsamp)
    # Plot the frequency domain signal
    power_spectrum_temp = (np.abs(yf[:N//2])**2) / N
    plt.plot(xf[:N//2], power_spectrum_temp, color = muscle_colors[pooli], linewidth = 1.5, alpha =0.7)
    common_noise_power_integral = np.sum(power_spectrum_temp)
    print(f"Boosted common noise power integral ({muscle_names[pooli]}) = {common_noise_power_integral}")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power")
plt.title("Power spectrum of the common noise (boosted)")
plt.xlim([0,50])
plt.legend(muscle_names)
new_filename = f'Signal_common_noise_power_spectrum.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()


In [None]:
## COMMON INPUTS AND COMMON INPUTS DISTRIBUTION

# Define a time-varying input current (example: sinusoidal current)
time = np.arange(0, int(duration_with_ignored_window/ms), 1) * ms
common_input_current = {}

plt.figure()
for pooli in range(nb_of_pools_to_simulate):
    common_input_current[pooli] = []
    for inputi in range(nb_common_inputs):
        common_input_current[pooli].append(downsampled_input_PCs[pooli][inputi])
    plt.plot(np.transpose(common_input_current[pooli]), color = muscle_colors[pooli], alpha = 1/nb_common_inputs)
plt.xlabel("Time (ms)")
plt.ylabel("Common input(s) - arbitrary amplitude")
plt.legend(legend_labels)
new_filename = f'Signal_common_inputs.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

# Distribution of common inputs
common_inputs_distrib = {}
fig, ax = plt.subplots(1,nb_of_pools_to_simulate, figsize=(15,4))
for pooli in range(nb_of_pools_to_simulate):
    common_inputs_distrib_identity_mat = np.eye(nb_common_inputs) # Create identity matrix
    interpol_range = linspace(0,1,nb_motoneurons)
    common_inputs_distrib[pooli] = []
    for inputi in range(nb_common_inputs):
        common_inputs_distrib[pooli].append(
            np.interp(interpol_range, linspace(0,1,nb_common_inputs),common_inputs_distrib_identity_mat[inputi,:]))
    common_inputs_distrib[pooli] = np.vstack(common_inputs_distrib[pooli])
    # Shuffle distribution of inputs (useful if simulating different types of motoneurons, so that it's not the same MN always receiving the same input)
    shuffled_mu_indices = np.random.permutation(shape(common_inputs_distrib[pooli])[1])
    common_inputs_distrib[pooli] = common_inputs_distrib[pooli][:,shuffled_mu_indices]

    common_inputs_distrib[pooli] = [row.tolist() for row in common_inputs_distrib[pooli]]
    
    x_plot_mns = range(nb_motoneurons)
    bottom_barplot = np.zeros(nb_motoneurons)
    colormap_temp = cm.get_cmap(muscle_colormaps[pooli])
    for inputi in range(nb_common_inputs):
        ax[pooli].bar(x_plot_mns,
            common_inputs_distrib[pooli][inputi],
            color = colormap_temp(0.5+(inputi/(nb_common_inputs*3))),
            bottom = bottom_barplot)
        bottom_barplot += common_inputs_distrib[pooli][inputi]
    ax[pooli].set_xlabel('Motoneurons')
    ax[pooli].set_ylabel('Proportion of common inputs')
    ax[pooli].set_title(f'{muscle_names[pooli]}')
plt.tight_layout(rect=[0,0,1,0.96])
plt.suptitle("Distribution of common excitatory inputs")
new_filename = f'Signal_common_input_distribution_to_MN.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

In [None]:
# Generate independent input to MNs
independent_input_noise = {}
plt.figure()
for pooli in range(nb_of_pools_to_simulate):
    independent_input_noise[pooli] = randn(nb_motoneurons,len(time)) # noise input, with mean zero and std 1 (default setting)
    # Plot the frequency domain independent noise signal
    # Apply low-pass filter to the Gaussian noise
    for mni in range(nb_motoneurons):
        independent_input_noise[pooli][mni,:] = lowpass_filter(independent_input_noise[pooli][mni,:], independent_noise_max_freq, fsamp)
        # Normalize power with common input - get power with FFT
        N = len(independent_input_noise[pooli][mni,:])
        yf = fft(independent_input_noise[pooli][mni,:])
        xf = fftfreq(N, 1 / fsamp)
        power_spectrum_temp = (np.abs(yf[:N//2])**2) / N
        independent_noise_power_integral = np.sum(power_spectrum_temp)
        power_ratio = np.sqrt(common_input_power_integral / independent_noise_power_integral)
        independent_input_noise[pooli][mni,:] = independent_input_noise[pooli][mni,:] * power_ratio
        plt.plot(xf[:N//2], power_spectrum_temp * power_ratio**2, color = muscle_colors[pooli], alpha = 0.03/nb_of_pools_to_simulate)
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power")
plt.title("Power spectrum of the independent noise inputs (power normalized to common input power)")
plt.xlim([0,80])
new_filename = f'Signal_independent_noise_power_spectrum.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()


In [None]:
# FINAL INPUT TO EACH MOTONEURON, IN EACH POOL
input_array = {}
all_cells_input_array = np.array([])

for pooli in range(nb_of_pools_to_simulate):
    input_array[pooli] = np.ones(len(time))*common_input_current_baseline
    input_array[pooli] = np.repeat([input_array[pooli]], nb_motoneurons, axis=0)
    for mni in range(nb_motoneurons):
        for inputi in range(nb_common_inputs):
            # add common input
            input_array[pooli][mni,:] = input_array[pooli][mni,:] + (
                common_input_current[pooli][inputi] * common_input_std  * common_input_to_independent_noise_ratio
                * common_input_to_common_noise_ratio * common_inputs_distrib[pooli][inputi][mni])
        # add common noise
        input_array[pooli][mni,:] = input_array[pooli][mni,:] + (
            common_noise_MN_specific_bandwidth[pooli] * common_input_std * common_input_to_independent_noise_ratio
            * (1-common_input_to_common_noise_ratio))
    # add independent noise
    input_array[pooli] = input_array[pooli] + (independent_input_noise[pooli] * independent_input_std * (1-common_input_to_independent_noise_ratio))
    input_array[pooli] = np.transpose(input_array[pooli])
    if all_cells_input_array.size == 0:
        all_cells_input_array = np.array(input_array[pooli])
    else:
        all_cells_input_array = np.column_stack((all_cells_input_array,input_array[pooli]))

I = TimedArray(all_cells_input_array,dt=1*ms)

figure()
plot(I(time,0), linewidth = 1)
xlabel("Time (ms)")
ylabel("Net input")
title("Input to simulated MN#0 of pool #0 (common input & noise + independent noise)")


In [None]:
# Generate independent input to RCs
independent_input_RC_noise = {}
all_cells_independent_input_RC_noise = np.array([])
for pooli in range(nb_of_pools_to_simulate):
    independent_input_RC_noise[pooli] = np.ones(len(time))*Renshaw_cell_excit_noise_input_mean # mean offset (baseline input to RC)
    independent_input_RC_noise[pooli] = np.repeat([independent_input_RC_noise[pooli]], nb_renshaw, axis=0)
    independent_input_RC_noise[pooli] = independent_input_RC_noise[pooli] + (randn(nb_renshaw,len(time))*Renshaw_cell_excit_noise_input_std) # Add noise
    independent_input_RC_noise[pooli] = np.transpose(independent_input_RC_noise[pooli])
    if all_cells_independent_input_RC_noise.size == 0:
        all_cells_independent_input_RC_noise = np.array(independent_input_RC_noise[pooli])
    else:
        all_cells_independent_input_RC_noise = np.column_stack((all_cells_independent_input_RC_noise,independent_input_RC_noise[pooli]))

I_RC = TimedArray(all_cells_independent_input_RC_noise , dt=1*ms)

figure()
plot(I_RC(time,0), linewidth=1)
xlabel("Time (ms)")
ylabel("Net input")
title("Input to simulated RC#0 from pool #0 (independent noise)")

In [None]:
# Define the softmax function, which will ensure that each MN receives an equal amount of inhibition from RCs on average (if the option is set to true)
def softmax_with_temperature(logits, temperature=1.0):
    """
    Compute the softmax of a list of logits with a temperature parameter.

    Parameters:
    logits (list or numpy array): The input logits.
    temperature (float): The temperature parameter.

    Returns:
    numpy array: The softmax probabilities.
    """
    # Convert logits to numpy array if they are not already
    logits = np.array(logits)
    
    # Apply the temperature parameter
    logits = logits / temperature
    
    # Compute the exponentials of the scaled logits
    exp_logits = np.exp(logits - np.max(logits))  # Subtract max for numerical stability
    
    # Compute the softmax probabilities
    softmax_probs = exp_logits / np.sum(exp_logits)
    
    return softmax_probs

In [None]:
# When rounding 2.7 for example, 70% to get 3 and 30% chance to get 2
def probabilistic_round(number):
    lower = int(number)  # The lower integer
    upper = lower + 1    # The upper integer
    decimal_part = number - lower
    
    return upper if random.random() < decimal_part else lower

In [None]:
# Create connetivity matrix
MN_to_Renshaw_connectivity_matrix = np.zeros([nb_motoneurons*nb_of_pools_to_simulate,nb_renshaw*nb_of_pools_to_simulate])
Renshaw_to_MNs_connectivity_matrix = np.zeros([nb_renshaw*nb_of_pools_to_simulate,nb_motoneurons*nb_of_pools_to_simulate])

idx_of_MNs_according_to_pools = {}
idx_of_RCs_according_to_pools = {}
for pooli in range(nb_of_pools_to_simulate):
    idx_of_MNs_according_to_pools[pooli] = np.arange(nb_motoneurons*pooli,nb_motoneurons*(pooli+1))
    idx_of_RCs_according_to_pools[pooli] = np.arange(nb_renshaw*pooli,nb_renshaw*(pooli+1))

# Build connectivity
for pooli in range(nb_of_pools_to_simulate):
    other_pools_idx = np.arange(nb_of_pools_to_simulate)
    other_pools_idx = np.delete(other_pools_idx, pooli)
    # within pool
    nb_Renshaw_to_connect_temp = round(Renshaw_to_MN_connectivity_probability_within_pool*nb_motoneurons)
    nb_motoneurons_to_connect_temp = round(MN_to_Renshaw_connectivity_probability_within_pool*nb_motoneurons)
    for renshawi in idx_of_RCs_according_to_pools[pooli]:
        random_idx_to_connect = [int(x) for x in random.sample(idx_of_MNs_according_to_pools[pooli].tolist(), nb_motoneurons_to_connect_temp)]
        MN_to_Renshaw_connectivity_matrix[random_idx_to_connect,renshawi.astype(int)] = 1
        random_idx_to_connect = [int(x) for x in random.sample(idx_of_MNs_according_to_pools[pooli].tolist(), nb_Renshaw_to_connect_temp)]
        Renshaw_to_MNs_connectivity_matrix[renshawi.astype(int),random_idx_to_connect] = 1
    # across pools
    for heteronymous_pooli in other_pools_idx:
        nb_Renshaw_to_connect_temp = round(Renshaw_to_MN_connectivity_probability_across_pool*nb_motoneurons)
        nb_motoneurons_to_connect_temp = round(MN_to_Renshaw_connectivity_probability_across_pool*nb_motoneurons)
        for renshawi in idx_of_RCs_according_to_pools[pooli]:
            random_idx_to_connect = [int(x) for x in random.sample(idx_of_MNs_according_to_pools[heteronymous_pooli].tolist(), nb_motoneurons_to_connect_temp)]
            MN_to_Renshaw_connectivity_matrix[random_idx_to_connect,renshawi.astype(int)] = 1
            random_idx_to_connect = [int(x) for x in random.sample(idx_of_MNs_according_to_pools[heteronymous_pooli].tolist(), nb_Renshaw_to_connect_temp)]
            Renshaw_to_MNs_connectivity_matrix[renshawi.astype(int),random_idx_to_connect] = 1

# Normalize connectivity weights
if equal_MN_to_RC_connectivity_for_all_MN==True:
    for rci in range(nb_renshaw):
        MN_to_Renshaw_connectivity_matrix[:,rci] = softmax_with_temperature(MN_to_Renshaw_connectivity_matrix[:,rci],0.1)

if equal_RC_to_MN_connectivity_for_all_MN==True:
    for mni in range(nb_motoneurons):
        Renshaw_to_MNs_connectivity_matrix[:,mni] = softmax_with_temperature(Renshaw_to_MNs_connectivity_matrix[:,mni],0.1)

# Plot the connectivity arrays
plt.figure(figsize=(5,10))
plt.imshow(MN_to_Renshaw_connectivity_matrix, vmin = 0.0, vmax = np.max(MN_to_Renshaw_connectivity_matrix), cmap='gray')
plt.colorbar()
plt.title("MN (y axis) to Renshaw (x axis) connectivity \n black = unconnected, colored = connected")
plt.xlabel("Renshaw cells")
plt.ylabel("Motoneurons")
plt.xticks([])
plt.yticks([])
new_filename = f'Connectivity_MN_to_RC_B&W.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

plt.figure(figsize=(10,5))
plt.imshow(Renshaw_to_MNs_connectivity_matrix, vmin = 0.0, vmax = np.max(Renshaw_to_MNs_connectivity_matrix), cmap='gray')
plt.colorbar()
plt.title("Renshaw (x axis) to MN (y axis) connectivity \n black = unconnected, colored = connected")
plt.ylabel("Renshaw cells")
plt.xlabel("Motoneurons")
plt.xticks([])
plt.yticks([])
new_filename = f'Connectivity_RC_to_MN_B&W.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

In [None]:
# Set the motoneuron properties values
refractory_period_slow_MN = refractory_period_MN
refractory_period_fast_MN = refractory_period_MN
refactory_period_values = linspace(refractory_period_slow_MN  ,refractory_period_fast_MN , nb_motoneurons)

# Initialize network - groups of motoneurons and synapses
eqs_motoneuron = '''
dv/dt = ((I(t,i)*input_weight)-v)/tau: 1 (unless refractory)
tau : second
refractory_period : second
input_weight : 1
'''
eqs_renshaw = '''
dv/dt = (I_RC(t,i)-v)/tau: 1 (unless refractory)
tau : second
'''
# Groups of neurons
motoneurons = NeuronGroup(nb_motoneurons*nb_of_pools_to_simulate, eqs_motoneuron,
                          threshold='v>voltage_thresh',
                          reset='v=voltage_rest',
                          refractory='refractory_period',
                          method='exact')
for pooli in range(nb_of_pools_to_simulate):
    motoneurons[idx_of_MNs_according_to_pools[pooli]].tau = tau_motoneurons*(1/1000)*second # convert to ms
    motoneurons[idx_of_MNs_according_to_pools[pooli]].refractory_period = refactory_period_values
    motoneurons[idx_of_MNs_according_to_pools[pooli]].input_weight = input_weight_motoneurons

renshaw_cells = NeuronGroup(nb_renshaw*nb_of_pools_to_simulate, eqs_renshaw,
                          threshold='v>1',
                          reset='v=0',
                          refractory=refractory_period_RC,
                          method='exact')
renshaw_cells.tau = tau_Renshaw

# Display MN properties
plt.plot(motoneurons.tau)
plt.plot(motoneurons.input_weight)
plt.plot(motoneurons.refractory_period)
plt.xlabel("MN indx")
plt.title("MNs properties")
plt.legend(["Time constant (seconds)","Input weight (0-1 scaling factor)","refractory period (seconds)"])
new_filename = f'Properties_MN.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

# Synapses (connectivity) between neurons
S_MN_to_Renshaw = Synapses(motoneurons, renshaw_cells, 'w : 1',
                           on_pre='v += MN_to_Renshaw_excit*w',
                           delay = synpatic_delay)
S_Renshaw_to_MN = Synapses(renshaw_cells, motoneurons, 'w : 1',
                           on_pre='v += Renshaw_to_MN_inhib*w',
                           delay = synpatic_delay)
# Create connetivity network
pre_indices, post_indices = np.nonzero(MN_to_Renshaw_connectivity_matrix)
weights_to_assign = MN_to_Renshaw_connectivity_matrix[pre_indices,post_indices]
S_MN_to_Renshaw.connect(i=pre_indices, j=post_indices)
S_MN_to_Renshaw.w = weights_to_assign
# S_MN_to_Renshaw.w = 'MN_to_Renshaw_connectivity_matrix[i,j]'
pre_indices, post_indices = np.nonzero(Renshaw_to_MNs_connectivity_matrix)
weights_to_assign = Renshaw_to_MNs_connectivity_matrix[pre_indices,post_indices]
S_Renshaw_to_MN.connect(i=pre_indices, j=post_indices)
S_Renshaw_to_MN.w = weights_to_assign

# visualize as scatter plot (show weights) = MN to Renshaw
plt.figure(figsize=(12,6))
pre_indices, post_indices = np.nonzero(MN_to_Renshaw_connectivity_matrix)
plt.scatter(pre_indices, post_indices, S_MN_to_Renshaw.w*10)
plt.xlabel('Source neuron (motoneuron)')
plt.ylabel('Target neuron (Renshaw)')
plt.title("Scatter plot of motoneuron to Renshaw connectivity matrix")
new_filename = f'Connectivity_MN_to_RC.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()
# visualize as scatter plot (show weights) = Renshaw to MN
plt.figure(figsize=(6,12))
pre_indices, post_indices = np.nonzero(Renshaw_to_MNs_connectivity_matrix)
plt.scatter(pre_indices, post_indices, S_Renshaw_to_MN.w*10, color = 'C4')
plt.xlabel('Source neuron (Renshaw)')
plt.ylabel('Target neuron (motoneuron)')
plt.title("Scatter plot of Renshaw to motoneuron connectivity matrix")
new_filename = f'Connectivity_RC_to_MN.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()
    

# Monitors
monitor_state_motoneurons = StateMonitor(motoneurons, variables=True, record=True)
monitor_spikes_motoneurons = SpikeMonitor(motoneurons, record=True)
monitor_state_renshaw = StateMonitor(renshaw_cells, variables=True, record=True)
monitor_spikes_renshaw = SpikeMonitor(renshaw_cells, record=True)

In [None]:
# Initialize voltage of MNs and Renshaw cells randomly
motoneurons.v = rand() # uniform distribution between 0 and 1
renshaw_cells.v = rand()

In [None]:
# Run sim
run(duration_with_ignored_window)

# Careful! Re-running the simulation without resetting the parameters will result in the simulation simply being ADDED to the previous sim

In [None]:
## Plot results for Renshaw cells
figure(num=0,figsize=(20,3))
plot(monitor_state_renshaw.t/second, monitor_state_renshaw.v[0], label='v')
xlabel('Time (s)')
ylabel('Voltage')
title("Example voltage of RC#0")

figure(num=1,figsize=(20,3))
plot(monitor_spikes_renshaw.t/second, monitor_spikes_renshaw.i, '.k')
xlabel('Time (s)')
ylabel('Renshaw cell index')
title("Raster plot of RCs")

In [None]:
## Plot results for MNs
figure(num=0,figsize=(20,3))
# figure(num=0)
plot(monitor_state_motoneurons.t/second, monitor_state_motoneurons.v[0], label='v')
xlabel('Time (s)')
ylabel('Voltage')
title("Voltage (MN #0)")

figure(num=1,figsize=(20,3))
plot(monitor_spikes_motoneurons.t/second, monitor_spikes_motoneurons.i, '.k')
xlabel('Time (s)')
ylabel('Motoneuron index')
title("Raster plot of motoneuron spikes")

In [None]:
### Get discharge characteristics of Renshaw cells
spike_trains_RC = {}
firing_rates_RC = {}
mean_firing_rate_RC = {}
std_firing_rate_RC = {}
fig, axs = plt.subplots(1,nb_of_pools_to_simulate)
for pooli in range(nb_of_pools_to_simulate):
    # Retrieve spikes
    spike_trains_RC[pooli] = []
    for rci in range(nb_renshaw):
        spike_trains_RC[pooli].append(monitor_spikes_renshaw.spike_trains()[idx_of_RCs_according_to_pools[pooli][rci]])
    # Calculate the firing rate for each Renshaw cell
    firing_rates_RC[pooli] = []
    for rci in range(nb_renshaw):
        firing_rate_temp = len(spike_trains_RC[pooli][rci]) / duration_with_ignored_window
        firing_rates_RC[pooli].append(firing_rate_temp)
    # Convert to a numpy array for easier calculations
    firing_rates_RC[pooli] = np.array(firing_rates_RC[pooli])
    # Calculate mean and standard deviation of the firing rates
    mean_firing_rate_RC[pooli] = np.mean(firing_rates_RC[pooli])
    std_firing_rate_RC[pooli] = np.std(firing_rates_RC[pooli])
    # Renshaw cells' firing rates results
    print(f"Mean firing rate of Renshaw cells (pool #{pooli}): {mean_firing_rate_RC[pooli]:.2f} Hz")
    print(f"Standard deviation of firing rates of Renshaw cells (pool #{pooli}): {std_firing_rate_RC[pooli]:.2f} Hz")
    axs[pooli].hist(firing_rates_RC[pooli], edgecolor='white', color=muscle_colors[pooli], alpha=0.75)
    axs[pooli].axvline(x = mean_firing_rate_RC[pooli], color = muscle_colors[pooli], linestyle='--', linewidth=2, label='Mean firing rate')
    axs[pooli].set_xlabel("Mean firing rate (pps)")
    axs[pooli].set_ylabel("Renshaw cell count")
    axs[pooli].set_title(f"{muscle_names[pooli]}")
plt.tight_layout(rect=[0,0,1,0.96])
plt.suptitle("Histogram of renshaw cells' firing rate")
new_filename = f'Hist_RC_Discharge_rates.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show(fig)
    


In [None]:
### Get discharge characteristics of motoneurons
spike_trains = {}
firing_rates = {}
highest_ISIs = {}
mean_firing_rate = {}
std_firing_rate = {}
fig, axs = plt.subplots(1,nb_of_pools_to_simulate)
for pooli in range(nb_of_pools_to_simulate):
    # Retrieve spikes
    spike_trains[pooli] = []
    for mni in range(nb_motoneurons):
        spike_trains[pooli].append(monitor_spikes_motoneurons.spike_trains()[idx_of_MNs_according_to_pools[pooli][mni]])
    # Calculate the firing rate for each neuron
    firing_rates[pooli] = []
    highest_ISIs[pooli] = []
    for mni in range(nb_motoneurons):
        if len(spike_trains[pooli][mni]) <= 1:
            highest_ISIs[pooli].append(duration_with_ignored_window)
        else:
            highest_ISIs[pooli].append(max(diff(spike_trains[pooli][mni])))
        firing_rate_temp = len(spike_trains[pooli][mni]) / duration_with_ignored_window
        firing_rates[pooli].append(firing_rate_temp)
    # Convert to a numpy array for easier calculations
    firing_rates[pooli] = np.array(firing_rates[pooli])
    # Calculate mean and standard deviation of the firing rates
    mean_firing_rate[pooli] = np.mean(firing_rates[pooli])
    std_firing_rate[pooli] = np.std(firing_rates[pooli])
    # Motoneurons' firing rates results
    axs[pooli].hist(firing_rates[pooli], edgecolor='white', color=muscle_colors[pooli], alpha=0.75)
    axs[pooli].axvline(x = mean_firing_rate[pooli], color = muscle_colors[pooli], linestyle='--', linewidth=2, label='Mean firing rate')
    axs[pooli].set_xlabel("Mean firing rate (pps)")
    axs[pooli].set_ylabel("Motoneuron count count")
    axs[pooli].set_title(f"{muscle_names[pooli]}")
plt.tight_layout(rect=[0,0,1,0.96])
plt.suptitle("Histogram of motoneurons' firing rate (all motoneurons)")

In [None]:
### SELECT ONLY VALID (continuous) MOTONEURONS
discontinuous_MUs_idx = {}
valid_MUs_idx = {}
fig, axs = plt.subplots(1,nb_of_pools_to_simulate)

for pooli in range(nb_of_pools_to_simulate):
# Index of discontinuous MNs (ISIs > 0.4)
    discontinuous_MUs_idx[pooli] = [i for i, x in enumerate(highest_ISIs[pooli]) if x > ISI_threshold_for_discontinuity]
    discontinuous_MUs_idx[pooli] = append(discontinuous_MUs_idx[pooli],
                               [i for i, x in enumerate(arange(nb_motoneurons)) if len(spike_trains[pooli][x])<20]) # remove MUs with less than X spikes
    discontinuous_MUs_idx[pooli] = unique(discontinuous_MUs_idx[pooli])

    valid_MUs_idx[pooli] = [i for i, x in enumerate(arange(nb_motoneurons)) if x not in discontinuous_MUs_idx[pooli]]
    print("Number of invalid MUs (discontinuous, max(ISI) > ",ISI_threshold_for_discontinuity,", or MUs with too few spikes ) = ", len(discontinuous_MUs_idx[pooli]), " out of ", nb_motoneurons," (",muscle_names[pooli],")")
    axs[pooli].hist(highest_ISIs[pooli], edgecolor='white', color=muscle_colors[pooli], alpha=0.5)
    axs[pooli].axvline(x = np.mean(highest_ISIs[pooli]), color = muscle_colors[pooli], linestyle='--', linewidth=2, label='Mean highest ISI')
    axs[pooli].axvline(x = ISI_threshold_for_discontinuity/second, color = 'black', linestyle='-', linewidth=2, label='Mean firing rate')
    axs[pooli].set_xlabel("Max ISI (s)")
    axs[pooli].set_ylabel("Motoneuron count")
    axs[pooli].set_title(f"{muscle_names[pooli]}")
plt.tight_layout(rect=[0,0,1,0.92])
plt.suptitle("Histogram of motoneurons' max ISI \n (colored line is mean ; black line is threshold)")
new_filename = f'Hist_MN_ISIs.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show(fig)
    


In [None]:
# Firing rate results - only valid (continuous) MNs
mean_firing_rate_valid = {}
std_firing_rate_valid = {}
fig, axs = plt.subplots(1,nb_of_pools_to_simulate)
for pooli in range(nb_of_pools_to_simulate):
    axs[pooli].hist(firing_rates[pooli][valid_MUs_idx[pooli]], edgecolor='white', color=muscle_colors[pooli], alpha=0.75)
    mean_firing_rate_valid[pooli] = np.mean(firing_rates[pooli][valid_MUs_idx[pooli]])
    std_firing_rate_valid[pooli] = np.mean(firing_rates[pooli][valid_MUs_idx[pooli]])
    axs[pooli].axvline(x = mean_firing_rate_valid[pooli], color = muscle_colors[pooli], linestyle='--', linewidth=2, label='Mean firing rate')
    axs[pooli].set_xlabel("Mean firing rate (pps)")
    axs[pooli].set_ylabel("Motoneuron count count")
    axs[pooli].set_title(f"{muscle_names[pooli]}")
plt.tight_layout(rect=[0,0,1,0.96])
plt.suptitle("Histogram of motoneurons' firing rate (only valid motoneurons)")
new_filename = f'Hist_MN_Discharge_rates.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show(fig)
    

In [None]:
## SMOOTHING SPIKE TRAINS

# Define time bins
time_bins = np.arange(0, duration_with_ignored_window/ms) * ms
# Setting the kernel
Wind_s = 0.4  # hanning window duration. 0.4 for 2.5hz low-pass, 0.2 for 5hz low-pass
HanningW = 2 / round(fsamp * Wind_s) * windows.hann(round(fsamp * Wind_s))  # unitary area

binary_spike_trains = {}
smoothed_signal = {}
for pooli in range(nb_of_pools_to_simulate):
    # Initialize the binary spike train array
    binary_spike_trains[pooli] = np.zeros((nb_motoneurons, len(time_bins)))
    # Convert spike times to binary spike train
    for neuron_idx in range(nb_motoneurons):
        spikes = spike_trains[pooli][neuron_idx]
        spike_indices = np.searchsorted(time_bins, spikes)
        binary_spike_trains[pooli][neuron_idx, spike_indices-1] = 1 #-1 because of offset due to 0-indexing

    smoothed_signal[pooli] = []
    for mni in range(nb_motoneurons):
        smoothed_signal[pooli].append(filtfilt(HanningW, 1, binary_spike_trains[pooli][mni, :] * fsamp))
    smoothed_signal[pooli] = np.array(smoothed_signal[pooli])

    # cut edges of signal
    smoothed_signal[pooli] = smoothed_signal[pooli][:, :-int(window_end_ignore * fsamp)]  # remove the end
    smoothed_signal[pooli] = smoothed_signal[pooli][:, int(window_beginning_ignore * fsamp):]  # remove the start

    fig, ax = plt.subplots(figsize=(10, 5))
    # Getting a smooth color blend from a given colormap
    colormap_temp = cm.get_cmap(muscle_colormaps[pooli])
    for mni in range(nb_motoneurons):
        if mni in valid_MUs_idx[pooli]:
            ax.plot(transpose(smoothed_signal[pooli])[:,mni], color=colormap_temp(mni/(nb_motoneurons-1)))
    plt.title(f"Smoothed signals of only valid MUs ({muscle_names[pooli]}) \n (dark = small MNs ; light = large MNs)")
    plt.ylabel("Smoothed discharge rate (pps)")
    plt.xlabel("Time (ms)")
    new_filename = f'Discharge_rates_{muscle_names[pooli]}.png'
    save_file_path = os.path.join(new_directory, new_filename)
    plt.savefig(save_file_path)
    plt.show(fig)
    

In [None]:
## DIMENSIONALITY REDUCTION
# Do it only on the VL population (quite long otherwise)

## FA ######
max_nb_of_factors_to_extract = 10
loadings = []
scores = []
reconstructed_Rsquared = []
for factori in range(max_nb_of_factors_to_extract-1):
    fa = FactorAnalyzer(n_factors=factori+1, rotation='promax')
    fa.fit(transpose(smoothed_signal[0][valid_MUs_idx[0]]))
    loadings.append(fa.loadings_)
    scores.append(fa.transform(transpose(smoothed_signal[0][valid_MUs_idx[0]])))
    # get Rsquared value of FA's data reconstruction
    reconstructed_data_temp = np.dot(loadings[factori], scores[factori].T)
    reconstructed_data_temp = reconstructed_data_temp
    Rsquared_per_MN_temp = []
    for mni in range(len(valid_MUs_idx[0])):
        MN_idx_temp = valid_MUs_idx[0][mni]
        Rsquared_per_MN_temp.append(
            (np.corrcoef(smoothed_signal[0][MN_idx_temp],reconstructed_data_temp[mni])[0,1])
            **2) # get R² by calculating r^2
    reconstructed_Rsquared.append(np.mean(Rsquared_per_MN_temp))
reconstructed_Rsquared = np.insert(reconstructed_Rsquared, 0, 0.0)

figure()
plot(reconstructed_Rsquared, color=muscle_colors[0])
title(f"R² of reconstructed signals ({muscle_names[0]})")
xlabel('Nb of factors')
ylabel('R²')
ylim([0,1])

In [None]:
# Create surrogate random signal
# Only for VL too

nb_shuffle_iter = 1
shuffled_smoothed_signal = list(range(nb_shuffle_iter))
shuffled_explained_variance = list(range(nb_shuffle_iter))
for shuffli in range(nb_shuffle_iter):
    shuffled_ISI = list(range(nb_motoneurons))
    shuffled_binary_spike_trains = np.zeros((nb_motoneurons, len(time_bins)))   # Initialize the binary spike train array
    shuffled_smoothed_signal[shuffli] = []
    for mni in range(nb_motoneurons):
        shuffled_ISI[mni] = np.random.permutation(diff(spike_trains[0][mni])) # list of random ISIs (one list per motoneuron)
        spikes_shuffled = cumsum(shuffled_ISI[mni])
        shuffled_binary_spike_trains[mni, np.round((spikes_shuffled-1)*fsamp).astype(int)] = 1 #spikes_shuffled-1 because of offset due to 0-indexing
        shuffled_smoothed_signal[shuffli].append(filtfilt(HanningW, 1, shuffled_binary_spike_trains[mni, :] * fsamp))
    shuffled_smoothed_signal[shuffli] = np.array(shuffled_smoothed_signal[shuffli])
    # cut edges of signal
    shuffled_smoothed_signal[shuffli] = shuffled_smoothed_signal[shuffli][:, :-int(window_end_ignore * fsamp + (duration_with_ignored_window*fsamp/true_duration/second))]  # remove the end by "(duration_with_ignored_window/20)s"
    # the "/ second" just ensures that there is no dimension/unit associated to the index value
    # => necessary because beginning and end ISIs (time before 1st spike and time between last spike and end) are all together when suffling the spikes
    shuffled_smoothed_signal[shuffli] = shuffled_smoothed_signal[shuffli][:, int(window_beginning_ignore * fsamp):]  # remove the start
    
    # FA on shuffled signal #####
    loadings_shuffled = []
    scores_shuffled = []
    shuffled_explained_variance[shuffli] = list(range(max_nb_of_factors_to_extract-1))
    for factori in range(max_nb_of_factors_to_extract-1):
        fa_shuffled = FactorAnalyzer(n_factors=factori+1, rotation='promax')
        fa_shuffled.fit(transpose(shuffled_smoothed_signal[shuffli][valid_MUs_idx[0]]))
        loadings_shuffled.append(fa_shuffled.loadings_)
        scores_shuffled.append(fa_shuffled.transform(transpose(shuffled_smoothed_signal[shuffli][valid_MUs_idx[0]])))
        # get Rsquared value of FA's data reconstruction
        reconstructed_data_temp = np.dot(loadings_shuffled[factori], scores_shuffled[factori].T)
        reconstructed_data_temp = reconstructed_data_temp
        Rsquared_per_MN_temp = []
        for mni in range(len(valid_MUs_idx[0])):
            MN_idx_temp = valid_MUs_idx[0][mni]
            Rsquared_per_MN_temp.append(
                (np.corrcoef(shuffled_smoothed_signal[shuffli][MN_idx_temp],reconstructed_data_temp[mni])[0,1])
                **2) # get R² by calculating r^2
        shuffled_explained_variance[shuffli][factori] = np.mean(Rsquared_per_MN_temp)
    shuffled_explained_variance[shuffli] = np.insert(shuffled_explained_variance[shuffli], 0, 0.0)

In [None]:
# Sanity check = plot the first shuffle
colormap_temp = cm.get_cmap(muscle_colormaps[0])
fig, ax = plt.subplots(figsize=(10, 5))
for mni in range(nb_motoneurons):
    if mni in valid_MUs_idx[0]:
        ax.plot(transpose(shuffled_smoothed_signal[0])[:,mni], color=colormap_temp(mni/(nb_motoneurons-1)))
title(f"Randomized smoothed signals of only valid MUs MUs ({muscle_names[0]}) \n (dark = small MNs ; light = large MNs)")
ylabel("Smoothed discharge rate (pps)")
xlabel("Time (ms)")
new_filename = f'Discharge_rates_randomized_{muscle_names[0]}.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show(fig)

figure()
plot(shuffled_explained_variance[0])
title("Randomized signal cumulative variance explained (R²)")
xlabel('Nb of components / factors')
ylabel('Variance explained (R²)')
ylim([0,1])

In [None]:
# Calculate random slope at each point on the curve
# Get slope of variance explained for simulated VS random data
slope_VAF = []
slope_VAF_RMS = []
slope_VAF_random = []
slope_VAF_random_RMS = []
x_axis_components = arange(max_nb_of_factors_to_extract-2)+1
for componenti in range(max_nb_of_factors_to_extract-2):
    componenti = componenti+1
    # For simulated data
    #    # get slope
    x_temp = [componenti-1,componenti,componenti+1]
    y_temp = reconstructed_Rsquared[x_temp]
    slope_temp, intercept, r_value, p_value, std_err = linregress(x_temp, y_temp)
    slope_VAF.append(slope_temp)
    #    # get error of linear fit
    y_pred_temp = slope_temp * np.array(x_temp) + intercept
    residuals_temp = y_temp - y_pred_temp
    slope_VAF_RMS.append(np.sqrt(np.sum(residuals_temp**2)))
    # For shuffled data
    #    # get slope
    y_temp_random = shuffled_explained_variance[shuffli][x_temp]
    slope_temp, intercept, r_value, p_value, std_err = linregress(x_temp, y_temp_random)
    slope_VAF_random.append(slope_temp)
    #    # get error of linear fit
    y_pred_temp = slope_temp * np.array(x_temp) + intercept
    residuals_temp = y_temp_random - y_pred_temp
    slope_VAF_random_RMS.append(np.sqrt(np.sum(residuals_temp**2)))
plt.figure()
plt.plot(x_axis_components,slope_VAF)
plt.plot(x_axis_components,slope_VAF_random)
plt.ylabel("Slope")
plt.xlabel("Nb of components")
plt.title(f"Slope values ({muscle_names[0]})")

nb_of_factors_above_which_Rsquared_is_below_Rsquared_of_shuffled_data = []
for componenti in range(len(slope_VAF)):
    if slope_VAF[componenti] <= slope_VAF_random[componenti]: #np.mean(slope_VAF_random[componenti])
        nb_of_factors_above_which_Rsquared_is_below_Rsquared_of_shuffled_data = componenti
        break
plt.axvline(x = nb_of_factors_above_which_Rsquared_is_below_Rsquared_of_shuffled_data, color = 'crimson', linewidth=2)
legend(["Simulated data slope of R² curve",
        "Shuffled data slope of R² curve",
        "Nb of components AFTER which slope of R² <= slope of surrogate (random) R²"])

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(reconstructed_Rsquared)
ax.plot(shuffled_explained_variance[0]) # plot line for randomized signal
ax.axvline(nb_common_inputs, ls='-', c='C2', lw=5, alpha = 0.5) # nb of excitatory inputs
ax.axvline(nb_of_factors_above_which_Rsquared_is_below_Rsquared_of_shuffled_data, ls='-', c='crimson', lw=2) # Slope of true variance explained < slope of variance explained for noise/random data
# ax.axvline(nb_of_factors_above_which_curvature_is_below_curvature_of_shuffled_data, ls='--', c='magenta', lw=3) # Slope of true variance explained < slope of variance explained for noise/random data
ax.set_xlabel('Nb of components')
ax.set_ylabel('Variance explained / R²')
# ax.set_xlim(0, min([nb_motoneurons,10]))
ax.set_ylim(0, 1)
legend(["Simulation","Randomized (noise)","Nb of excit. components",
        "Nb of components AFTER which slope of R² <= slope of surrogate (random) R²",
        "Nb of components AFTER which curvature of R² curve <= curvature of surrogate (random) R² curve"])
plt.title(f"R² with factor analysis ({muscle_names[0]})")
new_filename = 'Rsquared_curve.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show(fig)

In [None]:
### SAVE RESULTS AND DATA
new_filename = 'results.txt'
save_file_path = os.path.join(new_directory, new_filename)

with open(save_file_path, 'w') as file:
    file.write(f"General results (dimensionality reduction on VL) -----\n")
    file.write(f"   reconstructed_Rsquared: {reconstructed_Rsquared}\n")
    file.write(f"   nb of common input(s): {nb_common_inputs}\n")
    file.write(f"   nb_of_factors_above_which_Rsquared_is_below_Rsquared_of_shuffled_data: {nb_of_factors_above_which_Rsquared_is_below_Rsquared_of_shuffled_data}\n")
    for pooli in range(nb_of_pools_to_simulate):
        file.write(f"\n")
        file.write(f"{muscle_names[pooli]} -----\n")
        file.write(f"   nb of valid MNs: {len(valid_MUs_idx[pooli])}\n")
        file.write(f"   mean firing rate (valid MNs): {np.mean(firing_rates[pooli][valid_MUs_idx[pooli]]):.2f}\n")
        file.write(f"   std firing rate (valid MNs): {np.std(firing_rates[pooli][valid_MUs_idx[pooli]]):.2f}\n")
        file.write(f"   mean firing rate (Renshaw cells): {mean_firing_rate_RC[pooli]:.2f}\n")
        file.write(f"   std firing rate (Renshaw cells): {std_firing_rate_RC[pooli]:.2f}\n")

# Min-max normalize smoothed discharge rates so that min=0 and max=1 (for each MN)
# Only for VL
normalized_smoothed_DR = smoothed_signal[0][valid_MUs_idx[0],:]
for mni in np.arange(len(valid_MUs_idx[0])):
    rate_temp = normalized_smoothed_DR[mni,:]
    min_rate = np.min(rate_temp)
    max_rate = np.max(rate_temp)
    div = max_rate - min_rate
    if div==0:
        normalized_smoothed_DR[mni,:] = rate_temp*0
    else:
        normalized_smoothed_DR[mni,:] = (rate_temp - min_rate)/div
# Display normalized discharge rates
fig, ax = plt.subplots(figsize=(10, 5))
colormap_temp = cm.get_cmap(muscle_colormaps[0])
for mni in np.arange(len(valid_MUs_idx[0])):
    ax.plot(transpose(normalized_smoothed_DR)[:,mni], color=colormap_temp(mni/(nb_motoneurons-1)), alpha=0.5)
title(f"Smoothed and normalized discharge rates of valid MUs ({muscle_names[0]}) \n (dark = small MNs ; light = large MNs)")
ylabel("Normalized discharge rates (a.u.)")
xlabel("Time (ms)")
new_filename = f'Discharge_rates_normalized_{muscle_names[0]}.png'
save_file_path = os.path.join(new_directory, new_filename)
plt.savefig(save_file_path)
plt.show()

# SAVE SMOOTHED NORMALIZED DISCHARGE RATES AS CSV
new_filename = f'Discharge_rates_{muscle_names[0]}.csv'
save_file_path = os.path.join(new_directory, new_filename)
df = pd.DataFrame(normalized_smoothed_DR)
df.to_csv(save_file_path, index=False, header=False)


In [None]:
# COHERENCE - Initialize
# (only for VL)

# do not execute if "perform_coherence_analysis = False"
if not perform_coherence_analysis:
    print("Not performing coherence analysis")
    sys.exit()

nb_of_valid_MUs = len(valid_MUs_idx[0])
valid_spike_trains = []
valid_binary_spike_trains = np.zeros([nb_of_valid_MUs,int(np.round(true_duration*fsamp))]) # Convert into binary spike train matrix
for mni in arange(nb_of_valid_MUs):
    valid_spike_trains.append(spike_trains[0][valid_MUs_idx[0][mni]])
    # remove beginning and end (window edges to cut)
    valid_spike_trains[mni] = valid_spike_trains[mni] - window_beginning_ignore*second
    idx_remove = []
    idx_remove = np.where(valid_spike_trains[mni]<0*second)
    idx_remove = np.append(idx_remove, np.where( \
            valid_spike_trains[mni]>duration_with_ignored_window-((window_end_ignore+window_beginning_ignore)*second)))
    valid_spike_trains[mni] = np.delete(valid_spike_trains[mni], idx_remove)
    samples_of_discharges_temp = np.round((valid_spike_trains[mni] *fsamp)/second)
    valid_binary_spike_trains[mni, \
            samples_of_discharges_temp.astype(int)-1] = 1

# # # Sanity check of binary matrix
# test = valid_binary_spike_trains
# for mni in range(nb_of_valid_MUs):
#     test[mni] = test[mni]+mni
# plt.plot(test.T)

max_nb_of_MUs_per_group = nb_of_valid_MUs/2 - 1
NbOfRandomPermutations = 100
windowCOH = 1 # in seconds
frequencies_per_FFT_window = 10
frequencyBandForNoise = [150, 500] #no coherence other than noise expected within this frequency band (Hz)
frequencyBandForCOH = [0, 5] #coherence in the frequency band relevant for force production (Hz)
frequencyRangeToDisplayForCOHresults = 100 # in hz

# Initialize variables
raw_coherence = {}
pooled_coherence = {}
pooled_coherence_std = {}
zscore_coherence = {}

# Determine the correct length of the frequency array from the first CSD computation
dummy_signal = np.zeros(valid_binary_spike_trains.shape[1])
f, Pxx_dummy = csd(dummy_signal, dummy_signal, window=windows.hann(round(windowCOH * fsamp)), noverlap=0, nfft=frequencies_per_FFT_window * fsamp, fs=fsamp)
freq_length = len(Pxx_dummy)

COH_intragroup_X_mat = np.zeros((NbOfRandomPermutations, int(np.round(max_nb_of_MUs_per_group + 1)), freq_length))
COH_intragroup_Y_mat = np.zeros((NbOfRandomPermutations, int(np.round(max_nb_of_MUs_per_group + 1)), freq_length))
COH_intergroup_mat = np.zeros((NbOfRandomPermutations, int(np.round(max_nb_of_MUs_per_group + 1)), freq_length))


In [None]:
# COMPUTE COHERENCE

for groupi in range(1, int(np.round(max_nb_of_MUs_per_group))):

    MUPerGroup = groupi
    raw_coherence[groupi] = []

    for permuti in range(1, int(NbOfRandomPermutations + 1)):
        random_shuffling_of_MUs_idx = np.random.permutation(nb_of_valid_MUs)

        CST1 = np.sum(valid_binary_spike_trains[random_shuffling_of_MUs_idx[:MUPerGroup], :], axis=0)
        CST2 = np.sum(valid_binary_spike_trains[random_shuffling_of_MUs_idx[-MUPerGroup:], :], axis=0)

        print(f"Iteration n° {permuti} / {NbOfRandomPermutations} for groups of {groupi} MUs (up to {int(max_nb_of_MUs_per_group)})")

        # Compute intra-group coherence for group 1
        f, Pxx = csd(detrend(CST1), detrend(CST1), window=windows.hann(round(windowCOH * fsamp)), noverlap=0, nfft=frequencies_per_FFT_window * fsamp, fs=fsamp)
        
        # Compute intra-group coherence for group 2
        f, Pyy = csd(detrend(CST2), detrend(CST2), window=windows.hann(round(windowCOH * fsamp)), noverlap=0, nfft=frequencies_per_FFT_window * fsamp, fs=fsamp)
        
        # Compute inter-group coherence
        f, Pxy = csd(detrend(CST1), detrend(CST2), window=windows.hann(round(windowCOH * fsamp)), noverlap=0, nfft=frequencies_per_FFT_window * fsamp, fs=fsamp)
        
        # Collect results
        COH_intragroup_X_mat[permuti - 1, groupi - 1, :] = Pxx
        COH_intragroup_Y_mat[permuti - 1, groupi - 1, :] = Pyy
        COH_intergroup_mat[permuti - 1, groupi - 1, :] = Pxy

    raw_coherence[groupi] = np.mean(np.abs(COH_intergroup_mat[:, groupi - 1, :]), axis=0)

    COH = (np.abs(COH_intergroup_mat[:, groupi - 1, :]) ** 2) / (COH_intragroup_X_mat[:, groupi - 1, :] * COH_intragroup_Y_mat[:, groupi - 1, :])
    pooled_coherence[groupi] = np.mean(COH, axis=0)
    pooled_coherence_std[groupi] = np.std(COH, axis=0)

    bias = np.mean(np.arctanh(np.sqrt(np.mean(COH[:, (frequencyBandForNoise[0] * frequencies_per_FFT_window):(frequencyBandForNoise[1] * frequencies_per_FFT_window)], axis=1))) * np.sqrt(2 * true_duration))
    coh_zscore = np.arctanh(np.sqrt(np.mean(COH, axis=0))) * np.sqrt(2 * true_duration) - bias
    zscore_coherence[groupi] = coh_zscore

In [None]:
# Define plotting functions
def plot_coherence():
    colors_plots = plt.cm.winter(np.linspace(0, 1, int(np.round(max_nb_of_MUs_per_group+1))))

    # Pooled coherence plot
    plt.figure(figsize=(10, 8))
    legend_label = []
    for groupi in np.arange(0, int(np.round(max_nb_of_MUs_per_group))-1):
        plt.plot(pooled_coherence[groupi+1], color=colors_plots[groupi], linewidth=2)
        legend_label.append(f"Pooled coherence calculated over random groups of {groupi} MUs")

    plt.xlim([0, frequencyRangeToDisplayForCOHresults * frequencies_per_FFT_window])
    plt.ylim([0, 1])
    plt.xticks(ticks=plt.xticks()[0], labels=[str(int(x / frequencies_per_FFT_window)) for x in plt.xticks()[0]])
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Coherence (pooled cross power spectral density)')
    plt.title(f"Pooled coherence of continuous MUs {muscle_names[0]}\n with groups of 2 up to {int(max_nb_of_MUs_per_group)} MUs")
    # plt.legend(legend_label)
    plt.grid()
    new_filename = f'COH_{muscle_names[0]}.png'
    save_file_path = os.path.join(new_directory, new_filename)
    plt.savefig(save_file_path)
    plt.show()

    # Standard deviation plot
    plt.figure(figsize=(10, 8))
    for groupi in np.arange(0, int(np.round(max_nb_of_MUs_per_group))-1):
        plt.plot(pooled_coherence_std[groupi+1], color=colors_plots[groupi], linewidth=2)

    plt.xlim([0, frequencyRangeToDisplayForCOHresults * frequencies_per_FFT_window])
    plt.xticks(ticks=plt.xticks()[0], labels=[str(int(x / frequencies_per_FFT_window)) for x in plt.xticks()[0]])
    plt.xlabel('Frequency (Hz)')
    plt.ylabel(f"STD of iterations ({NbOfRandomPermutations} iterations)")
    plt.title(f"Standard deviation of pooled coherence of continuous MUs {muscle_names[0]} \n with groups of 2 up to {int(max_nb_of_MUs_per_group)} MUs")
    plt.grid()
    new_filename = f'COH_{muscle_names[0]}_std.png'
    save_file_path = os.path.join(new_directory, new_filename)
    plt.savefig(save_file_path)
    plt.show()
    plt.show()

def plot_zscore_coherence():
    colors_plots = plt.cm.turbo(np.linspace(0, 1, int(np.round(max_nb_of_MUs_per_group+1))))

    plt.figure(figsize=(10, 8))
    legend_label = []
    for groupi in np.arange(0, int(np.round(max_nb_of_MUs_per_group))-1):
        plt.plot(zscore_coherence[groupi+1], color=colors_plots[groupi], linewidth=2)
        legend_label.append(f"Z-score of pooled coherence calculated over random groups of {groupi} MUs")

    plt.xlim([0, frequencyRangeToDisplayForCOHresults * frequencies_per_FFT_window])
    plt.xticks(ticks=plt.xticks()[0], labels=[str(int(x / frequencies_per_FFT_window)) for x in plt.xticks()[0]])
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Coherence (z-scores)')
    plt.title(f"Z-scores of pooled coherence of continuous MUs {muscle_names[0]}\n with groups of 2 up to {int(max_nb_of_MUs_per_group)} MUs")
    # plt.legend(legend_label)
    plt.grid()
    new_filename = f'COH_{muscle_names[0]}_zscore.png'
    save_file_path = os.path.join(new_directory, new_filename)
    plt.savefig(save_file_path)
    plt.show()

# Call plotting functions
plot_coherence()
plot_zscore_coherence()