Skip to content

Commit

Permalink
Introduce label conversion and TFRecords encoding. Introduce independ…
Browse files Browse the repository at this point in the history
…ent train and test files. Organize code into modules.
  • Loading branch information
andreaazzini committed Dec 14, 2016
1 parent 5685e01 commit d0619a4
Show file tree
Hide file tree
Showing 20 changed files with 579 additions and 302 deletions.
6 changes: 2 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
__pycache__/
logs/
input/old/
input/raw/
input/shuffled/
ckpts/
input/
logs/

*.pyc
vgg16_weights.npz
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# SegNet

SegNet is a TensorFlow implementation of the [segmentation network proposed by Kendall et al.](http://mi.eng.cam.ac.uk/projects/segnet/).

## Configuration

Before running, download the [VGG16 weights file](https://www.cs.toronto.edu/~frossard/vgg16/vgg16_weights.npz)
and save it as `input/vgg16_weights.npz`.
and save it as `input/vgg16_weights.npz` if you want to initialize the encoder weights with the VGG16 ones trained on ImageNet classification dataset.

In `config.py`, choose your working dataset. The dataset name needs to match the data directories you create in your `input` folder.
You can use `segnet-32` and `segnet-13` to replicate the aforementioned Kendall et al. experiments.

## Train and test

Train SegNet with `python segnet.py`.
Train SegNet with `python -m src/train.py`. Analogously, test it with `python -m src/test.py`.
Empty file added __init__.py
Empty file.
56 changes: 0 additions & 56 deletions classifier.py

This file was deleted.

1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
working_dataset = 'single-coke'
Binary file removed input/train.tfrecords
Binary file not shown.
Binary file removed input/train_labels.tfrecords
Binary file not shown.
172 changes: 0 additions & 172 deletions segnet.py

This file was deleted.

Empty file added src/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions src/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import config
import tensorflow as tf
import utils

colors = tf.cast(tf.pack(utils.colors_of_dataset(config.working_dataset)), tf.float32) / 255

def color_mask(tensor, color):
return tf.reduce_all(tf.equal(tensor, color), 3)

def one_hot(labels):
color_tensors = tf.unstack(colors)
channel_tensors = list(map(lambda color: color_mask(labels, color), color_tensors))
one_hot_labels = tf.cast(tf.stack(channel_tensors, 3), 'float32')
return one_hot_labels

def rgb(logits):
softmax = tf.nn.softmax(logits)
argmax = tf.argmax(softmax, 3)
n = colors.get_shape().as_list()[0]
one_hot = tf.one_hot(argmax, n, dtype=tf.float32)
one_hot_matrix = tf.reshape(one_hot, [-1, n])
rgb_matrix = tf.matmul(one_hot_matrix, colors)
rgb_tensor = tf.reshape(rgb_matrix, [-1, 224, 224, 3])
return tf.cast(rgb_tensor, tf.float32)
4 changes: 2 additions & 2 deletions convnet.py → src/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def deconv(x, receptive_field_shape, channels_shape, stride, name):
kernel_shape = receptive_field_shape + channels_shape
bias_shape = [channels_shape[0]]

input_shape = tf.shape(x)
input_shape = x.get_shape().as_list()
batch_size = input_shape[0]
height = input_shape[1]
width = input_shape[2]
Expand All @@ -26,4 +26,4 @@ def deconv(x, receptive_field_shape, channels_shape, stride, name):
return tf.nn.relu(tf.contrib.layers.batch_norm(conv_bias))

def max_pool(x, size, stride, padding='SAME'):
return tf.nn.max_pool_with_argmax(x, ksize=[1, size, size, 1], strides=[1, stride, stride, 1], padding=padding, name='maxpool')
return tf.nn.max_pool(x, ksize=[1, size, size, 1], strides=[1, stride, stride, 1], padding=padding, name='maxpool')
File renamed without changes.
18 changes: 11 additions & 7 deletions inputs.py → src/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ def read_and_decode_single_example(filename):
image.set_shape([224, 224, 3])
return image

def inputs(train_filename, train_labels_filename, batch_size):
def inputs(batch_size, train_filename, train_labels_filename=None):
image = read_and_decode_single_example(train_filename)
label = read_and_decode_single_example(train_labels_filename)
images_batch, labels_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size,
capacity=2000,
min_after_dequeue=1000)
return images_batch, labels_batch
if train_labels_filename:
label = read_and_decode_single_example(train_labels_filename)
images_batch, labels_batch = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
capacity=2000,
min_after_dequeue=1000)
return images_batch, labels_batch
else:
return tf.train.shuffle_batch([image], batch_size=batch_size, capacity=2000, min_after_dequeue=1000)
Loading

0 comments on commit d0619a4

Please sign in to comment.