Skip to content

Commit

Permalink
Merge pull request #32 from AminAlam/dev
Browse files Browse the repository at this point in the history
BUG fix + feature addition
  • Loading branch information
AminAlam committed Jan 15, 2024
2 parents 686003e + dd58238 commit 3bd7d77
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 123 deletions.
14 changes: 11 additions & 3 deletions elecphys/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def re_reference_npz(ctx, input_npz_folder: str, output_npz_folder: str = 'outpu
----------
"""

print('--- Average re-referencing NPZ files...')
print('--- Re-referencing NPZ files...')
preprocessing.re_reference_npz(
input_npz_folder,
output_npz_folder,
Expand Down Expand Up @@ -463,15 +463,16 @@ def plot_avg_stft(ctx, input_npz_folder: str, output_plot_file: str, f_min: floa
required=False, type=list, default=None, show_default=True)
@click.option('--normalize', '-n', help='Normalize signals. If true, each channel will be normalized',
required=False, type=bool, default=False, show_default=True)
@click.option('--scale_bar', '-sb', help='Scale bar. If true, a scale bar will be added to the plot', required=False, type=bool, default=True, show_default=True)
@click.option('--re_reference', '-rr', help='Re-reference signals. If -rr_channel not specified, signals will be re-referenced to the average of all channels, unless they will be re-referenced to the given channel. If --ignore_channels is specified, specified channels will not be re referenced or taken into account for avg rereferencing', required=False, type=bool, default=False, show_default=True)
@click.option('--ignore_channels', '-ic', help='List of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored',
required=False, type=list, default=None, show_default=True)
required=False, type=str, default=None, show_default=True)
@click.option('--rr_channel', '-rrc', help='Channel to re-reference signals to. If None, signals will be re-referenced to the average of all channels',
required=False, type=int, default=None, show_default=True)
@click.pass_context
@error_handler
def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float = None, t_max: float = None, channels_list: list = None,
normalize: bool = False, re_reference: bool = False, ignore_channels: list = None, rr_channel: int = None) -> None:
normalize: bool = False, scale_bar: bool = True, re_reference: bool = False, ignore_channels: str = None, rr_channel: int = None) -> None:
""" Plots signals from NPZ file
Parameters
Expand All @@ -488,6 +489,12 @@ def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float
list of channels to plot. either a string of comma-separated channel numbers or a list of integers. If not specified, the default value is None and all of the channels will be plotted.
normalize: bool
normalize signals. If true, each channel will be normalized. If not specified, the default value is False.
scale_bar: bool
scale bar. If true, a scale bar will be added to the plot. If not specified, the default value is True.
re_reference: bool
re-reference signals. If true, signals will be re-referenced. If not specified, the default value is False.
ignore_channels: str
list of channels to ignore (e.g EMG, EOG, etc.). If None, then no channels will be ignored. Either a list of channel indexes or a string of channel indexes separated by commas. If not specified, the default value is None.
rr_channel: int
channel to re-reference signals to. If None, signals will be re-referenced to the average of all channels. If not specified, the default value is None.
Expand All @@ -510,6 +517,7 @@ def plot_signal(ctx, input_npz_folder: str, output_plot_file: str, t_min: float
t_max,
channels_list,
normalize,
scale_bar,
_rereference_args)
print('--- Plotting complete.\n\n')

Expand Down
8 changes: 6 additions & 2 deletions elecphys/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,16 @@ def re_reference(data: np.ndarray, ignore_channels: [
channels_list = [i for i in range(data.shape[0])]

rr_channel = rr_channel - 1 if rr_channel is not None else None

data_rereferenced = data.copy()

if rr_channel is not None:
reference = data[rr_channel, :].reshape(1, -1)
data_rereferenced[channels_list, :] = data[channels_list,
:] - data[rr_channel, :].reshape(1, -1)
:] - np.repeat(reference, len(channels_list), axis=0)
else:
reference = np.mean(data[channels_list, :], axis=0).reshape(1, -1)
data_rereferenced[channels_list, :] = data[channels_list,
:] - np.mean(data[channels_list, :], axis=0).reshape(1, -1)
:] - np.repeat(reference, len(channels_list), axis=0)

return data_rereferenced
1 change: 1 addition & 0 deletions elecphys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,5 @@ def convert_string_to_list(string):
string = string.split(',')
output = list(map(int, string))
output = np.unique(output)
output = output.tolist()
return output
76 changes: 52 additions & 24 deletions elecphys/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def plot_stft_from_array(Zxx: np.ndarray, t: np.ndarray, f: np.ndarray, f_min: i


def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: float, t_max: float, channels_list: [
str, list] = None, normalize: bool = False, _rereference_args: dict = None) -> None:
str, list] = None, normalize: bool = False, scale_bar: bool = True, _rereference_args: dict = None) -> None:
""" Plots signals from NPZ file
Parameters
Expand All @@ -224,6 +224,8 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl
list of channels to plot (can be a strin of comma-separated values or a list of integers)
normalize: bool
whether to normalize the signal
scale_bar: bool
whether to plot a scale bar
_rereference_args: dict
dictionary containing rereferencing parameters. If None, no rereferencing will be applied
Expand All @@ -245,7 +247,16 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl
else:
fig = plt.figure(figsize=(30, 10))

ax = fig.subplots(len(channels_list), 1, sharex=True)
if normalize:
scale_bar = False

if scale_bar:
ax = fig.subplots(
len(channels_list), 2, gridspec_kw={
'width_ratios': [
10, 1]})
else:
ax = fig.subplots(len(channels_list), 1, sharex=True)

for row_no, channel_index in enumerate(channels_list):
channel_index = channel_index - 1
Expand Down Expand Up @@ -275,39 +286,56 @@ def plot_signals_from_npz(npz_folder_path: str, output_plot_file: str, t_min: fl
rr_channel = _rereference_args['rr_channel']
ignore_channels = utils.convert_string_to_list(ignore_channels)
if rr_channel is not None:
print(rr_channel, channels_list)
if rr_channel not in channels_list:
raise ValueError('rr_channel must be in channels_list')
if ignore_channels is not None:
if not set(ignore_channels).issubset(set(channels_list)):
raise ValueError(
'All channels in ignore_channels must be in channels_list, as it does not make sense to ignore a channel for re-referencing if it is not in channels_list')

for row_no, channel_index in enumerate(channels_list):
channel_index = channel_index - 1
if rr_channel is not None:
if channel_index == rr_channel - 1:
rr_channel = row_no
if ignore_channels is not None:
if channel_index in ignore_channels:
ignore_channels[ignore_channels.index(
channel_index)] = row_no
data_all = preprocessing.re_reference(
data_all, ignore_channels, rr_channel)

for row_no, channel_index in enumerate(channels_list):

ax[row_no].plot(t, signal_chan, color='k')
ax[row_no].set_ylabel(f'Channel {channel_index}')
ax[row_no].spines['top'].set_visible(False)
ax[row_no].spines['right'].set_visible(False)
if row_no != len(channels_list) - 1:
ax[row_no].spines['bottom'].set_visible(False)
ax[row_no].tick_params(axis='both', which='both', length=0)
ax[row_no].set_yticks([])

ax[-1].set_xlabel('Time (s)')
ax[-1].set_xlim(t_min, t_max)
signal_chan = data_all[row_no, :]
if not scale_bar:
ax[row_no].plot(t, signal_chan, color='k')
ax[row_no].set_ylabel(f'Channel {channel_index}')
ax[row_no].spines['top'].set_visible(False)
ax[row_no].spines['right'].set_visible(False)
if row_no != len(channels_list) - 1:
ax[row_no].spines['bottom'].set_visible(False)
ax[row_no].tick_params(axis='both', which='both', length=0)
ax[row_no].set_yticks([])
ax[row_no].set_xlim(t_min, t_max)
else:
scale_bar_size = np.max(signal_chan) - np.min(signal_chan)
ax[row_no, 0].plot(t, signal_chan, color='k')
ax[row_no, 0].set_ylabel(f'Channel {channel_index}')
ax[row_no, 0].spines['top'].set_visible(False)
ax[row_no, 0].spines['right'].set_visible(False)
if row_no != len(channels_list) - 1:
ax[row_no, 0].spines['bottom'].set_visible(False)
ax[row_no, 0].tick_params(axis='both', which='both', length=0)
ax[row_no, 0].set_yticks([])
# symmetric error bar for scale bar
ax[row_no, 1].errorbar(
0, 0, yerr=scale_bar_size / 2, color='k', capsize=5)
ax[row_no, 1].text(
1 / fs, 0, f'{scale_bar_size:.2f} uV', ha='left')
ax[row_no, 1].axis('off')
ax[row_no, 0].set_xlim(t_min, t_max)
ax[row_no, 0].set_xticks([])
ax[row_no, 1].set_xticks([])

if not scale_bar:
ax[-1].set_xlabel('Time (s)')
ax[-1].set_xlim(t_min, t_max)
else:
ax[-1, 0].set_xlabel('Time (s)')
ax[-1, 1].set_xlabel('Scale bar')
ax[-1, 0].set_xlim(t_min, t_max)
ax[-1, 1].set_ylim(-scale_bar_size / 2, scale_bar_size / 2)
plt.tight_layout()

if output_plot_file is None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read_description():

setup(
name="ElecPhys",
version="0.0.52",
version="0.0.6",
author='Amin Alam',
description='Electrophysiology data processing',
long_description=read_description(),
Expand Down

0 comments on commit 3bd7d77

Please sign in to comment.