diff --git a/plotpy/items/image/base.py b/plotpy/items/image/base.py index b228d58..0a251d8 100644 --- a/plotpy/items/image/base.py +++ b/plotpy/items/image/base.py @@ -340,6 +340,15 @@ def get_r_values(self, i0, i1, j0, j1, flag_circle=False): """ return self.get_x_values(i0, i1) + def _recompute_log_data(self) -> None: + """Refresh the cached log10 data from the current ``self.data``. + + Used both when toggling the Z-axis log scale on and when the underlying + data is replaced (e.g. via :meth:`set_data`) while the log scale is + already active. + """ + self._log_data = np.array(np.log10(self.data.clip(1)), dtype=np.float64) + def set_data( self, data: np.ndarray, lut_range: tuple[float, float] | None = None ) -> None: @@ -353,9 +362,15 @@ def set_data( self.histogram_cache = None self.update_bounds() self.update_border() + # Refresh the cached log10 data when log scale is active, otherwise the + # display would keep using the previous (now stale) log data. + if self.get_zaxis_log_state(): + self._recompute_log_data() if not self.param.keep_lut_range: if lut_range is not None: _min, _max = lut_range + elif self.get_zaxis_log_state(): + _min, _max = get_nan_range(self._log_data) else: _min, _max = get_nan_range(data) self.set_lut_range((_min, _max)) @@ -574,7 +589,7 @@ def set_zaxis_log_state(self, state: bool) -> None: if state: self._lin_lut_range = self.get_lut_range() if self._log_data is None: - self._log_data = np.array(np.log10(self.data.clip(1)), dtype=np.float64) + self._recompute_log_data() self.set_lut_range(get_nan_range(self._log_data)) dtype = self._log_data.dtype else: diff --git a/plotpy/tests/unit/test_image_log_set_data.py b/plotpy/tests/unit/test_image_log_set_data.py new file mode 100644 index 0000000..cc26cb1 --- /dev/null +++ b/plotpy/tests/unit/test_image_log_set_data.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the terms of the BSD 3-Clause +# (see plotpy/LICENSE for details) + +"""Regression tests for cached log10 data refresh in ``ImageItem.set_data``. + +When the Z-axis is in logarithmic scale, ``ImageItem`` keeps a cached +``_log_data`` array. Prior to the fix, calling ``set_data`` did not refresh +that cache, so the displayed image kept reflecting the previous values until +the user toggled the log scale off and on again. +""" + +from __future__ import annotations + +import numpy as np +from guidata.qthelpers import qt_app_context + +from plotpy.builder import make + + +def _make_item(data: np.ndarray): + """Return an ``ImageItem`` ready for log-scale tests.""" + return make.image(data, interpolation="nearest") + + +def test_set_data_refreshes_log_data_when_log_scale_enabled() -> None: + """``set_data`` must recompute ``_log_data`` when log scale is active.""" + with qt_app_context(exec_loop=False): + first = np.array([[1.0, 10.0], [100.0, 1000.0]]) + item = _make_item(first) + item.set_zaxis_log_state(True) + np.testing.assert_array_almost_equal(item._log_data, np.log10(first.clip(1))) + + second = np.array([[10.0, 100.0], [1000.0, 10000.0]]) + item.set_data(second) + + # The cache must reflect the new data, not the previous one. + np.testing.assert_array_almost_equal(item._log_data, np.log10(second.clip(1))) + # And the LUT range must be derived from the refreshed log data. + lut_min, lut_max = item.get_lut_range() + assert lut_min == np.log10(second.clip(1)).min() + assert lut_max == np.log10(second.clip(1)).max() + + +def test_set_data_keeps_lut_range_in_log_mode() -> None: + """``keep_lut_range`` must be honored even when log scale is active.""" + with qt_app_context(exec_loop=False): + first = np.array([[1.0, 10.0], [100.0, 1000.0]]) + item = _make_item(first) + item.set_zaxis_log_state(True) + item.set_lut_range((0.5, 2.5)) + item.param.keep_lut_range = True + + second = np.array([[10.0, 100.0], [1000.0, 10000.0]]) + item.set_data(second) + + # Cache must still be refreshed (display correctness)… + np.testing.assert_array_almost_equal(item._log_data, np.log10(second.clip(1))) + # …but the LUT range must remain frozen as requested by the user. + assert item.get_lut_range() == (0.5, 2.5) + + +def test_set_data_does_not_create_log_data_when_log_scale_disabled() -> None: + """When log scale is off, ``set_data`` must not create ``_log_data``.""" + with qt_app_context(exec_loop=False): + item = _make_item(np.array([[1.0, 2.0], [3.0, 4.0]])) + assert item._log_data is None + item.set_data(np.array([[5.0, 6.0], [7.0, 8.0]])) + assert item._log_data is None