Skip to content

Commit

Permalink
Merge pull request #36 from AminAlam/dev
Browse files Browse the repository at this point in the history
BUG FIX - channels_list and ignore_channels in command line
  • Loading branch information
AminAlam committed Jan 30, 2024
2 parents 7aef5b5 + 45abccc commit 925e801
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 69 deletions.
5 changes: 2 additions & 3 deletions elecphys/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,11 @@ def load_all_npz_files(npz_folder: str, ignore_channels: [
files_list.remove(file_name)
files_list = utils.sort_file_names(files_list)
all_channels_in_folder = list(range(0, len(files_list)))
channels_list = utils.convert_string_to_list(channels_list)
if channels_list is None:
channels_list = all_channels_in_folder
else:
channels_list = utils.convert_string_to_list(channels_list)
ignore_channels = utils.convert_string_to_list(ignore_channels)
if ignore_channels is not None:
ignore_channels = utils.convert_string_to_list(ignore_channels)
ignore_channels = [i - 1 for i in ignore_channels]
else:
ignore_channels = []
Expand Down
8 changes: 3 additions & 5 deletions elecphys/fourier_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def freq_bands_power_over_time(
----------
input_npz_folder: str
path to input npz folder containing signal npz files (in time domain)
freq_bands: tuple
freq_bands: tuple, list
tuple or list of frequency bands to calculate power over time for. It should be a tuple or list of lists, where each list contains two elements: the lower and upper frequency bounds of the band (in Hz). For example, freq_bands = [[1, 4], [4, 8], [8, 12]] would calculate power over time for the delta, theta, and alpha bands.
channels_list: str
list of channels to include in analysis
Expand All @@ -427,10 +427,8 @@ def freq_bands_power_over_time(
Returns
----------
"""
if channels_list is not None:
channels_list = utils.convert_string_to_list(channels_list)
if ignore_channels is not None:
ignore_channels = utils.convert_string_to_list(ignore_channels)
channels_list = utils.convert_string_to_list(channels_list)
ignore_channels = utils.convert_string_to_list(ignore_channels)

data_all, fs, channels_map = data_io.load_all_npz_files(input_npz_folder, ignore_channels, channels_list)
# if freq_bands only has one list, we should make sure it is a list of lists
Expand Down
64 changes: 42 additions & 22 deletions elecphys/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def normalize_npz(ctx, input_npz_folder: str,
required=True, type=str, default='output_npz_avg_rereferenced', show_default=True)
@click.option('--ignore_channels',
'-ic',
help='List of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored',
help='List of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored (e.g. --ignore_channels "[1,2,3]").',
required=False,
type=list,
type=str,
default=None,
show_default=True)
@click.option('--rr_channel',
Expand All @@ -191,7 +191,7 @@ def normalize_npz(ctx, input_npz_folder: str,
@click.pass_context
@error_handler
def re_reference_npz(ctx, input_npz_folder: str, output_npz_folder: str = 'output_npz_avg_rereferenced',
ignore_channels: list = None, rr_channel: int = None) -> None:
ignore_channels: str = None, rr_channel: int = None) -> None:
""" Re-references NPZ files and save them as NPZ files
Parameters
Expand All @@ -200,7 +200,7 @@ def re_reference_npz(ctx, input_npz_folder: str, output_npz_folder: str = 'outpu
path to input npz folder
output_npz_folder: str
path to output npz folder. If the folder already exists, it will be overwritten. If not specified, the default value is 'output_npz_avg_rereferenced'
ignore_channels: list
ignore_channels: str
list of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored. Either a list of channel indexes or a string of channel indexes separated by commas
rr_channel: int
channel to re-reference signals to. If None, signals will be re-referenced to the average of all channels. If not specified, the default value is None
Expand Down Expand Up @@ -357,13 +357,18 @@ def frequncy_domain_filter(ctx, input_npz_folder: str, output_npz_folder: str =
type=str)
@click.option('--channels_list',
'-cl',
help='List of channels to compute power for. If None, then all of the channels will be used. It should be a string of comma-separated channel numbers (e.g. "[1,2,3]").',
help='List of channels to compute power for. If None, then all of the channels will be used. It should be a string of comma-separated channel numbers (e.g. --channels_list "[1,2,3]").',
required=False,
type=str,
default=None,
show_default=True)
@click.option('--ignore_channels',
'-ic',
help='List of channels to ignore. If None, then no channels will be ignored (e.g. --ignore_channels "[1,2,3]").',
required=False,
type=str,
default=None,
show_default=True)
@click.option('--ignore_channels', '-ic', help='List of channels to ignore. If None, then no channels will be ignored',
required=False, type=str, default=None, show_default=True)
@click.option('--window_size', '-w', help='Window size in seconds',
required=False, type=float, default=1, show_default=True)
@click.option('--overlap', '-ov', help='Overlap in seconds',
Expand Down Expand Up @@ -546,7 +551,7 @@ def plot_avg_stft(
t_max: float = None,
db_min: float = None,
db_max: float = None,
channels_list: list = None) -> None:
channels_list: str = None) -> None:
""" Plots average STFT from NPZ files
Parameters
Expand All @@ -567,7 +572,7 @@ def plot_avg_stft(
minimum dB to plot. If not specified, the default value is None and the minimum dB will be the minimum dB of the signal
db_max: float
maximum dB to plot. If not specified, the default value is None and the maximum dB will be the maximum dB of the signal
channels_list: list
channels_list: str
list of channels to plot. either a string of comma-separated channel numbers or a list of integers. If not specified, the default value is None and all of the channels will be plotted.
Returns
Expand Down Expand Up @@ -597,8 +602,13 @@ def plot_avg_stft(
required=False, type=float, default=None, show_default=True)
@click.option('--t_max', '-tmax', help='End of time interval to plot',
required=False, type=float, default=None, show_default=True)
@click.option('--channels_list', '-cl', help='List of channels to plot, if None then all of the channels will be plotted',
required=False, type=list, default=None, show_default=True)
@click.option('--channels_list',
'-cl',
help='List of channels to plot, if None then all of the channels will be plotted (e.g. --channels_list "[1,2,3]").',
required=False,
type=str,
default=None,
show_default=True)
@click.option('--normalize', '-n', help='Normalize signals. If true, each channel will be normalized',
required=False, type=bool, default=False, show_default=True)
@click.option('--scale_bar', '-sb', help='Scale bar. If true, a scale bar will be added to the plot',
Expand All @@ -612,7 +622,7 @@ def plot_avg_stft(
show_default=True)
@click.option('--ignore_channels',
'-ic',
help='List of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored',
help='List of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored (e.g. --ignore_channels "[1,2,3]").',
required=False,
type=str,
default=None,
Expand All @@ -632,7 +642,7 @@ def plot_signal(
output_plot_file: str,
t_min: float = None,
t_max: float = None,
channels_list: list = None,
channels_list: str = None,
normalize: bool = False,
scale_bar: bool = True,
re_reference: bool = False,
Expand All @@ -650,7 +660,7 @@ def plot_signal(
Start of time interval to plot. If not specified, the default value is None and the minimum time will be 0 seconds
t_max: float
End of time interval to plot. If not specified, the default value is None and the maximum time will be the total duration of the signal
channels_list: list
channels_list: str
list of channels to plot. either a string of comma-separated channel numbers or a list of integers. If not specified, the default value is None and all of the channels will be plotted.
normalize: bool
normalize signals. If true, each channel will be normalized. If not specified, the default value is False.
Expand Down Expand Up @@ -696,8 +706,13 @@ def plot_signal(
required=False, type=float, default=None, show_default=True)
@click.option('--f_max', '-fmax', help='Maximum frequency to plot in Hz',
required=False, type=float, default=None, show_default=True)
@click.option('--channels_list', '-cl', help='List of channels to plot, if None then all of the channels will be plotted',
required=False, type=list, default=None, show_default=True)
@click.option('--channels_list',
'-cl',
help='List of channels to plot, if None then all of the channels will be plotted (e.g. --channels_list "[1,2,3]").',
required=False,
type=str,
default=None,
show_default=True)
@click.option('--plot_type',
'-pt',
help='Plot type. If "all_channels", then all channels will be plotted in one figure. If "average_of_channels", then average of channels will be plotted in one figure with errorbar',
Expand All @@ -710,7 +725,7 @@ def plot_signal(
@click.pass_context
@error_handler
def plot_dft(ctx, input_npz_folder: str, output_plot_file: str, f_min: float = None, f_max: float = None,
channels_list: list = None, plot_type: str = 'average_of_channels', conv_window_size: int = None) -> None:
channels_list: str = None, plot_type: str = 'average_of_channels', conv_window_size: int = None) -> None:
""" Plots DFT from NPZ file
Parameters
Expand All @@ -723,7 +738,7 @@ def plot_dft(ctx, input_npz_folder: str, output_plot_file: str, f_min: float = N
minimum frequency to plot in Hz. If not specified, the default value is None and the minimum frequency will be 0 Hz
f_max: float
maximum frequency to plot in Hz. If not specified, the default value is None and the maximum frequency will be the Nyquist frequency
channels_list: list
channels_list: str
list of channels to plot. either a string of comma-separated channel numbers or a list of integers. If not specified, the default value is None and all of the channels will be plotted.
plot_type: str
plot type. If not specified, the default value is 'average_of_channels'. It should be 'all_channels' or 'average_of_channels'
Expand Down Expand Up @@ -817,12 +832,17 @@ def plot_filter_freq_response(ctx, filter_type: str = 'LPF', freq_cutoff: str =
type=bool,
default=False,
show_default=True)
@click.option('--channels_list', '-cl', help='List of channels to apply PCA, if None then all of the channels will be applied',
required=False, type=list, default=None, show_default=True)
@click.option('--channels_list',
'-cl',
help='List of channels to apply PCA, if None then all of the channels will be applied (e.g. --channels_list "[1,2,3]").',
required=False,
type=list,
default=None,
show_default=True)
@click.pass_context
@error_handler
def pca_from_npz(ctx, input_npz_folder: str, output_npz_folder: str = 'output_npz_pca',
n_components: int = None, matrix_whitenning: bool = False, channels_list: list = None) -> None:
n_components: int = None, matrix_whitenning: bool = False, channels_list: str = None) -> None:
""" Computes PCA from NPZ files
Parameters
Expand All @@ -835,7 +855,7 @@ def pca_from_npz(ctx, input_npz_folder: str, output_npz_folder: str = 'output_np
number of components to keep after applying the PCA. If not specified, the default value is None
matrix_whitenning: bool
matrix whitening boolean. If true, the singular values are divided by n_samples. If not specified, the default value is False
channels_list: list
channels_list: str
list of channels to apply PCA. either a string of comma-separated channel numbers or a list of integers. If not specified, the default value is None and all of the channels will be applied.
Returns
Expand Down
7 changes: 3 additions & 4 deletions elecphys/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def re_reference_npz(input_npz_folder: str, output_npz_folder: str, ignore_chann
path to input npz folder
output_npz_folder: str
path to output npz folder
ignore_channels: list
ignore_channels: list, str
list of channels to be ignored. Either a list of channel indexes or a string of channel indexes separated by commas. If None, no channels will be ignored
rr_channel: int
channel to be used as reference. If None, average re-referencing will be used
Expand Down Expand Up @@ -171,7 +171,7 @@ def re_reference(data: np.ndarray, ignore_channels: [
----------
data: numpy.ndarray
data to be re-referenced. Shape: (n_channels, n_samples)
ignore_channels: list
ignore_channels: str, list
list of channels to be ignored. Either a list of channel indexes or a string of channel indexes separated by commas. If None, no channels will be ignored
rr_channel: int
channel to be used as reference. If None, average re-referencing will be used
Expand All @@ -181,15 +181,14 @@ def re_reference(data: np.ndarray, ignore_channels: [
data_rereferenced: numpy.ndarray
re-referenced data. Shape: (n_channels, n_samples)
"""
ignore_channels = utils.convert_string_to_list(ignore_channels)
if ignore_channels is not None:
ignore_channels = utils.convert_string_to_list(ignore_channels)
ignore_channels = [i - 1 for i in ignore_channels]
channels_list = [
i for i in range(
data.shape[0]) if i not in ignore_channels]
else:
channels_list = [i for i in range(data.shape[0])]

rr_channel = rr_channel - 1 if rr_channel is not None else None
data_rereferenced = data.copy()
if rr_channel is not None:
Expand Down
2 changes: 1 addition & 1 deletion elecphys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def convert_string_to_list(string):
output: list
converted list
"""
if string is None:
if string is None or string == '' or string == 'None':
return None
if not isinstance(string, str):
string = remove_non_numeric(string)
Expand Down
40 changes: 10 additions & 30 deletions elecphys/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,10 @@ def plot_avg_stft_from_npz(npz_folder_path: str, output_plot_file: str, f_min: i

npz_files = os.listdir(npz_folder_path)
npz_files = utils.sort_file_names(npz_files)

channels_list = utils.convert_string_to_list(channels_list)
if channels_list is None:
channels_list = tuple(range(1, len(npz_files) + 1))
else:
channels_list = utils.convert_string_to_list(channels_list)
channels_list = sorted(channels_list)

for channel_index, channel in enumerate(channels_list):
Expand Down Expand Up @@ -246,11 +245,10 @@ def plot_signals_from_npz(
npz_files = os.listdir(npz_folder_path)
# remove non-NPZ files
npz_files = utils.sort_file_names(npz_files)

channels_list = utils.convert_string_to_list(channels_list)
if channels_list is None:
channels_list = tuple(range(1, len(npz_files) + 1))
else:
channels_list = utils.convert_string_to_list(channels_list)
channels_list = sorted(channels_list)

if len(channels_list) > 20:
Expand All @@ -269,29 +267,12 @@ def plot_signals_from_npz(
else:
ax = fig.subplots(len(channels_list), 1, sharex=True)

for row_no, channel_index in enumerate(channels_list):
channel_index = channel_index - 1
npz_file = npz_files[channel_index]
signal_chan, fs = data_io.load_npz(
os.path.join(npz_folder_path, npz_file))
if normalize:
signal_chan = preprocessing.normalize(signal_chan)
t = np.linspace(0, len(signal_chan) / fs, len(signal_chan))
if t_min is None:
t_min = np.min(t)
if t_max is None:
t_max = np.max(t)

desired_time_index_low = np.where(
np.min(abs(t - t_min)) == abs(t - t_min))[0][0]
desired_time_index_high = np.where(
np.min(abs(t - t_max)) == abs(t - t_max))[0][0]
signal_chan = signal_chan[desired_time_index_low:desired_time_index_high]
t = t[desired_time_index_low:desired_time_index_high]
if channel_index == 0:
data_all = np.zeros((len(channels_list), len(signal_chan)))
data_all[row_no, :] = signal_chan

data_all, fs, channels_map = data_io.load_all_npz_files(npz_folder_path, channels_list=channels_list)
t = np.arange(0, data_all.shape[1] / fs, 1 / fs)
if t_min is None:
t_min = np.min(t)
if t_max is None:
t_max = np.max(t)
if _rereference_args is not None:
ignore_channels = _rereference_args['ignore_channels']
rr_channel = _rereference_args['rr_channel']
Expand All @@ -307,7 +288,7 @@ def plot_signals_from_npz(
data_all = preprocessing.re_reference(
data_all, ignore_channels, rr_channel)

for row_no, channel_index in enumerate(channels_list):
for row_no, channel_index in enumerate(channels_map):
signal_chan = data_all[row_no, :]
if not scale_bar:
ax[row_no].plot(t, signal_chan, color='k')
Expand Down Expand Up @@ -393,11 +374,10 @@ def plot_dft_from_npz(npz_folder_path: str, output_plot_file: str, f_min: int, f
npz_files = os.listdir(npz_folder_path)
npz_files = utils.keep_npz_files(npz_files)
npz_files = utils.sort_file_names(npz_files)

channels_list = utils.convert_string_to_list(channels_list)
if channels_list is None:
channels_list = tuple(range(1, len(npz_files) + 1))
else:
channels_list = utils.convert_string_to_list(channels_list)
channels_list = sorted(channels_list)

if len(channels_list) > 20:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read_description():

setup(
name="ElecPhys",
version="0.0.54",
version="0.0.55",
author='Amin Alam',
description='Electrophysiology data processing',
long_description=read_description(),
Expand Down
Loading

0 comments on commit 925e801

Please sign in to comment.