Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic segmentation work-flow #961

Merged
merged 4 commits into from
Aug 30, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions digits/dataset/generic/views.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import caffe_pb2
import flask
import PIL.Image

# Find the best implementation available
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO

import caffe_pb2
import flask
import matplotlib as mpl
import numpy as np
import PIL.Image

from .forms import GenericDatasetForm
from .job import GenericDatasetJob

from digits import extensions, utils
from digits.utils.constants import COLOR_PALETTE_ATTRIBUTE
from digits.utils.routing import request_wants_json, job_from_request
from digits.utils.lmdbreader import DbReader
from digits.webapp import scheduler
Expand Down Expand Up @@ -144,6 +146,18 @@ def explore():
db_path = job.path(db)
labels = []

if COLOR_PALETTE_ATTRIBUTE in job.extension_userdata:
# assume single-channel 8-bit palette
palette = job.extension_userdata[COLOR_PALETTE_ATTRIBUTE]
palette = np.array(palette).reshape((len(palette)/3,3)) / 255.
# normalize input pixels to [0,1]
norm = mpl.colors.Normalize(vmin=0,vmax=255)
# create map
cmap = mpl.pyplot.cm.ScalarMappable(norm=norm,
cmap=mpl.colors.ListedColormap(palette))
else:
cmap = None

page = int(flask.request.args.get('page', 0))
size = int(flask.request.args.get('size', 25))

Expand All @@ -166,6 +180,13 @@ def explore():
s.write(datum.data)
s.seek(0)
img = PIL.Image.open(s)
if cmap and img.mode in ['L', '1']:
data = np.array(img)
data = cmap.to_rgba(data)*255
data = data.astype('uint8')
# keep RGB values only, remove alpha channel
data = data[:, :, 0:3]
img = PIL.Image.fromarray(data)
imgs.append({"label": None, "b64": utils.image.embed_image_html(img)})
count += 1
if len(imgs) >= size:
Expand Down
2 changes: 2 additions & 0 deletions digits/extensions/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from . import imageGradients
from . import imageProcessing
from . import imageSegmentation
from . import objectDetection

data_extensions = [
Expand All @@ -11,6 +12,7 @@
# editing DIGITS config option 'data_extension_list'
{'class': imageGradients.DataIngestion, 'show': False},
{'class': imageProcessing.DataIngestion, 'show': True},
{'class': imageSegmentation.DataIngestion, 'show': True},
{'class': objectDetection.DataIngestion, 'show': True},
]

Expand Down
4 changes: 4 additions & 0 deletions digits/extensions/data/imageSegmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from .data import DataIngestion
203 changes: 203 additions & 0 deletions digits/extensions/data/imageSegmentation/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import math
import os
import random

import numpy as np
import PIL.Image

from digits.utils import image, subclass, override, constants
from digits.utils.constants import COLOR_PALETTE_ATTRIBUTE
from ..interface import DataIngestionInterface
from .forms import DatasetForm

TEMPLATE = "template.html"


@subclass
class DataIngestion(DataIngestionInterface):
"""
A data ingestion extension for an image segmentation dataset
"""

def __init__(self, **kwargs):
"""
the parent __init__ method automatically populates this
instance with attributes from the form
"""
super(DataIngestion, self).__init__(**kwargs)

self.random_indices = None

if not 'seed' in self.userdata:
# choose random seed and add to userdata so it gets persisted
self.userdata['seed'] = random.randint(0, 1000)

random.seed(self.userdata['seed'])

# open first image in label folder to retrieve palette
# all label images must use the same palette - this is enforced
# during dataset creation
filename = self.make_image_list(self.label_folder)[0]
image = self.load_label(filename)
self.userdata[COLOR_PALETTE_ATTRIBUTE] = image.getpalette()

# get labels if those were provided
if self.class_labels_file:
with open(self.class_labels_file) as f:
self.userdata['class_labels'] = f.read().splitlines()

@override
def encode_entry(self, entry):
"""
Return numpy.ndarray
"""
feature_image_file = entry[0]
label_image_file = entry[1]

# feature image
feature_image = self.encode_PIL_Image(
image.load_image(feature_image_file),
self.channel_conversion)

