Skip to content

Commit

Permalink
Added tests for stacked_area_reader
Browse files Browse the repository at this point in the history
  • Loading branch information
Chilipp committed May 10, 2018
1 parent 27950ca commit 52968c5
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 26 deletions.
6 changes: 5 additions & 1 deletion straditize/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def set_x_and_ylim(stradi):
stradi.data_ylim = ylim
if xlim or ylim or full:
set_x_and_ylim(stradi)
if reader_type == 'stacked area':
import straditize.widgets.stacked_area_reader
stradi.init_reader(reader_type)
stradi.data_reader.digitize()
if output:
stradi.data_reader.digitize()
stradi.data_reader.sample_locs, stradi.data_reader.rough_locs = \
stradi.data_reader.find_samples()
stradi.final_df.to_csv(output)
elif exec_:
stradi_widget.refresh()
Expand Down
4 changes: 2 additions & 2 deletions straditize/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def sample_locs(self):
if self.parent._sample_locs is not None:
return self.parent._sample_locs
elif self.parent._full_df is not None:
self.parent._sample_locs = self.parent._full_df.iloc[:0].copy(
True)
self.parent._sample_locs = pd.DataFrame(
[], columns=list(self.parent._full_df.columns))
return self.parent._sample_locs

@sample_locs.setter
Expand Down
2 changes: 2 additions & 0 deletions straditize/label_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,11 @@ def enable_label_selection(self, arr, ncolors, img=None,

def select_all_labels(self):
colors = [self.cunselect, self.cselect]
self._selection_arr = self._orig_selection_arr.copy()
self._select_img.set_cmap(self.copy_cmap(self._select_img.get_cmap(),
colors))
self._select_img.set_norm(self._select_norm)
self._select_img.set_array(self._selection_arr)
self._update_magni_img()

def _update_magni_img(self):
Expand Down
58 changes: 35 additions & 23 deletions straditize/widgets/stacked_area_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def digitize(self):
digitizer.cb_readers.setEnabled(False)
digitizer.tree.expandItem(digitizer.digitize_item)
self.enable_or_disable_navigation_buttons()
self.reset_lbl_col()
elif not digitizing:
# stop digitization
digitizer.btn_digitize.setChecked(False)
Expand Down Expand Up @@ -81,10 +82,12 @@ def reset_lbl_col(self):
def increase_current_col(self):
self._current_col = min(self.columns[-1], self._current_col + 1)
self.reset_lbl_col()
self.enable_or_disable_navigation_buttons()

def decrease_current_col(self):
self._current_col = max(self.columns[0], self._current_col - 1)
self.reset_lbl_col()
self.enable_or_disable_navigation_buttons()

def _remove_digitze_child(self, digitizer):
digitizer.digitize_item.takeChild(
Expand All @@ -98,15 +101,16 @@ def _remove_digitze_child(self, digitizer):
def enable_or_disable_navigation_buttons(self):
disable_all = self.columns is None or len(self.columns) == 1
self.btn_prev.setEnabled(not disable_all and
self._current_col != self.columns[1])
self._current_col != self.columns[0])
self.btn_next.setEnabled(not disable_all and
self._current_col != self.columns[-1])

def select_current_column(self, add_on_apply=False):
image = np.array(self.image.convert('L')).astype(int) + 1
image = self.to_grey_pil(self.image.convert('L')).astype(int) + 1
start = self.start_of_current_col
end = start + self.full_df[self._current_col].values
all_end = start + self.full_df[self.columns[-1]].values
all_end = start + self.full_df.loc[:, self._current_col:].values.sum(
axis=1)
x = np.meshgrid(*map(np.arange, image.shape[::-1]))[0]
image[(x < start[:, np.newaxis]) | (x > all_end[:, np.newaxis])] = 0
labels = skim.label(image, 8)
Expand All @@ -118,7 +122,7 @@ def select_current_column(self, add_on_apply=False):
self.select_all_labels()
# set values outside the current column to 0
self._selection_arr[(x < start[:, np.newaxis]) |
(x > end[:, np.newaxis])] = -1
(x >= end[:, np.newaxis])] = -1
self._select_img.set_array(self._selection_arr)
self.draw_figure()

Expand All @@ -133,11 +137,14 @@ def start_of_current_col(self):
return start

