Skip to content

Commit 5bf71ba

Browse files
committed
BUG: Incorrectly resolved conflicts created series of bugs
When the choose-ui branch was merged in after the label_images branch it seems that some conflicts were automatically but incorrectly resolved. This PR cleans up those conflicts to support both features.
1 parent b0a7aa3 commit 5bf71ba

File tree

3 files changed

+45
-108
lines changed

3 files changed

+45
-108
lines changed

itkwidgets/_initialization_params.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,11 @@ def init_params_dict(itk_viewer):
2929
'y_slice': itk_viewer.setYSlice,
3030
'z_slice': itk_viewer.setZSlice,
3131
}
32+
33+
def init_key_aliases():
34+
return {
35+
'data': 'image',
36+
'image': 'image',
37+
'label_image': 'labelImage',
38+
'point_sets': 'pointSets',
39+
}

itkwidgets/integrations/__init__.py

Lines changed: 14 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -9,133 +9,51 @@
99
from .xarray import HAVE_XARRAY, xarray_data_array_to_numpy, xarray_data_set_to_numpy
1010
from ..render_types import RenderType
1111

12-
_image_count = 1
1312

14-
async def _set_viewer_image(itk_viewer, image, name=None, is_label=False):
15-
global _image_count
16-
if isinstance(image, itkwasm.Image):
17-
if not name:
18-
name = image.name
19-
if not name:
20-
name = f"image {_image_count}"
21-
_image_count += 1
22-
if is_label:
23-
await itk_viewer.setLabelImage(image)
24-
else:
25-
await itk_viewer.setImage(image, name)
26-
elif isinstance(image, np.ndarray):
27-
if not name:
28-
name = f"image {_image_count}"
29-
_image_count += 1
30-
if is_label:
31-
await itk_viewer.setLabelImage(image)
32-
else:
33-
await itk_viewer.setImage(image, name)
34-
elif isinstance(image, zarr.Group):
35-
if not name:
36-
name = f"image {_image_count}"
37-
_image_count += 1
38-
if is_label:
39-
await itk_viewer.setLabelImage(image)
40-
else:
41-
await itk_viewer.setImage(image, name)
42-
elif HAVE_ITK:
13+
async def _get_viewer_image(image):
14+
if HAVE_ITK:
4315
import itk
4416
if isinstance(image, itk.Image):
45-
wasm_image = itk_image_to_wasm_image(image)
46-
name = image.GetObjectName()
47-
if not name:
48-
name = f"image {_image_count}"
49-
_image_count += 1
50-
if is_label:
51-
await itk_viewer.setLabelImage(wasm_image)
52-
else:
53-
await itk_viewer.setImage(wasm_image, name)
17+
return itk_image_to_wasm_image(image)
5418
if HAVE_VTK:
5519
import vtk
5620
if isinstance(image, vtk.vtkImageData):
57-
ndarray = vtk_image_to_ndarray(image)
58-
if not name:
59-
name = f"image {_image_count}"
60-
_image_count += 1
61-
if is_label:
62-
await itk_viewer.setLabelImage(ndarray)
63-
else:
64-
await itk_viewer.setImage(ndarray, name)
21+
return vtk_image_to_ndarray(image)
6522
if HAVE_DASK:
6623
import dask
6724
if isinstance(image, dask.array.core.Array):
68-
ndarray = dask_array_to_ndarray(image)
69-
name = image.name
70-
if not name:
71-
name = f"image {_image_count}"
72-
_image_count += 1
73-
if is_label:
74-
await itk_viewer.setLabelImage(ndarray)
75-
else:
76-
await itk_viewer.setImage(ndarray, name)
25+
return dask_array_to_ndarray(image)
7726
if HAVE_TORCH:
7827
import torch
7928
if isinstance(image, torch.Tensor):
80-
if not name:
81-
name = f"image {_image_count}"
82-
_image_count += 1
83-
if is_label:
84-
await itk_viewer.setLabelImage(image.numpy())
85-
else:
86-
await itk_viewer.setImage(image.numpy(), name)
29+
return image.numpy()
8730
if HAVE_XARRAY:
8831
import xarray
8932
if isinstance(image, xarray.DataArray):
90-
ndarray = xarray_data_array_to_numpy(image)
91-
name = image.name
92-
if not name:
93-
name = f"image {_image_count}"
94-
_image_count += 1
95-
if is_label:
96-
await itk_viewer.setLabelImage(ndarray)
97-
else:
98-
await itk_viewer.setImage(ndarray, name)
33+
return xarray_data_array_to_numpy(image)
9934
if isinstance(image, xarray.Dataset):
100-
ndarray = xarray_data_set_to_numpy(image)
101-
if not name:
102-
name = f"image {_image_count}"
103-
_image_count += 1
104-
if is_label:
105-
await itk_viewer.setLabelImage(ndarray)
106-
else:
107-
await itk_viewer.setImage(ndarray, name)
35+
return xarray_data_set_to_numpy(image)
10836

10937

