diff --git a/bluepyefe/extract.py b/bluepyefe/extract.py index 03d7cc58..691f2314 100644 --- a/bluepyefe/extract.py +++ b/bluepyefe/extract.py @@ -987,7 +987,6 @@ def extract_efeatures( rheobase_strategy, rheobase_settings ) - efeatures, protocol_definitions, current = create_feature_protocol_files( cells, protocols, @@ -996,7 +995,6 @@ def extract_efeatures( write_files=write_files, default_std_value=default_std_value ) - if pickle_cells: path_cells = pathlib.Path(output_directory) path_cells.mkdir(parents=True, exist_ok=True) @@ -1004,7 +1002,7 @@ def extract_efeatures( if plot: plot_all_recordings_efeatures( - cells, protocols, output_dir=output_directory + cells, protocols, output_dir=output_directory, mapper=map_function ) if extract_per_cell and write_files: diff --git a/bluepyefe/plotting.py b/bluepyefe/plotting.py index 88846f19..d39f1ca4 100644 --- a/bluepyefe/plotting.py +++ b/bluepyefe/plotting.py @@ -22,6 +22,7 @@ import logging import pathlib from itertools import cycle +from functools import partial import matplotlib.pyplot as plt import numpy @@ -108,11 +109,13 @@ def _plot_legend(colors, markers, output_dir, show=False): return fig, axs -def plot_all_recordings(cells, output_dir, show=False): - """Plot recordings for all cells and all protocols""" +def _plot(cell, output_dir, show=False): + cell.plot_all_recordings(output_dir, show=show) - for cell in cells: - cell.plot_all_recordings(output_dir, show=show) + +def plot_all_recordings(cells, output_dir, show=False, mapper=map): + """Plot recordings for all cells and all protocols""" + mapper(partial(_plot, output_dir=output_dir, show=show), cells) def plot_efeature( @@ -250,6 +253,20 @@ def plot_efeatures( ) +def _plot_ind(cell, output_dir, protocols, key_amp, colors, markers, show): + for protocol_name in cell.get_protocol_names(): + _ = plot_efeatures( + cells=[cell], + protocol_name=protocol_name, + output_dir=output_dir, + protocols=protocols, + key_amp=key_amp, + colors=colors, + markers=markers, + show=show + ) + + def plot_individual_efeatures( cells, protocols, @@ -257,27 +274,26 @@ def plot_individual_efeatures( colors=None, markers=None, key_amp="amp", - show=False + show=False, + mapper=map, ): """Generate efeatures plots for all each cell individually""" if not colors or not markers: colors, markers = _get_colors_markers_wheels(cells) - for cell in cells: - - for protocol_name in cell.get_protocol_names(): - - _ = plot_efeatures( - cells=[cell], - protocol_name=protocol_name, - output_dir=output_dir, - protocols=protocols, - key_amp=key_amp, - colors=colors, - markers=markers, - show=show - ) + mapper( + partial( + _plot_ind, + output_dir=output_dir, + protocols=protocols, + key_amp=key_amp, + colors=colors, + markers=markers, + show=show, + ), + cells, + ) def plot_grouped_efeatures( @@ -317,14 +333,14 @@ def plot_grouped_efeatures( def plot_all_recordings_efeatures( - cells, protocols, output_dir=None, show=False + cells, protocols, output_dir=None, show=False, mapper=map ): """Generate plots for all recordings and efeatures both for individual cells and across all cells.""" colors, markers = _get_colors_markers_wheels(cells) - plot_all_recordings(cells, output_dir) + plot_all_recordings(cells, output_dir, mapper=mapper) for key_amp in ["amp", "amp_rel"]: plot_individual_efeatures( @@ -334,7 +350,8 @@ def plot_all_recordings_efeatures( colors=colors, markers=markers, key_amp=key_amp, - show=show + show=show, + mapper=mapper, ) plot_grouped_efeatures( cells,