def update_col(self):
"""Update the current column based on the selection"""
"""Update the current column based on the selection.
This method updates the end of the current column and adds or removes
the changes from the columns to the right."""
current = self._current_col
start = self.start_of_current_col
selected = self.selected_part
end = (self.binary.shape[1] - selected[:, ::-1].argmax(axis=1) - 1 -
end = (self.binary.shape[1] - selected[:, ::-1].argmax(axis=1) -
start)
not_selected = ~selected.any()
end[not_selected] = 0
Expand Down Expand Up @@ -167,7 +174,7 @@ def increase_col_nums(df):
current = self._current_col
start = self.start_of_current_col
selected = self.selected_part
end = (self.binary.shape[1] - selected[:, ::-1].argmax(axis=1) - 1 -
end = (self.binary.shape[1] - selected[:, ::-1].argmax(axis=1) -
start)
not_selected = ~selected.any()
end[not_selected] = 0
Expand All @@ -190,13 +197,9 @@ def increase_col_nums(df):
full_df = self.parent._full_df
increase_col_nums(full_df)
# increase column numbers in samples
samples = self.parent.sample_locs
samples = self.parent._sample_locs
if samples is not None:
increase_col_nums(samples)
# increase column numbers in rough locations
rough_locs = self.parent.rough_locs
if rough_locs is not None:
increase_col_nums(rough_locs)

# ----- Update of DataFrames -----
# update the current column in full_df and add the new one
Expand All @@ -207,9 +210,14 @@ def increase_col_nums(df):
if samples is not None:
new_samples = full_df.loc[samples.index, current]
samples.loc[:, current + 1] -= new_samples
samples[:, current] = new_samples
samples[current] = new_samples
samples.sort_index(axis=1, inplace=True)
rough_locs = self.parent.rough_locs
if rough_locs is not None:
rough_locs[current] = 0
rough_locs[(current + 1, 'vmin')] = rough_locs[(current, 'vmin')]
rough_locs[(current + 1, 'vmax')] = rough_locs[(current, 'vmax')]
rough_locs.loc[:, current] = -1
rough_locs.sort_index(inplace=True, level=0)
self.reset_lbl_col()
self.enable_or_disable_navigation_buttons()

Expand All @@ -226,33 +234,37 @@ def plot_full_df(self, ax=None):
x = np.zeros_like(vals[:, 0]) + starts[0]
for i in range(vals.shape[1]):
x += vals[:, i]
lines.extend(ax.plot(x, y, lw=2.0))
lines.extend(ax.plot(x.copy(), y, lw=2.0))

def plot_potential_samples(self, excluded=False, ax=None,
*args, **kwargs):
def plot_potential_samples(self, excluded=False, ax=None, plot_kws={},
*args, **kwargs):
"""Plot the ranges for potential samples"""
vals = self.full_df.values.copy()
starts = self.column_starts.copy()
self.sample_ranges = lines = []
y = np.arange(np.shape(self.image)[0])
ax = ax or self.ax
plot_kws = dict(plot_kws)
plot_kws.setdefault('marker', '+')
if self.extent is not None:
y += self.extent[-1]
starts = starts + self.extent[0]
x = np.zeros(vals.shape[0]) + starts[0]
for i, arr in enumerate(vals.T):
for i, (col, arr) in enumerate(zip(self.columns, vals.T)):
all_indices, excluded_indices = self.find_potential_samples(
i, *args, **kwargs)
if excluded:
all_indices = excluded_indices
if not all_indices:
x += arr
continue
indices = list(chain.from_iterable(all_indices))
mask = np.ones(arr.size, dtype=bool)
mask[indices] = False
for l in all_indices:
lines.extend(ax.plot(np.where(mask, np.nan, arr)[l] + x[l],
y[l], marker='+'))
for imin, imax in all_indices:
mask[imin:imax] = False
for imin, imax in all_indices:
lines.extend(ax.plot(
np.where(mask, np.nan, arr)[imin:imax] + x[imin:imax],
y[imin:imax], **plot_kws))
x += arr


Expand Down
Binary file added tests/test_figures/stacked_diagram.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
130 changes: 130 additions & 0 deletions tests/widgets/test_stacked_area.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Test the straditize.widgets.stacked_area_reader module"""
import os.path as osp
import pandas as pd
import _base_testing as bt
import unittest
from straditize.widgets.stacked_area_reader import StackedReader
from psyplot_gui.compat.qtcompat import QTest, Qt


class StackedAreaReaderTest(bt.StraditizeWidgetsTestCase):
"""A test case for the BarReader"""

@property
def toolbar(self):
return self.straditizer_widgets.selection_toolbar

def select_rectangle(self, x0, x1, y0, y1):
tb = self.toolbar
slx, sly = tb.get_xy_slice(x0, y0, x1, y1)
tb.select_rect(slx, sly)

def init_reader(self, fname='stacked_diagram.png', *args, **kwargs):
"""Reimplemented to make sure, we intiailize a bar diagram"""
self.digitizer.cb_reader_type.setCurrentText('stacked area')
super(StackedAreaReaderTest, self).init_reader(fname, *args, **kwargs)
self.assertIsInstance(self.reader, StackedReader)

def test_init_reader(self):
self.init_reader()

def test_digitize(self):
self.init_reader()
QTest.mouseClick(self.digitizer.btn_column_starts, Qt.LeftButton)
QTest.mouseClick(self.straditizer_widgets.apply_button,
Qt.LeftButton)
QTest.mouseClick(self.digitizer.btn_digitize, Qt.LeftButton)
self.assertTrue(self.digitizer.btn_digitize.isCheckable())
self.assertTrue(self.digitizer.btn_digitize.isChecked())
self.assertEqual(list(self.reader._full_df.columns), [0])
tb = self.toolbar

# select the first column
QTest.mouseClick(self.reader.btn_add, Qt.LeftButton)
tb.set_color_wand_mode()
tb.wand_action.setChecked(True)
tb.toggle_selection()
self.select_rectangle(10.5, 10.5, 10.5, 10.5)
QTest.mouseClick(self.digitizer.apply_button, Qt.LeftButton)
self.assertEqual(list(self.reader._full_df.columns), [0, 1])

# now select the second column
QTest.mouseClick(self.reader.btn_next, Qt.LeftButton)
QTest.mouseClick(self.reader.btn_add, Qt.LeftButton)
tb.set_color_wand_mode()
tb.wand_action.setChecked(True)
tb.toggle_selection()
self.select_rectangle(13.5, 13.5, 10.5, 10.5)
QTest.mouseClick(self.digitizer.apply_button, Qt.LeftButton)
self.assertEqual(list(self.reader._full_df.columns), [0, 1, 2])

# test the digitization result
full_df = self.reader.full_df
ref = pd.read_csv(self.get_fig_path(osp.join('data', 'full_data.csv')),
index_col=0, dtype=float)
self.assertEqual(list(map(str, full_df.columns)),
list(map(str, ref.columns)))
ref.columns = full_df.columns
self.assertFrameEqual(full_df, ref, check_index_type=False)

QTest.mouseClick(self.reader.btn_prev, Qt.LeftButton)

# end digitizing
QTest.mouseClick(self.digitizer.btn_digitize, Qt.LeftButton)

def test_edit_col(self):
"""Test the editing of a column"""
self.test_digitize()
ref = self.reader.full_df.copy(True)
tb = self.toolbar
# restart the digitization
QTest.mouseClick(self.digitizer.btn_digitize, Qt.LeftButton)
QTest.mouseClick(self.reader.btn_edit, Qt.LeftButton)

# deselect one pixel
tb.set_rect_select_mode()
tb.select_action.setChecked(True)
tb.remove_select_action.setChecked(True)
tb.toggle_selection()
self.select_rectangle(11.5, 11.5, 10.5, 10.5)

QTest.mouseClick(self.digitizer.apply_button, Qt.LeftButton)

# now the first cell should be lowered
self.assertEqual(self.reader.full_df.iloc[0, 0], ref.iloc[0, 0] - 1)

# end digitizing
QTest.mouseClick(self.digitizer.btn_digitize, Qt.LeftButton)

def test_plot_full_df(self):
"""Test the visualization of the full df"""
self.test_digitize()
self.reader.plot_full_df()
x0 = self.straditizer.data_xlim[0]
self.assertEqual(list(self.reader.lines[0].get_xdata()),
list(x0 + self.reader.full_df.iloc[:, 0]))
self.assertEqual(list(self.reader.lines[1].get_xdata()),
list(x0 + self.reader.full_df.iloc[:, :2].sum(axis=1))
)
self.assertEqual(list(self.reader.lines[2].get_xdata()),
list(x0 + self.reader.full_df.sum(axis=1)))

def test_plot_potential_samples(self):
"""Test the visualization of the full df"""
self.test_digitize()
self.reader.plot_potential_samples()
x0 = self.straditizer.data_xlim[0]
j = 0
for col in range(3):
for i, (s, e) in enumerate(
self.reader.find_potential_samples(col)[0]):
self.assertEqual(
list(self.reader.sample_ranges[j].get_xdata()),
list(x0 +
self.reader.full_df.iloc[s:e, :col + 1].sum(axis=1)),
msg='Failed at sample %i in column %i' % (i, col))
j += 1


if __name__ == '__main__':
unittest.main()

0 comments on commit 52968c5

Please sign in to comment.