In [None]:
# Import general libraries
import numpy as np
import pandas as pd
import re
import pickle
import torch
from brian2 import *
from sbi import utils, inference
import matplotlib.pyplot as plt
import os
from pathlib import Path
import getpass
import psutil
from dataclasses import dataclass, field, asdict
import copy

In [None]:
# Import from other files in the project
from simulator import SimulationParameters, run_simulation
from analyzer import AnalyzesParams, analyze_data

In [None]:
# Experimental data
# folder_with_data_to_analyze = 'C:\\Users\\franc\\Documents\\Mapping_Recurrent_Inhibition_minimal_datasets_to_run_scripts\\experimental_data'
# Simulated data
folder_with_data_to_analyze = 'C:\\Users\\franc\\Documents\\Mapping_Recurrent_Inhibition_minimal_datasets_to_run_scripts\\Example_small_simulation_batch'

skip_analyses_already_performed = True # True # False
analyze_sim_subset = [] # [0,1,2,3,4,5] # for debuggin purpose. Set to [] to analyze everything (normal behavior)
repeat_analyses_for_different_MN_nb = False # True # check robustness/reproducibility of results when subsampling MNs. Will redo the analyses with N motor units, with N being each element of each nb_of_MUs_to_subsample
nb_of_MUs_to_subsample = [5, 10, 15, 20, 25, 0] # used only if repeat_analyses_for_different_MN_nb = True. 0 is all motor units
nb_of_subsampling_iterations = [30, 30, 30, 30, 30, 30] # has to be the same length as nb_of_MUs_to_subsample. Will repeat the analysis (for the selected nb of subsampled motor units) N times
# # ANALYZIS PARAMETERS FOR SIMULATED DATA
analyzes_params = AnalyzesParams(is_simulation=True, # True,
                                 remove_discontinuous_MNs=False,
                                 select_random_subset_of_MUs_per_pool_for_analyses=0, # 0 = all motor units
                                 get_coherence = True, # True,
                                 coherence_between_CST_and_common_input = True,
                                 coherence_calc_max_iteration_nb_per_group_size=100,
                                 coherence_max_freq = 150,
                                 coherence_output_figures = True, # This is quite long to run! But it generates per-MN figure of the cross-histograms (in both direction, with MN as reference or MN as comparison) for visual inspection
                                 get_cross_histogram_measures = True, # True
                                 cross_histogram_ignore_homonymous_pool=False, # if True, do not perform the analysis on MU pairs coming from the same pool (saves time when doing the analysis only for the between-pool case)
                                 cross_histogram_ignore_heteronymous_pool=False, # if True, do not perform the analysis on MU pairs coming from different pools (saves time when doing the analysis only for the within-pool case)
                                 cross_histogram_output_figures=False, # This is quite long to run!
                                 # But it generates per-MN figure of the cross-histograms (in both direction, with MN as reference or MN as comparison) for visual inspection
                                 cross_histogram_save_cross_hists=True, # Saving all histograms/probability distributions generated for the analysis takes up memory space but allows for plotting later.
                                 cross_histogram_measures_min_plateau = 0.02, # in seconds. 0.02 is 10ms on each side of t=0
                                 cross_histogram_measures_min_spikes=1*1e3, # Filtering later, during SBI
                                 # ^ A minimum amount like 1e3 allows to avoid performing the analysis when heteronymous or homonymous pairs are to be ignored (cross_histogram_ignore_homonymous_pool == True or cross_histogram_ignore_homonymous_pool == True)
                                 cross_histogram_measures_min_r2=0, # Filtering later, during SBI
                                 cross_histogram_measures_null_distrib_nb_iter=0) # if 0, no p values calculated from distribution of troughs (makes it faster)

# ANALYZIS PARAMETERS FOR EXPERIMENTAL DATA
# analyzes_params = AnalyzesParams(is_simulation=False,
#                                  remove_discontinuous_MNs=False,
#                                  get_firing_rates = False,
#                                  get_ground_truth_RI_connectivity = False,
#                                  get_graph_theory_connectivity_measures = False,
#                                  get_cross_histogram_measures=True,
#                                  cross_histogram_output_figures=False, # This is quite long to run!
#                                  # But it generates per-MN figure of the cross-histograms (in both direction, with MN as reference or MN as comparison) for visual inspection
#                                  cross_histogram_save_cross_hists=True, # Saving all histograms/probability distributions generated for the analysis takes up memory space but allows for plotting later.
#                                  # ^ if set to True, then cross_histogram_output_figures must also be True
#                                  cross_histogram_measures_min_plateau = 0.02, # in seconds. 0.02 is 10ms on each side of t=0
#                                  cross_histogram_measures_min_spikes=1*1e3, # Filtering later, during analysis
#                                  cross_histogram_measures_min_r2=0, # Filtering later, during analysis
#                                  cross_histogram_measures_null_distrib_nb_iter=100, # for p values calculations
#                                  get_coherence = False)

