# aymericdamien/TensorFlow-Examples

Refactor nearest_neighbor for TF1.0

`Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>`
• Loading branch information...
normanheckscher committed Jan 14, 2017
1 parent 7839ba2 commit 8e0382318172f1ff7fdccf998fd56bde00f9be8d
Showing with 25 additions and 16 deletions.
1. +2 −2 examples/2_BasicModels/nearest_neighbor.py
2. +23 −14 notebooks/2_BasicModels/nearest_neighbor.ipynb
 @@ -26,14 +26,14 @@ # Nearest Neighbor calculation using L1 Distance # Calculate L1 Distance distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1) distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1) # Prediction: Get min distance index (Nearest neighbor) pred = tf.arg_min(distance, 0) accuracy = 0. # Initializing the variables init = tf.initialize_all_variables() init = tf.global_variables_initializer() # Launch the graph with tf.Session() as sess:
 @@ -18,7 +18,7 @@ }, { "cell_type": "code", "execution_count": 2, "execution_count": 1, "metadata": { "collapsed": false }, @@ -27,10 +27,10 @@ "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/data/train-images-idx3-ubyte.gz\n", "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n", "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n", "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n" "Extracting MNIST_data/train-images-idx3-ubyte.gz\n", "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" ] } ], @@ -40,14 +40,14 @@ "\n", "# Import MINST data\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)" "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)" ] }, { "cell_type": "code", "execution_count": 3, "execution_count": 2, "metadata": { "collapsed": true "collapsed": false }, "outputs": [], "source": [ @@ -61,19 +61,19 @@ "\n", "# Nearest Neighbor calculation using L1 Distance\n", "# Calculate L1 Distance\n", "distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n", "distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)\n", "# Prediction: Get min distance index (Nearest neighbor)\n", "pred = tf.arg_min(distance, 0)\n", "\n", "accuracy = 0.\n", "\n", "# Initializing the variables\n", "init = tf.initialize_all_variables()" "init = tf.global_variables_initializer()" ] }, { "cell_type": "code", "execution_count": 4, "execution_count": 3, "metadata": { "collapsed": false }, @@ -305,6 +305,15 @@ " print \"Done!\"\n", " print \"Accuracy:\", accuracy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { @@ -316,16 +325,16 @@ "language_info": { "codemirror_mode": { "name": "ipython", "version": 2.0 "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } }

#### 0 comments on commit `8e03823`

Please sign in to comment.