Skip to content

Commit 30b6a95

Browse files
authored
add mobilenet
1 parent a53d2cc commit 30b6a95

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed

CNNs/MobileNet.py

+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""
2+
2017/11/24 ref:https://github.com/Zehaos/MobileNet/blob/master/nets/mobilenet.py
3+
"""
4+
5+
import tensorflow as tf
6+
from tensorflow.python.training import moving_averages
7+
8+
UPDATE_OPS_COLLECTION = "_update_ops_"
9+
10+
# create variable
11+
def create_variable(name, shape, initializer,
12+
dtype=tf.float32, trainable=True):
13+
return tf.get_variable(name, shape=shape, dtype=dtype,
14+
initializer=initializer, trainable=trainable)
15+
16+
# batchnorm layer
17+
def bacthnorm(inputs, scope, epsilon=1e-05, momentum=0.99, is_training=True):
18+
inputs_shape = inputs.get_shape().as_list()
19+
params_shape = inputs_shape[-1:]
20+
axis = list(range(len(inputs_shape) - 1))
21+
22+
with tf.variable_scope(scope):
23+
beta = create_variable("beta", params_shape,
24+
initializer=tf.zeros_initializer())
25+
gamma = create_variable("gamma", params_shape,
26+
initializer=tf.ones_initializer())
27+
# for inference
28+
moving_mean = create_variable("moving_mean", params_shape,
29+
initializer=tf.zeros_initializer(), trainable=False)
30+
moving_variance = create_variable("moving_variance", params_shape,
31+
initializer=tf.ones_initializer(), trainable=False)
32+
if is_training:
33+
mean, variance = tf.nn.moments(inputs, axes=axis)
34+
update_move_mean = moving_averages.assign_moving_average(moving_mean,
35+
mean, decay=momentum)
36+
update_move_variance = moving_averages.assign_moving_average(moving_variance,
37+
variance, decay=momentum)
38+
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_move_mean)
39+
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_move_variance)
40+
else:
41+
mean, variance = moving_mean, moving_variance
42+
return tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
43+
44+
# depthwise conv2d layer
45+
def depthwise_conv2d(inputs, scope, filter_size=3, channel_multiplier=1, strides=1):
46+
inputs_shape = inputs.get_shape().as_list()
47+
in_channels = inputs_shape[-1]
48+
with tf.variable_scope(scope):
49+
filter = create_variable("filter", shape=[filter_size, filter_size,
50+
in_channels, channel_multiplier],
51+
initializer=tf.truncated_normal_initializer(stddev=0.01))
52+
53+
return tf.nn.depthwise_conv2d(inputs, filter, strides=[1, strides, strides, 1],
54+
padding="SAME", rate=[1, 1])
55+
56+
# conv2d layer
57+
def conv2d(inputs, scope, num_filters, filter_size=1, strides=1):
58+
inputs_shape = inputs.get_shape().as_list()
59+
in_channels = inputs_shape[-1]
60+
with tf.variable_scope(scope):
61+
filter = create_variable("filter", shape=[filter_size, filter_size,
62+
in_channels, num_filters],
63+
initializer=tf.truncated_normal_initializer(stddev=0.01))
64+
return tf.nn.conv2d(inputs, filter, strides=[1, strides, strides, 1],
65+
padding="SAME")
66+
67+
# avg pool layer
68+
def avg_pool(inputs, pool_size, scope):
69+
with tf.variable_scope(scope):
70+
return tf.nn.avg_pool(inputs, [1, pool_size, pool_size, 1],
71+
strides=[1, pool_size, pool_size, 1], padding="VALID")
72+
73+
# fully connected layer
74+
def fc(inputs, n_out, scope, use_bias=True):
75+
inputs_shape = inputs.get_shape().as_list()
76+
n_in = inputs_shape[-1]
77+
with tf.variable_scope(scope):
78+
weight = create_variable("weight", shape=[n_in, n_out],
79+
initializer=tf.random_normal_initializer(stddev=0.01))
80+
if use_bias:
81+
bias = create_variable("bias", shape=[n_out,],
82+
initializer=tf.zeros_initializer())
83+
return tf.nn.xw_plus_b(inputs, weight, bias)
84+
return tf.matmul(inputs, weight)
85+
86+
87+
class MobileNet(object):
88+
def __init__(self, inputs, num_classes=1000, is_training=True,
89+
width_multiplier=1, scope="MobileNet"):
90+
"""
91+
The implement of MobileNet(ref:https://arxiv.org/abs/1704.04861)
92+
:param inputs: 4-D Tensor of [batch_size, height, width, channels]
93+
:param num_classes: number of classes
94+
:param is_training: Boolean, whether or not the model is training
95+
:param width_multiplier: float, controls the size of model
96+
:param scope: Optional scope for variables
97+
"""
98+
self.inputs = inputs
99+
self.num_classes = num_classes
100+
self.is_training = is_training
101+
self.width_multiplier = width_multiplier
102+
103+
# construct model
104+
with tf.variable_scope(scope):
105+
# conv1
106+
net = conv2d(inputs, "conv_1", round(32 * width_multiplier), filter_size=3,
107+
strides=2) # ->[N, 112, 112, 32]
108+
net = tf.nn.relu(bacthnorm(net, "conv_1/bn", is_training=self.is_training))
109+
net = self._depthwise_separable_conv2d(net, 64, self.width_multiplier,
110+
"ds_conv_2") # ->[N, 112, 112, 64]
111+
net = self._depthwise_separable_conv2d(net, 128, self.width_multiplier,
112+
"ds_conv_3", downsample=True) # ->[N, 56, 56, 128]
113+
net = self._depthwise_separable_conv2d(net, 128, self.width_multiplier,
114+
"ds_conv_4") # ->[N, 56, 56, 128]
115+
net = self._depthwise_separable_conv2d(net, 256, self.width_multiplier,
116+
"ds_conv_5", downsample=True) # ->[N, 28, 28, 256]
117+
net = self._depthwise_separable_conv2d(net, 256, self.width_multiplier,
118+
"ds_conv_6") # ->[N, 28, 28, 256]
119+
net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier,
120+
"ds_conv_7", downsample=True) # ->[N, 14, 14, 512]
121+
net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier,
122+
"ds_conv_8") # ->[N, 14, 14, 512]
123+
net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier,
124+
"ds_conv_9") # ->[N, 14, 14, 512]
125+
net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier,
126+
"ds_conv_10") # ->[N, 14, 14, 512]
127+
net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier,
128+
"ds_conv_11") # ->[N, 14, 14, 512]
129+
net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier,
130+
"ds_conv_12") # ->[N, 14, 14, 512]
131+
net = self._depthwise_separable_conv2d(net, 1024, self.width_multiplier,
132+
"ds_conv_13", downsample=True) # ->[N, 7, 7, 1024]
133+
net = self._depthwise_separable_conv2d(net, 1024, self.width_multiplier,
134+
"ds_conv_14") # ->[N, 7, 7, 1024]
135+
net = avg_pool(net, 7, "avg_pool_15")
136+
net = tf.squeeze(net, [1, 2], name="SpatialSqueeze")
137+
self.logits = fc(net, self.num_classes, "fc_16")
138+
self.predictions = tf.nn.softmax(self.logits)
139+
140+
def _depthwise_separable_conv2d(self, inputs, num_filters, width_multiplier,
141+
scope, downsample=False):
142+
"""depthwise separable convolution 2D function"""
143+
num_filters = round(num_filters * width_multiplier)
144+
strides = 2 if downsample else 1
145+
146+
with tf.variable_scope(scope):
147+
# depthwise conv2d
148+
dw_conv = depthwise_conv2d(inputs, "depthwise_conv", strides=strides)
149+
# batchnorm
150+
bn = bacthnorm(dw_conv, "dw_bn", is_training=self.is_training)
151+
# relu
152+
relu = tf.nn.relu(bn)
153+
# pointwise conv2d (1x1)
154+
pw_conv = conv2d(relu, "pointwise_conv", num_filters)
155+
# bn
156+
bn = bacthnorm(pw_conv, "pw_bn", is_training=self.is_training)
157+
return tf.nn.relu(bn)
158+
159+
if __name__ == "__main__":
160+
# test data
161+
inputs = tf.random_normal(shape=[4, 224, 224, 3])
162+
mobileNet = MobileNet(inputs)
163+
writer = tf.summary.FileWriter("./logs", graph=tf.get_default_graph())
164+
init = tf.global_variables_initializer()
165+
with tf.Session() as sess:
166+
sess.run(init)
167+
pred = sess.run(mobileNet.predictions)
168+
print(pred.shape)
169+

0 commit comments

Comments
 (0)