analyze_parallel_cpus = 16 # Modify according to your computer

In [None]:
# Save analyzes parameters as a pkl file in the folder containing all the simulations from the batch
# Saves the fixed analyzis parameters (so doesn't save the subsampling part of the parameters)
pkl_path = f"{folder_with_data_to_analyze}\\analyzes_parameters.pkl"
with open(pkl_path, 'wb') as f:
    pickle.dump(asdict(analyzes_params), f)

print(f"✅ Saved analyses parameters to {pkl_path}")

In [None]:
def find_hdf5_files(root_folder, skip_already_done_analyses=True):
    """
    Return a list of all .h5 files under root_folder (including subdirectories).
    If skip_already_done_analyses is True, any .h5 file that lives in a directory
    containing "analysis_output.pkl" will be omitted from the result.
    """
    root = Path(root_folder)
    h5_paths = list(root.rglob("*.h5"))
    
    if not skip_already_done_analyses:
        return [str(p) for p in h5_paths]
    
    filtered = []
    for p in h5_paths:
        # check if this .h5’s parent directory has "analysis_output.pkl"
        if (p.parent / "analysis_output.pkl").exists():
            # skip this file
            continue
        filtered.append(str(p))
    
    return filtered

In [None]:
# Import for parallelization
from joblib import Parallel, delayed
from brian2 import prefs, device
import logging

# # # Function to make sure to terminate any Python process that runs in the background (this can happen when the kernel crashes during the parallelized computations)
def kill_other_python_processes():
    me = os.getpid()
    user = getpass.getuser()
    for proc in psutil.process_iter(['pid', 'name', 'username']):
        try:
            # only consider Python executables run by this user
            if proc.info['username'] != user:
                continue
            name = proc.info['name'].lower()
            # match python, pythonw, python3, etc
            if name.startswith('python'):
                pid = proc.info['pid']
                if pid != me:
                    proc.kill()   # or proc.terminate()
                    print(f"{name} process terminated")
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            pass
if __name__ == '__main__':
    kill_other_python_processes()
    # now safe to start joblib Parallel(...)

# --- Helper to wrap a single simulation and then reset Brian2 ---
def _run_and_reset(params):
    # params must be a SimulationParameters instance
    out = run_simulation(params)
    # after each run, reset Brian2’s magic network so the next worker starts fresh
    device.reinit()     # clears all Brian2 objects
    device.activate()   # re–activate the default runtime device
    return out

# # # PARALLEL SIMULATIONS
def parallel_simulate(params_prior_list, n_jobs=8):
    """
    params_prior_list : list of SimulationParameters
    n_jobs      :       number of parallel workers
    """
    # Note: `prefer="processes"` is the default for `n_jobs>1`
    sim_outputs = Parallel(n_jobs=n_jobs)(
        delayed(_run_and_reset)(p) for p in params_prior_list)
    return sim_outputs

In [None]:
# # # PARALLEL ANALYSIS
# Analyzes_to_run should be a list of AnalyzesParams objects (can be a list with a single object if running the same analyses each time)
def parallel_analyze(files_to_analyze, analyzes_to_run, n_jobs=8):
    """
    sim_outputs : list of simulation outputs as hdf5 files (path)
    n_jobs      : number of parallel workers
    """
    # Note: `prefer="processes"` is the default for `n_jobs>1`
    if (len(analyzes_to_run) == 1): 
        if isinstance(files_to_analyze, str): # If there is only one file to analyze (a string)
            print(f"Running a single analysis on a single file")
            analysis_outputs = analyze_data(files_to_analyze, analyzes_to_run[0])
        elif isinstance(files_to_analyze, list): # Several files to analyze (strings in a list)
            print(f"Running a single analysis on {len(files_to_analyze)} file(s)")
            analysis_outputs = Parallel(n_jobs=n_jobs)(
                delayed(analyze_data)(f, analyzes_to_run[0])
                for f in files_to_analyze)
        else:
            analysis_outputs = None
            print("No valid type of analysis_outputs")
    else: #elif (len(analyzes_to_run) > 1): # will run several analyses per files - parallelize the analyzis iterations instead of the files to analyses
        if isinstance(files_to_analyze, str): # If there is only one file to analyze (a string)
            print(f"Running {len(analyzes_to_run)} analyses on a single file")
            analysis_outputs = Parallel(n_jobs=n_jobs)(
                delayed(analyze_data)(files_to_analyze, analysis_i)
                for analysis_i in analyzes_to_run)
        elif isinstance(files_to_analyze, list): # Several files to analyze (strings in a list)
            print(f"Running {len(analyzes_to_run)} analyses on {len(files_to_analyze)} file(s)")
            for file_to_analyze_i in range(len(files_to_analyze)):
                analysis_outputs = Parallel(n_jobs=n_jobs)(
                    delayed(analyze_data)(files_to_analyze[file_to_analyze_i], analysis_i)
                    for analysis_i in analyzes_to_run)
        else:
            analysis_outputs = None
            print("No valid type of analysis_outputs")
        
        # for analyzis_element_i in range(len(analyzes_to_run)):
        #     for analyzis_iter in range(nb_of_subsampling_iterations[analyzis_element_i]):
        #         if isinstance(files_to_analyze, str): # If there is only one file to analyze (a string)
        #             analysis_outputs = analysis_outputs = analyze_data(files_to_analyze, analyzes_to_run[0])
        #         else: # Several files to analyze (strings in a list)
        #             analysis_outputs = Parallel(n_jobs=n_jobs)(
        #                 delayed(analyze_data)(f, analyzes_to_run[analyzis_element_i])
        #                 for f in files_to_analyze)

    return analysis_outputs

