Permalink
Browse files

Refactor nearest_neighbor for TF1.0

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
  • Loading branch information...
normanheckscher committed Jan 14, 2017
1 parent 7839ba2 commit 8e0382318172f1ff7fdccf998fd56bde00f9be8d
Showing with 25 additions and 16 deletions.
  1. +2 −2 examples/2_BasicModels/nearest_neighbor.py
  2. +23 −14 notebooks/2_BasicModels/nearest_neighbor.ipynb
@@ -26,14 +26,14 @@
# Nearest Neighbor calculation using L1 Distance
# Calculate L1 Distance
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
# Prediction: Get min distance index (Nearest neighbor)
pred = tf.arg_min(distance, 0)
accuracy = 0.
# Initializing the variables
init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
# Launch the graph
with tf.Session() as sess:
@@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {
"collapsed": false
},
@@ -27,10 +27,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
"Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
"Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
"Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
]
}
],
@@ -40,14 +40,14 @@
"\n",
"# Import MINST data\n",
"from tensorflow.examples.tutorials.mnist import input_data\n",
"mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
"mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {
"collapsed": true
"collapsed": false
},
"outputs": [],
"source": [
@@ -61,19 +61,19 @@
"\n",
"# Nearest Neighbor calculation using L1 Distance\n",
"# Calculate L1 Distance\n",
"distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n",
"distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)\n",
"# Prediction: Get min distance index (Nearest neighbor)\n",
"pred = tf.arg_min(distance, 0)\n",
"\n",
"accuracy = 0.\n",
"\n",
"# Initializing the variables\n",
"init = tf.initialize_all_variables()"
"init = tf.global_variables_initializer()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {
"collapsed": false
},
@@ -305,6 +305,15 @@
" print \"Done!\"\n",
" print \"Accuracy:\", accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -316,16 +325,16 @@
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2.0
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}

0 comments on commit 8e03823

Please sign in to comment.