diff --git a/License.txt b/License.txt new file mode 100644 index 0000000..7adbac8 --- /dev/null +++ b/License.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016 Sully Chen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7b4a661..e40320b 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,18 @@ -# Nvidia-Autopilot-TensorFlow -A TensorFlow implementation of this [Nvidia paper](https://arxiv.org/pdf/1604.07316.pdf) with some changes. +# Autopilot-TensorFlow +A TensorFlow implementation of this [Nvidia paper](https://arxiv.org/pdf/1604.07316.pdf) with some changes. For a summary of the design process and FAQs, see [this medium article I wrote](https://medium.com/@sullyfchen/how-a-high-school-junior-made-a-self-driving-car-705fa9b6e860). -#How to Use -Download the [dataset](https://drive.google.com/file/d/0B-KJCaaF7ellem5pSVM2NTNQcDg/view?usp=sharing) and extract into the repository folder +# IMPORTANT +Absolutely, under NO circumstance, should one ever pilot a car using computer vision software trained with this code (or any home made software for that matter). It is extremely dangerous to use your own self-driving software in a car, even if you think you know what you're doing, not to mention it is quite illegal in most places and any accidents will land you in huge lawsuits. + +This code is purely for research and statistics, absolutley NOT for application or testing of any sort. + +# How to Use +Download the [dataset](https://github.com/SullyChen/driving-datasets) and extract into the repository folder Use `python train.py` to train the model Use `python run.py` to run the model on a live webcam feed Use `python run_dataset.py` to run the model on the dataset + +To visualize training using Tensorboard use `tensorboard --logdir=./logs`, then open http://0.0.0.0:6006/ into your web browser. diff --git a/driving_data.py b/driving_data.py index cdc8f8c..97fbb73 100644 --- a/driving_data.py +++ b/driving_data.py @@ -1,11 +1,13 @@ -import scipy.misc +import cv2 import random +import numpy as np xs = [] ys = [] #points to the end of the last batch -batch_pointer = 0 +train_batch_pointer = 0 +val_batch_pointer = 0 #read data.txt with open("driving_dataset/data.txt") as f: @@ -14,7 +16,7 @@ #the paper by Nvidia uses the inverse of the turning radius, #but steering wheel angle is proportional to the inverse of turning radius #so the steering wheel angle in radians is used as the output - ys.append(float(line.split()[1]) * scipy.pi / 180) + ys.append(float(line.split()[1]) * 3.14159265 / 180) #get number of images num_images = len(xs) @@ -24,12 +26,31 @@ random.shuffle(c) xs, ys = zip(*c) -def LoadBatch(batch_size): - global batch_pointer +train_xs = xs[:int(len(xs) * 0.8)] +train_ys = ys[:int(len(xs) * 0.8)] + +val_xs = xs[-int(len(xs) * 0.2):] +val_ys = ys[-int(len(xs) * 0.2):] + +num_train_images = len(train_xs) +num_val_images = len(val_xs) + +def LoadTrainBatch(batch_size): + global train_batch_pointer + x_out = [] + y_out = [] + for i in range(0, batch_size): + x_out.append(cv2.resize(cv2.imread(train_xs[(train_batch_pointer + i) % num_train_images])[-150:], (200, 66)) / 255.0) + y_out.append([train_ys[(train_batch_pointer + i) % num_train_images]]) + train_batch_pointer += batch_size + return x_out, y_out + +def LoadValBatch(batch_size): + global val_batch_pointer x_out = [] y_out = [] for i in range(0, batch_size): - x_out.append(scipy.misc.imresize(scipy.misc.imread(xs[(batch_pointer + i) % num_images])[-150:], [66, 200]) / 255.0) - y_out.append([ys[(batch_pointer + i) % num_images]]) - batch_pointer += batch_size + x_out.append(cv2.resize(cv2.imread(val_xs[(val_batch_pointer + i) % num_val_images])[-150:], (200, 66)) / 255.0) + y_out.append([val_ys[(val_batch_pointer + i) % num_val_images]]) + val_batch_pointer += batch_size return x_out, y_out diff --git a/model.py b/model.py index 0bc4576..4847174 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,5 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() import scipy def weight_variable(shape): @@ -85,4 +86,4 @@ def conv2d(x, W, stride): W_fc5 = weight_variable([10, 1]) b_fc5 = bias_variable([1]) -y = tf.mul(tf.atan(tf.matmul(h_fc4_drop, W_fc5) + b_fc5), 2) #scale the atan output +y = tf.multiply(tf.atan(tf.matmul(h_fc4_drop, W_fc5) + b_fc5), 2) #scale the atan output diff --git a/run.py b/run.py index 0fe20ac..9212ecb 100644 --- a/run.py +++ b/run.py @@ -1,8 +1,14 @@ -import tensorflow as tf -import scipy.misc +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() import model import cv2 from subprocess import call +import os + +#check if on windows OS +windows = False +if os.name == 'nt': + windows = True sess = tf.InteractiveSession() saver = tf.train.Saver() @@ -11,15 +17,21 @@ img = cv2.imread('steering_wheel_image.jpg',0) rows,cols = img.shape +smoothed_angle = 0 + cap = cv2.VideoCapture(0) while(cv2.waitKey(10) != ord('q')): ret, frame = cap.read() - image = scipy.misc.imresize(frame, [66, 200]) / 255.0 - degrees = model.y.eval(feed_dict={model.x: [image], model.keep_prob: 1.0})[0][0] * 180 / scipy.pi - call("clear") + image = cv2.resize(frame, (200, 66)) / 255.0 + degrees = model.y.eval(feed_dict={model.x: [image], model.keep_prob: 1.0})[0][0] * 180 / 3.14159265 + if not windows: + call("clear") print("Predicted steering angle: " + str(degrees) + " degrees") cv2.imshow('frame', frame) - M = cv2.getRotationMatrix2D((cols/2,rows/2),-degrees,1) + #make smooth angle transitions by turning the steering wheel based on the difference of the current angle + #and the predicted angle + smoothed_angle += 0.2 * pow(abs((degrees - smoothed_angle)), 2.0 / 3.0) * (degrees - smoothed_angle) / abs(degrees - smoothed_angle) + M = cv2.getRotationMatrix2D((cols/2,rows/2),-smoothed_angle,1) dst = cv2.warpAffine(img,M,(cols,rows)) cv2.imshow("steering wheel", dst) diff --git a/run_dataset.py b/run_dataset.py index 4ba2abd..276311e 100644 --- a/run_dataset.py +++ b/run_dataset.py @@ -1,8 +1,14 @@ -import tensorflow as tf -import scipy.misc +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() import model import cv2 from subprocess import call +import os + +#check if on windows OS +windows = False +if os.name == 'nt': + windows = True sess = tf.InteractiveSession() saver = tf.train.Saver() @@ -11,14 +17,21 @@ img = cv2.imread('steering_wheel_image.jpg',0) rows,cols = img.shape +smoothed_angle = 0 + i = 0 while(cv2.waitKey(10) != ord('q')): - image = scipy.misc.imresize(scipy.misc.imread("driving_dataset/" + str(i) + ".jpg")[-150:], [66, 200]) / 255.0 - degrees = model.y.eval(feed_dict={model.x: [image], model.keep_prob: 1.0})[0][0] * 180.0 / scipy.pi - call("clear") + full_image = cv2.imread("driving_dataset/" + str(i) + ".jpg") + image = cv2.resize(full_image[-150:], (200, 66)) / 255.0 + degrees = model.y.eval(feed_dict={model.x: [image], model.keep_prob: 1.0})[0][0] * 180.0 / 3.14159265 + if not windows: + call("clear") print("Predicted steering angle: " + str(degrees) + " degrees") - cv2.imshow("frame", image) - M = cv2.getRotationMatrix2D((cols/2,rows/2),-degrees,1) + cv2.imshow("frame", full_image) + #make smooth angle transitions by turning the steering wheel based on the difference of the current angle + #and the predicted angle + smoothed_angle += 0.2 * pow(abs((degrees - smoothed_angle)), 2.0 / 3.0) * (degrees - smoothed_angle) / abs(degrees - smoothed_angle) + M = cv2.getRotationMatrix2D((cols/2,rows/2),-smoothed_angle,1) dst = cv2.warpAffine(img,M,(cols,rows)) cv2.imshow("steering wheel", dst) i += 1 diff --git a/save/checkpoint b/save/checkpoint new file mode 100644 index 0000000..febd7d5 --- /dev/null +++ b/save/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "model.ckpt" +all_model_checkpoint_paths: "model.ckpt" diff --git a/save/model.ckpt b/save/model.ckpt new file mode 100644 index 0000000..ebf7cb2 Binary files /dev/null and b/save/model.ckpt differ diff --git a/save/model.ckpt.meta b/save/model.ckpt.meta new file mode 100644 index 0000000..0ef5016 Binary files /dev/null and b/save/model.ckpt.meta differ diff --git a/steering_wheel_image.jpg b/steering_wheel_image.jpg index b236679..e6c6505 100755 Binary files a/steering_wheel_image.jpg and b/steering_wheel_image.jpg differ diff --git a/train.py b/train.py index 8aefeb8..2a6d1f9 100644 --- a/train.py +++ b/train.py @@ -1,22 +1,57 @@ -import tensorflow as tf +import os +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() +from tensorflow.core.protobuf import saver_pb2 import driving_data import model +LOGDIR = './save' + sess = tf.InteractiveSession() -loss = tf.reduce_mean(tf.square(tf.sub(model.y_, model.y))) -train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss) -sess.run(tf.initialize_all_variables()) - -saver = tf.train.Saver() - -#train over the dataset about 30 times -for i in range(int(driving_data.num_images * 0.3)): - xs, ys = driving_data.LoadBatch(100) - train_step.run(feed_dict={model.x: xs, model.y_: ys, model.keep_prob: 0.8}) - if i % 10 == 0: - print("step %d, train loss %g"%(i, loss.eval(feed_dict={ - model.x:xs, model.y_: ys, model.keep_prob: 1.0}))) - if i % 100 == 0: - save_path = saver.save(sess, "save/model.ckpt") - print("Model saved in file: %s" % save_path) +L2NormConst = 0.001 + +train_vars = tf.trainable_variables() + +loss = tf.reduce_mean(tf.square(tf.subtract(model.y_, model.y))) + tf.add_n([tf.nn.l2_loss(v) for v in train_vars]) * L2NormConst +train_step = tf.train.AdamOptimizer(1e-4).minimize(loss) +sess.run(tf.global_variables_initializer()) + +# create a summary to monitor cost tensor +tf.summary.scalar("loss", loss) +# merge all summaries into a single op +merged_summary_op = tf.summary.merge_all() + +saver = tf.train.Saver(write_version = saver_pb2.SaverDef.V2) + +# op to write logs to Tensorboard +logs_path = './logs' +summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph()) + +epochs = 30 +batch_size = 100 + +# train over the dataset about 30 times +for epoch in range(epochs): + for i in range(int(driving_data.num_images/batch_size)): + xs, ys = driving_data.LoadTrainBatch(batch_size) + train_step.run(feed_dict={model.x: xs, model.y_: ys, model.keep_prob: 0.8}) + if i % 10 == 0: + xs, ys = driving_data.LoadValBatch(batch_size) + loss_value = loss.eval(feed_dict={model.x:xs, model.y_: ys, model.keep_prob: 1.0}) + print("Epoch: %d, Step: %d, Loss: %g" % (epoch, epoch * batch_size + i, loss_value)) + + # write logs at every iteration + summary = merged_summary_op.eval(feed_dict={model.x:xs, model.y_: ys, model.keep_prob: 1.0}) + summary_writer.add_summary(summary, epoch * driving_data.num_images/batch_size + i) + + if i % batch_size == 0: + if not os.path.exists(LOGDIR): + os.makedirs(LOGDIR) + checkpoint_path = os.path.join(LOGDIR, "model.ckpt") + filename = saver.save(sess, checkpoint_path) + print("Model saved in file: %s" % filename) + +print("Run the command line:\n" \ + "--> tensorboard --logdir=./logs " \ + "\nThen open http://0.0.0.0:6006/ into your web browser")