Skip to content

Commit 9074a43

Browse files
committed
ENH: Add support for label maps
1 parent 10cc3ee commit 9074a43

File tree

2 files changed

+57
-20
lines changed

2 files changed

+57
-20
lines changed

itkwidgets/integrations/__init__.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,34 @@
1111

1212
_image_count = 1
1313

14-
async def _set_viewer_image(itk_viewer, image, name=None):
14+
async def _set_viewer_image(itk_viewer, image, name=None, is_label=False):
1515
global _image_count
1616
if isinstance(image, itkwasm.Image):
1717
if not name:
1818
name = image.name
1919
if not name:
2020
name = f"image {_image_count}"
2121
_image_count += 1
22-
await itk_viewer.setImage(image, name)
22+
if is_label:
23+
await itk_viewer.setLabelImage(image)
24+
else:
25+
await itk_viewer.setImage(image, name)
2326
elif isinstance(image, np.ndarray):
2427
if not name:
2528
name = f"image {_image_count}"
2629
_image_count += 1
27-
await itk_viewer.setImage(image, name)
30+
if is_label:
31+
await itk_viewer.setLabelImage(image)
32+
else:
33+
await itk_viewer.setImage(image, name)
2834
elif isinstance(image, zarr.Group):
2935
if not name:
3036
name = f"image {_image_count}"
3137
_image_count += 1
32-
await itk_viewer.setImage(image, name)
38+
if is_label:
39+
await itk_viewer.setLabelImage(image)
40+
else:
41+
await itk_viewer.setImage(image, name)
3342
elif HAVE_ITK:
3443
import itk
3544
if isinstance(image, itk.Image):
@@ -38,15 +47,21 @@ async def _set_viewer_image(itk_viewer, image, name=None):
3847
if not name:
3948
name = f"image {_image_count}"
4049
_image_count += 1
41-
await itk_viewer.setImage(wasm_image, name)
50+
if is_label:
51+
await itk_viewer.setLabelImage(wasm_image)
52+
else:
53+
await itk_viewer.setImage(wasm_image, name)
4254
if HAVE_VTK:
4355
import vtk
4456
if isinstance(image, vtk.vtkImageData):
4557
ndarray = vtk_image_to_ndarray(image)
4658
if not name:
4759
name = f"image {_image_count}"
4860
_image_count += 1
49-
await itk_viewer.setImage(ndarray, name)
61+
if is_label:
62+
await itk_viewer.setLabelImage(ndarray)
63+
else:
64+
await itk_viewer.setImage(ndarray, name)
5065
if HAVE_DASK:
5166
import dask
5267
if isinstance(image, dask.array.core.Array):
@@ -55,14 +70,20 @@ async def _set_viewer_image(itk_viewer, image, name=None):
5570
if not name:
5671
name = f"image {_image_count}"
5772
_image_count += 1
58-
await itk_viewer.setImage(ndarray, name)
73+
if is_label:
74+
await itk_viewer.setLabelImage(ndarray)
75+
else:
76+
await itk_viewer.setImage(ndarray, name)
5977
if HAVE_TORCH:
6078
import torch
6179
if isinstance(image, torch.Tensor):
6280
if not name:
6381
name = f"image {_image_count}"
6482
_image_count += 1
65-
await itk_viewer.setImage(image.numpy(), name)
83+
if is_label:
84+
await itk_viewer.setLabelImage(image.numpy())
85+
else:
86+
await itk_viewer.setImage(image.numpy(), name)
6687
if HAVE_XARRAY:
6788
import xarray
6889
if isinstance(image, xarray.DataArray):
@@ -71,13 +92,19 @@ async def _set_viewer_image(itk_viewer, image, name=None):
7192
if not name:
7293
name = f"image {_image_count}"
7394
_image_count += 1
74-
await itk_viewer.setImage(ndarray, name)
95+
if is_label:
96+
await itk_viewer.setLabelImage(ndarray)
97+
else:
98+
await itk_viewer.setImage(ndarray, name)
7599
if isinstance(image, xarray.Dataset):
76100
ndarray = xarray_data_set_to_numpy(image)
77101
if not name:
78102
name = f"image {_image_count}"
79103
_image_count += 1
80-
await itk_viewer.setImage(ndarray, name)
104+
if is_label:
105+
await itk_viewer.setLabelImage(ndarray)
106+
else:
107+
await itk_viewer.setImage(ndarray, name)
81108

82109

83110
async def _set_viewer_point_sets(itk_viewer, point_sets):

itkwidgets/viewer.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ def __init__(self, ui_collapsed=True, rotate=False, **add_data_kwargs):
2222
self._init_viewer_kwargs.update(**add_data_kwargs)
2323

2424
def _get_input_data(self):
25-
input_options = ['data', 'image', 'point_sets']
25+
input_options = ['data', 'image', 'point_sets', 'label_image']
26+
inputs = []
2627
for option in input_options:
2728
data = self._init_viewer_kwargs.get(option, None)
2829
if data is not None:
29-
break
30-
return data, option
30+
inputs.append((option, data))
31+
return inputs
3132

3233
async def setup(self):
3334
pass
@@ -43,13 +44,15 @@ async def run(self, ctx):
4344
)
4445
_viewer_count += 1
4546

46-
data, input_type = self._get_input_data()
47-
if data is not None:
48-
render_type = _detect_render_type(data, input_type)
49-
if render_type is RenderType.IMAGE:
50-
await _set_viewer_image(itk_viewer, data)
51-
elif render_type is RenderType.POINT_SET:
52-
await _set_viewer_point_sets(itk_viewer, data)
47+
inputs = self._get_input_data()
48+
if inputs:
49+
for (input_type, data) in inputs:
50+
render_type = _detect_render_type(data, input_type)
51+
if render_type is RenderType.IMAGE:
52+
is_label = input_type == 'label_image'
53+
await _set_viewer_image(itk_viewer, data, is_label)
54+
elif render_type is RenderType.POINT_SET:
55+
await _set_viewer_point_sets(itk_viewer, data)
5356

5457
self.set_default_ui_values(itk_viewer)
5558

@@ -115,6 +118,13 @@ def set_image_shadow_enabled(self, enabled: bool):
115118
def set_image_volume_sample_distance(self, distance: float):
116119
self.viewer_rpc.itk_viewer.setImageVolumeSampleDistance(distance)
117120

121+
async def set_label_image(self, label_image: Image):
122+
render_type = _detect_render_type(label_image, 'image')
123+
if render_type is RenderType.IMAGE:
124+
await _set_viewer_image(self.viewer_rpc.itk_viewer, label_image, is_label=True)
125+
elif render_type is RenderType.POINT_SET:
126+
await _set_viewer_point_sets(self.viewer_rpc.itk_viewer, label_image)
127+
118128
def set_label_image_blend(self, blend: float):
119129
self.viewer_rpc.itk_viewer.setLabelImageBlend(blend)
120130

0 commit comments

Comments
 (0)