Skip to content

Commit

Permalink
Improve imageSegmentation view extension
Browse files Browse the repository at this point in the history
  • Loading branch information
jmancewicz committed Oct 4, 2016
1 parent f8b74e8 commit a9284ab
Show file tree
Hide file tree
Showing 11 changed files with 336 additions and 66 deletions.
37 changes: 37 additions & 0 deletions digits/extensions/view/imageSegmentation/app_begin_template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!-- Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. -->

<div ng-app="visualization_app" class="ng-cloak">
<div ng-controller="display_controller as dispaly">
<!-- Display Options -->
<div class="pull-right">
<div class="row">
<div class="col-md-12">
<div class="button-group pull-right">
<button type="button" class="btn btn-default btn-sm dropdown-toggle" data-toggle="dropdown">
<span class="glyphicon glyphicon-cog"></span>
<span class="caret"></span>
</button>
<ul class="dropdown-menu"
style="padding:10px"
ng-click="$event.stopPropagation()">
<li>
<small>Opacity {[storage.opacity * 100 | number : 0]}%</small>
<input type="range" min="0" max="1" step="0.01" ng-model="storage.opacity">
</li>
<li>
<small>Mask {[storage.mask * 100 | number : 0]}%</small>
<input type="range" min="0" max="1" step="0.01" ng-model="storage.mask">
</li>
<!-- reset -->
<li>
<button type="button" class="btn btn-default btn-sm dropdown-toggle"
ng-click="storage.opacity = 0.3; storage.mask = 0;"
title="Reset to defaults">
<span class="glyphicon glyphicon-refresh"></span>
</button>
</li>
</ul>
</div>
</div>
</div>
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<!-- Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. -->

</div>
</div>
14 changes: 2 additions & 12 deletions digits/extensions/view/imageSegmentation/header_template.html
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #}

<h3>Legend</h3>

<table>
{% for category in legend %}
{% set image = category['image'] %}
{% set text = category['text'] %}
<tr>
<td><img src="{{image}}" style="max-width:100%;" /></td>
<td>{{ text }}</td>
</tr>
{% endfor %}
</table>
<script type="text/javascript" src="/extension-static/view/image-segmentation/js/app.js"></script>
<link rel="stylesheet" href="/extension-static/view/image-segmentation/css/app.css">
20 changes: 20 additions & 0 deletions digits/extensions/view/imageSegmentation/static/css/app.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. */

div.vis-div {
position: relative;
}
img.overlay {
float: left;
position: absolute;
left: 0px;
top: 0px;
}
img.fill {
z-index: 1;
}
img.mask {
z-index: 2;
}
caption {
caption-side: bottom;
}
34 changes: 34 additions & 0 deletions digits/extensions/view/imageSegmentation/static/js/app.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.

// Angularjs app, visualization_app
var app = angular.module('visualization_app', ['ngStorage']);

// Controller to handle global display attributes
app.controller('display_controller',
['$scope', '$rootScope', '$localStorage',
function($scope, $rootScope, $localStorage) {
$rootScope.storage = $localStorage.$default({
opacity: .3,
mask: 0.0,
});
$scope.fill_style = {'opacity': $localStorage.opacity};
$scope.mask_style = {'opacity': $localStorage.mask};
// Broadcast to child scopes that the opacity has changed.
$scope.$watch(function() {
return $localStorage.opacity;
}, function() {
$scope.fill_style = {'opacity': $localStorage.opacity};
});
// Broadcast to child scopes that the mask has changed.
$scope.$watch(function() {
return $localStorage.mask;
}, function() {
$scope.mask_style = {'opacity': $localStorage.mask};
});
}]);

// Because jinja uses {{ and }}, tell angular to use {[ and ]}
app.config(['$interpolateProvider', function($interpolateProvider) {
$interpolateProvider.startSymbol('{[');
$interpolateProvider.endSymbol(']}');
}]);
149 changes: 111 additions & 38 deletions digits/extensions/view/imageSegmentation/view.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import json
import os

import matplotlib as mpl
import numpy as np
import PIL.Image
import skfmm

import digits
from digits.utils import subclass, override
Expand All @@ -14,8 +16,10 @@
from ..interface import VisualizationInterface

CONFIG_TEMPLATE = "config_template.html"
VIEW_TEMPLATE = "view_template.html"
HEADER_TEMPLATE = "header_template.html"
APP_BEGIN_TEMPLATE = "app_begin_template.html"
APP_END_TEMPLATE = "app_end_template.html"
VIEW_TEMPLATE = "view_template.html"


@subclass
Expand Down Expand Up @@ -56,9 +60,6 @@ def __init__(self, dataset, **kwargs):
else:
self.class_labels = None

# create array to memorize all classes we found in labels
self.found_classes = np.array([])

@staticmethod
def get_config_form():
return ConfigForm()
Expand All @@ -80,6 +81,32 @@ def get_config_template(form):
os.path.join(extension_dir, CONFIG_TEMPLATE), "r").read()
return (template, {'form': form})

def get_legend_for(self, found_classes, skip_classes=[]):
"""
Return the legend color image squares and text for each class
:param found_classes: list of class indices
:param skip_classes: list of class indices to skip
:return: list of dicts of text hex_color for each class
"""
legend = []
for c in (x for x in found_classes if x not in skip_classes):
# create hex color associated with the category ID
if self.map:
rgb_color = self.map.to_rgba([c])[0, :3]
hex_color = mpl.colors.rgb2hex(rgb_color)
else:
# make a grey scale hex color
h = hex(int(c)).split('x')[1].zfill(2)
hex_color = '#%s%s%s' % (h, h, h)

