Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion plotpy/items/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,15 @@ def get_r_values(self, i0, i1, j0, j1, flag_circle=False):
"""
return self.get_x_values(i0, i1)

def _recompute_log_data(self) -> None:
"""Refresh the cached log10 data from the current ``self.data``.

Used both when toggling the Z-axis log scale on and when the underlying
data is replaced (e.g. via :meth:`set_data`) while the log scale is
already active.
"""
self._log_data = np.array(np.log10(self.data.clip(1)), dtype=np.float64)

def set_data(
self, data: np.ndarray, lut_range: tuple[float, float] | None = None
) -> None:
Expand All @@ -353,9 +362,15 @@ def set_data(
self.histogram_cache = None
self.update_bounds()
self.update_border()
# Refresh the cached log10 data when log scale is active, otherwise the
# display would keep using the previous (now stale) log data.
if self.get_zaxis_log_state():
self._recompute_log_data()
if not self.param.keep_lut_range:
if lut_range is not None:
_min, _max = lut_range
elif self.get_zaxis_log_state():
_min, _max = get_nan_range(self._log_data)
else:
_min, _max = get_nan_range(data)
self.set_lut_range((_min, _max))
Expand Down Expand Up @@ -574,7 +589,7 @@ def set_zaxis_log_state(self, state: bool) -> None:
if state:
self._lin_lut_range = self.get_lut_range()
if self._log_data is None:
self._log_data = np.array(np.log10(self.data.clip(1)), dtype=np.float64)
self._recompute_log_data()
self.set_lut_range(get_nan_range(self._log_data))
dtype = self._log_data.dtype
else:
Expand Down
70 changes: 70 additions & 0 deletions plotpy/tests/unit/test_image_log_set_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
#
# Licensed under the terms of the BSD 3-Clause
# (see plotpy/LICENSE for details)

"""Regression tests for cached log10 data refresh in ``ImageItem.set_data``.

When the Z-axis is in logarithmic scale, ``ImageItem`` keeps a cached
``_log_data`` array. Prior to the fix, calling ``set_data`` did not refresh
that cache, so the displayed image kept reflecting the previous values until
the user toggled the log scale off and on again.
"""

from __future__ import annotations

import numpy as np
from guidata.qthelpers import qt_app_context

from plotpy.builder import make


def _make_item(data: np.ndarray):
"""Return an ``ImageItem`` ready for log-scale tests."""
return make.image(data, interpolation="nearest")


def test_set_data_refreshes_log_data_when_log_scale_enabled() -> None:
"""``set_data`` must recompute ``_log_data`` when log scale is active."""
with qt_app_context(exec_loop=False):
first = np.array([[1.0, 10.0], [100.0, 1000.0]])
item = _make_item(first)
item.set_zaxis_log_state(True)
np.testing.assert_array_almost_equal(item._log_data, np.log10(first.clip(1)))

second = np.array([[10.0, 100.0], [1000.0, 10000.0]])
item.set_data(second)

# The cache must reflect the new data, not the previous one.
np.testing.assert_array_almost_equal(item._log_data, np.log10(second.clip(1)))
# And the LUT range must be derived from the refreshed log data.
lut_min, lut_max = item.get_lut_range()
assert lut_min == np.log10(second.clip(1)).min()
assert lut_max == np.log10(second.clip(1)).max()


def test_set_data_keeps_lut_range_in_log_mode() -> None:
"""``keep_lut_range`` must be honored even when log scale is active."""
with qt_app_context(exec_loop=False):
first = np.array([[1.0, 10.0], [100.0, 1000.0]])
item = _make_item(first)
item.set_zaxis_log_state(True)
item.set_lut_range((0.5, 2.5))
item.param.keep_lut_range = True

second = np.array([[10.0, 100.0], [1000.0, 10000.0]])
item.set_data(second)

# Cache must still be refreshed (display correctness)…
np.testing.assert_array_almost_equal(item._log_data, np.log10(second.clip(1)))
# …but the LUT range must remain frozen as requested by the user.
assert item.get_lut_range() == (0.5, 2.5)


def test_set_data_does_not_create_log_data_when_log_scale_disabled() -> None:
"""When log scale is off, ``set_data`` must not create ``_log_data``."""
with qt_app_context(exec_loop=False):
item = _make_item(np.array([[1.0, 2.0], [3.0, 4.0]]))
assert item._log_data is None
item.set_data(np.array([[5.0, 6.0], [7.0, 8.0]]))
assert item._log_data is None
Loading