In [None]:
# Capture progres (in .log file) directly in the notebook cell
from threading import Thread, Event
import time
from datetime import datetime
import re

_tail_thread = None
_tail_stop = threading.Event()

def start_tail(logfile="simulations_progress_log.log", poll_interval=1.0):
    """
    Spawn a thread that prints only the log‐lines whose timestamp
    is ≥ the moment you called start_tail(), and strips off everything
    before the log‐level (INFO:, WARNING:, ERROR:, etc.).
    """
    global _tail_thread, _tail_stop

    # make sure the file exists (touch it)
    open(logfile, "a").close()

    # remember "now" and clear any previous stop flag
    start_dt = datetime.now()
    _tail_stop.clear()

    def _tail_loop():
        level_re = re.compile(r'\b(?:DEBUG|INFO|WARNING|ERROR|CRITICAL)\b:\s*')
        with open(logfile, "r", encoding="utf-8") as f:
            # seek to end: we only want new lines
            f.seek(0, 2)
            while not _tail_stop.is_set():
                line = f.readline()
                if not line:
                    time.sleep(poll_interval)
                    continue

                # try to parse timestamp at the very start
                try:
                    ts_str = " ".join(line.split(" ")[:2]).rstrip(",")
                    ts = datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S,%f")
                except Exception:
                    ts = start_dt  # force print for non‐timestamped lines

                if ts >= start_dt:
                    # strip off everything before the level marker
                    m = level_re.search(line)
                    if m:
                        print(line[m.start():], end="")
                    else:
                        print(line, end="")

    # fire up the thread (only one at a time)
    if _tail_thread is None or not _tail_thread.is_alive():
        _tail_thread = threading.Thread(target=_tail_loop, daemon=True)
        _tail_thread.start()
    else:
        print("Tail already running; call `end_tail()` first if you want to restart.")

def stop_tail():
    """Stop the background tail thread."""
    _tail_stop.set()
    if _tail_thread:
        _tail_thread.join()


In [None]:
simulation_output_files = find_hdf5_files(folder_with_data_to_analyze, skip_already_done_analyses=skip_analyses_already_performed)
if (len(analyze_sim_subset) > 0):
    simulation_output_files = [simulation_output_files[i] for i in analyze_sim_subset]
print(f"Nb of files to analyze = {len(simulation_output_files)}")

In [None]:
analyzes_params_list = []
if repeat_analyses_for_different_MN_nb: # Create a list of simulation params, with the only change being the nb of MUs to subsample
    element_i = -1
    for subsample_i in nb_of_MUs_to_subsample:
        element_i += 1
        for iteration_i in range(nb_of_subsampling_iterations[element_i]):
            params_copy = copy.deepcopy(analyzes_params) # make a fresh copy each time
            params_copy.select_random_subset_of_MUs_per_pool_for_analyses = subsample_i
            if subsample_i == 0:
                params_copy.analysis_output_name = f"subsample_all_MNs_iter{iteration_i}"
            else:
                params_copy.analysis_output_name = f"subsample_{subsample_i}MNs_iter{iteration_i}"
            # print(analyzes_params_temp.analysis_output_name)
            analyzes_params_list.append(params_copy)
else: # Create a one-element list
    analyzes_params_list.append(analyzes_params)

print(f"Nb of analyzes to perform on each file = {len(analyzes_params_list)}")

In [None]:
if not repeat_analyses_for_different_MN_nb: # skip this part if running the analyses several times for each MN subsample
    start_tail()
    analyzis_output = parallel_analyze(simulation_output_files, [analyzes_params], n_jobs=analyze_parallel_cpus)
    time.sleep(1)  # give it a moment to print the last lines
    stop_tail()
else:
    start_tail()
    analyzis_output = parallel_analyze(simulation_output_files, analyzes_params_list, n_jobs=analyze_parallel_cpus)
    time.sleep(1)  # give it a moment to print the last lines
    stop_tail()