Skip to content

Commit

Permalink
Fixed data augmentation to work with inception v3 and v4
Browse files Browse the repository at this point in the history
  • Loading branch information
Steven-N-Hart committed Oct 5, 2017
1 parent c31e97e commit 88a7d9d
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 31 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
*/*/*/*/*/*/*/*/__pycache__/
*/__pycache__/
checkpoints/
nets/
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ mv inception_v3_2016_08_28.tar.gz checkpoints/
5. Run the pretrained model on the SPITZ dataset.
```
DATASET_DIR=/data/images/
TRAIN_DIR=/tmp/train_logs
TRAIN_DIR=/tmp/from_checkpoint
python scripts/train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_name=spitz \
Expand All @@ -65,4 +65,35 @@ python scripts/train_image_classifier.py \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--preprocessing_name spitz
```

6. Run the naive model on Spitz
```
TRAIN_DIR=/tmp/from_scratch
for DA in '--DA' ''
do
for model_name in inception_v4 inception_v3
do
for optimizer in adadelta adagrad adam ftrl momentum sgd rmsprop
do
for lr in 0.01 0.05 0.001
do
time python scripts/train_image_classifier.py \
--train_dir=${TRAIN_DIR}/${model_name}_${DA//-/}_${optimizer}_${lr} \
--dataset_name=spitz \
--train_image_size 299 \
--dataset_split_name=train \
--dataset_dir=${DATASET_DIR} \
--model_name=${model_name} \
--preprocessing_name spitz \
--log_every_n_steps 1000 \
--num_clones 4 \
--optimizer ${optimizer} \
--max_number_of_steps 20000 \
${DA}
echo ${model_name}_${DA//-}_${optimizer}_${lr}
done
done
done
done
```
2 changes: 1 addition & 1 deletion nets/nets_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
}


