Skip to content

Commit d657c79

Browse files
committed
ENH: Add initial label_map support
1 parent 6064d7a commit d657c79

File tree

3 files changed

+189
-56
lines changed

3 files changed

+189
-56
lines changed

examples/3DImage.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@
4747
{
4848
"data": {
4949
"application/vnd.jupyter.widget-view+json": {
50-
"model_id": "ff7e3fd0c36d47829ba3326ebc805032",
50+
"model_id": "57448261fbe849c4bacc9af9a83f9f14",
5151
"version_major": 2,
5252
"version_minor": 0
5353
},
5454
"text/plain": [
55-
"Viewer(geometries=[], gradient_opacity=0.4, point_sets=[], rendered_image=<itkImagePython.itkImageSS3; proxy o…"
55+
"Viewer(geometries=[], gradient_opacity=0.9, point_sets=[], rendered_image=<itkImagePython.itkImageSS3; proxy o…"
5656
]
5757
},
5858
"metadata": {},
@@ -61,7 +61,7 @@
6161
],
6262
"source": [
6363
"image = itk.imread(file_name)\n",
64-
"view(image, rotate=True, vmin=4000, vmax=17000, gradient_opacity=0.4)"
64+
"view(image, rotate=True, vmin=4000, vmax=17000, gradient_opacity=0.9)"
6565
]
6666
},
6767
{

itkwidgets/widget_viewer.py

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,17 @@ class Viewer(ViewerParent):
132132
_rendering_image = CBool(
133133
default_value=False,
134134
help="We are currently volume rendering the image.").tag(sync=True)
135+
label_map = ITKImage(
136+
default_value=None,
137+
allow_none=True,
138+
help="Label map for the image.").tag(
139+
sync=False,
140+
**itkimage_serialization)
141+
rendered_label_map = ITKImage(
142+
default_value=None,
143+
allow_none=True).tag(
144+
sync=True,
145+
**itkimage_serialization)
135146
interpolation = CBool(
136147
default_value=True,
137148
help="Use linear interpolation in slicing planes.").tag(sync=True)
@@ -298,23 +309,28 @@ def __init__(self, **kwargs): # noqa: C901
298309

299310
super(Viewer, self).__init__(**kwargs)
300311

301-
if not self.image:
312+
if not self.image and not self.label_map:
302313
return
303-
dimension = self.image.GetImageDimension()
304-
largest_region = self.image.GetLargestPossibleRegion()
314+
if self.image:
315+
image = self.image
316+
else:
317+
image = self.label_map
318+
dimension = image.GetImageDimension()
319+
largest_region = image.GetLargestPossibleRegion()
305320
size = largest_region.GetSize()
306321

307322
# Cache this so we do not need to recompute on it when resetting the
308323
# roi
309324
self._largest_roi_rendered_image = None
325+
self._largest_roi_rendered_label_map = None
310326
self._largest_roi = np.zeros((2, 3), dtype=np.float64)
311327
if not np.any(self.roi):
312328
largest_index = largest_region.GetIndex()
313329
self.roi[0][:dimension] = np.array(
314-
self.image.TransformIndexToPhysicalPoint(largest_index))
330+
image.TransformIndexToPhysicalPoint(largest_index))
315331
largest_index_upper = largest_index + size
316332
self.roi[1][:dimension] = np.array(
317-
self.image.TransformIndexToPhysicalPoint(largest_index_upper))
333+
image.TransformIndexToPhysicalPoint(largest_index_upper))
318334
self._largest_roi = self.roi.copy()
319335

320336
if dimension == 2:
@@ -325,32 +341,39 @@ def __init__(self, **kwargs): # noqa: C901
325341
for dim in range(dimension):
326342
if size[dim] > self.size_limit_3d[dim]:
327343
self._downsampling = True
328-
if self._downsampling:
344+
if self._downsampling and self.image:
329345
self.extractor = itk.ExtractImageFilter.New(self.image)
330346
self.shrinker = itk.BinShrinkImageFilter.New(self.extractor)
347+
if self._downsampling and self.label_map:
348+
self.label_map_extractor = itk.ExtractImageFilter.New(self.label_map)
349+
self.label_map_shrinker = itk.ShrinkImageFilter.New(self.label_map_extractor)
331350
self._update_rendered_image()
332351
if self._downsampling:
333352
self.observe(self._on_roi_changed, ['roi'])
334353

335354
self.observe(self._on_reset_crop_requested, ['_reset_crop_requested'])
336-
self.observe(self.update_rendered_image, ['image'])
355+
self.observe(self.update_rendered_image, ['image', 'label_map'])
337356

338357
def _on_roi_changed(self, change=None):
339358
if self._downsampling:
340359
self._update_rendered_image()
341360

