Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image segmentation view improvement #1131

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions digits/extensions/view/imageSegmentation/app_begin_template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!-- Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. -->

<div ng-app="visualization_app" class="ng-cloak">
<div ng-controller="display_controller as dispaly">
<!-- Display Options -->
<div class="pull-right">
<div class="row">
<div class="col-md-12">
<div class="button-group pull-right">
<button type="button" class="btn btn-default btn-sm dropdown-toggle" data-toggle="dropdown">
<span class="glyphicon glyphicon-cog"></span>
<span class="caret"></span>
</button>
<ul class="dropdown-menu"
style="padding:10px"
ng-click="$event.stopPropagation()">
<li>
<small>Opacity {[storage.opacity * 100 | number : 0]}%</small>
<input type="range" min="0" max="1" step="0.01" ng-model="storage.opacity">
</li>
<li>
<small>Mask {[storage.mask * 100 | number : 0]}%</small>
<input type="range" min="0" max="1" step="0.01" ng-model="storage.mask">
</li>
<!-- reset -->
<li>
<button type="button" class="btn btn-default btn-sm dropdown-toggle"
ng-click="storage.opacity = 0.3; storage.mask = 0;"
title="Reset to defaults">
<span class="glyphicon glyphicon-refresh"></span>
</button>
</li>
</ul>
</div>
</div>
</div>
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<!-- Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. -->

</div>
</div>
14 changes: 2 additions & 12 deletions digits/extensions/view/imageSegmentation/header_template.html
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

<h3>Legend</h3>

<table>
{% for category in legend %}
{% set image = category['image'] %}
{% set text = category['text'] %}
<tr>
<td><img src="{{image}}" style="max-width:100%;" /></td>
<td>{{ text }}</td>
</tr>
{% endfor %}
</table>
<script type="text/javascript" src="/extension-static/view/image-segmentation/js/app.js"></script>
<link rel="stylesheet" href="/extension-static/view/image-segmentation/css/app.css">
20 changes: 20 additions & 0 deletions digits/extensions/view/imageSegmentation/static/css/app.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. */

div.vis-div {
position: relative;
}
img.overlay {
float: left;
position: absolute;
left: 0px;
top: 0px;
}
img.fill {
z-index: 1;
}
img.mask {
z-index: 2;
}
caption {
caption-side: bottom;
}
34 changes: 34 additions & 0 deletions digits/extensions/view/imageSegmentation/static/js/app.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.

// Angularjs app, visualization_app
var app = angular.module('visualization_app', ['ngStorage']);

// Controller to handle global display attributes
app.controller('display_controller',
['$scope', '$rootScope', '$localStorage',
function($scope, $rootScope, $localStorage) {
$rootScope.storage = $localStorage.$default({
opacity: .3,
mask: 0.0,
});
$scope.fill_style = {'opacity': $localStorage.opacity};
$scope.mask_style = {'opacity': $localStorage.mask};
// Broadcast to child scopes that the opacity has changed.
$scope.$watch(function() {
return $localStorage.opacity;
}, function() {
$scope.fill_style = {'opacity': $localStorage.opacity};
});
// Broadcast to child scopes that the mask has changed.
$scope.$watch(function() {
return $localStorage.mask;
}, function() {
$scope.mask_style = {'opacity': $localStorage.mask};
});
}]);

// Because jinja uses {{ and }}, tell angular to use {[ and ]}
app.config(['$interpolateProvider', function($interpolateProvider) {
$interpolateProvider.startSymbol('{[');
$interpolateProvider.endSymbol(']}');
}]);
149 changes: 111 additions & 38 deletions digits/extensions/view/imageSegmentation/view.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import json
import os

import matplotlib as mpl
import numpy as np
import PIL.Image
import skfmm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lukeyeager usually prefers seeing non-standard package imports below (line 14)

Copy link
Contributor Author

@jmancewicz jmancewicz Oct 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gheinrich, this seems PEP8 compliant (https://www.python.org/dev/peps/pep-0008/#imports) in the third party block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just move json up above import os


import digits
from digits.utils import subclass, override
Expand All @@ -14,8 +16,10 @@
from ..interface import VisualizationInterface

CONFIG_TEMPLATE = "config_template.html"
VIEW_TEMPLATE = "view_template.html"
HEADER_TEMPLATE = "header_template.html"
APP_BEGIN_TEMPLATE = "app_begin_template.html"
APP_END_TEMPLATE = "app_end_template.html"
VIEW_TEMPLATE = "view_template.html"


@subclass
Expand Down Expand Up @@ -56,9 +60,6 @@ def __init__(self, dataset, **kwargs):
else:
self.class_labels = None

# create array to memorize all classes we found in labels
self.found_classes = np.array([])

@staticmethod
def get_config_form():
return ConfigForm()
Expand All @@ -80,6 +81,32 @@ def get_config_template(form):
os.path.join(extension_dir, CONFIG_TEMPLATE), "r").read()
return (template, {'form': form})

def get_legend_for(self, found_classes, skip_classes=[]):
"""
Return the legend color image squares and text for each class
:param found_classes: list of class indices
:param skip_classes: list of class indices to skip
:return: list of dicts of text hex_color for each class
"""
legend = []
for c in (x for x in found_classes if x not in skip_classes):
# create hex color associated with the category ID
if self.map:
rgb_color = self.map.to_rgba([c])[0, :3]
hex_color = mpl.colors.rgb2hex(rgb_color)
else:
# make a grey scale hex color
h = hex(int(c)).split('x')[1].zfill(2)
hex_color = '#%s%s%s' % (h, h, h)

