From 02ae6fe3c1dd4ef8effda72de4b80ab07f12338f Mon Sep 17 00:00:00 2001 From: Amin Alam Date: Thu, 4 Jan 2024 16:24:47 +0100 Subject: [PATCH 1/6] BUG fixed in re-referencing Instead of re-referenced data, old data was shown --- elecphys/main.py | 10 +++++++--- elecphys/preprocessing.py | 10 +++++++--- elecphys/utils.py | 1 + elecphys/visualization.py | 12 +----------- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/elecphys/main.py b/elecphys/main.py index aa255c1..1a6b546 100644 --- a/elecphys/main.py +++ b/elecphys/main.py @@ -199,7 +199,7 @@ def re_reference_npz(ctx, input_npz_folder: str, output_npz_folder: str = 'outpu ---------- """ - print('--- Average re-referencing NPZ files...') + print('--- Re-referencing NPZ files...') preprocessing.re_reference_npz( input_npz_folder, output_npz_folder, @@ -465,13 +465,13 @@ def plot_avg_stft(ctx, input_npz_folder: str, output_plot_file: str, f_min: floa required=False, type=bool, default=False, show_default=True) @click.option('--re_reference', '-rr', help='Re-reference signals. If -rr_channel not specified, signals will be re-referenced to the average of all channels, unless they will be re-referenced to the given channel. If --ignore_channels is specified, specified channels will not be re referenced or taken into account for avg rereferencing', required=False, type=bool, default=False, show_default=True) @click.option('--ignore_channels', '-ic', help='List of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored', - required=False, type=list, default=None, show_default=True) + required=False, type=str, default=None, show_default=True) @click.option('--rr_channel', '-rrc', help='Channel to re-reference signals to. If None, signals will be re-referenced to the average of all channels', required=False, type=int, default=None, show_default=True) @click.pass_context @error_handler def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float = None, t_max: float = None, channels_list: list = None, - normalize: bool = False, re_reference: bool = False, ignore_channels: list = None, rr_channel: int = None) -> None: + normalize: bool = False, re_reference: bool = False, ignore_channels: str = None, rr_channel: int = None) -> None: """ Plots signals from NPZ file Parameters @@ -488,6 +488,10 @@ def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float list of channels to plot. either a string of comma-separated channel numbers or a list of integers. If not specified, the default value is None and all of the channels will be plotted. normalize: bool normalize signals. If true, each channel will be normalized. If not specified, the default value is False. + re_reference: bool + re-reference signals. If true, signals will be re-referenced. If not specified, the default value is False. + ignore_channels: str + list of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored. Either a list of channel indexes or a string of channel indexes separated by commas. If not specified, the default value is None. rr_channel: int channel to re-reference signals to. If None, signals will be re-referenced to the average of all channels. If not specified, the default value is None. diff --git a/elecphys/preprocessing.py b/elecphys/preprocessing.py index a2ccd5c..00cc67f 100644 --- a/elecphys/preprocessing.py +++ b/elecphys/preprocessing.py @@ -189,14 +189,18 @@ def re_reference(data: np.ndarray, ignore_channels: [ data.shape[0]) if i not in ignore_channels] else: channels_list = [i for i in range(data.shape[0])] - rr_channel = rr_channel - 1 if rr_channel is not None else None + print(rr_channel, ignore_channels) + data_rereferenced = data.copy() if rr_channel is not None: + reference = data[rr_channel, :].reshape(1, -1) data_rereferenced[channels_list, :] = data[channels_list, - :] - data[rr_channel, :].reshape(1, -1) + :] - np.repeat(reference, len(channels_list), axis=0) else: + reference = np.mean(data[channels_list, :], axis=0).reshape(1, -1) data_rereferenced[channels_list, :] = data[channels_list, - :] - np.mean(data[channels_list, :], axis=0).reshape(1, -1) + :] - np.repeat(reference, len(channels_list), axis=0) + return data_rereferenced diff --git a/elecphys/utils.py b/elecphys/utils.py index 9d723b5..2d0db02 100644 --- a/elecphys/utils.py +++ b/elecphys/utils.py @@ -116,4 +116,5 @@ def convert_string_to_list(string): string = string.split(',') output = list(map(int, string)) output = np.unique(output) + output = output.tolist() return output diff --git a/elecphys/visualization.py b/elecphys/visualization.py index 192315d..d12ce67 100644 --- a/elecphys/visualization.py +++ b/elecphys/visualization.py @@ -275,7 +275,6 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl rr_channel = _rereference_args['rr_channel'] ignore_channels = utils.convert_string_to_list(ignore_channels) if rr_channel is not None: - print(rr_channel, channels_list) if rr_channel not in channels_list: raise ValueError('rr_channel must be in channels_list') if ignore_channels is not None: @@ -283,20 +282,11 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl raise ValueError( 'All channels in ignore_channels must be in channels_list, as it does not make sense to ignore a channel for re-referencing if it is not in channels_list') - for row_no, channel_index in enumerate(channels_list): - channel_index = channel_index - 1 - if rr_channel is not None: - if channel_index == rr_channel - 1: - rr_channel = row_no - if ignore_channels is not None: - if channel_index in ignore_channels: - ignore_channels[ignore_channels.index( - channel_index)] = row_no data_all = preprocessing.re_reference( data_all, ignore_channels, rr_channel) for row_no, channel_index in enumerate(channels_list): - + signal_chan = data_all[row_no, :] ax[row_no].plot(t, signal_chan, color='k') ax[row_no].set_ylabel(f'Channel {channel_index}') ax[row_no].spines['top'].set_visible(False) From 5de2ccc2aa3d16d7ebe4a4822028742f854cee8a Mon Sep 17 00:00:00 2001 From: Amin Alam Date: Mon, 15 Jan 2024 10:53:47 +0100 Subject: [PATCH 2/6] codes reformated --- elecphys/preprocessing.py | 2 +- test/main.py | 214 ++++++++++++++++++++++++-------------- 2 files changed, 139 insertions(+), 77 deletions(-) diff --git a/elecphys/preprocessing.py b/elecphys/preprocessing.py index 00cc67f..345c4f6 100644 --- a/elecphys/preprocessing.py +++ b/elecphys/preprocessing.py @@ -189,8 +189,8 @@ def re_reference(data: np.ndarray, ignore_channels: [ data.shape[0]) if i not in ignore_channels] else: channels_list = [i for i in range(data.shape[0])] + rr_channel = rr_channel - 1 if rr_channel is not None else None - print(rr_channel, ignore_channels) data_rereferenced = data.copy() diff --git a/test/main.py b/test/main.py index c07f808..b7962d7 100644 --- a/test/main.py +++ b/test/main.py @@ -1,27 +1,29 @@ +import elecphys.data_io as data_io +import elecphys.visualization as visualization +import elecphys.fourier_analysis as fourier_analysis +import elecphys.preprocessing as preprocessing +import elecphys.conversion as conversion import os import sys import unittest import shutil sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -import elecphys.conversion as conversion -import elecphys.preprocessing as preprocessing -import elecphys.fourier_analysis as fourier_analysis -import elecphys.visualization as visualization -import elecphys.data_io as data_io class TestCases_0_conversion(unittest.TestCase): def test_1_rhd_to_mat(self): if not MATLAB_TEST: self.assertTrue(True) - return + return folder_path = os.path.join(os.path.dirname(__file__), 'data', 'rhd') - output_mat_file = os.path.join(os.path.dirname(__file__), 'data', 'mat', 'sample.mat') + output_mat_file = os.path.join( + os.path.dirname(__file__), 'data', 'mat', 'sample.mat') for ds_factor in [1, 20]: if os.path.exists(output_mat_file): os.remove(output_mat_file) - conversion.convert_rhd_to_mat(folder_path, output_mat_file, ds_factor) + conversion.convert_rhd_to_mat( + folder_path, output_mat_file, ds_factor) self.assertTrue(os.path.exists(output_mat_file)) os.remove(output_mat_file) @@ -29,36 +31,44 @@ def test_1_rhd_to_mat(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_mat_file)) - def test_2_mat_to_npz(self): - mat_file = os.path.join(os.path.dirname(__file__), 'data', 'mat', 'sample.mat') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') + mat_file = os.path.join( + os.path.dirname(__file__), + 'data', + 'mat', + 'sample.mat') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') for notch_filter_freq in [0, 50, 60]: if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) - conversion.convert_mat_to_npz(mat_file, output_npz_folder, notch_filter_freq) + conversion.convert_mat_to_npz( + mat_file, output_npz_folder, notch_filter_freq) self.assertTrue(os.path.exists(output_npz_folder)) - + shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main convert_mat_to_npz --mat_file {mat_file} --output_npz_folder {output_npz_folder} --notch_filter_freq {notch_filter_freq}' os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) - class TestCases_1_preprocessing(unittest.TestCase): def test_apply_notch(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') npz_files = os.listdir(npz_files_folder) npz_file = npz_files[0] - _signal_chan, fs = data_io.load_npz(os.path.join(npz_files_folder, npz_file)) - output = preprocessing.apply_notch(_signal_chan, {'Q':60, 'fs':fs, 'f0':50}) + _signal_chan, fs = data_io.load_npz( + os.path.join(npz_files_folder, npz_file)) + output = preprocessing.apply_notch( + _signal_chan, {'Q': 60, 'fs': fs, 'f0': 50}) self.assertTrue(output.shape == _signal_chan.shape) - def test_zscore_normalize_npz(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_zscore') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_zscore') if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) preprocessing.zscore_normalize_npz(npz_files_folder, output_npz_folder) @@ -69,10 +79,11 @@ def test_zscore_normalize_npz(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) - def test_normalize_npz(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_normalized') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_normalized') if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) preprocessing.normalize_npz(npz_files_folder, output_npz_folder) @@ -83,15 +94,17 @@ def test_normalize_npz(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) - def test_re_reference_npz(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_avg_reref') - for ignore_channels in [[1,2], "[1,4,6]", None]: + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_avg_reref') + for ignore_channels in [[1, 2], "[1,4,6]", None]: for rr_channel in [1, 4]: if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) - preprocessing.re_reference_npz(npz_files_folder, output_npz_folder, ignore_channels, rr_channel) + preprocessing.re_reference_npz( + npz_files_folder, output_npz_folder, ignore_channels, rr_channel) self.assertTrue(os.path.exists(output_npz_folder)) shutil.rmtree(output_npz_folder) @@ -107,14 +120,17 @@ def test_re_reference_npz(self): class TestCases_2_fourier_analysis(unittest.TestCase): def test_stft_numeric_output_from_npz(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_stft') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_stft') if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) window_size = 1 overlap = 0.5 for window_type in ['hann', 'kaiser 5']: - fourier_analysis.stft_numeric_output_from_npz(npz_files_folder, output_npz_folder, window_size, overlap, window_type) + fourier_analysis.stft_numeric_output_from_npz( + npz_files_folder, output_npz_folder, window_size, overlap, window_type) self.assertTrue(os.path.exists(output_npz_folder)) shutil.rmtree(output_npz_folder) @@ -122,13 +138,15 @@ def test_stft_numeric_output_from_npz(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) - def test_dft_numeric_output_from_npz(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_dft') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_dft') if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) - fourier_analysis.dft_numeric_output_from_npz(npz_files_folder, output_npz_folder) + fourier_analysis.dft_numeric_output_from_npz( + npz_files_folder, output_npz_folder) self.assertTrue(os.path.exists(output_npz_folder)) shutil.rmtree(output_npz_folder) @@ -136,17 +154,20 @@ def test_dft_numeric_output_from_npz(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) - def test_frequency_filtering(self): filter_order = 2 - for filter_args in [{'filter_type': 'LPF', 'freq_cutoff': 100}, {'filter_type': 'HPF', 'freq_cutoff': 100}, {'filter_type': 'BPF', 'freq_cutoff': [50, 100]}]: + for filter_args in [{'filter_type': 'LPF', 'freq_cutoff': 100}, { + 'filter_type': 'HPF', 'freq_cutoff': 100}, {'filter_type': 'BPF', 'freq_cutoff': [50, 100]}]: filter_args['filter_order'] = filter_order - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_filtered') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_filtered') if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) - - fourier_analysis.butterworth_filtering_from_npz(npz_files_folder, output_npz_folder, filter_args) + + fourier_analysis.butterworth_filtering_from_npz( + npz_files_folder, output_npz_folder, filter_args) self.assertTrue(os.path.exists(output_npz_folder)) shutil.rmtree(output_npz_folder) @@ -154,27 +175,29 @@ def test_frequency_filtering(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) - def test_cfc_from_npz(self): return - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_npz_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_cfc') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_npz_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_cfc') if os.path.exists(output_npz_folder): shutil.rmtree(output_npz_folder) freq_phase = list(range(2, 9)) freq_amp = list(range(35, 46)) - fourier_analysis.calc_cfc_from_npz(npz_files_folder, output_npz_folder, freq_amp, freq_phase) + fourier_analysis.calc_cfc_from_npz( + npz_files_folder, output_npz_folder, freq_amp, freq_phase) self.assertTrue(os.path.exists(output_npz_folder)) - - class TestCases_3_visualization(unittest.TestCase): def test_plot_stft(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_stft') + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_stft') npz_files = os.listdir(npz_files_folder) npz_file = npz_files[0] - output_plot_file = os.path.join(os.path.dirname(__file__), 'data', 'plots', 'stft_plot.png') + output_plot_file = os.path.join(os.path.dirname( + __file__), 'data', 'plots', 'stft_plot.png') f_min = None f_max = None @@ -191,19 +214,31 @@ def test_plot_stft(self): for db_max in [None, 50]: if os.path.exists(output_plot_file): os.remove(output_plot_file) - visualization.plot_stft_from_npz(os.path.join(npz_files_folder, npz_file), output_plot_file, f_min, f_max, t_min, t_max, db_min, db_max) - self.assertTrue(os.path.exists(output_plot_file)) - + visualization.plot_stft_from_npz( + os.path.join( + npz_files_folder, + npz_file), + output_plot_file, + f_min, + f_max, + t_min, + t_max, + db_min, + db_max) + self.assertTrue( + os.path.exists(output_plot_file)) + os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_stft --input_npz_file "{os.path.join(npz_files_folder, npz_file)}" --output_plot_file {output_plot_file} --f_min {f_min} --f_max {f_max} --t_min {t_min} --t_max {t_max} --db_min {db_min} --db_max {db_max}' os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) - def test_plot_avg_stft(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_stft') - output_plot_file = os.path.join(os.path.dirname(__file__), 'data', 'plots', 'avg_stft_plot.png') - + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_stft') + output_plot_file = os.path.join(os.path.dirname( + __file__), 'data', 'plots', 'avg_stft_plot.png') + f_min = None f_max = None t_min = None @@ -219,24 +254,36 @@ def test_plot_avg_stft(self): for db_max in [None, 50]: if os.path.exists(output_plot_file): os.remove(output_plot_file) - visualization.plot_avg_stft_from_npz(npz_files_folder, output_plot_file, f_min, f_max, t_min, t_max, db_min, db_max, channels_list) - self.assertTrue(os.path.exists(output_plot_file)) + visualization.plot_avg_stft_from_npz( + npz_files_folder, + output_plot_file, + f_min, + f_max, + t_min, + t_max, + db_min, + db_max, + channels_list) + self.assertTrue( + os.path.exists(output_plot_file)) os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_avg_stft --input_npz_folder "{npz_files_folder}" --output_plot_file {output_plot_file} --f_min {f_min} --f_max {f_max} --t_min {t_min} --t_max {t_max} --db_min {db_min} --db_max {db_max} --channels_list "{[1, 2, 3, 4, 5, 6, 7, 12, 15]}"' os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) - def test_plot_signal(self): - npz_folder_path = os.path.join(os.path.dirname(__file__), 'data', 'npz') - output_plot_file = os.path.join(os.path.dirname(__file__), 'data', 'plots', 'signal_plot.png') + npz_folder_path = os.path.join( + os.path.dirname(__file__), 'data', 'npz') + output_plot_file = os.path.join(os.path.dirname( + __file__), 'data', 'plots', 'signal_plot.png') t_min = None t_max = None for channels_list in [None, [1, 2, 3, 4, 5, 6, 7, 12, 15]]: if os.path.exists(output_plot_file): os.remove(output_plot_file) - visualization.plot_signals_from_npz(npz_folder_path, output_plot_file, t_min, t_max, channels_list) + visualization.plot_signals_from_npz( + npz_folder_path, output_plot_file, t_min, t_max, channels_list) self.assertTrue(os.path.exists(output_plot_file)) os.remove(output_plot_file) @@ -250,7 +297,8 @@ def test_plot_signal(self): self.assertTrue(os.path.exists(output_plot_file)) re_reference = True - for channels_list, ignore_channels in zip([[1, 2, 3, 4, 5, 6, 7, 12, 15], [1, 2, 3, 4, 5]], [[1, 5], [5]]): + for channels_list, ignore_channels in zip( + [[1, 2, 3, 4, 5, 6, 7, 12, 15], [1, 2, 3, 4, 5]], [[1, 5], [5]]): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file {output_plot_file} --channels_list "{channels_list}" --ignore_channels "{ignore_channels}" --re_reference {re_reference}' os.system(command_prompt) @@ -268,17 +316,19 @@ def test_plot_signal(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) - for channels_list, ignore_channels in zip([[1, 2, 3, 4, 5, 6, 7, 12, 15], [1, 2, 3, 4, 5]], [[1, 5], [5]]): + for channels_list, ignore_channels in zip( + [[1, 2, 3, 4, 5, 6, 7, 12, 15], [1, 2, 3, 4, 5]], [[1, 5], [5]]): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file {output_plot_file} --channels_list "{channels_list}" --ignore_channels "{ignore_channels}" --re_reference {re_reference} --rr_channel {rr_channel}' os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) - def test_plot_dft(self): - npz_files_folder = os.path.join(os.path.dirname(__file__), 'data', 'npz_dft') - output_plot_file = os.path.join(os.path.dirname(__file__), 'data', 'plots', 'dft_plot.png') - + npz_files_folder = os.path.join( + os.path.dirname(__file__), 'data', 'npz_dft') + output_plot_file = os.path.join(os.path.dirname( + __file__), 'data', 'plots', 'dft_plot.png') + f_min = None f_max = 150 for channels_list in [None, [1, 2, 3]]: @@ -286,7 +336,14 @@ def test_plot_dft(self): for plot_type in ['all_channels', 'average_of_channels']: if os.path.exists(output_plot_file): os.remove(output_plot_file) - visualization.plot_dft_from_npz(npz_files_folder, output_plot_file, f_min, f_max, plot_type, conv_window_size=conv_window_size, channels_list=channels_list) + visualization.plot_dft_from_npz( + npz_files_folder, + output_plot_file, + f_min, + f_max, + plot_type, + conv_window_size=conv_window_size, + channels_list=channels_list) self.assertTrue(os.path.exists(output_plot_file)) os.remove(output_plot_file) @@ -294,17 +351,23 @@ def test_plot_dft(self): os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) - def test_plot_filter_freq_response(self): - output_plot_file = os.path.join(os.path.dirname(__file__), 'data', 'plots', 'filter_freq_response_plot.png') - filter_freq_response_json_file_path = os.path.join(os.path.dirname(__file__), 'data', 'npz_filtered', 'filter_freq_response.json') + output_plot_file = os.path.join( + os.path.dirname(__file__), + 'data', + 'plots', + 'filter_freq_response_plot.png') + filter_freq_response_json_file_path = os.path.join(os.path.dirname( + __file__), 'data', 'npz_filtered', 'filter_freq_response.json') if os.path.exists(output_plot_file): os.remove(output_plot_file) - visualization.plot_filter_freq_response_from_json(filter_freq_response_json_file_path, output_plot_file) + visualization.plot_filter_freq_response_from_json( + filter_freq_response_json_file_path, output_plot_file) self.assertTrue(os.path.exists(output_plot_file)) os.remove(output_plot_file) - for filter_type, freq_cutoff in zip(['LPF', 'HPF', 'BPF'], [60, 60, [50, 100]]): + for filter_type, freq_cutoff in zip( + ['LPF', 'HPF', 'BPF'], [60, 60, [50, 100]]): for filter_order in [2, 4]: if os.path.exists(output_plot_file): os.remove(output_plot_file) @@ -318,10 +381,9 @@ def test_get_matlab_engine(self): pass - if __name__ == '__main__': os.system('pip3 uninstall elecphys -y') MATLAB_TEST = int(sys.argv[1]) os.environ['ELECPHYS_DEBUG'] = 'True' os.environ['ELECPHYS_VERBOSE'] = 'True' - unittest.main(argv=['first-arg-is-ignored'], exit=False) \ No newline at end of file + unittest.main(argv=['first-arg-is-ignored'], exit=False) From 218dd7fee3b4dca22763f5c3bcf247bba81e6673 Mon Sep 17 00:00:00 2001 From: Amin Alam Date: Mon, 15 Jan 2024 11:30:25 +0100 Subject: [PATCH 3/6] codes reformat --- test/main.py | 71 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/test/main.py b/test/main.py index b7962d7..37a8010 100644 --- a/test/main.py +++ b/test/main.py @@ -1,12 +1,12 @@ -import elecphys.data_io as data_io -import elecphys.visualization as visualization -import elecphys.fourier_analysis as fourier_analysis -import elecphys.preprocessing as preprocessing +import shutil +import unittest import elecphys.conversion as conversion -import os +import elecphys.preprocessing as preprocessing +import elecphys.fourier_analysis as fourier_analysis +import elecphys.visualization as visualization +import elecphys.data_io as data_io import sys -import unittest -import shutil +import os sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -28,7 +28,8 @@ def test_1_rhd_to_mat(self): os.remove(output_mat_file) command_prompt = f'python3 -m elecphys.main convert_rhd_to_mat --folder_path {folder_path} --output_mat_file {output_mat_file} --ds_factor {ds_factor}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_mat_file)) def test_2_mat_to_npz(self): @@ -48,7 +49,8 @@ def test_2_mat_to_npz(self): shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main convert_mat_to_npz --mat_file {mat_file} --output_npz_folder {output_npz_folder} --notch_filter_freq {notch_filter_freq}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) @@ -76,7 +78,8 @@ def test_zscore_normalize_npz(self): shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main zscore_normalize_npz --input_npz_folder {npz_files_folder} --output_npz_folder {output_npz_folder}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) def test_normalize_npz(self): @@ -91,7 +94,8 @@ def test_normalize_npz(self): shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main normalize_npz --input_npz_folder {npz_files_folder} --output_npz_folder {output_npz_folder}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) def test_re_reference_npz(self): @@ -109,11 +113,13 @@ def test_re_reference_npz(self): shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main re_reference_npz --input_npz_folder {npz_files_folder} --output_npz_folder {output_npz_folder} --ignore_channels "{ignore_channels}" --rr_channel {rr_channel}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main re_reference_npz --input_npz_folder {npz_files_folder} --output_npz_folder {output_npz_folder}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) @@ -135,7 +141,8 @@ def test_stft_numeric_output_from_npz(self): shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main stft_numeric_output_from_npz --input_npz_folder "{npz_files_folder}" --output_npz_folder {output_npz_folder} --window_size {window_size} --overlap {overlap} --window_type "{window_type}"' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) def test_dft_numeric_output_from_npz(self): @@ -151,7 +158,8 @@ def test_dft_numeric_output_from_npz(self): shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main dft_numeric_output_from_npz --input_npz_folder "{npz_files_folder}" --output_npz_folder {output_npz_folder}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) def test_frequency_filtering(self): @@ -172,7 +180,8 @@ def test_frequency_filtering(self): shutil.rmtree(output_npz_folder) command_prompt = f'python3 -m elecphys.main frequncy_domain_filter --input_npz_folder "{npz_files_folder}" --output_npz_folder {output_npz_folder} --filter_type {filter_args["filter_type"]} --filter_order {filter_args["filter_order"]} --freq_cutoff "{filter_args["freq_cutoff"]}"' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_npz_folder)) def test_cfc_from_npz(self): @@ -230,7 +239,8 @@ def test_plot_stft(self): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_stft --input_npz_file "{os.path.join(npz_files_folder, npz_file)}" --output_plot_file {output_plot_file} --f_min {f_min} --f_max {f_max} --t_min {t_min} --t_max {t_max} --db_min {db_min} --db_max {db_max}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) def test_plot_avg_stft(self): @@ -269,7 +279,8 @@ def test_plot_avg_stft(self): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_avg_stft --input_npz_folder "{npz_files_folder}" --output_plot_file {output_plot_file} --f_min {f_min} --f_max {f_max} --t_min {t_min} --t_max {t_max} --db_min {db_min} --db_max {db_max} --channels_list "{[1, 2, 3, 4, 5, 6, 7, 12, 15]}"' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) def test_plot_signal(self): @@ -288,12 +299,14 @@ def test_plot_signal(self): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file "{output_plot_file}"' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file {output_plot_file} --channels_list "{[1, 2, 3, 4, 5, 6, 7, 12, 15]}"' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) re_reference = True @@ -301,26 +314,30 @@ def test_plot_signal(self): [[1, 2, 3, 4, 5, 6, 7, 12, 15], [1, 2, 3, 4, 5]], [[1, 5], [5]]): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file {output_plot_file} --channels_list "{channels_list}" --ignore_channels "{ignore_channels}" --re_reference {re_reference}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file {output_plot_file} --re_reference {re_reference}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) rr_channel = 2 os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file {output_plot_file} --re_reference {re_reference} --rr_channel {rr_channel}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) for channels_list, ignore_channels in zip( [[1, 2, 3, 4, 5, 6, 7, 12, 15], [1, 2, 3, 4, 5]], [[1, 5], [5]]): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_signal --input_npz_folder "{npz_folder_path}" --output_plot_file {output_plot_file} --channels_list "{channels_list}" --ignore_channels "{ignore_channels}" --re_reference {re_reference} --rr_channel {rr_channel}' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) def test_plot_dft(self): @@ -348,7 +365,8 @@ def test_plot_dft(self): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_dft --input_npz_folder "{npz_files_folder}" --output_plot_file {output_plot_file} --plot_type {plot_type} --conv_window_size {conv_window_size} --channels_list "{[1, 2, 3]}"' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) def test_plot_filter_freq_response(self): @@ -372,7 +390,8 @@ def test_plot_filter_freq_response(self): if os.path.exists(output_plot_file): os.remove(output_plot_file) command_prompt = f'python3 -m elecphys.main plot_filter_freq_response --filter_type {filter_type} --filter_order {filter_order} --freq_cutoff "{freq_cutoff}" --output_plot_file {output_plot_file} -fs 1000' - os.system(command_prompt) + for _ in range(2): + os.system(command_prompt) self.assertTrue(os.path.exists(output_plot_file)) From d722abc4a22cb8ae8d1b779abd46010e89e4043b Mon Sep 17 00:00:00 2001 From: Amin Alam Date: Mon, 15 Jan 2024 13:18:49 +0100 Subject: [PATCH 4/6] scale bar is added --- elecphys/main.py | 6 +++- elecphys/visualization.py | 64 +++++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/elecphys/main.py b/elecphys/main.py index 1a6b546..bf13acb 100644 --- a/elecphys/main.py +++ b/elecphys/main.py @@ -463,6 +463,7 @@ def plot_avg_stft(ctx, input_npz_folder: str, output_plot_file: str, f_min: floa required=False, type=list, default=None, show_default=True) @click.option('--normalize', '-n', help='Normalize signals. If true, each channel will be normalized', required=False, type=bool, default=False, show_default=True) +@click.option('--scale_bar', '-sb', help='Scale bar. If true, a scale bar will be added to the plot', required=False, type=bool, default=True, show_default=True) @click.option('--re_reference', '-rr', help='Re-reference signals. If -rr_channel not specified, signals will be re-referenced to the average of all channels, unless they will be re-referenced to the given channel. If --ignore_channels is specified, specified channels will not be re referenced or taken into account for avg rereferencing', required=False, type=bool, default=False, show_default=True) @click.option('--ignore_channels', '-ic', help='List of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored', required=False, type=str, default=None, show_default=True) @@ -471,7 +472,7 @@ def plot_avg_stft(ctx, input_npz_folder: str, output_plot_file: str, f_min: floa @click.pass_context @error_handler def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float = None, t_max: float = None, channels_list: list = None, - normalize: bool = False, re_reference: bool = False, ignore_channels: str = None, rr_channel: int = None) -> None: + normalize: bool = False, scale_bar: bool = True, re_reference: bool = False, ignore_channels: str = None, rr_channel: int = None) -> None: """ Plots signals from NPZ file Parameters @@ -488,6 +489,8 @@ def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float list of channels to plot. either a string of comma-separated channel numbers or a list of integers. If not specified, the default value is None and all of the channels will be plotted. normalize: bool normalize signals. If true, each channel will be normalized. If not specified, the default value is False. + scale_bar: bool + scale bar. If true, a scale bar will be added to the plot. If not specified, the default value is True. re_reference: bool re-reference signals. If true, signals will be re-referenced. If not specified, the default value is False. ignore_channels: str @@ -514,6 +517,7 @@ def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float t_max, channels_list, normalize, + scale_bar, _rereference_args) print('--- Plotting complete.\n\n') diff --git a/elecphys/visualization.py b/elecphys/visualization.py index d12ce67..e3aaf77 100644 --- a/elecphys/visualization.py +++ b/elecphys/visualization.py @@ -207,7 +207,7 @@ def plot_stft_from_array(Zxx: np.ndarray, t: np.ndarray, f: np.ndarray, f_min: i def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: float, t_max: float, channels_list: [ - str, list] = None, normalize: bool = False, _rereference_args: dict = None) -> None: + str, list] = None, normalize: bool = False, scale_bar: bool = True, _rereference_args: dict = None) -> None: """ Plots signals from NPZ file Parameters @@ -224,6 +224,8 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl list of channels to plot (can be a strin of comma-separated values or a list of integers) normalize: bool whether to normalize the signal + scale_bar: bool + whether to plot a scale bar _rereference_args: dict dictionary containing rereferencing parameters. If None, no rereferencing will be applied @@ -245,7 +247,16 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl else: fig = plt.figure(figsize=(30, 10)) - ax = fig.subplots(len(channels_list), 1, sharex=True) + if normalize: + scale_bar = False + + if scale_bar: + ax = fig.subplots( + len(channels_list), 2, gridspec_kw={ + 'width_ratios': [ + 10, 1]}) + else: + ax = fig.subplots(len(channels_list), 1, sharex=True) for row_no, channel_index in enumerate(channels_list): channel_index = channel_index - 1 @@ -287,17 +298,44 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl for row_no, channel_index in enumerate(channels_list): signal_chan = data_all[row_no, :] - ax[row_no].plot(t, signal_chan, color='k') - ax[row_no].set_ylabel(f'Channel {channel_index}') - ax[row_no].spines['top'].set_visible(False) - ax[row_no].spines['right'].set_visible(False) - if row_no != len(channels_list) - 1: - ax[row_no].spines['bottom'].set_visible(False) - ax[row_no].tick_params(axis='both', which='both', length=0) - ax[row_no].set_yticks([]) - - ax[-1].set_xlabel('Time (s)') - ax[-1].set_xlim(t_min, t_max) + if not scale_bar: + ax[row_no].plot(t, signal_chan, color='k') + ax[row_no].set_ylabel(f'Channel {channel_index}') + ax[row_no].spines['top'].set_visible(False) + ax[row_no].spines['right'].set_visible(False) + if row_no != len(channels_list) - 1: + ax[row_no].spines['bottom'].set_visible(False) + ax[row_no].tick_params(axis='both', which='both', length=0) + ax[row_no].set_yticks([]) + ax[row_no].set_xlim(t_min, t_max) + else: + scale_bar_size = np.max(signal_chan) - np.min(signal_chan) + ax[row_no, 0].plot(t, signal_chan, color='k') + ax[row_no, 0].set_ylabel(f'Channel {channel_index}') + ax[row_no, 0].spines['top'].set_visible(False) + ax[row_no, 0].spines['right'].set_visible(False) + if row_no != len(channels_list) - 1: + ax[row_no, 0].spines['bottom'].set_visible(False) + ax[row_no, 0].tick_params(axis='both', which='both', length=0) + ax[row_no, 0].set_yticks([]) + # symmetric error bar for scale bar + ax[row_no, 1].errorbar( + 0, 0, yerr=scale_bar_size / 2, color='k', capsize=5) + ax[row_no, 1].text( + 1 / fs, 0, f'{scale_bar_size:.2f} uV', ha='left') + ax[row_no, 1].axis('off') + ax[row_no, 0].set_xlim(t_min, t_max) + ax[row_no, 0].set_xticks([]) + ax[row_no, 1].set_xticks([]) + + if not scale_bar: + ax[-1].set_xlabel('Time (s)') + ax[-1].set_xlim(t_min, t_max) + else: + ax[-1, 0].set_xlabel('Time (s)') + ax[-1, 1].set_xlabel('Scale bar') + ax[-1, 0].set_xlim(t_min, t_max) + ax[-1, 1].set_ylim(-scale_bar_size / 2, scale_bar_size / 2) plt.tight_layout() if output_plot_file is None: From 8b84be81b6e8dcbd014df019cf8e16007ace0b43 Mon Sep 17 00:00:00 2001 From: Amin Alam Date: Mon, 15 Jan 2024 13:22:18 +0100 Subject: [PATCH 5/6] path append before the imports --- test/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/main.py b/test/main.py index 37a8010..9775eab 100644 --- a/test/main.py +++ b/test/main.py @@ -1,3 +1,8 @@ +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import shutil import unittest import elecphys.conversion as conversion @@ -5,10 +10,7 @@ import elecphys.fourier_analysis as fourier_analysis import elecphys.visualization as visualization import elecphys.data_io as data_io -import sys -import os -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) class TestCases_0_conversion(unittest.TestCase): From dd58238d3e88be321f7c74e9f99abfc6cdc2c602 Mon Sep 17 00:00:00 2001 From: Amin Alam Date: Mon, 15 Jan 2024 13:22:31 +0100 Subject: [PATCH 6/6] new version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e7af738..31c1242 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def read_description(): setup( name="ElecPhys", - version="0.0.52", + version="0.0.6", author='Amin Alam', description='Electrophysiology data processing', long_description=read_description(),