342361
def _on_reset_crop_requested(self, change=None):
343362
if change.new is True and self._downsampling:
344-
dimension = self.image.GetImageDimension()
345-
largest_region = self.image.GetLargestPossibleRegion()
363+
if self.image:
364+
image = self.image
365+
else:
366+
image = self.label_map
367+
dimension = image.GetImageDimension()
368+
largest_region = image.GetLargestPossibleRegion()
346369
size = largest_region.GetSize()
347370
largest_index = largest_region.GetIndex()
348371
new_roi = self.roi.copy()
349372
new_roi[0][:dimension] = np.array(
350-
self.image.TransformIndexToPhysicalPoint(largest_index))
373+
image.TransformIndexToPhysicalPoint(largest_index))
351374
largest_index_upper = largest_index + size
352375
new_roi[1][:dimension] = np.array(
353-
self.image.TransformIndexToPhysicalPoint(largest_index_upper))
376+
image.TransformIndexToPhysicalPoint(largest_index_upper))
354377
self._largest_roi = new_roi.copy()
355378
self.roi = new_roi
356379
if change.new is True:
@@ -359,6 +382,7 @@ def _on_reset_crop_requested(self, change=None):
359382
@debounced(delay_seconds=0.2, method=True)
360383
def update_rendered_image(self, change=None):
361384
self._largest_roi_rendered_image = None
385+
self._largest_roi_rendered_label_map = None
362386
self._largest_roi = np.zeros((2, 3), dtype=np.float64)
363387
self._update_rendered_image()
364388

@@ -371,7 +395,7 @@ def _find_scale_factors(limit, dimension, size):
371395
return scale_factors
372396

373397
def _update_rendered_image(self):
374-
if self.image is None:
398+
if self.image is None and self.label_map is None:
375399
return
376400
if self._rendering_image:
377401
@yield_for_change(self, '_rendering_image')
@@ -382,10 +406,14 @@ def f():
382406
self._rendering_image = True
383407

