-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Image segmentation view improvement #1131
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
<!-- Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. --> | ||
|
||
<div ng-app="visualization_app" class="ng-cloak"> | ||
<div ng-controller="display_controller as dispaly"> | ||
<!-- Display Options --> | ||
<div class="pull-right"> | ||
<div class="row"> | ||
<div class="col-md-12"> | ||
<div class="button-group pull-right"> | ||
<button type="button" class="btn btn-default btn-sm dropdown-toggle" data-toggle="dropdown"> | ||
<span class="glyphicon glyphicon-cog"></span> | ||
<span class="caret"></span> | ||
</button> | ||
<ul class="dropdown-menu" | ||
style="padding:10px" | ||
ng-click="$event.stopPropagation()"> | ||
<li> | ||
<small>Opacity {[storage.opacity * 100 | number : 0]}%</small> | ||
<input type="range" min="0" max="1" step="0.01" ng-model="storage.opacity"> | ||
</li> | ||
<li> | ||
<small>Mask {[storage.mask * 100 | number : 0]}%</small> | ||
<input type="range" min="0" max="1" step="0.01" ng-model="storage.mask"> | ||
</li> | ||
<!-- reset --> | ||
<li> | ||
<button type="button" class="btn btn-default btn-sm dropdown-toggle" | ||
ng-click="storage.opacity = 0.3; storage.mask = 0;" | ||
title="Reset to defaults"> | ||
<span class="glyphicon glyphicon-refresh"></span> | ||
</button> | ||
</li> | ||
</ul> | ||
</div> | ||
</div> | ||
</div> | ||
</div> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
<!-- Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. --> | ||
|
||
</div> | ||
</div> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,4 @@ | ||
{# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. #} | ||
|
||
<h3>Legend</h3> | ||
|
||
<table> | ||
{% for category in legend %} | ||
{% set image = category['image'] %} | ||
{% set text = category['text'] %} | ||
<tr> | ||
<td><img src="{{image}}" style="max-width:100%;" /></td> | ||
<td>{{ text }}</td> | ||
</tr> | ||
{% endfor %} | ||
</table> | ||
<script type="text/javascript" src="/extension-static/view/image-segmentation/js/app.js"></script> | ||
<link rel="stylesheet" href="/extension-static/view/image-segmentation/css/app.css"> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/* Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. */ | ||
|
||
div.vis-div { | ||
position: relative; | ||
} | ||
img.overlay { | ||
float: left; | ||
position: absolute; | ||
left: 0px; | ||
top: 0px; | ||
} | ||
img.fill { | ||
z-index: 1; | ||
} | ||
img.mask { | ||
z-index: 2; | ||
} | ||
caption { | ||
caption-side: bottom; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. | ||
|
||
// Angularjs app, visualization_app | ||
var app = angular.module('visualization_app', ['ngStorage']); | ||
|
||
// Controller to handle global display attributes | ||
app.controller('display_controller', | ||
['$scope', '$rootScope', '$localStorage', | ||
function($scope, $rootScope, $localStorage) { | ||
$rootScope.storage = $localStorage.$default({ | ||
opacity: .3, | ||
mask: 0.0, | ||
}); | ||
$scope.fill_style = {'opacity': $localStorage.opacity}; | ||
$scope.mask_style = {'opacity': $localStorage.mask}; | ||
// Broadcast to child scopes that the opacity has changed. | ||
$scope.$watch(function() { | ||
return $localStorage.opacity; | ||
}, function() { | ||
$scope.fill_style = {'opacity': $localStorage.opacity}; | ||
}); | ||
// Broadcast to child scopes that the mask has changed. | ||
$scope.$watch(function() { | ||
return $localStorage.mask; | ||
}, function() { | ||
$scope.mask_style = {'opacity': $localStorage.mask}; | ||
}); | ||
}]); | ||
|
||
// Because jinja uses {{ and }}, tell angular to use {[ and ]} | ||
app.config(['$interpolateProvider', function($interpolateProvider) { | ||
$interpolateProvider.startSymbol('{['); | ||
$interpolateProvider.endSymbol(']}'); | ||
}]); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. | ||
from __future__ import absolute_import | ||
|
||
import json | ||
import os | ||
|
||
import matplotlib as mpl | ||
import numpy as np | ||
import PIL.Image | ||
import skfmm | ||
|
||
import digits | ||
from digits.utils import subclass, override | ||
|
@@ -14,8 +16,10 @@ | |
from ..interface import VisualizationInterface | ||
|
||
CONFIG_TEMPLATE = "config_template.html" | ||
VIEW_TEMPLATE = "view_template.html" | ||
HEADER_TEMPLATE = "header_template.html" | ||
APP_BEGIN_TEMPLATE = "app_begin_template.html" | ||
APP_END_TEMPLATE = "app_end_template.html" | ||
VIEW_TEMPLATE = "view_template.html" | ||
|
||
|
||
@subclass | ||
|
@@ -56,9 +60,6 @@ def __init__(self, dataset, **kwargs): | |
else: | ||
self.class_labels = None | ||
|
||
# create array to memorize all classes we found in labels | ||
self.found_classes = np.array([]) | ||
|
||
@staticmethod | ||
def get_config_form(): | ||
return ConfigForm() | ||
|
@@ -80,6 +81,32 @@ def get_config_template(form): | |
os.path.join(extension_dir, CONFIG_TEMPLATE), "r").read() | ||
return (template, {'form': form}) | ||
|
||
def get_legend_for(self, found_classes, skip_classes=[]): | ||
""" | ||
Return the legend color image squares and text for each class | ||
:param found_classes: list of class indices | ||
:param skip_classes: list of class indices to skip | ||
:return: list of dicts of text hex_color for each class | ||
""" | ||
legend = [] | ||
for c in (x for x in found_classes if x not in skip_classes): | ||
# create hex color associated with the category ID | ||
if self.map: | ||
rgb_color = self.map.to_rgba([c])[0, :3] | ||
hex_color = mpl.colors.rgb2hex(rgb_color) | ||
else: | ||
# make a grey scale hex color | ||
h = hex(int(c)).split('x')[1].zfill(2) | ||
hex_color = '#%s%s%s' % (h, h, h) | ||
|
||
if self.class_labels: | ||
text = self.class_labels[int(c)] | ||
else: | ||
text = "Class #%d" % c | ||
|
||
legend.append({'index':c, 'text': text, 'hex_color': hex_color}) | ||
return legend | ||
|
||
@override | ||
def get_header_template(self): | ||
""" | ||
|
@@ -89,24 +116,17 @@ def get_header_template(self): | |
template = open( | ||
os.path.join(extension_dir, HEADER_TEMPLATE), "r").read() | ||
|
||
# show legend | ||
legend = [] | ||
for c in self.found_classes: | ||
# create small square image and fill it with the color | ||
# associated with the category ID | ||
if self.map: | ||
rgb_color = self.map.to_rgba([c])[0,:3]*255 | ||
image = np.zeros(shape=(50, 50, 3)) | ||
image[:,:] = rgb_color | ||
else: | ||
image = np.full(shape=(50, 50), fill_value=c) | ||
image = image.astype('uint8') | ||
image = digits.utils.image.embed_image_html(image) | ||
text = "Class #%d" % c | ||
if self.class_labels: | ||
text = "%s (%s)" % (text, self.class_labels[int(c)]) | ||
legend.append({'image': image, 'text': text}) | ||
return template, {'legend': legend} | ||
return template, {} | ||
|
||
@override | ||
def get_ng_templates(self): | ||
""" | ||
Implements get_ng_templates() method from view extension interface | ||
""" | ||
extension_dir = os.path.dirname(os.path.abspath(__file__)) | ||
header = open(os.path.join(extension_dir, APP_BEGIN_TEMPLATE), "r").read() | ||
footer = open(os.path.join(extension_dir, APP_END_TEMPLATE), "r").read() | ||
return header, footer | ||
|
||
@staticmethod | ||
def get_id(): | ||
|
@@ -116,6 +136,10 @@ def get_id(): | |
def get_title(): | ||
return 'Image Segmentation' | ||
|
||
@staticmethod | ||
def get_dirname(): | ||
return 'imageSegmentation' | ||
|
||
@override | ||
def get_view_template(self, data): | ||
""" | ||
|
@@ -125,7 +149,14 @@ def get_view_template(self, data): | |
- context is a dictionary of context variables to use for rendering | ||
the form | ||
""" | ||
return self.view_template, {'image': digits.utils.image.embed_image_html(data)} | ||
return self.view_template, { | ||
'input_id': data['input_id'], | ||
'input_image': digits.utils.image.embed_image_html(data['input_image']), | ||
'fill_image': digits.utils.image.embed_image_html(data['fill_image']), | ||
'mask_image': digits.utils.image.embed_image_html(data['mask_image']), | ||
'legend': data['legend'], | ||
'class_data': json.dumps(data['class_data'].tolist()), | ||
} | ||
|
||
@override | ||
def process_data(self, input_id, input_data, output_data): | ||
|
@@ -134,24 +165,66 @@ def process_data(self, input_id, input_data, output_data): | |
""" | ||
# assume the only output is a CHW image where C is the number | ||
# of classes, H and W are the height and width of the image | ||
data = output_data[output_data.keys()[0]].astype('float32') | ||
class_data = output_data[output_data.keys()[0]].astype('float32') | ||
# retain only the top class for each pixel | ||
data = np.argmax(data,axis=0) | ||
class_data = np.argmax(class_data, axis=0).astype('uint8') | ||
|
||
# remember the classes we found | ||
found_classes = np.unique(data) | ||
self.found_classes = np.unique(np.concatenate( | ||
(self.found_classes, found_classes))) | ||
found_classes = np.unique(class_data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we still need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Damn, you're good. |
||
|
||
# convert using color map (assume 8-bit output) | ||
if self.map: | ||
data = self.map.to_rgba(data)*255 | ||
# keep RGB values only, remove alpha channel | ||
data = data[:, :, 0:3] | ||
|
||
# convert to uint8 | ||
data = data.astype('uint8') | ||
# convert to PIL image | ||
image = PIL.Image.fromarray(data) | ||
|
||
return image | ||
fill_data = (self.map.to_rgba(class_data)*255).astype('uint8') | ||
else: | ||
fill_data = np.ndarray((class_data.shape[0], class_data.shape[1], 4), dtype='uint8') | ||
for x in xrange(3): | ||
fill_data[:, :, x] = class_data.copy() | ||
|
||
# Assuming that class 0 is the background | ||
mask = np.greater(class_data, 0) | ||
fill_data[:, :, 3] = mask * 255 | ||
|
||
# Black mask of non-segmented pixels | ||
mask_data = np.zeros(fill_data.shape, dtype='uint8') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you need this line? I think the below one is enough (?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The np.zeros create RGBA data. the line below sets the A component. Is there some pythonic trickery that I'm missing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my mistake, sorry. |
||
mask_data[:, :, 3] = (1 - mask) * 255 | ||
|
||
# Generate outlines around segmented classes | ||
if len(found_classes) > 1: | ||
# Assuming that class 0 is the background. | ||
line_mask = np.zeros(class_data.shape, dtype=bool) | ||
for c in (x for x in found_classes if x != 0): | ||
c_mask = np.equal(class_data, c) | ||
# Find the signed distance from the zero contour | ||
distance = skfmm.distance(c_mask.astype('float32') - 0.5) | ||
# Accumulate the mask for all classes | ||
line_width = 3 | ||
line_mask |= c_mask & np.less(distance, line_width) | ||
|
||
# add the outlines to the input image | ||
for x in xrange(3): | ||
input_data[:, :, x] = (input_data[:, :, x] * (1 - line_mask) + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this doesn't work if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, this should be done within the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gheinrich, did you try this with a (256, 256)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jmancewicz yes, this was an image from the Sunnybrook dataset (with channel conversion = 'none') |
||
fill_data[:, :, x] * line_mask) | ||
|
||
# Input image with outlines | ||
input_image = PIL.Image.fromarray(input_data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, but that's in an upcoming PR. |
||
input_image.format = 'png' | ||
|
||
# Fill image | ||
fill_image = PIL.Image.fromarray(fill_data) | ||
fill_image.format = 'png' | ||
|
||
# Mask image | ||
mask_image = PIL.Image.fromarray(mask_data) | ||
mask_image.format = 'png' | ||
|
||
# legend for this instance | ||
legend = self.get_legend_for(found_classes, skip_classes=[0]) | ||
|
||
return { | ||
'input_id': input_id, | ||
'input_image': input_image, | ||
'fill_image': fill_image, | ||
'mask_image': mask_image, | ||
'legend': legend, | ||
'class_data': class_data, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lukeyeager usually prefers seeing non-standard package imports below (line 14)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gheinrich, this seems PEP8 compliant (https://www.python.org/dev/peps/pep-0008/#imports) in the third party block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just move
json
up aboveimport os