-
Notifications
You must be signed in to change notification settings - Fork 6
/
inception_v3_imagenet.py
62 lines (53 loc) · 2.19 KB
/
inception_v3_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from utils import optimistic_restore, t_optimistic_restore
t_optimistic_restore, t_optimistic_restore, t_optimistic_restore
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
import functools
import os
SIZE = 299
# to make this work, you need to download:
# http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
# and decompress it in the `data` directory
_INCEPTION_CHECKPOINT_NAME = 'inception_v3.ckpt'
_INCEPTION_CHECKPOINT_NAME = 'ens4_adv_inception_v3.ckpt'
_INCEPTION_CHECKPOINT_NAME = 'ens4_adv_inception_v3.ckpt.data-00000-of-0000'
_INCEPTION_CHECKPOINT_NAME = 'model_v1.ckpt'
INCEPTION_CHECKPOINT_PATH = os.path.join(
os.path.dirname(__file__),
'data',
_INCEPTION_CHECKPOINT_NAME
)
def _get_model(reuse):
arg_scope = nets.inception.inception_v3_arg_scope(weight_decay=0.0)
func = nets.inception.inception_v3
@functools.wraps(func)
def network_fn(images):
with slim.arg_scope(arg_scope):
return func(images, 1001, is_training=False, reuse=reuse)
if hasattr(func, 'default_image_size'):
network_fn.default_image_size = func.default_image_size
return network_fn
def _preprocess(image, height, width, scope=None):
with tf.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize_bilinear(image, [height, width], align_corners=False)
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image
# input is [batch, 256, 256, 3], pixels in [0, 1]
# output is [batch, 10]
_inception_initialized = False
def model(sess, image):
global _inception_initialized
network_fn = _get_model(reuse=_inception_initialized)
size = network_fn.default_image_size
preprocessed = _preprocess(image, size, size)
logits, _ = network_fn(preprocessed)
logits = logits[:,1:] # ignore background class
predictions = tf.argmax(logits, 1)
if not _inception_initialized:
optimistic_restore(sess, INCEPTION_CHECKPOINT_PATH)
_inception_initialized = True
return logits, predictions