384408
if self._downsampling:
385-
dimension = self.image.GetImageDimension()
386-
index = self.image.TransformPhysicalPointToIndex(
409+
if self.image:
410+
image = self.image
411+
else:
412+
image = self.label_map
413+
dimension = image.GetImageDimension()
414+
index = image.TransformPhysicalPointToIndex(
387415
self.roi[0][:dimension])
388-
upper_index = self.image.TransformPhysicalPointToIndex(
416+
upper_index = image.TransformPhysicalPointToIndex(
389417
self.roi[1][:dimension])
390418
size = upper_index - index
391419

@@ -396,43 +424,72 @@ def f():
396424
scale_factors = self._find_scale_factors(
397425
self.size_limit_3d, dimension, size)
398426
self._scale_factors = np.array(scale_factors, dtype=np.uint8)
399-
self.shrinker.SetShrinkFactors(scale_factors[:dimension])
427+
if self.image:
428+
self.shrinker.SetShrinkFactors(scale_factors[:dimension])
429+
if self.label_map:
430+
self.label_map_shrinker.SetShrinkFactors(scale_factors[:dimension])
400431

401432
region = itk.ImageRegion[dimension]()
402433
region.SetIndex(index)
403434
region.SetSize(tuple(size))
404435
# Account for rounding
405436
# truncation issues
406437
region.PadByRadius(1)
407-
region.Crop(self.image.GetLargestPossibleRegion())
438+
region.Crop(image.GetLargestPossibleRegion())
408439

409-
self.extractor.SetInput(self.image)
410-
self.extractor.SetExtractionRegion(region)
440+
if self.image:
441+
self.extractor.SetInput(self.image)
442+
self.extractor.SetExtractionRegion(region)
443+
if self.label_map:
444+
self.label_map_extractor.SetInput(self.label_map)
445+
self.label_map_extractor.SetExtractionRegion(region)
411446

412447
size = region.GetSize()
413448

414449
is_largest = False
415450
if np.any(self._largest_roi) and np.all(
416451
self._largest_roi == self.roi):
417452
is_largest = True
418-
if self._largest_roi_rendered_image is not None:
419-
self.rendered_image = self._largest_roi_rendered_image
453+
if self._largest_roi_rendered_image is not None or self._largest_roi_rendered_label_map is not None:
454+
if self.image:
455+
self.rendered_image = self._largest_roi_rendered_image
456+
if self.label_map:
457+
self.rendered_label_map = self._largest_roi_rendered_label_map
420458
return
421459

422-
self.shrinker.UpdateLargestPossibleRegion()
460+
if self.image:
461+
self.shrinker.UpdateLargestPossibleRegion()
462+
if self.label_map:
463+
self.label_map_shrinker.UpdateLargestPossibleRegion()
423464
if is_largest:
424-
self._largest_roi_rendered_image = self.shrinker.GetOutput()
425-
self._largest_roi_rendered_image.DisconnectPipeline()
426-
self._largest_roi_rendered_image.SetOrigin(
427-
self.roi[0][:dimension])
428-
self.rendered_image = self._largest_roi_rendered_image
465+
if self.image:
466+
self._largest_roi_rendered_image = self.shrinker.GetOutput()
467+
self._largest_roi_rendered_image.DisconnectPipeline()
468+
self._largest_roi_rendered_image.SetOrigin(
469+
self.roi[0][:dimension])
470+
self.rendered_image = self._largest_roi_rendered_image
471+
if self.label_map:
472+
self._largest_roi_rendered_label_map = self.label_map_shrinker.GetOutput()
473+
self._largest_roi_rendered_label_map.DisconnectPipeline()
474+
self._largest_roi_rendered_label_map.SetOrigin(
475+
self.roi[0][:dimension])
476+
self.rendered_label_map = self._largest_roi_rendered_label_map
429477
return
430-
shrunk = self.shrinker.GetOutput()
431-
shrunk.DisconnectPipeline()
432-
shrunk.SetOrigin(self.roi[0][:dimension])
433-
self.rendered_image = shrunk
478+
if self.image:
479+
shrunk = self.shrinker.GetOutput()
480+
shrunk.DisconnectPipeline()
481+
shrunk.SetOrigin(self.roi[0][:dimension])
482+
self.rendered_image = shrunk
483+
if self.label_map:
484+
shrunk = self.label_map_shrinker.GetOutput()
485+
shrunk.DisconnectPipeline()
486+
shrunk.SetOrigin(self.roi[0][:dimension])
487+
self.rendered_label_map = shrunk
434488
else:
435-
self.rendered_image = self.image
489+
if self.image:
490+
self.rendered_image = self.image
491+
if self.label_map:
492+
self.rendered_label_map = self.image
436493

437494
@validate('gradient_opacity')
438495
def _validate_gradient_opacity(self, proposal):
@@ -541,23 +598,31 @@ def _on_geometries_changed(self, change=None):
541598

542599
def roi_region(self):
543600
"""Return the itk.ImageRegion corresponding to the roi."""
544-
dimension = self.image.GetImageDimension()
545-
index = self.image.TransformPhysicalPointToIndex(
601+
if self.image:
602+
image = self.image
603+
else:
604+
image = self.label_map
605+
dimension = image.GetImageDimension()
606+
index = image.TransformPhysicalPointToIndex(
546607
tuple(self.roi[0][:dimension]))
547-
upper_index = self.image.TransformPhysicalPointToIndex(
608+
upper_index = image.TransformPhysicalPointToIndex(
548609
tuple(self.roi[1][:dimension]))
549610
size = upper_index - index
550611
for dim in range(dimension):
551612
size[dim] += 1
552613
region = itk.ImageRegion[dimension]()
553614
region.SetIndex(index)
554615
region.SetSize(tuple(size))
555-
region.Crop(self.image.GetLargestPossibleRegion())
616+
region.Crop(image.GetLargestPossibleRegion())
556617
return region
557618

558619
def roi_slice(self):
559620
"""Return the numpy array slice corresponding to the roi."""
560-
dimension = self.image.GetImageDimension()
621+
if self.image:
622+
image = self.image
623+
else:
624+
image = self.label_map
625+
dimension = image.GetImageDimension()
561626
region = self.roi_region()
562627
index = region.GetIndex()
563628
upper_index = np.array(index) + np.array(region.GetSize())
@@ -568,6 +633,7 @@ def roi_slice(self):
568633

569634

570635
def view(image=None, # noqa: C901
636+
label_map=None, # noqa: C901
571637
cmap=cm.viridis,
572638
select_roi=False,
573639
interpolation=True,
@@ -584,7 +650,8 @@ def view(image=None, # noqa: C901
584650
Creates and returns an ipywidget to visualize an image, and/or point sets
585651
and/or geometries .
586652
587-
The image can be 2D or 3D.
653+
The image can be 2D or 3D. A label map that corresponds to the image can
654+
also be provided. The image and label map must have the same size.
588655
589656
The type of the image can be an numpy.array, itk.Image,
590657
vtk.vtkImageData, pyvista.UniformGrid, imglyb.ReferenceGuardingRandomAccessibleInterval,
@@ -634,6 +701,10 @@ def view(image=None, # noqa: C901
634701
image : array_like, itk.Image, or vtk.vtkImageData
635702
The 2D or 3D image to visualize.
636703
704+
label_map : array_like, itk.Image, or vtk.vtkImageData
705+
The 2D or 3D label map to visualize. If an image is also provided, the
706+
label map must have the same size.
707+
637708
vmin: float, optional, default: None
638709
Value that maps to the minimum of image colormap. Defaults to minimum of
639710
the image pixel buffer.
@@ -798,6 +869,7 @@ def view(image=None, # noqa: C901
798869
image = images[0]
799870

800871
viewer = Viewer(image=image,
872+
label_map=label_map,
801873
cmap=cmap,
802874
select_roi=select_roi,
803875
interpolation=interpolation,

0 commit comments

Comments
 (0)