110-
async def _set_viewer_point_sets(itk_viewer, point_sets):
111-
if isinstance(point_sets, itkwasm.PointSet):
112-
await itk_viewer.setPointSets(point_sets)
113-
elif isinstance(point_sets, np.ndarray):
114-
await itk_viewer.setPointSets(point_sets)
115-
elif isinstance(point_sets, zarr.Group):
116-
await itk_viewer.setPointSets(point_sets)
38+
async def _get_viewer_point_sets(itk_viewer, point_sets):
11739
if HAVE_VTK:
11840
import vtk
11941
if isinstance(point_sets, vtk.vtkPolyData):
120-
vtkjs_polydata = vtk_polydata_to_vtkjs(point_sets)
121-
await itk_viewer.setPointSets(vtkjs_polydata)
42+
return vtk_polydata_to_vtkjs(point_sets)
12243
if HAVE_DASK:
12344
import dask
12445
if isinstance(point_sets, dask.array.core.Array):
125-
ndarray = dask_array_to_ndarray(point_sets)
126-
await itk_viewer.setPointSets(ndarray)
46+
return dask_array_to_ndarray(point_sets)
12747
if HAVE_TORCH:
12848
import torch
12949
if isinstance(point_sets, torch.Tensor):
130-
await itk_viewer.setPointSets(point_sets.numpy())
50+
return point_sets.numpy()
13151
if HAVE_XARRAY:
13252
import xarray
13353
if isinstance(point_sets, xarray.DataArray):
134-
ndarray = xarray_data_array_to_numpy(point_sets)
135-
await itk_viewer.setPointSets(ndarray)
54+
return xarray_data_array_to_numpy(point_sets)
13655
if isinstance(point_sets, xarray.Dataset):
137-
ndarray = xarray_data_set_to_numpy(point_sets)
138-
await itk_viewer.setPointSets(ndarray)
56+
return xarray_data_set_to_numpy(point_sets)
13957

14058

14159
def _detect_render_type(data, input_type) -> RenderType:

itkwidgets/viewer.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import List
33

44
from ._type_aliases import Gaussians, Style, Image, Point_Sets
5-
from ._initialization_params import init_params_dict
6-
from .integrations import _detect_render_type, _set_viewer_image, _set_viewer_point_sets
5+
from ._initialization_params import init_params_dict, init_key_aliases
6+
from .integrations import _detect_render_type, _get_viewer_image, _get_viewer_point_sets
77
from .render_types import RenderType
88

99
__all__ = [
@@ -23,9 +23,11 @@ def __init__(
2323
"""Create a viewer."""
2424
self._init_viewer_kwargs = dict(ui_collapsed=ui_collapsed, rotate=rotate, ui=ui)
2525
self._init_viewer_kwargs.update(**add_data_kwargs)
26+
self.init_data = {}
2627

2728
def _get_input_data(self):
2829
input_options = ["data", "image", "point_sets"]
30+
inputs = []
2931
for option in input_options:
3032
data = self._init_viewer_kwargs.get(option, None)
3133
if data is not None:
@@ -58,27 +60,32 @@ async def run(self, ctx):
5860
else:
5961
config = {}
6062

61-
data, input_type = self._get_input_data()
63+
inputs = self._get_input_data()
6264

63-
init_data = None
64-
if data is not None:
65+
self.init_data.clear()
66+
for (input_type, data) in inputs:
6567
render_type = _detect_render_type(data, input_type)
68+
key = init_key_aliases()[input_type]
6669
if render_type is RenderType.IMAGE:
67-
init_data = {"image": data}
70+
result = await _get_viewer_image(data)
6871
elif render_type is RenderType.POINT_SET:
69-
init_data = {"pointSets": data}
72+
result = await _get_viewer_point_sets(data)
73+
if not result:
74+
result = data
75+
self.init_data[key] = result
7076

7177
itk_viewer = await api.createWindow(
7278
name=f"itkwidgets viewer {_viewer_count}",
7379
type="itk-vtk-viewer",
7480
src="https://kitware.github.io/itk-vtk-viewer/app",
7581
fullscreen=False,
76-
data=init_data,
82+
data=self.init_data,
7783
# config should be a python data dictionary and can't be a string e.g. 'pydata-sphinx',
7884
config=config,
7985
)
8086
_viewer_count += 1
8187

88+
self.set_default_ui_values(itk_viewer)
8289
self.itk_viewer = itk_viewer
8390

8491
def set_default_ui_values(self, itk_viewer):
@@ -112,9 +119,11 @@ def set_background_color(self, bgColor: List[float]):
112119
async def set_image(self, image: Image):
113120
render_type = _detect_render_type(image, 'image')
114121
if render_type is RenderType.IMAGE:
115-
await _set_viewer_image(self.viewer_rpc.itk_viewer, image)
122+
image = _get_viewer_image(image)
123+
await self.viewer_rpc.itk_viewer.setImage(image)
116124
elif render_type is RenderType.POINT_SET:
117-
await _set_viewer_point_sets(self.viewer_rpc.itk_viewer, image)
125+
image = _get_viewer_point_sets(image)
126+
await self.viewer_rpc.itk_viewer.setPointSets(image)
118127

119128
def set_image_blend_mode(self, mode: str):
120129
self.viewer_rpc.itk_viewer.setImageBlendMode(mode)
@@ -152,9 +161,11 @@ def set_image_volume_sample_distance(self, distance: float):
152161
async def set_label_image(self, label_image: Image):
153162
render_type = _detect_render_type(label_image, 'image')
154163
if render_type is RenderType.IMAGE:
155-
await _set_viewer_image(self.viewer_rpc.itk_viewer, label_image, is_label=True)
164+
label_image = _get_viewer_image(label_image, is_label=True)
165+
await self.viewer_rpc.itk_viewer.setImage(label_image)
156166
elif render_type is RenderType.POINT_SET:
157-
await _set_viewer_point_sets(self.viewer_rpc.itk_viewer, label_image)
167+
label_image = _get_viewer_point_sets(label_image, is_label=True)
168+
await self.viewer_rpc.itk_viewer.setPointSets(label_image)
158169

159170
def set_label_image_blend(self, blend: float):
160171
self.viewer_rpc.itk_viewer.setLabelImageBlend(blend)

0 commit comments

Comments
 (0)