def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False,depth_multiplier=1.0, dropout_keep_prob=0.8):
"""Returns a network_fn such as `logits, end_points = network_fn(images)`.
Args:
Expand Down
4 changes: 2 additions & 2 deletions preprocessing/preprocessing_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def get_preprocessing(name, is_training=False):
if name not in preprocessing_fn_map:
raise ValueError('Preprocessing name [%s] was not recognized' % name)

def preprocessing_fn(image, output_height, output_width, **kwargs):
def preprocessing_fn(image, output_height=299, output_width=299):
return preprocessing_fn_map[name].preprocess_image(
image, output_height, output_width, is_training=is_training, **kwargs)
image, output_height, output_width)

return preprocessing_fn
5 changes: 3 additions & 2 deletions preprocessing/sec_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
slim = tf.contrib.slim


def preprocess_image(image):
def preprocess_image(image, height, width,
is_training=False):
"""Data argumentation. Produce child images from a same input image.
Args:
Expand All @@ -34,5 +35,5 @@ def preprocess_image(image):
num_child_image = len(angle) + 1
for ang in angle:
image_single = tf.contrib.image.rotate(image,ang)
image_list.append(image_single)
image_list = tf.concat([image_list,[image_single]],0)
return(image_list, num_child_image)
2 changes: 1 addition & 1 deletion preprocessing/spitz_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def preprocess_image(image, height, width,
Raises:
ValueError: if user does not provide bounding box
"""
print('Preprcessing image',image,height,width)
#print('Preprcessing image',image,height,width)
if is_training:
return preprocess_for_train(image, height, width)
else:
Expand Down
137 changes: 115 additions & 22 deletions scripts/train_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,25 @@
tf.app.flags.DEFINE_integer('max_number_of_steps', None,
'The maximum number of training steps.')

tf.app.flags.DEFINE_string(
'sec_preprocessing_name', 'sec_preprocessing', 'The name of the secondary '
'preprocessing to use. To create an image list. If left as `None`, then no '
'image expansion will be used.')

tf.app.flags.DEFINE_boolean('DA_flag', False, 'Data augmentation option. if True, then call "sec_preprocessing" func and average the logits.')

tf.app.flags.DEFINE_float(
'dropout_keep_prob', 0.8,
'Probability of keeping node in network.')


tf.app.flags.DEFINE_float(
'depth_multiplier', 1.0,
'float that controls network bulk.')

list_with_multiplier=['inception_v2','inception_v3','inception_v4','inception_resnet_v2']
list_with_dropout=['inception_v2','inception_resnet_v2','inception_v3','inception_v4']

#####################
# Fine-Tuning Flags #
#####################
Expand Down Expand Up @@ -388,21 +407,51 @@ def _get_variables_to_train():
variables_to_train.extend(variables)
return variables_to_train

def _record_accuracy(predictions,labels):
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
})

# Print the summaries to screen.
for name, value in names_to_values.items():
summary_name = 'eval/%s' % name
op = tf.summary.scalar(summary_name, value, collections=[])
op = tf.Print(op, [value], summary_name)
tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
return True
def val_exist():
""" verify if depth_multiplier and dropout_keep_prob are accepted in the chosen architecture.
"""
if not FLAGS.model_name in list_with_multiplier:
raise ValueError('The chosen architecture is not support "depth_multiplier" parameter.')
if not FLAGS.model_name in list_with_dropout:
raise ValueError('The chosen architecture is not support "dropout_keep_prob" parameter.')
return

def val_range():
""" Verify if depth_multiplier value and dropout_keep_prob value are in the range of [0.0,1.0].
"""
if FLAGS.depth_multiplier<0.0 or FLAGS.depth_multiplier>1.0:
raise ValueError('The depth_multiplier value should be in the range of [0.0,1.0].')
if FLAGS.dropout_keep_prob<0.0 or FLAGS.dropout_keep_prob>1.0:
raise ValueError('The dropout_keep_prob value should be in the range of [0.0,1.0].')
return

def average_logits(logits, num_child_image, batch_size=FLAGS.batch_size):
""" Calculate the average of multiple logits. Useful when multiple image rotations used
for a single prediction.
"""
logits_list = []
n = 0
while n < batch_size/num_child_image:
logits_temp = tf.reduce_mean(logits[n*num_child_image:(n+1)*num_child_image], axis=0, keep_dims=True)
logits_list.append(logits_temp)
n = n + 1
logits = tf.concat(logits_list, 0)
return logits

def labels_average_logits(labels, num_child_image, batch_size=FLAGS.batch_size):
""" When multiple image rotations used for a single prediction, they also have labels duplicated.
This step essentially creates those labels.
"""
label_list = []
n = 0
while n < batch_size/num_child_image:
label_single = labels[n*num_child_image]
label_list.append(label_single)
n = n + 1
labels = tf.concat(label_list, 0)
return labels

def _pass():
return True

def main(_):
if not FLAGS.dataset_dir:
Expand Down Expand Up @@ -435,10 +484,15 @@ def main(_):
######################
# Select the network #
######################
val_exist()
val_range()

network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
weight_decay=FLAGS.weight_decay,
depth_multiplier=FLAGS.depth_multiplier,
dropout_keep_prob=FLAGS.dropout_keep_prob,
is_training=True)

#####################################
Expand All @@ -463,37 +517,78 @@ def main(_):

train_image_size = FLAGS.train_image_size or network_fn.default_image_size
image = image_preprocessing_fn(image, train_image_size, train_image_size)
#print('Original label shape: {}'.format(label))
if FLAGS.DA_flag == True: #Data augmentation
enqueue_many=True
sec_preprocessing_name = FLAGS.sec_preprocessing_name
image_sec_preprocessing_fn = preprocessing_factory.get_preprocessing(
sec_preprocessing_name)
image, num_child_image = image_sec_preprocessing_fn(image)
#print('num_child_image: {}'.format(num_child_image))
label = [label for i in range(num_child_image)]
label=tf.reshape(label, [-1])
#print('label: {}'.format(label))
#print('image: {}'.format(image))
else:
enqueue_many=False

images, labels = tf.train.batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=5 * FLAGS.batch_size)

if FLAGS.DA_flag == True:
# Need to reshape since now the shape is [batch_size, num_child_image, height, width, channels]
# Needs to be [batch_size * num_child_image, height, width, channels]
labels=tf.reshape(labels, [-1])
labels=tf.reshape(labels, [FLAGS.batch_size * num_child_image, 1])

images = tf.reshape(images,[-1])
images=tf.reshape(images, [FLAGS.batch_size * num_child_image, train_image_size,train_image_size,3])
#print('Batch image shape {}'.format(images))
#print('Batch labels shape {}'.format(labels))

labels = slim.one_hot_encoding(
labels, dataset.num_classes - FLAGS.labels_offset)
batch_queue = slim.prefetch_queue.prefetch_queue(
[images, labels], capacity=2 * deploy_config.num_clones)
#print('Completed model deployment')


####################
# Define the model #
####################
def clone_fn(batch_queue):
"""Allows data parallelism by creating multiple clones of network_fn."""
images, labels = batch_queue.dequeue()
#print('Clone function images shape: {}'.format(images))
#print('Clone function labels shape: {}'.format(labels))
logits, end_points = network_fn(images)
tf.logging.info('Post network')
#print('Clone function logits shape: {}'.format(logits))
#print('Clone function end_points shape: {}'.format(end_points))
AuxLogits = end_points['AuxLogits']
#print('cloning function')
if FLAGS.DA_flag == True: #Actually score multiple image rotations of same image
logits = average_logits(logits, num_child_image)
AuxLogits = average_logits(AuxLogits, num_child_image)
labels = labels_average_logits(labels, num_child_image)

#############################
# Specify the loss function #
#############################

if 'AuxLogits' in end_points:
#print('Specifying the loss function for Aux logits {}\n{}'.format(labels,AuxLogits))
tf.losses.softmax_cross_entropy(
labels, end_points['AuxLogits'],
labels, AuxLogits,
label_smoothing=FLAGS.label_smoothing, weights=0.4,
scope='aux_loss')
#print('Specifying the loss function {}\n{}'.format(labels, logits))
tf.losses.softmax_cross_entropy(
labels, logits, label_smoothing=FLAGS.label_smoothing, weights=1.0)


#############################
## Calculation of accuracy ##
#############################
Expand All @@ -514,11 +609,11 @@ def clone_fn(batch_queue):
tf.logging.info('Updating ops')
# Add summaries for end_points.
end_points = clones[0].outputs
for end_point in end_points:
x = end_points[end_point]
summaries.add(tf.summary.histogram('activations/' + end_point, x))
summaries.add(tf.summary.scalar('sparsity/' + end_point,
tf.nn.zero_fraction(x)))
#for end_point in end_points:
# x = end_points[end_point]
# summaries.add(tf.summary.histogram('activations/' + end_point, x))
# summaries.add(tf.summary.scalar('sparsity/' + end_point,
# tf.nn.zero_fraction(x)))

# Add summaries for losses.
for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
Expand All @@ -537,11 +632,9 @@ def clone_fn(batch_queue):
# Stack and take the mean.
accuracy = tf.reduce_mean(tf.stack(accuracy, axis=0))


# Add summaries for accuracy.
summaries.add(tf.summary.scalar('accuracy/training', accuracy))


#################################
# Configure the moving averages #
#################################
Expand Down

0 comments on commit 88a7d9d

Please sign in to comment.