Switch branches/tags
Nothing to show
Find file Copy path
112e967 Sep 17, 2018
2 contributors

Users who have contributed to this file

@GeorgeSeif @mrshu
340 lines (255 sloc) 13.6 KB
from __future__ import print_function
import os,time,cv2, sys, math
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import time, datetime
import argparse
import random
import os, sys
import subprocess
# use 'Agg' on matplotlib so that plots could be generated even without Xserver
# running
import matplotlib
from utils import utils, helpers
from builders import model_builder
import matplotlib.pyplot as plt
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs to train for')
parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number')
parser.add_argument('--checkpoint_step', type=int, default=5, help='How often to save checkpoints (epochs)')
parser.add_argument('--validation_step', type=int, default=1, help='How often to perform validation (epochs)')
parser.add_argument('--image', type=str, default=None, help='The image you want to predict on. Only valid in "predict" mode.')
parser.add_argument('--continue_training', type=str2bool, default=False, help='Whether to continue training from a checkpoint')
parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.')
parser.add_argument('--crop_height', type=int, default=512, help='Height of cropped input image to network')
parser.add_argument('--crop_width', type=int, default=512, help='Width of cropped input image to network')
parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch')
parser.add_argument('--num_val_images', type=int, default=20, help='The number of images to used for validations')
parser.add_argument('--h_flip', type=str2bool, default=False, help='Whether to randomly flip the image horizontally for data augmentation')
parser.add_argument('--v_flip', type=str2bool, default=False, help='Whether to randomly flip the image vertically for data augmentation')
parser.add_argument('--brightness', type=float, default=None, help='Whether to randomly change the image brightness for data augmentation. Specifies the max bightness change as a factor between 0.0 and 1.0. For example, 0.1 represents a max brightness change of 10%% (+-).')
parser.add_argument('--rotation', type=float, default=None, help='Whether to randomly rotate the image for data augmentation. Specifies the max rotation angle in degrees.')
parser.add_argument('--model', type=str, default="FC-DenseNet56", help='The model you are using. See for supported models')
parser.add_argument('--frontend', type=str, default="ResNet101", help='The frontend you are using. See for supported models')
args = parser.parse_args()
def data_augmentation(input_image, output_image):
# Data augmentation
input_image, output_image = utils.random_crop(input_image, output_image, args.crop_height, args.crop_width)
if args.h_flip and random.randint(0,1):
input_image = cv2.flip(input_image, 1)
output_image = cv2.flip(output_image, 1)
if args.v_flip and random.randint(0,1):
input_image = cv2.flip(input_image, 0)
output_image = cv2.flip(output_image, 0)
if args.brightness:
factor = 1.0 + random.uniform(-1.0*args.brightness, args.brightness)
table = np.array([((i / 255.0) * factor) * 255 for i in np.arange(0, 256)]).astype(np.uint8)
input_image = cv2.LUT(input_image, table)
if args.rotation:
angle = random.uniform(-1*args.rotation, args.rotation)
if args.rotation:
M = cv2.getRotationMatrix2D((input_image.shape[1]//2, input_image.shape[0]//2), angle, 1.0)
input_image = cv2.warpAffine(input_image, M, (input_image.shape[1], input_image.shape[0]), flags=cv2.INTER_NEAREST)
output_image = cv2.warpAffine(output_image, M, (output_image.shape[1], output_image.shape[0]), flags=cv2.INTER_NEAREST)
return input_image, output_image
# Get the names of the classes so we can record the evaluation results
class_names_list, label_values = helpers.get_label_info(os.path.join(args.dataset, "class_dict.csv"))
class_names_string = ""
for class_name in class_names_list:
if not class_name == class_names_list[-1]:
class_names_string = class_names_string + class_name + ", "
class_names_string = class_names_string + class_name
num_classes = len(label_values)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# Compute your softmax cross entropy loss
net_input = tf.placeholder(tf.float32,shape=[None,None,None,3])
net_output = tf.placeholder(tf.float32,shape=[None,None,None,num_classes])
network, init_fn = model_builder.build_model(model_name=args.model, frontend=args.frontend, net_input=net_input, num_classes=num_classes, crop_width=args.crop_width, crop_height=args.crop_height, is_training=True)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=network, labels=net_output))
opt = tf.train.RMSPropOptimizer(learning_rate=0.0001, decay=0.995).minimize(loss, var_list=[var for var in tf.trainable_variables()])
# If a pre-trained ResNet is required, load the weights.
# This must be done AFTER the variables are initialized with
if init_fn is not None:
# Load a previous checkpoint if desired
model_checkpoint_name = "checkpoints/latest_model_" + args.model + "_" + args.dataset + ".ckpt"
if args.continue_training:
print('Loaded latest model checkpoint')
saver.restore(sess, model_checkpoint_name)
# Load the data
print("Loading the data ...")
train_input_names,train_output_names, val_input_names, val_output_names, test_input_names, test_output_names = utils.prepare_data(dataset_dir=args.dataset)
print("\n***** Begin training *****")
print("Dataset -->", args.dataset)
print("Model -->", args.model)
print("Crop Height -->", args.crop_height)
print("Crop Width -->", args.crop_width)
print("Num Epochs -->", args.num_epochs)
print("Batch Size -->", args.batch_size)
print("Num Classes -->", num_classes)
print("Data Augmentation:")
print("\tVertical Flip -->", args.v_flip)
print("\tHorizontal Flip -->", args.h_flip)
print("\tBrightness Alteration -->", args.brightness)
print("\tRotation -->", args.rotation)
avg_loss_per_epoch = []
avg_scores_per_epoch = []
avg_iou_per_epoch = []
# Which validation images do we want
val_indices = []
num_vals = min(args.num_val_images, len(val_input_names))
# Set random seed to make sure models are validated on the same validation images.
# So you can compare the results of different models more intuitively.
# Do the training here
for epoch in range(args.epoch_start_i, args.num_epochs):
current_losses = []
# Equivalent to shuffling
id_list = np.random.permutation(len(train_input_names))
num_iters = int(np.floor(len(id_list) / args.batch_size))
st = time.time()
for i in range(num_iters):
# st=time.time()
input_image_batch = []
output_image_batch = []
# Collect a batch of images
for j in range(args.batch_size):
index = i*args.batch_size + j
id = id_list[index]
input_image = utils.load_image(train_input_names[id])
output_image = utils.load_image(train_output_names[id])
with tf.device('/cpu:0'):
input_image, output_image = data_augmentation(input_image, output_image)
# Prep the data. Make sure the labels are in one-hot format
input_image = np.float32(input_image) / 255.0
output_image = np.float32(helpers.one_hot_it(label=output_image, label_values=label_values))
input_image_batch.append(np.expand_dims(input_image, axis=0))
output_image_batch.append(np.expand_dims(output_image, axis=0))
if args.batch_size == 1:
input_image_batch = input_image_batch[0]
output_image_batch = output_image_batch[0]
input_image_batch = np.squeeze(np.stack(input_image_batch, axis=1))
output_image_batch = np.squeeze(np.stack(output_image_batch, axis=1))
# Do the training
cnt = cnt + args.batch_size
if cnt % 20 == 0:
string_print = "Epoch = %d Count = %d Current_Loss = %.4f Time = %.2f"%(epoch,cnt,current,time.time()-st)
st = time.time()
mean_loss = np.mean(current_losses)
# Create directories if needed
if not os.path.isdir("%s/%04d"%("checkpoints",epoch)):
# Save latest checkpoint to same file name
print("Saving latest checkpoint"),model_checkpoint_name)
if val_indices != 0 and epoch % args.checkpoint_step == 0:
print("Saving checkpoint for this epoch"),"%s/%04d/model.ckpt"%("checkpoints",epoch))
if epoch % args.validation_step == 0:
print("Performing validation")
target.write("val_name, avg_accuracy, precision, recall, f1 score, mean iou, %s\n" % (class_names_string))
scores_list = []
class_scores_list = []
precision_list = []
recall_list = []
f1_list = []
iou_list = []
# Do the validation on a small set of validation images
for ind in val_indices:
input_image = np.expand_dims(np.float32(utils.load_image(val_input_names[ind])[:args.crop_height, :args.crop_width]),axis=0)/255.0
gt = utils.load_image(val_output_names[ind])[:args.crop_height, :args.crop_width]
gt = helpers.reverse_one_hot(helpers.one_hot_it(gt, label_values))
# st = time.time()
output_image =,feed_dict={net_input:input_image})
output_image = np.array(output_image[0,:,:,:])
output_image = helpers.reverse_one_hot(output_image)
out_vis_image = helpers.colour_code_segmentation(output_image, label_values)
accuracy, class_accuracies, prec, rec, f1, iou = utils.evaluate_segmentation(pred=output_image, label=gt, num_classes=num_classes)
file_name = utils.filepath_to_name(val_input_names[ind])
target.write("%s, %f, %f, %f, %f, %f"%(file_name, accuracy, prec, rec, f1, iou))
for item in class_accuracies:
target.write(", %f"%(item))
gt = helpers.colour_code_segmentation(gt, label_values)
file_name = os.path.basename(val_input_names[ind])
file_name = os.path.splitext(file_name)[0]
cv2.imwrite("%s/%04d/%s_pred.png"%("checkpoints",epoch, file_name),cv2.cvtColor(np.uint8(out_vis_image), cv2.COLOR_RGB2BGR))
cv2.imwrite("%s/%04d/%s_gt.png"%("checkpoints",epoch, file_name),cv2.cvtColor(np.uint8(gt), cv2.COLOR_RGB2BGR))
avg_score = np.mean(scores_list)
class_avg_scores = np.mean(class_scores_list, axis=0)
avg_precision = np.mean(precision_list)
avg_recall = np.mean(recall_list)
avg_f1 = np.mean(f1_list)
avg_iou = np.mean(iou_list)
print("\nAverage validation accuracy for epoch # %04d = %f"% (epoch, avg_score))
print("Average per class validation accuracies for epoch # %04d:"% (epoch))
for index, item in enumerate(class_avg_scores):
print("%s = %f" % (class_names_list[index], item))
print("Validation precision = ", avg_precision)
print("Validation recall = ", avg_recall)
print("Validation F1 score = ", avg_f1)
print("Validation IoU score = ", avg_iou)
m, s = divmod(remain_time, 60)
h, m = divmod(m, 60)
if s!=0:
train_time="Remaining training time = %d hours %d minutes %d seconds\n"%(h,m,s)
train_time="Remaining training time : Training completed.\n"
scores_list = []
fig1, ax1 = plt.subplots(figsize=(11, 8))
ax1.plot(range(epoch+1), avg_scores_per_epoch)
ax1.set_title("Average validation accuracy vs epochs")
ax1.set_ylabel("Avg. val. accuracy")
fig2, ax2 = plt.subplots(figsize=(11, 8))
ax2.plot(range(epoch+1), avg_loss_per_epoch)
ax2.set_title("Average loss vs epochs")
ax2.set_ylabel("Current loss")
fig3, ax3 = plt.subplots(figsize=(11, 8))
ax3.plot(range(epoch+1), avg_iou_per_epoch)
ax3.set_title("Average IoU vs epochs")
ax3.set_ylabel("Current IoU")