if self.class_labels:
text = self.class_labels[int(c)]
else:
text = "Class #%d" % c

legend.append({'index':c, 'text': text, 'hex_color': hex_color})
return legend

@override
def get_header_template(self):
"""
Expand All @@ -89,24 +116,17 @@ def get_header_template(self):
template = open(
os.path.join(extension_dir, HEADER_TEMPLATE), "r").read()

# show legend
legend = []
for c in self.found_classes:
# create small square image and fill it with the color
# associated with the category ID
if self.map:
rgb_color = self.map.to_rgba([c])[0,:3]*255
image = np.zeros(shape=(50, 50, 3))
image[:,:] = rgb_color
else:
image = np.full(shape=(50, 50), fill_value=c)
image = image.astype('uint8')
image = digits.utils.image.embed_image_html(image)
text = "Class #%d" % c
if self.class_labels:
text = "%s (%s)" % (text, self.class_labels[int(c)])
legend.append({'image': image, 'text': text})
return template, {'legend': legend}
return template, {}

@override
def get_ng_templates(self):
"""
Implements get_ng_templates() method from view extension interface
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
header = open(os.path.join(extension_dir, APP_BEGIN_TEMPLATE), "r").read()
footer = open(os.path.join(extension_dir, APP_END_TEMPLATE), "r").read()
return header, footer

@staticmethod
def get_id():
Expand All @@ -116,6 +136,10 @@ def get_id():
def get_title():
return 'Image Segmentation'

@staticmethod
def get_dirname():
return 'imageSegmentation'

@override
def get_view_template(self, data):
"""
Expand All @@ -125,7 +149,14 @@ def get_view_template(self, data):
- context is a dictionary of context variables to use for rendering
the form
"""
return self.view_template, {'image': digits.utils.image.embed_image_html(data)}
return self.view_template, {
'input_id': data['input_id'],
'input_image': digits.utils.image.embed_image_html(data['input_image']),
'fill_image': digits.utils.image.embed_image_html(data['fill_image']),
'mask_image': digits.utils.image.embed_image_html(data['mask_image']),
'legend': data['legend'],
'class_data': json.dumps(data['class_data'].tolist()),
}

@override
def process_data(self, input_id, input_data, output_data):
Expand All @@ -134,24 +165,66 @@ def process_data(self, input_id, input_data, output_data):
"""
# assume the only output is a CHW image where C is the number
# of classes, H and W are the height and width of the image
data = output_data[output_data.keys()[0]].astype('float32')
class_data = output_data[output_data.keys()[0]].astype('float32')
# retain only the top class for each pixel
data = np.argmax(data,axis=0)
class_data = np.argmax(class_data, axis=0).astype('uint8')

# remember the classes we found
found_classes = np.unique(data)
self.found_classes = np.unique(np.concatenate(
(self.found_classes, found_classes)))
found_classes = np.unique(class_data)

# convert using color map (assume 8-bit output)
if self.map:
data = self.map.to_rgba(data)*255
# keep RGB values only, remove alpha channel
data = data[:, :, 0:3]

# convert to uint8
data = data.astype('uint8')
# convert to PIL image
image = PIL.Image.fromarray(data)

return image
fill_data = (self.map.to_rgba(class_data)*255).astype('uint8')
else:
fill_data = np.ndarray((class_data.shape[0], class_data.shape[1], 4), dtype='uint8')
for x in xrange(3):
fill_data[:, :, x] = class_data.copy()

# Assuming that class 0 is the background
mask = np.greater(class_data, 0)
fill_data[:, :, 3] = mask * 255

# Black mask of non-segmented pixels
mask_data = np.zeros(fill_data.shape, dtype='uint8')
mask_data[:, :, 3] = (1 - mask) * 255

# Generate outlines around segmented classes
if len(found_classes) > 1:
# Assuming that class 0 is the background.
line_mask = np.zeros(class_data.shape, dtype=bool)
for c in (x for x in found_classes if x != 0):
c_mask = np.equal(class_data, c)
# Find the signed distance from the zero contour
distance = skfmm.distance(c_mask.astype('float32') - 0.5)
# Accumulate the mask for all classes
line_width = 3
line_mask |= c_mask & np.less(distance, line_width)

# add the outlines to the input image
for x in xrange(3):
input_data[:, :, x] = (input_data[:, :, x] * (1 - line_mask) +
fill_data[:, :, x] * line_mask)

# Input image with outlines
input_image = PIL.Image.fromarray(input_data)
input_image.format = 'png'

# Fill image
fill_image = PIL.Image.fromarray(fill_data)
fill_image.format = 'png'

# Mask image
mask_image = PIL.Image.fromarray(mask_data)
mask_image.format = 'png'

# legend for this instance
legend = self.get_legend_for(found_classes, skip_classes=[0])

return {
'input_id': input_id,
'input_image': input_image,
'fill_image': fill_image,
'mask_image': mask_image,
'legend': legend,
'class_data': class_data,
}
Loading

0 comments on commit a9284ab

Please sign in to comment.