if self.class_labels:
text = self.class_labels[int(c)]
else:
text = "Class #%d" % c

legend.append({'index':c, 'text': text, 'hex_color': hex_color})
return legend

@override
def get_header_template(self):
"""
Expand All @@ -89,24 +116,17 @@ def get_header_template(self):
template = open(
os.path.join(extension_dir, HEADER_TEMPLATE), "r").read()

# show legend
legend = []
for c in self.found_classes:
# create small square image and fill it with the color
# associated with the category ID
if self.map:
rgb_color = self.map.to_rgba([c])[0,:3]*255
image = np.zeros(shape=(50, 50, 3))
image[:,:] = rgb_color
else:
image = np.full(shape=(50, 50), fill_value=c)
image = image.astype('uint8')
image = digits.utils.image.embed_image_html(image)
text = "Class #%d" % c
if self.class_labels:
text = "%s (%s)" % (text, self.class_labels[int(c)])
legend.append({'image': image, 'text': text})
return template, {'legend': legend}
return template, {}

@override
def get_ng_templates(self):
"""
Implements get_ng_templates() method from view extension interface
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
header = open(os.path.join(extension_dir, APP_BEGIN_TEMPLATE), "r").read()
footer = open(os.path.join(extension_dir, APP_END_TEMPLATE), "r").read()
return header, footer

@staticmethod
def get_id():
Expand All @@ -116,6 +136,10 @@ def get_id():
def get_title():
return 'Image Segmentation'

@staticmethod
def get_dirname():
return 'imageSegmentation'

@override
def get_view_template(self, data):
"""
Expand All @@ -125,7 +149,14 @@ def get_view_template(self, data):
- context is a dictionary of context variables to use for rendering
the form
"""
return self.view_template, {'image': digits.utils.image.embed_image_html(data)}
return self.view_template, {
'input_id': data['input_id'],
'input_image': digits.utils.image.embed_image_html(data['input_image']),
'fill_image': digits.utils.image.embed_image_html(data['fill_image']),
'mask_image': digits.utils.image.embed_image_html(data['mask_image']),
'legend': data['legend'],
'class_data': json.dumps(data['class_data'].tolist()),
}

@override
def process_data(self, input_id, input_data, output_data):
Expand All @@ -134,24 +165,66 @@ def process_data(self, input_id, input_data, output_data):
"""
# assume the only output is a CHW image where C is the number
# of classes, H and W are the height and width of the image
data = output_data[output_data.keys()[0]].astype('float32')
class_data = output_data[output_data.keys()[0]].astype('float32')
# retain only the top class for each pixel
data = np.argmax(data,axis=0)
class_data = np.argmax(class_data, axis=0).astype('uint8')

# remember the classes we found
found_classes = np.unique(data)
self.found_classes = np.unique(np.concatenate(
(self.found_classes, found_classes)))
found_classes = np.unique(class_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need self.found_classes below? Now that we have one legend per image I think this can go away.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, you're good.


# convert using color map (assume 8-bit output)
if self.map:
data = self.map.to_rgba(data)*255
# keep RGB values only, remove alpha channel
data = data[:, :, 0:3]

# convert to uint8
data = data.astype('uint8')
# convert to PIL image
image = PIL.Image.fromarray(data)

return image
fill_data = (self.map.to_rgba(class_data)*255).astype('uint8')
else:
fill_data = np.ndarray((class_data.shape[0], class_data.shape[1], 4), dtype='uint8')
for x in xrange(3):
fill_data[:, :, x] = class_data.copy()

# Assuming that class 0 is the background
mask = np.greater(class_data, 0)
fill_data[:, :, 3] = mask * 255

# Black mask of non-segmented pixels
mask_data = np.zeros(fill_data.shape, dtype='uint8')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need this line? I think the below one is enough (?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The np.zeros create RGBA data. the line below sets the A component. Is there some pythonic trickery that I'm missing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my mistake, sorry.

mask_data[:, :, 3] = (1 - mask) * 255

# Generate outlines around segmented classes
if len(found_classes) > 1:
# Assuming that class 0 is the background.
line_mask = np.zeros(class_data.shape, dtype=bool)
for c in (x for x in found_classes if x != 0):
c_mask = np.equal(class_data, c)
# Find the signed distance from the zero contour
distance = skfmm.distance(c_mask.astype('float32') - 0.5)
# Accumulate the mask for all classes
line_width = 3
line_mask |= c_mask & np.less(distance, line_width)

# add the outlines to the input image
for x in xrange(3):
input_data[:, :, x] = (input_data[:, :, x] * (1 - line_mask) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work if input_data is a grayscale image (with shape e.g. [256,256])

Copy link
Contributor

@gheinrich gheinrich Oct 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, this should be done within the if len(found_classes)>1 otherwise line_mask is uninitialized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gheinrich, did you try this with a (256, 256)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmancewicz yes, this was an image from the Sunnybrook dataset (with channel conversion = 'none')

fill_data[:, :, x] * line_mask)

# Input image with outlines
input_image = PIL.Image.fromarray(input_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work if input_data.dtype!='uint8'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but that's in an upcoming PR.

input_image.format = 'png'

# Fill image
fill_image = PIL.Image.fromarray(fill_data)
fill_image.format = 'png'

# Mask image
mask_image = PIL.Image.fromarray(mask_data)
mask_image.format = 'png'

# legend for this instance
legend = self.get_legend_for(found_classes, skip_classes=[0])

return {
'input_id': input_id,
'input_image': input_image,
'fill_image': fill_image,
'mask_image': mask_image,
'legend': legend,
'class_data': class_data,
}
Loading