# label image
label_image = self.load_label(label_image_file)
if label_image.getpalette() != self.userdata[COLOR_PALETTE_ATTRIBUTE]:
raise ValueError("All label images must use the same palette")
label_image = self.encode_PIL_Image(label_image)

return feature_image, label_image

def encode_PIL_Image(self, image, channel_conversion='none'):
if channel_conversion != 'none':
if image.mode != channel_conversion:
# convert to different image mode if necessary
image = image.convert(channel_conversion)
# convert to numpy array
image = np.array(image)
# add channel axis if input is grayscale image
if image.ndim == 2:
image = image[..., np.newaxis]
elif image.ndim != 3:
raise ValueError("Unhandled number of channels: %d" % image.ndim)
# transpose to CHW
image = image.transpose(2, 0, 1)
return image

@staticmethod
@override
def get_category():
return "Images"

@staticmethod
@override
def get_id():
return "image-segmentation"

@staticmethod
@override
def get_dataset_form():
return DatasetForm()

@staticmethod
@override
def get_dataset_template(form):
"""
parameters:
- form: form returned by get_dataset_form(). This may be populated
with values if the job was cloned
returns:
- (template, context) tuple
- template is a Jinja template to use for rendering dataset creation
options
- context is a dictionary of context variables to use for rendering
the form
"""
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(os.path.join(extension_dir, TEMPLATE), "r").read()
context = {'form': form}
return (template, context)

@staticmethod
@override
def get_title():
return "Segmentation"

@override
def itemize_entries(self, stage):
if stage == constants.TEST_DB:
# don't retun anything for the test stage
return []

if stage == constants.TRAIN_DB or (not self.has_val_folder):
feature_image_list = self.make_image_list(self.feature_folder)
label_image_list = self.make_image_list(self.label_folder)
else:
# separate validation images
feature_image_list = self.make_image_list(self.validation_feature_folder)
label_image_list = self.make_image_list(self.validation_label_folder)

# make sure filenames match
if len(feature_image_list) != len(label_image_list):
raise ValueError(
"Expect same number of images in feature and label folders (%d!=%d)"
% (len(feature_image_list), len(label_image_list)))

for idx in range(len(feature_image_list)):
feature_name = os.path.splitext(
os.path.split(feature_image_list[idx])[1])[0]
label_name = os.path.splitext(
os.path.split(label_image_list[idx])[1])[0]
if feature_name != label_name:
raise ValueError("No corresponding feature/label pair found for (%s,%s)"
% (feature_name, label_name) )

# split lists if there is no val folder
if not self.has_val_folder:
feature_image_list = self.split_image_list(feature_image_list, stage)
label_image_list = self.split_image_list(label_image_list, stage)

return zip(
feature_image_list,
label_image_list)

def load_label(self, filename):
"""
Load a label image
"""
image = PIL.Image.open(filename)
if image.mode not in ['P', 'L', '1']:
raise ValueError("Labels are expected to be single-channel (paletted or "
" grayscale) images - %s mode is '%s'"
% (filename, image.mode))
return image

def make_image_list(self, folder):
image_files = []
for dirpath, dirnames, filenames in os.walk(folder, followlinks=True):
for filename in filenames:
if filename.lower().endswith(image.SUPPORTED_EXTENSIONS):
image_files.append('%s' % os.path.join(folder, filename))
if len(image_files) == 0:
raise ValueError("Unable to find supported images in %s" % folder)
return sorted(image_files)

def split_image_list(self, filelist, stage):
if self.random_indices is None:
self.random_indices = range(len(filelist))
random.shuffle(self.random_indices)
elif len(filelist) != len(self.random_indices):
raise ValueError(
"Expect same number of images in folders (%d!=%d)"
% (len(filelist), len(self.random_indices)))
filelist = [filelist[idx] for idx in self.random_indices]
pct_val = int(self.folder_pct_val)
n_val_entries = int(math.floor(len(filelist) * pct_val / 100))
if stage == constants.VAL_DB:
return filelist[:n_val_entries]
elif stage == constants.TRAIN_DB:
return filelist[n_val_entries:]
else:
raise ValueError("Unknown stage: %s" % stage)
Loading